├── assets ├── fig1.png ├── fig2.png └── fig5.png ├── en_data └── demo_domain │ ├── assistance_for_text_editting.json │ └── tree_meta_info.json ├── prompt_bank ├── prompt_write_assistance_enrich.txt ├── prompt_brainstorming_enrich.txt ├── prompt_math_enrich.txt ├── prompt_write_assistance_extend.txt ├── prompt_brainstorming_extend.txt └── prompt_math_extend.txt ├── train ├── utils.py ├── train.py └── zero_to_fp32.py ├── deepspeed_config └── deepspeed_zero3_offload_config.json ├── requirements.txt ├── eval ├── chatgpt_generate.py ├── prompt │ └── en_review_prompt_compare.jsonl ├── generate.py ├── auto_eval │ └── math │ │ └── math_equivalence.py ├── auto_eval.py └── chatgpt_score.py ├── LICENSE ├── README.md └── generate_instruction.py /assets/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanqiwan/Explore-Instruct/HEAD/assets/fig1.png -------------------------------------------------------------------------------- /assets/fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanqiwan/Explore-Instruct/HEAD/assets/fig2.png -------------------------------------------------------------------------------- /assets/fig5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanqiwan/Explore-Instruct/HEAD/assets/fig5.png -------------------------------------------------------------------------------- /en_data/demo_domain/assistance_for_text_editting.json: -------------------------------------------------------------------------------- 1 | { 2 | "raw_task_name": "Assistance for text editting", 3 | "task_name": "assistance_for_text_editting", 4 | "config_filename": "assistance_for_text_editting.json", 5 | "reason": "", 6 | "examples": [ 7 | { 8 | "instruction": "Rewrite this text in another way", 9 | "input": "People often think of cats as mysterious creatures because they are especially active at night. Also, their eyes can glow an eerie green light in the dark.", 10 | "output": "Cats are often considered mystical creatures because they are particularly active at night. Also, their eyes can emit a strange green glow in dark environments." 11 | }, 12 | { 13 | "instruction": "Extend the content from the brief description below.", 14 | "input": "This is a great phone with a high-resolution display.", 15 | "output": "This is an excellent phone with a high-performance processor for faster performance and smoother multitasking. Furthermore, it has a high-resolution display for crisp, vibrant images , allowing users to enjoy a better visual experience." 16 | } 17 | ] 18 | } -------------------------------------------------------------------------------- /prompt_bank/prompt_write_assistance_enrich.txt: -------------------------------------------------------------------------------- 1 | You are asked to generate a set of examples for a new subtask. Each example consists of an instruction, an input, and an output. 2 | 3 | Here are the requirements: 4 | 1. The skills required to perform a subtask belong to the skills required to perform the target task, and the former is a subset of the latter. 5 | 2. The instruction should be in English. The instruction should be 1 to 2 sentences long. Either an imperative sentence or a question is permitted. 6 | 3. You should create an appropriate input based on the instruction in an example. The input should involve realistic data and should not contain simple placeholders. The input should provide substantial content to make the instruction challenging but do not exceed 200 words in general. 7 | 4. The input should include detailed content of a passage or an article if instructed, but not any overview or description about it. 8 | 5. You should generate an appropriate output according to the instruction and depending on the input in an example. Make sure the output is less than 200 words in general. 9 | 6. The response you generated should conform to the following format: 10 | ### 11 | 1. Instruction: ____ 12 | Input: ____ 13 | Output: ____ 14 | ### 15 | 2. Instruction: ____ 16 | Input: ____ 17 | Output: ____ 18 | ### -------------------------------------------------------------------------------- /train/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import json 4 | 5 | 6 | def _make_w_io_base(f, mode: str): 7 | if not isinstance(f, io.IOBase): 8 | f_dirname = os.path.dirname(f) 9 | if f_dirname != "": 10 | os.makedirs(f_dirname, exist_ok=True) 11 | f = open(f, mode=mode) 12 | return f 13 | 14 | 15 | def _make_r_io_base(f, mode: str): 16 | if not isinstance(f, io.IOBase): 17 | f = open(f, mode=mode) 18 | return f 19 | 20 | 21 | def jdump(obj, f, mode="w", indent=4, default=str): 22 | """Dump a str or dictionary to a file in json format. 23 | 24 | Args: 25 | obj: An object to be written. 26 | f: A string path to the location on disk. 27 | mode: Mode for opening the file. 28 | indent: Indent for storing json dictionaries. 29 | default: A function to handle non-serializable entries; defaults to `str`. 30 | """ 31 | f = _make_w_io_base(f, mode) 32 | if isinstance(obj, (dict, list)): 33 | json.dump(obj, f, indent=indent, default=default) 34 | elif isinstance(obj, str): 35 | f.write(obj) 36 | else: 37 | raise ValueError(f"Unexpected type: {type(obj)}") 38 | f.close() 39 | 40 | 41 | def jload(f, mode="r"): 42 | """Load a .json file into a dictionary.""" 43 | f = _make_r_io_base(f, mode) 44 | jdict = json.load(f) 45 | f.close() 46 | return jdict 47 | -------------------------------------------------------------------------------- /prompt_bank/prompt_brainstorming_enrich.txt: -------------------------------------------------------------------------------- 1 | You are asked to generate a set of examples for a new subtask. Each example consists of an instruction, an input, and an output. 2 | 3 | Here are the requirements: 4 | 1. The skills required to perform a subtask belong to the skills required to perform the target task, and the former is a subset of the latter. 5 | 2. The instruction should be in English. The instruction should be 1 to 2 sentences long. Either an imperative sentence or a question is permitted. 6 | 3. Try not to repeat the verb for each instruction in the examples to maximize diversity. 7 | 4. You should create an appropriate input based on the instruction in an example. The input should involve realistic data and should not contain simple placeholders. The input should provide substantial content to make the instruction challenging but do not exceed 200 words in general. 8 | 5. Note that some instructions do not require input. For example, when an instruction asks about some general information of self-contained, eg: "What is the highest mountain in the world." or "Please list 5 different fruits.", it is not necessary to provide a specific context. In this case, we simply put "" in the input field. 9 | 6. You should generate an appropriate output according to the instruction and depending on the input in an example. Make sure the output is less than 200 words in general. 10 | 7. The response you generated should conform to the following format: 11 | ### 12 | Instruction: ____ 13 | Input: ____ 14 | Output: ____ 15 | ### 16 | Instruction: ____ 17 | Input: ____ 18 | Output: ____ 19 | ### -------------------------------------------------------------------------------- /deepspeed_config/deepspeed_zero3_offload_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "optimizer": { 11 | "type": "AdamW", 12 | "params": { 13 | "lr": "auto", 14 | "weight_decay": "auto" 15 | } 16 | }, 17 | "scheduler": { 18 | "type": "WarmupDecayLR", 19 | "params": { 20 | "warmup_min_lr": "auto", 21 | "warmup_max_lr": "auto", 22 | "warmup_num_steps": "auto", 23 | "total_num_steps": "auto" 24 | } 25 | }, 26 | "zero_optimization": { 27 | "stage": 3, 28 | "offload_optimizer": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "offload_param": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "overlap_comm": true, 37 | "contiguous_gradients": true, 38 | "reduce_bucket_size": "auto", 39 | "stage3_prefetch_bucket_size": "auto", 40 | "stage3_param_persistence_threshold": "auto", 41 | "sub_group_size": 1e9, 42 | "stage3_max_live_parameters": 1e9, 43 | "stage3_max_reuse_distance": 1e9, 44 | "stage3_gather_16bit_weights_on_model_save": "auto" 45 | }, 46 | "gradient_accumulation_steps": "auto", 47 | "gradient_clipping": "auto", 48 | "steps_per_print": 2000, 49 | "train_batch_size": "auto", 50 | "train_micro_batch_size_per_gpu": "auto", 51 | "wall_clock_breakdown": false 52 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | accelerate==0.17.1 3 | aiofiles==23.1.0 4 | aiohttp==3.8.4 5 | aiosignal==1.3.1 6 | altair==4.2.2 7 | anthropic==0.2.6 8 | anyio==3.6.2 9 | anytree==2.8.0 10 | appdirs==1.4.4 11 | async-timeout==4.0.2 12 | asyncio==3.4.3 13 | attrs==22.2.0 14 | blessed==1.20.0 15 | brotlipy==0.7.0 16 | certifi @ file:///opt/conda/conda-bld/certifi_1655968806487/work/certifi 17 | cffi==1.15.1 18 | charset-normalizer==3.1.0 19 | click==8.1.3 20 | datasets==2.10.1 21 | deepspeed==0.8.2 22 | dill==0.3.6 23 | docker-pycreds==0.4.0 24 | evaluate==0.4.0 25 | filelock==3.9.0 26 | fire==0.5.0 27 | fonttools==4.25.0 28 | frozenlist==1.3.3 29 | fsspec==2023.3.0 30 | future==0.18.3 31 | gitdb==4.0.10 32 | GitPython==3.1.31 33 | gpustat==1.0.0 34 | hjson==3.1.0 35 | huggingface-hub==0.13.2 36 | idna==3.4 37 | joblib==1.2.0 38 | jsonline==0.2.1 39 | jsonlines==3.1.0 40 | loguru==0.6.0 41 | multidict==6.0.4 42 | multiprocess==0.70.14 43 | munkres==1.1.4 44 | ninja==1.11.1 45 | nltk==3.8.1 46 | numpy==1.24.2 47 | nvidia-ml-py==11.495.46 48 | openai==0.27.2 49 | packaging==23.0 50 | pandas==1.5.3 51 | pathtools==0.1.2 52 | ply==3.11 53 | protobuf==4.22.1 54 | psutil==5.9.4 55 | py-cpuinfo==9.0.0 56 | pyarrow==11.0.0 57 | pycparser==2.21 58 | pydantic==1.10.6 59 | pynvml==11.5.0 60 | PyQt5-sip==12.11.0 61 | python-dateutil==2.8.2 62 | pytz==2022.7.1 63 | PyYAML==6.0 64 | ray==2.3.1 65 | regex==2022.10.31 66 | requests==2.28.2 67 | responses==0.18.0 68 | rouge-score==0.1.2 69 | scikit-learn==1.2.2 70 | scipy==1.9.3 71 | sentencepiece==0.1.97 72 | sentry-sdk==1.16.0 73 | setproctitle==1.3.2 74 | shortuuid==1.0.11 75 | six==1.16.0 76 | sklearn==0.0.post1 77 | smmap==5.0.0 78 | threadpoolctl==3.1.0 79 | tokenizers==0.13.2 80 | torch==1.13.1+cu117 81 | tqdm==4.65.0 82 | transformers @ git+https://github.com/zphang/transformers@ef61b1ba1a8ee9fd354b640b059c3474b676c0c5 83 | typing_extensions==4.5.0 84 | urllib3==1.26.15 85 | wandb==0.14.0 86 | wcwidth==0.2.6 87 | xxhash==3.2.0 88 | yarl==1.8.2 89 | -------------------------------------------------------------------------------- /prompt_bank/prompt_math_enrich.txt: -------------------------------------------------------------------------------- 1 | You are asked to generate a set of examples for a new subtask. Each example consists of an instruction, an input, and an output. 2 | 3 | Here are the requirements: 4 | 1. The skills required to perform a subtask belong to the skills required to perform the target task, and the former is a subset of the latter. 5 | 2. The instruction should be in English. The instruction should be 1 to 2 sentences long. Either an imperative sentence or a question is permitted. 6 | 3. Try not to repeat the verb for each instruction in the examples to maximize diversity. 7 | 4. You should create an appropriate input based on the instruction in an example. The input should be a math problem in latex format. The input should provide substantial content to make the instruction challenging but do not exceed 200 words in general. 8 | 5. Note that some instructions do not require input. For example, when an instruction contain a math question, eg: "Reduce to lowest terms: $- \\dfrac{1}{9} \\div \\dfrac{9}{5} = {?}$", it is not necessary to provide a specific context in input. In this case, we simply put "" in the input field. 9 | 6. You should generate an appropriate output according to the instruction and depending on the input in an example. 10 | 7. The output should in latex format. You should generate the "Explanation:" to the math problem first and then extract & show the "Answer:" (the final value of "Answer:" should be in the form \\boxed{value of "Answer:"}), eg: "Explanation: Dividing by a fraction is the same as multiplying by the reciprocal of the fraction. The reciprocal of $ \\dfrac{9}{5}$ is $ \\dfrac{5}{9}$ Therefore: $ - \\dfrac{1}{9} \\div \\dfrac{9}{5} = - \\dfrac{1}{9} \\times \\dfrac{5}{9} $ $ \\phantom{- \\dfrac{1}{9} \\times \\dfrac{5}{9}} = \\dfrac{-1 \\times 5}{9 \\times 9} $ $ \\phantom{- \\dfrac{1}{9} \\times \\dfrac{5}{9}} = \\dfrac{-5}{81} $. Answer: $\\boxed{\\dfrac{-5}{81}}$.". Make sure the output is less than 200 words in general. 11 | 8. The response you generated should conform to the following format: 12 | ### 13 | Instruction: ____ 14 | Input: ____ 15 | Output: ____ 16 | ### 17 | Instruction: ____ 18 | Input: ____ 19 | Output: ____ 20 | ### -------------------------------------------------------------------------------- /prompt_bank/prompt_write_assistance_extend.txt: -------------------------------------------------------------------------------- 1 | You are asked to propose some new subtasks for the target task given a list of existing subtasks and another list of existing peer tasks, then generate a set of examples for each new subtask. Each example consists of an instruction, an input, and an output. 2 | 3 | Here are the requirements: 4 | 1. The skills required to perform a subtask belong to the skills required to perform the target task, and the former is a subset of the latter. 5 | 2. The skills required to perform a peer task relate to the skills required to perform the target task. There is an intersection of the former and the latter. 6 | 3. The subtask and peer task should focus on common domains, not specific domains. 7 | 4. A new subtask is complementary to existing subtasks, and the addition of a new subtask is essential to the completion of the target task. 8 | 5. The new subtask should be different from the existing subtasks and peer tasks. The skills required for a new subtask should be designed to avoid overlapping with existing subtasks and peer tasks. 9 | 6. The instruction should be in English. 10 | 7. The instruction should be 1 to 2 sentences long. Either an imperative sentence or a question is permitted. 11 | 8. The instruction should not contain specific examples and detailed content. 12 | 9. Try not to repeat the verb for each instruction in the examples to maximize diversity. 13 | 10. The instruction should be able to complete by a GPT language model. For example, the instruction should not ask the assistant to create any visual or audio output. For another example, do not ask the assistant to wake you up at 5 pm or set a reminder because it cannot perform any action. 14 | 11. You should create an appropriate input based on the instruction in an example, but the input should not respond to the instruction. The input should involve realistic data and should not contain simple placeholders. The input should provide substantial content to make the instruction challenging but do not exceed 200 words in general. 15 | 12. The input should include detailed content of a passage or an article if instructed, but not any overview or description about it. 16 | 13. You should generate an appropriate output according to the instruction and depending on the input in an example. Make sure the output is less than 200 words in general. 17 | 14. The response you generated should conform to the following format: 18 | New subtask: ____ 19 | Reason: ____ 20 | Examples: 21 | ### 22 | 1. Instruction: ____ 23 | Input: ____ 24 | Output: ____ 25 | ### 26 | 2. Instruction: ____ 27 | Input: ____ 28 | Output: ____ 29 | ### -------------------------------------------------------------------------------- /prompt_bank/prompt_brainstorming_extend.txt: -------------------------------------------------------------------------------- 1 | You are asked to propose some new subtasks for the target task given a list of existing subtasks and another list of existing peer tasks, then generate a set of examples for each new subtask. Each example consists of an instruction, an input, and an output. 2 | 3 | Here are the requirements: 4 | 1. The skills required to perform a subtask belong to the skills required to perform the target task, and the former is a subset of the latter. 5 | 2. The skills required to perform a peer task relate to the skills required to perform the target task. There is an intersection of the former and the latter. 6 | 3. The subtask and peer task should focus on common domains, not specific domains. 7 | 4. A new subtask is complementary to existing subtasks, and the addition of a new subtask is essential to the completion of the target task. 8 | 5. The new subtask should be different from the existing subtasks and peer tasks. The skills required for a new subtask should be designed to avoid overlapping with existing subtasks and peer tasks. 9 | 6. The instruction should be in English. 10 | 7. The instruction should be 1 to 2 sentences long. Either an imperative sentence or a question is permitted. 11 | 8. The instruction should not contain specific examples and detailed content. 12 | 9. Try not to repeat the verb for each instruction in the examples to maximize diversity. 13 | 10. The instruction should be able to complete by a GPT language model. For example, the instruction should not ask the assistant to create any visual or audio output. For another example, do not ask the assistant to wake you up at 5 pm or set a reminder because it cannot perform any action. 14 | 11. You should create an appropriate input based on the instruction in an example, but the input should not respond to the instruction. The input should involve realistic data and should not contain simple placeholders. The input should provide substantial content to make the instruction challenging but do not exceed 200 words in general. 15 | 12. Note that some instructions do not require input. For example, when an instruction asks about some general information of self-contained, eg: "What is the highest mountain in the world." or "Please list 5 different fruits.", it is not necessary to provide a specific context. In this case, we simply put "" in the input field. 16 | 13. You should generate an appropriate output according to the instruction and depending on the input in an example. Make sure the output is less than 200 words in general. 17 | 14. The response you generated should conform to the following format: 18 | New subtask: ____ 19 | Reason: ____ 20 | Examples: 21 | ### 22 | Instruction: ____ 23 | Input: ____ 24 | Output: ____ 25 | ### 26 | Instruction: ____ 27 | Input: ____ 28 | Output: ____ 29 | ### -------------------------------------------------------------------------------- /prompt_bank/prompt_math_extend.txt: -------------------------------------------------------------------------------- 1 | You are asked to propose some new subtasks for the target task given a list of existing subtasks and another list of existing peer tasks, then generate a set of examples for each new subtask. Each example consists of an instruction, an input, and an output. 2 | 3 | Here are the requirements: 4 | 1. The skills required to perform a subtask belong to the skills required to perform the target task, and the former is a subset of the latter. 5 | 2. The skills required to perform a peer task relate to the skills required to perform the target task. There is an intersection of the former and the latter. 6 | 3. The subtask and peer task should focus on common domains, not specific domains. 7 | 4. A new subtask is complementary to existing subtasks, and the addition of a new subtask is essential to the completion of the target task. 8 | 5. The new subtask should be different from the existing subtasks and peer tasks. The skills required for a new subtask should be designed to avoid overlapping with existing subtasks and peer tasks. 9 | 6. The instruction should be in English. 10 | 7. The instruction should be 1 to 2 sentences long. Either an imperative sentence or a question is permitted. 11 | 8. Try not to repeat the verb for each instruction in the examples to maximize diversity. 12 | 9. The instruction should be able to complete by a GPT language model. For example, the instruction should not ask the assistant to create any visual or audio output. For another example, do not ask the assistant to wake you up at 5 pm or set a reminder because it cannot perform any action. 13 | 10. You should create an appropriate input based on the instruction in an example. The input should be a math problem in latex format. The input should provide substantial content to make the instruction challenging but do not exceed 200 words in general. 14 | 11. Note that some instructions do not require input. For example, when an instruction contain a math question, eg: "Reduce to lowest terms: $- \\dfrac{1}{9} \\div \\dfrac{9}{5} = {?}$", it is not necessary to provide a specific context in input. In this case, we simply put "" in the input field. 15 | 12. You should generate an appropriate output according to the instruction and depending on the input in an example. 16 | 13. The output should in latex format. You should generate the "Explanation:" to the math problem first and then extract & show the "Answer:" (the final value of "Answer:" should be in the form \\boxed{value of "Answer:"}), eg: "Explanation: Dividing by a fraction is the same as multiplying by the reciprocal of the fraction. The reciprocal of $ \\dfrac{9}{5}$ is $ \\dfrac{5}{9}$ Therefore: $ - \\dfrac{1}{9} \\div \\dfrac{9}{5} = - \\dfrac{1}{9} \\times \\dfrac{5}{9} $ $ \\phantom{- \\dfrac{1}{9} \\times \\dfrac{5}{9}} = \\dfrac{-1 \\times 5}{9 \\times 9} $ $ \\phantom{- \\dfrac{1}{9} \\times \\dfrac{5}{9}} = \\dfrac{-5}{81} $. Answer: $\\boxed{\\dfrac{-5}{81}}$.". Make sure the output is less than 200 words in general. 17 | 14. The response you generated should conform to the following format: 18 | New subtask: ____ 19 | Reason: ____ 20 | Examples: 21 | ### 22 | Instruction: ____ 23 | Input: ____ 24 | Output: ____ 25 | ### 26 | Instruction: ____ 27 | Input: ____ 28 | Output: ____ 29 | ### -------------------------------------------------------------------------------- /en_data/demo_domain/tree_meta_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "raw_task_name": "Assistance for text editting", 3 | "task_name": "assistance_for_text_editting", 4 | "config_filename": "assistance_for_text_editting.json", 5 | "reason": "", 6 | "children": [ 7 | { 8 | "raw_task_name": "Paraphrase", 9 | "task_name": "paraphrase", 10 | "config_filename": "paraphrase.json", 11 | "reason": "The purposes of paraphrasing are to show understanding, make ideas accessible, translate technical language, avoid plagiarism, compare perspectives, and aid memorization. The ability to paraphrase skillfully is a valuable tool for write assistance.", 12 | "children": [] 13 | }, 14 | { 15 | "raw_task_name": "Style transfer", 16 | "task_name": "style_transfer", 17 | "config_filename": "style_transfer.json", 18 | "reason": "It can improve writing efficiency and text diversity, while stimulating creativity and enriching the expression of text, bringing a better writing experience.", 19 | "children": [] 20 | }, 21 | { 22 | "raw_task_name": "Simplify language", 23 | "task_name": "simplify_language", 24 | "config_filename": "simplify_language.json", 25 | "reason": "The purpose of this task is to make the original language more accessible to the broad masses of the people. Justifications may include documentation needs for different audiences, such as educational materials, government announcements, etc.", 26 | "children": [] 27 | }, 28 | { 29 | "raw_task_name": "Text expansion", 30 | "task_name": "text_expansion", 31 | "config_filename": "text_expansion.json", 32 | "reason": "Text expansion refers to expanding a piece of text, adding content and details, to help writers better express their thoughts and opinions, and to make the text easier for readers to understand.", 33 | "children": [] 34 | }, 35 | { 36 | "raw_task_name": "Fix spelling and grammar", 37 | "task_name": "fix_spelling_and_grammar", 38 | "config_filename": "fix_spelling_and_grammar.json", 39 | "reason": "The purpose of fix spelling and grammar is to help users create more accurate, professional text and improve the quality of their writing.", 40 | "children": [] 41 | }, 42 | { 43 | "raw_task_name": "Language style simulation", 44 | "task_name": "language_style_simulation", 45 | "config_filename": "language_style_simulation.json", 46 | "reason": "The purpose of this task is to help users create more focused and expressive text, and to improve the accuracy and professionalism of their writing.", 47 | "children": [] 48 | }, 49 | { 50 | "raw_task_name": "Coherence and cohesion improvement", 51 | "task_name": "coherence_and_cohesion_improvement", 52 | "config_filename": "coherence_and_cohesion_improvement.json", 53 | "reason": "Detect and improve logical coherence in text and cohesion between paragraphs. The purpose of this task is to help users create more coherent, logical and organized text", 54 | "children": [] 55 | }, 56 | { 57 | "raw_task_name": "Writing enhancement", 58 | "task_name": "writing_enhancement", 59 | "config_filename": "writing_enhancement.json", 60 | "reason": "The purpose of write enhancement is to improve the expressiveness and professionalism of text to make it more lively, precise and readable.", 61 | "children": [] 62 | } 63 | ] 64 | } -------------------------------------------------------------------------------- /eval/chatgpt_generate.py: -------------------------------------------------------------------------------- 1 | """Get answer for gpt-3.5-turbo""" 2 | import argparse 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | import asyncio 8 | import time 9 | from typing import Any 10 | import openai 11 | 12 | MAX_API_RETRY = 5 13 | openai.api_key = "YOUR OPENAI API KEY" 14 | 15 | 16 | def get_json_list(file_path): 17 | file_path = os.path.expanduser(file_path) 18 | with open(file_path, 'r') as f: 19 | json_list = [] 20 | for line in f: 21 | json_list.append(json.loads(line)) 22 | return json_list 23 | 24 | 25 | async def dispatch_openai_requests( 26 | messages_list: list[list[dict[str, Any]]], 27 | model: str, 28 | temperature: float, 29 | max_tokens: int, 30 | ) -> list[str]: 31 | """Dispatches requests to OpenAI API asynchronously. 32 | 33 | Args: 34 | messages_list: List of messages to be sent to OpenAI ChatCompletion API. 35 | model: OpenAI model to use. 36 | temperature: Temperature to use for the model. 37 | max_tokens: Maximum number of tokens to generate. 38 | Returns: 39 | List of responses from OpenAI API. 40 | """ 41 | async_responses = [ 42 | openai.ChatCompletion.acreate( 43 | model=model, 44 | messages=x, 45 | temperature=temperature, 46 | max_tokens=max_tokens, 47 | ) 48 | for x in messages_list 49 | ] 50 | return await asyncio.gather(*async_responses) 51 | 52 | 53 | def get_completion(messages_list: list, model: str, temperature: float = 0.0, max_tokens: int = 2048): 54 | for i in range(MAX_API_RETRY): 55 | try: 56 | completions = asyncio.run( 57 | dispatch_openai_requests( 58 | messages_list=messages_list, 59 | model=model, 60 | temperature=temperature, 61 | max_tokens=max_tokens, 62 | ) 63 | ) 64 | return completions 65 | except Exception as e: 66 | print(e) 67 | time.sleep(20) 68 | print(f'Failed after {MAX_API_RETRY} retries.') 69 | raise RuntimeError 70 | 71 | 72 | def get_prompt(qs, use_math_prompt): 73 | if use_math_prompt is False: 74 | prompt = qs 75 | else: 76 | prompt = "Given a math problem, you should generate the \"Explanation:\" to the math problem first and then extract & show the \"Answer:\" (the final value of \"Answer:\" should be in the form \\boxed{value of \"Answer:\"}). The output should in latex format." 77 | prompt += f"\nProblem: {qs}" 78 | return prompt 79 | 80 | 81 | def run_eval(model_id, input_file, output_file, decoding_args, use_math_prompt, batch_size): 82 | questions = get_json_list(input_file) 83 | if os.path.exists(output_file): 84 | curr_result = get_json_list(output_file) 85 | else: 86 | curr_result = [] 87 | for i in tqdm(range(len(curr_result), len(questions), batch_size)): 88 | batch_question = questions[i: i + batch_size] 89 | messages_list = [] 90 | for x in batch_question: 91 | qs = x["question"] 92 | prompt = get_prompt(qs, use_math_prompt) 93 | messages_list.append([ 94 | {"role": "user", 95 | "content": prompt}, 96 | ]) 97 | completions = get_completion(messages_list, model_id, **decoding_args) 98 | results = [completion['choices'][0]['message']['content'] for completion in completions] 99 | for idx, x in enumerate(batch_question): 100 | ans_id = shortuuid.uuid() 101 | ans = {"question_id": x["question_id"], 102 | "question": x["question"], 103 | "std_answer": x["std_answer"], 104 | "class": x["class"], 105 | "answer_id": ans_id, 106 | "answer": results[idx], 107 | "model_id": model_id, 108 | "metadata": decoding_args} 109 | with open(output_file, "a+") as fout: 110 | fout.write(json.dumps(ans) + '\n') 111 | 112 | 113 | if __name__ == "__main__": 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument("--model_id", type=str, required=True) 116 | parser.add_argument("--question_file", type=str, default="") 117 | parser.add_argument("--answer_file", type=str, default="answer.jsonl") 118 | parser.add_argument("--temperature", type=float, default=0.7) 119 | parser.add_argument("--max_tokens", type=int, default=1024) 120 | parser.add_argument("--use_math_prompt", action="store_true") 121 | parser.add_argument("--batch_size", type=int, default=1) 122 | args = parser.parse_args() 123 | print(args) 124 | decoding_args = {"temperature": args.temperature, "max_tokens": args.max_tokens} 125 | run_eval(args.model_id, args.question_file, args.answer_file, decoding_args, args.use_math_prompt, args.batch_size) 126 | -------------------------------------------------------------------------------- /eval/prompt/en_review_prompt_compare.jsonl: -------------------------------------------------------------------------------- 1 | {"class": "rewrite", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease evaluate the given four aspects: helpfulness, relevance, accuracy, level of details of their responses.\nPlease first clarify how each response achieves each aspect respectively.\nThen, provide a comparison on the overall performance between Assistant 1 and Assistant 2, and you need to clarify which one is better than or equal to another. Avoid any potential bias and ensuring that the order in which the responses were presented does not affect your judgment.\nIn the last line, order the two assistants. Please output a single line ordering Assistant 1 and Assistant 2, where '>' means 'is better than' and '=' means 'is equal to'. The order should be consistent to your comparison. If there is not comparison that one is better, it is assumed they have equivalent overall performance ('=').", "demo_input_1": "", "demo_output_1": "", "demo_input_2": "", "demo_output_2": ""} 2 | {"class": "generation", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease evaluate the given four aspects: helpfulness, relevance, accuracy, level of details of their responses.\nPlease first clarify how each response achieves each aspect respectively.\nThen, provide a comparison on the overall performance between Assistant 1 and Assistant 2, and you need to clarify which one is better than or equal to another. Avoid any potential bias and ensuring that the order in which the responses were presented does not affect your judgment.\nIn the last line, order the two assistants. Please output a single line ordering Assistant 1 and Assistant 2, where '>' means 'is better than' and '=' means 'is equal to'. The order should be consistent to your comparison. If there is not comparison that one is better, it is assumed they have equivalent overall performance ('=').", "demo_input_1": "", "demo_output_1": "", "demo_input_2": "", "demo_output_2": ""} 3 | {"class": "brainstorming", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease evaluate the given four aspects: helpfulness, relevance, accuracy, level of details of their responses.\nPlease first clarify how each response achieves each aspect respectively.\nThen, provide a comparison on the overall performance between Assistant 1 and Assistant 2, and you need to clarify which one is better than or equal to another. Avoid any potential bias and ensuring that the order in which the responses were presented does not affect your judgment.\nIn the last line, order the two assistants. Please output a single line ordering Assistant 1 and Assistant 2, where '>' means 'is better than' and '=' means 'is equal to'. The order should be consistent to your comparison. If there is not comparison that one is better, it is assumed they have equivalent overall performance ('=').", "demo_input_1": "", "demo_output_1": "", "demo_input_2": "", "demo_output_2": ""} 4 | {"class": "code", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease evaluate the given four aspects: helpfulness, relevance, accuracy, level of details of their responses.\nPlease first clarify how each response achieves each aspect respectively.\nThen, provide a comparison on the overall performance between Assistant 1 and Assistant 2, and you need to clarify which one is better than or equal to another. Avoid any potential bias and ensuring that the order in which the responses were presented does not affect your judgment.\nIn the last line, order the two assistants. Please output a single line ordering Assistant 1 and Assistant 2, where '>' means 'is better than' and '=' means 'is equal to'. The order should be consistent to your comparison. If there is not comparison that one is better, it is assumed they have equivalent overall performance ('=').", "demo_input_1": "", "demo_output_1": "", "demo_input_2": "", "demo_output_2": ""} 5 | {"class": "math", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease evaluate the given four aspects: helpfulness, relevance, accuracy, level of details of their responses.\nPlease first clarify how each response achieves each aspect respectively.\nThen, provide a comparison on the overall performance between Assistant 1 and Assistant 2, and you need to clarify which one is better than or equal to another. Avoid any potential bias and ensuring that the order in which the responses were presented does not affect your judgment.\nIn the last line, order the two assistants. Please output a single line ordering Assistant 1 and Assistant 2, where '>' means 'is better than' and '=' means 'is equal to'. The order should be consistent to your comparison. If there is not comparison that one is better, it is assumed they have equivalent overall performance ('=').", "demo_input_1": "", "demo_output_1": "", "demo_input_2": "", "demo_output_2": ""} -------------------------------------------------------------------------------- /eval/generate.py: -------------------------------------------------------------------------------- 1 | """Get answer for fine-tuned model""" 2 | import argparse 3 | from transformers import AutoTokenizer, AutoModelForCausalLM 4 | import torch 5 | import os 6 | import json 7 | from tqdm import tqdm 8 | import shortuuid 9 | import ray 10 | 11 | 12 | PROMPT_DICT_ALPACA = { 13 | "prompt_input": ( 14 | "Below is an instruction that describes a task, paired with an input that provides further context. " 15 | "Write a response that appropriately completes the request.\n\n" 16 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 17 | ), 18 | "prompt_no_input": ( 19 | "Below is an instruction that describes a task. " 20 | "Write a response that appropriately completes the request.\n\n" 21 | "### Instruction:\n{instruction}\n\n### Response:" 22 | ), 23 | } 24 | 25 | 26 | def run_eval(model_path, model_id, input_file, output_file, num_gpus, decoding_args, prompt_type): 27 | # split question file into num_gpus files 28 | ques_jsons = [] 29 | with open(os.path.expanduser(input_file), "r") as ques_file: 30 | for line in ques_file: 31 | ques_jsons.append(line) 32 | 33 | chunk_size = len(ques_jsons) // num_gpus 34 | ans_handles = [] 35 | for i in range(0, len(ques_jsons), chunk_size): 36 | ans_handles.append(get_model_answers.remote(model_path, model_id, ques_jsons[i:i + chunk_size], decoding_args, prompt_type)) 37 | 38 | ans_jsons = [] 39 | for ans_handle in ans_handles: 40 | ans_jsons.extend(ray.get(ans_handle)) 41 | 42 | with open(os.path.expanduser(output_file), "w") as ans_file: 43 | for line in ans_jsons: 44 | ans_file.write(json.dumps(line) + "\n") 45 | 46 | 47 | @ray.remote(num_gpus=1) 48 | @torch.inference_mode() 49 | def get_model_answers(model_path, model_id, question_jsons, decoding_args, prompt_type): 50 | tokenizer = AutoTokenizer.from_pretrained(model_path) 51 | model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16).cuda() 52 | ans_jsons = [] 53 | for i, line in enumerate(tqdm(question_jsons)): 54 | ques_json = json.loads(line) 55 | idx = ques_json["question_id"] 56 | qs = ques_json["question"] 57 | if prompt_type == "alpaca": 58 | prompt = PROMPT_DICT_ALPACA["prompt_no_input"].format_map({"instruction": qs}) 59 | else: 60 | print(f"{prompt_type} is not supported.") 61 | raise NotImplementedError 62 | inputs = tokenizer([prompt]) 63 | try: 64 | output_ids = model.generate( 65 | torch.as_tensor(inputs.input_ids).cuda(), 66 | **decoding_args) 67 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 68 | print("----------") 69 | print(f"prompt: {prompt}") 70 | print("----------") 71 | print(f"outputs: {outputs}") 72 | if prompt_type == "alpaca": 73 | try: 74 | start_index = outputs.index("### Response:") 75 | except: 76 | start_index = len(prompt) 77 | outputs = outputs[start_index:].strip().lstrip("### Response:").strip() 78 | else: 79 | raise NotImplementedError 80 | print("----------") 81 | print(f"prediction: {outputs}") 82 | except: 83 | outputs = "garbage" 84 | ans_id = shortuuid.uuid() 85 | ans_jsons.append({"question_id": idx, 86 | "question": qs, 87 | "std_answer": ques_json["std_answer"], 88 | "class": ques_json["class"], 89 | "answer_id": ans_id, 90 | "answer": outputs, 91 | "model_id": model_id, 92 | "metadata": decoding_args}) 93 | return ans_jsons 94 | 95 | 96 | if __name__ == "__main__": 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument("--model_path", type=str, required=True) 99 | parser.add_argument("--model_id", type=str, required=True) 100 | parser.add_argument("--question_file", type=str, default="") 101 | parser.add_argument("--answer_file", type=str, default="answer.jsonl") 102 | parser.add_argument("--num_gpus", type=int, default=1) 103 | parser.add_argument("--do_sample", action="store_true") 104 | parser.add_argument("--num_beams", type=int, default=1) 105 | parser.add_argument("--temperature", type=float, default=0.7) 106 | parser.add_argument("--max_new_tokens", type=int, default=512) 107 | parser.add_argument("--prompt_type", type=str, default="alpaca") 108 | args = parser.parse_args() 109 | print(args) 110 | ray.init() 111 | decoding_args = {"do_sample": args.do_sample, "num_beams": args.num_beams, 112 | "temperature": args.temperature, "max_new_tokens": args.max_new_tokens} 113 | run_eval(args.model_path, args.model_id, args.question_file, args.answer_file, args.num_gpus, decoding_args, args.prompt_type) 114 | -------------------------------------------------------------------------------- /eval/auto_eval/math/math_equivalence.py: -------------------------------------------------------------------------------- 1 | def _fix_fracs(string): 2 | substrs = string.split("\\frac") 3 | new_str = substrs[0] 4 | if len(substrs) > 1: 5 | substrs = substrs[1:] 6 | for substr in substrs: 7 | new_str += "\\frac" 8 | if substr[0] == "{": 9 | new_str += substr 10 | else: 11 | try: 12 | assert len(substr) >= 2 13 | except: 14 | return string 15 | a = substr[0] 16 | b = substr[1] 17 | if b != "{": 18 | if len(substr) > 2: 19 | post_substr = substr[2:] 20 | new_str += "{" + a + "}{" + b + "}" + post_substr 21 | else: 22 | new_str += "{" + a + "}{" + b + "}" 23 | else: 24 | if len(substr) > 2: 25 | post_substr = substr[2:] 26 | new_str += "{" + a + "}" + b + post_substr 27 | else: 28 | new_str += "{" + a + "}" + b 29 | string = new_str 30 | return string 31 | 32 | 33 | def _fix_a_slash_b(string): 34 | if len(string.split("/")) != 2: 35 | return string 36 | a = string.split("/")[0] 37 | b = string.split("/")[1] 38 | try: 39 | a = int(a) 40 | b = int(b) 41 | assert string == "{}/{}".format(a, b) 42 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 43 | return new_string 44 | except: 45 | return string 46 | 47 | 48 | def _remove_right_units(string): 49 | # "\\text{ " only ever occurs (at least in the val set) when describing units 50 | if "\\text{ " in string: 51 | splits = string.split("\\text{ ") 52 | assert len(splits) == 2 53 | return splits[0] 54 | else: 55 | return string 56 | 57 | 58 | def _fix_sqrt(string): 59 | if "\\sqrt" not in string: 60 | return string 61 | splits = string.split("\\sqrt") 62 | new_string = splits[0] 63 | for split in splits[1:]: 64 | if split[0] != "{": 65 | a = split[0] 66 | new_substr = "\\sqrt{" + a + "}" + split[1:] 67 | else: 68 | new_substr = "\\sqrt" + split 69 | new_string += new_substr 70 | return new_string 71 | 72 | 73 | def _strip_string(string): 74 | # linebreaks 75 | string = string.replace("\n", "") 76 | # print(string) 77 | 78 | # remove inverse spaces 79 | string = string.replace("\\!", "") 80 | # print(string) 81 | 82 | # replace \\ with \ 83 | string = string.replace("\\\\", "\\") 84 | # print(string) 85 | 86 | # replace tfrac and dfrac with frac 87 | string = string.replace("tfrac", "frac") 88 | string = string.replace("dfrac", "frac") 89 | # print(string) 90 | 91 | # remove \left and \right 92 | string = string.replace("\\left", "") 93 | string = string.replace("\\right", "") 94 | # print(string) 95 | 96 | # Remove circ (degrees) 97 | string = string.replace("^{\\circ}", "") 98 | string = string.replace("^\\circ", "") 99 | 100 | # remove dollar signs 101 | string = string.replace("\\$", "") 102 | 103 | # remove units (on the right) 104 | string = _remove_right_units(string) 105 | 106 | # remove percentage 107 | string = string.replace("\\%", "") 108 | string = string.replace("\%", "") 109 | 110 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 111 | string = string.replace(" .", " 0.") 112 | string = string.replace("{.", "{0.") 113 | # if empty, return empty string 114 | if len(string) == 0: 115 | return string 116 | if string[0] == ".": 117 | string = "0" + string 118 | 119 | # to consider: get rid of e.g. "k = " or "q = " at beginning 120 | if len(string.split("=")) == 2: 121 | if len(string.split("=")[0]) <= 2: 122 | string = string.split("=")[1] 123 | 124 | # fix sqrt3 --> sqrt{3} 125 | string = _fix_sqrt(string) 126 | 127 | # remove spaces 128 | string = string.replace(" ", "") 129 | 130 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 131 | string = _fix_fracs(string) 132 | 133 | # manually change 0.5 --> \frac{1}{2} 134 | if string == "0.5": 135 | string = "\\frac{1}{2}" 136 | 137 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 138 | string = _fix_a_slash_b(string) 139 | 140 | return string 141 | 142 | 143 | def is_equiv(str1, str2, verbose=False): 144 | if str1 is None and str2 is None: 145 | print("WARNING: Both None") 146 | return True 147 | if str1 is None or str2 is None: 148 | return False 149 | 150 | try: 151 | ss1 = _strip_string(str1) 152 | ss2 = _strip_string(str2) 153 | if verbose: 154 | print(ss1, ss2) 155 | return ss1 == ss2 156 | except: 157 | return str1 == str2 -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | from dataclasses import dataclass, field 4 | from typing import Optional, Dict, Sequence 5 | 6 | import torch 7 | import transformers 8 | from torch.utils.data import Dataset 9 | from transformers import Trainer 10 | 11 | from utils import jload 12 | 13 | 14 | IGNORE_INDEX = -100 15 | DEFAULT_PAD_TOKEN = "[PAD]" 16 | DEFAULT_EOS_TOKEN = "" 17 | DEFAULT_BOS_TOKEN = "" 18 | DEFAULT_UNK_TOKEN = "" 19 | PROMPT_DICT_ALPACA = { 20 | "prompt_input": ( 21 | "Below is an instruction that describes a task, paired with an input that provides further context. " 22 | "Write a response that appropriately completes the request.\n\n" 23 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 24 | ), 25 | "prompt_no_input": ( 26 | "Below is an instruction that describes a task. " 27 | "Write a response that appropriately completes the request.\n\n" 28 | "### Instruction:\n{instruction}\n\n### Response:" 29 | ), 30 | } 31 | 32 | 33 | @dataclass 34 | class ModelArguments: 35 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 36 | 37 | 38 | @dataclass 39 | class DataArguments: 40 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 41 | prompt_type: str = field(default="alpaca", metadata={"help": "Prompt type."}) 42 | 43 | 44 | @dataclass 45 | class TrainingArguments(transformers.TrainingArguments): 46 | cache_dir: Optional[str] = field(default=None) 47 | optim: str = field(default="adamw_torch") 48 | model_max_length: int = field( 49 | default=512, 50 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 51 | ) 52 | 53 | 54 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 55 | """Collects the state dict and dump to disk.""" 56 | state_dict = trainer.model.state_dict() 57 | if trainer.args.should_save: 58 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} 59 | del state_dict 60 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 61 | 62 | 63 | def smart_tokenizer_and_embedding_resize( 64 | special_tokens_dict: Dict, 65 | tokenizer: transformers.PreTrainedTokenizer, 66 | model: transformers.PreTrainedModel, 67 | ): 68 | """Resize tokenizer and embedding. 69 | 70 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 71 | """ 72 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 73 | model.resize_token_embeddings(len(tokenizer)) 74 | 75 | if num_new_tokens > 0: 76 | input_embeddings = model.get_input_embeddings().weight.data 77 | output_embeddings = model.get_output_embeddings().weight.data 78 | 79 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 80 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 81 | 82 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 83 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 84 | 85 | 86 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: 87 | """Tokenize a list of strings.""" 88 | tokenized_list = [ 89 | tokenizer( 90 | text, 91 | return_tensors="pt", 92 | padding="longest", 93 | max_length=tokenizer.model_max_length, 94 | truncation=True, 95 | ) 96 | for text in strings 97 | ] 98 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] 99 | input_ids_lens = labels_lens = [ 100 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list 101 | ] 102 | return dict( 103 | input_ids=input_ids, 104 | labels=labels, 105 | input_ids_lens=input_ids_lens, 106 | labels_lens=labels_lens, 107 | ) 108 | 109 | 110 | def preprocess( 111 | sources: Sequence[str], 112 | targets: Sequence[str], 113 | tokenizer: transformers.PreTrainedTokenizer, 114 | ) -> Dict: 115 | """Preprocess the data by tokenizing.""" 116 | examples = [s + t for s, t in zip(sources, targets)] 117 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] 118 | input_ids = examples_tokenized["input_ids"] 119 | labels = copy.deepcopy(input_ids) 120 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): 121 | label[:source_len] = IGNORE_INDEX 122 | return dict(input_ids=input_ids, labels=labels) 123 | 124 | 125 | class SupervisedDataset(Dataset): 126 | """Dataset for supervised fine-tuning.""" 127 | 128 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, prompt_type: str): 129 | super(SupervisedDataset, self).__init__() 130 | logging.warning("Loading data...") 131 | list_data_dict = jload(data_path) 132 | 133 | logging.warning("Formatting inputs...") 134 | if prompt_type == "alpaca": 135 | prompt_input, prompt_no_input = PROMPT_DICT_ALPACA["prompt_input"], PROMPT_DICT_ALPACA["prompt_no_input"] 136 | else: 137 | logging.warning(f"{prompt_type} is not supported.") 138 | raise NotImplementedError 139 | sources = [ 140 | prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) 141 | for example in list_data_dict 142 | ] 143 | targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] 144 | 145 | logging.warning("Tokenizing inputs... This may take some time...") 146 | data_dict = preprocess(sources, targets, tokenizer) 147 | 148 | self.input_ids = data_dict["input_ids"] 149 | self.labels = data_dict["labels"] 150 | 151 | def __len__(self): 152 | return len(self.input_ids) 153 | 154 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 155 | return dict(input_ids=self.input_ids[i], labels=self.labels[i]) 156 | 157 | 158 | @dataclass 159 | class DataCollatorForSupervisedDataset(object): 160 | """Collate examples for supervised fine-tuning.""" 161 | 162 | tokenizer: transformers.PreTrainedTokenizer 163 | 164 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 165 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 166 | input_ids = torch.nn.utils.rnn.pad_sequence( 167 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 168 | ) 169 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 170 | return dict( 171 | input_ids=input_ids, 172 | labels=labels, 173 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 174 | ) 175 | 176 | 177 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: 178 | """Make dataset and collator for supervised fine-tuning.""" 179 | train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, prompt_type=data_args.prompt_type) 180 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 181 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) 182 | 183 | 184 | def train(): 185 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 186 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 187 | 188 | model = transformers.AutoModelForCausalLM.from_pretrained( 189 | model_args.model_name_or_path, 190 | cache_dir=training_args.cache_dir, 191 | ) 192 | 193 | tokenizer = transformers.AutoTokenizer.from_pretrained( 194 | model_args.model_name_or_path, 195 | cache_dir=training_args.cache_dir, 196 | model_max_length=training_args.model_max_length, 197 | padding_side="right", 198 | use_fast=False, 199 | ) 200 | if tokenizer.pad_token is None: 201 | smart_tokenizer_and_embedding_resize( 202 | special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), 203 | tokenizer=tokenizer, 204 | model=model, 205 | ) 206 | assert "llama" in model_args.model_name_or_path 207 | if "llama" in model_args.model_name_or_path: 208 | tokenizer.add_special_tokens( 209 | { 210 | "eos_token": DEFAULT_EOS_TOKEN, 211 | "bos_token": DEFAULT_BOS_TOKEN, 212 | "unk_token": DEFAULT_UNK_TOKEN, 213 | } 214 | ) 215 | 216 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) 217 | 218 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) 219 | trainer.train() 220 | trainer.save_model() 221 | trainer.save_state() 222 | 223 | 224 | if __name__ == "__main__": 225 | train() -------------------------------------------------------------------------------- /eval/auto_eval.py: -------------------------------------------------------------------------------- 1 | """currently only support "math" domain""" 2 | import os 3 | import json 4 | import random 5 | 6 | import shortuuid 7 | import tqdm 8 | import numpy as np 9 | import argparse 10 | import sys 11 | 12 | sys.path.append("./auto_eval/math/") 13 | from math_equivalence import is_equiv 14 | 15 | 16 | class AutoEval(object): 17 | def __init__(self, data_file): 18 | self.data_file = data_file 19 | 20 | 21 | class AutoEvalMath(AutoEval): 22 | def __init__(self, data_file="./auto_eval/math/MATH", eval_set_name="MATH"): 23 | super().__init__(data_file) 24 | self.eval_set_name = eval_set_name 25 | 26 | def evaluate(self, question_file, answer_file, debugging=False, is_alpaca=False, verbose=False): 27 | """evaluate model performance on 'eval_set_name' dataset 28 | from: https://github.com/hendrycks/math/blob/357963a7f5501a6c1708cf3f3fb0cdf525642761/modeling/evaluate_gpt3.py 29 | """ 30 | if not debugging: 31 | answers = [] 32 | with open(answer_file, 'r') as f_in: 33 | for line in f_in.readlines(): 34 | answers.append(json.loads(line)) 35 | else: 36 | answers = answer_file # debugging only! answer_file now is a list 37 | questions = [] 38 | with open(question_file, 'r') as f_in: 39 | for line in f_in.readlines(): 40 | questions.append(json.loads(line)) 41 | if self.eval_set_name == "MATH": 42 | model_predictions = [] 43 | std_answers = [] 44 | types = [] 45 | levels = [] 46 | cors = {} 47 | subject_cors = {} 48 | level_cors = {} 49 | correct = 0 50 | total = 0 51 | correct_idx = [] 52 | for idx, answer in enumerate(answers): 53 | std_answer = answer["std_answer"] 54 | model_prediction = answer["answer"] 55 | prob_level = questions[idx]["level"] 56 | prob_type = questions[idx]["type"] 57 | 58 | def _last_boxed_only_string(string): 59 | idx = string.rfind("\\boxed") 60 | if idx < 0: 61 | idx = string.rfind("\\fbox") 62 | if idx < 0: 63 | return None 64 | 65 | i = idx 66 | right_brace_idx = None 67 | num_left_braces_open = 0 68 | while i < len(string): 69 | if string[i] == "{": 70 | num_left_braces_open += 1 71 | if string[i] == "}": 72 | num_left_braces_open -= 1 73 | if num_left_braces_open == 0: 74 | right_brace_idx = i 75 | break 76 | i += 1 77 | 78 | if right_brace_idx == None: 79 | retval = None 80 | else: 81 | retval = string[idx:right_brace_idx + 1] 82 | 83 | return retval 84 | 85 | def _remove_boxed(string): 86 | left = "\\boxed{" 87 | try: 88 | assert string[:len(left)] == left 89 | assert string[-1] == "}" 90 | return string[len(left):-1] 91 | except: 92 | return None 93 | 94 | def _extract_answer(string): 95 | """extract answer for std answer / model prediction 96 | """ 97 | return _remove_boxed(_last_boxed_only_string(string)) 98 | 99 | def _extract_alpaca_answer(string): 100 | """extract answer for alpaca model prediction (could not generate the accurate \\boxed{} format) 101 | """ 102 | idx1 = string.rfind("is") 103 | if idx1 > 0: 104 | string = string[idx1:].lstrip("is").strip().replace("$", "").replace(".", "").strip() 105 | idx2 = string.rfind("=") 106 | if idx2 > 0: 107 | string = string[idx2:].lstrip("=").strip().replace("$", "").replace(".", "").strip() 108 | return string 109 | 110 | std_answer_tokens = _extract_answer(std_answer) 111 | if is_alpaca is True: 112 | model_prediction_tokens = _extract_alpaca_answer(model_prediction) 113 | else: 114 | model_prediction_tokens = _extract_answer(model_prediction) 115 | if model_prediction_tokens is None: 116 | model_prediction_tokens = _extract_alpaca_answer(model_prediction) 117 | levels.append(prob_level) 118 | types.append(prob_type) 119 | std_answers.append(std_answer_tokens) 120 | model_predictions.append(model_prediction_tokens) 121 | try: 122 | equiv = is_equiv(model_prediction_tokens, std_answer_tokens) 123 | except: 124 | equiv = False 125 | if (prob_level, prob_type) in cors: 126 | cors[(prob_level, prob_type)].append(equiv) 127 | else: 128 | cors[(prob_level, prob_type)] = [equiv] 129 | if prob_level in level_cors: 130 | level_cors[prob_level].append(equiv) 131 | else: 132 | if prob_level is not None: 133 | level_cors[prob_level] = [equiv] 134 | if prob_type in subject_cors: 135 | subject_cors[prob_type].append(equiv) 136 | else: 137 | if prob_type is not None: 138 | subject_cors[prob_type] = [equiv] 139 | if equiv: 140 | correct += 1 141 | correct_idx.append(idx) 142 | total += 1 143 | if verbose is False: 144 | print(str(correct) + "/" + str(total)) 145 | if verbose is False: 146 | for subject in ['Prealgebra', 'Algebra', 'Number Theory', 'Counting & Probability', 'Geometry', 147 | 'Intermediate Algebra', 'Precalculus']: 148 | for level in range(1, 6): 149 | key = (level, subject) 150 | if key not in cors.keys(): 151 | print("Skipping", key) 152 | continue 153 | cors_list = cors[key] 154 | print("{} Level {} Accuracy = {}/{} = {:.3f}".format(subject, level, np.sum(cors_list), 155 | len(cors_list), np.mean(cors_list))) 156 | print("#####################") 157 | for level in sorted(level_cors): 158 | if level not in level_cors.keys(): 159 | print("Skipping", level) 160 | continue 161 | cors_list = level_cors[level] 162 | print("Level {} Accuracy = {}/{} = {:.3f}".format(level, np.sum(cors_list), len(cors_list), 163 | np.mean(cors_list))) 164 | print("#####################") 165 | for subject in ['Prealgebra', 'Algebra', 'Number Theory', 'Counting & Probability', 'Geometry', 166 | 'Intermediate Algebra', 'Precalculus']: 167 | if subject not in subject_cors.keys(): 168 | print("Skipping", subject) 169 | continue 170 | cors_list = subject_cors[subject] 171 | print("{} Accuracy = {}/{} = {:.3f}".format(subject, np.sum(cors_list), len(cors_list), 172 | np.mean(cors_list))) 173 | print("#####################") 174 | print("Overall Accuracy = {}/{} = {:.3f}".format(correct, total, correct / total)) 175 | print("#####################") 176 | print(correct_idx) 177 | else: 178 | print(f"{self.eval_set_name} is not supported.") 179 | raise NotImplementedError 180 | 181 | 182 | def auto_eval_math(question_file, answer_file, debugging, is_alpaca, verbose): 183 | evaluator = AutoEvalMath() 184 | evaluator.evaluate(question_file, answer_file, debugging=debugging, is_alpaca=is_alpaca, verbose=verbose) 185 | 186 | 187 | def main(): 188 | parser = argparse.ArgumentParser() 189 | parser.add_argument("--question_file", type=str, default="question.jsonl") 190 | parser.add_argument("--answer_file", type=str, default="answer.jsonl") 191 | parser.add_argument("--debugging", action="store_true") 192 | parser.add_argument("--is_alpaca", action="store_true") 193 | parser.add_argument("--verbose", action="store_true") 194 | args = parser.parse_args() 195 | print(args) 196 | auto_eval_math(args.question_file, args.answer_file, args.debugging, args.is_alpaca, args.verbose) 197 | 198 | 199 | if __name__ == "__main__": 200 | main() 201 | -------------------------------------------------------------------------------- /eval/chatgpt_score.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import time 3 | import json 4 | import os 5 | import tqdm 6 | import re 7 | import argparse 8 | import asyncio 9 | from typing import Any 10 | 11 | MAX_API_RETRY = 5 12 | openai.api_key = "YOUR OPENAI API KEY" 13 | 14 | 15 | # ---------------------------------- utils ------------------------------------------ 16 | def get_json_list(file_path): 17 | file_path = os.path.expanduser(file_path) 18 | with open(file_path, 'r') as f: 19 | json_list = [] 20 | for line in f: 21 | json_list.append(json.loads(line)) 22 | return json_list 23 | 24 | 25 | async def dispatch_openai_requests( 26 | messages_list: list[list[dict[str, Any]]], 27 | model: str, 28 | temperature: float, 29 | max_tokens: int, 30 | ) -> list[str]: 31 | """Dispatches requests to OpenAI API asynchronously. 32 | 33 | Args: 34 | messages_list: List of messages to be sent to OpenAI ChatCompletion API. 35 | model: OpenAI model to use. 36 | temperature: Temperature to use for the model. 37 | max_tokens: Maximum number of tokens to generate. 38 | Returns: 39 | List of responses from OpenAI API. 40 | """ 41 | async_responses = [ 42 | openai.ChatCompletion.acreate( 43 | model=model, 44 | messages=x, 45 | temperature=temperature, 46 | max_tokens=max_tokens, 47 | ) 48 | for x in messages_list 49 | ] 50 | return await asyncio.gather(*async_responses) 51 | 52 | 53 | def get_completion(messages_list: list, model: str, temperature: float = 0.0): 54 | for i in range(MAX_API_RETRY): 55 | try: 56 | completions = asyncio.run( 57 | dispatch_openai_requests( 58 | messages_list=messages_list, 59 | model=model, 60 | temperature=temperature, 61 | max_tokens=2048, 62 | ) 63 | ) 64 | return completions 65 | except Exception as e: 66 | print(e) 67 | time.sleep(20) 68 | print(f'Failed after {MAX_API_RETRY} retries.') 69 | raise RuntimeError 70 | 71 | 72 | def calculate_order(scores): 73 | result = dict() 74 | result["Assistant_1_Win"] = sum([_ == [1, 2] for _ in scores]) 75 | result["Assistant_1_Lose"] = sum([_ == [2, 1] for _ in scores]) 76 | result["Tie"] = sum([_ == [0, 0] for _ in scores]) 77 | return result 78 | 79 | 80 | # ---------------------------------- comparison ------------------------------------------ 81 | def get_prompt_compare(question, answer_1, answer_2, template_prompt, examples): 82 | prompt = "" 83 | if len(examples) > 0: 84 | for exp in examples: 85 | prompt += exp["input"] 86 | prompt += "\n" 87 | prompt += exp["output"] 88 | prompt += "\n" 89 | prompt += f"You are a helpful and precise assistant for checking the quality of the answer.\n[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{template_prompt}\n\n" 90 | return prompt 91 | 92 | 93 | def post_process_compare(review): 94 | review = review.replace(">=", ">") 95 | review = review.replace("> =", ">") 96 | review = review.replace("=>", ">") 97 | review = review.replace("= >", ">") 98 | review = review.replace("<=", "<") 99 | review = review.replace("< =", "<") 100 | review = review.replace("=<", "<") 101 | review = review.replace("= <", "<") 102 | review = review.replace(">>", ">") 103 | review = review.replace("> >", ">") 104 | review = review.replace("<<", "<") 105 | review = review.replace("< <", "<") 106 | review = review.replace("==", "=") 107 | review = review.replace("= =", "=") 108 | Assistant_1_win = ["Assistant 1 > Assistant 2", "Assistant 2 < Assistant 1", 109 | "[Assistant 1] > [Assistant 2]", "[Assistant 2] < [Assistant 1]"] 110 | for x in Assistant_1_win: 111 | if x in review: 112 | return [1, 2] 113 | Assistant_2_win = ["Assistant 1 < Assistant 2", "Assistant 2 > Assistant 1", 114 | "[Assistant 1] < [Assistant 2]", "[Assistant 2] > [Assistant 1]"] 115 | for x in Assistant_2_win: 116 | if x in review: 117 | return [2, 1] 118 | tie = ["Assistant 1 = Assistant 2", "[Assistant 1] = [Assistant 2]", 119 | "Assistant 2 = Assistant 1", "[Assistant 2] = [Assistant 1]"] 120 | for x in tie: 121 | if x in review: 122 | return [0, 0] 123 | print(f"Error for processing: {review}") 124 | return [0, 0] 125 | 126 | 127 | def get_compare(input_file_1, input_file_2, output_file, prompt_file, target_classes, use_demo=False, 128 | model="gpt-3.5-turbo", temperature=0.0, batch_size=1): 129 | prompt_templates = get_json_list(prompt_file) 130 | input_examples_1 = get_json_list(input_file_1) 131 | input_examples_2 = get_json_list(input_file_2) 132 | assert len(input_examples_1) == len(input_examples_2) 133 | review_examples = [] 134 | for i in range(len(input_examples_1)): 135 | if input_examples_1[i]["class"] in target_classes and \ 136 | (input_examples_1[i]["answer"] != "garbage" and input_examples_2[i]["answer"] != "garbage"): 137 | review_example = dict() 138 | review_example["question_id"] = input_examples_1[i]["question_id"] 139 | review_example["question"] = input_examples_1[i]["question"] 140 | review_example["std_answer"] = input_examples_1[i]["std_answer"] 141 | review_example["class"] = input_examples_1[i]["class"] 142 | 143 | review_example["answer_id_1"] = input_examples_1[i]["answer_id"] 144 | review_example["answer_1"] = input_examples_1[i]["answer"] 145 | review_example["model_id_1"] = input_examples_1[i]["model_id"] 146 | 147 | review_example["answer_id_2"] = input_examples_2[i]["answer_id"] 148 | review_example["answer_2"] = input_examples_2[i]["answer"] 149 | review_example["model_id_2"] = input_examples_2[i]["model_id"] 150 | 151 | review_example["metadata"] = input_examples_1[i]["metadata"] 152 | review_examples.append(review_example) 153 | if os.path.exists(output_file): 154 | curr_result = get_json_list(output_file) 155 | else: 156 | curr_result = [] 157 | for i in tqdm.tqdm(range(len(curr_result), len(review_examples), batch_size)): 158 | examples = review_examples[i: i + batch_size] 159 | prompt_template = [] 160 | demo_examples = [] 161 | messages_list = [] 162 | for example in examples: 163 | demo_examples.append([]) 164 | for x in prompt_templates: 165 | if x["class"] == example["class"]: 166 | prompt_template.append(x["prompt"]) 167 | if use_demo is True and x["demo_input_1"] != "": 168 | demo_examples[-1].append({"input": x["demo_input_1"], "output": x["demo_output_1"]}) 169 | if use_demo is True and x["demo_input_2"] != "": 170 | demo_examples[-1].append({"input": x["demo_input_2"], "output": x["demo_output_2"]}) 171 | break 172 | prompt = get_prompt_compare(example["question"], example["answer_1"], example["answer_2"], 173 | prompt_template[-1], demo_examples[-1]) 174 | messages_list.append([ 175 | {"role": "user", 176 | "content": prompt}, 177 | ]) 178 | assert len(messages_list) == len(prompt_template) 179 | completions = get_completion(messages_list, model, temperature) 180 | results = [completion['choices'][0]['message']['content'] for completion in completions] 181 | scores = [post_process_compare(result) for result in results] 182 | for idx, example in enumerate(examples): 183 | example["review_result"] = results[idx] 184 | example["review_score"] = scores[idx] 185 | with open(output_file, "a+") as fout: 186 | fout.write(json.dumps(example) + '\n') 187 | 188 | 189 | def get_statistic_for_compare(input_file): 190 | review_results = get_json_list(input_file) 191 | scores = dict() 192 | scores["all"] = [] 193 | for example in review_results: 194 | if example["class"] not in scores: 195 | scores[example["class"]] = [] 196 | scores[example["class"]].append(example["review_score"]) 197 | scores["all"].append(example["review_score"]) 198 | final_result = dict() 199 | for key, val in scores.items(): 200 | choice = calculate_order(val) 201 | final_result[key] = choice 202 | print(f"-----------------------{review_results[0]['model_id_1']}------------------------") 203 | print(f"-----------------------{review_results[0]['model_id_2']}------------------------") 204 | print(f"win:tie:lose={final_result['all']['Assistant_1_Lose']}:{final_result['all']['Tie']}:{final_result['all']['Assistant_1_Win']}") 205 | print("beat rate:{:.2f}".format(final_result['all']['Assistant_1_Lose'] / (final_result['all']['Assistant_1_Lose'] + final_result['all']['Assistant_1_Win']) * 100)) 206 | 207 | 208 | def main(): 209 | parser = argparse.ArgumentParser() 210 | parser.add_argument("--answer_file", type=str, default="answer.jsonl") 211 | parser.add_argument("--baseline_file", type=str, default="baseline.jsonl") 212 | parser.add_argument("--review_file", type=str, default="review.jsonl") 213 | parser.add_argument("--prompt_file", type=str, default="prompt.jsonl") 214 | parser.add_argument("--target_classes", type=str, default="rewrite") 215 | parser.add_argument("--use_demo", action="store_true") 216 | parser.add_argument("--review_model", type=str, default="gpt-3.5-turbo") 217 | parser.add_argument("--batch_size", type=int, default=1) 218 | args = parser.parse_args() 219 | print(args) 220 | target_classes = args.target_classes.split(",") 221 | get_compare(args.baseline_file, args.answer_file, args.review_file, args.prompt_file, target_classes, 222 | use_demo=args.use_demo, model=args.review_model, batch_size=args.batch_size) 223 | get_statistic_for_compare(args.review_file) 224 | 225 | 226 | if __name__ == "__main__": 227 | main() 228 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

