├── .gitignore ├── assets └── images │ ├── Benchmark Performance.jpg │ ├── case-attire.jpg │ ├── case-helpful.jpg │ ├── case-onion.jpg │ └── reformulating_datasets.jpg ├── eval ├── data │ ├── caps_boxes_coco2014_val_80.jsonl │ ├── qa90_gpt4_answer.jsonl │ ├── qa90_questions.jsonl │ ├── qa90_questions_with_image.jsonl │ └── rule.jsonfile ├── eval_gpt_review_llava.py ├── eval_gpt_review_unimm-bench.py ├── gpt4_grpc.py ├── summarize_gpt_llava_review.py └── summarize_gpt_unimm-bench_review.py ├── muffin ├── __init__.py ├── constants.py ├── conversation.py ├── data │ ├── data_processors.py │ ├── datasets.py │ ├── tsv_file.py │ └── tsv_file_op.py ├── eval │ ├── muffin_chat.py │ ├── muffin_inference_logp.py │ └── muffin_vqa.py ├── model │ ├── __init__.py │ ├── beit3.py │ ├── llava.py │ ├── muffin.py │ └── utils.py ├── serve │ ├── __init__.py │ ├── cli.py │ ├── controller.py │ ├── examples │ │ ├── extreme_ironing.jpg │ │ └── waterview.jpg │ ├── gateway │ │ ├── README.md │ │ └── nginx.conf │ ├── gradio_css.py │ ├── gradio_patch.py │ ├── gradio_web_server.py │ ├── model_worker.py │ ├── muffin_model_worker.py │ ├── register_worker.py │ └── test_message.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── train_mem_muffin.py │ ├── train_muffin.py │ ├── train_utils.py │ └── trainers.py └── utils.py ├── pyproject.toml ├── readme.md ├── script ├── download_data.sh ├── eval │ ├── batch_gpt4_review.sh │ ├── eval_muffin_inference_logp.sh │ ├── eval_muffin_qa.sh │ └── eval_muffin_qaFILE.sh ├── init.sh └── train │ ├── ddpo_muffin_PATH.sh │ ├── run_RLHFV.sh │ ├── run_unimm-chat.sh │ └── sft_muffin_PATH.sh └── utils ├── __init__.py ├── diff_lib.py ├── logger.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data 2 | wandb 3 | __pycache__/ -------------------------------------------------------------------------------- /assets/images/Benchmark Performance.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Muffin/85d72b4b04035b6b1cf4168814739ae9e5e6e1dc/assets/images/Benchmark Performance.jpg -------------------------------------------------------------------------------- /assets/images/case-attire.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Muffin/85d72b4b04035b6b1cf4168814739ae9e5e6e1dc/assets/images/case-attire.jpg -------------------------------------------------------------------------------- /assets/images/case-helpful.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Muffin/85d72b4b04035b6b1cf4168814739ae9e5e6e1dc/assets/images/case-helpful.jpg -------------------------------------------------------------------------------- /assets/images/case-onion.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Muffin/85d72b4b04035b6b1cf4168814739ae9e5e6e1dc/assets/images/case-onion.jpg -------------------------------------------------------------------------------- /assets/images/reformulating_datasets.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Muffin/85d72b4b04035b6b1cf4168814739ae9e5e6e1dc/assets/images/reformulating_datasets.jpg -------------------------------------------------------------------------------- /eval/data/qa90_questions.jsonl: -------------------------------------------------------------------------------- 1 | {"question_id": 0, "image": "000000441147.jpg", "text": "What is the color of the two suitcases in the image?", "category": "conv"} 2 | {"question_id": 1, "image": "000000441147.jpg", "text": "Analyze the image in a comprehensive and detailed manner.", "category": "detail"} 3 | {"question_id": 2, "image": "000000441147.jpg", "text": "What potential factors could make these suitcases valuable?", "category": "complex"} 4 | {"question_id": 3, "image": "000000353536.jpg", "text": "What are the main objects on the table in the image?", "category": "conv"} 5 | {"question_id": 4, "image": "000000353536.jpg", "text": "Describe the following image.", "category": "detail"} 6 | {"question_id": 5, "image": "000000353536.jpg", "text": "What activity might have recently taken place around this table and what could be its aftermath?", "category": "complex"} 7 | {"question_id": 6, "image": "000000506095.jpg", "text": "What is the main focus of the image?", "category": "conv"} 8 | {"question_id": 7, "image": "000000506095.jpg", "text": "What is this photo about'?", "category": "detail"} 9 | {"question_id": 8, "image": "000000506095.jpg", "text": "What could be a reason for the cat's interest in the laptop?", "category": "complex"} 10 | {"question_id": 9, "image": "000000514915.jpg", "text": "What color is the dog in the image?", "category": "conv"} 11 | {"question_id": 10, "image": "000000514915.jpg", "text": "Describe the following image.", "category": "detail"} 12 | {"question_id": 11, "image": "000000514915.jpg", "text": "Why might the dog have chosen to lay down on the floor in this location?", "category": "complex"} 13 | {"question_id": 12, "image": "000000056013.jpg", "text": "Where is the cart with luggage bags located?", "category": "conv"} 14 | {"question_id": 13, "image": "000000056013.jpg", "text": "Describe the following image.", "category": "detail"} 15 | {"question_id": 14, "image": "000000056013.jpg", "text": "What challenges might the staff encounter with handling luggage in the lobby?", "category": "complex"} 16 | {"question_id": 15, "image": "000000408439.jpg", "text": "What is the body of water seen in the image?", "category": "conv"} 17 | {"question_id": 16, "image": "000000408439.jpg", "text": "Can you elaborate on the elements of the picture provided?", "category": "detail"} 18 | {"question_id": 17, "image": "000000408439.jpg", "text": "What are potential reasons for the train to be traveling in such a scenic route?", "category": "complex"} 19 | {"question_id": 18, "image": "000000164255.jpg", "text": "What is the main object in the image?", "category": "conv"} 20 | {"question_id": 19, "image": "000000164255.jpg", "text": "What do you think is going on in this snapshot?", "category": "detail"} 21 | {"question_id": 20, "image": "000000164255.jpg", "text": "What purpose does the umbrella serve on the beach?", "category": "complex"} 22 | {"question_id": 21, "image": "000000034096.jpg", "text": "What stage is the house in regarding its construction?", "category": "conv"} 23 | {"question_id": 22, "image": "000000034096.jpg", "text": "What is this photo about'?", "category": "detail"} 24 | {"question_id": 23, "image": "000000034096.jpg", "text": "What may be a potential obstacle for the house construction in its current state?", "category": "complex"} 25 | {"question_id": 24, "image": "000000385873.jpg", "text": "How many pizzas are there in the image?", "category": "conv"} 26 | {"question_id": 25, "image": "000000385873.jpg", "text": "Explain the visual content of the image in great detail.", "category": "detail"} 27 | {"question_id": 26, "image": "000000385873.jpg", "text": "What are the possible challenges one might face when dividing these pizzas equitably among a group of people?", "category": "complex"} 28 | {"question_id": 27, "image": "000000367571.jpg", "text": "How many doughnuts are in the box?", "category": "conv"} 29 | {"question_id": 28, "image": "000000367571.jpg", "text": "Write a detailed description of the given image.", "category": "detail"} 30 | {"question_id": 29, "image": "000000367571.jpg", "text": "What unique taste combinations can be found in this box of doughnuts?", "category": "complex"} 31 | {"question_id": 30, "image": "000000319432.jpg", "text": "What is the man doing in the image?", "category": "conv"} 32 | {"question_id": 31, "image": "000000319432.jpg", "text": "Can you describe the main features of this image for me?", "category": "detail"} 33 | {"question_id": 32, "image": "000000319432.jpg", "text": "Why might the man be holding the clothes on hangers?", "category": "complex"} 34 | {"question_id": 33, "image": "000000225738.jpg", "text": "How many giraffes are depicted in the image?", "category": "conv"} 35 | {"question_id": 34, "image": "000000225738.jpg", "text": "Write a detailed description of the given image.", "category": "detail"} 36 | {"question_id": 35, "image": "000000225738.jpg", "text": "Why might these giraffes be gathering near the same tree?", "category": "complex"} 37 | {"question_id": 36, "image": "000000092109.jpg", "text": "What is the main focus of the image?", "category": "conv"} 38 | {"question_id": 37, "image": "000000092109.jpg", "text": "What is this photo about'?", "category": "detail"} 39 | {"question_id": 38, "image": "000000092109.jpg", "text": "What can be inferred about the giraffe's habitat from this image?", "category": "complex"} 40 | {"question_id": 39, "image": "000000506483.jpg", "text": "What kind of objects are included in the art installation?", "category": "conv"} 41 | {"question_id": 40, "image": "000000506483.jpg", "text": "Explain the visual content of the image in great detail.", "category": "detail"} 42 | {"question_id": 41, "image": "000000506483.jpg", "text": "How might the art installation encourage interaction among people?", "category": "complex"} 43 | {"question_id": 42, "image": "000000097131.jpg", "text": "What is the main focus of the image?", "category": "conv"} 44 | {"question_id": 43, "image": "000000097131.jpg", "text": "Can you elaborate on the elements of the picture provided?", "category": "detail"} 45 | {"question_id": 44, "image": "000000097131.jpg", "text": "What might be the reason for the car to be parked on the side of the road?", "category": "complex"} 46 | {"question_id": 45, "image": "000000052312.jpg", "text": "What is unusual about the elderly man's facial appearance in the image?", "category": "conv"} 47 | {"question_id": 46, "image": "000000052312.jpg", "text": "What's happening in the scene?", "category": "detail"} 48 | {"question_id": 47, "image": "000000052312.jpg", "text": "What message could be conveyed by the man's unusual facial adornment?", "category": "complex"} 49 | {"question_id": 48, "image": "000000258285.jpg", "text": "How many airplanes are visible in the image?", "category": "conv"} 50 | {"question_id": 49, "image": "000000258285.jpg", "text": "What do you see happening in this image?", "category": "detail"} 51 | {"question_id": 50, "image": "000000258285.jpg", "text": "What environmental impact might these planes' activity have on the birds?", "category": "complex"} 52 | {"question_id": 51, "image": "000000081552.jpg", "text": "What is the cat doing on the red couch?", "category": "conv"} 53 | {"question_id": 52, "image": "000000081552.jpg", "text": "Write a detailed description of the given image.", "category": "detail"} 54 | {"question_id": 53, "image": "000000081552.jpg", "text": "What factors may have led the cat to choose this location for sleep?", "category": "complex"} 55 | {"question_id": 54, "image": "000000066144.jpg", "text": "What type of car is parked next to the stop sign in the image?", "category": "conv"} 56 | {"question_id": 55, "image": "000000066144.jpg", "text": "Explain the visual content of the image in great detail.", "category": "detail"} 57 | {"question_id": 56, "image": "000000066144.jpg", "text": "What can be inferred about the traffic situation in the area?", "category": "complex"} 58 | {"question_id": 57, "image": "000000203629.jpg", "text": "What is the man doing in the image?", "category": "conv"} 59 | {"question_id": 58, "image": "000000203629.jpg", "text": "Can you describe the main features of this image for me?", "category": "detail"} 60 | {"question_id": 59, "image": "000000203629.jpg", "text": "What might be the reason behind the man's thumbs-up gesture while on the phone?", "category": "complex"} 61 | {"question_id": 60, "image": "000000109532.jpg", "text": "What breed is the dog in the image?", "category": "conv"} 62 | {"question_id": 61, "image": "000000109532.jpg", "text": "What are the key elements in this picture?", "category": "detail"} 63 | {"question_id": 62, "image": "000000109532.jpg", "text": "What measures may be taken by the dog's owner to make the dog feel comfortable in the room?", "category": "complex"} 64 | {"question_id": 63, "image": "000000431165.jpg", "text": "What is the color of the elephant in the image?", "category": "conv"} 65 | {"question_id": 64, "image": "000000431165.jpg", "text": "What do you think is going on in this snapshot?", "category": "detail"} 66 | {"question_id": 65, "image": "000000431165.jpg", "text": "What could be a potential concern for the baby elephant in this setting?", "category": "complex"} 67 | {"question_id": 66, "image": "000000020650.jpg", "text": "What type of food items are present on the orange plate?", "category": "conv"} 68 | {"question_id": 67, "image": "000000020650.jpg", "text": "What are the key elements in this picture?", "category": "detail"} 69 | {"question_id": 68, "image": "000000020650.jpg", "text": "What can be inferred about the individual's dietary choices?", "category": "complex"} 70 | {"question_id": 69, "image": "000000205183.jpg", "text": "What type of bird is pictured in the image?", "category": "conv"} 71 | {"question_id": 70, "image": "000000205183.jpg", "text": "What do you see happening in this image?", "category": "detail"} 72 | {"question_id": 71, "image": "000000205183.jpg", "text": "What challenges could the duck face while walking along the paved road?", "category": "complex"} 73 | {"question_id": 72, "image": "000000210299.jpg", "text": "Is the little boy wearing any safety gear while riding his bike?", "category": "conv"} 74 | {"question_id": 73, "image": "000000210299.jpg", "text": "Can you elaborate on the elements of the picture provided?", "category": "detail"} 75 | {"question_id": 74, "image": "000000210299.jpg", "text": "Why is the little boy's precaution noteworthy?", "category": "complex"} 76 | {"question_id": 75, "image": "000000515716.jpg", "text": "How many people are visible in the image?", "category": "conv"} 77 | {"question_id": 76, "image": "000000515716.jpg", "text": "What do you see happening in this image?", "category": "detail"} 78 | {"question_id": 77, "image": "000000515716.jpg", "text": "What might be the reason behind this gathering?", "category": "complex"} 79 | {"question_id": 78, "image": "000000460149.jpg", "text": "What is the main object featured in the image?", "category": "conv"} 80 | {"question_id": 79, "image": "000000460149.jpg", "text": "Describe the following image.", "category": "detail"} 81 | {"question_id": 80, "image": "000000460149.jpg", "text": "Imagine you are walking down this street; how would you know the time or which direction the sidewalk leads?", "category": "complex"} 82 | {"question_id": 81, "image": "000000293505.jpg", "text": "Is the person riding a motorbike or scooter?", "category": "conv"} 83 | {"question_id": 82, "image": "000000293505.jpg", "text": "Explain the visual content of the image in great detail.", "category": "detail"} 84 | {"question_id": 83, "image": "000000293505.jpg", "text": "What challenges might the woman on the motorcycle face in this situation?", "category": "complex"} 85 | {"question_id": 84, "image": "000000525439.jpg", "text": "What is the position of the skateboard in the image?", "category": "conv"} 86 | {"question_id": 85, "image": "000000525439.jpg", "text": "Can you elaborate on the elements of the picture provided?", "category": "detail"} 87 | {"question_id": 86, "image": "000000525439.jpg", "text": "What might have happened prior to this moment?", "category": "complex"} 88 | {"question_id": 87, "image": "000000534270.jpg", "text": "What type of hats are the man and woman wearing?", "category": "conv"} 89 | {"question_id": 88, "image": "000000534270.jpg", "text": "Can you describe the main features of this image for me?", "category": "detail"} 90 | {"question_id": 89, "image": "000000534270.jpg", "text": "What benefits do umbrella hats provide to the couple on this rainy day?", "category": "complex"} 91 | -------------------------------------------------------------------------------- /eval/data/rule.jsonfile: -------------------------------------------------------------------------------- 1 | { 2 | "coding": {"role": "Assistant", "prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, 3 | "math": {"role": "Assistant", "prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, 4 | "default": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, 5 | "conv": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, 6 | "detail": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, 7 | "complex": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, 8 | "vqa_standard": {"role": "Assistant", "prompt": "We would like to request your feedback to evaluate the performance of an AI assistant in the response to an user question displayed above. The AI assistant is asked to look the image and answer the question. You need to give an overall score to the assistant's response to the question on a scale of 1 to 5, where a higher score indicates better overall performance. Please first output a single line containing only one value indicating the score for the assistant.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation. \n\n We will give you some additional information about the image and question for reference in the following (such as the expected answer, human answers and hints given by annotators). Note that the assistant can only see the image content and question text, all other reference informations are used to help you better understand the question and content of the image only. The major criteria is the correctness of the answer, you don't have to care about the conciseness or structure or other irrelevant factors of the answer."} 9 | } -------------------------------------------------------------------------------- /eval/eval_gpt_review_llava.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import time 5 | import pathlib 6 | 7 | from gpt4_grpc import get_eval 8 | 9 | def parse_score(review): 10 | try: 11 | score_pair = review.split('\n')[0] 12 | score_pair = score_pair.replace(',', ' ') 13 | sp = score_pair.split(' ') 14 | if len(sp) == 2: 15 | return [float(sp[0]), float(sp[1])] 16 | else: 17 | print('error', review) 18 | return [-1, -1] 19 | except Exception as e: 20 | print(e) 21 | print('error', review) 22 | return [-1, -1] 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 27 | parser.add_argument('-q', '--question') 28 | parser.add_argument('-c', '--context') 29 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 30 | parser.add_argument('-r', '--rule') 31 | parser.add_argument('-o', '--output') 32 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 33 | args = parser.parse_args() 34 | 35 | chat = 'gpt-4-0314' 36 | 37 | f_q = open(os.path.expanduser(args.question)) 38 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 39 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 40 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 41 | 42 | 43 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 44 | image_to_context = {context['image']: context for context in context_list} 45 | 46 | 47 | reviewed_lines = [] 48 | if pathlib.Path(args.output).exists(): 49 | reviewed_lines = open(args.output).readlines()[:-1] 50 | print(f'Resume {args.output} from {len(reviewed_lines)}') 51 | review_file = open(f'{args.output}', 'w') 52 | if reviewed_lines: 53 | for line in reviewed_lines: 54 | review_file.write(line) 55 | review_file.flush() 56 | 57 | js_list = [] 58 | handles = [] 59 | for line_idx, (ques_js, ans1_js, ans2_js) in enumerate(zip(f_q, f_ans1, f_ans2)): 60 | if line_idx < len(reviewed_lines): 61 | continue 62 | ques = json.loads(ques_js) 63 | ans1 = json.loads(ans1_js) 64 | ans2 = json.loads(ans2_js) 65 | 66 | inst = image_to_context[ques['image']] 67 | cap_str = '\n'.join(inst['captions']) 68 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) 69 | 70 | category = json.loads(ques_js)['category'] 71 | if category in rule_dict: 72 | rule = rule_dict[category] 73 | else: 74 | assert False, f"Visual QA category not found in rule file: {category}." 75 | prompt = rule['prompt'] 76 | role = rule['role'] 77 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' 78 | f'[Question]\n{ques["text"]}\n\n' 79 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 80 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 81 | f'[System]\n{prompt}\n\n') 82 | output = { 83 | 'id': line_idx, 84 | 'question_id': ques['question_id'], 85 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 86 | 'answer2_id': ans2.get('answer_id', 0), 87 | 'category': category} 88 | 89 | review = get_eval(chat, content, max_tokens=180) 90 | scores = parse_score(review) 91 | output['content'] = review 92 | output['tuple'] = scores 93 | review_file.write(json.dumps(output) + '\n') 94 | review_file.flush() 95 | 96 | # To avoid the rate limit set by OpenAI 97 | time.sleep(1) 98 | 99 | review_file.close() 100 | -------------------------------------------------------------------------------- /eval/eval_gpt_review_unimm-bench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import tqdm 6 | import time 7 | import pathlib 8 | 9 | from gpt4_grpc import get_eval 10 | 11 | 12 | def avg(lst): 13 | return sum(lst) / len(lst) 14 | 15 | 16 | def parse_score(review): 17 | try: 18 | score_str = review.split('\n')[0] 19 | return float(score_str) 20 | except Exception as e: 21 | print(e) 22 | print('error', review) 23 | return -1 24 | 25 | 26 | if __name__ == '__main__': 27 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 28 | parser.add_argument('-q', '--question') 29 | # parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 30 | parser.add_argument('-a', '--answer') 31 | parser.add_argument('-r', '--rule') 32 | parser.add_argument('-o', '--output') 33 | parser.add_argument('-l', '--limit', default=100, type=int) 34 | parser.add_argument('--max-tokens', type=int, default=1024, 35 | help='maximum number of tokens produced in the output') 36 | args = parser.parse_args() 37 | 38 | f_q = open(os.path.expanduser(args.question)) 39 | questions = json.load(f_q) 40 | category_lst = ["AOKVQA", "GQA", "OKVQA", "VQAv2"] 41 | for i, ques_js in enumerate(questions): 42 | ques_js["category"] = category_lst[i // 100] 43 | ans1_list = [json.loads(line) for line in open(os.path.expanduser(args.answer))] 44 | 45 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 46 | 47 | reviewed_lines = [] 48 | score_list = [] 49 | if pathlib.Path(args.output).exists(): 50 | reviewed_lines = open(args.output).readlines()[:-1] 51 | print(f'Resume {args.output} from {len(reviewed_lines)}') 52 | review_file = open(f'{args.output}', 'w') 53 | if reviewed_lines: 54 | for line in reviewed_lines: 55 | review_file.write(line) 56 | score_list.append(json.loads(line)['tuple']) 57 | review_file.flush() 58 | 59 | chat = 'gpt-4-0314' 60 | 61 | js_list = [] 62 | handles = [] 63 | category_sample_count = dict() 64 | 65 | for line_idx, (ques_js, ans1_js) in tqdm.tqdm(enumerate(zip(questions, ans1_list)), total=400): 66 | if line_idx < len(reviewed_lines): 67 | continue 68 | category = ques_js['category'] 69 | 70 | if category not in category_sample_count: 71 | category_sample_count[category] = 0 72 | if category_sample_count[category] == args.limit: 73 | continue 74 | category_sample_count[category] += 1 75 | ques, ans1 = ques_js, ans1_js 76 | 77 | rule = rule_dict["vqa_standard"] 78 | prompt = rule['prompt'] 79 | role = rule['role'] 80 | 81 | content = (f'[Question]\n\n{ques["question"]}\n\n' 82 | f'[{role} Response]\n{ans1["text"]}\n\n[End of {role} Response]\n\n' 83 | f'[System]\n{prompt}\n\n' 84 | f'[Expected Answer]\n{ques["answer"]}\n\n') 85 | 86 | if category == "AOKVQA": 87 | content = content + (f'[Rationale to Get Answer]\n{ques["context"]}\n\n' 88 | f'[Human Answers]\n{ques["Human Answers"].replace("[", " ").replace("]", "")}\n\n' 89 | ) 90 | elif category == "GQA": 91 | content = content + (f'[Human Answers]\n{ques["Human Answers"]}\n\n' 92 | ) 93 | elif category == "OKVQA": 94 | human_answers = ques['Human Answers'].replace( 95 | ']', '').replace('Human answers are: [', '[Human Answers]\n ') 96 | content = content + (f'{human_answers}\n\n' 97 | ) 98 | elif category == "ScienceQA": 99 | content = content + (f'[Hint]\n{ques["context"]}\n\n' 100 | f'[Human Answers]\n{ques["Human Answers"]}\n\n' 101 | ) 102 | elif category == "VizWiz": 103 | human_answers = ques['Human Answers'].replace( 104 | ']', '').replace('Human answers are: [', '[Human Answers]\n ') 105 | content = content + (f'[Image Description]\n{ques["context"]}\n\n' 106 | f'{human_answers}\n\n' 107 | ) 108 | elif category == "VQAv2": 109 | content = content + (f'[Human Answers]\n{ques["Human Answers"].replace("[", " ").replace("]", "")}\n\n' 110 | ) 111 | elif category == "WebQA": 112 | content = content + (f'[Webpage Title]\n{ques["context"]}\n\n' 113 | f'[Human Answers]\n{ques["answer"]}\n\n' 114 | ) 115 | output = { 116 | 'id': line_idx + 1, 117 | 'question_id': ans1.get('question_id', line_idx + 1), 118 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 119 | # 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 120 | 'category': category, 121 | 'input_content': content} 122 | 123 | review = get_eval(chat, content) 124 | 125 | interval = len(content.split()) // 50 126 | time.sleep(interval) 127 | 128 | score = parse_score(review) 129 | if score == -1: 130 | start = len(score_list) // args.limit * args.limit 131 | score = avg(score_list[start:]) 132 | print(f'Drop One Sample, use smoothed value of #{start} to #{len(score_list)} = {score:.2f}') 133 | 134 | score_list.append(score) 135 | 136 | output['content'] = review 137 | output['tuple'] = score 138 | review_file.write(json.dumps(output) + '\n') 139 | review_file.flush() 140 | 141 | review_file.close() 142 | -------------------------------------------------------------------------------- /eval/gpt4_grpc.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import time 3 | 4 | NUM_SECONDS_TO_SLEEP = 0.5 5 | 6 | 7 | # 'gpt-4-0314' 8 | def get_eval(model, content: str, 9 | chat_gpt_system='You are a helpful and precise assistant for checking the quality of the answer.', 10 | max_tokens: int=256, 11 | fail_limit=100): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model=model, 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': chat_gpt_system 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] -------------------------------------------------------------------------------- /eval/summarize_gpt_llava_review.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import sys 7 | import glob 8 | 9 | 10 | if __name__ == '__main__': 11 | base_dir = sys.argv[1] 12 | print(base_dir) 13 | 14 | patterns = ['*', '*/*', '*/*/*'] 15 | f_list = sum([list(glob.glob(os.path.join(base_dir, p))) for p in patterns], []) 16 | review_files = [x for x in f_list if x.endswith('.jsonl') and 'llava_test_gpt4' in x] 17 | 18 | for review_file in sorted(review_files): 19 | config = review_file.replace('gpt4_text_', '').replace('.jsonl', '') 20 | scores = defaultdict(list) 21 | print(f'GPT-4 vs. {config} #{len(list(open(review_file)))}') 22 | with open(review_file) as f: 23 | for review_str in f: 24 | review = json.loads(review_str) 25 | # filter failed case 26 | if review['tuple'][0] == -1: 27 | print(f'#### Skip fail Case') 28 | continue 29 | 30 | scores[review['category']].append(review['tuple']) 31 | scores['all'].append(review['tuple']) 32 | for k, v in scores.items(): 33 | stats = np.asarray(v).mean(0).tolist() 34 | stats = [round(x, 3) for x in stats] 35 | print(k, stats, round(stats[1]/stats[0]*100, 1)) 36 | print('=================================') 37 | -------------------------------------------------------------------------------- /eval/summarize_gpt_unimm-bench_review.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import sys 7 | import glob 8 | 9 | 10 | if __name__ == '__main__': 11 | base_dir = sys.argv[1] 12 | print(base_dir) 13 | 14 | patterns = ['*', '*/*', '*/*/*'] 15 | f_list = sum([list(glob.glob(os.path.join(base_dir, p))) for p in patterns], []) 16 | review_files = [x for x in f_list if x.endswith('.jsonl') and 'unimm-bench_gpt4_eval' in x and 'gpt4' in x] 17 | 18 | for review_file in sorted(review_files): 19 | config = review_file.replace('gpt4_text_', '').replace('.jsonl', '') 20 | scores = defaultdict(list) 21 | print(f'{config} #{len(list(open(review_file)))}') 22 | with open(review_file) as f: 23 | for review_str in f: 24 | try: 25 | review = json.loads(review_str) 26 | # filter failed case 27 | if review['content'].startswith('-1\n'): 28 | print(f'#### Skip fail Case') 29 | continue 30 | scores[review['category']].append(review['tuple']) 31 | scores['all'].append(review['tuple']) 32 | except: 33 | print(f'Error parsing {review_str}') 34 | for k, v in scores.items(): 35 | mean = np.asarray(v).mean() / 5 * 100 36 | print(f'{k:<7s} {mean: .2f}') 37 | print('=================================') 38 | -------------------------------------------------------------------------------- /muffin/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM, Beit3LlavaLlamaForCausalLM -------------------------------------------------------------------------------- /muffin/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | -------------------------------------------------------------------------------- /muffin/conversation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from enum import auto, Enum 3 | from typing import List, Tuple 4 | 5 | 6 | class SeparatorStyle(Enum): 7 | """Different separator style.""" 8 | SINGLE = auto() 9 | TWO = auto() 10 | 11 | 12 | @dataclasses.dataclass 13 | class Conversation: 14 | """A class that keeps all conversation history.""" 15 | system: str 16 | roles: List[str] 17 | messages: List[List[str]] 18 | offset: int 19 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 20 | sep: str = "###" 21 | sep2: str = None 22 | version: str = "Unknown" 23 | 24 | skip_next: bool = False 25 | 26 | def get_prompt(self): 27 | if self.sep_style == SeparatorStyle.SINGLE: 28 | ret = self.system + self.sep 29 | for role, message in self.messages: 30 | if message: 31 | if type(message) is tuple: 32 | message, _, _ = message 33 | ret += role + ": " + message + self.sep 34 | else: 35 | ret += role + ":" 36 | return ret 37 | elif self.sep_style == SeparatorStyle.TWO: 38 | seps = [self.sep, self.sep2] 39 | ret = self.system + seps[0] 40 | for i, (role, message) in enumerate(self.messages): 41 | if message: 42 | if type(message) is tuple: 43 | message, _, _ = message 44 | ret += role + ": " + message + seps[i % 2] 45 | else: 46 | ret += role + ":" 47 | return ret 48 | else: 49 | raise ValueError(f"Invalid style: {self.sep_style}") 50 | 51 | def append_message(self, role, message): 52 | self.messages.append([role, message]) 53 | 54 | def get_images(self, return_pil=False): 55 | images = [] 56 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 57 | if i % 2 == 0: 58 | if type(msg) is tuple: 59 | import base64 60 | from io import BytesIO 61 | from PIL import Image 62 | msg, image, image_process_mode = msg 63 | if image_process_mode == "Pad": 64 | def expand2square(pil_img, background_color=(122, 116, 104)): 65 | width, height = pil_img.size 66 | if width == height: 67 | return pil_img 68 | elif width > height: 69 | result = Image.new(pil_img.mode, (width, width), background_color) 70 | result.paste(pil_img, (0, (width - height) // 2)) 71 | return result 72 | else: 73 | result = Image.new(pil_img.mode, (height, height), background_color) 74 | result.paste(pil_img, ((height - width) // 2, 0)) 75 | return result 76 | image = expand2square(image) 77 | elif image_process_mode == "Crop": 78 | pass 79 | elif image_process_mode == "Resize": 80 | image = image.resize((224, 224)) 81 | else: 82 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}") 83 | max_hw, min_hw = max(image.size), min(image.size) 84 | aspect_ratio = max_hw / min_hw 85 | max_len, min_len = 800, 400 86 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 87 | longest_edge = int(shortest_edge * aspect_ratio) 88 | W, H = image.size 89 | if H > W: 90 | H, W = longest_edge, shortest_edge 91 | else: 92 | H, W = shortest_edge, longest_edge 93 | image = image.resize((W, H)) 94 | if return_pil: 95 | images.append(image) 96 | else: 97 | buffered = BytesIO() 98 | image.save(buffered, format="JPEG") 99 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 100 | images.append(img_b64_str) 101 | return images 102 | 103 | def to_gradio_chatbot(self): 104 | ret = [] 105 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 106 | if i % 2 == 0: 107 | if type(msg) is tuple: 108 | import base64 109 | from io import BytesIO 110 | msg, image, image_process_mode = msg 111 | max_hw, min_hw = max(image.size), min(image.size) 112 | aspect_ratio = max_hw / min_hw 113 | max_len, min_len = 800, 400 114 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 115 | longest_edge = int(shortest_edge * aspect_ratio) 116 | W, H = image.size 117 | if H > W: 118 | H, W = longest_edge, shortest_edge 119 | else: 120 | H, W = shortest_edge, longest_edge 121 | image = image.resize((W, H)) 122 | # image = image.resize((224, 224)) 123 | buffered = BytesIO() 124 | image.save(buffered, format="JPEG") 125 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 126 | img_str = f'user upload image' 127 | msg = msg.replace('', img_str) 128 | ret.append([msg, None]) 129 | else: 130 | ret[-1][-1] = msg 131 | return ret 132 | 133 | def copy(self): 134 | return Conversation( 135 | system=self.system, 136 | roles=self.roles, 137 | messages=[[x, y] for x, y in self.messages], 138 | offset=self.offset, 139 | sep_style=self.sep_style, 140 | sep=self.sep, 141 | sep2=self.sep2) 142 | 143 | def dict(self): 144 | if len(self.get_images()) > 0: 145 | return { 146 | "system": self.system, 147 | "roles": self.roles, 148 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], 149 | "offset": self.offset, 150 | "sep": self.sep, 151 | "sep2": self.sep2, 152 | } 153 | return { 154 | "system": self.system, 155 | "roles": self.roles, 156 | "messages": self.messages, 157 | "offset": self.offset, 158 | "sep": self.sep, 159 | "sep2": self.sep2, 160 | } 161 | 162 | 163 | conv_v1 = Conversation( 164 | system="A chat between a curious human and an artificial intelligence assistant. " 165 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 166 | roles=("Human", "Assistant"), 167 | messages=( 168 | ("Human", "Give three tips for staying healthy."), 169 | ("Assistant", 170 | "Sure, here are three tips for staying healthy:\n" 171 | "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. " 172 | "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, " 173 | "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or " 174 | "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening " 175 | "activities at least two days per week.\n" 176 | "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, " 177 | "vegetables, whole grains, lean proteins, and healthy fats can help support " 178 | "your overall health. Try to limit your intake of processed and high-sugar foods, " 179 | "and aim to drink plenty of water throughout the day.\n" 180 | "3. Get enough sleep: Getting enough quality sleep is essential for your physical " 181 | "and mental health. Adults should aim for seven to nine hours of sleep per night. " 182 | "Establish a regular sleep schedule and try to create a relaxing bedtime routine to " 183 | "help improve the quality of your sleep.") 184 | ), 185 | offset=2, 186 | sep_style=SeparatorStyle.SINGLE, 187 | sep="###", 188 | ) 189 | 190 | conv_v1_2 = Conversation( 191 | system="A chat between a curious human and an artificial intelligence assistant. " 192 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 193 | roles=("Human", "Assistant"), 194 | messages=( 195 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"), 196 | ("Assistant", 197 | "Renewable energy sources are those that can be replenished naturally in a relatively " 198 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " 199 | "Non-renewable energy sources, on the other hand, are finite and will eventually be " 200 | "depleted, such as coal, oil, and natural gas. Here are some key differences between " 201 | "renewable and non-renewable energy sources:\n" 202 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " 203 | "energy sources are finite and will eventually run out.\n" 204 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact " 205 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " 206 | "and other negative effects.\n" 207 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " 208 | "have lower operational costs than non-renewable sources.\n" 209 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " 210 | "locations than non-renewable sources.\n" 211 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " 212 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n" 213 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " 214 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") 215 | ), 216 | offset=2, 217 | sep_style=SeparatorStyle.SINGLE, 218 | sep="###", 219 | ) 220 | 221 | conv_vicuna_v1_1 = Conversation( 222 | system="A chat between a curious user and an artificial intelligence assistant. " 223 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 224 | roles=("USER", "ASSISTANT"), 225 | version="v1", 226 | messages=(), 227 | offset=0, 228 | sep_style=SeparatorStyle.TWO, 229 | sep=" ", 230 | sep2="", 231 | ) 232 | 233 | conv_bair_v1 = Conversation( 234 | system="BEGINNING OF CONVERSATION:", 235 | roles=("USER", "GPT"), 236 | messages=(), 237 | offset=0, 238 | sep_style=SeparatorStyle.TWO, 239 | sep=" ", 240 | sep2="", 241 | ) 242 | 243 | simple_conv = Conversation( 244 | system="You are LLaVA, a large language model trained by UW Madison WAIV Lab, based on LLaMA architecture." 245 | "You are designed to assist human with a variety of tasks using natural language." 246 | "Follow the instructions carefully.", 247 | roles=("Human", "Assistant"), 248 | messages=( 249 | ("Human", "Hi!"), 250 | ("Assistant", "Hi there! How can I help you today?\n") 251 | ), 252 | offset=2, 253 | sep_style=SeparatorStyle.SINGLE, 254 | sep="###", 255 | ) 256 | 257 | simple_conv_multimodal = Conversation( 258 | system="A chat between a curious user and an artificial intelligence assistant. " 259 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 260 | roles=("Human", "Assistant"), 261 | messages=( 262 | ), 263 | offset=0, 264 | sep_style=SeparatorStyle.SINGLE, 265 | sep="###", 266 | ) 267 | 268 | simple_conv_legacy = Conversation( 269 | system="You are LLaVA, a large language model trained by UW Madison WAIV Lab." 270 | "You are designed to assist human with a variety of tasks using natural language." 271 | "Follow the instructions carefully.", 272 | roles=("Human", "Assistant"), 273 | messages=( 274 | ("Human", "Hi!\n\n### Response:"), 275 | ("Assistant", "Hi there! How can I help you today?\n") 276 | ), 277 | offset=2, 278 | sep_style=SeparatorStyle.SINGLE, 279 | sep="###", 280 | ) 281 | 282 | conv_llava_v1 = Conversation( 283 | system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab." 284 | "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." 285 | "Follow the instructions carefully and explain your answers in detail.", 286 | roles=("USER", "ASSISTANT"), 287 | version="v1", 288 | messages=(), 289 | offset=0, 290 | sep_style=SeparatorStyle.TWO, 291 | sep=" ", 292 | sep2="", 293 | ) 294 | 295 | default_conversation = conv_v1_2 296 | conv_templates = { 297 | "default": conv_v1_2, 298 | "simple": simple_conv, 299 | "simple_legacy": simple_conv_legacy, 300 | "multimodal": simple_conv_multimodal, 301 | "llava_v1": conv_llava_v1, 302 | 303 | # fastchat 304 | "v1": conv_v1_2, 305 | "bair_v1": conv_bair_v1, 306 | "vicuna_v1_1": conv_vicuna_v1_1, 307 | } 308 | 309 | 310 | if __name__ == "__main__": 311 | print(default_conversation.get_prompt()) 312 | -------------------------------------------------------------------------------- /muffin/data/data_processors.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import glob 4 | import json 5 | import base64 6 | import random 7 | import pathlib 8 | 9 | from PIL import Image 10 | from typing import List 11 | 12 | class Register(dict): 13 | def __init__(self, *args, **kwargs): 14 | super(Register, self).__init__(*args, **kwargs) 15 | self._dict = {} 16 | 17 | def register(self, target): 18 | def add_register_item(keys, value): 19 | if not callable(value): 20 | raise Exception( 21 | f"Register object must be callable! But receice:{value} is not callable!") 22 | 23 | if not isinstance(keys, list): 24 | keys = [keys] 25 | 26 | for key in keys: 27 | if key in self._dict: 28 | print( 29 | f"error: \033[33m{value.__name__} has been registered before, so we will overriden it\033[0m") 30 | exit() 31 | 32 | self[key] = value 33 | return value 34 | 35 | if callable(target): 36 | return add_register_item(target.__name__, target) 37 | else: 38 | return lambda x: add_register_item(target, x) 39 | 40 | def __call__(self, target): 41 | return self.register(target) 42 | 43 | def __setitem__(self, key, value): 44 | self._dict[key] = value 45 | 46 | def __getitem__(self, key): 47 | return self._dict[key] 48 | 49 | def __contains__(self, key): 50 | return key in self._dict 51 | 52 | def __str__(self): 53 | return str(self._dict) 54 | 55 | def keys(self): 56 | return self._dict.keys() 57 | 58 | def values(self): 59 | return self._dict.values() 60 | 61 | def items(self): 62 | return self._dict.items() 63 | 64 | 65 | register_data_processor = Register() 66 | register_data_path = Register() 67 | 68 | 69 | def vqa_instruction_templates(question, idx=None): 70 | instructions = [ 71 | "{Question} A short answer to the question is", 72 | "Given the image, answer the following question with no more than three words. {Question}", 73 | "Based on the image, respond to this question with a short answer: {Question} Answer:", 74 | "Use the provided image to answer the question: {Question} Provide your answer as short as possible:", 75 | ] 76 | if idx is None: 77 | new_question = random.choice(instructions).replace("{Question}", question) 78 | else: 79 | new_question = instructions[idx].replace("{Question}", question) 80 | 81 | return new_question 82 | 83 | 84 | def caption_instruction_templates(): 85 | instructions = [ 86 | "Describe the image concisely.", 87 | "Provide a brief description of the given image.", 88 | "Offer a succinct explanation of the picture presented.", 89 | "Summarize the visual content of the image.", 90 | "Give a short and clear explanation of the subsequent image.", 91 | "Share a concise interpretation of the image provided.", 92 | "Present a compact description of the photo's key features.", 93 | "Relay a brief, clear account of the picture shown.", 94 | "Render a clear and concise summary of the photo.", 95 | "Write a terse but informative summary of the picture.", 96 | "Create a compact narrative representing the image presented." 97 | ] 98 | 99 | new_question = random.choice(instructions) 100 | 101 | return new_question 102 | 103 | 104 | def load_multimodal_conversation(text_b64, img_b64_buffer): 105 | map_role = { 106 | 'human': 'human', 107 | 'gpt': 'gpt' 108 | } 109 | 110 | text = base64.b64decode(text_b64).decode('utf-8') 111 | list_conv = json.loads(text) 112 | 113 | out: List[dict] = [] 114 | for idx, sentence in enumerate(list_conv): 115 | value = sentence['value'] 116 | 117 | if idx == 0 and '' not in value: 118 | value = f"\n{value}" 119 | if idx != 0 and '' in value: 120 | value = value.replace('', '') 121 | 122 | out.append({ 123 | 'from': map_role[sentence['from']], 124 | 'value': value 125 | }) 126 | 127 | img_io = io.BytesIO(base64.b64decode(img_b64_buffer)) 128 | img_io.seek(0) 129 | image = Image.open(img_io).convert('RGB') 130 | return image, out 131 | 132 | 133 | def b64_to_PIL_image(img_b64_buffer): 134 | img_io = io.BytesIO(base64.b64decode(img_b64_buffer)) 135 | img_io.seek(0) 136 | image = Image.open(img_io).convert('RGB') 137 | return image 138 | 139 | 140 | def wrap_qa_to_single_turn_multimodal_conv(answer, question): 141 | if '' not in question: 142 | question = f"\n{question}" 143 | 144 | out = [ 145 | {"from": "human", "value": question}, 146 | {"from": "gpt", "value": answer} 147 | ] 148 | return question, out 149 | 150 | 151 | def wrap_generation_single_turn_conv(out, template_func): 152 | conv = [ 153 | { 154 | "from": "human", 155 | "value": f"\n{template_func()}" 156 | 157 | }, 158 | { 159 | "from": "gpt", 160 | "value": out 161 | } 162 | ] 163 | return conv 164 | 165 | 166 | def wrap_caption_generation_single_turn_conv(out): 167 | return wrap_generation_single_turn_conv(out, caption_instruction_templates) 168 | 169 | 170 | def gather_data_files_by_glob(root: str, pattern='*.tsv'): 171 | filenames = [] 172 | 173 | for fullpath in glob.glob(f'{root}/{pattern}'): 174 | filename = fullpath.split('/')[-1] 175 | filenames.append(filename) 176 | return root, filenames 177 | 178 | 179 | @register_data_path('unimm-chat') 180 | def unimmchat_data_path(): 181 | data_dir = pathlib.Path(__file__).parent.resolve() / '../../data/unimm-chat' 182 | return gather_data_files_by_glob(data_dir, '*.tsv') 183 | 184 | 185 | @register_data_processor(['unimm-chat']) 186 | def unimmchat_processor(img_b64_buffer, text_b64, origin_dataset, origin_split, origin_split_inner_idx, img_path, 187 | intent, img_transformer=None): 188 | if intent == 'pretrain' or intent == 'sft': 189 | image, out = load_multimodal_conversation(text_b64, img_b64_buffer) 190 | 191 | metainfo = { 192 | "origin_dataset": origin_dataset, 193 | "origin_split": origin_split, 194 | "origin_idx": origin_split_inner_idx, 195 | "image_id": img_path, 196 | } 197 | 198 | return { 199 | 'image': image, 200 | 'conversations': out, 201 | 'idx': origin_split_inner_idx, 202 | 'metainfo': metainfo, 203 | } 204 | else: 205 | raise NotImplemented 206 | 207 | 208 | @register_data_processor('RLHF-V-Dataset') 209 | def dpo_cvpr_ncrp_vqa_processor(*args, **kwargs): 210 | return dpo_preference_processor(*args, **kwargs) 211 | 212 | @register_data_path('RLHF-V-Dataset') 213 | def dpo_cvpr_ncrp_vqa_path(): 214 | data_dir = pathlib.Path(__file__).parent.resolve() / '../../data/RLHF-V-Dataset' 215 | return gather_data_files_by_glob(data_dir, pattern='RLHF-V-Dataset_withlogp-1401.tsv') 216 | 217 | 218 | def dpo_preference_processor(img_b64_buffer, text_b64, origin_dataset, origin_split, origin_split_inner_idx, img_path, 219 | intent, img_transformer=None): 220 | if intent == 'pretrain' or intent == 'sft': 221 | text = base64.b64decode(text_b64).decode('utf-8') 222 | origin_split = base64.b64decode(origin_split).decode('utf-8') 223 | origin_split = json.loads(origin_split) 224 | list_conv = json.loads(text) 225 | 226 | assert len(list_conv) in [ 227 | 3, 4], f'length must be in [3, 4] for data w/ or w/o logps, bug got {len(list_conv)}' 228 | 229 | question = list_conv[0] 230 | if '' not in question: 231 | question = f"\n{question}" 232 | 233 | out_chosen = list_conv[1] 234 | out_rejected = list_conv[2] 235 | 236 | question = {"from": "human", "value": question} 237 | out_chosen = {"from": "gpt", "value": out_chosen} 238 | out_rejected = {"from": "gpt", "value": out_rejected} 239 | 240 | image = b64_to_PIL_image(img_b64_buffer) 241 | 242 | metainfo = { 243 | "origin_dataset": origin_dataset, 244 | "origin_split": origin_split, 245 | "origin_idx": origin_split_inner_idx, 246 | "image_id": img_path, 247 | } 248 | 249 | data_dict = { 250 | 'image': image, 251 | 'question': question, 252 | 'chosen': out_chosen, 253 | 'rejected': out_rejected, 254 | 'idx': origin_split_inner_idx, 255 | 'metainfo': metainfo, 256 | } 257 | 258 | if len(list_conv) == 4: 259 | (data_dict['ref_win_logp'], data_dict['ref_win_avg_logp'], data_dict['ref_win_per_token_logp'], 260 | data_dict['ref_rej_logp'], data_dict['ref_rej_avg_logp'], data_dict['ref_rej_per_token_logp']) = list_conv[3] 261 | 262 | return data_dict 263 | else: 264 | raise NotImplemented 265 | 266 | 267 | @register_data_path('vqav2-val') 268 | def vqav2_val_data_path(): 269 | data_dir = pathlib.Path(__file__).parent.resolve() / '../../data/VQAv2' 270 | _, filenames = gather_data_files_by_glob(data_dir) 271 | filenames = [f for f in filenames if 'val' in f] 272 | return data_dir, filenames 273 | 274 | 275 | @register_data_processor('vqav2-val') 276 | def vqav2_val_processor(img_b64_buffer, text_b64, origin_dataset, origin_split, origin_split_inner_idx, img_path, 277 | intent, img_transformer=None): 278 | if intent == 'eval': 279 | 280 | text = base64.b64decode(text_b64).decode('utf-8') 281 | origin_qa = json.loads(text) 282 | 283 | out: List[dict] = [] 284 | 285 | question = origin_qa["question"] 286 | answer = origin_qa["answer"] 287 | 288 | question, out = wrap_qa_to_single_turn_multimodal_conv(answer, question) 289 | 290 | image = b64_to_PIL_image(img_b64_buffer) 291 | 292 | metainfo = { 293 | "origin_dataset": origin_dataset, 294 | "origin_split": origin_split, 295 | "origin_idx": int(origin_split_inner_idx), 296 | "image_id": img_path, 297 | } 298 | 299 | return { 300 | 'image': image, 301 | 'conversations': out, 302 | 'idx': origin_split_inner_idx, 303 | 'metainfo': metainfo, 304 | 'origin_question': origin_qa["question"], 305 | } 306 | else: 307 | raise NotImplemented 308 | 309 | 310 | @register_data_path('vqav2-train') 311 | def vqav2_train_data_path(): 312 | data_dir = pathlib.Path(__file__).parent.resolve() / '../../data/VQAv2' 313 | _, filenames = gather_data_files_by_glob(data_dir) 314 | filenames = [f for f in filenames if 'train' in f] 315 | return data_dir, filenames 316 | 317 | 318 | @register_data_processor('vqav2-train') 319 | def vqav2_train_processor(img_b64_buffer, text_b64, origin_dataset, origin_split, origin_split_inner_idx, img_path, 320 | intent, img_transformer=None): 321 | if intent == 'pretrain' or intent == 'sft': 322 | 323 | text = base64.b64decode(text_b64).decode('utf-8') 324 | origin_qa = json.loads(text) 325 | 326 | out: List[dict] = [] 327 | 328 | question = origin_qa["question"] 329 | answer = origin_qa["answer"] 330 | question = vqa_instruction_templates(question) # vqa short answer template 331 | 332 | question, out = wrap_qa_to_single_turn_multimodal_conv(answer, question) 333 | 334 | image = b64_to_PIL_image(img_b64_buffer) 335 | 336 | metainfo = { 337 | "origin_dataset": origin_dataset, 338 | "origin_split": origin_split, 339 | "origin_idx": origin_split_inner_idx, 340 | "image_id": img_path, 341 | } 342 | 343 | return { 344 | 'image': image, 345 | 'conversations': out, 346 | 'idx': origin_split_inner_idx, 347 | 'metainfo': metainfo, 348 | } 349 | elif intent == 'eval': 350 | return vqav2_val_processor(img_b64_buffer, text_b64, origin_dataset, origin_split, origin_split_inner_idx, img_path, 351 | intent, img_transformer) 352 | else: 353 | raise NotImplemented 354 | -------------------------------------------------------------------------------- /muffin/data/datasets.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import json 4 | import torch 5 | import numpy 6 | import base64 7 | import pandas as pd 8 | import os.path as op 9 | import datasets as hf_datasets 10 | import torch.utils.data as torch_data 11 | 12 | from PIL import Image 13 | from typing import List, Iterator 14 | from muffin.data.tsv_file import TSVFile 15 | from muffin.data.data_processors import register_data_processor 16 | from muffin.eval.muffin_inference_logp import inference_logp 17 | 18 | def bytes_to_PIL_image(img_buffer): 19 | img_io = io.BytesIO(img_buffer) 20 | img_io.seek(0) 21 | image = Image.open(img_io).convert('RGB') 22 | return image 23 | 24 | def read_jsonl(file_path): 25 | with open(file_path, "r", encoding="utf-8") as file: 26 | return [json.loads(line) for line in file] 27 | 28 | class RLHFVDataset(torch_data.Dataset): 29 | def __init__(self, data_dir: str, ref_name: str, reference_model=None, 30 | tokenizer=None, image_token_len=None, img_processor=None, use_im_start_end=True): 31 | super().__init__() 32 | 33 | self.data_path = f'{data_dir}/{ref_name}_with_logp.parquet' 34 | 35 | if not op.exists(self.data_path): 36 | os.makedirs(data_dir, exist_ok=True) 37 | 38 | assert reference_model is not None, "`reference_model` is mandatory when logps do not exist." 39 | 40 | hf_data = hf_datasets.load_dataset("HaoyeZhang/RLHF-V-Dataset")['train'].cast_column("image", hf_datasets.Image(decode=False)) 41 | 42 | inference_logp(reference_model, tokenizer, hf_data, self.data_path, 43 | image_token_len, img_processor, use_im_start_end) 44 | 45 | torch.distributed.barrier() 46 | 47 | self.data = pd.read_parquet(self.data_path) 48 | 49 | # print(f'{torch.distributed.get_rank()} data len: {len(self.data)}') 50 | else: 51 | self.data = pd.read_parquet(self.data_path) 52 | 53 | def __len__(self): 54 | return len(self.data) 55 | 56 | def __getitem__(self, index): 57 | sample = self.data.iloc[index] 58 | text = json.loads(sample['text']) 59 | question = {'from': 'human', 'value': f"\n{text['question']}"} 60 | chosen = {'from': 'gpt', 'value': text['chosen']} 61 | rejected = {'from': 'gpt', 'value': text['rejected']} 62 | 63 | image = bytes_to_PIL_image(sample['image']['bytes']) 64 | 65 | metainfo = { 66 | "origin_dataset": sample['origin_dataset'], 67 | "origin_split": sample['origin_split'], 68 | "origin_idx": sample['idx'], 69 | "image_id": sample['image_path'], 70 | } 71 | 72 | data_dict = { 73 | 'image': image, 74 | "question": question, 75 | "chosen": chosen, 76 | "rejected": rejected, 77 | "idx": sample['idx'], 78 | "metainfo": metainfo 79 | } 80 | 81 | (data_dict['ref_win_logp'], data_dict['ref_win_avg_logp'], data_dict['ref_win_per_token_logp'], 82 | data_dict['ref_rej_logp'], data_dict['ref_rej_avg_logp'], data_dict['ref_rej_per_token_logp']) = text['logps'] 83 | 84 | return data_dict 85 | 86 | 87 | class MultimodalQADataset(torch_data.Dataset): 88 | def __init__(self, qa_file, question_process): 89 | ''' 90 | qa_file: jsonl file that each line is a dict like { 91 | 'image': b64img, 92 | 'question': question_text 93 | } 94 | ''' 95 | super().__init__() 96 | 97 | self.qa_file = qa_file 98 | self.qa_data = [json.loads(line) for line in open(self.qa_file)] 99 | if isinstance(self.qa_data[0], list): 100 | self.qa_data = self.qa_data[0] # unwrap one-line json question file 101 | 102 | self.question_process = question_process 103 | 104 | def __getitem__(self, index): 105 | item = self.qa_data[index] 106 | 107 | img_b64 = item['image'] 108 | image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert('RGB') 109 | 110 | raw_question = item['question'] 111 | question_text = self.question_process(raw_question) 112 | return { 113 | 'image': image, 114 | 'raw_question': raw_question, 115 | 'question': question_text 116 | } 117 | 118 | 119 | def __len__(self): 120 | return len(self.qa_data) 121 | 122 | 123 | class SingleDataSourceDataset(torch_data.Dataset): 124 | def __init__(self, ds_name, data_dir, tsv_filenames: List[str], intent='sft') -> None: 125 | super().__init__() 126 | 127 | self.data_dir = data_dir 128 | self.filenames = tsv_filenames 129 | self.ds_name = ds_name 130 | 131 | self.sizes = [] 132 | for filename in self.filenames: 133 | try: 134 | size = int(filename[:-4].split('-')[-1]) 135 | except: 136 | raise ValueError(f'TSV Data File {filename} is not valid, last component separated by `-` must be the number of sample in this file') 137 | self.sizes.append(size) 138 | 139 | self.file_border_index = [] 140 | self.prepare_border_index() 141 | 142 | self.files = self.filenames[:] 143 | self.intent = intent 144 | 145 | 146 | def prepare_border_index(self): 147 | self.file_border_index = [0] 148 | 149 | temp_sum = 0 150 | for size in self.sizes: 151 | temp_sum += size 152 | self.file_border_index.append(temp_sum) 153 | 154 | 155 | def get_file_idx_and_row_idx(self, index): 156 | found = False 157 | file_idx = -1 158 | 159 | for border_idx, border in enumerate(self.file_border_index): 160 | if index < border: 161 | file_idx = border_idx - 1 162 | found = True 163 | break 164 | if not found: 165 | raise ValueError(f'Index {index} out of range for {self.size_sum} border markers') 166 | 167 | offset = self.file_border_index[file_idx] 168 | row_idx = index - offset 169 | return file_idx, row_idx 170 | 171 | def __len__(self): 172 | return self.file_border_index[-1] 173 | 174 | def __getitem__(self, index): 175 | file_idx, row_idx = self.get_file_idx_and_row_idx(index) 176 | return self.fetch_sample(file_idx, row_idx) 177 | 178 | def fetch_sample(self, file_idx, row_idx): 179 | file = self.files[file_idx] 180 | if isinstance(file, str): 181 | self.prepare_file(file_idx) 182 | file = self.files[file_idx] 183 | 184 | assert isinstance(file, TSVFile), f'Expecting TSVFile but get {file} as {type(file)}' 185 | 186 | # tsv line as tuple 187 | sample = file[row_idx] 188 | ds_name, *values = sample 189 | 190 | # data dict 191 | sample = register_data_processor[self.ds_name](*values, intent=self.intent) 192 | 193 | if row_idx + 1 == len(file): 194 | del file 195 | self.files[file_idx] = self.filenames[file_idx] 196 | 197 | return sample 198 | 199 | def prepare_file(self, idx): 200 | filename = self.filenames[idx] 201 | file = TSVFile(op.join(self.data_dir, filename)) 202 | self.files[idx] = file 203 | 204 | 205 | class IterableSingleDataSourceDataset(torch_data.IterableDataset): 206 | def __init__(self) -> None: 207 | super().__init__() 208 | raise NotImplemented 209 | 210 | 211 | class MultiDataSourceDataset(torch_data.Dataset): 212 | def __init__(self, data_sources: List[SingleDataSourceDataset], data_source_weights: List[int]): 213 | super().__init__() 214 | 215 | self.ds_list = data_sources 216 | 217 | self.sum_weight = sum(data_source_weights) 218 | self.ds_weights = data_source_weights 219 | for weight in self.ds_weights: 220 | assert isinstance(weight, int), 'weight must be integer' 221 | 222 | self.offset2ds = {} 223 | self.offset2wt = {} 224 | self.offset2pd = {} 225 | self.prepare_offset2ds() 226 | 227 | ds_loops = [] 228 | for ds, wt in zip(self.ds_list, self.ds_weights): 229 | ds_loop = len(ds) // wt 230 | ds_loops.append(ds_loop) 231 | max_loop = max(ds_loops) 232 | self.size = max_loop * self.sum_weight 233 | 234 | def prepare_offset2ds(self): 235 | offset = 0 236 | for ds, weight in zip(self.ds_list, self.ds_weights): 237 | pd = offset 238 | for _ in range(weight): 239 | self.offset2ds[offset] = ds 240 | self.offset2wt[offset] = weight 241 | self.offset2pd[offset] = pd 242 | offset += 1 243 | 244 | def __getitem__(self, index): 245 | n_loop = index // self.sum_weight 246 | offset = index % self.sum_weight 247 | 248 | ds = self.offset2ds[offset] 249 | ds_inner_idx = n_loop * self.offset2wt[offset] + offset - self.offset2pd[offset] 250 | ds_inner_idx = ds_inner_idx % len(ds) 251 | 252 | return ds[ds_inner_idx] 253 | 254 | def __len__(self): 255 | return self.size 256 | 257 | 258 | class IterableMultiDataSourceDataset(torch_data.IterableDataset): 259 | def __init__(self, data_sources, data_source_weights): 260 | super().__init__() 261 | 262 | self.ds_list = data_sources 263 | 264 | sum_weight = sum(data_source_weights) 265 | self.ds_weights = [x / sum_weight for x in data_source_weights] 266 | 267 | self.ds_consumption = [] 268 | self.ds_sizes = [len(ds) for ds in self.ds_list] 269 | 270 | def __next__(self): 271 | ds_idx = numpy.random.choice(range(len(self.ds_list)), 1, p=self.ds_weights)[0] 272 | data_source = self.ds_list[ds_idx] 273 | 274 | self.ds_consumption[ds_idx] += 1 275 | if self.ds_consumption[ds_idx] % self.ds_sizes[ds_idx] == 0: 276 | self.report_consumption() 277 | 278 | sample = next(data_source) 279 | return sample 280 | 281 | def __iter__(self) -> Iterator: 282 | return self 283 | 284 | def __len__(self): 285 | return sum(self.ds_sizes) 286 | 287 | def report_consumption(self): 288 | for ds, consumption, size in zip(self.ds_list, self.ds_consumption, self.ds_sizes): 289 | print(f'Data {ds} consumption: {consumption / size:.2f} epoch', flush=True) 290 | 291 | -------------------------------------------------------------------------------- /muffin/data/tsv_file.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Microsoft Corporation. Licensed under the MIT license. 2 | import os 3 | import logging 4 | import os.path as op 5 | 6 | LARGEST_TSV_SIZE = 500_000 7 | 8 | # LARGEST_TSV_SIZE = 10_000 9 | 10 | def create_lineidx(filein, idxout): 11 | idxout_tmp = idxout + '.tmp' 12 | with open(filein, 'r') as tsvin, open(idxout_tmp, 'w') as tsvout: 13 | fsize = os.fstat(tsvin.fileno()).st_size 14 | fpos = 0 15 | while fpos != fsize: 16 | tsvout.write(str(fpos)+"\n") 17 | tsvin.readline() 18 | fpos = tsvin.tell() 19 | os.rename(idxout_tmp, idxout) 20 | 21 | 22 | def read_to_character(fp, c): 23 | result = [] 24 | while True: 25 | s = fp.read(32) 26 | assert s != '' 27 | if c in s: 28 | result.append(s[: s.index(c)]) 29 | break 30 | else: 31 | result.append(s) 32 | return ''.join(result) 33 | 34 | 35 | class TSVFile(object): 36 | def __init__(self, tsv_file, generate_lineidx=False): 37 | self.tsv_file = tsv_file 38 | self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' 39 | self._fp = None 40 | self._lineidx = None 41 | # the process always keeps the process which opens the file. 42 | # If the pid is not equal to the currrent pid, we will re-open the file. 43 | self.pid = None 44 | # generate lineidx if not exist 45 | if not op.isfile(self.lineidx) and generate_lineidx: 46 | create_lineidx(self.tsv_file, self.lineidx) 47 | 48 | def __del__(self): 49 | if self._fp: 50 | self._fp.close() 51 | 52 | def __str__(self): 53 | return "TSVFile(tsv_file='{}')".format(self.tsv_file) 54 | 55 | def __repr__(self): 56 | return str(self) 57 | 58 | def num_rows(self): 59 | self._ensure_lineidx_loaded() 60 | assert len(self._lineidx) <= LARGEST_TSV_SIZE, f'Do not support TSVFile larger than {LARGEST_TSV_SIZE} yet' 61 | return len(self._lineidx) 62 | 63 | def seek(self, idx): 64 | self._ensure_tsv_opened() 65 | self._ensure_lineidx_loaded() 66 | try: 67 | pos = self._lineidx[idx] 68 | except: 69 | logging.info('{}-{}'.format(self.tsv_file, idx)) 70 | raise 71 | self._fp.seek(pos) 72 | return [s.strip() for s in self._fp.readline().split('\t')] 73 | 74 | def seek_first_column(self, idx): 75 | self._ensure_tsv_opened() 76 | self._ensure_lineidx_loaded() 77 | pos = self._lineidx[idx] 78 | self._fp.seek(pos) 79 | return read_to_character(self._fp, '\t') 80 | 81 | def get_key(self, idx): 82 | return self.seek_first_column(idx) 83 | 84 | def __getitem__(self, index): 85 | return self.seek(index) 86 | 87 | def __len__(self): 88 | return self.num_rows() 89 | 90 | def _ensure_lineidx_loaded(self): 91 | if self._lineidx is None: 92 | logging.debug('loading lineidx: {}'.format(self.lineidx)) 93 | with open(self.lineidx, 'r') as fp: 94 | self._lineidx = [int(i.strip()) for i in fp.readlines()] 95 | 96 | def _ensure_tsv_opened(self): 97 | if self._fp is None: 98 | self._fp = open(self.tsv_file, 'r') 99 | self.pid = os.getpid() 100 | 101 | if self.pid != os.getpid(): 102 | # logging.info('re-open {} because the process id changed'.format(self.tsv_file)) 103 | self._fp = open(self.tsv_file, 'r') 104 | self.pid = os.getpid() -------------------------------------------------------------------------------- /muffin/data/tsv_file_op.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Microsoft Corporation. Licensed under the MIT license. 2 | import os 3 | import io 4 | import errno 5 | import pandas 6 | import base64 7 | 8 | import os.path as op 9 | 10 | from PIL import Image 11 | from tqdm import tqdm 12 | from typing import List 13 | from muffin.data.tsv_file import TSVFile 14 | from muffin.data.tsv_file import LARGEST_TSV_SIZE 15 | 16 | 17 | def mkdir(path): 18 | try: 19 | os.makedirs(path) 20 | except OSError as e: 21 | if e.errno != errno.EEXIST: 22 | raise 23 | 24 | 25 | def tsv_writer(value_lines: List[List[str]], tsv_file, sep='\t'): 26 | mkdir(op.dirname(tsv_file)) 27 | lineidx_file = op.splitext(tsv_file)[0] + '.lineidx' 28 | 29 | idx = 0 30 | tsv_file_tmp = tsv_file + '.tmp' 31 | lineidx_file_tmp = lineidx_file + '.tmp' 32 | 33 | with open(tsv_file_tmp, 'w') as fp, open(lineidx_file_tmp, 'w') as fpidx: 34 | assert value_lines is not None 35 | 36 | line_value: List[str] 37 | for line_value in value_lines: 38 | assert line_value is not None 39 | line_value = [v if type(v) != bytes else v.decode('utf-8') for v in line_value] 40 | v = '{0}\n'.format(sep.join(map(str, line_value))) 41 | fp.write(v) 42 | fpidx.write(str(idx) + '\n') 43 | idx = idx + len(v) 44 | 45 | os.rename(tsv_file_tmp, tsv_file) 46 | os.rename(lineidx_file_tmp, lineidx_file) 47 | 48 | 49 | def b64img_ok(b64img): 50 | try: 51 | img_io = io.BytesIO(base64.b64decode(b64img)) 52 | img_io.seek(0) 53 | image = Image.open(img_io).convert('RGB') 54 | except: 55 | return False 56 | return True 57 | 58 | 59 | def write_line(fp, fpidx, line_value, sep, idx): 60 | for value in line_value: 61 | assert isinstance(value, str), f'{type(value)}-{value}' 62 | 63 | v = '{0}\n'.format(sep.join(line_value)) 64 | fp.write(v) 65 | fpidx.write(str(idx) + '\n') 66 | idx += len(v) 67 | return idx 68 | 69 | def open_new_file(base_name, ext, counter): 70 | f_name = f"{base_name}_{counter}" 71 | return open(f"{f_name}.{ext}", 'w'), f_name 72 | 73 | def multimodal_img_tsv_writer(value_lines, tsv_file, sep='\t', text_only=False): 74 | mkdir(op.dirname(tsv_file)) 75 | # lineidx_file = op.splitext(tsv_file)[0] + '.lineidx' 76 | 77 | tsv_name = op.splitext(tsv_file)[0] 78 | 79 | idx = 0 80 | num_row = 0 81 | file_counter = 0 82 | 83 | fp, fp_name = open_new_file(tsv_name, 'tsv.tmp', file_counter) 84 | fpidx, fpidx_name = open_new_file(tsv_name, 'lineidx.tmp', file_counter) 85 | 86 | assert value_lines is not None 87 | 88 | for dataset_name, img_buffer, text_b64, origin_dataset, origin_split, origin_split_inner_idx, img_path in value_lines: 89 | if not b64img_ok(img_buffer) and not text_only: 90 | print(f'Image value cannot be interpreted as b64 str of image: {origin_dataset} | {origin_split} | {origin_split_inner_idx}') 91 | continue 92 | # raise ValueError('Image value cannot be interpreted as b64 str of image') 93 | 94 | line_value = [dataset_name, img_buffer, text_b64, origin_dataset, 95 | origin_split, origin_split_inner_idx, img_path] 96 | 97 | if num_row >= LARGEST_TSV_SIZE: # LARGEST_TSV_SIZE 98 | fp.close() 99 | fpidx.close() 100 | os.rename(f"{fp_name}.tsv.tmp", f'{fp_name}-{num_row}.tsv') 101 | os.rename(f"{fpidx_name}.lineidx.tmp", f'{fpidx_name}-{num_row}.lineidx') 102 | 103 | file_counter += 1 104 | num_row = 0 105 | idx = 0 106 | fp, fp_name = open_new_file(tsv_name, 'tsv.tmp', file_counter) 107 | fpidx, fpidx_name = open_new_file(tsv_name, 'lineidx.tmp', file_counter) 108 | 109 | idx = write_line(fp, fpidx, line_value, sep=sep, idx=idx) 110 | num_row += 1 111 | 112 | fp.close() 113 | fpidx.close() 114 | 115 | os.rename(f"{fp_name}.tsv.tmp", f'{fp_name}-{num_row}.tsv') 116 | os.rename(f"{fpidx_name}.lineidx.tmp", f'{fpidx_name}-{num_row}.lineidx') 117 | 118 | 119 | # prev 120 | def multimodal_img_tsv_writer_prev(value_lines, tsv_file, sep='\t', text_only=False): 121 | mkdir(op.dirname(tsv_file)) 122 | lineidx_file = op.splitext(tsv_file)[0] + '.lineidx' 123 | 124 | idx = 0 125 | tsv_file_tmp = tsv_file + '.tmp' 126 | lineidx_file_tmp = lineidx_file + '.tmp' 127 | num_row = 0 128 | 129 | with open(tsv_file_tmp, 'w') as fp, open(lineidx_file_tmp, 'w') as fpidx: 130 | assert value_lines is not None 131 | 132 | for dataset_name, img_buffer, text_b64, origin_dataset, origin_split, origin_split_inner_idx, img_path in value_lines: 133 | if not b64img_ok(img_buffer) and not text_only: 134 | raise ValueError('Image value cannot be interpreted as b64 str of image') 135 | 136 | line_value = [dataset_name, img_buffer, text_b64, origin_dataset, 137 | origin_split, origin_split_inner_idx, img_path] 138 | for value in line_value: 139 | assert isinstance(value, str), f'{type(value)}-{value}' 140 | 141 | v = '{0}\n'.format(sep.join(line_value)) 142 | fp.write(v) 143 | fpidx.write(str(idx) + '\n') 144 | num_row += 1 145 | idx = idx + len(v) 146 | 147 | os.rename(tsv_file_tmp, f'{tsv_file}-{num_row}.tsv') 148 | os.rename(lineidx_file_tmp, f'{tsv_file}-{num_row}.lineidx') 149 | -------------------------------------------------------------------------------- /muffin/eval/muffin_chat.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from PIL import Image 4 | from muffin.eval.muffin_vqa import init_muffin, wrap_question_with_default_conv, KeywordsStoppingCriteria 5 | from muffin.train.train_utils import _add_speaker_and_signal, _tokenize_fn 6 | from muffin import conversation as conversation_lib 7 | 8 | class MuffinForSingleTurnChat: 9 | def __init__(self,model, img_processor, image_token_len, tokenizer) -> None: 10 | self.model = model 11 | self.image_token_len = image_token_len 12 | self.image_transform = img_processor 13 | self.tokenizer = tokenizer 14 | 15 | def decode(self, image, input_ids): 16 | keywords = ['###'] 17 | with torch.inference_mode(): 18 | num_beams = 3 19 | input_size = input_ids.shape[-1] 20 | stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_size) 21 | # print(f'Input: {self.tokenizer.batch_decode(input_ids)}' 22 | # f'input_ids: {input_ids}') 23 | 24 | output = self.model.generate( 25 | input_ids=input_ids.unsqueeze(0).cuda(), 26 | images=image.unsqueeze(0).half().cuda(), 27 | temperature=0.7, 28 | max_new_tokens=1024, 29 | num_beams=num_beams, 30 | # do_sample=True, 31 | output_scores=True, 32 | return_dict_in_generate=True, 33 | stopping_criteria=[stopping_criteria], 34 | repetition_penalty=1.1) 35 | 36 | response = self.tokenizer.decode(output.sequences[0][input_size:], skip_special_tokens=True) 37 | # print(f'raw response is {response}') 38 | if response.count('###'): 39 | response = response[: response.index('###')] 40 | if response.count('Assistant:'): 41 | response = response[response.index('Assistant:') + len('Assistant:'):] 42 | response = response.strip() 43 | return response 44 | 45 | def chat(self, image_path, question): 46 | image = Image.open(image_path).convert('RGB') 47 | question = wrap_question_with_default_conv(question, self.image_token_len) 48 | 49 | tokenized = self.tokenizer([question]) 50 | input_ids = torch.as_tensor(tokenized['input_ids'][0]) 51 | image = self.image_transform(image) 52 | 53 | return self.decode(image, input_ids) 54 | 55 | 56 | class MuffinForMultiTurnChat(MuffinForSingleTurnChat): 57 | def __init__(self, model, img_processor, image_token_len, tokenizer) -> None: 58 | super(MuffinForMultiTurnChat, self).__init__(model, img_processor, image_token_len, tokenizer) 59 | self.history = [] 60 | self.image = None 61 | 62 | def _update_history(self, question, out): 63 | self.history.append({ 64 | 'from': 'human', 65 | 'value': question 66 | }) 67 | self.history.append({ 68 | 'from': 'gpt', 69 | 'value': out 70 | }) 71 | 72 | def start_chat(self, image_path, raw_question): 73 | image = Image.open(image_path).convert('RGB') 74 | question = wrap_question_with_default_conv(raw_question, self.image_token_len) 75 | 76 | tokenized = self.tokenizer([question]) 77 | input_ids = torch.as_tensor(tokenized['input_ids'][0]) 78 | image = self.image_transform(image) 79 | 80 | out = self.decode(image, input_ids) 81 | self._update_history(raw_question, out) 82 | self.image = image 83 | return out 84 | 85 | def resume(self, question): 86 | if self.image is None or len(self.history) == 0: 87 | print(f'Please first start chat before resuming.') 88 | return '' 89 | conv = copy.deepcopy(self.history) + [{ 90 | 'from': 'human', 91 | 'value': question 92 | }] 93 | header = f"{conversation_lib.default_conversation.system}\n\n" 94 | conv = _add_speaker_and_signal(header, conv) 95 | conv = conv.strip() 96 | input_ids = _tokenize_fn([conv], self.tokenizer)['input_ids'][0] 97 | 98 | out = self.decode(self.image, input_ids) 99 | self._update_history(question, out) 100 | return out 101 | 102 | def clear(self): 103 | self.history = [] 104 | self.image = None 105 | 106 | 107 | # model, img_processor, image_token_len, tokenizer = init_muffin('/home/yutianyu/Muffin_checkpoints/SFT_exp/muffin_13b_SFT-Muffin_QA_win_SFT_combine-vqav2-train#dpo_sftwin_checked_1005-1026#dpo_sftwin_checked_1103-1106-1#1#1-beit3_large_patch16_448/checkpionts/checkpoint-20/') 108 | 109 | if __name__ == '__main__': 110 | model, img_processor, image_token_len, tokenizer = init_muffin('/home/yutianyu/Muffin_checkpoints/SFT_exp/muffin_13b_SFT-Muffin_QA_win_SFT_combine-vqav2-train#dpo_sftwin_checked_1005-1026#dpo_sftwin_checked_1103-1106-1#1#1-beit3_large_patch16_448/checkpionts/checkpoint-20/') 111 | chat_model = MuffinForSingleTurnChat(model, img_processor, image_token_len, tokenizer) -------------------------------------------------------------------------------- /muffin/eval/muffin_inference_logp.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | import tqdm 4 | import copy 5 | import torch 6 | import itertools 7 | import pandas as pd 8 | import torch.utils.data as torch_data 9 | import PIL.Image as PIL_image 10 | 11 | from functools import partial 12 | 13 | from muffin.train.train_utils import encode_multimodal_preference_sample, SFT_collator_fn 14 | 15 | 16 | def bytes_to_PIL_image(img_buffer): 17 | img_io = io.BytesIO(img_buffer) 18 | img_io.seek(0) 19 | image = PIL_image.open(img_io).convert('RGB') 20 | return image 21 | 22 | 23 | class InferenceSampler(torch.utils.data.sampler.Sampler): 24 | 25 | def __init__(self, size): 26 | self._size = int(size) 27 | assert size > 0 28 | self._rank = torch.distributed.get_rank() 29 | self._world_size = torch.distributed.get_world_size() 30 | self._local_indices = self._get_local_indices(size, self._world_size, 31 | self._rank) 32 | 33 | @staticmethod 34 | def _get_local_indices(total_size, world_size, rank): 35 | shard_size = total_size // world_size 36 | left = total_size % world_size 37 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 38 | 39 | begin = sum(shard_sizes[:rank]) 40 | end = min(sum(shard_sizes[:rank + 1]), total_size) 41 | return range(begin, end) 42 | 43 | def __iter__(self): 44 | yield from self._local_indices 45 | 46 | def __len__(self): 47 | return len(self._local_indices) 48 | 49 | 50 | def get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, return_per_token_logp=False, return_all=False) -> torch.FloatTensor: 51 | """Compute the log probabilities of the given labels under the given logits. 52 | 53 | Args: 54 | logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) 55 | labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length) 56 | Returns: 57 | A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. 58 | """ 59 | assert logits.shape[:-1] == labels.shape 60 | 61 | labels = labels[:, 1:].clone() 62 | logits = logits[:, :-1, :] 63 | loss_mask = (labels != -100) 64 | 65 | # dummy token; we'll ignore the losses on these tokens later 66 | labels[labels == -100] = 0 67 | 68 | per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, 69 | index=labels.unsqueeze(2)).squeeze(2) 70 | 71 | log_prob = (per_token_logps * loss_mask).sum(-1) 72 | average_log_prob = log_prob / loss_mask.sum(-1) 73 | 74 | # print(per_token_logps.shape, labels.shape) 75 | if return_per_token_logp: 76 | return per_token_logps 77 | 78 | if return_all: 79 | return per_token_logps, log_prob, average_log_prob 80 | 81 | return log_prob, average_log_prob 82 | 83 | 84 | class PreferenceInferenceDataset(torch_data.Dataset): 85 | def __init__(self, 86 | data, 87 | tokenizer, 88 | image_token_len, 89 | img_processor, 90 | use_im_start_end=True): 91 | 92 | self.data = data 93 | 94 | self.mm_cfg = { 95 | 'image_processor': img_processor, 96 | 'is_multimodal': True, 97 | 'image_token_len': image_token_len, 98 | 'use_im_start_end': use_im_start_end 99 | } 100 | self.tokenizer = tokenizer 101 | 102 | def __getitem__(self, index): 103 | sample = self.data[index] 104 | metainfo = { 105 | "origin_dataset": sample['origin_dataset'], 106 | "origin_split": json.loads(sample['origin_split']), 107 | "origin_idx": sample['idx'], 108 | "image_id": sample['image_path'], 109 | } 110 | 111 | text = json.loads(sample['text']) 112 | question = {'from': 'human', 'value': f"\n{text['question']}"} 113 | chosen = {'from': 'gpt', 'value': text['chosen']} 114 | rejected = {'from': 'gpt', 'value': text['rejected']} 115 | 116 | image = bytes_to_PIL_image(sample['image']['bytes']) 117 | 118 | formated_sample = { 119 | 'image': image, 120 | "question": question, 121 | "chosen": chosen, 122 | "rejected": rejected, 123 | "idx": sample['idx'], 124 | "metainfo": metainfo 125 | } 126 | rej_data_dict, win_data_dict = encode_multimodal_preference_sample(formated_sample, self.tokenizer, self.mm_cfg) 127 | return rej_data_dict, win_data_dict 128 | 129 | def __len__(self): 130 | return len(self.data) 131 | 132 | 133 | def pretty_print(data_dict, tokenizer): 134 | input_ids = data_dict['input_ids'] 135 | input_str = tokenizer.decode(input_ids) 136 | print(f'input_ids.shape={input_ids.shape}\ninput_str is {input_str}') 137 | 138 | label_ids = data_dict['labels'] 139 | print(f'label_ids.shape={input_ids.shape}') 140 | for i, o in zip(input_ids, label_ids): 141 | i_tok = tokenizer.convert_ids_to_tokens(i.item()) 142 | o_tok = tokenizer.convert_ids_to_tokens(o.item()) if o.item() != -100 else '[SKIP]' 143 | print(f'{i_tok:10s} => {o_tok:10s}') 144 | 145 | 146 | def concate_pad(tensorA, tensorB, padding_value): 147 | out = torch.nn.utils.rnn.pad_sequence( 148 | list(tensorA) + list(tensorB), 149 | batch_first=True, 150 | padding_value=padding_value) 151 | return out 152 | 153 | 154 | def preference_collator_fn(instances, pad_token_id): 155 | rej_instances, win_instances = list(zip(*instances)) 156 | rej_batch = SFT_collator_fn(rej_instances, pad_token_id) 157 | win_batch = SFT_collator_fn(win_instances, pad_token_id) 158 | 159 | concatenated_input_ids = concate_pad(win_batch['input_ids'], rej_batch['input_ids'], pad_token_id) 160 | concatenated_labels = concate_pad(win_batch['labels'], rej_batch['labels'], -100) 161 | concatenated_attention_mask = concatenated_input_ids.ne(pad_token_id) 162 | 163 | batch = dict( 164 | concatenated_input_ids=concatenated_input_ids, 165 | concatenated_labels=concatenated_labels, 166 | concatenated_attention_mask=concatenated_attention_mask, 167 | win_input_ids=win_batch['input_ids'], 168 | rej_input_ids=rej_batch['input_ids'], 169 | win_labels=win_batch['labels'], 170 | rej_labels=rej_batch['labels'], 171 | win_attention_mask=win_batch['attention_mask'], 172 | rej_attention_mask=rej_batch['attention_mask'], 173 | images=win_batch['images'], 174 | ) 175 | return batch 176 | 177 | 178 | def get_multimodal_sample_logps(model, dataloader): 179 | win_logp_list = [] 180 | rej_logp_list = [] 181 | 182 | win_avg_logp_list = [] 183 | rej_avg_logp_list = [] 184 | 185 | win_per_token_logp_list = [] 186 | rej_per_token_logp_list = [] 187 | 188 | with torch.inference_mode(): 189 | for batch in tqdm.tqdm(dataloader): 190 | for key in ['win', 'rej']: 191 | input_ids = batch[f'{key}_input_ids'].cuda() 192 | labels = batch[f'{key}_labels'].cuda() 193 | attention_mask = batch[f'{key}_attention_mask'].cuda() 194 | 195 | output = model( 196 | input_ids=input_ids, 197 | labels=labels, 198 | attention_mask=attention_mask, 199 | images=batch['images'].half().cuda() 200 | ) 201 | per_token_logp, log_prob, average_log_prob = get_batch_logps(output.logits, labels, return_all=True) 202 | 203 | # print(per_token_logp.shape, input_ids.shape, labels.shape, flush=True) 204 | assert per_token_logp.size(1) >= input_ids.size(1) - 1 205 | per_token_logp = per_token_logp.tolist() 206 | # per_token_logp = [x[:input_ids[i].ne(tokenizer.pad_token_id).sum().item()] for i, x in enumerate(per_token_logp)] 207 | log_prob = log_prob.tolist() 208 | average_log_prob = average_log_prob.tolist() 209 | 210 | if key == 'win': 211 | win_logp_list += log_prob 212 | win_avg_logp_list += average_log_prob 213 | win_per_token_logp_list += per_token_logp 214 | else: 215 | rej_logp_list += log_prob 216 | rej_avg_logp_list += average_log_prob 217 | rej_per_token_logp_list += per_token_logp 218 | # print(f'{key} logits in {output.logits.shape}, logp in {log_prob.shape} avg_logp in {average_log_prob.shape}') 219 | 220 | return win_logp_list, win_avg_logp_list, win_per_token_logp_list, rej_logp_list, rej_avg_logp_list, rej_per_token_logp_list 221 | 222 | 223 | def write_logp_to_preference_parquet(origin_data, cache_file, logps, overwrite_logps=False): 224 | out_data = [] 225 | 226 | for index in range(len(origin_data)): 227 | line = origin_data[index] 228 | logp_data = logps[index] 229 | 230 | new_line = copy.deepcopy(line) 231 | 232 | text = json.loads(new_line['text']) 233 | 234 | if 'logps' in text.keys(): 235 | assert overwrite_logps, 'Found existing logp data, pass overwrite_logps=True to force overwritting' 236 | text['logps'] = logp_data 237 | new_line['text'] = json.dumps(text) 238 | 239 | else: 240 | assert list(text.keys()) == ['question', 'chosen', 'rejected'], f'Undefined data structure, expecting [Q, Win, Rej], got {text.keys()}' 241 | text['logps'] = logp_data 242 | new_line['text'] = json.dumps(text) 243 | 244 | out_data.append(new_line) 245 | 246 | df = pd.DataFrame(out_data) 247 | 248 | if torch.distributed.get_rank() == 0: 249 | df.to_parquet(cache_file) 250 | 251 | torch.distributed.barrier() 252 | 253 | return df 254 | 255 | def inference_logp(model, tokenizer, hf_data, cache_file, image_token_len, img_processor, use_im_start_end): 256 | model = model.to(dtype=torch.bfloat16, device='cuda') 257 | dataset = PreferenceInferenceDataset(tokenizer=tokenizer, 258 | data = hf_data, 259 | image_token_len=image_token_len, 260 | img_processor=img_processor, 261 | use_im_start_end=use_im_start_end) 262 | collate_fn = partial(preference_collator_fn, pad_token_id=tokenizer.pad_token_id) 263 | dataloader = torch_data.DataLoader(dataset, batch_size=1, collate_fn=collate_fn, 264 | num_workers=5, shuffle=False, sampler=InferenceSampler(len(dataset))) 265 | 266 | outputs = get_multimodal_sample_logps(model, dataloader) # win_logp_list, win_avg_logp_list, win_per_token_logp_list, rej_logp_list, rej_avg_logp_list, rej_per_token_logp_list 267 | 268 | world_size = torch.distributed.get_world_size() 269 | merged_outputs = [[None for _ in range(world_size)] for i in range(len(outputs))] 270 | for i in range(len(outputs)): 271 | torch.distributed.all_gather_object(merged_outputs[i], outputs[i]) 272 | merged_outputs[i] = [_ for _ in itertools.chain.from_iterable(merged_outputs[i])] 273 | 274 | win_logp_list, win_avg_logp_list, win_per_token_logp_list, rej_logp_list, rej_avg_logp_list, rej_per_token_logp_list \ 275 | = merged_outputs 276 | 277 | logps = list(zip(win_logp_list, win_avg_logp_list, win_per_token_logp_list, rej_logp_list, rej_avg_logp_list, rej_per_token_logp_list)) 278 | 279 | df = write_logp_to_preference_parquet(dataset.data, cache_file, logps, overwrite_logps=False) 280 | 281 | torch.distributed.barrier() 282 | 283 | del model 284 | return df -------------------------------------------------------------------------------- /muffin/eval/muffin_vqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import argparse 5 | 6 | import torch.utils.data as torch_data 7 | 8 | import tqdm 9 | from functools import partial 10 | from transformers import AutoTokenizer, AutoConfig 11 | 12 | from muffin import Beit3LlavaLlamaForCausalLM 13 | from muffin.conversation import conv_templates 14 | from muffin.utils import disable_torch_init 15 | from muffin.model.utils import build_transform 16 | from transformers import StoppingCriteria 17 | from muffin.data.datasets import MultimodalQADataset 18 | 19 | 20 | DEFAULT_IMAGE_TOKEN = "" 21 | DEFAULT_IMAGE_PATCH_TOKEN = "" 22 | DEFAULT_IM_START_TOKEN = "" 23 | DEFAULT_IM_END_TOKEN = "" 24 | 25 | 26 | def patch_config(config): 27 | patch_dict = { 28 | "use_mm_proj": True, 29 | "mm_vision_tower": "openai/clip-vit-large-patch14", 30 | "mm_hidden_size": 1024 31 | } 32 | 33 | cfg = AutoConfig.from_pretrained(config) 34 | if not hasattr(cfg, "mm_vision_tower"): 35 | print(f'`mm_vision_tower` not found in `{config}`, applying patch and save to disk.') 36 | for k, v in patch_dict.items(): 37 | setattr(cfg, k, v) 38 | cfg.save_pretrained(config) 39 | 40 | 41 | def expand_question_into_multimodal(question_text, image_token_len, im_st_token, im_ed_token, im_patch_token): 42 | if '' in question_text: 43 | question_text = question_text.replace('', '') 44 | 45 | question_text = question_text + '\n' + im_st_token + im_patch_token * image_token_len + im_ed_token 46 | return question_text 47 | 48 | 49 | def wrap_question_with_default_conv(question_text, image_token_len): 50 | question_text = expand_question_into_multimodal( 51 | question_text, image_token_len, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN) 52 | conv = conv_templates['default'].copy() 53 | conv.messages = [] 54 | conv.sep = '\n###' 55 | 56 | conv.append_message(conv.roles[0], question_text) 57 | prompt = conv.get_prompt() 58 | return prompt 59 | 60 | 61 | def torch_pad_sequence(sequence, padding_value, batch_first=True, padding_side='right'): 62 | 63 | if padding_side == 'right': 64 | sequence = torch.nn.utils.rnn.pad_sequence( 65 | sequence, 66 | batch_first=batch_first, 67 | padding_value=padding_value) 68 | elif padding_side == 'left': 69 | sequence = torch.nn.utils.rnn.pad_sequence( 70 | [v.flip(-1) for v in sequence], 71 | batch_first=batch_first, 72 | padding_value=padding_value) 73 | sequence = sequence.flip(-1) 74 | else: 75 | raise NotImplementedError(f'padding_size={padding_side}') 76 | return sequence 77 | 78 | 79 | def qa_colloator_fn(data_list, tokenizer, img_transform): 80 | questions = [x['question'] for x in data_list] 81 | tokenized = tokenizer(questions) 82 | 83 | input_ids = [torch.as_tensor(v) for v in tokenized['input_ids']] 84 | input_ids = torch_pad_sequence(input_ids, tokenizer.pad_token_id, padding_side='left') 85 | 86 | attn_mask = [torch.as_tensor(v) for v in tokenized['attention_mask']] 87 | attn_mask = torch_pad_sequence(attn_mask, 0, padding_side='left') 88 | 89 | images = [img_transform(x['image']) for x in data_list] 90 | images = torch.stack(images) 91 | 92 | raw_questions = [x['raw_question'] for x in data_list] 93 | data = { 94 | 'images': images, 95 | 'input_ids': input_ids, 96 | 'attention_mask': attn_mask, 97 | 'raw_questions': raw_questions 98 | } 99 | 100 | if 'question_id' in data_list[0]: 101 | data['question_id'] = [x['question_id'] for x in data_list] 102 | 103 | return data 104 | 105 | 106 | class KeywordsStoppingCriteria(StoppingCriteria): 107 | def __init__(self, keywords, tokenizer, input_size): 108 | self.keywords = keywords 109 | self.tokenizer = tokenizer 110 | self.input_size = input_size 111 | 112 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 113 | for o in output_ids: 114 | o = self.tokenizer.decode(o[self.input_size:], skip_special_tokens=True) 115 | if all([keyword not in o for keyword in self.keywords]): 116 | return False 117 | return True 118 | 119 | def init_muffin(model_path, device = None): 120 | disable_torch_init() 121 | model_name = os.path.expanduser(model_path) 122 | print(f'Load muffin model and tokenizer from {model_name}') 123 | tokenizer = AutoTokenizer.from_pretrained(model_name) 124 | 125 | patch_config(model_name) 126 | model = Beit3LlavaLlamaForCausalLM.from_pretrained( 127 | model_name, torch_dtype=torch.float16).cuda() 128 | image_processor = build_transform( 129 | is_train=False, input_size=model.model.vision_tower.args.img_size) 130 | 131 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 132 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 133 | if mm_use_im_start_end: 134 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 135 | 136 | vision_tower = model.model.vision_tower 137 | if device is not None: 138 | vision_tower.to(device=device, dtype=torch.float16) 139 | else: 140 | vision_tower.to(device='cuda', dtype=torch.float16) 141 | 142 | vision_config = model.model.vision_config 143 | vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] 144 | vision_config.use_im_start_end = mm_use_im_start_end 145 | if mm_use_im_start_end: 146 | vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids( 147 | [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) 148 | image_token_len = model.model.config.num_query 149 | 150 | return model, image_processor, image_token_len, tokenizer 151 | 152 | 153 | def eval_model(args): 154 | model, image_processor, image_token_len, tokenizer = init_muffin(args.model_name) 155 | 156 | answers_file = os.path.expanduser(args.answers_file) 157 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 158 | 159 | qa_dataset = MultimodalQADataset(args.question_file, partial( 160 | wrap_question_with_default_conv, image_token_len=image_token_len)) 161 | 162 | collate_fn = partial(qa_colloator_fn, tokenizer=tokenizer, img_transform=image_processor) 163 | dataloader = torch_data.DataLoader(qa_dataset, batch_size=1, collate_fn=collate_fn) 164 | 165 | keywords = ['###'] 166 | ans_file = open(answers_file, "w") 167 | question_idx = 0 168 | 169 | with torch.inference_mode(): 170 | for batch in tqdm.tqdm(dataloader, f'Generating answers'): 171 | num_beams = 3 172 | input_size = batch['input_ids'].shape[-1] 173 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_size) 174 | # print(f'Input: {tokenizer.batch_decode(batch["input_ids"])}' 175 | # f'input_ids: {batch["input_ids"]}' 176 | # f'attn_mask: {batch["attention_mask"]}') 177 | 178 | output = model.generate( 179 | input_ids=batch['input_ids'].cuda(), 180 | images=batch['images'].half().cuda(), 181 | attention_mask=batch['attention_mask'].cuda(), 182 | temperature=0.7, 183 | max_new_tokens=1024, 184 | num_beams=num_beams, 185 | # do_sample=True, 186 | output_scores=True, 187 | return_dict_in_generate=True, 188 | stopping_criteria=[stopping_criteria], 189 | repetition_penalty=1.1) 190 | 191 | for question, output_ids in zip(batch['raw_questions'], output.sequences): 192 | response = tokenizer.decode(output_ids[input_size:], skip_special_tokens=True) 193 | if response.count('###'): 194 | response = response[: response.index('###')] 195 | response = response.strip() 196 | # print(f'{question}, {response}\n') 197 | 198 | ans_file.write(json.dumps({ 199 | "question_id": question_idx, 200 | "prompt": question, 201 | "text": response, 202 | "model_id": args.model_name 203 | }) + "\n") 204 | ans_file.flush() 205 | question_idx += 1 206 | ans_file.close() 207 | 208 | 209 | if __name__ == "__main__": 210 | parser = argparse.ArgumentParser() 211 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 212 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 213 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 214 | parser.add_argument("--mm-projector", type=str, default=None) 215 | parser.add_argument("--vision-tower", type=str, default=None) 216 | parser.add_argument("--conv-mode", type=str, default="default") 217 | args = parser.parse_args() 218 | 219 | eval_model(args) 220 | -------------------------------------------------------------------------------- /muffin/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .llava import LlavaLlamaForCausalLM 2 | from .muffin import Beit3LlavaLlamaForCausalLM 3 | -------------------------------------------------------------------------------- /muffin/model/beit3.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beit3 4 | # Copyright (c) 2023 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # --------------------------------------------------------' 7 | 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_ 12 | from timm.models.registry import register_model 13 | 14 | from torchscale.model.BEiT3 import BEiT3 15 | from torchscale.architecture.config import EncoderConfig 16 | 17 | def trunc_normal_(tensor, mean=0., std=1.): 18 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 19 | 20 | 21 | def _get_base_config( 22 | img_size=224, patch_size=16, drop_path_rate=0, 23 | checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs 24 | ): 25 | return EncoderConfig( 26 | img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True, 27 | layernorm_embedding=False, normalize_output=True, no_output_layer=True, 28 | drop_path_rate=drop_path_rate, encoder_embed_dim=768, encoder_attention_heads=12, 29 | encoder_ffn_embed_dim=int(768 * mlp_ratio), encoder_layers=12, 30 | checkpoint_activations=checkpoint_activations, 31 | ) 32 | 33 | 34 | def _get_large_config( 35 | img_size=224, patch_size=16, drop_path_rate=0, 36 | checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs 37 | ): 38 | return EncoderConfig( 39 | img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True, 40 | layernorm_embedding=False, normalize_output=True, no_output_layer=True, 41 | drop_path_rate=drop_path_rate, encoder_embed_dim=1024, encoder_attention_heads=16, 42 | encoder_ffn_embed_dim=int(1024 * mlp_ratio), encoder_layers=24, 43 | checkpoint_activations=checkpoint_activations, 44 | ) 45 | 46 | 47 | class BEiT3Wrapper(nn.Module): 48 | def __init__(self, args, **kwargs): 49 | super().__init__() 50 | self.args = args 51 | self.beit3 = BEiT3(args) 52 | 53 | # self.apply(self._init_weights) # no longer necessary since we only use the pre-trained ckpt 54 | # self.mim_head = nn.Linear(1024, 8192) 55 | self.num_img_patches = self.beit3.vision_embed.num_position_embeddings() 56 | self.hidden_size = args.encoder_embed_dim 57 | 58 | def fix_init_weight(self): 59 | def rescale(param, layer_id): 60 | param.div_(math.sqrt(2.0 * layer_id)) 61 | 62 | for layer_id, layer in enumerate(self.blocks): 63 | rescale(layer.attn.proj.weight.data, layer_id + 1) 64 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 65 | 66 | def get_num_layers(self): 67 | return self.beit3.encoder.num_layers 68 | 69 | @torch.jit.ignore 70 | def no_weight_decay(self): 71 | return {'pos_embed', 'cls_token', 'beit3.encoder.embed_positions.A.weight', 'beit3.vision_embed.cls_token', 'logit_scale'} 72 | 73 | def _init_weights(self, m): 74 | if isinstance(m, nn.Linear): 75 | trunc_normal_(m.weight, std=.02) 76 | if isinstance(m, nn.Linear) and m.bias is not None: 77 | nn.init.constant_(m.bias, 0) 78 | elif isinstance(m, nn.LayerNorm): 79 | nn.init.constant_(m.bias, 0) 80 | nn.init.constant_(m.weight, 1.0) 81 | 82 | def forward(self, pixel_values, query_embed=None, encode_image=False, img_feat_layer=-1, attn_mask=None): 83 | assert (query_embed is not None) ^ encode_image 84 | B = pixel_values.size(0) 85 | dtype = self.beit3.vision_embed.proj.weight.dtype 86 | pixel_values = pixel_values.to(dtype) 87 | token_embeddings = self.beit3.vision_embed(pixel_values) 88 | multiway_split_position = -1 89 | if query_embed is not None: 90 | query_embed = torch.stack([query_embed] * B) 91 | multiway_split_position = token_embeddings.size(1) 92 | token_embeddings = torch.cat([token_embeddings, query_embed], dim=1) 93 | 94 | outputs = self.beit3.encoder( 95 | src_tokens=None, 96 | token_embeddings=token_embeddings, 97 | multiway_split_position=multiway_split_position, 98 | return_all_hiddens=encode_image, 99 | attn_mask=attn_mask, 100 | ) 101 | vision_hidden_states = outputs["encoder_out"] 102 | if query_embed is not None: 103 | vision_hidden_states = vision_hidden_states[:, self.num_img_patches:] 104 | if encode_image: 105 | vision_hidden_states = outputs['encoder_states'][img_feat_layer][:, 1:self.num_img_patches] 106 | return vision_hidden_states 107 | 108 | @register_model 109 | def beit3_large_patch16_224(pretrained=False, **kwargs): 110 | args = _get_large_config(img_size=224, **kwargs) 111 | model = BEiT3Wrapper(args, **kwargs) 112 | return model 113 | 114 | @register_model 115 | def beit3_large_patch16_256(pretrained=False, **kwargs): 116 | args = _get_large_config(img_size=256, **kwargs) 117 | model = BEiT3Wrapper(args, **kwargs) 118 | return model 119 | 120 | @register_model 121 | def beit3_large_patch16_336(pretrained=False, **kwargs): 122 | args = _get_large_config(img_size=336, **kwargs) 123 | model = BEiT3Wrapper(args, **kwargs) 124 | return model 125 | 126 | @register_model 127 | def beit3_large_patch16_448(pretrained=False, **kwargs): 128 | args = _get_large_config(img_size=448, **kwargs) 129 | model = BEiT3Wrapper(args, **kwargs) 130 | return model 131 | 132 | @register_model 133 | def beit3_large_patch16_672(pretrained=False, **kwargs): 134 | args = _get_large_config(img_size=672, **kwargs) 135 | model = BEiT3Wrapper(args, **kwargs) 136 | return model 137 | 138 | @register_model 139 | def beit3_large_itc_patch16_224(pretrained=False, **kwargs): 140 | args = _get_large_config(img_size=224, **kwargs) 141 | model = BEiT3Wrapper(args, **kwargs) 142 | return model -------------------------------------------------------------------------------- /muffin/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Muffin/85d72b4b04035b6b1cf4168814739ae9e5e6e1dc/muffin/serve/__init__.py -------------------------------------------------------------------------------- /muffin/serve/cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.serve.cli --model ~/model_weights/llama-7b 4 | """ 5 | import argparse 6 | import time 7 | 8 | import torch 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | 11 | from muffin.conversation import conv_templates, SeparatorStyle 12 | 13 | 14 | @torch.inference_mode() 15 | def generate_stream(tokenizer, model, params, device, 16 | context_len=2048, stream_interval=2): 17 | """Adapted from fastchat/serve/model_worker.py::generate_stream""" 18 | 19 | prompt = params["prompt"] 20 | l_prompt = len(prompt) 21 | temperature = float(params.get("temperature", 1.0)) 22 | max_new_tokens = int(params.get("max_new_tokens", 256)) 23 | stop_str = params.get("stop", None) 24 | 25 | input_ids = tokenizer(prompt).input_ids 26 | output_ids = list(input_ids) 27 | 28 | max_src_len = context_len - max_new_tokens - 8 29 | input_ids = input_ids[-max_src_len:] 30 | 31 | for i in range(max_new_tokens): 32 | if i == 0: 33 | out = model( 34 | torch.as_tensor([input_ids], device=device), use_cache=True) 35 | logits = out.logits 36 | past_key_values = out.past_key_values 37 | else: 38 | attention_mask = torch.ones( 39 | 1, past_key_values[0][0].shape[-2] + 1, device=device) 40 | out = model(input_ids=torch.as_tensor([[token]], device=device), 41 | use_cache=True, 42 | attention_mask=attention_mask, 43 | past_key_values=past_key_values) 44 | logits = out.logits 45 | past_key_values = out.past_key_values 46 | 47 | last_token_logits = logits[0][-1] 48 | if temperature < 1e-4: 49 | token = int(torch.argmax(last_token_logits)) 50 | else: 51 | probs = torch.softmax(last_token_logits / temperature, dim=-1) 52 | token = int(torch.multinomial(probs, num_samples=1)) 53 | 54 | output_ids.append(token) 55 | 56 | if token == tokenizer.eos_token_id: 57 | stopped = True 58 | else: 59 | stopped = False 60 | 61 | if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: 62 | output = tokenizer.decode(output_ids, skip_special_tokens=True) 63 | pos = output.rfind(stop_str, l_prompt) 64 | if pos != -1: 65 | output = output[:pos] 66 | stopped = True 67 | yield output 68 | 69 | if stopped: 70 | break 71 | 72 | del past_key_values 73 | 74 | 75 | def main(args): 76 | model_name = args.model_name 77 | num_gpus = args.num_gpus 78 | 79 | # Model 80 | if args.device == "cuda": 81 | kwargs = {"torch_dtype": torch.float16} 82 | if num_gpus == "auto": 83 | kwargs["device_map"] = "auto" 84 | else: 85 | num_gpus = int(num_gpus) 86 | if num_gpus != 1: 87 | kwargs.update({ 88 | "device_map": "auto", 89 | "max_memory": {i: "13GiB" for i in range(num_gpus)}, 90 | }) 91 | elif args.device == "cpu": 92 | kwargs = {} 93 | else: 94 | raise ValueError(f"Invalid device: {args.device}") 95 | 96 | tokenizer = AutoTokenizer.from_pretrained(model_name) 97 | model = AutoModelForCausalLM.from_pretrained(model_name, 98 | low_cpu_mem_usage=True, **kwargs) 99 | 100 | if args.device == "cuda" and num_gpus == 1: 101 | model.cuda() 102 | 103 | # Chat 104 | conv = conv_templates[args.conv_template].copy() 105 | while True: 106 | try: 107 | inp = input(f"{conv.roles[0]}: ") 108 | except EOFError: 109 | inp = "" 110 | if not inp: 111 | print("exit...") 112 | break 113 | 114 | conv.append_message(conv.roles[0], inp) 115 | conv.append_message(conv.roles[1], None) 116 | prompt = conv.get_prompt() 117 | 118 | params = { 119 | "model": model_name, 120 | "prompt": prompt, 121 | "temperature": args.temperature, 122 | "max_new_tokens": args.max_new_tokens, 123 | "stop": conv.sep if conv.sep_style == SeparatorStyle.SINGLE else conv.sep2, 124 | } 125 | 126 | print(f"{conv.roles[1]}: ", end="", flush=True) 127 | pre = 0 128 | for outputs in generate_stream(tokenizer, model, params, args.device): 129 | outputs = outputs[len(prompt) + 1:].strip() 130 | outputs = outputs.split(" ") 131 | now = len(outputs) 132 | if now - 1 > pre: 133 | print(" ".join(outputs[pre:now-1]), end=" ", flush=True) 134 | pre = now - 1 135 | print(" ".join(outputs[pre:]), flush=True) 136 | 137 | conv.messages[-1][-1] = " ".join(outputs) 138 | 139 | if args.debug: 140 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 141 | 142 | 143 | if __name__ == "__main__": 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 146 | parser.add_argument("--num-gpus", type=str, default="1") 147 | parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda") 148 | parser.add_argument("--conv-template", type=str, default="v1") 149 | parser.add_argument("--temperature", type=float, default=0.7) 150 | parser.add_argument("--max-new-tokens", type=int, default=512) 151 | parser.add_argument("--debug", action="store_true") 152 | args = parser.parse_args() 153 | main(args) 154 | -------------------------------------------------------------------------------- /muffin/serve/controller.py: -------------------------------------------------------------------------------- 1 | """ 2 | A controller manages distributed workers. 3 | It sends worker addresses to clients. 4 | """ 5 | import argparse 6 | import asyncio 7 | import dataclasses 8 | from enum import Enum, auto 9 | import json 10 | import logging 11 | import time 12 | from typing import List, Union 13 | import threading 14 | 15 | from fastapi import FastAPI, Request 16 | from fastapi.responses import StreamingResponse 17 | import numpy as np 18 | import requests 19 | import uvicorn 20 | 21 | from muffin.constants import CONTROLLER_HEART_BEAT_EXPIRATION 22 | from muffin.utils import build_logger, server_error_msg 23 | 24 | 25 | logger = build_logger("controller", "controller.log") 26 | 27 | 28 | class DispatchMethod(Enum): 29 | LOTTERY = auto() 30 | SHORTEST_QUEUE = auto() 31 | 32 | @classmethod 33 | def from_str(cls, name): 34 | if name == "lottery": 35 | return cls.LOTTERY 36 | elif name == "shortest_queue": 37 | return cls.SHORTEST_QUEUE 38 | else: 39 | raise ValueError(f"Invalid dispatch method") 40 | 41 | 42 | @dataclasses.dataclass 43 | class WorkerInfo: 44 | model_names: List[str] 45 | speed: int 46 | queue_length: int 47 | check_heart_beat: bool 48 | last_heart_beat: str 49 | 50 | 51 | def heart_beat_controller(controller): 52 | while True: 53 | time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) 54 | controller.remove_stable_workers_by_expiration() 55 | 56 | 57 | class Controller: 58 | def __init__(self, dispatch_method: str): 59 | # Dict[str -> WorkerInfo] 60 | self.worker_info = {} 61 | self.dispatch_method = DispatchMethod.from_str(dispatch_method) 62 | 63 | self.heart_beat_thread = threading.Thread( 64 | target=heart_beat_controller, args=(self,)) 65 | self.heart_beat_thread.start() 66 | 67 | logger.info("Init controller") 68 | 69 | def register_worker(self, worker_name: str, check_heart_beat: bool, 70 | worker_status: dict): 71 | if worker_name not in self.worker_info: 72 | logger.info(f"Register a new worker: {worker_name}") 73 | else: 74 | logger.info(f"Register an existing worker: {worker_name}") 75 | 76 | if not worker_status: 77 | worker_status = self.get_worker_status(worker_name) 78 | if not worker_status: 79 | return False 80 | 81 | self.worker_info[worker_name] = WorkerInfo( 82 | worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], 83 | check_heart_beat, time.time()) 84 | 85 | logger.info(f"Register done: {worker_name}, {worker_status}") 86 | return True 87 | 88 | def get_worker_status(self, worker_name: str): 89 | try: 90 | r = requests.post(worker_name + "/worker_get_status", timeout=5) 91 | except requests.exceptions.RequestException as e: 92 | logger.error(f"Get status fails: {worker_name}, {e}") 93 | return None 94 | 95 | if r.status_code != 200: 96 | logger.error(f"Get status fails: {worker_name}, {r}") 97 | return None 98 | 99 | return r.json() 100 | 101 | def remove_worker(self, worker_name: str): 102 | del self.worker_info[worker_name] 103 | 104 | def refresh_all_workers(self): 105 | old_info = dict(self.worker_info) 106 | self.worker_info = {} 107 | 108 | for w_name, w_info in old_info.items(): 109 | if not self.register_worker(w_name, w_info.check_heart_beat, None): 110 | logger.info(f"Remove stale worker: {w_name}") 111 | 112 | def list_models(self): 113 | model_names = set() 114 | 115 | for w_name, w_info in self.worker_info.items(): 116 | model_names.update(w_info.model_names) 117 | 118 | return list(model_names) 119 | 120 | def get_worker_address(self, model_name: str): 121 | if self.dispatch_method == DispatchMethod.LOTTERY: 122 | worker_names = [] 123 | worker_speeds = [] 124 | for w_name, w_info in self.worker_info.items(): 125 | if model_name in w_info.model_names: 126 | worker_names.append(w_name) 127 | worker_speeds.append(w_info.speed) 128 | worker_speeds = np.array(worker_speeds, dtype=np.float32) 129 | norm = np.sum(worker_speeds) 130 | if norm < 1e-4: 131 | return "" 132 | worker_speeds = worker_speeds / norm 133 | if True: # Directly return address 134 | pt = np.random.choice(np.arange(len(worker_names)), 135 | p=worker_speeds) 136 | worker_name = worker_names[pt] 137 | return worker_name 138 | 139 | # Check status before returning 140 | while True: 141 | pt = np.random.choice(np.arange(len(worker_names)), 142 | p=worker_speeds) 143 | worker_name = worker_names[pt] 144 | 145 | if self.get_worker_status(worker_name): 146 | break 147 | else: 148 | self.remove_worker(worker_name) 149 | worker_speeds[pt] = 0 150 | norm = np.sum(worker_speeds) 151 | if norm < 1e-4: 152 | return "" 153 | worker_speeds = worker_speeds / norm 154 | continue 155 | return worker_name 156 | elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: 157 | worker_names = [] 158 | worker_qlen = [] 159 | for w_name, w_info in self.worker_info.items(): 160 | if model_name in w_info.model_names: 161 | worker_names.append(w_name) 162 | worker_qlen.append(w_info.queue_length / w_info.speed) 163 | if len(worker_names) == 0: 164 | return "" 165 | min_index = np.argmin(worker_qlen) 166 | w_name = worker_names[min_index] 167 | self.worker_info[w_name].queue_length += 1 168 | logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") 169 | return w_name 170 | else: 171 | raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") 172 | 173 | def receive_heart_beat(self, worker_name: str, queue_length: int): 174 | if worker_name not in self.worker_info: 175 | logger.info(f"Receive unknown heart beat. {worker_name}") 176 | return False 177 | 178 | self.worker_info[worker_name].queue_length = queue_length 179 | self.worker_info[worker_name].last_heart_beat = time.time() 180 | logger.info(f"Receive heart beat. {worker_name}") 181 | return True 182 | 183 | def remove_stable_workers_by_expiration(self): 184 | expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION 185 | to_delete = [] 186 | for worker_name, w_info in self.worker_info.items(): 187 | if w_info.check_heart_beat and w_info.last_heart_beat < expire: 188 | to_delete.append(worker_name) 189 | 190 | for worker_name in to_delete: 191 | self.remove_worker(worker_name) 192 | 193 | def worker_api_generate_stream(self, params): 194 | worker_addr = self.get_worker_address(params["model"]) 195 | if not worker_addr: 196 | logger.info(f"no worker: {params['model']}") 197 | ret = { 198 | "text": server_error_msg, 199 | "error_code": 2, 200 | } 201 | yield json.dumps(ret).encode() + b"\0" 202 | 203 | try: 204 | response = requests.post(worker_addr + "/worker_generate_stream", 205 | json=params, stream=True, timeout=5) 206 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): 207 | if chunk: 208 | yield chunk + b"\0" 209 | except requests.exceptions.RequestException as e: 210 | logger.info(f"worker timeout: {worker_addr}") 211 | ret = { 212 | "text": server_error_msg, 213 | "error_code": 3, 214 | } 215 | yield json.dumps(ret).encode() + b"\0" 216 | 217 | 218 | # Let the controller act as a worker to achieve hierarchical 219 | # management. This can be used to connect isolated sub networks. 220 | def worker_api_get_status(self): 221 | model_names = set() 222 | speed = 0 223 | queue_length = 0 224 | 225 | for w_name in self.worker_info: 226 | worker_status = self.get_worker_status(w_name) 227 | if worker_status is not None: 228 | model_names.update(worker_status["model_names"]) 229 | speed += worker_status["speed"] 230 | queue_length += worker_status["queue_length"] 231 | 232 | return { 233 | "model_names": list(model_names), 234 | "speed": speed, 235 | "queue_length": queue_length, 236 | } 237 | 238 | 239 | app = FastAPI() 240 | 241 | 242 | @app.post("/register_worker") 243 | async def register_worker(request: Request): 244 | data = await request.json() 245 | controller.register_worker( 246 | data["worker_name"], data["check_heart_beat"], 247 | data.get("worker_status", None)) 248 | 249 | 250 | @app.post("/refresh_all_workers") 251 | async def refresh_all_workers(): 252 | models = controller.refresh_all_workers() 253 | 254 | 255 | @app.post("/list_models") 256 | async def list_models(): 257 | models = controller.list_models() 258 | return {"models": models} 259 | 260 | 261 | @app.post("/get_worker_address") 262 | async def get_worker_address(request: Request): 263 | data = await request.json() 264 | addr = controller.get_worker_address(data["model"]) 265 | return {"address": addr} 266 | 267 | 268 | @app.post("/receive_heart_beat") 269 | async def receive_heart_beat(request: Request): 270 | data = await request.json() 271 | exist = controller.receive_heart_beat( 272 | data["worker_name"], data["queue_length"]) 273 | return {"exist": exist} 274 | 275 | 276 | @app.post("/worker_generate_stream") 277 | async def worker_api_generate_stream(request: Request): 278 | params = await request.json() 279 | generator = controller.worker_api_generate_stream(params) 280 | return StreamingResponse(generator) 281 | 282 | 283 | @app.post("/worker_get_status") 284 | async def worker_api_get_status(request: Request): 285 | return controller.worker_api_get_status() 286 | 287 | 288 | if __name__ == "__main__": 289 | parser = argparse.ArgumentParser() 290 | parser.add_argument("--host", type=str, default="localhost") 291 | parser.add_argument("--port", type=int, default=21001) 292 | parser.add_argument("--dispatch-method", type=str, choices=[ 293 | "lottery", "shortest_queue"], default="shortest_queue") 294 | args = parser.parse_args() 295 | logger.info(f"args: {args}") 296 | 297 | controller = Controller(args.dispatch_method) 298 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 299 | -------------------------------------------------------------------------------- /muffin/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Muffin/85d72b4b04035b6b1cf4168814739ae9e5e6e1dc/muffin/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /muffin/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Muffin/85d72b4b04035b6b1cf4168814739ae9e5e6e1dc/muffin/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /muffin/serve/gateway/README.md: -------------------------------------------------------------------------------- 1 | # fastchat Nginx Gateway 2 | 3 | ## Purpose of the Gateway 4 | 5 | The Nginx gateway serves the following purposes: 6 | 7 | 1. Protects Gradio servers by acting as a firewall. 8 | 2. Facilitates dynamic mounting and unmounting of Gradio servers. 9 | 3. Provides load balancing for Gradio servers. 10 | 4. Offers additional security features, such as total connection limit. 11 | 5. Reduces attack surface by requiring only a single public port to be exposed for serving. 12 | 13 | ## Deployment and Updating of the Gateway 14 | 15 | ### Installing Nginx 16 | 17 | On Debian-based distributions (e.g., Ubuntu): 18 | 19 | ```bash 20 | sudo apt update 21 | sudo apt install nginx 22 | ``` 23 | On Red Hat-based distributions (e.g., CentOS, Fedora): 24 | 25 | ```bash 26 | sudo yum install epel-release 27 | sudo yum install nginx 28 | ``` 29 | 30 | ### Deployment 31 | 32 | Copy `nginx.conf` to `/etc/nginx/nginx.conf` (need sudo permission). 33 | 34 | Replace the port number 7860 in `server localhost:7860` with the port where you deploy the Gradio web server. 35 | 36 | Modify `upstream websocket` to configure Gradio servers behind the gateway. 37 | 38 | Lastly, update Nginx. 39 | 40 | 41 | ### HTTPS Deployment with a Public Domain URL 42 | 43 | Make sure you obtain the HTTPS certificate and the private key used to generate the certificate. 44 | 45 | Fill the addresses to your certificate and private key in the `[PATH_TO_SSL_CERT]` and `[PATH_TO_PRIVATE_KEY]` fields. 46 | 47 | If you have your own domain url to serve the chatbot, replace the chat.lmsys.org url with your own domain url. 48 | 49 | ### Updating 50 | 51 | Every time when `/etc/nginx/nginx.conf` is modified, you need to update the Nginx service: 52 | 53 | ```bash 54 | sudo nginx -t # check `/etc/nginx/nginx.conf` 55 | sudo systemctl reload nginx # restart Nginx service to load the new config 56 | sudo systemctl status nginx # check the status of the Nginx service. It should be active (running). 57 | ``` 58 | -------------------------------------------------------------------------------- /muffin/serve/gateway/nginx.conf: -------------------------------------------------------------------------------- 1 | user www-data; 2 | worker_processes auto; 3 | pid /run/nginx.pid; 4 | include /etc/nginx/modules-enabled/*.conf; 5 | 6 | events { 7 | worker_connections 1024; # maximum number of connections that a worker process can handle concurrently 8 | # multi_accept on; # enabling multi_accept can help improve performance under high load, but may increase the number of simultaneous connections that a worker process can handle 9 | 10 | } 11 | 12 | http { 13 | ## 14 | # Basic Settings 15 | ## 16 | 17 | sendfile on; # enable sendfile for performance optimization 18 | tcp_nopush on; # enable TCP no-pushing 19 | tcp_nodelay on; # enable TCP no-delay 20 | keepalive_timeout 65; # sets the timeout for keep-alive connections 21 | types_hash_max_size 2048; # maximum size of the types hash table 22 | # server_tokens off; # disable server token (i.e., server signature) in response headers to improve security 23 | 24 | # server_names_hash_bucket_size 64; 25 | # server_name_in_redirect off; 26 | 27 | include /etc/nginx/mime.types; # include MIME types file 28 | default_type application/octet-stream; # default MIME type for unknown file types 29 | 30 | ## 31 | # SSL Settings 32 | ## 33 | 34 | ssl_protocols TLSv1.2; # specify SSL/TLS protocols to use 35 | ssl_prefer_server_ciphers on; # prefer server ciphers over client ciphers 36 | 37 | ## 38 | # Logging Settings 39 | ## 40 | 41 | access_log /var/log/nginx/access.log; # path to access log file 42 | error_log /var/log/nginx/error.log; # path to error log file 43 | 44 | ## 45 | # Gzip Settings 46 | ## 47 | gzip on; # enable Gzip compression 48 | 49 | ## 50 | # Virtual Host Configs 51 | ## 52 | 53 | include /etc/nginx/conf.d/*.conf; # include all configuration files in conf.d directory 54 | include /etc/nginx/sites-enabled/*; # include all enabled sites configuration files 55 | 56 | # WebSocket Proxy: https://www.nginx.com/blog/websocket-nginx/ 57 | map $http_upgrade $connection_upgrade { 58 | default upgrade; 59 | '' close; 60 | } 61 | 62 | upstream websocket { 63 | ip_hash; # load balancing by IP to guarantee session persistence 64 | server localhost:7860; # The port should be the gradio web server port 65 | # server localhost:7861; # extra gradio server if more than one 66 | } 67 | 68 | limit_conn_status 429; 69 | limit_conn_zone $binary_remote_addr zone=perip:10m; # limit number of connections per IP 70 | limit_conn_zone $server_name zone=perserver:10m; # limit number of connections per server 71 | 72 | server { 73 | listen 443 ssl; # the listening port of our server 74 | ssl_certificate [PATH_TO_SSL_CERT]; 75 | ssl_certificate_key [PATH_TO_PRIVATE_KEY]; 76 | server_name chat.lmsys.org; # replace the url with your own domain url 77 | limit_conn perserver 1024; # connections per server 78 | location / { 79 | proxy_pass http://websocket; # proxy all requests to the defined upstream server 80 | limit_conn perip 5; # connections per IP 81 | proxy_set_header Host $host; # set the Host header for the upstream server 82 | proxy_set_header X-Real-IP $remote_addr; # set the client IP address as the real IP for the upstream server 83 | proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; # set the client IP addresses in the X-Forwarded-For header 84 | proxy_http_version 1.1; # use HTTP version 1.1 for upstream communication 85 | proxy_set_header Upgrade $http_upgrade; 86 | proxy_set_header Connection "Upgrade"; # set the Connection header to Upgrade to enable WebSocket communication 87 | } 88 | } 89 | 90 | # the following block routes all HTTP traffic to HTTPS via nginx 91 | server { 92 | listen 80; 93 | server_name chat.lmsys.org; 94 | return 301 https://chat.lmsys.org$request_uri; 95 | } 96 | 97 | } 98 | -------------------------------------------------------------------------------- /muffin/serve/gradio_css.py: -------------------------------------------------------------------------------- 1 | code_highlight_css = ( 2 | """ 3 | #chatbot .hll { background-color: #ffffcc } 4 | #chatbot .c { color: #408080; font-style: italic } 5 | #chatbot .err { border: 1px solid #FF0000 } 6 | #chatbot .k { color: #008000; font-weight: bold } 7 | #chatbot .o { color: #666666 } 8 | #chatbot .ch { color: #408080; font-style: italic } 9 | #chatbot .cm { color: #408080; font-style: italic } 10 | #chatbot .cp { color: #BC7A00 } 11 | #chatbot .cpf { color: #408080; font-style: italic } 12 | #chatbot .c1 { color: #408080; font-style: italic } 13 | #chatbot .cs { color: #408080; font-style: italic } 14 | #chatbot .gd { color: #A00000 } 15 | #chatbot .ge { font-style: italic } 16 | #chatbot .gr { color: #FF0000 } 17 | #chatbot .gh { color: #000080; font-weight: bold } 18 | #chatbot .gi { color: #00A000 } 19 | #chatbot .go { color: #888888 } 20 | #chatbot .gp { color: #000080; font-weight: bold } 21 | #chatbot .gs { font-weight: bold } 22 | #chatbot .gu { color: #800080; font-weight: bold } 23 | #chatbot .gt { color: #0044DD } 24 | #chatbot .kc { color: #008000; font-weight: bold } 25 | #chatbot .kd { color: #008000; font-weight: bold } 26 | #chatbot .kn { color: #008000; font-weight: bold } 27 | #chatbot .kp { color: #008000 } 28 | #chatbot .kr { color: #008000; font-weight: bold } 29 | #chatbot .kt { color: #B00040 } 30 | #chatbot .m { color: #666666 } 31 | #chatbot .s { color: #BA2121 } 32 | #chatbot .na { color: #7D9029 } 33 | #chatbot .nb { color: #008000 } 34 | #chatbot .nc { color: #0000FF; font-weight: bold } 35 | #chatbot .no { color: #880000 } 36 | #chatbot .nd { color: #AA22FF } 37 | #chatbot .ni { color: #999999; font-weight: bold } 38 | #chatbot .ne { color: #D2413A; font-weight: bold } 39 | #chatbot .nf { color: #0000FF } 40 | #chatbot .nl { color: #A0A000 } 41 | #chatbot .nn { color: #0000FF; font-weight: bold } 42 | #chatbot .nt { color: #008000; font-weight: bold } 43 | #chatbot .nv { color: #19177C } 44 | #chatbot .ow { color: #AA22FF; font-weight: bold } 45 | #chatbot .w { color: #bbbbbb } 46 | #chatbot .mb { color: #666666 } 47 | #chatbot .mf { color: #666666 } 48 | #chatbot .mh { color: #666666 } 49 | #chatbot .mi { color: #666666 } 50 | #chatbot .mo { color: #666666 } 51 | #chatbot .sa { color: #BA2121 } 52 | #chatbot .sb { color: #BA2121 } 53 | #chatbot .sc { color: #BA2121 } 54 | #chatbot .dl { color: #BA2121 } 55 | #chatbot .sd { color: #BA2121; font-style: italic } 56 | #chatbot .s2 { color: #BA2121 } 57 | #chatbot .se { color: #BB6622; font-weight: bold } 58 | #chatbot .sh { color: #BA2121 } 59 | #chatbot .si { color: #BB6688; font-weight: bold } 60 | #chatbot .sx { color: #008000 } 61 | #chatbot .sr { color: #BB6688 } 62 | #chatbot .s1 { color: #BA2121 } 63 | #chatbot .ss { color: #19177C } 64 | #chatbot .bp { color: #008000 } 65 | #chatbot .fm { color: #0000FF } 66 | #chatbot .vc { color: #19177C } 67 | #chatbot .vg { color: #19177C } 68 | #chatbot .vi { color: #19177C } 69 | #chatbot .vm { color: #19177C } 70 | #chatbot .il { color: #666666 } 71 | """) 72 | #.highlight { background: #f8f8f8; } 73 | 74 | -------------------------------------------------------------------------------- /muffin/serve/gradio_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adopted from https://github.com/gradio-app/gradio/blob/main/gradio/components.py 3 | Fix a markdown render problem. 4 | """ 5 | from __future__ import annotations 6 | 7 | from gradio.components import * 8 | from markdown2 import Markdown 9 | 10 | 11 | class _Keywords(Enum): 12 | NO_VALUE = "NO_VALUE" # Used as a sentinel to determine if nothing is provided as a argument for `value` in `Component.update()` 13 | FINISHED_ITERATING = "FINISHED_ITERATING" # Used to skip processing of a component's value (needed for generators + state) 14 | 15 | 16 | @document("style") 17 | class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): 18 | """ 19 | Displays a chatbot output showing both user submitted messages and responses. Supports a subset of Markdown including bold, italics, code, and images. 20 | Preprocessing: this component does *not* accept input. 21 | Postprocessing: expects function to return a {List[Tuple[str | None | Tuple, str | None | Tuple]]}, a list of tuples with user message and response messages. Messages should be strings, tuples, or Nones. If the message is a string, it can include Markdown. If it is a tuple, it should consist of (string filepath to image/video/audio, [optional string alt text]). Messages that are `None` are not displayed. 22 | 23 | Demos: chatbot_simple, chatbot_multimodal 24 | """ 25 | 26 | def __init__( 27 | self, 28 | value: List[Tuple[str | None, str | None]] | Callable | None = None, 29 | color_map: Dict[str, str] | None = None, # Parameter moved to Chatbot.style() 30 | *, 31 | label: str | None = None, 32 | every: float | None = None, 33 | show_label: bool = True, 34 | visible: bool = True, 35 | elem_id: str | None = None, 36 | elem_classes: List[str] | str | None = None, 37 | **kwargs, 38 | ): 39 | """ 40 | Parameters: 41 | value: Default value to show in chatbot. If callable, the function will be called whenever the app loads to set the initial value of the component. 42 | label: component name in interface. 43 | every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. 44 | show_label: if True, will display label. 45 | visible: If False, component will be hidden. 46 | elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles. 47 | elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles. 48 | """ 49 | if color_map is not None: 50 | warnings.warn( 51 | "The 'color_map' parameter has been deprecated.", 52 | ) 53 | #self.md = utils.get_markdown_parser() 54 | self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"]) 55 | self.select: EventListenerMethod 56 | """ 57 | Event listener for when the user selects message from Chatbot. 58 | Uses event data gradio.SelectData to carry `value` referring to text of selected message, and `index` tuple to refer to [message, participant] index. 59 | See EventData documentation on how to use this event data. 60 | """ 61 | 62 | IOComponent.__init__( 63 | self, 64 | label=label, 65 | every=every, 66 | show_label=show_label, 67 | visible=visible, 68 | elem_id=elem_id, 69 | elem_classes=elem_classes, 70 | value=value, 71 | **kwargs, 72 | ) 73 | 74 | def get_config(self): 75 | return { 76 | "value": self.value, 77 | "selectable": self.selectable, 78 | **IOComponent.get_config(self), 79 | } 80 | 81 | @staticmethod 82 | def update( 83 | value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE, 84 | label: str | None = None, 85 | show_label: bool | None = None, 86 | visible: bool | None = None, 87 | ): 88 | updated_config = { 89 | "label": label, 90 | "show_label": show_label, 91 | "visible": visible, 92 | "value": value, 93 | "__type__": "update", 94 | } 95 | return updated_config 96 | 97 | def _process_chat_messages( 98 | self, chat_message: str | Tuple | List | Dict | None 99 | ) -> str | Dict | None: 100 | if chat_message is None: 101 | return None 102 | elif isinstance(chat_message, (tuple, list)): 103 | mime_type = processing_utils.get_mimetype(chat_message[0]) 104 | return { 105 | "name": chat_message[0], 106 | "mime_type": mime_type, 107 | "alt_text": chat_message[1] if len(chat_message) > 1 else None, 108 | "data": None, # These last two fields are filled in by the frontend 109 | "is_file": True, 110 | } 111 | elif isinstance( 112 | chat_message, dict 113 | ): # This happens for previously processed messages 114 | return chat_message 115 | elif isinstance(chat_message, str): 116 | #return self.md.render(chat_message) 117 | return str(self.md.convert(chat_message)) 118 | else: 119 | raise ValueError(f"Invalid message for Chatbot component: {chat_message}") 120 | 121 | def postprocess( 122 | self, 123 | y: List[ 124 | Tuple[str | Tuple | List | Dict | None, str | Tuple | List | Dict | None] 125 | ], 126 | ) -> List[Tuple[str | Dict | None, str | Dict | None]]: 127 | """ 128 | Parameters: 129 | y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed. 130 | Returns: 131 | List of tuples representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information. 132 | """ 133 | if y is None: 134 | return [] 135 | processed_messages = [] 136 | for message_pair in y: 137 | assert isinstance( 138 | message_pair, (tuple, list) 139 | ), f"Expected a list of lists or list of tuples. Received: {message_pair}" 140 | assert ( 141 | len(message_pair) == 2 142 | ), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}" 143 | processed_messages.append( 144 | ( 145 | #self._process_chat_messages(message_pair[0]), 146 | '
' +
147 |                     message_pair[0] + "
", 148 | self._process_chat_messages(message_pair[1]), 149 | ) 150 | ) 151 | return processed_messages 152 | 153 | def style(self, height: int | None = None, **kwargs): 154 | """ 155 | This method can be used to change the appearance of the Chatbot component. 156 | """ 157 | if height is not None: 158 | self._style["height"] = height 159 | if kwargs.get("color_map") is not None: 160 | warnings.warn("The 'color_map' parameter has been deprecated.") 161 | 162 | Component.style( 163 | self, 164 | **kwargs, 165 | ) 166 | return self 167 | 168 | 169 | -------------------------------------------------------------------------------- /muffin/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /muffin/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from muffin.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default= 59 | "Tell me a story with more than 1000 words.") 60 | args = parser.parse_args() 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /muffin/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | from typing import List, Optional, Tuple 3 | 4 | import torch 5 | from torch import nn 6 | 7 | import transformers 8 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 9 | 10 | from einops import rearrange 11 | 12 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | def forward( 16 | self, 17 | hidden_states: torch.Tensor, 18 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | output_attentions: bool = False, 21 | use_cache: bool = False, 22 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], 23 | Optional[Tuple[torch.Tensor]]]: 24 | """Input shape: Batch x Time x Channel 25 | 26 | attention_mask: [bsz, q_len] 27 | """ 28 | bsz, q_len, _ = hidden_states.size() 29 | 30 | query_states = self.q_proj(hidden_states).view( 31 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 32 | key_states = self.k_proj(hidden_states).view( 33 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 34 | value_states = self.v_proj(hidden_states).view( 35 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 36 | # [bsz, q_len, nh, hd] 37 | # [bsz, nh, q_len, hd] 38 | 39 | kv_seq_len = key_states.shape[-2] 40 | offset = 0 41 | if past_key_value is not None: 42 | offset = past_key_value[0].shape[-2] 43 | kv_seq_len += offset 44 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 45 | query_states, key_states = apply_rotary_pos_emb(query_states, 46 | key_states, 47 | cos, 48 | sin, 49 | offset=offset) 50 | # [bsz, nh, t, hd] 51 | assert not output_attentions, "output_attentions is not supported" 52 | assert not use_cache, "use_cache is not supported" 53 | assert past_key_value is None, "past_key_value is not supported" 54 | 55 | # Flash attention codes from 56 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 57 | 58 | # transform the data into the format required by flash attention 59 | qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd] 60 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 61 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 62 | # the attention_mask should be the same as the key_padding_mask 63 | key_padding_mask = attention_mask 64 | 65 | 66 | if key_padding_mask is None: 67 | qkv = rearrange(qkv, 'b s ... -> (b s) ...') 68 | max_s = q_len 69 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, 70 | device=qkv.device) 71 | output = flash_attn_unpadded_qkvpacked_func( 72 | qkv, cu_q_lens, max_s, 0.0, 73 | softmax_scale=None, causal=True 74 | ) 75 | output = rearrange(output, '(b s) ... -> b s ...', b=bsz) 76 | else: 77 | nheads = qkv.shape[-2] 78 | x = rearrange(qkv, 'b s three h d -> b s (three h d)') 79 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 80 | x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) 81 | output_unpad = flash_attn_unpadded_qkvpacked_func( 82 | x_unpad, cu_q_lens, max_s, 0.0, 83 | softmax_scale=None, causal=True 84 | ) 85 | output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), 86 | indices, bsz, q_len), 87 | 'b s (h d) -> b s h d', h=nheads) 88 | return self.o_proj(rearrange(output, 89 | 'b s h d -> b s (h d)')), None, None 90 | 91 | 92 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 93 | # requires the attention mask to be the same as the key_padding_mask 94 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, 95 | inputs_embeds, past_key_values_length): 96 | # [bsz, seq_len] 97 | return attention_mask 98 | 99 | 100 | def replace_llama_attn_with_flash_attn(): 101 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 102 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 103 | -------------------------------------------------------------------------------- /muffin/train/train_mem_muffin.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from muffin.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from muffin.train.train_muffin import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /muffin/train/train_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import copy 4 | import time 5 | import transformers 6 | 7 | import torch 8 | 9 | from typing import Dict, Optional, Sequence 10 | from muffin import conversation as conversation_lib 11 | 12 | IGNORE_INDEX = -100 13 | DEFAULT_IMAGE_TOKEN = "" 14 | DEFAULT_IMAGE_PATCH_TOKEN = "" 15 | DEFAULT_IM_START_TOKEN = "" 16 | DEFAULT_IM_END_TOKEN = "" 17 | 18 | 19 | def _tokenize_fn(strings: Sequence[str], 20 | tokenizer: transformers.PreTrainedTokenizer) -> Dict: 21 | """Tokenize a list of strings.""" 22 | tokenized_list = [ 23 | tokenizer( 24 | text, 25 | return_tensors="pt", 26 | padding="longest", 27 | max_length=tokenizer.model_max_length, 28 | truncation=True, 29 | ) for text in strings 30 | ] 31 | input_ids = labels = [ 32 | tokenized.input_ids[0] for tokenized in tokenized_list 33 | ] 34 | input_ids_lens = labels_lens = [ 35 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() 36 | for tokenized in tokenized_list 37 | ] 38 | return dict( 39 | input_ids=input_ids, 40 | labels=labels, 41 | input_ids_lens=input_ids_lens, 42 | labels_lens=labels_lens, 43 | ) 44 | 45 | def _mask_targets(target, tokenized_lens, speakers): 46 | # cur_idx = 0 47 | cur_idx = tokenized_lens[0] 48 | tokenized_lens = tokenized_lens[1:] 49 | target[:cur_idx] = IGNORE_INDEX 50 | for tokenized_len, speaker in zip(tokenized_lens, speakers): 51 | if speaker == "human": 52 | target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX 53 | cur_idx += tokenized_len 54 | 55 | 56 | def _add_speaker_and_signal(header, source, get_conversation=True): 57 | """Add speaker and start/end signal on each round.""" 58 | BEGIN_SIGNAL = "### " 59 | END_SIGNAL = "\n" 60 | conversation = header 61 | for sentence in source: 62 | from_str = sentence["from"] 63 | if from_str.lower() == "human": 64 | from_str = conversation_lib.default_conversation.roles[0] 65 | elif from_str.lower() == "gpt": 66 | from_str = conversation_lib.default_conversation.roles[1] 67 | else: 68 | from_str = 'unknown' 69 | sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + 70 | sentence["value"] + END_SIGNAL) 71 | if get_conversation: 72 | conversation += sentence["value"] 73 | conversation += BEGIN_SIGNAL 74 | return conversation 75 | 76 | 77 | def SFT_collator_fn(instances, pad_token_id): 78 | input_ids, labels = tuple([instance[key] for instance in instances] 79 | for key in ("input_ids", "labels")) 80 | input_ids = torch.nn.utils.rnn.pad_sequence( 81 | input_ids, 82 | batch_first=True, 83 | padding_value=pad_token_id) 84 | labels = torch.nn.utils.rnn.pad_sequence(labels, 85 | batch_first=True, 86 | padding_value=IGNORE_INDEX) 87 | batch = dict( 88 | input_ids=input_ids, 89 | labels=labels, 90 | attention_mask=input_ids.ne(pad_token_id), 91 | ) 92 | 93 | if 'image' in instances[0]: 94 | images = [instance['image'] for instance in instances] 95 | if all(x is not None and x.shape == images[0].shape for x in images): 96 | batch['images'] = torch.stack(images) 97 | else: 98 | batch['images'] = images 99 | return batch 100 | 101 | 102 | def encode_multimodal_preference_sample(source, tokenizer, multimodal_cfg): 103 | if isinstance(source['chosen'], list): 104 | win_conv = source['chosen'] 105 | rej_conv = source['rejected'] 106 | elif isinstance(source['chosen'], dict): 107 | win_conv = copy.deepcopy([source['question'], source["chosen"]]) 108 | rej_conv = copy.deepcopy([source['question'], source["rejected"]]) 109 | 110 | if 'image' in source: 111 | image = source['image'] 112 | image = multimodal_cfg['image_processor'](image) 113 | win_conv = expand_image_token(win_conv, multimodal_cfg) 114 | rej_conv = expand_image_token(rej_conv, multimodal_cfg) 115 | 116 | rej_data_dict = preprocess([rej_conv], tokenizer) 117 | rej_data_dict = dict(input_ids=rej_data_dict["input_ids"][0], 118 | labels=rej_data_dict["labels"][0]) 119 | 120 | win_data_dict = preprocess([win_conv], tokenizer) 121 | win_data_dict = dict(input_ids=win_data_dict["input_ids"][0], 122 | labels=win_data_dict["labels"][0]) 123 | 124 | # image exist in the data 125 | if 'image' in source: 126 | rej_data_dict['image'] = win_data_dict['image'] = image 127 | elif multimodal_cfg['is_multimodal']: 128 | # image does not exist in the data, but the model is multimodal 129 | crop_size = multimodal_cfg['image_processor'].crop_size 130 | rej_data_dict['image'] = win_data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) 131 | 132 | if 'ref_win_logp' in source: 133 | rej_data_dict['ref_rej_logp'] = source['ref_rej_logp'] 134 | win_data_dict['ref_win_logp'] = source['ref_win_logp'] 135 | rej_data_dict['ref_rej_avg_logp'] = source['ref_rej_avg_logp'] 136 | win_data_dict['ref_win_avg_logp'] = source['ref_win_avg_logp'] 137 | rej_data_dict['ref_rej_per_token_logp'] = source['ref_rej_per_token_logp'] 138 | win_data_dict['ref_win_per_token_logp'] = source['ref_win_per_token_logp'] 139 | return rej_data_dict, win_data_dict 140 | 141 | def expand_image_token(source, multimodal_cfg) -> Dict: 142 | is_multimodal = multimodal_cfg['is_multimodal'] 143 | image_token_len = multimodal_cfg['image_token_len'] 144 | if not is_multimodal: 145 | return source 146 | 147 | for sentence in source: 148 | replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len 149 | if multimodal_cfg['use_im_start_end']: 150 | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN 151 | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) 152 | 153 | return source 154 | 155 | def encode_multimodal_sample(source, tokenizer, multimodal_cfg): 156 | conversation = copy.deepcopy(source["conversations"]) 157 | if 'image' in source: 158 | image = source['image'] 159 | image = multimodal_cfg['image_processor'](image) 160 | conversation = expand_image_token(conversation, multimodal_cfg) 161 | 162 | data_dict = preprocess([conversation], tokenizer) 163 | data_dict = dict(input_ids=data_dict["input_ids"][0], 164 | labels=data_dict["labels"][0]) 165 | 166 | # image exist in the data 167 | if 'image' in source: 168 | data_dict['image'] = image 169 | elif multimodal_cfg['is_multimodal']: 170 | # image does not exist in the data, but the model is multimodal 171 | crop_size = multimodal_cfg['image_processor'].crop_size 172 | data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) 173 | return data_dict 174 | 175 | 176 | def preprocess_v1( 177 | sources, 178 | tokenizer: transformers.PreTrainedTokenizer, 179 | ) -> Dict: 180 | conv = conversation_lib.default_conversation.copy() 181 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 182 | 183 | # Apply prompt templates 184 | conversations = [] 185 | for i, source in enumerate(sources): 186 | if roles[source[0]["from"]] != conv.roles[0]: 187 | # Skip the first one if it is not from human 188 | source = source[1:] 189 | 190 | conv.messages = [] 191 | for j, sentence in enumerate(source): 192 | role = roles[sentence["from"]] 193 | assert role == conv.roles[j % 2], f"{i}" 194 | conv.append_message(role, sentence["value"]) 195 | conversations.append(conv.get_prompt()) 196 | 197 | # Tokenize conversations 198 | input_ids = tokenizer( 199 | conversations, 200 | return_tensors="pt", 201 | padding="longest", 202 | max_length=tokenizer.model_max_length, 203 | truncation=True, 204 | ).input_ids 205 | targets = input_ids.clone() 206 | 207 | assert conv.sep_style == conversation_lib.SeparatorStyle.TWO 208 | 209 | # Mask targets 210 | sep = conv.sep + conv.roles[1] + ": " 211 | for conversation, target in zip(conversations, targets): 212 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 213 | 214 | rounds = conversation.split(conv.sep2) 215 | cur_len = 1 216 | target[:cur_len] = IGNORE_INDEX 217 | for i, rou in enumerate(rounds): 218 | if rou == "": 219 | break 220 | 221 | parts = rou.split(sep) 222 | if len(parts) != 2: 223 | break 224 | parts[0] += sep 225 | round_len = len(tokenizer(rou).input_ids) 226 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 227 | 228 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX 229 | 230 | cur_len += round_len 231 | target[cur_len:] = IGNORE_INDEX 232 | 233 | if cur_len < tokenizer.model_max_length: 234 | if cur_len != total_len: 235 | target[:] = IGNORE_INDEX 236 | print( 237 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 238 | f" (ignored)" 239 | ) 240 | 241 | return dict( 242 | input_ids=input_ids, 243 | labels=targets, 244 | ) 245 | 246 | 247 | def preprocess( 248 | sources: Sequence[str], 249 | tokenizer: transformers.PreTrainedTokenizer, 250 | ) -> Dict: 251 | """ 252 | Given a list of sources, each is a conversation list. This transform: 253 | 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 254 | 2. Concatenate conversations together; 255 | 3. Tokenize the concatenated conversation; 256 | 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. 257 | """ 258 | if conversation_lib.default_conversation.version == "v1": 259 | return preprocess_v1(sources, tokenizer) 260 | # add end signal and concatenate together 261 | conversations = [] 262 | for source in sources: 263 | header = f"{conversation_lib.default_conversation.system}\n\n" 264 | conversation = _add_speaker_and_signal(header, source) 265 | conversations.append(conversation) 266 | # tokenize conversations 267 | conversations_tokenized = _tokenize_fn(conversations, tokenizer) 268 | input_ids = conversations_tokenized["input_ids"] 269 | targets = copy.deepcopy(input_ids) 270 | for target, source in zip(targets, sources): 271 | tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], 272 | tokenizer)["input_ids_lens"] 273 | speakers = [sentence["from"] for sentence in source] 274 | _mask_targets(target, tokenized_lens, speakers) 275 | 276 | return dict(input_ids=input_ids, labels=targets) 277 | -------------------------------------------------------------------------------- /muffin/train/trainers.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS 3 | from transformers.trainer_pt_utils import get_parameter_names 4 | from transformers.utils.import_utils import is_sagemaker_mp_enabled 5 | from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES 6 | import os 7 | import torch 8 | import wandb 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | import torch.nn.functional as F 12 | 13 | from transformers import Trainer 14 | from typing import Any, Dict, Optional, Tuple, Union 15 | from torch import Tensor 16 | from torch.nn import Module 17 | from utils.utils import is_main_process 18 | from muffin.eval.muffin_inference_logp import get_batch_logps 19 | 20 | 21 | def unwrap_model(model: nn.Module) -> nn.Module: 22 | """ 23 | Recursively unwraps a model from potential containers (as used in distributed training). 24 | 25 | Args: 26 | model (`torch.nn.Module`): The model to unwrap. 27 | """ 28 | # since there could be multiple levels of wrapping, unwrap recursively 29 | if hasattr(model, "module"): 30 | return unwrap_model(model.module) 31 | else: 32 | return model 33 | 34 | 35 | class MuffinTrainer(Trainer): 36 | 37 | def create_optimizer_and_scheduler(self, num_training_steps: int): 38 | model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model 39 | decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS) 40 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 41 | 42 | def should_zero_lr(param_name: str): 43 | if 'beit3' in param_name: 44 | if '.A' in param_name: 45 | return True 46 | if 'beit3.vision_embed' in param_name: 47 | return True 48 | optimizer_grouped_parameters = [ 49 | { 50 | "params": [ 51 | p for n, p in model.named_parameters() if (p.requires_grad and should_zero_lr(n)) 52 | ], 53 | "weight_decay": self.args.weight_decay, 54 | "lr": 0.0, 55 | "initial_lr": 0.0 56 | }, 57 | { 58 | "params": [ 59 | p for n, p in model.named_parameters() if (n in decay_parameters and p.requires_grad and not should_zero_lr(n)) 60 | ], 61 | "weight_decay": self.args.weight_decay, 62 | }, 63 | { 64 | "params": [ 65 | p for n, p in model.named_parameters() if (n not in decay_parameters and p.requires_grad and not should_zero_lr(n)) 66 | ], 67 | "weight_decay": 0.0, 68 | }, 69 | ] 70 | for n, p in model.named_parameters(): 71 | # print(f'Check LR of {n}') 72 | if should_zero_lr(n) and is_main_process(): 73 | print(f'Zero LR params: {n}', flush=True) 74 | 75 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) 76 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 77 | 78 | self.scheduler = self.create_scheduler( 79 | num_training_steps=num_training_steps, optimizer=self.optimizer) 80 | print(f'LR schduler is ', self.scheduler) 81 | 82 | 83 | def dpo_loss(policy_chosen_logps: torch.FloatTensor, 84 | policy_rejected_logps: torch.FloatTensor, 85 | reference_chosen_logps: torch.FloatTensor, 86 | reference_rejected_logps: torch.FloatTensor, 87 | beta: float, 88 | reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 89 | """Compute the DPO loss for a batch of policy and reference model log probabilities. 90 | 91 | Args: 92 | policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) 93 | policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) 94 | reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) 95 | reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) 96 | beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. 97 | reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. 98 | 99 | Returns: 100 | A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). 101 | The losses tensor contains the DPO loss for each example in the batch. 102 | The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. 103 | """ 104 | pi_logratios = policy_chosen_logps - policy_rejected_logps 105 | ref_logratios = reference_chosen_logps - reference_rejected_logps 106 | 107 | if reference_free: 108 | ref_logratios = 0 109 | 110 | logits = pi_logratios - ref_logratios 111 | 112 | losses = -F.logsigmoid(beta * logits) 113 | chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach() 114 | rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach() 115 | 116 | return losses, chosen_rewards, rejected_rewards 117 | 118 | 119 | def forward_DPO(model, input_ids, labels, attention_mask, images, **kwargs): 120 | token_weighted = kwargs.pop('token_weighted', False) 121 | dpo_use_average = kwargs.pop('dpo_use_average', False) 122 | 123 | output = model( 124 | input_ids=input_ids, 125 | labels=labels, 126 | attention_mask=attention_mask, 127 | images=images, 128 | **kwargs 129 | ) 130 | 131 | if token_weighted: 132 | token_log_prob = get_batch_logps(output.logits, labels, return_per_token_logp=True) 133 | return token_log_prob 134 | else: 135 | log_prob, average_log_prob = get_batch_logps(output.logits, labels, return_per_token_logp=False) 136 | if dpo_use_average: 137 | return average_log_prob 138 | return log_prob 139 | 140 | 141 | def compute_weighted_logp(per_token_logp, labels, token_weight, use_average): 142 | loss_mask = (labels[:, 1:].clone() != -100) 143 | # print(f'compute wlogp {labels.shape} {loss_mask.shape}, {token_weight.shape}, {per_token_logp.shape}') 144 | weighted_mask = token_weight * loss_mask 145 | logp = (per_token_logp * weighted_mask).sum(-1) 146 | 147 | average_logp = logp / weighted_mask.sum(-1) 148 | if use_average: 149 | return average_logp 150 | return logp 151 | 152 | 153 | class MuffinDPOTrainer(MuffinTrainer): 154 | 155 | def compute_loss(self, model: Module, inputs: dict, return_outputs=False): 156 | 157 | data_dict = inputs 158 | win_input_ids = data_dict.pop('win_input_ids') 159 | rej_input_ids = data_dict.pop('rej_input_ids') 160 | 161 | win_labels = data_dict.pop('win_labels') 162 | rej_labels = data_dict.pop('rej_labels') 163 | 164 | win_attention_mask = data_dict.pop('win_attention_mask') 165 | rej_attention_mask = data_dict.pop('rej_attention_mask') 166 | 167 | ref_win_avg_logp = data_dict.pop('ref_win_avg_logp') 168 | ref_rej_avg_logp = data_dict.pop('ref_rej_avg_logp') 169 | ref_win_logp = data_dict.pop('ref_win_logp') 170 | ref_rej_logp = data_dict.pop('ref_rej_logp') 171 | ref_win_per_token_logp = data_dict.pop('ref_win_per_token_logp') 172 | ref_rej_per_token_logp = data_dict.pop('ref_rej_per_token_logp') 173 | if self.args.dpo_use_average: 174 | ref_win_logp = ref_win_avg_logp 175 | ref_rej_logp = ref_rej_avg_logp 176 | 177 | beta = data_dict.pop('beta') 178 | images = data_dict.pop('images') 179 | 180 | concatenated_input_ids = data_dict.pop('concatenated_input_ids') 181 | concatenated_labels = data_dict.pop('concatenated_labels') 182 | concatenated_attention_mask = data_dict.pop('concatenated_attention_mask') 183 | concatenated_images = torch.cat([images, images], dim=0) 184 | 185 | win_token_weight = data_dict.pop('win_token_weight') 186 | rej_token_weight = data_dict.pop('rej_token_weight') 187 | concatenated_token_weight = data_dict.pop('concatenated_token_weight') 188 | 189 | concatenated_logp = forward_DPO(model, 190 | concatenated_input_ids, 191 | concatenated_labels, 192 | concatenated_attention_mask, 193 | concatenated_images, 194 | token_weighted=self.args.dpo_token_weighted, 195 | dpo_use_average=self.args.dpo_use_average, 196 | **data_dict) 197 | win_size = win_input_ids.shape[0] 198 | rej_size = rej_input_ids.shape[0] 199 | assert win_size == rej_size 200 | 201 | if self.args.dpo_token_weighted: 202 | ref_win_logp = compute_weighted_logp(ref_win_per_token_logp, win_labels, win_token_weight, self.args.dpo_use_average) 203 | ref_rej_logp = compute_weighted_logp(ref_rej_per_token_logp, rej_labels, rej_token_weight, self.args.dpo_use_average) 204 | concatenated_logp = compute_weighted_logp(concatenated_logp, concatenated_labels,concatenated_token_weight, self.args.dpo_use_average) 205 | 206 | if torch.any(torch.isnan(ref_win_logp)): 207 | print(f'ref_win_logp fail', flush=True) 208 | exit() 209 | if torch.any(torch.isnan(ref_rej_logp)): 210 | print(f'ref_rej_logp fail', flush=True) 211 | exit() 212 | if torch.any(torch.isnan(concatenated_logp)): 213 | print(f'concatenated_logp fail', flush=True) 214 | exit() 215 | 216 | 217 | policy_win_logp, policy_rej_logp = concatenated_logp.split([win_size, rej_size]) 218 | 219 | 220 | if self.args.past_index >= 0: 221 | raise NotImplementedError 222 | 223 | losses, chosen_rewards, rejected_rewards = dpo_loss(policy_win_logp, 224 | policy_rej_logp, 225 | ref_win_logp, 226 | ref_rej_logp, 227 | beta=beta) 228 | reward_accuracies = (chosen_rewards > rejected_rewards).float() 229 | # loss = losses.mean() 230 | 231 | # do SFT 232 | # loss = - policy_win_logp.mean() 233 | SFT_weight = float(os.environ.get('SFT_weight', 0.0)) 234 | DPO_weight = float(os.environ.get('DPO_weight', 1.0)) 235 | loss = DPO_weight * losses.mean() - SFT_weight * policy_win_logp.mean() 236 | # loss = DPO_weight * losses.mean() - SFT_weight * policy_rej_logp.mean() 237 | 238 | train_test = 'train' if model.training else 'test' 239 | metrics = {} 240 | metrics[f'rewards_{train_test}/chosen'] = self._nested_gather(chosen_rewards.mean()).mean().item() 241 | metrics[f'rewards_{train_test}/rejected'] = self._nested_gather(rejected_rewards.mean()).mean().item() 242 | metrics[f'rewards_{train_test}/accuracies'] = self._nested_gather(reward_accuracies.mean()).mean().item() 243 | metrics[f'rewards_{train_test}/margins'] = metrics[f'rewards_{train_test}/chosen'] - metrics[f'rewards_{train_test}/rejected'] 244 | metrics[f'logps_{train_test}/rejected'] = self._nested_gather(policy_rej_logp.mean()).mean().item() 245 | metrics[f'logps_{train_test}/chosen'] = self._nested_gather(policy_win_logp.mean()).mean().item() 246 | metrics[f'logps_{train_test}/ref_rejected'] = self._nested_gather(ref_rej_logp.mean()).mean().item() 247 | metrics[f'logps_{train_test}/ref_chosen'] = self._nested_gather(ref_win_logp.mean()).mean().item() 248 | # metrics[f'batch_size'] = len(win_labels) 249 | self.log(metrics) 250 | 251 | return loss 252 | -------------------------------------------------------------------------------- /muffin/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from muffin.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True) 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "muffin" 7 | version = "1.0" 8 | description = "MuffinMultimodal foundation models are found to be “out-of-the-box” multimodal interfaces for LLMs" 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "torch==2.0.1", "torchvision==0.15.2", 17 | "tokenizers>=0.12.1,<0.14", "sentencepiece==0.1.99", "shortuuid", 18 | "accelerate==0.21.0", "peft==0.4.0", "bitsandbytes==0.41.0", 19 | "pydantic<2,>=1", "markdown2[all]", "numpy", "scikit-learn==1.2.2", 20 | "gradio==3.35.2", "gradio_client==0.2.9", 21 | "requests", "httpx==0.24.0", "uvicorn", "fastapi", 22 | "einops==0.6.1", "einops-exts==0.0.4", 23 | "openai==0.28.1", "timm==0.4.12", "torchscale==0.2.0", "opencv-python==4.5.5.62", 24 | "protobuf==4.24.3", 25 | 26 | ] 27 | 28 | [project.optional-dependencies] 29 | train = ["ninja", "wandb"] 30 | 31 | [project.urls] 32 | "Homepage" = "https://rlhf-v.github.io" 33 | "Bug Tracker" = "https://github.com/yiranyyu/RLHF-V/issues" 34 | 35 | [tool.setuptools.packages.find] 36 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 37 | 38 | [tool.wheel] 39 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # 🥞 Muffin 4 | 5 | **Multimodal foundation models are found to be “out-of-the-box” multimodal interfaces for LLMs** 6 |

7 | 🔥 Try It Now! • 8 | 🔧 Install • 9 | 🙌 Build Your Muffin • 10 | 📄 Our Paper 11 |

12 | 13 |
14 | 15 | 16 | ## Reformulating Vision-Language Foundation Models and Datasets Towards Universal Multimodal Assistants 17 | 18 | ### News 19 | 20 | * [12/06] We open-source the weights of Muffin trained with more SFT data at [hugging face](https://huggingface.co/Yirany/RLHF-V_v0_SFT/tree/main) now. The model achieves 80.0 on VQAv2 test-dev split and strong chat ability. 21 | * [12/04] Our recent work [RLHF-V](https://rlhf-v.github.io/) is released now, which is built upon Muffin and achieve SoTA results in preventing hallucination! 22 | 23 | ### Models 24 | 25 | Multimodal foundation models (MFMs) are natives of multiple modalities and can serve as better bridegs from different modalities towards LLMs. It is because MFMs can naturally encode the feature from other modalities (such as vision, audio) into the same space as language, which consequently better activate the capability of LLMs. We also list some [examples](#examples) generated by our Muffin to demonstrate the effectiveness. 26 | 27 | ### Datasets 28 | 29 | Demonstration of the framework designed for constructing the UniMM-Chat dataset. UniMM-Chat combines various VL datasets to generate knowledge-intensive dialogues. Text spans highlighted in colors indicate different knowledge from origin annotations which are required to answer the questions. 30 | 31 | ![Reformulating-Datasets](./assets/images/reformulating_datasets.jpg) 32 | 33 | ### Benchmarks 34 | 35 | ![Benchmark-Performance](./assets/images/Benchmark%20Performance.jpg) 36 | 37 | 38 | ## Examples 39 | 40 | We list some representative cases to demonstrate the power of Muffin. We refer readers to our paper for more examples, and you can deploy a web-demo following the [instructions](#serving). 41 | 42 | ### Outside Knowledge Visual QA 43 | 44 | ![case](./assets/images/case-onion.jpg) 45 | 46 | 47 | ### Culture 48 | 49 | ![case](./assets/images/case-attire.jpg) 50 | 51 | ### Helpful 52 | 53 | ![case](./assets/images/case-helpful.jpg) 54 | 55 | 56 | ## Contents 57 | - [Data](#data) 58 | - [Install](#install) 59 | - [Muffin Weights](#muffin-weights) 60 | - [Serving](#serving) 61 | - [Evaluation](#evaluation) 62 | - [Fine-tuning](#fine-tuning) 63 | 64 | ## Data 65 | 66 | ### Pre-training Data 67 | 68 | The pre-training data used in this release are all public data include CC-3M, CC-12M, COCO, Visual Genome, LAION-COCO. 69 | 70 | ### Instruction Following Data 71 | 72 | We present the [UniMM-Chat](https://huggingface.co/datasets/Yirany/UniMM-Chat) dataset, which is constructed for visual instruction tuning and expected to be helpful in improving models' capabilities to solve different tasks without harming their generation ability. 73 | 74 | During training, we use both the UniMM-Chat and the [LLaVA-Instruct-150K]((https://github.com/haotian-liu/LLaVA#data-download)) dataset during training. To download our langauge-image multimodal instruction-folllowing dataset , please run the following script: 75 | 76 | ```bash 77 | bash ./script/download_data.sh 78 | ``` 79 | 80 | 81 | ## Install 82 | 83 | 1. Clone this repository and navigate to source folder 84 | ```bash 85 | git clone https://github.com/thunlp/muffin 86 | cd muffin 87 | ``` 88 | 89 | 2. Download training data and install dependencies. 90 | 91 | ```Shell 92 | bash download_data.sh 93 | 94 | echo "Creating conda environment" 95 | conda create -n muffin python=3.10 96 | conda activate muffin 97 | 98 | echo "Installing dependencies" 99 | pip install -e . 100 | 101 | # Install specific version of transformers to make sure you can reproduce the experimental results in our papers 102 | git clone --recursive git@github.com:huggingface/transformers.git 103 | cd transformers 104 | git checkout a92e0ad2e20ef4ce28410b5e05c5d63a5a304e65 105 | pip install . 106 | cd .. 107 | ``` 108 | 109 | ### Training 110 | 111 | Install additional packages if you need to do training. 112 | 113 | ```Shell 114 | git clone --recursive https://github.com/Dao-AILab/flash-attention.git 115 | cd flash-attention 116 | 117 | # Uncomment the following line if you have CUDA version <= 11.4 118 | # git checkout ad11394 119 | 120 | MAX_JOBS=8 python setup.py install 121 | cd .. 122 | ``` 123 | 124 | ## Muffin Weights 125 | 126 | We release Muffin weights on [Hugging Face](https://huggingface.co/Yirany/Muffin-13B/tree/main). To load Muffin for inference: 127 | 128 | ```py 129 | from muffin.eval.muffin_vqa import init_muffin 130 | 131 | model, image_processor, image_token_len, tokenizer = init_muffin('Yirany/Muffin-13B') 132 | ``` 133 | 134 | 135 | ### Muffin pretrained weights 136 | 137 | We also provice the [pretrained Muffin weights](https://huggingface.co/Yirany/Muffin-13B-pretrained/tree/main) (uploading, will be available soon) without training on instruction following data. 138 | 139 | ## Serving 140 | 141 | ### Web UI 142 | 143 | #### Launch a controller 144 | 145 | ```Shell 146 | python -m muffin.serve.controller --host 0.0.0.0 --port 10000 147 | ``` 148 | 149 | #### Launch a model worker 150 | ```Shell 151 | python -m muffin.serve.muffin_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path Yirany/Muffin-13B --multi-modal 152 | ``` 153 | Wait until the process finishes loading the model and you see "Uvicorn running on ...". 154 | 155 | #### Launch a model worker (Multiple GPUs, when GPU VRAM <= 24GB) 156 | 157 | If your the VRAM of your GPU is less than 24GB (e.g., RTX 3090, RTX 4090, etc.), you may try running it with multiple GPUs. 158 | 159 | ```Shell 160 | python -m muffin.serve.muffin_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path Yirany/Muffin-13B --multi-modal --num-gpus 2 161 | ``` 162 | Wait until the process finishes loading the model and you see "Uvicorn running on ...". 163 | 164 | 165 | #### Launch a gradio web server. 166 | 167 | ```Shell 168 | python -m muffin.serve.gradio_web_server --controller http://localhost:10000 169 | ``` 170 | #### You can open your browser and chat with a model now. 171 | 172 | ## Evaluation 173 | 174 | Our GPT-assisted evaluation pipeline for multimodal modeling is provided for a comprehensive understanding of the capabilities of vision-language models. Please see our [paper](https://arxiv.org/abs/2310.00653) for more details. 175 | 176 | 1. Generate responses 177 | 178 | ```Shell 179 | bash ./script/eval/eval_muffin_qa.sh your_checkpoint_dir 180 | ``` 181 | 182 | 2. Evaluate the generated responses. 183 | 184 | ```Shell 185 | bash ./script/eval/batch_gpt4_review.sh your_checkpoint_dir 186 | ``` 187 | 188 | 3. Summarize the evaluation results 189 | 190 | ```Shell 191 | python ./eval/summarize_gpt_llava_review.py your_checkpoint_dir 192 | python ./eval/summarize_gpt_unimm-bench_review.py your_checkpoint_dir 193 | ``` 194 | 195 | ## Fine-tuning 196 | 197 | 198 | ```Shell 199 | bash ./script/train/run_unimm-chat.sh ./output/checkpoints master finetune_muffin ./data/coco_images 200 | ``` 201 | 202 | ## RLHF 203 | 204 | ```Shell 205 | ref_model=./RLHF-V_SFT_weight 206 | 207 | bash ./script/train/run_RLHFV.sh \ 208 | ./RLHFV_checkpoints/dpo_exp \ 209 | master \ 210 | RLHFV \ 211 | 1.1 \ 212 | $ref_model \ 213 | ./RLHF-V-Dataset \ 214 | RLHFV_SFT \ 215 | 2160 \ 216 | 360 \ 217 | 0.1 \ 218 | False \ 219 | True 220 | ``` 221 | 222 | ## Licenses 223 | 224 | 225 | [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE) 226 | [![Data License](https://img.shields.io/badge/Data%20License-CC%20By%20NC%204.0-red.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/DATA_LICENSE) 227 | 228 | **Usage and License Notices**: The data, code and checkpoint is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaMA, Vicuna and Chat GPT. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes. 229 | 230 | 231 | ## Acknowledgement 232 | 233 | - [LLaVA](https://github.com/haotian-liu/LLaVA/tree/main/llava): the codebase we built upon, and our base model Vicuna-13B that has the amazing language capabilities! 234 | 235 | If you find Muffin useful for your your research and applications, please cite using this BibTeX: 236 | ```bibtex 237 | @misc{yu2023muffin, 238 | title={Reformulating Vision-Language Foundation Models and Datasets Towards Universal Multimodal Assistants}, 239 | author={Tianyu Yu, Jinyi Hu, Yuan Yao, Haoye Zhang, Yue Zhao, Chongyi Wang, Shan Wang, Yinxv Pan, Jiao Xue, Dahai Li, Zhiyuan Liu, Hai-Tao Zheng, Maosong Sun}, 240 | publisher={arXiv:2310.00653}, 241 | year={2023}, 242 | } 243 | ``` 244 | -------------------------------------------------------------------------------- /script/download_data.sh: -------------------------------------------------------------------------------- 1 | echo "Downloading UniMM-Chat" 2 | mkdir ./data 3 | wget https://huggingface.co/datasets/Yirany/UniMM-Chat/resolve/main/UniMM-Chat-v0.json ./data/unimm-chat.json -------------------------------------------------------------------------------- /script/eval/batch_gpt4_review.sh: -------------------------------------------------------------------------------- 1 | SOURCE_DIR=$1 2 | limit=100 3 | C=0 4 | process_limit=2 5 | force=no_force 6 | 7 | while IFS= read -r -d '' -u 9 8 | do 9 | # echo $REPLY 10 | if [[ $REPLY == *$prefix*unimm-bench_answer.jsonl ]]; then 11 | echo "EVAL unimm-bench "$REPLY 12 | if [[ $force == force ]]; then 13 | rm -f $REPLY.unimm-bench_gpt4_eval.jsonl 14 | fi 15 | python ./eval/eval_gpt_review_unimm-bench.py \ 16 | --question ./data/unimm-bench.json \ 17 | --answer $REPLY \ 18 | --rule ./eval/data/rule.jsonfile \ 19 | --output $REPLY.unimm-bench_gpt4_eval.jsonl \ 20 | --limit $limit & 21 | sleep 5 22 | 23 | C=$((C+1)) 24 | echo "C=$C" 25 | if [[ $C == $process_limit ]]; then 26 | echo "Wait for next iteration" 27 | C=0 28 | wait 29 | fi 30 | fi 31 | done 9< <( find $SOURCE_DIR -type f -name "*answer*" -exec printf '%s\0' {} + ) 32 | 33 | wait 34 | 35 | while IFS= read -r -d '' -u 9 36 | do 37 | # echo $REPLY 38 | if [[ $REPLY == *$prefix*llava_test_answer.jsonl ]]; then 39 | if [[ $force == force ]]; then 40 | rm -f $REPLY.llava_test_gpt4.jsonl 41 | fi 42 | echo "EVAL qa90 "$REPLY 43 | python ./eval/eval_gpt_review_visual.py \ 44 | --question ./eval/data/qa90_questions.jsonl \ 45 | --context ./eval/data/caps_boxes_coco2014_val_80.jsonl \ 46 | --answer-list \ 47 | ./eval/data/qa90_gpt4_answer.jsonl \ 48 | $REPLY \ 49 | --rule ./eval/data/rule.jsonfile \ 50 | --output $REPLY.llava_test_gpt4.jsonl & 51 | sleep 5 52 | 53 | C=$((C+1)) 54 | echo "C=$C" 55 | if [[ $C == $process_limit ]]; then 56 | echo "Wait for next iteration" 57 | C=0 58 | wait 59 | fi 60 | fi 61 | done 9< <( find $SOURCE_DIR -type f -name "*answer*" -exec printf '%s\0' {} + ) 62 | 63 | wait -------------------------------------------------------------------------------- /script/eval/eval_muffin_inference_logp.sh: -------------------------------------------------------------------------------- 1 | ###===> install dependencies 2 | export PYTHONPATH=$PYTHONPATH:`realpath .` 3 | export TORCH_DISTRIBUTED_DEBUG=DETAIL 4 | echo "pythonpath="$PYTHONPATH 5 | ###<=== 6 | 7 | 8 | ckpt_path=$1 9 | base_dir=$2 10 | to_process_tsv_list="$3 " 11 | 12 | echo $to_process_tsv_list 13 | 14 | # save_logp_name is the suffix to add to the logp file, we defaultly use 'dpo_with_rlhf-v-sft_logp_train' 15 | save_logp_name='dpo_with_rlhf-v-sft_logp_train' 16 | 17 | C=0 18 | 19 | for tsv_file in $to_process_tsv_list; 20 | do 21 | echo "PWD at `pwd` checkpoint: "$ckpt_path 22 | 23 | CUDA_VISIBLE_DEVICES=$C python ./muffin/eval/muffin_inference_logp.py \ 24 | --model-name $ckpt_path \ 25 | --data-dir $base_dir \ 26 | --tsv-file $tsv_file \ 27 | --logp-file $save_logp_name 28 | C=$((C+1)) 29 | echo "C=$C" 30 | if [[ $C == 8 ]]; then 31 | echo "Wait for next iteration" 32 | C=0 33 | wait 34 | fi 35 | done -------------------------------------------------------------------------------- /script/eval/eval_muffin_qa.sh: -------------------------------------------------------------------------------- 1 | ###===> install dependencies 2 | export PYTHONPATH=$PYTHONPATH:`realpath .` 3 | export TORCH_DISTRIBUTED_DEBUG=DETAIL 4 | echo "pythonpath="$PYTHONPATH 5 | ###<=== 6 | 7 | 8 | base_dir=$1 9 | to_process_ckpt_list="$1 " 10 | # to_process_ckpt_list+=" $base_dir/checkpoint-40 $base_dir/checkpoint-80 $base_dir/checkpoint-120 $base_dir/checkpoint-160" 11 | # to_process_ckpt_list+=" $base_dir/checkpoint-200 $base_dir/checkpoint-600 $base_dir/checkpoint-1000 $base_dir/checkpoint-1400 $base_dir/checkpoint-1800 $base_dir/checkpoint-2200 $base_dir/checkpoint-2600 $base_dir/checkpoint-3000" 12 | # to_process_ckpt_list+=" $base_dir/checkpoint-400 $base_dir/checkpoint-800 $base_dir/checkpoint-1200 $base_dir/checkpoint-1600 $base_dir/checkpoint-2000 $base_dir/checkpoint-2400 $base_dir/checkpoint-2800 $base_dir/checkpoint-3200" 13 | # to_process_ckpt_list+=" $base_dir/checkpoint-3600 $base_dir/checkpoint-4000 $base_dir/checkpoint-4400 $base_dir/checkpoint-4800 $base_dir/checkpoint-5200 $base_dir/checkpoint-5600 $base_dir/checkpoint-6000 $base_dir/checkpoint-6400" 14 | 15 | # ===========> LLaVA Test Set <============== 16 | 17 | answer_file_name="llava_test_answer.jsonl" 18 | 19 | filered_to_process_ckpt_list="" 20 | for ckpt in $to_process_ckpt_list; 21 | do 22 | [[ ! -d $ckpt ]] && continue 23 | 24 | echo $ckpt/$answer_file_name 25 | if [[ ! -f $ckpt/$answer_file_name ]]; then 26 | filered_to_process_ckpt_list=$filered_to_process_ckpt_list" "$ckpt 27 | fi 28 | # filered_to_process_ckpt_list=$filered_to_process_ckpt_list" "$ckpt 29 | done 30 | echo "Process these checkpoints: [$filered_to_process_ckpt_list]" 31 | 32 | 33 | C=0 34 | q_file=./eval/data/qa90_questions_with_image.jsonl 35 | 36 | for ckpt_path in $filered_to_process_ckpt_list; 37 | do 38 | answer_file=$ckpt_path/$answer_file_name 39 | echo "PWD at `pwd` checkpoint: "$ckpt_path" output to: "$answer_file 40 | 41 | echo "Start generating answers for $ckpt_path" 42 | CUDA_VISIBLE_DEVICES=$C python ./muffin/eval/muffin_vqa.py \ 43 | --model-name $ckpt_path \ 44 | --question-file $q_file \ 45 | --answers-file $answer_file & 46 | C=$((C+1)) 47 | echo "C=$C" 48 | if [[ $C == 8 ]]; then 49 | echo "Wait for next iteration" 50 | C=0 51 | wait 52 | fi 53 | done 54 | wait 55 | 56 | # =========> unimm-bench <============ 57 | 58 | answer_file_name="unimm-bench_answer.jsonl" 59 | eval_file_name="unimm-bench_gpt4_eval.jsonl" 60 | 61 | filered_to_process_ckpt_list="" 62 | for ckpt in $to_process_ckpt_list; 63 | do 64 | [[ ! -d $ckpt ]] && continue 65 | 66 | echo $ckpt/$answer_file_name 67 | if [[ ! -f $ckpt/$answer_file_name ]]; then 68 | filered_to_process_ckpt_list=$filered_to_process_ckpt_list" "$ckpt 69 | fi 70 | # filered_to_process_ckpt_list=$filered_to_process_ckpt_list" "$ckpt 71 | done 72 | echo "Process these checkpoints: [$filered_to_process_ckpt_list]" 73 | 74 | 75 | C=0 76 | q_file=/data/public/multimodal/multimodal_data/MMU_Benchmark/keep_400_vqa_eval.json 77 | 78 | for ckpt_path in $filered_to_process_ckpt_list; 79 | do 80 | answer_file=$ckpt_path/$answer_file_name 81 | echo "PWD at `pwd` checkpoint: "$ckpt_path" output to: "$answer_file 82 | 83 | CUDA_VISIBLE_DEVICES=$C python ./muffin/eval/muffin_vqa.py \ 84 | --model-name $ckpt_path \ 85 | --question-file $q_file \ 86 | --answers-file $answer_file & 87 | C=$((C+1)) 88 | echo "C=$C" 89 | if [[ $C == 8 ]]; then 90 | echo "Wait for next iteration" 91 | C=0 92 | wait 93 | fi 94 | done 95 | wait 96 | 97 | 98 | echo "========>Start GPT4 Evaluating<========" 99 | bash ./script/eval/batch_gpt4_review.sh $base_dir 4 100 | python ./eval/summarize_gpt_unimm-bench_review.py $base_dir > $base_dir/unimm-bench_scores.txt 101 | python ./eval/summarize_gpt_llava_review.py $base_dir >> $base_dir/llava_test_scores.txt 102 | 103 | # Print Log 104 | echo Scores are: 105 | cat $base_dir/unimm-bench_scores.txt 106 | cat $base_dir/llava_test_scores.txt 107 | echo done 108 | -------------------------------------------------------------------------------- /script/eval/eval_muffin_qaFILE.sh: -------------------------------------------------------------------------------- 1 | ###===> install dependencies 2 | export PYTHONPATH=$PYTHONPATH:`realpath .` 3 | export TORCH_DISTRIBUTED_DEBUG=DETAIL 4 | echo "pythonpath="$PYTHONPATH 5 | ###<=== 6 | 7 | to_process_ckpt_list="$1 " 8 | 9 | 10 | q_file=$2 11 | answer_file_name=$3 12 | 13 | filered_to_process_ckpt_list="" 14 | for ckpt in $to_process_ckpt_list; 15 | do 16 | [[ ! -d $ckpt ]] && continue 17 | 18 | echo $ckpt/$answer_file_name 19 | if [[ ! -f $ckpt/$answer_file_name ]]; then 20 | filered_to_process_ckpt_list=$filered_to_process_ckpt_list" "$ckpt 21 | fi 22 | # filered_to_process_ckpt_list=$filered_to_process_ckpt_list" "$ckpt 23 | done 24 | echo "Process these checkpoints: [$filered_to_process_ckpt_list]" 25 | 26 | 27 | C=0 28 | 29 | for ckpt_path in $filered_to_process_ckpt_list; 30 | do 31 | answer_file=$ckpt_path/$answer_file_name 32 | echo "PWD at `pwd` checkpoint: "$ckpt_path" output to: "$answer_file 33 | 34 | CUDA_VISIBLE_DEVICES=$C python ./muffin/eval/muffin_vqa.py \ 35 | --model-name $ckpt_path \ 36 | --question-file $q_file \ 37 | --answers-file $answer_file & 38 | C=$((C+1)) 39 | echo "C=$C" 40 | if [[ $C == 8 ]]; then 41 | echo "Wait for next iteration" 42 | C=0 43 | wait 44 | fi 45 | done 46 | wait 47 | echo "========>Done generating answers<========" 48 | 49 | -------------------------------------------------------------------------------- /script/init.sh: -------------------------------------------------------------------------------- 1 | ## Download UniMM-Chat Data 2 | bash download_data.sh 3 | 4 | ## Prepare Environment 5 | echo "Creating conda environment" 6 | conda create -n muffin python=3.10 7 | conda activate muffin 8 | 9 | echo "Installing dependencies" 10 | pip install -e . 11 | 12 | echo "Installing flash-attention" 13 | git clone --recursive https://github.com/Dao-AILab/flash-attention.git 14 | cd flash-attention 15 | 16 | # uncomment the following line if you have CUDA version < 11.6 17 | # git checkout 757058d 18 | 19 | MAX_JOBS=8 python setup.py install 20 | cd .. 21 | 22 | -------------------------------------------------------------------------------- /script/train/ddpo_muffin_PATH.sh: -------------------------------------------------------------------------------- 1 | 2 | ###===> Model config 3 | sft_vision_tower=$6 4 | llm_path=not_used 5 | ###<=== 6 | 7 | export TORCH_DISTRIBUTED_DEBUG=DETAIL 8 | echo "pythonpath="$PYTHONPATH 9 | sleep 5 10 | 11 | 12 | GPUS_PER_NODE=8 13 | NUM_NODE=${11} 14 | RDZV_ENDPOINT=${13} 15 | RUNNER="torchrun --nnodes=${NUM_NODE} --nproc_per_node=${GPUS_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${RDZV_ENDPOINT}" 16 | 17 | echo RUNNER=$RUNNER 18 | 19 | ###===> Start preparing data 20 | data_dir=$7 21 | ref_name=$8 22 | echo Data config: $data_dir $ref_name 23 | ###<=== 24 | 25 | 26 | ###===> Checkpointing 27 | num_epoch=10 # not used indeed 28 | 29 | num_save=16 30 | save_step=${10} 31 | max_step=$9 32 | task_name=muffin_13b_DPO 33 | exp_name=$5-$sft_data-$sft_data_weight-$sft_vision_tower 34 | sft_output_dir=${12}/$task_name-$exp_name/checkpionts 35 | sft_logging_dir=${12}/$task_name-$exp_name/log 36 | 37 | echo "sft_output_dir="$sft_output_dir" sft_logging_dir="$sft_logging_dir 38 | ###<=== 39 | 40 | ###===> SFT 41 | 42 | pretrain_ckpt=$1 43 | echo "Load from "$pretrain_ckpt 44 | 45 | $RUNNER ./muffin/train/train_mem_muffin.py \ 46 | --model_name_or_path $pretrain_ckpt \ 47 | --image_folder not_used \ 48 | --vision_tower $sft_vision_tower \ 49 | --pretrain_mm_mlp_adapter not_used \ 50 | --fully_tune True \ 51 | --mm_vision_select_layer -2 \ 52 | --mm_use_im_start_end True \ 53 | --bf16 True \ 54 | --output_dir $sft_output_dir \ 55 | --num_train_epochs $num_epoch \ 56 | --per_device_train_batch_size $2 \ 57 | --per_device_eval_batch_size 4 \ 58 | --gradient_accumulation_steps $3 \ 59 | --evaluation_strategy "no" \ 60 | --save_strategy "steps" \ 61 | --save_steps $save_step \ 62 | --save_total_limit $num_save \ 63 | --data_source_names '' \ 64 | --data_source_weights '' \ 65 | --data_dir $data_dir \ 66 | --ref_name $ref_name \ 67 | --max_steps $max_step \ 68 | --learning_rate 5e-7 \ 69 | --weight_decay 0. \ 70 | --warmup_ratio 0.03 \ 71 | --lr_scheduler_type "cosine" \ 72 | --logging_steps 5 \ 73 | --logging_dir $sft_logging_dir \ 74 | --tf32 True \ 75 | --fsdp "full_shard auto_wrap" \ 76 | --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ 77 | --model_max_length 2048 \ 78 | --gradient_checkpointing True \ 79 | --lazy_preprocess True \ 80 | --task DPO \ 81 | --report_to wandb \ 82 | --run_name $5 \ 83 | --dataloader_num_workers 10 \ 84 | --dpo_use_average ${14} \ 85 | --dpo_token_weighted ${15} \ 86 | --dpo_token_weight ${16} \ 87 | --dpo_beta ${17} 88 | ###<=== 89 | 90 | -------------------------------------------------------------------------------- /script/train/run_RLHFV.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --partition=gpu3-2 4 | #SBATCH --nodes=1 5 | #SBATCH --ntasks-per-node=8 6 | #SBATCH --gres=gpu:8 7 | #SBATCH --cpus-per-task=16 8 | #SBATCH --nodelist=g3013 9 | #SBATCH --output=./_temp/slurm_output/%j.%x.out 10 | 11 | echo Working Directory at `pwd` 12 | echo Bash at `which bash` 13 | echo Python at `which python` 14 | 15 | export PYTHONPATH=$PYTHONPATH:`realpath .` 16 | 17 | nvidia-smi 18 | 19 | root_dir=$1 # directory to save log and checkpoints 20 | slave_or_master=$2 21 | 22 | 23 | MASTER_ADDR=`hostname` 24 | MASTER_PORT=13245 25 | rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT 26 | 27 | echo $slave_or_master local_addr=$rdzv_endpoint 28 | 29 | bsz=1 30 | num_node=1 31 | grad_acc=1 32 | 33 | exp_name=$3 34 | max_step=$8 35 | save_step=$9 36 | 37 | data_dir=$6 38 | ref_name=$7 39 | 40 | dpo_use_average=${11} 41 | dpo_token_weighted=${12} 42 | dpo_token_weight=$4 43 | dpo_beta=${10} 44 | echo ddpo weight is $4 beta is $dpo_beta 45 | 46 | ref_model=$5 47 | 48 | bash ./script/train/ddpo_muffin_PATH.sh \ 49 | $ref_model \ 50 | $bsz \ 51 | $grad_acc \ 52 | not_used_param \ 53 | $exp_name \ 54 | beit3_large_patch16_448 \ 55 | $data_dir \ 56 | $ref_name \ 57 | $max_step \ 58 | $save_step \ 59 | $num_node \ 60 | $root_dir \ 61 | $rdzv_endpoint \ 62 | $dpo_use_average \ 63 | $dpo_token_weighted \ 64 | $dpo_token_weight \ 65 | $dpo_beta 66 | -------------------------------------------------------------------------------- /script/train/run_unimm-chat.sh: -------------------------------------------------------------------------------- 1 | echo Working Directory at `pwd` 2 | echo Bash at `which bash` 3 | echo Python at `which python` 4 | 5 | export PYTHONPATH=$PYTHONPATH:`realpath .` 6 | 7 | nvidia-smi 8 | 9 | root_dir=$1 # directory to save log and checkpoints 10 | slave_or_master=$2 11 | 12 | 13 | MASTER_ADDR=`hostname` 14 | MASTER_PORT=13245 15 | rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT 16 | 17 | echo $slave_or_master local_addr=$rdzv_endpoint 18 | 19 | if [[ $slave_or_master == "slave" ]]; then 20 | echo Slave rdzv_endpoint is "<${4}>" 21 | rdzv_endpoint=$4 22 | fi 23 | 24 | 25 | bsz=4 26 | num_node=1 27 | grad_acc=16 28 | 29 | exp_name=$3 30 | max_step=1600 31 | save_step=400 32 | sft_data="unimm-chat" 33 | 34 | image_folder=$4 35 | 36 | bash ./script/train/sft_muffin_PATH.sh \ 37 | /home/yutianyu/Muffin_checkpoints/310m_pretrain_100k_SFT_M3IT_2800_then_M3IT-LVA-UniMM-SYNTHEDOG_2800/ \ 38 | $bsz \ 39 | $grad_acc \ 40 | not_used_param \ 41 | $exp_name \ 42 | beit3_large_patch16_448 \ 43 | $sft_data \ 44 | 100 \ 45 | $max_step \ 46 | $save_step \ 47 | $num_node \ 48 | $root_dir \ 49 | $rdzv_endpoint \ 50 | $image_folder -------------------------------------------------------------------------------- /script/train/sft_muffin_PATH.sh: -------------------------------------------------------------------------------- 1 | 2 | ###===> Model config 3 | sft_vision_tower=$6 4 | llm_path=not_used 5 | ###<=== 6 | 7 | export TORCH_DISTRIBUTED_DEBUG=DETAIL 8 | echo "pythonpath="$PYTHONPATH 9 | sleep 5 10 | 11 | image_folder=${14} 12 | 13 | GPUS_PER_NODE=8 14 | NUM_NODE=${11} 15 | RDZV_ENDPOINT=${13} 16 | RUNNER="torchrun --nnodes=${NUM_NODE} --nproc_per_node=${GPUS_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${RDZV_ENDPOINT}" 17 | 18 | echo RUNNER=$RUNNER 19 | 20 | ###===> Start preparing data 21 | sft_data=$7 22 | sft_data_weight=$8 23 | echo Data config: $sft_data $sft_data_weight 24 | ###<=== 25 | 26 | 27 | ###===> Checkpointing 28 | num_epoch=10 # not used indeed 29 | 30 | num_save=10 31 | save_step=${10} 32 | max_step=$9 33 | task_name=muffin_13b_SFT 34 | exp_name=$5-$sft_data-$sft_data_weight-$sft_vision_tower 35 | sft_output_dir=${12}/$task_name-$exp_name/checkpionts 36 | sft_logging_dir=${12}/$task_name-$exp_name/log 37 | 38 | echo "sft_output_dir="$sft_output_dir" sft_logging_dir="$sft_logging_dir 39 | ###<=== 40 | 41 | ###===> SFT 42 | 43 | pretrain_ckpt=$1 44 | echo "Load from "$pretrain_ckpt 45 | 46 | $RUNNER ./muffin/train/train_mem_muffin.py \ 47 | --model_name_or_path $pretrain_ckpt \ 48 | --image_folder $image_folder \ 49 | --vision_tower $sft_vision_tower \ 50 | --pretrain_mm_mlp_adapter not_used \ 51 | --fully_tune True \ 52 | --mm_vision_select_layer -2 \ 53 | --mm_use_im_start_end True \ 54 | --bf16 True \ 55 | --output_dir $sft_output_dir \ 56 | --num_train_epochs $num_epoch \ 57 | --per_device_train_batch_size $2 \ 58 | --per_device_eval_batch_size 4 \ 59 | --gradient_accumulation_steps $3 \ 60 | --evaluation_strategy "no" \ 61 | --save_strategy "steps" \ 62 | --save_steps $save_step \ 63 | --save_total_limit $num_save \ 64 | --data_source_names $sft_data \ 65 | --data_source_weights $sft_data_weight \ 66 | --max_steps $max_step \ 67 | --learning_rate 2e-5 \ 68 | --weight_decay 0. \ 69 | --warmup_ratio 0.03 \ 70 | --lr_scheduler_type "cosine" \ 71 | --logging_steps 10 \ 72 | --logging_dir $sft_logging_dir \ 73 | --tf32 True \ 74 | --fsdp "full_shard auto_wrap" \ 75 | --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ 76 | --model_max_length 2048 \ 77 | --gradient_checkpointing True \ 78 | --lazy_preprocess True \ 79 | --report_to wandb \ 80 | --run_name $5 \ 81 | --dataloader_num_workers 10 82 | ###<=== 83 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/Muffin/85d72b4b04035b6b1cf4168814739ae9e5e6e1dc/utils/__init__.py -------------------------------------------------------------------------------- /utils/diff_lib.py: -------------------------------------------------------------------------------- 1 | import re 2 | import difflib 3 | 4 | # SGR color constants 5 | # rene-d 2018 6 | 7 | 8 | class Colors: 9 | """ ANSI color codes """ 10 | BLACK = "\033[0;30m" 11 | RED = "\033[0;31m" 12 | GREEN = "\033[0;32m" 13 | BROWN = "\033[0;33m" 14 | BLUE = "\033[0;34m" 15 | PURPLE = "\033[0;35m" 16 | CYAN = "\033[0;36m" 17 | LIGHT_GRAY = "\033[0;37m" 18 | DARK_GRAY = "\033[1;30m" 19 | LIGHT_RED = "\033[1;31m" 20 | LIGHT_GREEN = "\033[1;32m" 21 | YELLOW = "\033[1;33m" 22 | LIGHT_BLUE = "\033[1;34m" 23 | LIGHT_PURPLE = "\033[1;35m" 24 | LIGHT_CYAN = "\033[1;36m" 25 | LIGHT_WHITE = "\033[1;37m" 26 | BOLD = "\033[1m" 27 | FAINT = "\033[2m" 28 | ITALIC = "\033[3m" 29 | UNDERLINE = "\033[4m" 30 | BLINK = "\033[5m" 31 | NEGATIVE = "\033[7m" 32 | CROSSED = "\033[9m" 33 | END = "\033[0m" 34 | # cancel SGR codes if we don't write to a terminal 35 | if not __import__("sys").stdout.isatty(): 36 | for _ in dir(): 37 | if isinstance(_, str) and _[0] != "_": 38 | locals()[_] = "" 39 | else: 40 | # set Windows console in VT mode 41 | if __import__("platform").system() == "Windows": 42 | kernel32 = __import__("ctypes").windll.kernel32 43 | kernel32.SetConsoleMode(kernel32.GetStdHandle(-11), 7) 44 | del kernel32 45 | 46 | 47 | def split_into_clauses(text): 48 | # 使用正则表达式将文本分割成子句 49 | # 这个正则表达式匹配句点、感叹号和问号,并且这些符号后面是空格、换行符或文本的末尾 50 | clauses = re.split(r'(?<=[.!?,;(.")(!")(,")(?")])\s+', text) 51 | return clauses 52 | 53 | 54 | def split_into_words(text): 55 | # words = re.findall(r"[\w']+|[.,!?;]", text) 56 | words = text.split() 57 | return words 58 | 59 | 60 | def show_mark_compare_words(text1, text2): 61 | d = difflib.Differ() 62 | diff = d.compare(text1.split(" "), text2.split(" ")) 63 | return '\n'.join(diff) 64 | 65 | 66 | def show_mark_compare_substring(text1, text2): 67 | d = difflib.Differ() 68 | diff = d.compare(split_into_clauses(text1), split_into_clauses(text2)) 69 | return '\n'.join(diff) 70 | 71 | 72 | def complete_modification_spans(matches, length): 73 | i, j = 0, matches[0][0] 74 | out = [] 75 | for idx in range(0, len(matches)): 76 | out.append((i, j)) 77 | out.append(matches[idx]) 78 | if idx + 1 < len(matches): 79 | i, j = matches[idx][1], matches[idx + 1][0] 80 | else: 81 | i, j = matches[idx][1], length 82 | return out 83 | 84 | 85 | def colorize(raw_text, color): 86 | return f'{color}{raw_text}{Colors.END}' 87 | 88 | def split_mark(raw_text): 89 | return f'【【--{raw_text}--】】' 90 | 91 | def color_print_diff_single(seq, diff_spans, sep=' ', color=Colors.RED, use_split=False): 92 | seq = list(map(str, seq)) 93 | 94 | out = '' 95 | for idx, span in enumerate(diff_spans): 96 | text = sep.join(seq[span[0]: span[1]]) 97 | if not text: 98 | continue 99 | if idx % 2 == 0: 100 | if use_split: 101 | out += f'{sep}{split_mark(text)}' 102 | else: 103 | out += f'{sep}{colorize(text, color)}' 104 | else: 105 | if use_split: 106 | out += f'{sep}{text}' 107 | else: 108 | out += f'{sep}{colorize(text, Colors.BLACK)}' 109 | out = out[len(sep):] 110 | print(f'{out}') 111 | 112 | 113 | def get_match_info(a_seq, b_seq, min_match_size=1): 114 | sm = difflib.SequenceMatcher(None, a_seq, b_seq) 115 | 116 | mb = sm.get_matching_blocks() 117 | 118 | mb = [m for m in mb[:-1] if m[2] >= min_match_size] + [mb[-1]] 119 | 120 | a_matches = [(x[0], x[0] + x[2]) for x in mb] 121 | b_matches = [(x[1], x[1] + x[2]) for x in mb] 122 | return a_matches, b_matches 123 | 124 | 125 | def span_not_empty(span): 126 | return span[0] != span[1] 127 | 128 | 129 | def join_by_space(seq): 130 | return ' '.join([str(x) for x in seq]) 131 | 132 | 133 | def generate_modification_mapping_impl(a_seq, b_seq, a_spans, b_spans, do_print=False): 134 | assert len(a_spans) == len(b_spans) 135 | mod_map = {} 136 | 137 | if do_print: 138 | print(a_spans) 139 | print(b_spans) 140 | 141 | for idx, (a_span, b_span) in enumerate(zip(a_spans, b_spans)): 142 | if idx % 2 == 1: 143 | continue 144 | a_text = join_by_space(a_seq[a_span[0]: a_span[1]]) 145 | b_text = join_by_space(b_seq[b_span[0]: b_span[1]]) 146 | if do_print: 147 | print(f'@{colorize(a_text, Colors.RED)}@ ==> @{colorize(b_text, Colors.GREEN)}@') 148 | 149 | if span_not_empty(a_span) and span_not_empty(b_span): 150 | mod_map[a_span] = b_span 151 | 152 | return mod_map 153 | 154 | 155 | def generate_modification_mapping(a_seq, b_seq, min_match_size=3, do_print=False): 156 | a_matches, b_matches = get_match_info(a_seq, b_seq, min_match_size=min_match_size) 157 | 158 | a_spans = complete_modification_spans(a_matches, len(a_seq)) 159 | b_spans = complete_modification_spans(b_matches, len(b_seq)) 160 | return generate_modification_mapping_impl(a_seq, b_seq, a_spans, b_spans, do_print=do_print) 161 | 162 | 163 | def spans2ids(spans): 164 | ids = [] 165 | for span in spans: 166 | ids += list(range(span[0], span[1])) 167 | return ids 168 | 169 | 170 | def get_diff_ids(a_seq, b_seq, min_match_size=3): 171 | mod_map = generate_modification_mapping(a_seq, b_seq, min_match_size=min_match_size) 172 | a_modification_spans = list(mod_map.keys()) 173 | b_modification_spans = list(mod_map.values()) 174 | 175 | a_ids = sorted(set(spans2ids(a_modification_spans))) 176 | b_ids = sorted(set(spans2ids(b_modification_spans))) 177 | return a_ids, b_ids 178 | 179 | 180 | def color_print_diff_pair(a_seq, b_seq, min_match_size=1, sep=' ', use_split=False): 181 | a_matches, b_matches = get_match_info(a_seq, b_seq, min_match_size=min_match_size) 182 | 183 | a_spans = complete_modification_spans(a_matches, len(a_seq)) 184 | b_spans = complete_modification_spans(b_matches, len(b_seq)) 185 | 186 | color_print_diff_single(a_seq, a_spans, sep, Colors.RED, use_split) 187 | color_print_diff_single(b_seq, b_spans, sep, Colors.GREEN, use_split) 188 | 189 | # %% 190 | if __name__ == '__main__': 191 | from transformers import AutoTokenizer 192 | tkz = AutoTokenizer.from_pretrained( 193 | '/home/yutianyu/Muffin_checkpoints/310m_pretrain_100k_SFT_M3IT_2800_then_M3IT-LVA-UniMM-SYNTHEDOG_2800/') 194 | 195 | text1 = "这张图片展示了一个带有大窗户的客厅,窗户外是城市的夜景。在其中一个窗户前面,有一台大型平板电视放在一个带玻璃面板的木质电视柜上,营造出一个舒适的娱乐区。在前景中,一些植物为空间增添了一丝绿意,与窗外的城市夜景形成了对比。除了这些元素,图片右下角附近有一个水龙头。客厅里散落着各种瓶子,可能是居住者使用的清洁用品或功能物品。总体而言,这个场景看起来很温馨舒适,适合放松和社交。" 196 | text2 = "这张图片展示了一个带有大窗户的客厅,窗户外是广阔的城市夜景,无数灯光照亮了整个夜晚。在其中一个窗户前面,有一台黑色的大型平板电视放在一个低矮的木质电视柜上,电视柜上有玻璃面板和开放式的搁板,营造出一个舒适的娱乐区。在电视旁边,可以看到一个黑色的音响条,暗示着一个增强音频体验的设置。在前景中,一些绿意盎然的植物带有宽大的叶子,为空间增添了一丝自然的绿意,与窗外人工的城市夜景形成了愉悦的视觉对比。除了这些元素,图片右下角附近有一个弯曲设计的银色水龙头,暗示着客厅内可能有一个厨房或吧台区域。客厅里散落着各种瓶子和一个水容器,可能是居住者使用的饮品供应或功能物品,比如放在台面上的蓝色水瓶和靠近地面的白色瓶子。总体而言,这个场景看起来有人居住,温馨舒适,适合放松和社交,沙发上的深色靠垫增添了一种随意使用和居住者个人风格的感觉。" 197 | 198 | # text2 = "The image depicts a peaceful scene of a herd of sheep grazing in a lush green field. There are at least six sheep visible in the scene, with some standing closer to the foreground and others scattered further back in the field. The sheep are of various sizes, indicating a mix of adults and younger members of the herd.\n\nIn addition to the sheep, there are some trees in the field. They are located in the background of the image. The color of these trees is somewhat yellowish, which may reflect the season in which the photo was taken. Among the three sheep in the foreground, the middle one has a black face and limbs, and the sheep on the left and the middle one also have red numbers on their bodies." 199 | # text1 = "The image depicts a peaceful scene of a herd of sheep grazing in a lush green field. There are a total of nine sheep visible in the scene, with some standing closer to the foreground and others scattered further back in the field. The sheep are of various sizes, indicating a mix of adults and younger members of the herd.\n\nIn addition to the sheep, there are two dogs present in the field. One dog is located towards the left side of the image, while the other is on the right side. The dogs appear to be herding the sheep, ensuring they stay together and move in the desired direction." 200 | 201 | min_word_match = 1 202 | # seq1 = split_into_words(text1) 203 | # seq2 = split_into_words(text2) 204 | seq1 = text1 205 | seq2 = text2 206 | color_print_diff_pair(seq1, seq2, min_word_match, sep='', use_split=True) 207 | 208 | # seq1 = split_into_clauses(text1) 209 | # seq2 = split_into_clauses(text2) 210 | # color_print_diff_pair(seq1, seq2) 211 | 212 | # tokens_1 = tkz.encode(text1) 213 | # tokens_2 = tkz.encode(text2) 214 | # color_print_diff_pair(tokens_1, tokens_2, min_word_match) 215 | 216 | # tokens_1 = tkz.tokenize(text1) 217 | # tokens_2 = tkz.tokenize(text2) 218 | # color_print_diff_pair(tokens_1, tokens_2, min_word_match) 219 | 220 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright @2019 AI, ZHIHU Inc. (zhihu.com) 5 | # 6 | # @author: J.G. Chen 7 | # @date: 12/05/2019 8 | # 9 | """ Init a logger with options from env variables. 10 | 11 | - set log level by ``LOG_LEVEL``, default: ``INFO``; 12 | - output log message to file by ``LOG_FILE``, default: output to stdout. 13 | 14 | TODO: 15 | support setting log level and log file from config file. 16 | """ 17 | import logging 18 | import os 19 | 20 | _LOG_FMT = "[%(asctime)s][%(levelname).1s][%(process)d-%(name)s-%(filename)s:%(lineno)s]- %(message)s" 21 | _DATE_FMT = "%Y-%m-%d,%H:%M:%S" 22 | 23 | _logging_level = { 24 | "CRITICAL": logging.CRITICAL, 25 | "ERROR": logging.ERROR, 26 | "WARNING": logging.WARNING, 27 | "INFO": logging.INFO, 28 | # Distributed Level, print log in main proc only by default, set this level to print all messages. 29 | "DP": logging.INFO, 30 | "DEBUG": logging.DEBUG, 31 | None: logging.INFO, 32 | } 33 | 34 | _level = os.environ.get("LOG_LEVEL", "INFO").upper() 35 | 36 | 37 | class ShortNameFormatter(logging.Formatter): 38 | def format(self, record: logging.LogRecord): 39 | raw = record.name # save and restore for other formatters if desired 40 | parts = raw.split(".") 41 | record.name = ".".join(p[:3] for p in parts) if len(parts) > 1 else raw # keep first char for module name. 42 | result = super().format(record) 43 | record.name = raw 44 | return result 45 | 46 | 47 | class StyleAdapter(logging.LoggerAdapter): 48 | def __init__(self, logger, extra=None, style="default"): 49 | super().__init__(logger, extra or {}) 50 | self._style = style 51 | self._enable = self._enable() 52 | 53 | @classmethod 54 | def _enable(cls): 55 | # Note: to make this Logger more standalone, perform basic check without extra deps, e.g. tf/torch et al. 56 | worker = os.getenv("WORKER") 57 | rank = os.getenv("RANK") 58 | # not in DP/DDP mode or proc_id = "0" 59 | is_main = (not worker and not rank) or (worker == "0" or rank == "0") 60 | is_jeeves_job = os.getenv("JEEVES_JOB_ID") 61 | return _level in ["DEBUG", "DP"] or is_jeeves_job or is_main 62 | 63 | def _format(self, *msgs, color: str = None): 64 | if self._style == "legacy": 65 | if len(msgs) == 1: 66 | msg_str = msgs[0] 67 | else: 68 | msg_str = msgs[0] % msgs[1:] 69 | else: 70 | msg_str = ", ".join([str(msg) for msg in msgs]) 71 | 72 | if color: 73 | pass 74 | return msg_str 75 | 76 | def log(self, level, msg, *args, **kwargs): 77 | color = kwargs.pop("color", None) 78 | if self.isEnabledFor(level) and self._enable: 79 | msg, kwargs = self.process(msg, kwargs) 80 | msg_str = self._format(msg, *args, color=color) 81 | # noinspection PyProtectedMember 82 | self.logger._log(level, msg_str, (), **kwargs) 83 | 84 | 85 | def init_logger(name="ai", filename=os.environ.get("LOG_FILE", ""), fmt=_LOG_FMT, level=_level, style="legacy"): 86 | """init logger 87 | 88 | Args: 89 | name(str): optional, default: ai. 90 | filename(str): optional, default: "". Output log to file if specified, by default is set by env `LOG_FILE`. 91 | fmt(str): optional, default: _LOG_FMT 92 | level(str): optional, default: INFO 93 | style(str): optional, choice from ["print", "legacy"] 94 | - legacy: take first argument as a formatter, the remaining positional arguments as message values. 95 | this is consistent with the constraint of `logging` pkg 96 | - print: all positional arguments are message values which will be concatenated with ", " 97 | 98 | Returns: 99 | a logger instance 100 | 101 | Examples: 102 | >>> log = init_logger("log2stdout", level="INFO") 103 | >>> log.error("info") 104 | """ 105 | logger = logging.getLogger(name) 106 | logger.setLevel(_logging_level[level]) 107 | if fmt: 108 | # formatter = logging.Formatter(fmt, datefmt=_DATE_FMT) 109 | formatter = ShortNameFormatter(fmt, datefmt=_DATE_FMT) 110 | else: 111 | formatter = None 112 | 113 | if not logger.hasHandlers(): 114 | handler = logging.StreamHandler() 115 | handler.setFormatter(formatter) 116 | logging.basicConfig(format=fmt, level=_logging_level[_level], handlers=[handler]) 117 | 118 | if filename: 119 | handler = logging.FileHandler(filename) 120 | handler.setFormatter(formatter) 121 | logger.addHandler(handler) 122 | 123 | return StyleAdapter(logger, style=style) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers (https://arxiv.org/abs/2208.06366) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beitv2 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Zhiliang Peng 7 | # Based on BEiT, timm, DeiT and DINO code bases 8 | # https://github.com/microsoft/unilm/tree/master/beit 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 10 | # https://github.com/facebookresearch/deit/ 11 | # https://github.com/facebookresearch/dino 12 | # --------------------------------------------------------' 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | import torch.nn as nn 18 | import matplotlib.pyplot as plt 19 | 20 | 21 | def weights_init(m): 22 | classname = m.__class__.__name__ 23 | if classname.find('Conv') != -1: 24 | nn.init.normal_(m.weight.data, 0.0, 0.02) 25 | elif classname.find('BatchNorm') != -1: 26 | nn.init.normal_(m.weight.data, 1.0, 0.02) 27 | nn.init.constant_(m.bias.data, 0) 28 | 29 | 30 | def plot_images(images: dict): 31 | x = images["input"] 32 | reconstruction = images["rec"] 33 | half_sample = images["half_sample"] 34 | new_sample = images["new_sample"] 35 | 36 | fig, axarr = plt.subplots(1, 4) 37 | axarr[0].imshow(x.cpu().detach().numpy()[0].transpose(1, 2, 0)) 38 | axarr[1].imshow(reconstruction.cpu().detach().numpy()[0].transpose(1, 2, 0)) 39 | axarr[2].imshow(half_sample.cpu().detach().numpy()[0].transpose(1, 2, 0)) 40 | axarr[3].imshow(new_sample.cpu().detach().numpy()[0].transpose(1, 2, 0)) 41 | plt.show() 42 | 43 | 44 | def get_model(model): 45 | if isinstance(model, torch.nn.DataParallel) \ 46 | or isinstance(model, torch.nn.parallel.DistributedDataParallel): 47 | return model.module 48 | else: 49 | return model 50 | 51 | 52 | def is_dist_avail_and_initialized(): 53 | if not dist.is_available(): 54 | return False 55 | if not dist.is_initialized(): 56 | return False 57 | return True 58 | 59 | 60 | def get_world_size(): 61 | if not is_dist_avail_and_initialized(): 62 | return 1 63 | return dist.get_world_size() 64 | 65 | 66 | def get_rank(): 67 | if not is_dist_avail_and_initialized(): 68 | return 0 69 | return dist.get_rank() 70 | 71 | 72 | def is_main_process(): 73 | return get_rank() == 0 74 | 75 | 76 | def save_on_master(*args, **kwargs): 77 | if is_main_process(): 78 | torch.save(*args, **kwargs) 79 | --------------------------------------------------------------------------------