3 | 4 |
5 | 6 | Explore-Instruct: Enhancing Domain-Specific Instruction Coverage through Active Exploration 7 | ----------------------------- 8 | 9 | Version 10 | License 11 | Stars 12 | Issues 13 | 14 | 15 |

| 📑 Paper | 16 | 🤗 HuggingFace Repo | 17 | 🐱 GitHub Repo | 18 |

19 | 20 | 21 | 22 | _**Fanqi Wan, Xinting Huang, Tao Yang, Xiaojun Quan, Wei Bi, Shuming Shi**_ 23 | 24 | 25 | 26 | 27 | 28 | _ Sun Yat-sen University, 29 | Tencent AI Lab_ 30 | 31 |
32 | 33 | 34 | ## News 35 | - **Oct 16, 2023:** 🔥 We're excited to announce that the Explore-Instruct datasets in brainstorming, rewriting, and math domains are now available on 🤗 [Huggingface Datasets](https://huggingface.co/datasets?sort=trending&search=Explore_Instruct)! Additionally, we've released Explore-LM models that have been initialized with LLaMA-7B and fine-tuned with the Explore-Instruct data in each domain. You can find these models on 🤗 [Huggingface Models](https://huggingface.co/models?sort=trending&search=Explore-LM). Happy exploring and instructing! 36 | 37 | ## Contents 38 | 39 | - [Overview](#overview) 40 | - [Data Release](#data-release) 41 | - [Model Release](#model-release) 42 | - [Data Generation Process](#data-generation-process) 43 | - [Fine-tuning](#fine-tuning) 44 | - [Evaluation](#evaluation) 45 | - [Limitations](#limitations) 46 | - [Citation](#citation) 47 | 48 | ## Overview 49 | 50 | We propose Explore-Instruct, a novel approach to enhancing domain-specific instruction coverage. We posit that the domain space is inherently structured akin to a tree, reminiscent of cognitive science ontologies. Drawing from the essence of classical search algorithms and incorporating the power of LLMs, Explore-Instruct is conceived to actively traverse the domain space and generate instruction-tuning data, **not** necessitating a predefined tree structure. Specifically, Explore-Instruct employs two strategic operations: lookahead and backtracking exploration: 51 | 52 | - **Lookahead** delves into a multitude of potential fine-grained sub-tasks, thereby mapping out a complex network of tasks 53 | 54 | - **Backtracking** seeks alternative branches to widen the search boundary, hence extending the domain spectrum. 55 | 56 |

57 |
58 |

59 | 60 | ## Data Release 61 | 62 | We release the Explore-Instruct data in brainstorming, rewriting, and math domains on 🤗 [Huggingface Datasets](https://huggingface.co/datasets?sort=trending&search=Explore_Instruct). Each domain includes two versions of datasets: the basic and extended version. The base version contains 10k instruction-tuning data and the extended version contains 16k, 32k, and 64k instruction-tuning data for each domain respectively. Each dataset is a structured data file in the JSON format. It consists of a list of dictionaries, with each dictionary containing the following fields: 63 | 64 | - `instruction`: `str`, describes the task the model should perform. 65 | - `input`: `str`, optional context or input for the task. 66 | - `output`: `str`, ground-truth output text for the task and input text. 67 | 68 | The results of data-centric analysis are shown as follows: 69 | 70 |

71 |
72 |

73 | 74 | | Method | Brainstorming Unique
V-N pairs | Rewriting Unique
V-N pairs | Math Unique
V-N pairs | 75 | |:--------------------------------|:----------------------------------:|:------------------------------:|:-------------------------:| 76 | | _Domain-Specific Human-Curated_ | 2 | 8 | 3 | 77 | | _Domain-Aware Self-Instruct_ | 781 | 1715 | 451 | 78 | | Explore-Instruct | **790** | **2015** | **917** | 79 | 80 | ## Model Release 81 | 82 | We release the Explore-LM models in brainstorming, rewriting, and math domains on 🤗 [Huggingface Models](https://huggingface.co/models?sort=trending&search=Explore-LM). Each domain includes two versions of models: the basic and extended version trained with the corresponding version of dataset. 83 | 84 | The results of automatic and human evaluation in three domains are shown as follows: 85 | 86 | - Automatic evaluation: 87 | 88 | | Automatic Comparison in the Brainstorming Domain | Win:Tie:Lose | Beat Rate | 89 | |:-------------------------------------------------|:------------:|:---------:| 90 | | Explore-LM vs Domain-Curated-LM | 194:1:13 | 93.72 | 91 | | Explore-LM-Ext vs Domain-Curated-LM | 196:1:11 | 94.69 | 92 | | Explore-LM vs Domain-Instruct-LM | 114:56:38 | 75.00 | 93 | | Explore-LM-Ext vs Domain-Instruct-LM | 122:55:31 | 79.74 | 94 | | Explore-LM vs ChatGPT | 52:71:85 | 37.96 | 95 | | Explore-LM-Ext vs ChatGPT | 83:69:56 | 59.71 | 96 | 97 | 98 | | Automatic Comparison in the Rewriting Domain | Win:Tie:Lose | Beat Rate | 99 | |:---------------------------------------------|:------------:|:---------:| 100 | | Explore-LM vs Domain-Curated-LM | 50:38:6 | 89.29 | 101 | | Explore-LM-Ext vs Domain-Curated-LM | 53:37:4 | 92.98 | 102 | | Explore-LM vs Domain-Instruct-LM | 34:49:11 | 75.56 | 103 | | Explore-LM-Ext vs Domain-Instruct-LM | 35:53:6 | 85.37 | 104 | | Explore-LM vs ChatGPT | 11:59:24 | 31.43 | 105 | | Explore-LM-Ext vs ChatGPT | 12:56:26 | 31.58 | 106 | 107 | 108 | | Automatic Comparison in the Math Domain | Accuracy Rate | 109 | |:----------------------------------------|:-------------:| 110 | | Domain-Curated-LM | 3.4 | 111 | | Domain-Instruct-LM | 4.0 | 112 | | Explore-LM | 6.8 | 113 | | Explore-LM-Ext | 8.4 | 114 | | ChatGPT | 34.8 | 115 | 116 | - Human evaluation: 117 | 118 |

119 |
120 |

121 | 122 | ## Data Generation Process 123 | 124 | To generate the domain-specific instruction-tuning data, please follow the following commands step by step: 125 | 126 | ### Domain Space Exploration 127 | ``` 128 | python3 generate_instruction.py \ 129 | --action extend \ 130 | --save_dir ./en_data/demo_domain \ # input dir include current domain tree for exploration 131 | --out_dir ./en_data/demo_domain_exploration \ # output dir of the explored new domain tree 132 | --lang \ # currently support 'en' 133 | --domain demo_domain \ # domain for exploration 134 | --extend_nums ,..., \ # exploration breadth at each depth 135 | --max_depth \ # exploration depth 136 | --assistant_name # currently support openai and claude 137 | ``` 138 | 139 | ### Instruction-Tuning Data Generation 140 | ``` 141 | python3 generate_instruction.py \ 142 | --action enrich \ 143 | --save_dir ./en_data/demo_domain_exploration \ # input dir include current domain tree for data generation 144 | --out_dir ./en_data/demo_domain_generation \ # output dir of the domain tree with generated data 145 | --lang \ # currently support 'en' 146 | --domain demo_domain \ # domain for exploration 147 | --enrich_nums ,..., \ # data number for task at each depth 148 | --enrich_batch_size \ # batch size for data generation 149 | --assistant_name # currently support openai and claude 150 | ``` 151 | 152 | ### Task Pruning 153 | ``` 154 | python3 generate_instruction.py \ 155 | --action prune \ 156 | --save_dir ./en_data/demo_domain_generation \ # input dir include current domain tree for task pruning 157 | --out_dir ./en_data/demo_domain_pruning \ # output dir of the domain tree with 'pruned_subtasks_name.json' file 158 | --lang \ # currently support 'en' 159 | --domain demo_domain \ # domain for exploration 160 | --pruned_file ./en_data/demo_domain_pruning/pruned_subtasks_name.json \ # file of pruned tasks 161 | --prune_threshold \ # threshold of rouge-l overlap between task names 162 | --assistant_name # currently support openai and claude 163 | ``` 164 | 165 | ### Data Filtering 166 | ``` 167 | python3 generate_instruction.py \ 168 | --action filter \ 169 | --save_dir ./en_data/demo_domain_pruning \ # input dir include current domain tree for data filtering 170 | --out_dir ./en_data/demo_domain_filtering \ # output dir of the domain tree with fitered data 171 | --lang \ # currently support 'en' 172 | --domain demo_domain \ # domain for exploration 173 | --pruned_file ./en_data/demo_domain_pruning/pruned_subtasks_name.json \ # file of pruned tasks 174 | --filter_threshold \ # threshold of rouge-l overlap between instructions 175 | --assistant_name # currently support openai and claude 176 | ``` 177 | 178 | ### Data Sampling 179 | ``` 180 | python3 generate_instruction.py \ 181 | --action sample \ 182 | --save_dir ./en_data/demo_domain_filtering \ # input dir include current domain tree for data sampling 183 | --out_dir ./en_data/demo_domain_sampling \ # output dir of the domain tree with sampled data 184 | --lang \ # currently support 'en' 185 | --domain demo_domain \ # domain for exploration 186 | --pruned_file ./en_data/demo_domain_filtering/pruned_subtasks_name.json \ # file of pruned tasks 187 | --sample_example_num \ # number of sampled examples 188 | --sample_max_depth \ # max depth for data sampling 189 | --sample_use_pruned \ # do not sample from pruned tasks 190 | --assistant_name # currently support openai and claude 191 | ``` 192 | 193 | ## Fine-tuning 194 | 195 | We fine-tune LLaMA-7B with the following hyperparameters: 196 | 197 | | Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay | 198 | |:----------------|-------------------:|---------------:|--------:|-----------:|--------------:| 199 | | LLaMA 7B | 128 | 2e-5 | 3 | 512 | 0 | 200 | 201 | To reproduce the training procedure, please use the following command: 202 | 203 | ``` 204 | deepspeed --num_gpus=8 ./train/train.py \ 205 | --deepspeed ./deepspeed_config/deepspeed_zero3_offload_config.json \ 206 | --model_name_or_path decapoda-research/llama-7b-hf \ 207 | --data_path ./en_data/demo_domain_sampling \ 208 | --fp16 True \ 209 | --output_dir ./training_results/explore-lm-7b-demo-domain \ 210 | --num_train_epochs 3 \ 211 | --per_device_train_batch_size 2 \ 212 | --per_device_eval_batch_size 2 \ 213 | --gradient_accumulation_steps 8 \ 214 | --evaluation_strategy "no" \ 215 | --model_max_length 512 \ 216 | --save_strategy "steps" \ 217 | --save_steps 2000 \ 218 | --save_total_limit 1 \ 219 | --learning_rate 2e-5 \ 220 | --weight_decay 0. \ 221 | --warmup_ratio 0.03 \ 222 | --lr_scheduler_type "cosine" \ 223 | --logging_steps 1 \ 224 | --prompt_type alpaca \ 225 | 2>&1 | tee ./training_logs/explore-lm-7b-demo-domain.log 226 | 227 | python3 ./train/zero_to_fp32.py \ 228 | --checkpoint_dir ./training_results/explore-lm-7b-demo-domain \ 229 | --output_file ./training_results/explore-lm-7b-demo-domain/pytorch_model.bin 230 | ``` 231 | 232 | ## Evaluation 233 | 234 | The evaluation datasets for different domains are as follows: 235 | - Brainstorming and Rewriting: From the corresponding categories in the translated test set of BELLE. ([en_eval_set.jsonl](./eval/question/en_eval_set.jsonl)) 236 | - Math: From randomly selected 500 questions from the test set of MATH. ([MATH_eval_set_sample.jsonl](./eval/question/MATH_eval_set_sample.jsonl)) 237 | 238 | The evaluation metrics for different domains are as follows: 239 | - Brainstorming and Rewriting: Both automatic and human evaluations following Vicuna. 240 | - Math: Accuracy Rate metric in solving math problems. 241 | 242 | The automatic evaluation commands for different domains are as follows: 243 | 244 | ``` 245 | # Brainstorming and Rewriting Domain 246 | 247 | # 1. Inference 248 | python3 ./eval/generate.py \ 249 | --model_id \ 250 | --model_path \ 251 | --question_file ./eval/question/en_eval_set.jsonl \ 252 | --answer_file ./eval/answer/.jsonl \ 253 | --num_gpus 8 \ 254 | --num_beams 1 \ 255 | --temperature 0.7 \ 256 | --max_new_tokens 512 \ 257 | --prompt_type alpaca \ 258 | --do_sample 259 | 260 | # 2. Evaluation 261 | python3 ./eval/chatgpt_score.py \ 262 | --baseline_file ./eval/answer/.jsonl \ # answer of baseline model to compare with 263 | --answer_file ./eval/answer/.jsonl \ # answer of evaluation model 264 | --review_file ./eval/review/_cp__.jsonl \ # review from chatgpt 265 | --prompt_file ./eval/prompt/en_review_prompt_compare.jsonl \ # evaluation prompt for chatgpt 266 | --target_classes \ # evaluation domain 267 | --batch_size \ 268 | --review_model "gpt-3.5-turbo-0301" 269 | ``` 270 | 271 | ``` 272 | # Math Domain 273 | 274 | # 1. Inference 275 | python3 ./eval/generate.py \ 276 | --model_id \ 277 | --model_path \ 278 | --question_file ./eval/question/MATH_eval_set_sample.jsonl \ 279 | --answer_file ./eval/answer/.jsonl \ 280 | --num_gpus 8 \ 281 | --num_beams 10 \ 282 | --temperature 1.0 \ 283 | --max_new_tokens 512 \ 284 | --prompt_type alpaca 285 | 286 | # 2. Evaluation 287 | python3 ./eval/auto_eval.py \ 288 | --question_file ./eval/question/MATH_eval_set_sample.jsonl \ 289 | --answer_file ./eval/answer/.jsonl # answer of evaluation model 290 | ``` 291 | 292 | ## Limitations 293 | 294 | Explore-Instruct is still under development and needs a lot of improvements. We acknowledge that our work focuses on the enhancement of domain-specific instruction coverage and does not address other aspects of instruction-tuning, such as the generation of complex and challenging instructions or the mitigation of toxic and harmful instructions. Future work is needed to explore the potential of our approach in these areas. 295 | 296 | ## Citation 297 | 298 | If you find this work is relevant with your research or applications, please feel free to cite our work! 299 | ``` 300 | @inproceedings{wan2023explore, 301 | title={Explore-Instruct: Enhancing Domain-Specific Instruction Coverage through Active Exploration}, 302 | author={Wan, Fanqi and Huang, Xinting and Yang, Tao and Quan, Xiaojun and Bi, Wei and Shi, Shuming}, 303 | booktitle={Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing}, 304 | pages={9435--9454}, 305 | year={2023} 306 | } 307 | ``` -------------------------------------------------------------------------------- /train/zero_to_fp32.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | '''Copyright The Microsoft DeepSpeed Team''' 3 | 4 | # This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets 5 | # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in 6 | # the future. Once extracted, the weights don't require DeepSpeed and can be used in any 7 | # application. 8 | # 9 | # example: python zero_to_fp32.py . pytorch_model.bin 10 | 11 | import argparse 12 | import torch 13 | import glob 14 | import math 15 | import os 16 | import re 17 | from collections import OrderedDict 18 | 19 | # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with 20 | # DeepSpeed data structures it has to be available in the current python environment. 21 | from deepspeed.utils import logger 22 | from deepspeed.checkpoint.constants import (DS_VERSION, 23 | OPTIMIZER_STATE_DICT, 24 | SINGLE_PARTITION_OF_FP32_GROUPS, 25 | FP32_FLAT_GROUPS, 26 | ZERO_STAGE, 27 | PARTITION_COUNT, 28 | PARAM_SHAPES, 29 | BUFFER_NAMES) 30 | 31 | debug = 0 32 | 33 | # load to cpu 34 | device = torch.device('cpu') 35 | 36 | 37 | def atoi(text): 38 | return int(text) if text.isdigit() else text 39 | 40 | 41 | def natural_keys(text): 42 | ''' 43 | alist.sort(key=natural_keys) sorts in human order 44 | http://nedbatchelder.com/blog/200712/human_sorting.html 45 | (See Toothy's implementation in the comments) 46 | ''' 47 | return [atoi(c) for c in re.split(r'(\d+)', text)] 48 | 49 | 50 | def get_model_state_file(checkpoint_dir, zero_stage): 51 | if not os.path.isdir(checkpoint_dir): 52 | raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") 53 | 54 | # there should be only one file 55 | if zero_stage == 2: 56 | file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") 57 | elif zero_stage == 3: 58 | file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") 59 | 60 | if not os.path.exists(file): 61 | raise FileNotFoundError(f"can't find model states file at '{file}'") 62 | 63 | return file 64 | 65 | 66 | def get_optim_files(checkpoint_dir): 67 | # XXX: need to test that this simple glob rule works for multi-node setup too 68 | optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, 69 | "*_optim_states.pt")), 70 | key=natural_keys) 71 | 72 | if len(optim_files) == 0: 73 | raise FileNotFoundError( 74 | f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") 75 | 76 | return optim_files 77 | 78 | 79 | def parse_model_state(file): 80 | state_dict = torch.load(file, map_location=device) 81 | 82 | if BUFFER_NAMES not in state_dict: 83 | raise ValueError(f"{file} is not a model state checkpoint") 84 | buffer_names = state_dict[BUFFER_NAMES] 85 | if debug: 86 | print("Found buffers:", buffer_names) 87 | 88 | # recover just the buffers while restoring them to fp32 if they were saved in fp16 89 | buffers = { 90 | k: v.float() 91 | for k, 92 | v in state_dict["module"].items() if k in buffer_names 93 | } 94 | param_shapes = state_dict[PARAM_SHAPES] 95 | 96 | ds_version = state_dict.get(DS_VERSION, None) 97 | 98 | return buffers, param_shapes, ds_version 99 | 100 | 101 | def parse_optim_states(files, ds_checkpoint_dir): 102 | 103 | total_files = len(files) 104 | state_dicts = [] 105 | for f in files: 106 | state_dicts.append(torch.load(f, map_location=device)) 107 | 108 | if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: 109 | raise ValueError(f"{files[0]} is not a zero checkpoint") 110 | zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] 111 | world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] 112 | 113 | # For ZeRO-2 each param group can have different partition_count as data parallelism for expert 114 | # parameters can be different from data parallelism for non-expert parameters. So we can just 115 | # use the max of the partition_count to get the dp world_size. 116 | 117 | if type(world_size) is list: 118 | world_size = max(world_size) 119 | 120 | if world_size != total_files: 121 | raise ValueError( 122 | f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " 123 | "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." 124 | ) 125 | 126 | # the groups are named differently in each stage 127 | if zero_stage == 2: 128 | fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS 129 | elif zero_stage == 3: 130 | fp32_groups_key = FP32_FLAT_GROUPS 131 | else: 132 | raise ValueError(f"unknown zero stage {zero_stage}") 133 | 134 | if zero_stage == 2: 135 | fp32_flat_groups = [ 136 | state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] 137 | for i in range(len(state_dicts)) 138 | ] 139 | elif zero_stage == 3: 140 | # if there is more than one param group, there will be multiple flattened tensors - one 141 | # flattened tensor per group - for simplicity merge them into a single tensor 142 | # 143 | # XXX: could make the script more memory efficient for when there are multiple groups - it 144 | # will require matching the sub-lists of param_shapes for each param group flattened tensor 145 | 146 | fp32_flat_groups = [ 147 | torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 148 | 0) for i in range(len(state_dicts)) 149 | ] 150 | 151 | return zero_stage, world_size, fp32_flat_groups 152 | 153 | 154 | def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): 155 | """ 156 | Returns fp32 state_dict reconstructed from ds checkpoint 157 | 158 | Args: 159 | - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) 160 | 161 | """ 162 | print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") 163 | 164 | optim_files = get_optim_files(ds_checkpoint_dir) 165 | zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) 166 | print( 167 | f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") 168 | 169 | model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) 170 | buffers, param_shapes, ds_version = parse_model_state(model_file) 171 | print(f'Parsing checkpoint created by deepspeed=={ds_version}') 172 | 173 | if zero_stage == 2: 174 | return _get_fp32_state_dict_from_zero2_checkpoint(world_size, 175 | param_shapes, 176 | fp32_flat_groups, 177 | buffers) 178 | elif zero_stage == 3: 179 | return _get_fp32_state_dict_from_zero3_checkpoint(world_size, 180 | param_shapes, 181 | fp32_flat_groups, 182 | buffers) 183 | 184 | 185 | def _get_fp32_state_dict_from_zero2_checkpoint(world_size, 186 | param_shapes, 187 | fp32_flat_groups, 188 | buffers): 189 | 190 | # Reconstruction protocol: 191 | # 192 | # XXX: document this 193 | 194 | if debug: 195 | for i in range(world_size): 196 | for j in range(len(fp32_flat_groups[0])): 197 | print( 198 | f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") 199 | 200 | # XXX: memory usage doubles here (zero2) 201 | num_param_groups = len(fp32_flat_groups[0]) 202 | merged_single_partition_of_fp32_groups = [] 203 | for i in range(num_param_groups): 204 | merged_partitions = [sd[i] for sd in fp32_flat_groups] 205 | full_single_fp32_vector = torch.cat(merged_partitions, 0) 206 | merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) 207 | avail_numel = sum([ 208 | full_single_fp32_vector.numel() 209 | for full_single_fp32_vector in merged_single_partition_of_fp32_groups 210 | ]) 211 | 212 | if debug: 213 | wanted_params = sum([len(shapes) for shapes in param_shapes]) 214 | wanted_numel = sum( 215 | [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) 216 | # not asserting if there is a mismatch due to possible padding 217 | print(f"Have {avail_numel} numels to process.") 218 | print(f"Need {wanted_numel} numels in {wanted_params} params.") 219 | 220 | state_dict = OrderedDict() 221 | 222 | # buffers 223 | state_dict.update(buffers) 224 | if debug: 225 | print(f"added {len(buffers)} buffers") 226 | 227 | # params 228 | # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support 229 | # out-of-core computing solution 230 | total_numel = 0 231 | total_params = 0 232 | for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): 233 | offset = 0 234 | avail_numel = full_single_fp32_vector.numel() 235 | for name, shape in shapes.items(): 236 | 237 | unpartitioned_numel = shape.numel() 238 | total_numel += unpartitioned_numel 239 | total_params += 1 240 | 241 | if debug: 242 | print( 243 | f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " 244 | ) 245 | state_dict[name] = full_single_fp32_vector.narrow( 246 | 0, 247 | offset, 248 | unpartitioned_numel).view(shape) 249 | offset += unpartitioned_numel 250 | 251 | # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and 252 | # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex 253 | # paddings performed in the code it's almost impossible to predict the exact numbers w/o the 254 | # live optimizer object, so we are checking that the numbers are within the right range 255 | align_to = 2 * world_size 256 | 257 | def zero2_align(x): 258 | return align_to * math.ceil(x / align_to) 259 | 260 | if debug: 261 | print(f"original offset={offset}, avail_numel={avail_numel}") 262 | 263 | offset = zero2_align(offset) 264 | avail_numel = zero2_align(avail_numel) 265 | 266 | if debug: 267 | print(f"aligned offset={offset}, avail_numel={avail_numel}") 268 | 269 | # Sanity check 270 | if offset != avail_numel: 271 | raise ValueError( 272 | f"consumed {offset} numels out of {avail_numel} - something is wrong") 273 | 274 | print( 275 | f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" 276 | ) 277 | 278 | return state_dict 279 | 280 | 281 | def zero3_partitioned_param_info(unpartitioned_numel, world_size): 282 | remainder = unpartitioned_numel % world_size 283 | padding_numel = (world_size - remainder) if remainder else 0 284 | partitioned_numel = math.ceil(unpartitioned_numel / world_size) 285 | return partitioned_numel, padding_numel 286 | 287 | 288 | def _get_fp32_state_dict_from_zero3_checkpoint(world_size, 289 | param_shapes, 290 | fp32_flat_groups, 291 | buffers): 292 | 293 | # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each 294 | # param, re-consolidating each param, while dealing with padding if any 295 | 296 | avail_numel = fp32_flat_groups[0].numel() * world_size 297 | # merge list of dicts, preserving order 298 | param_shapes = {k: v for d in param_shapes for k, v in d.items()} 299 | 300 | if debug: 301 | for i in range(world_size): 302 | print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") 303 | 304 | wanted_params = len(param_shapes) 305 | wanted_numel = sum(shape.numel() for shape in param_shapes.values()) 306 | # not asserting if there is a mismatch due to possible padding 307 | print(f"Have {avail_numel} numels to process.") 308 | print(f"Need {wanted_numel} numels in {wanted_params} params.") 309 | 310 | state_dict = OrderedDict() 311 | 312 | # buffers 313 | state_dict.update(buffers) 314 | if debug: 315 | print(f"added {len(buffers)} buffers") 316 | 317 | # params 318 | # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support 319 | # out-of-core computing solution 320 | offset = 0 321 | total_numel = 0 322 | total_params = 0 323 | for name, shape in param_shapes.items(): 324 | 325 | unpartitioned_numel = shape.numel() 326 | total_numel += unpartitioned_numel 327 | total_params += 1 328 | 329 | partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) 330 | 331 | if debug: 332 | print( 333 | f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" 334 | ) 335 | 336 | # XXX: memory usage doubles here 337 | state_dict[name] = torch.cat( 338 | tuple(fp32_flat_groups[i].narrow(0, 339 | offset, 340 | partitioned_numel) 341 | for i in range(world_size)), 342 | 0).narrow(0, 343 | 0, 344 | unpartitioned_numel).view(shape) 345 | offset += partitioned_numel 346 | 347 | offset *= world_size 348 | 349 | # Sanity check 350 | if offset != avail_numel: 351 | raise ValueError( 352 | f"consumed {offset} numels out of {avail_numel} - something is wrong") 353 | 354 | print( 355 | f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" 356 | ) 357 | 358 | return state_dict 359 | 360 | 361 | def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): 362 | """ 363 | Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with 364 | ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example 365 | via a model hub. 366 | 367 | Args: 368 | - ``checkpoint_dir``: path to the desired checkpoint folder 369 | - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` 370 | 371 | Returns: 372 | - pytorch ``state_dict`` 373 | 374 | Note: this approach may not work if your application doesn't have sufficient free CPU memory and 375 | you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with 376 | the checkpoint. 377 | 378 | A typical usage might be :: 379 | 380 | from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint 381 | # do the training and checkpoint saving 382 | state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu 383 | model = model.cpu() # move to cpu 384 | model.load_state_dict(state_dict) 385 | # submit to model hub or save the model to share with others 386 | 387 | In this example the ``model`` will no longer be usable in the deepspeed context of the same 388 | application. i.e. you will need to re-initialize the deepspeed engine, since 389 | ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. 390 | 391 | If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. 392 | 393 | """ 394 | if tag is None: 395 | latest_path = os.path.join(checkpoint_dir, 'latest') 396 | if os.path.isfile(latest_path): 397 | with open(latest_path, 'r') as fd: 398 | tag = fd.read().strip() 399 | else: 400 | raise ValueError(f"Unable to find 'latest' file at {latest_path}") 401 | 402 | ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) 403 | 404 | if not os.path.isdir(ds_checkpoint_dir): 405 | raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") 406 | 407 | return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) 408 | 409 | 410 | def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): 411 | """ 412 | Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be 413 | loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. 414 | 415 | Args: 416 | - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) 417 | - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) 418 | - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` 419 | """ 420 | 421 | state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) 422 | print(f"Saving fp32 state dict to {output_file}") 423 | torch.save(state_dict, output_file) 424 | 425 | 426 | def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): 427 | """ 428 | 1. Put the provided model to cpu 429 | 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` 430 | 3. Load it into the provided model 431 | 432 | Args: 433 | - ``model``: the model object to update 434 | - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) 435 | - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` 436 | 437 | Returns: 438 | - ``model`: modified model 439 | 440 | Make sure you have plenty of CPU memory available before you call this function. If you don't 441 | have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it 442 | conveniently placed for you in the checkpoint folder. 443 | 444 | A typical usage might be :: 445 | 446 | from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint 447 | model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) 448 | # submit to model hub or save the model to share with others 449 | 450 | Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context 451 | of the same application. i.e. you will need to re-initialize the deepspeed engine, since 452 | ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. 453 | 454 | """ 455 | logger.info(f"Extracting fp32 weights") 456 | state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) 457 | 458 | logger.info(f"Overwriting model with fp32 weights") 459 | model = model.cpu() 460 | model.load_state_dict(state_dict, strict=False) 461 | 462 | return model 463 | 464 | 465 | if __name__ == "__main__": 466 | 467 | parser = argparse.ArgumentParser() 468 | parser.add_argument( 469 | "--checkpoint_dir", 470 | type=str, 471 | help="path to the desired checkpoint folder, e.g., path/checkpoint-12") 472 | parser.add_argument( 473 | "--output_file", 474 | type=str, 475 | help= 476 | "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" 477 | ) 478 | parser.add_argument("-d", "--debug", action='store_true', help="enable debug") 479 | args = parser.parse_args() 480 | 481 | debug = args.debug 482 | 483 | convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) 484 | -------------------------------------------------------------------------------- /generate_instruction.py: -------------------------------------------------------------------------------- 1 | """ 2 | generate domain-specific instruction-tuning data 3 | """ 4 | import copy 5 | import shutil 6 | import time 7 | import json 8 | import os 9 | import random 10 | import re 11 | import string 12 | import logging 13 | from functools import partial 14 | from multiprocessing import Pool 15 | from typing import List, Tuple, Dict, Union, Optional, Any 16 | import argparse 17 | 18 | import numpy as np 19 | import tqdm 20 | from rouge_score import rouge_scorer 21 | # import utils 22 | import openai 23 | import fire 24 | from transformers import GPT2TokenizerFast, AutoTokenizer, BertTokenizer 25 | from anytree import AnyNode, Node, PreOrderIter, LevelOrderIter 26 | from anytree.importer import DictImporter, JsonImporter 27 | from anytree.exporter import JsonExporter, DictExporter 28 | import asyncio 29 | 30 | level = logging.DEBUG 31 | # level = logger.INFO 32 | logger = logging.getLogger(__name__) 33 | logger.setLevel(level) 34 | c_handler = logging.StreamHandler() 35 | c_handler.setLevel(level) 36 | c_format = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 37 | c_handler.setFormatter(c_format) 38 | 39 | # Add handlers to the logger 40 | logger.addHandler(c_handler) 41 | 42 | meta_info_filename = "tree_meta_info.json" 43 | 44 | openai.api_key = "YOUR OPENAI API KEY" 45 | claude_key = "YOUR CLAUDE API KEY" 46 | 47 | 48 | class DomainTree(object): 49 | """docstring for DomainTree""" 50 | 51 | def __init__( 52 | self, 53 | root: AnyNode, 54 | unique_notd_id: str = "task_name", 55 | name_to_node: Dict[str, AnyNode] = None, 56 | **kwargs 57 | ): 58 | self.root = root 59 | self.unique_id = unique_notd_id 60 | self.name_to_node = name_to_node if name_to_node is not None \ 61 | else {getattr(node, self.unique_id): node for node in PreOrderIter(self.root)} 62 | 63 | # TODO: load from config file instead of hard-coded 64 | self.prepare_mine_hparams(**kwargs) 65 | self.prepare_prompt() 66 | self.prepare_tools() 67 | 68 | def extend_node_children( 69 | self, 70 | node_to_extend: Union[AnyNode, str], 71 | extend_num: int = None, 72 | extend_batch_size: int = None, 73 | ) -> List[AnyNode]: 74 | all_new_nodes = list() 75 | while len(all_new_nodes) < extend_num: 76 | gap = extend_num - len(all_new_nodes) 77 | logger.info(f"Already extended num: {len(all_new_nodes)}. This time gap: {gap}") 78 | new_nodes: List[AnyNode] = self._extend_node_children( 79 | node_to_extend, 80 | extend_num=max(gap, self.min_extend_num), 81 | extend_batch_size=extend_batch_size, 82 | ) 83 | all_new_nodes += new_nodes 84 | return all_new_nodes 85 | 86 | def _extend_node_children( 87 | self, 88 | node_to_extend: Union[AnyNode, str], 89 | extend_num: int = None, 90 | extend_batch_size: int = None, 91 | ) -> List[AnyNode]: 92 | """ 93 | Extending the given node's children 94 | """ 95 | extend_num = self.default_extend_num if extend_num is None else extend_num 96 | extend_batch_size = self.default_extend_batch_size if extend_batch_size is None else extend_batch_size 97 | # locate the node 98 | if type(node_to_extend) is str: 99 | node_to_extend: AnyNode = self.name_to_node[node_to_extend] 100 | node_name = getattr(node_to_extend, self.unique_id) 101 | current_scenario = f"extend_node_children for {node_name}" 102 | logger.info(current_scenario) 103 | # check its existing children 104 | existing_children = list(node_to_extend.children) 105 | existing_children_names = [getattr(child, self.unique_id) for child in node_to_extend.children] 106 | if len(existing_children) >= self.max_child_num: 107 | logger.warning(f"Failed trial to extend node {node_name}: already having {len(existing_children)} children") 108 | return [] 109 | if len(existing_children) + extend_num > self.max_child_num: 110 | logger.warning(f"Exceeding max_child_num if extending by {extend_num}, using remaining num instead.") 111 | extend_num = self.max_child_num - len(existing_children) 112 | # formulate the prompt and api request 113 | base_prompt = self.extend_node_prompt 114 | existing_siblings = list(node_to_extend.siblings) 115 | if extend_batch_size > 1: 116 | prompt = [] 117 | for _ in range(extend_batch_size): 118 | demonstrate_examples = self.get_demonstrate_examples(node_to_extend) 119 | prompt_tmp = self.encode_prompt( 120 | base_prompt=base_prompt, 121 | demonstrate_examples=demonstrate_examples, # Example triplet 122 | target_task=node_name, # name of node_to_extend 123 | existing_children=existing_children, # list of children of node_to_extend 124 | existing_siblings=existing_siblings, # list of siblings of node_to_extend 125 | num_examples_per_time=self.num_example_extend, # num of triplet for each new mined subtask 126 | extend_num=extend_num, 127 | # new_subtask=new_subtask, # only used when enriching nodes 128 | # new_subtask_reason=new_subtask_reason, # only used when enriching nodes 129 | target_children_num=self.max_child_num, 130 | ) 131 | prompt.append(prompt_tmp) 132 | else: 133 | demonstrate_examples = self.get_demonstrate_examples(node_to_extend) 134 | prompt = self.encode_prompt( 135 | base_prompt=base_prompt, 136 | demonstrate_examples=demonstrate_examples, # Example triplet 137 | target_task=node_name, # name of node_to_extend 138 | existing_children=existing_children, # list of children of node_to_extend 139 | existing_siblings=existing_siblings, # list of siblings of node_to_extend 140 | num_examples_per_time=self.num_example_extend, # num of triplet for each new mined subtask 141 | extend_num=extend_num, 142 | # new_subtask=new_subtask, # only used when enriching nodes 143 | # new_subtask_reason=new_subtask_reason, # only used when enriching nodes 144 | target_children_num=self.max_child_num, 145 | ) 146 | logger.debug(f"Num of existing_children: {len(existing_children)}") 147 | logger.debug(f"Num of existing_siblings: {len(existing_siblings)}") 148 | logger.debug(f"Final prompt: {prompt}") 149 | 150 | logger.info(f"{self.assistant_name}: Online querying...") 151 | result = self.request_func(prompt) 152 | logger.info("Received online querying results") 153 | logger.debug(f"raw request result: {result}") 154 | if result is None: 155 | logger.warning(f"Received error response!!") 156 | return [] 157 | if extend_batch_size > 1: 158 | for p, r in zip(prompt, result): 159 | self.write_query_log(p, r) 160 | else: 161 | self.write_query_log(prompt, result) 162 | 163 | # new_subtask, new_subtask_reason, new_instructions = \ 164 | # self.post_process_gpt3_response_extend(result) 165 | if extend_batch_size > 1: 166 | new_subtask, new_subtask_reason, new_instructions = [], [], [] 167 | for r in result: 168 | new_subtask_tmp, new_subtask_reason_tmp, new_instructions_tmp = self.postprocess_extend(r) 169 | new_subtask += new_subtask_tmp 170 | new_subtask_reason += new_subtask_reason_tmp 171 | new_instructions += new_instructions_tmp 172 | else: 173 | new_subtask, new_subtask_reason, new_instructions = \ 174 | self.postprocess_extend(result) 175 | logger.debug(f"len(new_subtask): {len(new_subtask)}; new_subtask: {new_subtask};") 176 | 177 | add_nodes = list() 178 | if not (len(new_subtask) == len(new_subtask_reason) == len(new_instructions)): 179 | logger.warning(f"{current_scenario}, encountering bad completion result") 180 | if extend_batch_size > 1: 181 | for p, r in zip(prompt, result): 182 | self.handle_bad_completion(r, 183 | prompt=p, 184 | request_scenario=current_scenario, ) 185 | else: 186 | self.handle_bad_completion( 187 | result, 188 | prompt=prompt, 189 | request_scenario=current_scenario, 190 | ) 191 | return add_nodes 192 | 193 | # save file and update tree 194 | for subtask, reason, examples in zip(new_subtask, new_subtask_reason, new_instructions): 195 | subtask_id = self.formalize_taskname(subtask) 196 | # node_save_file = os.path.join(self.general_outdir, f"{subtask_id}.json") 197 | node_info = { 198 | self.unique_id: subtask_id, 199 | "raw_task_name": subtask, 200 | "parent": node_to_extend, 201 | "reason": reason, 202 | "examples": examples, 203 | # "config_file": node_save_file, 204 | "config_filename": f"{subtask_id}.json", 205 | 206 | } 207 | new_node = self.add_node(node_info) 208 | add_nodes.append(new_node) 209 | return add_nodes 210 | 211 | def write_query_log(self, prompt: str, res: Dict[str, str]): 212 | with open(os.path.join(self.general_outdir, "query_log.jsonl"), "a+", encoding="utf-8") as f_out: 213 | query_log = { 214 | "prompt": prompt, 215 | "res": res, 216 | } 217 | json.dump( 218 | query_log, 219 | f_out, 220 | ensure_ascii=False, 221 | ) 222 | f_out.write("\n") 223 | 224 | def enrich_node_samples( 225 | self, 226 | node_to_enrich: AnyNode, 227 | enrich_num: int = None, 228 | enrich_batch_size: int = None, 229 | ) -> List[Dict[str, str]]: 230 | all_new_examples = list() 231 | while len(all_new_examples) < enrich_num: 232 | gap = enrich_num - len(all_new_examples) 233 | logger.info(f"Already enriched example num: {len(all_new_examples)}. This time gap: {gap}") 234 | new_examples: List[Dict[str, str]] = self._enrich_node_samples( 235 | node_to_enrich, 236 | enrich_num=max(gap, self.min_extend_num), 237 | enrich_batch_size=enrich_batch_size 238 | ) 239 | all_new_examples += new_examples 240 | return all_new_examples 241 | 242 | def _enrich_node_samples( 243 | self, 244 | node_to_enrich: AnyNode, 245 | enrich_num: int = None, 246 | enrich_batch_size: int = None, 247 | ) -> List[Dict[str, str]]: 248 | """ 249 | Given a mined node, add more belonging to this node 250 | """ 251 | enrich_num = self.default_enrich_num if enrich_num is None else enrich_num 252 | enrich_batch_size = self.default_enrich_batch_size if enrich_batch_size is None else enrich_batch_size 253 | # get existing num of example 254 | if type(node_to_enrich) is str: 255 | node_to_enrich: AnyNode = self.name_to_node[node_to_enrich] 256 | new_subtask = node_name = getattr(node_to_enrich, self.unique_id) 257 | 258 | current_scenario = f"enrich_node_samples for {node_name}" 259 | logger.info(current_scenario) 260 | 261 | new_subtask_reason = getattr(node_to_enrich, "reason", "") 262 | # hard-coded to handle enriching root node 263 | parent_node = node_to_enrich.parent if node_to_enrich.parent is not None else node_to_enrich 264 | 265 | parent_node_name = getattr(parent_node, self.unique_id) 266 | existing_examples = getattr(node_to_enrich, "examples", []) 267 | if len(existing_examples) >= self.max_example_num: 268 | logger.warning(f"Failed trial to enrich node {node_name}: already having {len(existing_examples)} examples") 269 | return 270 | if len(existing_examples) + enrich_num > self.max_example_num: 271 | logger.warning(f"Exceeding max_child_num if extending by {enrich_num}, using remaining num instead.") 272 | enrich_num = self.max_example_num - len(existing_examples) 273 | base_prompt = self.enrich_node_prompt 274 | new_instructions = list() 275 | existing_children = list(parent_node.children) 276 | existing_siblings = list(parent_node.siblings) 277 | if enrich_batch_size > 1: 278 | prompt = [] 279 | for _ in range(enrich_batch_size): 280 | demonstrate_examples = self.get_demonstrate_examples(node_to_enrich) 281 | # only sampling from the seed tasks 282 | prompt_tmp = self.encode_prompt( 283 | base_prompt=base_prompt, 284 | demonstrate_examples=demonstrate_examples, # Example triplet 285 | target_task=parent_node_name, # name of node_to_extend 286 | existing_children=existing_children, # list of children of node_to_extend, 287 | existing_siblings=existing_siblings, # list of siblings of node_to_extend 288 | num_examples_per_time=self.num_example_enrich, # num of triplet for each new mined subtask 289 | new_subtask=new_subtask, # only used when enriching nodes 290 | new_subtask_reason=new_subtask_reason, # only used when enriching nodes 291 | # extend_num=extend_num, # only used when extending nodes 292 | # target_children_num=self.max_child_num, # only used when extending nodes 293 | ) 294 | prompt.append(prompt_tmp) 295 | else: 296 | demonstrate_examples = self.get_demonstrate_examples(node_to_enrich) 297 | # only sampling from the seed tasks 298 | prompt = self.encode_prompt( 299 | base_prompt=base_prompt, 300 | demonstrate_examples=demonstrate_examples, # Example triplet 301 | target_task=parent_node_name, # name of node_to_extend 302 | existing_children=existing_children, # list of children of node_to_extend, 303 | existing_siblings=existing_siblings, # list of siblings of node_to_extend 304 | num_examples_per_time=self.num_example_enrich, # num of triplet for each new mined subtask 305 | new_subtask=new_subtask, # only used when enriching nodes 306 | new_subtask_reason=new_subtask_reason, # only used when enriching nodes 307 | # extend_num=extend_num, # only used when extending nodes 308 | # target_children_num=self.max_child_num, # only used when extending nodes 309 | ) 310 | logger.debug(f"Final prompt: {prompt}") 311 | logger.info(f"{self.assistant_name}: Online querying...") 312 | # completion = self.request_openai(prompt) 313 | # result = completion['choices'][0] 314 | result = self.request_func(prompt) 315 | logger.info("Received online querying results") 316 | if enrich_batch_size > 1: 317 | for p, r in zip(prompt, result): 318 | self.write_query_log(p, r) 319 | else: 320 | self.write_query_log(prompt, result) 321 | logger.debug(f"raw request result: {result}") 322 | # _, _, new_instructions = \ 323 | # self.post_process_gpt3_response_enrich(result, new_subtask, new_subtask_reason) 324 | if enrich_batch_size > 1: 325 | new_instructions = [] 326 | for r in result: 327 | _, _, new_instructions_tmp = self.postprocess_enrich(r, new_subtask, new_subtask_reason) 328 | new_instructions += new_instructions_tmp 329 | else: 330 | _, _, new_instructions = \ 331 | self.postprocess_enrich(result, new_subtask, new_subtask_reason) 332 | 333 | # update node config 334 | node_to_enrich.examples = existing_examples + new_instructions 335 | self.update_file(node_to_enrich, ) 336 | return new_instructions 337 | 338 | def request_openai_single(self, prompt: str) -> Dict[str, str]: 339 | response = None 340 | for trial_idx in range(self.max_retry_times): 341 | try: 342 | messages = [ 343 | {"role": "user", 344 | "content": prompt}, 345 | ] 346 | prompt_len = len(self.tokenizer(prompt)['input_ids']) 347 | completion = openai.ChatCompletion.create( 348 | messages=messages, 349 | max_tokens=4096 - 300 - prompt_len, 350 | **self.openai_kwargs, 351 | ) 352 | result = completion['choices'][0] 353 | response = { 354 | "raw_response": result.get("message", {}).get("content", ""), 355 | "stop_reason": result.get("finish_reason", ), 356 | } 357 | return response 358 | except Exception as e: 359 | logger.warning(str(e)) 360 | logger.warning(f"Trail No. {trial_idx + 1} Failed, now sleep and retrying...") 361 | time.sleep(self.request_sleep_time) 362 | return response 363 | 364 | async def dispatch_openai_requests( 365 | self, 366 | messages_list: List[List[Dict[str, Any]]], 367 | model: str, 368 | temperature: float, 369 | max_tokens: int, 370 | top_p: float, 371 | n: int, 372 | logit_bias: dict, 373 | ) -> List[str]: 374 | """Dispatches requests to OpenAI API asynchronously. 375 | 376 | Args: 377 | messages_list: List of messages to be sent to OpenAI ChatCompletion API. 378 | model: OpenAI model to use. 379 | temperature: Temperature to use for the model. 380 | max_tokens: Maximum number of tokens to generate. 381 | top_p: Top p to use for the model. 382 | n: Return sentence nums. 383 | logit_bias: logit bias. 384 | Returns: 385 | List of responses from OpenAI API. 386 | """ 387 | async_responses = [ 388 | openai.ChatCompletion.acreate( 389 | model=model, 390 | messages=x, 391 | temperature=temperature, 392 | max_tokens=max_tokens, 393 | top_p=top_p, 394 | n=n, 395 | logit_bias=logit_bias, 396 | ) 397 | for x in messages_list 398 | ] 399 | return await asyncio.gather(*async_responses) 400 | 401 | def request_openai_dispatch(self, prompt_list: List[str]) -> List[Dict[str, str]]: 402 | responses = None 403 | for trial_idx in range(self.max_retry_times): 404 | try: 405 | messages_list = [[ 406 | {"role": "user", 407 | "content": prompt}, 408 | ] for prompt in prompt_list] 409 | prompt_len = max([len(self.tokenizer(prompt)['input_ids']) for prompt in prompt_list]) 410 | completions = asyncio.run( 411 | self.dispatch_openai_requests( 412 | messages_list=messages_list, 413 | max_tokens=4096 - 300 - prompt_len, 414 | **self.openai_kwargs, 415 | ) 416 | ) 417 | responses = [] 418 | for completion in completions: 419 | result = completion['choices'][0] 420 | response = { 421 | "raw_response": result.get("message", {}).get("content", ""), 422 | "stop_reason": result.get("finish_reason", ), 423 | } 424 | responses.append(response) 425 | return responses 426 | except Exception as e: 427 | logger.warning(str(e)) 428 | logger.warning(f"Trail No. {trial_idx + 1} Failed, now sleep and retrying...") 429 | time.sleep(self.request_sleep_time) 430 | return responses 431 | 432 | def request_openai(self, prompt): 433 | use_async = not type(prompt) is str 434 | if use_async: 435 | response = self.request_openai_dispatch(prompt) 436 | else: 437 | response = self.request_openai_single(prompt) 438 | return response 439 | 440 | def request_claude(self, prompt: str) -> Dict[str, str]: 441 | import anthropic 442 | result = None 443 | for trial_idx in range(self.max_retry_times): 444 | try: 445 | prompt = f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}" 446 | prompt_len = len(self.tokenizer(prompt)['input_ids']) 447 | resp = self.claude_client.completion( 448 | prompt=prompt, 449 | max_tokens_to_sample=4096 - 300 - prompt_len, 450 | **self.claude_kwargs, 451 | ) 452 | result = { 453 | "raw_response": resp["completion"], 454 | "stop_reason": resp["stop_reason"], 455 | } 456 | return result 457 | except Exception as e: 458 | logger.warning(str(e)) 459 | logger.warning(f"Trail No. {trial_idx + 1} Failed, now sleep and retrying...") 460 | time.sleep(self.request_sleep_time) 461 | return result 462 | 463 | def prepare_mine_hparams(self, **kwargs): 464 | """ 465 | Currently all hard-coded 466 | """ 467 | 468 | self.assistant_name = kwargs.get("assistant_name") if kwargs.get("assistant_name") else "openai" 469 | self.assistant_request_map = { 470 | "openai": self.request_openai, 471 | "claude": self.request_claude, 472 | } 473 | self.assistant_postprocess_map = { 474 | "openai": (self.post_process_gpt3_response_extend, self.post_process_gpt3_response_enrich), 475 | "claude": (self.post_process_gpt3_response_extend, self.post_process_gpt3_response_enrich), 476 | } 477 | self.postprocess_extend, self.postprocess_enrich = self.assistant_postprocess_map[self.assistant_name] 478 | self.request_func = self.assistant_request_map[self.assistant_name] 479 | self.default_demonstrate_num = 2 # num of demonstrate example, during extending & enriching 480 | self.num_example_extend = 2 # triplet per new subtask, during extending 481 | self.default_extend_num = 5 # total new subtask num per request, during extending 482 | self.default_enrich_num = 50 # cumulative enrich num, during enriching 483 | self.num_example_enrich = 10 # num example per request, during enriching 484 | self.max_child_num = 12 # maximum children num 485 | self.max_example_num = 60000 # maximum example num per node 486 | self.min_extend_num = 3 # minimum request subtask num, during extending 487 | self.min_enrich_num = 5 # minimum request example num, during enriching 488 | self.default_extend_batch_size = 1 # prompt batch size, during extending 489 | self.default_enrich_batch_size = 1 # prompt batch size, during enriching 490 | 491 | self.max_retry_times = 20 # max retry times for online querying 492 | self.request_sleep_time = 20 # sleep time after online query fail 493 | self.attribute_of_interest = [self.unique_id, "reason", "examples", "config_filename", "raw_task_name"] 494 | self.attribute_meta_save = [self.unique_id, "reason", "config_filename", "raw_task_name"] 495 | self.general_outdir = getattr(self.root, "general_outdir", 496 | f"./mined_data_{getattr(self.root, self.unique_id)}") 497 | os.makedirs(self.general_outdir, exist_ok=True) 498 | self.bad_res_outfile = os.path.join(self.general_outdir, "bad_completions.jsonl") 499 | 500 | def gather_all_examples(self, ) -> Dict[str, List[Dict[str, str]]]: 501 | node_examples_map, total_count = dict(), 0 502 | for node in PreOrderIter(self.root): 503 | node_exampels = getattr(node, "examples", []) 504 | node_examples_map[getattr(node, self.unique_id)] = node_exampels 505 | total_count += len(node_exampels) 506 | node_examples_map["total_count"] = total_count 507 | return node_examples_map 508 | 509 | def prepare_prompt(self, ): 510 | raise NotImplementedError 511 | 512 | def prepare_tools(self): 513 | raise NotImplementedError 514 | 515 | def get_demonstrate_examples(self, node_to_prepare: AnyNode) -> List[Dict[str, str]]: 516 | demonstrate_node = node_to_prepare.parent if node_to_prepare.parent else node_to_prepare 517 | demonstrate_examples_pool = getattr(demonstrate_node, "examples", []) 518 | demonstrate_examples = random.sample(demonstrate_examples_pool, self.default_demonstrate_num) \ 519 | if demonstrate_examples_pool else [] 520 | return demonstrate_examples 521 | 522 | def add_node(self, node_info: Dict[str, Union[str, List[str]]]) -> AnyNode: 523 | assert self.unique_id in node_info and "parent" in node_info, \ 524 | f"Invalid node to add: {node_info}: both .{self.unique_id} and .parent are required!" 525 | new_node = AnyNode( 526 | **node_info, 527 | ) 528 | self.update_file(new_node, file_type="node", update_mode="new_file") 529 | self.name_to_node[getattr(new_node, self.unique_id)] = new_node 530 | logger.info(f"Node added: {node_info[self.unique_id]}") 531 | return new_node 532 | 533 | def add_node_to_tree(self, new_node: AnyNode): 534 | self.update_file(new_node, file_type="node", update_mode="new_file") 535 | self.name_to_node[getattr(new_node, self.unique_id)] = new_node 536 | logger.info(f"Node added: {getattr(new_node, self.unique_id)}") 537 | 538 | def update_file( 539 | self, 540 | node: AnyNode, 541 | update_mode: str = "new_file", 542 | file_type: str = "node", 543 | ): 544 | if update_mode == "new_file": 545 | node_save_file = os.path.join(self.general_outdir, node.config_filename) 546 | with open(node_save_file, "w", encoding="utf-8") as f_out: 547 | save_dict = {k: v for k, v in vars(node).items() if k in self.attribute_of_interest} 548 | json.dump( 549 | save_dict, 550 | f_out, 551 | ensure_ascii=False, 552 | ) 553 | f_out.write("\n") 554 | 555 | def handle_bad_completion( 556 | self, 557 | result: Dict[str, str], 558 | **kwargs, 559 | ): 560 | result.update(**kwargs) 561 | with open(self.bad_res_outfile, "a+", encoding="utf-8") as f_out: 562 | json.dump( 563 | result, 564 | f_out, 565 | ensure_ascii=False, 566 | ) 567 | f_out.write("\n") 568 | 569 | def encode_prompt( 570 | self, 571 | base_prompt: str, 572 | demonstrate_examples: Optional[List[str]] = [], 573 | target_task: str = None, 574 | existing_children: List[AnyNode] = [], 575 | existing_siblings: List[AnyNode] = [], 576 | num_examples_per_time: int = None, 577 | extend_num: int = None, 578 | new_subtask: str = None, 579 | new_subtask_reason: str = None, 580 | target_children_num: int = None, 581 | ) -> str: 582 | raise NotImplementedError 583 | 584 | def formalize_taskname(self, taskname: str): 585 | raise NotImplementedError 586 | 587 | def save_to_local(self): 588 | """ 589 | save meta file and each node info 590 | """ 591 | # first check all node already saved to individual file 592 | for node in PreOrderIter(self.root): 593 | # config_file = getattr(node, "config_file") if getattr(node, "config_file", None)\ 594 | # else os.path.join(self.general_outdir, getattr(node, "config_filename")) 595 | config_file = os.path.join(self.general_outdir, getattr(node, "config_filename")) 596 | if not os.path.isfile(config_file): 597 | logger.warning( 598 | f"Node info for {getattr(node, self.unique_id)} is not saved! Now saving to {config_file}") 599 | self.update_file(node, file_type="node", update_mode="new_file") 600 | 601 | # export meta file (set attriter in order to ) 602 | exporter = DictExporter( 603 | attriter=lambda attrs: [(k, v) for k, v in attrs if k in self.attribute_meta_save] 604 | ) 605 | meta_info = exporter.export(self.root) 606 | with open(os.path.join(self.general_outdir, meta_info_filename), "w", encoding="utf-8") as f_out: 607 | json.dump( 608 | meta_info, 609 | f_out, 610 | ensure_ascii=False, 611 | ) 612 | f_out.write("\n") 613 | return 614 | 615 | def post_process_gpt3_response_enrich( 616 | self, 617 | response: str, 618 | current_new_subtask: str = None, 619 | current_new_subtask_reason: str = None, 620 | ) -> Tuple[Union[str, List[str]], Union[str, List[str]], List[Dict[str, str]]]: 621 | raise NotImplementedError 622 | 623 | @staticmethod 624 | def parse_node_config( 625 | node: AnyNode, 626 | config_field_name: str = "config_filename", 627 | **kwargs, 628 | ) -> AnyNode: 629 | config_filename = getattr(node, config_field_name, None) 630 | base_dir = kwargs.get("base_dir") 631 | config_file = os.path.join(base_dir, config_filename) 632 | if not config_filename or (not os.path.isfile(config_file)): 633 | logger.warning(f"Failed to load node info from {config_file}! Loading fail for {node}") 634 | return 635 | with open(config_file, encoding="utf-8") as f_in: 636 | task_config = json.load(f_in) 637 | for k, v in task_config.items(): 638 | # if getattr(node, config_field_name, None): 639 | # logger.warning(f"") 640 | setattr(node, k, v) 641 | 642 | def post_process_gpt3_response_extend( 643 | self, 644 | response: str 645 | ) -> Tuple[List[str], List[str], List[Dict[str, str]]]: 646 | raise NotImplementedError 647 | 648 | @classmethod 649 | def from_tree_dict( 650 | cls, 651 | domain_tree_dict: Dict[str, Union[str, List[Dict[str, str]]]] = {}, 652 | save_dir: str = None, 653 | out_dir: str = None, 654 | **kwargs, 655 | ): 656 | """ 657 | Currently only for debugging use! 658 | """ 659 | root_node = DictImporter().import_(domain_tree_dict) 660 | for node in PreOrderIter(root_node): 661 | cls.parse_node_config(node, base_dir=save_dir) 662 | if out_dir is not None: 663 | root_node.general_outdir = out_dir 664 | return cls(root_node, **kwargs) 665 | 666 | @classmethod 667 | def from_local_dir( 668 | cls, 669 | save_dir: str, 670 | out_dir: str = None, 671 | meta_file: str = meta_info_filename, 672 | **kwargs, 673 | ): 674 | # first load meta info, and infill each node info by loading config file 675 | 676 | with open(os.path.join(save_dir, meta_file)) as f_in: 677 | meta_info = json.load(f_in) 678 | root_node = DictImporter().import_(meta_info) 679 | for node in PreOrderIter(root_node): 680 | cls.parse_node_config(node, base_dir=save_dir) 681 | root_node.general_outdir = out_dir if out_dir is not None else save_dir 682 | return cls(root_node, **kwargs) 683 | 684 | 685 | class EnDomainTreeRewrite(DomainTree): 686 | 687 | def encode_prompt( 688 | self, 689 | base_prompt: str, 690 | demonstrate_examples: Optional[List[str]] = [], 691 | target_task: str = None, 692 | existing_children: List[AnyNode] = [], 693 | existing_siblings: List[AnyNode] = [], 694 | num_examples_per_time: int = None, 695 | extend_num: int = None, 696 | new_subtask: str = None, 697 | new_subtask_reason: str = None, 698 | target_children_num: int = None, 699 | ) -> str: 700 | prompt = base_prompt 701 | prompt += f"\nTarget task: {target_task}\n" 702 | if len(demonstrate_examples) > 0: 703 | prompt += "Examples:\n" 704 | for idx, task_dict in enumerate(demonstrate_examples): 705 | (instruction, input, output) = task_dict["instruction"], task_dict["input"], task_dict["output"] 706 | instruction = re.sub(r"\s+", " ", instruction).strip().rstrip(":") 707 | input = "" if input.lower() == "" else input 708 | prompt += "###\n" 709 | prompt += f"{idx + 1}. Instruction: {instruction}\n" 710 | prompt += f"Input: {input}\n" 711 | prompt += f"Output: {output}\n" 712 | prompt += "###\n" 713 | existing_children_names = [getattr(node, self.unique_id) for node in existing_children] 714 | existing_siblings_names = [getattr(node, self.unique_id) for node in existing_siblings] 715 | prompt += f"\nThe list of already existing subtasks for this target task is: {existing_children_names}.\n" 716 | prompt += f"The list of already existing peer tasks for this target task is: {existing_siblings_names}.\n" 717 | 718 | if target_children_num is not None: # for extending 719 | prompt += f"\nThe target task should be decomposed into a total of {target_children_num} diverse and complementary subtasks, " \ 720 | f"and there are {len(existing_children)} existing subtasks. " \ 721 | f"Generate {extend_num} new subtasks with the corresponding reason, then list {num_examples_per_time} examples of this new subtask:" 722 | else: # for enriching 723 | prompt += f"\nList {num_examples_per_time} examples of this new subtask below:" 724 | 725 | if new_subtask: # for enriching 726 | prompt += "\n" 727 | prompt += f"\nNew subtask: {new_subtask}\n" 728 | prompt += f"Reason: {new_subtask_reason}" 729 | return prompt 730 | 731 | def prepare_prompt(self, ): 732 | self.extend_node_prompt = open("./prompt_bank/prompt_write_assistance_extend.txt").read() + "\n" 733 | self.enrich_node_prompt = open("./prompt_bank/prompt_write_assistance_enrich.txt").read() + "\n" 734 | 735 | def prepare_tools(self): 736 | logger.info("Preparint tools...") 737 | self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") 738 | self.openai_kwargs = { 739 | "model": "gpt-3.5-turbo", # openai model type 740 | "temperature": 1.0, 741 | "top_p": 1.0, 742 | "n": 1, 743 | "logit_bias": {"50256": -100}, # prevent the <|endoftext|> token from being generated 744 | } 745 | if self.assistant_name == "claude": 746 | import anthropic 747 | self.claude_client = anthropic.Client(claude_key, proxy_url="http://127.0.0.1:2802") 748 | self.claude_kwargs = { 749 | "stop_sequences": [anthropic.HUMAN_PROMPT], 750 | "model": "claude-v1.3", # anthropic model type 751 | } 752 | 753 | def formalize_taskname(self, taskname: str): 754 | taskname = re.sub('[^A-Za-z0-9]+', ' ', taskname) 755 | taskname = taskname.strip().replace(" ", "_").lower() 756 | # dedup 757 | if taskname in self.name_to_node: 758 | for i in range(10): 759 | dedup_taskname = f"{taskname}_{i}" 760 | if dedup_taskname not in self.name_to_node: 761 | break 762 | 763 | taskname = dedup_taskname 764 | return taskname 765 | 766 | def post_process_gpt3_response_enrich( 767 | self, 768 | response: str, 769 | current_new_subtask: str = None, 770 | current_new_subtask_reason: str = None, 771 | ) -> Tuple[Union[str, List[str]], Union[str, List[str]], List[Dict[str, str]]]: 772 | if response is None: 773 | return None, current_new_subtask, current_new_subtask_reason 774 | stop_reason = response.get("stop_reason", "") 775 | raw_response = response.get("raw_response", "") 776 | 777 | raw_response = raw_response.replace("###", "").replace("### ", "").replace(" ###", "").replace(" ### ", "") 778 | new_subtask = current_new_subtask 779 | new_subtask_reason = current_new_subtask_reason 780 | new_subtask_examples = "\n" + raw_response.lstrip("Examples:") 781 | new_subtask_examples = re.split("\n+\d+\.\s+", new_subtask_examples) 782 | new_subtask_examples = new_subtask_examples[1:] 783 | instructions = [] 784 | for idx, inst in enumerate(new_subtask_examples): 785 | # if the decoding stops due to length, the last example is likely truncated so we discard it 786 | if idx == len(new_subtask_examples) - 1 and stop_reason == "length": 787 | continue 788 | idx += 1 789 | splitted_data = re.split("Instruction:|Input:|Output:", inst) 790 | if len(splitted_data) != 4: 791 | continue 792 | inst = splitted_data[1].strip() 793 | input = splitted_data[2].strip() 794 | input = "" if input.lower() == "" else input 795 | output = splitted_data[3].strip().strip() 796 | # filter out too short or too long instructions 797 | if len(inst.split()) < 3 or len(inst.split()) > 150: 798 | continue 799 | # filter based on keywords that are not suitable for language models. 800 | blacklist = [ 801 | "image", 802 | "images", 803 | "graph", 804 | "graphs", 805 | "picture", 806 | "pictures", 807 | "file", 808 | "files", 809 | "map", 810 | "maps", 811 | "draw", 812 | "plot", 813 | "go to", 814 | "video", 815 | "audio", 816 | "music", 817 | "flowchart", 818 | "diagram", 819 | ] 820 | blacklist += [] 821 | if any(find_word_in_string(word, inst) for word in blacklist): 822 | continue 823 | # We found that the model tends to add "write a program" to some existing instructions, which lead to a lot of such instructions. 824 | # And it's a bit comfusing whether the model need to write a program or directly output the result. 825 | # Here we filter them out. 826 | # Note this is not a comprehensive filtering for all programming instructions. 827 | if inst.startswith("Write a program"): 828 | continue 829 | # filter those starting with punctuation 830 | if inst[0] in string.punctuation: 831 | continue 832 | # filter those starting with non-english character 833 | if not inst[0].isascii(): 834 | continue 835 | # filter un-complete input 836 | if input.startswith("<") and input.endswith(">"): 837 | continue 838 | if input.startswith("(") and input.endswith(")"): 839 | continue 840 | instructions.append({"instruction": inst, "input": input, "output": output}) 841 | return new_subtask, new_subtask_reason, instructions 842 | 843 | def post_process_gpt3_response_extend( 844 | self, 845 | response: str 846 | ) -> Tuple[List[str], List[str], List[List[Dict[str, str]]]]: 847 | stop_reason = response.get("stop_reason", "") 848 | raw_response = response.get("raw_response", "") 849 | 850 | raw_response = raw_response.replace("###", "").replace("### ", "").replace(" ###", "").replace(" ### ", "") 851 | raw_response = raw_response.replace("Example:", "Examples:") 852 | split_response = re.split("New subtask:|Reason:|Examples:", raw_response) 853 | split_response = split_response[1:] 854 | num_subtasks = len(split_response) // 3 855 | new_subtasks = [] 856 | new_subtasks_reason = [] 857 | new_subtasks_example = [] 858 | for i in range(num_subtasks): 859 | new_subtask = split_response[i * 3].strip() 860 | new_subtask_reason = split_response[i * 3 + 1].strip() 861 | new_subtask_examples = split_response[i * 3 + 2] 862 | new_subtask_examples = re.split("\n+\d+\.\s+", new_subtask_examples) 863 | new_subtask_examples = new_subtask_examples[1:] 864 | instructions = [] 865 | for idx, inst in enumerate(new_subtask_examples): 866 | # if the decoding stops due to length, the last example is likely truncated so we discard it 867 | if idx == len(new_subtask_examples) - 1 and stop_reason == "length": 868 | continue 869 | splitted_data = re.split("Instruction:|Input:|Output:", inst) 870 | if len(splitted_data) != 4: 871 | continue 872 | inst = splitted_data[1].strip() 873 | input = splitted_data[2].strip() 874 | input = "" if input.lower() == "" else input 875 | output = splitted_data[3].strip().strip() 876 | # filter out too short or too long instructions 877 | if len(inst.split()) < 3 or len(inst.split()) > 150: 878 | continue 879 | # filter based on keywords that are not suitable for language models. 880 | blacklist = [ 881 | "image", 882 | "images", 883 | "graph", 884 | "graphs", 885 | "picture", 886 | "pictures", 887 | "file", 888 | "files", 889 | "map", 890 | "maps", 891 | "draw", 892 | "plot", 893 | "go to", 894 | "video", 895 | "audio", 896 | "music", 897 | "flowchart", 898 | "diagram", 899 | ] 900 | blacklist += [] 901 | if any(find_word_in_string(word, inst) for word in blacklist): 902 | continue 903 | # We found that the model tends to add "write a program" to some existing instructions, which lead to a lot of such instructions. 904 | # And it's a bit comfusing whether the model need to write a program or directly output the result. 905 | # Here we filter them out. 906 | # Note this is not a comprehensive filtering for all programming instructions. 907 | if inst.startswith("Write a program"): 908 | continue 909 | # filter those starting with punctuation 910 | if inst[0] in string.punctuation: 911 | continue 912 | # filter those starting with non-english character 913 | if not inst[0].isascii(): 914 | continue 915 | # filter un-complete input 916 | if input.startswith("<") and input.endswith(">"): 917 | continue 918 | if input.startswith("(") and input.endswith(")"): 919 | continue 920 | instructions.append({"instruction": inst, "input": input, "output": output}) 921 | new_subtasks.append(new_subtask) 922 | new_subtasks_reason.append(new_subtask_reason) 923 | new_subtasks_example.append(instructions) 924 | return new_subtasks, new_subtasks_reason, new_subtasks_example 925 | 926 | 927 | class EnDomainTreeBrainstorming(DomainTree): 928 | 929 | def encode_prompt( 930 | self, 931 | base_prompt: str, 932 | demonstrate_examples: Optional[List[str]] = [], 933 | target_task: str = None, 934 | existing_children: List[AnyNode] = [], 935 | existing_siblings: List[AnyNode] = [], 936 | num_examples_per_time: int = None, 937 | extend_num: int = None, 938 | new_subtask: str = None, 939 | new_subtask_reason: str = None, 940 | target_children_num: int = None, 941 | ) -> str: 942 | prompt = base_prompt 943 | prompt += f"\nTarget task: {target_task}\n" 944 | if len(demonstrate_examples) > 0: 945 | prompt += "Examples:\n" 946 | for idx, task_dict in enumerate(demonstrate_examples): 947 | (instruction, input, output) = task_dict["instruction"], task_dict["input"], task_dict["output"] 948 | instruction = re.sub(r"\s+", " ", instruction).strip().rstrip(":") 949 | input = "" if input.lower() == "" else input 950 | prompt += "###\n" 951 | prompt += f"Instruction: {instruction}\n" # without index 952 | prompt += f"Input: {input}\n" 953 | prompt += f"Output: {output}\n" 954 | prompt += "###\n" 955 | existing_children_names = [getattr(node, self.unique_id) for node in existing_children] 956 | existing_siblings_names = [getattr(node, self.unique_id) for node in existing_siblings] 957 | prompt += f"\nThe list of already existing subtasks for this target task is: {existing_children_names}.\n" 958 | prompt += f"The list of already existing peer tasks for this target task is: {existing_siblings_names}.\n" 959 | 960 | if target_children_num is not None: # for extending 961 | prompt += f"\nThe target task should be decomposed into a total of {target_children_num} diverse and complementary subtasks, " \ 962 | f"and there are {len(existing_children)} existing subtasks. " \ 963 | f"Generate {extend_num} new subtasks with the corresponding reason, then list {num_examples_per_time} examples of this new subtask:" 964 | else: # for enriching 965 | prompt += f"\nList {num_examples_per_time} examples of this new subtask below:" 966 | 967 | if new_subtask: # for enriching 968 | prompt += "\n" 969 | prompt += f"\nNew subtask: {new_subtask}\n" 970 | prompt += f"Reason: {new_subtask_reason}" 971 | return prompt 972 | 973 | def prepare_prompt(self, ): 974 | self.extend_node_prompt = open("./prompt_bank/prompt_brainstorming_extend.txt").read() + "\n" 975 | self.enrich_node_prompt = open("./prompt_bank/prompt_brainstorming_enrich.txt").read() + "\n" 976 | 977 | def prepare_tools(self): 978 | logger.info("Preparint tools...") 979 | self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") 980 | self.openai_kwargs = { 981 | "model": "gpt-3.5-turbo", # openai model type 982 | "temperature": 1.0, 983 | "top_p": 1.0, 984 | "n": 1, 985 | "logit_bias": {"50256": -100}, # prevent the <|endoftext|> token from being generated 986 | } 987 | if self.assistant_name == "claude": 988 | import anthropic 989 | self.claude_client = anthropic.Client(claude_key, proxy_url="http://127.0.0.1:2802") 990 | self.claude_kwargs = { 991 | "stop_sequences": [anthropic.HUMAN_PROMPT], 992 | "model": "claude-v1.3", # anthropic model type 993 | } 994 | 995 | def formalize_taskname(self, taskname: str): 996 | taskname = re.sub('[^A-Za-z0-9]+', ' ', taskname) 997 | taskname = taskname.strip().replace(" ", "_").lower() 998 | # dedup 999 | if taskname in self.name_to_node: 1000 | for i in range(10): 1001 | dedup_taskname = f"{taskname}_{i}" 1002 | if dedup_taskname not in self.name_to_node: 1003 | break 1004 | 1005 | taskname = dedup_taskname 1006 | return taskname 1007 | 1008 | def post_process_gpt3_response_enrich( 1009 | self, 1010 | response: str, 1011 | current_new_subtask: str = None, 1012 | current_new_subtask_reason: str = None, 1013 | ) -> Tuple[Union[str, List[str]], Union[str, List[str]], List[Dict[str, str]]]: 1014 | if response is None: 1015 | return None, current_new_subtask, current_new_subtask_reason 1016 | stop_reason = response.get("stop_reason", "") 1017 | raw_response = response.get("raw_response", "") 1018 | 1019 | raw_response = raw_response.replace("###", "").replace("### ", "").replace(" ###", "").replace(" ### ", "") 1020 | new_subtask = current_new_subtask 1021 | new_subtask_reason = current_new_subtask_reason 1022 | new_subtask_examples = "\n" + raw_response.lstrip("Examples:") 1023 | new_subtask_examples = re.split("Instruction:", new_subtask_examples) 1024 | new_subtask_examples = new_subtask_examples[1:] 1025 | instructions = [] 1026 | for idx, inst in enumerate(new_subtask_examples): 1027 | # if the decoding stops due to length, the last example is likely truncated so we discard it 1028 | if idx == len(new_subtask_examples) - 1 and stop_reason == "length": 1029 | continue 1030 | idx += 1 1031 | splitted_data = re.split("Instruction:|Input:|Output:", inst) 1032 | if len(splitted_data) != 3: 1033 | continue 1034 | inst = splitted_data[0].strip() 1035 | input = splitted_data[1].strip() 1036 | input = "" if input.lower() == "" else input 1037 | output = splitted_data[2].strip().strip() 1038 | # filter out too short or too long instructions 1039 | if len(inst.split()) < 3 or len(inst.split()) > 150: 1040 | continue 1041 | # filter based on keywords that are not suitable for language models. 1042 | blacklist = [ 1043 | "image", 1044 | "images", 1045 | "graph", 1046 | "graphs", 1047 | "picture", 1048 | "pictures", 1049 | "file", 1050 | "files", 1051 | # "map", 1052 | # "maps", 1053 | "draw", 1054 | "plot", 1055 | "go to", 1056 | "video", 1057 | "audio", 1058 | "music", 1059 | "flowchart", 1060 | "diagram", 1061 | ] 1062 | blacklist += [] 1063 | if any(find_word_in_string(word, inst) for word in blacklist): 1064 | continue 1065 | # We found that the model tends to add "write a program" to some existing instructions, which lead to a lot of such instructions. 1066 | # And it's a bit comfusing whether the model need to write a program or directly output the result. 1067 | # Here we filter them out. 1068 | # Note this is not a comprehensive filtering for all programming instructions. 1069 | if inst.startswith("Write a program"): 1070 | continue 1071 | # filter those starting with punctuation 1072 | if inst[0] in string.punctuation: 1073 | continue 1074 | # filter those starting with non-english character 1075 | if not inst[0].isascii(): 1076 | continue 1077 | # filter un-complete input 1078 | if input.startswith("<") and input.endswith(">"): 1079 | continue 1080 | if input.startswith("(") and input.endswith(")"): 1081 | continue 1082 | instructions.append({"instruction": inst, "input": input, "output": output}) 1083 | return new_subtask, new_subtask_reason, instructions 1084 | 1085 | def post_process_gpt3_response_extend( 1086 | self, 1087 | response: str 1088 | ) -> Tuple[List[str], List[str], List[List[Dict[str, str]]]]: 1089 | stop_reason = response.get("stop_reason", "") 1090 | raw_response = response.get("raw_response", "") 1091 | 1092 | raw_response = raw_response.replace("###", "").replace("### ", "").replace(" ###", "").replace(" ### ", "") 1093 | raw_response = raw_response.replace("Example:", "Examples:") 1094 | split_response = re.split("New subtask:|Reason:|Examples:", raw_response) 1095 | split_response = split_response[1:] 1096 | num_subtasks = len(split_response) // 3 1097 | new_subtasks = [] 1098 | new_subtasks_reason = [] 1099 | new_subtasks_example = [] 1100 | for i in range(num_subtasks): 1101 | new_subtask = split_response[i * 3].strip() 1102 | new_subtask_reason = split_response[i * 3 + 1].strip() 1103 | new_subtask_examples = split_response[i * 3 + 2] 1104 | new_subtask_examples = re.split("Instruction:", new_subtask_examples) 1105 | new_subtask_examples = new_subtask_examples[1:] 1106 | instructions = [] 1107 | for idx, inst in enumerate(new_subtask_examples): 1108 | # if the decoding stops due to length, the last example is likely truncated so we discard it 1109 | if idx == len(new_subtask_examples) - 1 and stop_reason == "length": 1110 | continue 1111 | splitted_data = re.split("Instruction:|Input:|Output:", inst) 1112 | if len(splitted_data) != 3: 1113 | continue 1114 | inst = splitted_data[0].strip() 1115 | input = splitted_data[1].strip() 1116 | input = "" if input.lower() == "" else input 1117 | output = splitted_data[2].strip().strip() 1118 | # filter out too short or too long instructions 1119 | if len(inst.split()) < 3 or len(inst.split()) > 150: 1120 | continue 1121 | # filter based on keywords that are not suitable for language models. 1122 | blacklist = [ 1123 | "image", 1124 | "images", 1125 | "graph", 1126 | "graphs", 1127 | "picture", 1128 | "pictures", 1129 | "file", 1130 | "files", 1131 | # "map", 1132 | # "maps", 1133 | "draw", 1134 | "plot", 1135 | "go to", 1136 | "video", 1137 | "audio", 1138 | "music", 1139 | "flowchart", 1140 | "diagram", 1141 | ] 1142 | blacklist += [] 1143 | if any(find_word_in_string(word, inst) for word in blacklist): 1144 | continue 1145 | # We found that the model tends to add "write a program" to some existing instructions, which lead to a lot of such instructions. 1146 | # And it's a bit comfusing whether the model need to write a program or directly output the result. 1147 | # Here we filter them out. 1148 | # Note this is not a comprehensive filtering for all programming instructions. 1149 | if inst.startswith("Write a program"): 1150 | continue 1151 | # filter those starting with punctuation 1152 | if inst[0] in string.punctuation: 1153 | continue 1154 | # filter those starting with non-english character 1155 | if not inst[0].isascii(): 1156 | continue 1157 | # filter un-complete input 1158 | if input.startswith("<") and input.endswith(">"): 1159 | continue 1160 | if input.startswith("(") and input.endswith(")"): 1161 | continue 1162 | instructions.append({"instruction": inst, "input": input, "output": output}) 1163 | new_subtasks.append(new_subtask) 1164 | new_subtasks_reason.append(new_subtask_reason) 1165 | new_subtasks_example.append(instructions) 1166 | return new_subtasks, new_subtasks_reason, new_subtasks_example 1167 | 1168 | 1169 | class EnDomainTreeMath(DomainTree): 1170 | 1171 | def encode_prompt( 1172 | self, 1173 | base_prompt: str, 1174 | demonstrate_examples: Optional[List[str]] = [], 1175 | target_task: str = None, 1176 | existing_children: List[AnyNode] = [], 1177 | existing_siblings: List[AnyNode] = [], 1178 | num_examples_per_time: int = None, 1179 | extend_num: int = None, 1180 | new_subtask: str = None, 1181 | new_subtask_reason: str = None, 1182 | target_children_num: int = None, 1183 | ) -> str: 1184 | prompt = base_prompt 1185 | prompt += f"\nTarget task: {target_task}\n" 1186 | if len(demonstrate_examples) > 0: 1187 | prompt += "Examples:\n" 1188 | for idx, task_dict in enumerate(demonstrate_examples): 1189 | (instruction, input, output) = task_dict["instruction"], task_dict["input"], task_dict["output"] 1190 | instruction = re.sub(r"\s+", " ", instruction).strip().rstrip(":") 1191 | input = "" if input.lower() == "" else input 1192 | prompt += "###\n" 1193 | prompt += f"Instruction: {instruction}\n" # without index 1194 | prompt += f"Input: {input}\n" 1195 | prompt += f"Output: {output}\n" 1196 | prompt += "###\n" 1197 | existing_children_names = [getattr(node, self.unique_id) for node in existing_children] 1198 | existing_siblings_names = [getattr(node, self.unique_id) for node in existing_siblings] 1199 | prompt += f"\nThe list of already existing subtasks for this target task is: {existing_children_names}.\n" 1200 | prompt += f"The list of already existing peer tasks for this target task is: {existing_siblings_names}.\n" 1201 | 1202 | if target_children_num is not None: # for extending 1203 | prompt += f"\nThe target task should be decomposed into a total of {target_children_num} diverse and complementary subtasks, " \ 1204 | f"and there are {len(existing_children)} existing subtasks. " \ 1205 | f"Generate {extend_num} new subtasks with the corresponding reason, then list {num_examples_per_time} examples of this new subtask:" 1206 | else: # for enriching 1207 | prompt += f"\nList {num_examples_per_time} examples of this new subtask below:" 1208 | 1209 | if new_subtask: # for enriching 1210 | prompt += "\n" 1211 | prompt += f"\nNew subtask: {new_subtask}\n" 1212 | prompt += f"Reason: {new_subtask_reason}" 1213 | return prompt 1214 | 1215 | def prepare_prompt(self, ): 1216 | self.extend_node_prompt = open("./prompt_bank/prompt_math_extend.txt").read() + "\n" 1217 | self.enrich_node_prompt = open("./prompt_bank/prompt_math_enrich.txt").read() + "\n" 1218 | 1219 | def prepare_tools(self): 1220 | logger.info("Preparint tools...") 1221 | self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") 1222 | self.openai_kwargs = { 1223 | "model": "gpt-3.5-turbo", # openai model type 1224 | "temperature": 1.0, 1225 | "top_p": 1.0, 1226 | "n": 1, 1227 | "logit_bias": {"50256": -100}, # prevent the <|endoftext|> token from being generated 1228 | } 1229 | if self.assistant_name == "claude": 1230 | import anthropic 1231 | self.claude_client = anthropic.Client(claude_key, proxy_url="http://127.0.0.1:2802") 1232 | self.claude_kwargs = { 1233 | "stop_sequences": [anthropic.HUMAN_PROMPT], 1234 | "model": "claude-v1.3", # anthropic model type 1235 | } 1236 | 1237 | def formalize_taskname(self, taskname: str): 1238 | taskname = re.sub('[^A-Za-z0-9]+', ' ', taskname) 1239 | taskname = taskname.strip().replace(" ", "_").lower() 1240 | # dedup 1241 | if taskname in self.name_to_node: 1242 | for i in range(10): 1243 | dedup_taskname = f"{taskname}_{i}" 1244 | if dedup_taskname not in self.name_to_node: 1245 | break 1246 | 1247 | taskname = dedup_taskname 1248 | return taskname 1249 | 1250 | def post_process_gpt3_response_enrich( 1251 | self, 1252 | response: str, 1253 | current_new_subtask: str = None, 1254 | current_new_subtask_reason: str = None, 1255 | ) -> Tuple[Union[str, List[str]], Union[str, List[str]], List[Dict[str, str]]]: 1256 | if response is None: 1257 | return None, current_new_subtask, current_new_subtask_reason 1258 | stop_reason = response.get("stop_reason", "") 1259 | raw_response = response.get("raw_response", "") 1260 | 1261 | raw_response = raw_response.replace("###", "").replace("### ", "").replace(" ###", "").replace(" ### ", "") 1262 | new_subtask = current_new_subtask 1263 | new_subtask_reason = current_new_subtask_reason 1264 | new_subtask_examples = "\n" + raw_response.lstrip("Examples:") 1265 | new_subtask_examples = re.split("Instruction:", new_subtask_examples) 1266 | new_subtask_examples = new_subtask_examples[1:] 1267 | instructions = [] 1268 | for idx, inst in enumerate(new_subtask_examples): 1269 | # if the decoding stops due to length, the last example is likely truncated so we discard it 1270 | if idx == len(new_subtask_examples) - 1 and stop_reason == "length": 1271 | continue 1272 | idx += 1 1273 | splitted_data = re.split("Instruction:|Input:|Output:", inst) 1274 | if len(splitted_data) != 3: 1275 | continue 1276 | inst = splitted_data[0].strip() 1277 | input = splitted_data[1].strip() 1278 | input = "" if input.lower() == "" else input 1279 | output = splitted_data[2].strip().strip() 1280 | if "Answer: " in output and "\\boxed" not in output[output.index("Answer: "):]: 1281 | answer_index = output.index("Answer: ") 1282 | answer = output[answer_index:].lstrip("Answer: ") 1283 | answer = answer[0] + "\\boxed{" + answer[1:-1] + "}" + answer[-1] 1284 | output = output[:answer_index] + "Answer: " + answer 1285 | # filter based on keywords that are not suitable for language models. 1286 | blacklist = [ 1287 | "image", 1288 | "images", 1289 | # "graph", 1290 | # "graphs", 1291 | "picture", 1292 | "pictures", 1293 | "file", 1294 | "files", 1295 | # "map", 1296 | # "maps", 1297 | "draw", 1298 | "plot", 1299 | "go to", 1300 | "video", 1301 | "audio", 1302 | "music", 1303 | "flowchart", 1304 | "diagram", 1305 | ] 1306 | blacklist += [] 1307 | if any(find_word_in_string(word, inst) for word in blacklist): 1308 | continue 1309 | # We found that the model tends to add "write a program" to some existing instructions, which lead to a lot of such instructions. 1310 | # And it's a bit comfusing whether the model need to write a program or directly output the result. 1311 | # Here we filter them out. 1312 | # Note this is not a comprehensive filtering for all programming instructions. 1313 | if inst.startswith("Write a program"): 1314 | continue 1315 | # filter those starting with punctuation 1316 | if inst[0] in string.punctuation: 1317 | continue 1318 | # filter those starting with non-english character 1319 | if not inst[0].isascii(): 1320 | continue 1321 | # filter un-complete input 1322 | if input.startswith("<") and input.endswith(">"): 1323 | continue 1324 | if input.startswith("(") and input.endswith(")"): 1325 | continue 1326 | instructions.append({"instruction": inst, "input": input, "output": output}) 1327 | return new_subtask, new_subtask_reason, instructions 1328 | 1329 | def post_process_gpt3_response_extend( 1330 | self, 1331 | response: str 1332 | ) -> Tuple[List[str], List[str], List[List[Dict[str, str]]]]: 1333 | stop_reason = response.get("stop_reason", "") 1334 | raw_response = response.get("raw_response", "") 1335 | 1336 | raw_response = raw_response.replace("###", "").replace("### ", "").replace(" ###", "").replace(" ### ", "") 1337 | raw_response = raw_response.replace("Example:", "Examples:") 1338 | split_response = re.split("New subtask:|Reason:|Examples:", raw_response) 1339 | split_response = split_response[1:] 1340 | num_subtasks = len(split_response) // 3 1341 | new_subtasks = [] 1342 | new_subtasks_reason = [] 1343 | new_subtasks_example = [] 1344 | for i in range(num_subtasks): 1345 | new_subtask = split_response[i * 3].strip() 1346 | new_subtask_reason = split_response[i * 3 + 1].strip() 1347 | new_subtask_examples = split_response[i * 3 + 2] 1348 | new_subtask_examples = re.split("Instruction:", new_subtask_examples) 1349 | new_subtask_examples = new_subtask_examples[1:] 1350 | instructions = [] 1351 | for idx, inst in enumerate(new_subtask_examples): 1352 | # if the decoding stops due to length, the last example is likely truncated so we discard it 1353 | if idx == len(new_subtask_examples) - 1 and stop_reason == "length": 1354 | continue 1355 | splitted_data = re.split("Instruction:|Input:|Output:", inst) 1356 | if len(splitted_data) != 3: 1357 | continue 1358 | inst = splitted_data[0].strip() 1359 | input = splitted_data[1].strip() 1360 | input = "" if input.lower() == "" else input 1361 | output = splitted_data[2].strip().strip() 1362 | if "Answer: " in output and "\\boxed" not in output[output.index("Answer: "):]: 1363 | answer_index = output.index("Answer: ") 1364 | answer = output[answer_index:].lstrip("Answer: ") 1365 | answer = answer[0] + "\\boxed{" + answer[1:-1] + "}" + answer[-1] 1366 | output = output[:answer_index] + "Answer: " + answer 1367 | # filter based on keywords that are not suitable for language models. 1368 | blacklist = [ 1369 | "image", 1370 | "images", 1371 | # "graph", 1372 | # "graphs", 1373 | "picture", 1374 | "pictures", 1375 | "file", 1376 | "files", 1377 | # "map", 1378 | # "maps", 1379 | "draw", 1380 | "plot", 1381 | "go to", 1382 | "video", 1383 | "audio", 1384 | "music", 1385 | "flowchart", 1386 | "diagram", 1387 | ] 1388 | blacklist += [] 1389 | if any(find_word_in_string(word, inst) for word in blacklist): 1390 | continue 1391 | # We found that the model tends to add "write a program" to some existing instructions, which lead to a lot of such instructions. 1392 | # And it's a bit comfusing whether the model need to write a program or directly output the result. 1393 | # Here we filter them out. 1394 | # Note this is not a comprehensive filtering for all programming instructions. 1395 | if inst.startswith("Write a program"): 1396 | continue 1397 | # filter those starting with punctuation 1398 | if inst[0] in string.punctuation: 1399 | continue 1400 | # filter those starting with non-english character 1401 | if not inst[0].isascii(): 1402 | continue 1403 | # filter un-complete input 1404 | if input.startswith("<") and input.endswith(">"): 1405 | continue 1406 | if input.startswith("(") and input.endswith(")"): 1407 | continue 1408 | instructions.append({"instruction": inst, "input": input, "output": output}) 1409 | new_subtasks.append(new_subtask) 1410 | new_subtasks_reason.append(new_subtask_reason) 1411 | new_subtasks_example.append(instructions) 1412 | return new_subtasks, new_subtasks_reason, new_subtasks_example 1413 | 1414 | 1415 | def find_word_in_string(w, s): 1416 | return re.compile(r"\b({0})\b".format(w), flags=re.IGNORECASE).search(s) 1417 | 1418 | 1419 | def test_extending( 1420 | domain_tree: DomainTree, 1421 | max_depth: int = 2, 1422 | extend_nums: List = None, 1423 | extend_batch_size: int = None 1424 | ): 1425 | queue = [_ for _ in PreOrderIter(domain_tree.root)] 1426 | while len(queue) > 0: 1427 | node = queue.pop(0) 1428 | logger.info(f"Processing {node.task_name}, depth: {node.depth}") 1429 | if node.depth >= max_depth: 1430 | continue 1431 | new_nodes: List[AnyNode] = domain_tree.extend_node_children( 1432 | node, 1433 | extend_num=extend_nums[node.depth], 1434 | extend_batch_size=extend_batch_size, 1435 | ) 1436 | domain_tree.save_to_local() 1437 | for new_node in new_nodes: 1438 | if new_node.depth < max_depth: 1439 | queue.append(new_node) 1440 | domain_tree.save_to_local() 1441 | return 1442 | 1443 | 1444 | def test_enriching( 1445 | domain_tree: DomainTree, 1446 | enrich_nums: List = None, 1447 | enrich_batch_size: int = None, 1448 | ): 1449 | queue = [_ for _ in PreOrderIter(domain_tree.root)] 1450 | for node in queue: 1451 | logger.info(f"Processing {node.task_name}, depth: {node.depth}") 1452 | domain_tree.enrich_node_samples( 1453 | node, 1454 | enrich_num=enrich_nums[node.depth], 1455 | enrich_batch_size=enrich_batch_size, 1456 | ) 1457 | domain_tree.save_to_local() 1458 | return 1459 | 1460 | 1461 | def test_prune( 1462 | domain_tree: DomainTree, 1463 | prune_threshold: float, 1464 | num_cpus=1, 1465 | ): 1466 | queue = [_ for _ in LevelOrderIter(domain_tree.root)] 1467 | root = queue[0] 1468 | queue = queue[1:] 1469 | scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False) 1470 | all_subtasks = [root.raw_task_name] 1471 | all_subtask_tokens = [scorer._tokenizer.tokenize(subtask_name) for subtask_name in all_subtasks] 1472 | pruned_subtask_name = [] 1473 | for node in tqdm.tqdm(queue): 1474 | new_subtask_tokens = scorer._tokenizer.tokenize(node.raw_task_name) 1475 | with Pool(num_cpus) as p: 1476 | rouge_scores = p.map( 1477 | partial(rouge_scorer._score_lcs, new_subtask_tokens), 1478 | all_subtask_tokens, 1479 | ) 1480 | rouge_scores = [score.fmeasure for score in rouge_scores] 1481 | # most_similar_instructions = { 1482 | # all_subtasks[i]: rouge_scores[i] for i in np.argsort(rouge_scores)[-10:][::-1] 1483 | # } 1484 | if max(rouge_scores) > prune_threshold: # pruning this subtask 1485 | pruned_subtask_name.append(node.task_name) 1486 | continue 1487 | all_subtasks.append(node.raw_task_name) 1488 | all_subtask_tokens.append(new_subtask_tokens) 1489 | return pruned_subtask_name 1490 | 1491 | 1492 | def test_filter( 1493 | domain_tree: DomainTree, 1494 | filter_threshold: float, 1495 | num_cpus=64, 1496 | ): 1497 | queue = [_ for _ in LevelOrderIter(domain_tree.root)] 1498 | scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False) 1499 | logger.info("Start filtering.") 1500 | for node in tqdm.tqdm(queue): 1501 | keep = 0 1502 | all_instruction_tokens = [] 1503 | all_examples = [] 1504 | for instruction_data_entry in tqdm.tqdm(node.examples): 1505 | new_instruction_tokens = scorer._tokenizer.tokenize(instruction_data_entry["instruction"]) 1506 | with Pool(num_cpus) as p: 1507 | rouge_scores = p.map( 1508 | partial(rouge_scorer._score_lcs, new_instruction_tokens), 1509 | all_instruction_tokens, 1510 | ) 1511 | rouge_scores = [score.fmeasure for score in rouge_scores] 1512 | # most_similar_instructions = { 1513 | # all_instructions[i]: rouge_scores[i] for i in np.argsort(rouge_scores)[-10:][::-1] 1514 | # } 1515 | if len(rouge_scores) != 0 and max(rouge_scores) > filter_threshold: 1516 | continue 1517 | else: 1518 | keep += 1 1519 | all_examples.append(instruction_data_entry) 1520 | # all_instructions.append(instruction_data_entry["instruction"]) 1521 | all_instruction_tokens.append(new_instruction_tokens) 1522 | node.examples = all_examples 1523 | logger.info(f"Subtask: {node.task_name}. Filtered instances for this subtask: {keep}") 1524 | domain_tree.save_to_local() 1525 | domain_tree.update_file(node) 1526 | return 1527 | 1528 | 1529 | def test_sample(domain_tree: DomainTree, 1530 | export_file: str, 1531 | sample_example_num: int, 1532 | sample_max_depth: int, 1533 | sample_use_pruned: bool, 1534 | pruned_subtasks_name: List[str] 1535 | ): 1536 | queue = [_ for _ in LevelOrderIter(domain_tree.root) if _.depth <= sample_max_depth] 1537 | if sample_use_pruned: 1538 | pruned_subtasks_name = [_.task_name for _ in queue if _.task_name in pruned_subtasks_name] 1539 | else: 1540 | pruned_subtasks_name = [] 1541 | logger.info(f"All subtasks: {len(queue)}. Pruned subtasks: {len(pruned_subtasks_name)}") 1542 | # sample_example_per_node = sample_example_num // (len(queue) - len(pruned_subtasks_name)) + 1 1543 | # logger.info(f"Sample examples per node: {sample_example_per_node}") 1544 | data = [] 1545 | all_examples_num = 0 1546 | for node in queue: 1547 | all_examples_num += min(len(node.examples), 500) 1548 | for node in queue: 1549 | sample_example_per_node = min(int(sample_example_num * (min(len(node.examples), 500) / all_examples_num)) + 1, 1550 | len(node.examples)) 1551 | logger.info(f"Sample examples num for {node.task_name}: {sample_example_per_node}") 1552 | data += random.sample(node.examples, k=sample_example_per_node) 1553 | data = data[:sample_example_num] 1554 | if len(data) < sample_example_num: 1555 | logger.info(f"Have no enough examples, sample {sample_example_num - len(data)} examples from root.") 1556 | data += random.sample(queue[0].examples, k=sample_example_num - len(data)) 1557 | with open(export_file, "w") as fout: 1558 | json.dump(data, fout, indent=4) 1559 | logger.info(f"All sampled examples {len(data)}") 1560 | 1561 | 1562 | def run_extend(args): 1563 | save_dir = args.save_dir 1564 | out_dir = args.out_dir 1565 | assistant_name = args.assistant_name 1566 | extend_nums = [int(_) for _ in args.extend_nums.split(",")] 1567 | max_depth = args.max_depth 1568 | extend_batch_size = args.extend_batch_size 1569 | TreeFactory = args.tree_map[args.lang][args.domain] 1570 | domain_tree: DomainTree = TreeFactory.from_local_dir( 1571 | save_dir=save_dir, 1572 | out_dir=out_dir, 1573 | assistant_name=assistant_name, 1574 | ) 1575 | test_extending(domain_tree, 1576 | max_depth=max_depth, 1577 | extend_nums=extend_nums, 1578 | extend_batch_size=extend_batch_size) 1579 | return 1580 | 1581 | 1582 | def run_enrich(args): 1583 | save_dir = args.save_dir 1584 | out_dir = args.out_dir 1585 | assistant_name = args.assistant_name 1586 | enrich_nums = [int(_) for _ in args.enrich_nums.split(",")] 1587 | enrich_batch_size = args.enrich_batch_size 1588 | TreeFactory = args.tree_map[args.lang][args.domain] 1589 | domain_tree: DomainTree = TreeFactory.from_local_dir( 1590 | save_dir=save_dir, 1591 | out_dir=out_dir, 1592 | assistant_name=assistant_name, 1593 | ) 1594 | test_enriching( 1595 | domain_tree, 1596 | enrich_nums=enrich_nums, 1597 | enrich_batch_size=enrich_batch_size 1598 | ) 1599 | return 1600 | 1601 | 1602 | def run_prune(args): 1603 | assistant_name = args.assistant_name 1604 | save_dir = args.save_dir 1605 | out_dir = args.out_dir 1606 | pruned_file = args.pruned_file 1607 | TreeFactory = args.tree_map[args.lang][args.domain] 1608 | domain_tree: DomainTree = TreeFactory.from_local_dir(save_dir=save_dir, out_dir=out_dir, 1609 | assistant_name=assistant_name) 1610 | pruned_subtask_name = test_prune(domain_tree, args.prune_threshold) 1611 | with open(pruned_file, "w") as fout: 1612 | json.dump(pruned_subtask_name, fout) 1613 | domain_tree.save_to_local() 1614 | 1615 | 1616 | def run_filter(args): 1617 | assistant_name = args.assistant_name 1618 | save_dir = args.save_dir 1619 | out_dir = args.out_dir 1620 | filter_threshold = args.filter_threshold 1621 | TreeFactory = args.tree_map[args.lang][args.domain] 1622 | domain_tree: DomainTree = TreeFactory.from_local_dir(save_dir=save_dir, out_dir=out_dir, 1623 | assistant_name=assistant_name) 1624 | test_filter(domain_tree, filter_threshold) 1625 | shutil.copy(args.pruned_file, os.path.join(out_dir, "pruned_subtasks_name.json")) 1626 | 1627 | 1628 | def run_sample(args): 1629 | assistant_name = args.assistant_name 1630 | save_dir = args.save_dir 1631 | export_file = args.export_file 1632 | sample_example_num = args.sample_example_num 1633 | sample_max_depth = args.sample_max_depth 1634 | sample_use_pruned = args.sample_use_pruned 1635 | pruned_subtasks_name = json.loads(open(args.pruned_file, "r").read()) 1636 | TreeFactory = args.tree_map[args.lang][args.domain] 1637 | domain_tree: DomainTree = TreeFactory.from_local_dir(save_dir=save_dir, out_dir=save_dir + "_tmp", 1638 | assistant_name=assistant_name) 1639 | test_sample(domain_tree, export_file=export_file, sample_example_num=sample_example_num, 1640 | sample_max_depth=sample_max_depth, sample_use_pruned=sample_use_pruned, 1641 | pruned_subtasks_name=pruned_subtasks_name) 1642 | os.rmdir(save_dir + "_tmp") 1643 | 1644 | 1645 | func_action_mapping = { 1646 | "extend": run_extend, 1647 | "enrich": run_enrich, 1648 | "prune": run_prune, 1649 | "filter": run_filter, 1650 | "sample": run_sample, 1651 | } 1652 | 1653 | 1654 | def main(): 1655 | parser = argparse.ArgumentParser() 1656 | parser.add_argument('--action', type=str, required=True, help='action to do') 1657 | parser.add_argument('--save_dir', type=str, help='original tree dir') 1658 | parser.add_argument('--out_dir', type=str, help='dir to save tree after operations (recommended to use a new one instead of reusing save_dir)') 1659 | parser.add_argument('--export_file', type=str, help='single json file to store all exported examples') 1660 | parser.add_argument('--lang', type=str, default="en", help='either zh or en') 1661 | parser.add_argument('--domain', type=str, default="rewrite", help='target domain') 1662 | parser.add_argument('--extend_nums', type=str, help='children num to add during extending') 1663 | parser.add_argument('--extend_batch_size', type=int, default=None, help='prompt batch size for extending') 1664 | parser.add_argument('--max_depth', type=int, help='extend max depth') 1665 | parser.add_argument('--enrich_nums', type=str, help='examples to add during enriching') 1666 | parser.add_argument('--enrich_batch_size', type=int, default=None, help='prompt batch size for enriching') 1667 | parser.add_argument('--prune_threshold', type=float, default=0.7, help='threshold for sub-task pruning') 1668 | parser.add_argument('--pruned_file', type=str, default=None, help='file to store pruned subtasks name list') 1669 | parser.add_argument('--filter_threshold', type=float, default=0.7, help='threshold for filter examples') 1670 | parser.add_argument('--sample_example_num', type=int, default=50000, help='data num to use during sampling') 1671 | parser.add_argument('--sample_max_depth', type=int, default=3, help='max depth for sampling') 1672 | parser.add_argument('--sample_use_pruned', action="store_true", help='use pruned for sampling') 1673 | parser.add_argument('--assistant_name', type=str, help='using either openai or claude') 1674 | args = parser.parse_args() 1675 | args.tree_map = {"en": {"rewrite": EnDomainTreeRewrite, 1676 | "brainstorming": EnDomainTreeBrainstorming, 1677 | "math": EnDomainTreeMath, 1678 | }} 1679 | random.seed(42) 1680 | print(args) 1681 | # further add sanity check 1682 | func_action_mapping[args.action](args) 1683 | 1684 | 1685 | if __name__ == "__main__": 1686 | main() 1687 | --------------------------------------------------------------------------------