├── .gitignore ├── .pre-commit-config.yaml ├── .pylintrc ├── LICENSE ├── LbT_poster.pdf ├── LbT_poster.png ├── README.md ├── examples └── config │ ├── code │ ├── chatgpt35_exam.yaml │ ├── chatgpt35_trgen.yaml │ ├── chatgpt35_trgen_debug.yaml │ ├── llama-3-70b_exam.yaml │ ├── llama-3-70b_trgen.yaml │ ├── llama-3-70b_trgen_debug.yaml │ ├── llama-3-8b_exam.yaml │ ├── llama-3-8b_trgen.yaml │ └── llama-3-8b_trgen_debug.yaml │ └── math │ ├── chatgpt35_greedy.yaml │ ├── chatgpt35_trgen.yaml │ ├── chatgpt4o-mini_exam.yaml │ ├── chatgpt4o_greedy.yaml │ ├── chatgpt4o_trgen.yaml │ ├── llama-3-70b_greedy.yaml │ ├── llama-3-70b_trgen.yaml │ ├── llama-3-8b_exam.yaml │ ├── llama-3-8b_greedy.yaml │ ├── llama-3-8b_trgen.yaml │ ├── mistral-7b_exam.yaml │ ├── mistral-7b_greedy.yaml │ └── mistral-7b_trgen.yaml ├── lbt ├── __init__.py ├── base.py ├── datasets_adapter │ ├── __init__.py │ ├── apps_utils │ │ └── testing_util.py │ ├── code_dataset.py │ ├── leetcode_sub │ │ ├── environment.py │ │ ├── leetcode.py │ │ └── types.py │ ├── math_dataset.py │ └── utils │ │ ├── __init__.py │ │ ├── add_test_cases.py │ │ ├── clean_leetcode.py │ │ ├── extract_tests.yaml │ │ ├── fetch_leetcode.py │ │ ├── format_leetcode.py │ │ ├── transform_code.py │ │ ├── utils_leetcode.py │ │ └── utils_llm.py ├── exam_maker.py ├── exam_scorer.py ├── models │ ├── __init__.py │ └── base.py ├── patch │ └── lmops-lbtm3.patch ├── qa_item.py ├── test.py └── utils │ ├── __init__.py │ ├── log.py │ └── registry.py ├── pyproject.toml ├── scripts ├── code │ ├── prepare_datasets.py │ ├── prepare_teaching_datasets.py │ └── search_rationale.py ├── code_exam.py ├── exam.py └── math │ ├── prepare_datasets.py │ ├── prepare_teaching_dataset.py │ └── search_rationale.py └── tests ├── test_math_extraction.py └── test_ques_similarity_exam_maker.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store 2 | 3 | results/ 4 | gitignoreresults/ 5 | .vscode/ 6 | build/ 7 | bkp/ 8 | 9 | *.egg-info 10 | 11 | *__pycache__/ 12 | cache-*.arrow 13 | 14 | *.csv 15 | *.jsonl 16 | *.pkl 17 | *.ipynb 18 | **/.DS_Store 19 | examples/leetcode 20 | 21 | *.env 22 | examples/apps 23 | examples/apps_bkp 24 | examples/code_contests 25 | examples/code_contests_bkp 26 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 23.3.0 4 | hooks: 5 | - id: black 6 | args: [--preview] 7 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MAIN] 2 | ignore = tests 3 | ignore-patterns = .*egg-info 4 | py-version = 3.7.2 5 | 6 | [MESSAGES CONTROL] 7 | disable= 8 | too-many-locals, 9 | too-many-arguments, 10 | attribute-defined-outside-init, 11 | invalid-name, 12 | missing-docstring, 13 | protected-access, 14 | too-few-public-methods, 15 | format, 16 | wildcard-import, 17 | bad-mcs-classmethod-argument, 18 | import-error, 19 | 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2024] [Xuefei Ning, Zifu Wang, Shiyao Li, Zinan Lin, Peiran Yao, Tianyu Fu, Matthew B. Blaschko, Guohao Dai, Huazhong Yang, Yu Wang] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LbT_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imagination-research/lbt/c14bf6e5799a06052f24caa1d41275d326d423c9/LbT_poster.pdf -------------------------------------------------------------------------------- /LbT_poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imagination-research/lbt/c14bf6e5799a06052f24caa1d41275d326d423c9/LbT_poster.png -------------------------------------------------------------------------------- /examples/config/code/chatgpt35_exam.yaml: -------------------------------------------------------------------------------- 1 | student_model_cfgs: 2 | - model_type: azure_openai 3 | model_cfg: 4 | model: gpt-35-turbo-0613 5 | api_version: 2024-02-15-preview 6 | sample_cfg: 7 | top_p: 0 8 | 9 | # can be overrided by per-model sample_cfg 10 | general_student_sample_cfg: 11 | temperature: 0 12 | 13 | exam_maker_type: code_metainfo 14 | exam_maker_cfg: 15 | num_exam_questions: 300 16 | 17 | exam_prompter_type: basic 18 | exam_prompter_cfg: 19 | demo_template: "[[Question]]:\nHere is an example question, please understand it very carefully:\n{question}\nFirst, let's think step by step to find a complete problem-solving strategy. Then, write a python code based on the problem-solving strategy.\n\n[[RATIONALE]]:\n{rationale}\n\n[[Final Code]]:\n${answer}$\n" 20 | exam_template: "{demo}\n\n[[Question]]:\nPlease first understand the problem-solving approach in rationale of the aforementioned example, and then follow the example to solve the following similar type of problem:\n{question}\nFirst, let's think step by step to find a complete problem-solving strategy. Then, write a python code based on the problem-solving strategy.\n\n[[RATIONALE]]:\n" 21 | use_multi_round_conv: false 22 | 23 | exam_scorer_type: code 24 | 25 | teaching_plans: every -------------------------------------------------------------------------------- /examples/config/code/chatgpt35_trgen.yaml: -------------------------------------------------------------------------------- 1 | student_model_cfgs: 2 | - model_type: azure_openai 3 | model_cfg: 4 | model: gpt-35-turbo-0613 5 | api_version: 2024-02-15-preview 6 | sample_cfg: 7 | top_p: 1.0 8 | 9 | # can be overrided by per-model sample_cfg 10 | general_student_sample_cfg: 11 | temperature: 1.0 12 | 13 | exam_maker_type: fixed 14 | exam_maker_cfg: 15 | selected_indexes: # range(0, 1) 16 | 17 | exam_prompter_type: basic 18 | exam_prompter_cfg: 19 | demo_template: "[[Question]]:\n{question}\nFirst, let's think step by step to find a complete problem-solving strategy. Then, write a python code based on the problem-solving strategy.\n\n[ROLESWITCHING assistant:][[RATIONALE]]:\n{rationale}\n\n[[Final Code]]:\n${answer}$\n" 20 | exam_template: "[[Question]]:\n{question}\nFirst, let's think step by step to find a complete problem-solving strategy. Then, write a python code based on the problem-solving strategy.\n\n[ROLESWITCHING assistant:][[RATIONALE]]:\n" 21 | use_multi_round_conv: true 22 | stub_teaching_items: 23 | - {"question": "Write a python function 'has_close_elements(numbers: List[float], threshold: float) -> bool:' to solve the following problem: Check if in given list of numbers, are any two numbers closer to each other than given threshold.", "answer": "```\nfrom typing import List\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n for idx, elem in enumerate(numbers):\n for idx2, elem2 in enumerate(numbers):\n if idx != idx2:\n distance = abs(elem - elem2)\n if distance < threshold:\n return True\n return False\n```", "rationale": "You can use a brute-force approach to solve this problem. Let's think step by step:\n\n1. Double Loop Iteration: The code uses nested loops to compare each pair of elements in the list. The outer loop iterates through each element (elem) in the list, and the inner loop also iterates through each element (elem2) in the list. This ensures that every possible pair of elements is considered for comparison.\n\n2. Comparison and Distance Calculation: Within the nested loops, the code checks if the indices of the two elements being compared (idx and idx2) are not the same, ensuring that the code doesn't compare an element with itself. It then calculates the absolute difference (distance) between the two elements using the abs() function.\n\n3. Threshold Check: After calculating the distance between two elements, the code checks if this distance is less than the given threshold. If the distance is below the threshold, it means that the two elements are closer to each other than allowed by the threshold, and the function returns True immediately.\n\n4. Return False if No Close Elements Found: If the nested loops complete without finding any pair of elements that satisfy the condition of being closer than the threshold, the function returns False. This indicates that no such pair exists in the given list.\n"} 24 | 25 | exam_scorer_type: code 26 | 27 | teaching_plans: no demo -------------------------------------------------------------------------------- /examples/config/code/chatgpt35_trgen_debug.yaml: -------------------------------------------------------------------------------- 1 | student_model_cfgs: 2 | - model_type: azure_openai 3 | model_cfg: 4 | model: gpt-35-turbo-0613 5 | api_version: 2024-02-15-preview 6 | sample_cfg: 7 | top_p: 0 8 | 9 | # can be overrided by per-model sample_cfg 10 | general_student_sample_cfg: 11 | temperature: 0 12 | 13 | exam_maker_type: code_metainfo 14 | exam_maker_cfg: 15 | num_exam_questions: 300 16 | 17 | exam_prompter_type: basic 18 | exam_prompter_cfg: 19 | demo_template: "[[Question]]:\n\nHere is an example question, please understand it very carefully:\n\n{question}\n\nFirst, let''s think step by step to find a complete problem-solving strategy.\nThen, write a python code based on the problem-solving strategy.\n\n\n[[RATIONALE]]:\n\n{rationale}\n\n\n[[Final Code]]:\n\n${answer}$\n" 20 | exam_template: "{demo}\n\n[[Question]]:\n\nPlease first understand the problem-solving approach in rationale of the aforementioned\nexample, and then follow the example to solve the following similar type of problem:\n\n{question}\n\nFirst, let''s think step by step to find a complete problem-solving strategy.\nThen, write a python code based on the problem-solving strategy.\n\n\n[[RATIONALE]]:\n" 21 | debug_template: "[[Question]]:\n\n{question}\n\n[[RATIONALE]]:\n\n{rationale}\n\n[[Final Code]]:\n\n{answer}\n\nYou need to debug this code with the following rules:\n(1) If you think the provided code is correct, you must retrieve the original correct code.\n(2) If you think the provided code is incorrect, you debug the code and write the final bug-free code.\n(3) If there is no complete code, you must write a complete code based on the rationale.\n\nLet's think step by step and remember you **must** give me a complete python code finally.\n\n[ROLESWITCHING assistant:]\n" 22 | use_multi_round_conv: false 23 | 24 | exam_scorer_type: code 25 | 26 | teaching_plans: every 27 | 28 | debug: true -------------------------------------------------------------------------------- /examples/config/code/llama-3-70b_exam.yaml: -------------------------------------------------------------------------------- 1 | student_model_cfgs: 2 | - model_type: huggingface 3 | model_cfg: 4 | path: meta-llama/Meta-Llama-3-8B-Instruct 5 | name: llama-3-70b 6 | sample_cfg: 7 | num_return_sequences: 1 8 | 9 | # can be overrided by per-model sample_cfg 10 | general_student_sample_cfg: 11 | batch_size: 1 12 | num_return_sequences: 1 13 | do_sample: false 14 | temperature: 0.0 15 | eos_token_id: [128001, 128009] 16 | pad_token_id: 128001 17 | 18 | exam_maker_type: code_metainfo 19 | exam_maker_cfg: 20 | num_exam_questions: 300 21 | 22 | exam_prompter_type: basic 23 | exam_prompter_cfg: 24 | demo_template: "[[Question]]:\nHere is an example question, please understand it very carefully:\n{question}\nFirst, let's think step by step to find a complete problem-solving strategy. Then, write a python code based on the problem-solving strategy.\n\n[[RATIONALE]]:\n{rationale}\n\n[[Final Code]]:\n${answer}$\n" 25 | exam_template: "{demo}\n\n[[Question]]:\nPlease first understand the problem-solving approach in rationale of the aforementioned example, and then follow the example to solve the following similar type of problem:\n{question}\nFirst, let's think step by step to find a complete problem-solving strategy. Then, write a python code based on the problem-solving strategy.\n\n[[RATIONALE]]:\n" 26 | use_multi_round_conv: false 27 | 28 | exam_scorer_type: code 29 | 30 | teaching_plans: every -------------------------------------------------------------------------------- /examples/config/code/llama-3-70b_trgen.yaml: -------------------------------------------------------------------------------- 1 | student_model_cfgs: 2 | - model_type: huggingface 3 | model_cfg: 4 | path: meta-llama/Meta-Llama-3-70B-Instruct 5 | name: llama-3-70b 6 | sample_cfg: 7 | num_return_sequences: 1 8 | top_p: 0.9 9 | 10 | # can be overrided by per-model sample_cfg 11 | general_student_sample_cfg: 12 | batch_size: 1 13 | num_return_sequences: 1 14 | do_sample: true 15 | temperature: 0.6 16 | top_p: 0.9 17 | eos_token_id: [128001, 128009] 18 | pad_token_id: 128001 19 | 20 | 21 | exam_maker_type: fixed 22 | exam_maker_cfg: 23 | selected_indexes: # range(0, 20) 24 | 25 | exam_prompter_type: basic 26 | exam_prompter_cfg: 27 | demo_template: "[[Question]]:\n{question}\nFirst, let's think step by step to find a complete problem-solving strategy. Then, write a python code based on the problem-solving strategy.\n\n[ROLESWITCHING assistant:][[RATIONALE]]:\n{rationale}\n\n[[Final Code]]:\n${answer}$\n" 28 | exam_template: "[[Question]]:\n{question}\nFirst, let's think step by step to find a complete problem-solving strategy. Then, write a python code based on the problem-solving strategy.\n\n[ROLESWITCHING assistant:][[RATIONALE]]:\n" 29 | use_multi_round_conv: true 30 | stub_teaching_items: 31 | - {"question": "Write a python function 'has_close_elements(numbers: List[float], threshold: float) -> bool:' to solve the following problem: Check if in given list of numbers, are any two numbers closer to each other than given threshold.", "answer": "from typing import List\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n for idx, elem in enumerate(numbers):\n for idx2, elem2 in enumerate(numbers):\n if idx != idx2:\n distance = abs(elem - elem2)\n if distance < threshold:\n return True\n return False", "rationale": "You can use a brute-force approach to solve this problem. Let's think step by step:\n\n1. Double Loop Iteration: The code uses nested loops to compare each pair of elements in the list. The outer loop iterates through each element (elem) in the list, and the inner loop also iterates through each element (elem2) in the list. This ensures that every possible pair of elements is considered for comparison.\n\n2. Comparison and Distance Calculation: Within the nested loops, the code checks if the indices of the two elements being compared (idx and idx2) are not the same, ensuring that the code doesn't compare an element with itself. It then calculates the absolute difference (distance) between the two elements using the abs() function.\n\n3. Threshold Check: After calculating the distance between two elements, the code checks if this distance is less than the given threshold. If the distance is below the threshold, it means that the two elements are closer to each other than allowed by the threshold, and the function returns True immediately.\n\n4. Return False if No Close Elements Found: If the nested loops complete without finding any pair of elements that satisfy the condition of being closer than the threshold, the function returns False. This indicates that no such pair exists in the given list.\n"} 32 | 33 | exam_scorer_type: code 34 | 35 | teaching_plans: no demo -------------------------------------------------------------------------------- /examples/config/code/llama-3-70b_trgen_debug.yaml: -------------------------------------------------------------------------------- 1 | student_model_cfgs: 2 | - model_type: huggingface 3 | model_cfg: 4 | path: meta-llama/Meta-Llama-3-70B-Instruct 5 | name: llama-3-70b 6 | sample_cfg: 7 | num_return_sequences: 1 8 | 9 | # can be overrided by per-model sample_cfg 10 | general_student_sample_cfg: 11 | batch_size: 1 12 | num_return_sequences: 1 13 | do_sample: false 14 | temperature: 0.0 15 | eos_token_id: [128001, 128009] 16 | pad_token_id: 128001 17 | 18 | exam_maker_type: code_metainfo 19 | exam_maker_cfg: 20 | num_exam_questions: 300 21 | 22 | exam_prompter_type: basic 23 | exam_prompter_cfg: 24 | demo_template: "[[Question]]:\n\nHere is an example question, please understand it very carefully:\n\n{question}\n\nFirst, let''s think step by step to find a complete problem-solving strategy.\nThen, write a python code based on the problem-solving strategy.\n\n\n[[RATIONALE]]:\n\n{rationale}\n\n\n[[Final Code]]:\n\n${answer}$\n" 25 | exam_template: "{demo}\n\n[[Question]]:\n\nPlease first understand the problem-solving approach in rationale of the aforementioned\nexample, and then follow the example to solve the following similar type of problem:\n\n{question}\n\nFirst, let''s think step by step to find a complete problem-solving strategy.\nThen, write a python code based on the problem-solving strategy.\n\n\n[[RATIONALE]]:\n" 26 | debug_template: "[[Question]]:\n\n{question}\n\n[[RATIONALE]]:\n\n{rationale}\n\n[[Final Code]]:\n\n{answer}\n\nYou need to debug this code with the following rules:\n(1) If you think the provided code is correct, you must retrieve the original correct code.\n(2) If you think the provided code is incorrect, you debug the code and write the final bug-free code.\n(3) If there is no complete code, you must write a complete code based on the rationale.\n\nLet's think step by step and remember you **must** give me a complete python code finally.\n\n[ROLESWITCHING assistant:]\n" 27 | use_multi_round_conv: false 28 | 29 | exam_scorer_type: code 30 | 31 | teaching_plans: every 32 | 33 | debug: true -------------------------------------------------------------------------------- /examples/config/code/llama-3-8b_exam.yaml: -------------------------------------------------------------------------------- 1 | student_model_cfgs: 2 | - model_type: huggingface 3 | model_cfg: 4 | path: meta-llama/Meta-Llama-3-8B-Instruct 5 | name: llama-3-8b 6 | sample_cfg: 7 | num_return_sequences: 1 8 | 9 | # can be overrided by per-model sample_cfg 10 | general_student_sample_cfg: 11 | batch_size: 1 12 | num_return_sequences: 1 13 | do_sample: false 14 | temperature: 0.0 15 | eos_token_id: [128001, 128009] 16 | pad_token_id: 128001 17 | 18 | exam_maker_type: code_metainfo 19 | exam_maker_cfg: 20 | num_exam_questions: 300 21 | 22 | exam_prompter_type: basic 23 | exam_prompter_cfg: 24 | demo_template: "[[Question]]:\nHere is an example question, please understand it very carefully:\n{question}\nFirst, let's think step by step to find a complete problem-solving strategy. Then, write a python code based on the problem-solving strategy.\n\n[[RATIONALE]]:\n{rationale}\n\n[[Final Code]]:\n${answer}$\n" 25 | exam_template: "{demo}\n\n[[Question]]:\nPlease first understand the problem-solving approach in rationale of the aforementioned example, and then follow the example to solve the following similar type of problem:\n{question}\nFirst, let's think step by step to find a complete problem-solving strategy. Then, write a python code based on the problem-solving strategy.\n\n[[RATIONALE]]:\n" 26 | use_multi_round_conv: false 27 | 28 | exam_scorer_type: code 29 | 30 | teaching_plans: every -------------------------------------------------------------------------------- /examples/config/code/llama-3-8b_trgen.yaml: -------------------------------------------------------------------------------- 1 | student_model_cfgs: 2 | - model_type: huggingface 3 | model_cfg: 4 | path: meta-llama/Meta-Llama-3-8B-Instruct 5 | name: llama-3-8b 6 | sample_cfg: 7 | num_return_sequences: 1 8 | top_p: 0.9 9 | 10 | # can be overrided by per-model sample_cfg 11 | general_student_sample_cfg: 12 | batch_size: 1 13 | num_return_sequences: 1 14 | do_sample: true 15 | temperature: 0.6 16 | top_p: 0.9 17 | eos_token_id: [128001, 128009] 18 | pad_token_id: 128001 19 | 20 | 21 | exam_maker_type: fixed 22 | exam_maker_cfg: 23 | selected_indexes: # range(0, 20) 24 | 25 | exam_prompter_type: basic 26 | exam_prompter_cfg: 27 | demo_template: "[[Question]]:\n{question}\nFirst, let's think step by step to find a complete problem-solving strategy. Then, write a python code based on the problem-solving strategy.\n\n[ROLESWITCHING assistant:][[RATIONALE]]:\n{rationale}\n\n[[Final Code]]:\n${answer}$\n" 28 | exam_template: "[[Question]]:\n{question}\nFirst, let's think step by step to find a complete problem-solving strategy. Then, write a python code based on the problem-solving strategy.\n\n[ROLESWITCHING assistant:][[RATIONALE]]:\n" 29 | use_multi_round_conv: true 30 | stub_teaching_items: 31 | - {"question": "Write a python function 'has_close_elements(numbers: List[float], threshold: float) -> bool:' to solve the following problem: Check if in given list of numbers, are any two numbers closer to each other than given threshold.", "answer": "from typing import List\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n for idx, elem in enumerate(numbers):\n for idx2, elem2 in enumerate(numbers):\n if idx != idx2:\n distance = abs(elem - elem2)\n if distance < threshold:\n return True\n return False", "rationale": "You can use a brute-force approach to solve this problem. Let's think step by step:\n\n1. Double Loop Iteration: The code uses nested loops to compare each pair of elements in the list. The outer loop iterates through each element (elem) in the list, and the inner loop also iterates through each element (elem2) in the list. This ensures that every possible pair of elements is considered for comparison.\n\n2. Comparison and Distance Calculation: Within the nested loops, the code checks if the indices of the two elements being compared (idx and idx2) are not the same, ensuring that the code doesn't compare an element with itself. It then calculates the absolute difference (distance) between the two elements using the abs() function.\n\n3. Threshold Check: After calculating the distance between two elements, the code checks if this distance is less than the given threshold. If the distance is below the threshold, it means that the two elements are closer to each other than allowed by the threshold, and the function returns True immediately.\n\n4. Return False if No Close Elements Found: If the nested loops complete without finding any pair of elements that satisfy the condition of being closer than the threshold, the function returns False. This indicates that no such pair exists in the given list.\n"} 32 | 33 | exam_scorer_type: code 34 | 35 | teaching_plans: no demo -------------------------------------------------------------------------------- /examples/config/code/llama-3-8b_trgen_debug.yaml: -------------------------------------------------------------------------------- 1 | student_model_cfgs: 2 | - model_type: huggingface 3 | model_cfg: 4 | path: meta-llama/Meta-Llama-3-8B-Instruct 5 | name: llama-3-8b 6 | sample_cfg: 7 | num_return_sequences: 1 8 | 9 | # can be overrided by per-model sample_cfg 10 | general_student_sample_cfg: 11 | batch_size: 1 12 | num_return_sequences: 1 13 | do_sample: false 14 | temperature: 0.0 15 | eos_token_id: [128001, 128009] 16 | pad_token_id: 128001 17 | 18 | exam_maker_type: code_metainfo 19 | exam_maker_cfg: 20 | num_exam_questions: 300 21 | 22 | exam_prompter_type: basic 23 | exam_prompter_cfg: 24 | demo_template: "[[Question]]:\n\nHere is an example question, please understand it very carefully:\n\n{question}\n\nFirst, let''s think step by step to find a complete problem-solving strategy.\nThen, write a python code based on the problem-solving strategy.\n\n\n[[RATIONALE]]:\n\n{rationale}\n\n\n[[Final Code]]:\n\n${answer}$\n" 25 | exam_template: "{demo}\n\n[[Question]]:\n\nPlease first understand the problem-solving approach in rationale of the aforementioned\nexample, and then follow the example to solve the following similar type of problem:\n\n{question}\n\nFirst, let''s think step by step to find a complete problem-solving strategy.\nThen, write a python code based on the problem-solving strategy.\n\n\n[[RATIONALE]]:\n" 26 | debug_template: "[[Question]]:\n\n{question}\n\n[[RATIONALE]]:\n\n{rationale}\n\n[[Final Code]]:\n\n{answer}\n\nYou need to debug this code with the following rules:\n(1) If you think the provided code is correct, you must retrieve the original correct code.\n(2) If you think the provided code is incorrect, you debug the code and write the final bug-free code.\n(3) If there is no complete code, you must write a complete code based on the rationale.\n\nLet's think step by step and remember you **must** give me a complete python code finally.\n\n[ROLESWITCHING assistant:]\n" 27 | use_multi_round_conv: false 28 | 29 | exam_scorer_type: code 30 | 31 | teaching_plans: every 32 | 33 | debug: true -------------------------------------------------------------------------------- /examples/config/math/chatgpt35_greedy.yaml: -------------------------------------------------------------------------------- 1 | teaching_plans: no demo 2 | 3 | student_model_cfgs: 4 | - model_type: azure_openai 5 | model_cfg: 6 | model: gpt-35-turbo-0125 7 | name: chatgpt35 8 | api_key: null 9 | api_endpoint: null 10 | 11 | general_student_sample_cfg: 12 | temperature: 0 13 | 14 | exam_maker_type: fixed 15 | 16 | exam_prompter_type: basic 17 | exam_prompter_cfg: 18 | instruction: "Your task is to answer the last question below. Give step by step reasoning before you answer. When you're ready to answer, please wrap your answer and conclude using the format\n'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n" 19 | demo_template: "[[Question]]:\n{question}\n\n[[Solution]]:\nLet's think step by step.\n\n{rationale}\n\n[[Final Answer]]:\n${answer}$\n" 20 | exam_template: "{demo}\n\n\n[[Question]]:\n{question}\n\n[ROLESWITCHING assistant:][[Solution]]:\nLet's think step by step.\n\n" 21 | use_multi_round_conv: false 22 | stub_teaching_items: 23 | - {"question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.", "rationale": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\n\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\n\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", "answer": "[2,5)"} 24 | - {"question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12$, then find $\\det (\\mathbf{A} \\mathbf{B})$.", "rationale": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}$.", "answer": "24"} 25 | - {"question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", "rationale": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight.\n\nIf he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 30n&=480\\\\ \\Rightarrow\\qquad n&=480/30=\\boxed{16} \\end{align*}", "answer": "16"} 26 | - {"question": "If the system of equations \\begin{align*} 6x-4y&=a, \\\\ 6y-9x &=b. \\end{align*} has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b}$, assuming $b$ is nonzero.", "rationale": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain $$6y-9x=-\\frac{3}{2}a$$. Since we also know that $6y-9x=b$, we have $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}$$.", "answer": "-\\frac{2}{3}"} 27 | 28 | exam_scorer_type: math 29 | exam_scorer_cfg: 30 | recall_mode: false 31 | -------------------------------------------------------------------------------- /examples/config/math/chatgpt35_trgen.yaml: -------------------------------------------------------------------------------- 1 | teaching_plans: no demo 2 | 3 | student_model_cfgs: 4 | - model_type: azure_openai 5 | model_cfg: 6 | model: gpt-35-turbo-0125 7 | name: chatgpt35 8 | api_key: null 9 | api_endpoint: null 10 | 11 | general_student_sample_cfg: 12 | temperature: 0.7 13 | 14 | exam_maker_type: fixed 15 | 16 | exam_prompter_type: basic 17 | exam_prompter_cfg: 18 | instruction: "Your task is to answer the last question below. Give step by step reasoning before you answer. When you're ready to answer, please wrap your answer and conclude using the format\n'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n" 19 | demo_template: "[[Question]]:\n{question}\n\n[[Solution]]:\nLet's think step by step.\n\n{rationale}\n\n[[Final Answer]]:\n${answer}$\n" 20 | exam_template: "{demo}\n\n\n[[Question]]:\n{question}\n\n[ROLESWITCHING assistant:][[Solution]]:\nLet's think step by step.\n\n" 21 | use_multi_round_conv: false 22 | stub_teaching_items: 23 | - {"question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.", "rationale": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\n\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\n\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", "answer": "[2,5)"} 24 | - {"question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12$, then find $\\det (\\mathbf{A} \\mathbf{B})$.", "rationale": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}$.", "answer": "24"} 25 | - {"question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", "rationale": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight.\n\nIf he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 30n&=480\\\\ \\Rightarrow\\qquad n&=480/30=\\boxed{16} \\end{align*}", "answer": "16"} 26 | - {"question": "If the system of equations \\begin{align*} 6x-4y&=a, \\\\ 6y-9x &=b. \\end{align*} has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b}$, assuming $b$ is nonzero.", "rationale": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain $$6y-9x=-\\frac{3}{2}a$$. Since we also know that $6y-9x=b$, we have $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}$$.", "answer": "-\\frac{2}{3}"} 27 | 28 | exam_scorer_type: math 29 | exam_scorer_cfg: 30 | recall_mode: false 31 | -------------------------------------------------------------------------------- /examples/config/math/chatgpt4o-mini_exam.yaml: -------------------------------------------------------------------------------- 1 | student_model_cfgs: 2 | - model_type: azure_openai 3 | model_cfg: 4 | model: gpt-4o-mini-2024-07-18 5 | name: chatgpt4o-mini 6 | api_key: null 7 | api_endpoint: null 8 | 9 | general_student_sample_cfg: 10 | temperature: 0.7 11 | 12 | exam_maker_type: func 13 | exam_maker_cfg: 14 | num_exam_questions: 3 15 | num_repetitions: 3 16 | 17 | exam_prompter_type: basic 18 | exam_prompter_cfg: 19 | instruction: "Your task is to answer the last question below. Give step by step reasoning before you answer. When you're ready to answer, please wrap your answer and conclude using the format\n'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n" 20 | demo_template: "[[Question]]:\n{question}\n\n[[Solution]]:\nLet's think step by step.\n\n{rationale}\n\n[[Final Answer]]:\n${answer}$\n" 21 | exam_template: "{demo}\n\n\n[[Question]]:\n{question}\n\n[ROLESWITCHING assistant:][[Solution]]:\nLet's think step by step.\n\n" 22 | use_multi_round_conv: false 23 | stub_teaching_items: 24 | - {"question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.", "rationale": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\n\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\n\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", "answer": "[2,5)"} 25 | - {"question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12$, then find $\\det (\\mathbf{A} \\mathbf{B})$.", "rationale": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}$.", "answer": "24"} 26 | - {"question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", "rationale": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight.\n\nIf he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 30n&=480\\\\ \\Rightarrow\\qquad n&=480/30=\\boxed{16} \\end{align*}", "answer": "16"} 27 | - {"question": "If the system of equations \\begin{align*} 6x-4y&=a, \\\\ 6y-9x &=b. \\end{align*} has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b}$, assuming $b$ is nonzero.", "rationale": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain $$6y-9x=-\\frac{3}{2}a$$. Since we also know that $6y-9x=b$, we have $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}$$.", "answer": "-\\frac{2}{3}"} 28 | 29 | exam_scorer_type: math 30 | exam_scorer_cfg: 31 | recall_mode: false 32 | -------------------------------------------------------------------------------- /examples/config/math/chatgpt4o_greedy.yaml: -------------------------------------------------------------------------------- 1 | teaching_plans: no demo 2 | 3 | student_model_cfgs: 4 | - model_type: azure_openai 5 | model_cfg: 6 | model: gpt-4o-2024-08-06 7 | name: chatgpt4o 8 | api_key: null 9 | api_endpoint: null 10 | 11 | general_student_sample_cfg: 12 | temperature: 0 13 | 14 | exam_maker_type: fixed 15 | 16 | exam_prompter_type: basic 17 | exam_prompter_cfg: 18 | instruction: "Your task is to answer the last question below. Give step by step reasoning before you answer. When you're ready to answer, please wrap your answer and conclude using the format\n'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n" 19 | demo_template: "[[Question]]:\n{question}\n\n[[Solution]]:\nLet's think step by step.\n\n{rationale}\n\n[[Final Answer]]:\n${answer}$\n" 20 | exam_template: "{demo}\n\n\n[[Question]]:\n{question}\n\n[ROLESWITCHING assistant:][[Solution]]:\nLet's think step by step.\n\n" 21 | use_multi_round_conv: false 22 | stub_teaching_items: 23 | - {"question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.", "rationale": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\n\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\n\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", "answer": "[2,5)"} 24 | - {"question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12$, then find $\\det (\\mathbf{A} \\mathbf{B})$.", "rationale": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}$.", "answer": "24"} 25 | - {"question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", "rationale": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight.\n\nIf he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 30n&=480\\\\ \\Rightarrow\\qquad n&=480/30=\\boxed{16} \\end{align*}", "answer": "16"} 26 | - {"question": "If the system of equations \\begin{align*} 6x-4y&=a, \\\\ 6y-9x &=b. \\end{align*} has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b}$, assuming $b$ is nonzero.", "rationale": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain $$6y-9x=-\\frac{3}{2}a$$. Since we also know that $6y-9x=b$, we have $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}$$.", "answer": "-\\frac{2}{3}"} 27 | 28 | exam_scorer_type: math 29 | exam_scorer_cfg: 30 | recall_mode: false 31 | -------------------------------------------------------------------------------- /examples/config/math/chatgpt4o_trgen.yaml: -------------------------------------------------------------------------------- 1 | teaching_plans: no demo 2 | 3 | student_model_cfgs: 4 | - model_type: azure_openai 5 | model_cfg: 6 | model: gpt-4o-2024-08-06 7 | name: chatgpt4o 8 | api_key: null 9 | api_endpoint: null 10 | 11 | general_student_sample_cfg: 12 | temperature: 0.7 13 | 14 | exam_maker_type: fixed 15 | 16 | exam_prompter_type: basic 17 | exam_prompter_cfg: 18 | instruction: "Your task is to answer the last question below. Give step by step reasoning before you answer. When you're ready to answer, please wrap your answer and conclude using the format\n'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n" 19 | demo_template: "[[Question]]:\n{question}\n\n[[Solution]]:\nLet's think step by step.\n\n{rationale}\n\n[[Final Answer]]:\n${answer}$\n" 20 | exam_template: "{demo}\n\n\n[[Question]]:\n{question}\n\n[ROLESWITCHING assistant:][[Solution]]:\nLet's think step by step.\n\n" 21 | use_multi_round_conv: false 22 | stub_teaching_items: 23 | - {"question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.", "rationale": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\n\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\n\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", "answer": "[2,5)"} 24 | - {"question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12$, then find $\\det (\\mathbf{A} \\mathbf{B})$.", "rationale": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}$.", "answer": "24"} 25 | - {"question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", "rationale": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight.\n\nIf he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 30n&=480\\\\ \\Rightarrow\\qquad n&=480/30=\\boxed{16} \\end{align*}", "answer": "16"} 26 | - {"question": "If the system of equations \\begin{align*} 6x-4y&=a, \\\\ 6y-9x &=b. \\end{align*} has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b}$, assuming $b$ is nonzero.", "rationale": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain $$6y-9x=-\\frac{3}{2}a$$. Since we also know that $6y-9x=b$, we have $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}$$.", "answer": "-\\frac{2}{3}"} 27 | 28 | exam_scorer_type: math 29 | exam_scorer_cfg: 30 | recall_mode: false 31 | -------------------------------------------------------------------------------- /examples/config/math/llama-3-70b_greedy.yaml: -------------------------------------------------------------------------------- 1 | teaching_plans: no demo 2 | 3 | student_model_cfgs: 4 | - model_type: huggingface 5 | model_cfg: 6 | path: meta-llama/Meta-Llama-3-70B-Instruct 7 | name: llama-3-70b 8 | 9 | general_student_sample_cfg: 10 | batch_size: 8 11 | do_sample: false 12 | num_return_sequences: 1 13 | temperature: 0 14 | 15 | exam_maker_type: fixed 16 | 17 | exam_prompter_type: basic 18 | exam_prompter_cfg: 19 | instruction: "Your task is to answer the last question below. Give step by step reasoning before you answer. When you're ready to answer, please wrap your answer and conclude using the format\n'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n" 20 | demo_template: "[[Question]]:\n{question}\n\n[[Solution]]:\nLet's think step by step.\n\n{rationale}\n\n[[Final Answer]]:\n${answer}$\n" 21 | exam_template: "{demo}\n\n\n[[Question]]:\n{question}\n\n[ROLESWITCHING assistant:][[Solution]]:\nLet's think step by step.\n\n" 22 | use_multi_round_conv: false 23 | stub_teaching_items: 24 | - {"question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.", "rationale": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\n\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\n\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", "answer": "[2,5)"} 25 | - {"question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12$, then find $\\det (\\mathbf{A} \\mathbf{B})$.", "rationale": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}$.", "answer": "24"} 26 | - {"question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", "rationale": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight.\n\nIf he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 30n&=480\\\\ \\Rightarrow\\qquad n&=480/30=\\boxed{16} \\end{align*}", "answer": "16"} 27 | - {"question": "If the system of equations \\begin{align*} 6x-4y&=a, \\\\ 6y-9x &=b. \\end{align*} has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b}$, assuming $b$ is nonzero.", "rationale": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain $$6y-9x=-\\frac{3}{2}a$$. Since we also know that $6y-9x=b$, we have $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}$$.", "answer": "-\\frac{2}{3}"} 28 | 29 | exam_scorer_type: math 30 | exam_scorer_cfg: 31 | recall_mode: false 32 | -------------------------------------------------------------------------------- /examples/config/math/llama-3-70b_trgen.yaml: -------------------------------------------------------------------------------- 1 | teaching_plans: no demo 2 | 3 | student_model_cfgs: 4 | - model_type: huggingface 5 | model_cfg: 6 | path: meta-llama/Meta-Llama-3-70B-Instruct 7 | name: llama-3-70b 8 | 9 | general_student_sample_cfg: 10 | batch_size: 8 11 | do_sample: true 12 | num_return_sequences: 1 13 | top_k: 20 14 | top_p: 1.0 15 | temperature: 0.7 16 | 17 | exam_maker_type: fixed 18 | 19 | exam_prompter_type: basic 20 | exam_prompter_cfg: 21 | instruction: "Your task is to answer the last question below. Give step by step reasoning before you answer. When you're ready to answer, please wrap your answer and conclude using the format\n'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n" 22 | demo_template: "[[Question]]:\n{question}\n\n[[Solution]]:\nLet's think step by step.\n\n{rationale}\n\n[[Final Answer]]:\n${answer}$\n" 23 | exam_template: "{demo}\n\n\n[[Question]]:\n{question}\n\n[ROLESWITCHING assistant:][[Solution]]:\nLet's think step by step.\n\n" 24 | use_multi_round_conv: false 25 | stub_teaching_items: 26 | - {"question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.", "rationale": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\n\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\n\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", "answer": "[2,5)"} 27 | - {"question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12$, then find $\\det (\\mathbf{A} \\mathbf{B})$.", "rationale": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}$.", "answer": "24"} 28 | - {"question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", "rationale": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight.\n\nIf he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 30n&=480\\\\ \\Rightarrow\\qquad n&=480/30=\\boxed{16} \\end{align*}", "answer": "16"} 29 | - {"question": "If the system of equations \\begin{align*} 6x-4y&=a, \\\\ 6y-9x &=b. \\end{align*} has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b}$, assuming $b$ is nonzero.", "rationale": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain $$6y-9x=-\\frac{3}{2}a$$. Since we also know that $6y-9x=b$, we have $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}$$.", "answer": "-\\frac{2}{3}"} 30 | 31 | exam_scorer_type: math 32 | exam_scorer_cfg: 33 | recall_mode: false 34 | -------------------------------------------------------------------------------- /examples/config/math/llama-3-8b_exam.yaml: -------------------------------------------------------------------------------- 1 | student_model_cfgs: 2 | - model_type: huggingface 3 | model_cfg: 4 | path: meta-llama/Meta-Llama-3-8B-Instruct 5 | name: llama-3-8b 6 | 7 | general_student_sample_cfg: 8 | batch_size: 9 9 | do_sample: true 10 | num_return_sequences: 1 11 | top_k: 20 12 | top_p: 1.0 13 | temperature: 0.7 14 | 15 | exam_maker_type: func 16 | exam_maker_cfg: 17 | num_exam_questions: 3 18 | num_repetitions: 3 19 | 20 | exam_prompter_type: basic 21 | exam_prompter_cfg: 22 | instruction: "Your task is to answer the last question below. Give step by step reasoning before you answer. When you're ready to answer, please wrap your answer and conclude using the format\n'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n" 23 | demo_template: "[[Question]]:\n{question}\n\n[[Solution]]:\nLet's think step by step.\n\n{rationale}\n\n[[Final Answer]]:\n${answer}$\n" 24 | exam_template: "{demo}\n\n\n[[Question]]:\n{question}\n\n[ROLESWITCHING assistant:][[Solution]]:\nLet's think step by step.\n\n" 25 | use_multi_round_conv: false 26 | stub_teaching_items: 27 | - {"question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.", "rationale": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\n\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\n\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", "answer": "[2,5)"} 28 | - {"question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12$, then find $\\det (\\mathbf{A} \\mathbf{B})$.", "rationale": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}$.", "answer": "24"} 29 | - {"question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", "rationale": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight.\n\nIf he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 30n&=480\\\\ \\Rightarrow\\qquad n&=480/30=\\boxed{16} \\end{align*}", "answer": "16"} 30 | - {"question": "If the system of equations \\begin{align*} 6x-4y&=a, \\\\ 6y-9x &=b. \\end{align*} has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b}$, assuming $b$ is nonzero.", "rationale": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain $$6y-9x=-\\frac{3}{2}a$$. Since we also know that $6y-9x=b$, we have $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}$$.", "answer": "-\\frac{2}{3}"} 31 | 32 | exam_scorer_type: math 33 | exam_scorer_cfg: 34 | recall_mode: false 35 | -------------------------------------------------------------------------------- /examples/config/math/llama-3-8b_greedy.yaml: -------------------------------------------------------------------------------- 1 | teaching_plans: no demo 2 | 3 | student_model_cfgs: 4 | - model_type: huggingface 5 | model_cfg: 6 | path: meta-llama/Meta-Llama-3-8B-Instruct 7 | name: llama-3-8b 8 | 9 | general_student_sample_cfg: 10 | batch_size: 8 11 | do_sample: false 12 | num_return_sequences: 1 13 | temperature: 0 14 | 15 | exam_maker_type: fixed 16 | 17 | exam_prompter_type: basic 18 | exam_prompter_cfg: 19 | instruction: "Your task is to answer the last question below. Give step by step reasoning before you answer. When you're ready to answer, please wrap your answer and conclude using the format\n'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n" 20 | demo_template: "[[Question]]:\n{question}\n\n[[Solution]]:\nLet's think step by step.\n\n{rationale}\n\n[[Final Answer]]:\n${answer}$\n" 21 | exam_template: "{demo}\n\n\n[[Question]]:\n{question}\n\n[ROLESWITCHING assistant:][[Solution]]:\nLet's think step by step.\n\n" 22 | use_multi_round_conv: false 23 | stub_teaching_items: 24 | - {"question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.", "rationale": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\n\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\n\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", "answer": "[2,5)"} 25 | - {"question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12$, then find $\\det (\\mathbf{A} \\mathbf{B})$.", "rationale": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}$.", "answer": "24"} 26 | - {"question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", "rationale": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight.\n\nIf he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 30n&=480\\\\ \\Rightarrow\\qquad n&=480/30=\\boxed{16} \\end{align*}", "answer": "16"} 27 | - {"question": "If the system of equations \\begin{align*} 6x-4y&=a, \\\\ 6y-9x &=b. \\end{align*} has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b}$, assuming $b$ is nonzero.", "rationale": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain $$6y-9x=-\\frac{3}{2}a$$. Since we also know that $6y-9x=b$, we have $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}$$.", "answer": "-\\frac{2}{3}"} 28 | 29 | exam_scorer_type: math 30 | exam_scorer_cfg: 31 | recall_mode: false 32 | -------------------------------------------------------------------------------- /examples/config/math/llama-3-8b_trgen.yaml: -------------------------------------------------------------------------------- 1 | teaching_plans: no demo 2 | 3 | student_model_cfgs: 4 | - model_type: huggingface 5 | model_cfg: 6 | path: meta-llama/Meta-Llama-3-8B-Instruct 7 | name: llama-3-8b 8 | 9 | general_student_sample_cfg: 10 | batch_size: 8 11 | do_sample: true 12 | num_return_sequences: 1 13 | top_k: 20 14 | top_p: 1.0 15 | temperature: 0.7 16 | 17 | exam_maker_type: fixed 18 | 19 | exam_prompter_type: basic 20 | exam_prompter_cfg: 21 | instruction: "Your task is to answer the last question below. Give step by step reasoning before you answer. When you're ready to answer, please wrap your answer and conclude using the format\n'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n" 22 | demo_template: "[[Question]]:\n{question}\n\n[[Solution]]:\nLet's think step by step.\n\n{rationale}\n\n[[Final Answer]]:\n${answer}$\n" 23 | exam_template: "{demo}\n\n\n[[Question]]:\n{question}\n\n[ROLESWITCHING assistant:][[Solution]]:\nLet's think step by step.\n\n" 24 | use_multi_round_conv: false 25 | stub_teaching_items: 26 | - {"question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.", "rationale": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\n\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\n\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", "answer": "[2,5)"} 27 | - {"question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12$, then find $\\det (\\mathbf{A} \\mathbf{B})$.", "rationale": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}$.", "answer": "24"} 28 | - {"question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", "rationale": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight.\n\nIf he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 30n&=480\\\\ \\Rightarrow\\qquad n&=480/30=\\boxed{16} \\end{align*}", "answer": "16"} 29 | - {"question": "If the system of equations \\begin{align*} 6x-4y&=a, \\\\ 6y-9x &=b. \\end{align*} has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b}$, assuming $b$ is nonzero.", "rationale": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain $$6y-9x=-\\frac{3}{2}a$$. Since we also know that $6y-9x=b$, we have $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}$$.", "answer": "-\\frac{2}{3}"} 30 | 31 | exam_scorer_type: math 32 | exam_scorer_cfg: 33 | recall_mode: false 34 | -------------------------------------------------------------------------------- /examples/config/math/mistral-7b_exam.yaml: -------------------------------------------------------------------------------- 1 | student_model_cfgs: 2 | - model_type: huggingface 3 | model_cfg: 4 | path: mistralai/Mistral-7B-Instruct-v0.2 5 | name: mistral-7b 6 | 7 | general_student_sample_cfg: 8 | batch_size: 9 9 | do_sample: true 10 | num_return_sequences: 1 11 | top_k: 20 12 | top_p: 1.0 13 | temperature: 0.7 14 | 15 | exam_maker_type: func 16 | exam_maker_cfg: 17 | num_exam_questions: 3 18 | num_repetitions: 3 19 | 20 | exam_prompter_type: basic 21 | exam_prompter_cfg: 22 | instruction: "Your task is to answer the last question below. Give step by step reasoning before you answer. When you're ready to answer, please wrap your answer and conclude using the format\n'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n" 23 | demo_template: "[[Question]]:\n{question}\n\n[[Solution]]:\nLet's think step by step.\n\n{rationale}\n\n[[Final Answer]]:\n${answer}$\n" 24 | exam_template: "{demo}\n\n\n[[Question]]:\n{question}\n\n[ROLESWITCHING assistant:][[Solution]]:\nLet's think step by step.\n\n" 25 | use_multi_round_conv: false 26 | stub_teaching_items: 27 | - {"question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.", "rationale": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\n\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\n\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", "answer": "[2,5)"} 28 | - {"question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12$, then find $\\det (\\mathbf{A} \\mathbf{B})$.", "rationale": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}$.", "answer": "24"} 29 | - {"question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", "rationale": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight.\n\nIf he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 30n&=480\\\\ \\Rightarrow\\qquad n&=480/30=\\boxed{16} \\end{align*}", "answer": "16"} 30 | - {"question": "If the system of equations \\begin{align*} 6x-4y&=a, \\\\ 6y-9x &=b. \\end{align*} has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b}$, assuming $b$ is nonzero.", "rationale": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain $$6y-9x=-\\frac{3}{2}a$$. Since we also know that $6y-9x=b$, we have $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}$$.", "answer": "-\\frac{2}{3}"} 31 | 32 | exam_scorer_type: math 33 | exam_scorer_cfg: 34 | recall_mode: false 35 | -------------------------------------------------------------------------------- /examples/config/math/mistral-7b_greedy.yaml: -------------------------------------------------------------------------------- 1 | teaching_plans: no demo 2 | 3 | student_model_cfgs: 4 | - model_type: huggingface 5 | model_cfg: 6 | path: mistralai/Mistral-7B-Instruct-v0.2 7 | name: mistral-7b 8 | 9 | general_student_sample_cfg: 10 | batch_size: 8 11 | do_sample: false 12 | num_return_sequences: 1 13 | temperature: 0 14 | 15 | exam_maker_type: fixed 16 | 17 | exam_prompter_type: basic 18 | exam_prompter_cfg: 19 | instruction: "Your task is to answer the last question below. Give step by step reasoning before you answer. When you're ready to answer, please wrap your answer and conclude using the format\n'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n" 20 | demo_template: "[[Question]]:\n{question}\n\n[[Solution]]:\nLet's think step by step.\n\n{rationale}\n\n[[Final Answer]]:\n${answer}$\n" 21 | exam_template: "{demo}\n\n\n[[Question]]:\n{question}\n\n[ROLESWITCHING assistant:][[Solution]]:\nLet's think step by step.\n\n" 22 | use_multi_round_conv: false 23 | stub_teaching_items: 24 | - {"question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.", "rationale": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\n\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\n\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", "answer": "[2,5)"} 25 | - {"question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12$, then find $\\det (\\mathbf{A} \\mathbf{B})$.", "rationale": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}$.", "answer": "24"} 26 | - {"question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", "rationale": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight.\n\nIf he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 30n&=480\\\\ \\Rightarrow\\qquad n&=480/30=\\boxed{16} \\end{align*}", "answer": "16"} 27 | - {"question": "If the system of equations \\begin{align*} 6x-4y&=a, \\\\ 6y-9x &=b. \\end{align*} has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b}$, assuming $b$ is nonzero.", "rationale": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain $$6y-9x=-\\frac{3}{2}a$$. Since we also know that $6y-9x=b$, we have $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}$$.", "answer": "-\\frac{2}{3}"} 28 | 29 | exam_scorer_type: math 30 | exam_scorer_cfg: 31 | recall_mode: false 32 | -------------------------------------------------------------------------------- /examples/config/math/mistral-7b_trgen.yaml: -------------------------------------------------------------------------------- 1 | teaching_plans: no demo 2 | 3 | student_model_cfgs: 4 | - model_type: huggingface 5 | model_cfg: 6 | path: mistralai/Mistral-7B-Instruct-v0.2 7 | name: mistral-7b 8 | 9 | general_student_sample_cfg: 10 | batch_size: 8 11 | do_sample: true 12 | num_return_sequences: 1 13 | top_k: 20 14 | top_p: 1.0 15 | temperature: 0.7 16 | 17 | exam_maker_type: fixed 18 | 19 | exam_prompter_type: basic 20 | exam_prompter_cfg: 21 | instruction: "Your task is to answer the last question below. Give step by step reasoning before you answer. When you're ready to answer, please wrap your answer and conclude using the format\n'''\n[[Final Answer]]:\n$ANSWER$\n'''\n\n\n\n" 22 | demo_template: "[[Question]]:\n{question}\n\n[[Solution]]:\nLet's think step by step.\n\n{rationale}\n\n[[Final Answer]]:\n${answer}$\n" 23 | exam_template: "{demo}\n\n\n[[Question]]:\n{question}\n\n[ROLESWITCHING assistant:][[Solution]]:\nLet's think step by step.\n\n" 24 | use_multi_round_conv: false 25 | stub_teaching_items: 26 | - {"question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.", "rationale": "The expressions inside each square root must be non-negative. Therefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\n\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\n\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", "answer": "[2,5)"} 27 | - {"question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12$, then find $\\det (\\mathbf{A} \\mathbf{B})$.", "rationale": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}$.", "answer": "24"} 28 | - {"question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", "rationale": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight.\n\nIf he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 30n&=480\\\\ \\Rightarrow\\qquad n&=480/30=\\boxed{16} \\end{align*}", "answer": "16"} 29 | - {"question": "If the system of equations \\begin{align*} 6x-4y&=a, \\\\ 6y-9x &=b. \\end{align*} has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b}$, assuming $b$ is nonzero.", "rationale": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain $$6y-9x=-\\frac{3}{2}a$$. Since we also know that $6y-9x=b$, we have $$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}$$.", "answer": "-\\frac{2}{3}"} 30 | 31 | exam_scorer_type: math 32 | exam_scorer_cfg: 33 | recall_mode: false 34 | -------------------------------------------------------------------------------- /lbt/__init__.py: -------------------------------------------------------------------------------- 1 | from lbt import utils 2 | from lbt.base import Component 3 | from lbt import datasets_adapter 4 | from lbt import models 5 | from lbt import exam_maker 6 | from lbt import exam_scorer 7 | -------------------------------------------------------------------------------- /lbt/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copied from https://github.com/walkerning/aw_nas 3 | 4 | from collections import OrderedDict 5 | import six 6 | from six import StringIO 7 | import yaml 8 | 9 | from lbt import utils 10 | from lbt.utils import RegistryMeta 11 | from lbt.utils import getLogger 12 | 13 | # Make yaml.safe_dump support OrderedDict 14 | yaml.add_representer( 15 | OrderedDict, 16 | lambda dumper, data: dumper.represent_mapping( 17 | "tag:yaml.org,2002:map", data.items() 18 | ), 19 | Dumper=yaml.dumper.SafeDumper, 20 | ) 21 | 22 | 23 | LOGGER = getLogger("registry") 24 | 25 | 26 | @six.add_metaclass(RegistryMeta) 27 | class Component: 28 | def __init__(self): 29 | self._logger = None 30 | 31 | @property 32 | def logger(self): 33 | if self._logger is None: 34 | self._logger = getLogger(self.__class__.__name__) 35 | return self._logger 36 | 37 | def __getstate__(self): 38 | state = self.__dict__.copy() 39 | if "_logger" in state: 40 | del state["_logger"] 41 | return state 42 | 43 | def __setstate__(self, state): 44 | self.__dict__.update(state) 45 | # set self._logger to None 46 | self._logger = None 47 | 48 | @classmethod 49 | def get_default_config(cls): 50 | return utils.get_default_argspec(cls.__init__) 51 | 52 | @classmethod 53 | def get_default_config_str(cls): 54 | stream = StringIO() 55 | cfg = OrderedDict(cls.get_default_config()) 56 | yaml.safe_dump(cfg, stream=stream, default_flow_style=False) 57 | return stream.getvalue() 58 | 59 | @classmethod 60 | def get_current_config_str(cls, cfg): 61 | stream = StringIO() 62 | whole_cfg = OrderedDict(cls.get_default_config()) 63 | whole_cfg.update(cfg) 64 | yaml.safe_dump(whole_cfg, stream=stream, default_flow_style=False) 65 | return stream.getvalue() 66 | 67 | @classmethod 68 | def init_from_cfg_file(cls, cfg_path, registry_name=None, **addi_kwargs): 69 | with open(cfg_path, "r") as rf: 70 | cfg = yaml.safe_load(rf) 71 | return cls.init_from_cfg(cfg, registry_name=registry_name, **addi_kwargs) 72 | 73 | @classmethod 74 | def init_from_cfg(cls, cfg, registry_name=None, **addi_kwargs): 75 | avail_registries = RegistryMeta.avail_tables() 76 | if not hasattr(cls, "REGISTRY"): 77 | # Component class 78 | if registry_name is not None: 79 | assert registry_name in avail_registries 80 | else: 81 | type_keys = [ 82 | key 83 | for key in cfg.keys() 84 | if key.endswith("_type") and key[:-5] in avail_registries 85 | ] 86 | assert len(type_keys) == 1 87 | registry_name = type_keys[0][:-5] 88 | LOGGER.info(f"Guess `registry_name={registry_name}` from the config.") 89 | elif not hasattr(cls, "NAME"): 90 | # Base classes that inherit `Component` 91 | assert registry_name is None or registry_name == cls.REGISTRY, ( 92 | f"This class `{cls.__name__}` is in registry `{cls.REGISTRY}`, do not" 93 | f" match `{registry_name}`. Either do not pass in the `registry_name`" 94 | " argument or pass in a matching registry name." 95 | ) 96 | registry_name = cls.REGISTRY 97 | type_ = cfg[registry_name + "_type"] 98 | true_cls = RegistryMeta.get_class(registry_name, type_) 99 | else: 100 | # Concrete class 101 | assert registry_name is None or registry_name == cls.REGISTRY, ( 102 | f"This class `{cls.__name__}` is in registry `{cls.REGISTRY}`, do not" 103 | f" match `{registry_name}`. Either do not pass in the `registry_name`" 104 | " argument or pass in a matching registry name." 105 | ) 106 | registry_name = cls.REGISTRY 107 | type_ = cls.NAME 108 | 109 | type_ = cfg[registry_name + "_type"] 110 | true_cls = RegistryMeta.get_class(registry_name, type_) 111 | LOGGER.info( 112 | "Component [%s] type: %s, %s.%s", 113 | registry_name, 114 | type_, 115 | true_cls.__module__, 116 | true_cls.__name__, 117 | ) 118 | 119 | class_cfg = cfg.get(registry_name + "_cfg", {}) 120 | class_cfg = class_cfg or {} 121 | # config items will override addi_args items 122 | addi_kwargs.update(class_cfg) 123 | 124 | whole_cfg_str = true_cls.get_current_config_str(class_cfg) 125 | LOGGER.info( 126 | "%s `%s` config:\n%s", 127 | registry_name, 128 | type_, 129 | utils._add_text_prefix(whole_cfg_str, " "), 130 | ) 131 | return true_cls(**addi_kwargs) 132 | -------------------------------------------------------------------------------- /lbt/datasets_adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from lbt.datasets_adapter.math_dataset import * 2 | from lbt.datasets_adapter.code_dataset import * 3 | -------------------------------------------------------------------------------- /lbt/datasets_adapter/apps_utils/testing_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | import io 6 | import faulthandler 7 | import platform 8 | 9 | # used for debugging to time steps 10 | from datetime import datetime 11 | 12 | # to run the solution files we're using a timing based approach 13 | import signal 14 | 15 | import numpy as np 16 | 17 | # for capturing the stdout 18 | from io import StringIO 19 | from typing import get_type_hints 20 | from typing import List, Tuple 21 | 22 | # used for testing the code that reads from input 23 | from unittest.mock import patch, mock_open 24 | 25 | from pyext import RuntimeModule 26 | 27 | from enum import Enum 28 | 29 | 30 | # stuff for setting up signal timer 31 | class TimeoutException(Exception): 32 | pass 33 | 34 | 35 | def timeout_handler(signum, frame): 36 | print("alarm went off") 37 | # return 38 | raise TimeoutException 39 | 40 | 41 | signal.signal(signal.SIGALRM, timeout_handler) 42 | timeout = 4 # seconds 43 | 44 | 45 | # used to capture stdout as a list 46 | # from https://stackoverflow.com/a/16571630/6416660 47 | # alternative use redirect_stdout() from contextlib 48 | class Capturing(list): 49 | def __enter__(self): 50 | self._stdout = sys.stdout 51 | sys.stdout = self._stringio = StringIO() 52 | # Make closing the StringIO a no-op 53 | self._stringio.close = lambda x: 1 54 | return self 55 | 56 | def __exit__(self, *args): 57 | self.extend(self._stringio.getvalue().splitlines()) 58 | del self._stringio # free up some memory 59 | sys.stdout = self._stdout 60 | 61 | 62 | def run_test(extracted_code: str = None, test: dict = None): 63 | assert extracted_code is not None 64 | assert test is not None 65 | 66 | # Disable functionalities that can make destructive changes to the test. 67 | # reliability_guard() 68 | 69 | results = [] 70 | sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n" 71 | 72 | tmp_code = extracted_code.split("\n") 73 | 74 | new_code = [] 75 | for x in tmp_code: 76 | if (not x.startswith("from ")) and (not x.startswith("import ")): 77 | new_code.append("\t" + x + "\n") 78 | else: 79 | new_code.append(x + "\n") 80 | tmp_code = new_code 81 | 82 | new_code = "" 83 | started = False 84 | for i in tmp_code: 85 | if i.startswith("\t") and not started: 86 | new_code += "stdin = sys.stdin\nstdout = sys.stdout\n" 87 | new_code += "def code():\n" 88 | new_code += i 89 | started = True 90 | elif started and ((i.startswith("from ")) or (i.startswith("import "))): 91 | new_code += "\t" + i 92 | else: 93 | new_code += i 94 | tmp_code = new_code 95 | 96 | sol += tmp_code 97 | 98 | method_name = "code" 99 | signal.alarm(timeout) 100 | try: 101 | tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol) 102 | tmp = tmp_sol 103 | signal.alarm(0) 104 | except Exception as e: 105 | signal.alarm(0) 106 | print(f"type 1 compilation error = {e}") 107 | # results.append(-2) 108 | results.append(False) 109 | return results 110 | signal.alarm(0) 111 | 112 | try: 113 | method = getattr(tmp, method_name) # get_attr second arg must be str 114 | except: 115 | signal.alarm(0) 116 | e = sys.exc_info() 117 | print(f"unable to get function error = {e}") 118 | # results.append(-2) 119 | results.append(False) 120 | return results 121 | 122 | for index, inputs in enumerate(test["inputs"]): 123 | # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list) 124 | try: 125 | if isinstance(inputs[0], dict): 126 | inputs = [{int(k): v for k, v in inputs[0].items()}] 127 | except: 128 | True 129 | 130 | try: 131 | if isinstance(test["outputs"][index], dict): 132 | test["outputs"][index] = [ 133 | {int(k): v for k, v in test["outputs"][index].items()} 134 | ] 135 | except: 136 | True 137 | try: 138 | if isinstance(test["outputs"][index][0], dict): 139 | test["outputs"][index] = [ 140 | {int(k): v for k, v in test["outputs"][index][0].items()} 141 | ] 142 | except: 143 | True 144 | 145 | faulthandler.enable() 146 | signal.alarm(timeout) 147 | passed = False 148 | 149 | if isinstance(inputs, list): 150 | inputs = "\n".join(inputs) 151 | if isinstance(test["outputs"][index], list): 152 | test["outputs"][index] = "\n".join(test["outputs"][index]) 153 | 154 | with Capturing() as output: 155 | try: 156 | call_method(method, inputs) 157 | # reset the alarm 158 | signal.alarm(0) 159 | passed = True 160 | except Exception as e: 161 | # runtime error or took too long 162 | signal.alarm(0) 163 | print( 164 | f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}" 165 | ) 166 | # results.append(-1) 167 | results.append(False) 168 | signal.alarm(0) 169 | 170 | if not passed: 171 | continue 172 | # else: 173 | # print(f"==> output = {output}, test outputs = {test['outputs'][index]}") 174 | 175 | if custom_compare_(output, test["outputs"][index]): 176 | tmp_result = True 177 | results.append(tmp_result) 178 | continue 179 | 180 | # ground truth sequences are expressed as lists not tuples 181 | if isinstance(output, tuple): 182 | output = list(output) 183 | 184 | tmp_result = False 185 | try: 186 | tmp_result = output == [test["outputs"][index]] 187 | if isinstance(test["outputs"][index], list): 188 | tmp_result = tmp_result or (output == test["outputs"][index]) 189 | if isinstance(output[0], str): 190 | tmp_result = tmp_result or ( 191 | [e.strip() for e in output] == test["outputs"][index] 192 | ) 193 | except Exception as e: 194 | print(f"Failed check1 exception = {e}") 195 | pass 196 | 197 | if tmp_result == True: 198 | results.append(tmp_result) 199 | continue 200 | 201 | # try one more time without \n 202 | if isinstance(test["outputs"][index], list): 203 | for tmp_index, i in enumerate(test["outputs"][index]): 204 | test["outputs"][index][tmp_index] = i.split("\n") 205 | test["outputs"][index][tmp_index] = [ 206 | x.strip() for x in test["outputs"][index][tmp_index] if x 207 | ] 208 | else: 209 | test["outputs"][index] = test["outputs"][index].split("\n") 210 | test["outputs"][index] = list(filter(len, test["outputs"][index])) 211 | test["outputs"][index] = list( 212 | map(lambda x: x.strip(), test["outputs"][index]) 213 | ) 214 | 215 | try: 216 | tmp_result = output == [test["outputs"][index]] 217 | if isinstance(test["outputs"][index], list): 218 | tmp_result = tmp_result or (output == test["outputs"][index]) 219 | except Exception as e: 220 | print(f"Failed check2 exception = {e}") 221 | pass 222 | 223 | if tmp_result == True: 224 | results.append(tmp_result) 225 | continue 226 | 227 | # try by converting the output into a split up list too 228 | if isinstance(output, list): 229 | output = list(filter(len, output)) 230 | 231 | if tmp_result == True: 232 | results.append(tmp_result) 233 | continue 234 | 235 | try: 236 | tmp_result = output == [test["outputs"][index]] 237 | if isinstance(test["outputs"][index], list): 238 | tmp_result = tmp_result or (output == test["outputs"][index]) 239 | except Exception as e: 240 | print(f"Failed check3 exception = {e}") 241 | pass 242 | 243 | try: 244 | output_float = [float(e) for e in output] 245 | gt_float = [float(e) for e in test["outputs"][index]] 246 | tmp_result = tmp_result or ( 247 | (len(output_float) == len(gt_float)) 248 | and np.allclose(output_float, gt_float) 249 | ) 250 | except Exception as e: 251 | pass 252 | try: 253 | if isinstance(output[0], list): 254 | output_float = [float(e) for e in output[0]] 255 | gt_float = [float(e) for e in test["outputs"][index][0]] 256 | tmp_result = tmp_result or ( 257 | (len(output_float) == len(gt_float)) 258 | and np.allclose(output_float, gt_float) 259 | ) 260 | except Exception as e: 261 | pass 262 | 263 | if tmp_result == True: 264 | results.append(tmp_result) 265 | continue 266 | 267 | # try by converting the stuff into split up list 268 | if isinstance(test["outputs"][index], list): 269 | for tmp_index, i in enumerate(test["outputs"][index]): 270 | test["outputs"][index][tmp_index] = set(i.split()) 271 | else: 272 | test["outputs"][index] = set(test["outputs"][index].split()) 273 | 274 | try: 275 | tmp_result = output == test["outputs"][index] 276 | except Exception as e: 277 | print(f"Failed check4 exception = {e}") 278 | continue 279 | 280 | if tmp_result == True: 281 | results.append(tmp_result) 282 | continue 283 | 284 | # try by converting the output into a split up list too 285 | if isinstance(output, list): 286 | for tmp_index, i in enumerate(output): 287 | output[tmp_index] = i.split() 288 | output = list(filter(len, output)) 289 | for tmp_index, i in enumerate(output): 290 | output[tmp_index] = set(i) 291 | else: 292 | output = output.split() 293 | output = list(filter(len, output)) 294 | output = set(output) 295 | 296 | try: 297 | tmp_result = set(frozenset(s) for s in output) == set( 298 | frozenset(s) for s in test["outputs"][index] 299 | ) 300 | except Exception as e: 301 | print(f"Failed check5 exception = {e}") 302 | 303 | # if they are all numbers, round so that similar numbers are treated as identical 304 | try: 305 | tmp_result = tmp_result or ( 306 | set(frozenset(round(float(t), 3) for t in s) for s in output) 307 | == set( 308 | frozenset(round(float(t), 3) for t in s) 309 | for s in test["outputs"][index] 310 | ) 311 | ) 312 | except Exception as e: 313 | print(f"Failed check6 exception = {e}") 314 | 315 | results.append(tmp_result) 316 | 317 | return results 318 | 319 | 320 | def custom_compare_(output, ground_truth): 321 | if isinstance(output, list): 322 | output_1 = "\n".join(output) 323 | if stripped_string_compare(output_1, ground_truth): 324 | return True 325 | 326 | if isinstance(output, list): 327 | output_2 = [o.lstrip().rstrip() for o in output] 328 | output_2 = "\n".join(output_2) 329 | if stripped_string_compare(output_2, ground_truth): 330 | return True 331 | 332 | return False 333 | 334 | 335 | def stripped_string_compare(s1, s2): 336 | s1 = s1.lstrip().rstrip() 337 | s2 = s2.lstrip().rstrip() 338 | return s1 == s2 339 | 340 | 341 | def call_method(method, inputs): 342 | if isinstance(inputs, list): 343 | inputs = "\n".join(inputs) 344 | 345 | inputs_line_iterator = iter(inputs.split("\n")) 346 | 347 | # sys.setrecursionlimit(10000) 348 | 349 | # @patch('builtins.input', side_effect=inputs.split("\n")) 350 | @patch("builtins.open", mock_open(read_data=inputs)) 351 | @patch("sys.stdin", StringIO(inputs)) 352 | @patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator)) 353 | @patch("sys.stdin.readlines", lambda *args: inputs.split("\n")) 354 | @patch("sys.stdin.read", lambda *args: inputs) 355 | # @patch('sys.stdout.write', print) 356 | def _inner_call_method(_method): 357 | try: 358 | return _method() 359 | except SystemExit as e: 360 | pass 361 | finally: 362 | pass 363 | 364 | return _inner_call_method(method) 365 | 366 | 367 | def reliability_guard(maximum_memory_bytes=None): 368 | """ 369 | source: https://github.com/openai/human-eval 370 | This disables various destructive functions and prevents the generated code 371 | from interfering with the test (e.g. fork bomb, killing other processes, 372 | removing filesystem files, etc.) 373 | WARNING 374 | This function is NOT a security sandbox. Untrusted code, including, model- 375 | generated code, should not be blindly executed outside of one. See the 376 | Codex paper for more information about OpenAI's code sandbox, and proceed 377 | with caution. 378 | """ 379 | 380 | if maximum_memory_bytes is not None: 381 | import resource 382 | 383 | resource.setrlimit( 384 | resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes) 385 | ) 386 | resource.setrlimit( 387 | resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes) 388 | ) 389 | if not platform.uname().system == "Darwin": 390 | resource.setrlimit( 391 | resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes) 392 | ) 393 | 394 | faulthandler.disable() 395 | 396 | import builtins 397 | 398 | builtins.exit = None 399 | builtins.quit = None 400 | 401 | import os 402 | 403 | os.environ["OMP_NUM_THREADS"] = "1" 404 | 405 | os.kill = None 406 | os.system = None 407 | os.putenv = None 408 | os.remove = None 409 | os.removedirs = None 410 | os.rmdir = None 411 | os.fchdir = None 412 | os.setuid = None 413 | os.fork = None 414 | os.forkpty = None 415 | os.killpg = None 416 | os.rename = None 417 | os.renames = None 418 | os.truncate = None 419 | os.replace = None 420 | os.unlink = None 421 | os.fchmod = None 422 | os.fchown = None 423 | os.chmod = None 424 | os.chown = None 425 | os.chroot = None 426 | os.fchdir = None 427 | os.lchflags = None 428 | os.lchmod = None 429 | os.lchown = None 430 | os.getcwd = None 431 | os.chdir = None 432 | 433 | import shutil 434 | 435 | shutil.rmtree = None 436 | shutil.move = None 437 | shutil.chown = None 438 | 439 | import subprocess 440 | 441 | subprocess.Popen = None # type: ignore 442 | 443 | __builtins__["help"] = None 444 | 445 | import sys 446 | 447 | sys.modules["ipdb"] = None 448 | sys.modules["joblib"] = None 449 | sys.modules["resource"] = None 450 | sys.modules["psutil"] = None 451 | sys.modules["tkinter"] = None 452 | -------------------------------------------------------------------------------- /lbt/datasets_adapter/leetcode_sub/environment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from datetime import datetime 4 | 5 | import dotenv 6 | import gym 7 | import leetcode 8 | import leetcode.auth 9 | 10 | from .types import LeetCodeSubmission 11 | from .leetcode import id_from_slug 12 | 13 | dotenv.load_dotenv() 14 | 15 | 16 | class LeetCodeEnv(gym.Env): 17 | """ 18 | Gym environment for LeetCode submissions 19 | """ 20 | 21 | metadata = {"render.modes": ["human"]} 22 | 23 | def __init__(self, cooldown=0, csrf_token=None): 24 | super(LeetCodeEnv, self).__init__() 25 | self.__configure_leetcode(csrf_token) 26 | self.reward = False 27 | self.last_run = None 28 | self.cooldown = cooldown # To avoid rate limit 29 | 30 | def __configure_leetcode(self, csrf_token=None): 31 | configuration = leetcode.Configuration() 32 | 33 | # From Dev Tools/Application/Cookies/LEETCODE_SESSION 34 | leetcode_session = os.environ["LEETCODE_SESSION"] 35 | if csrf_token is None: 36 | csrf_token = os.environ["CSRF_TOKEN"] 37 | 38 | configuration.api_key["x-csrftoken"] = csrf_token 39 | configuration.api_key["csrftoken"] = csrf_token 40 | configuration.api_key["LEETCODE_SESSION"] = leetcode_session 41 | configuration.api_key["Referer"] = "https://leetcode.com" 42 | configuration.debug = False 43 | 44 | self.api_instance = leetcode.DefaultApi(leetcode.ApiClient(configuration)) 45 | 46 | def step(self, action: LeetCodeSubmission): 47 | """ 48 | Sends a submission to LeetCode and returns the result 49 | 50 | Args: 51 | action (LeetCodeSubmission): LeetCodeSubmission object 52 | 53 | Returns: 54 | status (str): 'Accepted' | 'Runtime Error'| 'Wrong Answer' | 'Submission Timed-Out' | 'Unknown' 55 | reward (bool): True if status is 'Accepted', False otherwise 56 | done (bool): True if status is 'Accepted', False otherwise 57 | submission_result (dict): LeetCode API response 58 | """ 59 | submission_result = self.__send_submission(action) 60 | 61 | reward, status = self.__calculate_reward(submission_result) 62 | 63 | self.reward = reward 64 | 65 | done = self.is_done() 66 | 67 | return status, reward, done, submission_result 68 | 69 | def reset(self): 70 | self.reward = False 71 | 72 | def __send_submission(self, sub: LeetCodeSubmission): 73 | self.__wait_for_cooldown() 74 | 75 | if sub.question_id is None: 76 | sub.question_id = id_from_slug(sub.question_slug, self.api_instance) 77 | 78 | submission = leetcode.Submission( 79 | judge_type="large", 80 | typed_code=sub.code, 81 | question_id=sub.question_id, 82 | test_mode=False, 83 | lang=sub.lang.value, 84 | ) 85 | 86 | submission_id = self.api_instance.problems_problem_submit_post( 87 | problem=sub.question_slug, body=submission 88 | ) 89 | 90 | time.sleep(sub.timeout) 91 | 92 | submission_result = self.api_instance.submissions_detail_id_check_get( 93 | id=submission_id.submission_id 94 | ) 95 | 96 | return submission_result 97 | 98 | def __calculate_reward(self, submission_result): 99 | if submission_result == {"state": "STARTED"}: 100 | status_msg = "Submission Timed-Out" 101 | 102 | elif ( 103 | "status" in submission_result.keys() 104 | and submission_result["status"] == "PENDING" 105 | ): 106 | status_msg = "Submission Timed-Out" 107 | 108 | elif "status_msg" in submission_result.keys(): 109 | status_msg = submission_result[ 110 | "status_msg" 111 | ] # 'Accepted' | 'Runtime Error'| 'Wrong Answer' 112 | 113 | else: 114 | status_msg = "Unknown" 115 | 116 | return status_msg == "Accepted", status_msg 117 | 118 | def __wait_for_cooldown(self): 119 | if self.last_run == None: 120 | self.last_run = datetime.now() 121 | else: 122 | while (datetime.now() - self.last_run).total_seconds() < self.cooldown: 123 | time.sleep(0.1) 124 | self.last_run = datetime.now() 125 | 126 | def is_done(self): 127 | return self.reward 128 | -------------------------------------------------------------------------------- /lbt/datasets_adapter/leetcode_sub/leetcode.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import leetcode 3 | 4 | 5 | def id_from_slug(slug: str, api_instance) -> str: 6 | """ 7 | Retrieves the id of the question with the given slug 8 | """ 9 | graphql_request = leetcode.GraphqlQuery( 10 | query=""" 11 | query getQuestionDetail($titleSlug: String!) { 12 | question(titleSlug: $titleSlug) { 13 | questionId 14 | } 15 | } 16 | """, 17 | variables={"titleSlug": slug}, 18 | operation_name="getQuestionDetail", 19 | ) 20 | response = ast.literal_eval(str(api_instance.graphql_post(body=graphql_request))) 21 | frontend_id = response["data"]["question"]["question_id"] 22 | return frontend_id 23 | 24 | 25 | def metadata_from_slug(slug: str, api_instance) -> str: 26 | """ 27 | Retrieves the metadata of the question with the given slug 28 | """ 29 | graphql_request = leetcode.GraphqlQuery( 30 | query=""" 31 | query getQuestionDetail($titleSlug: String!) { 32 | question(titleSlug: $titleSlug) { 33 | metaData 34 | } 35 | } 36 | """, 37 | variables={"titleSlug": slug}, 38 | operation_name="getQuestionDetail", 39 | ) 40 | response = ast.literal_eval(str(api_instance.graphql_post(body=graphql_request))) 41 | metadata = response["data"]["question"] 42 | return metadata 43 | -------------------------------------------------------------------------------- /lbt/datasets_adapter/leetcode_sub/types.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from pydantic import BaseModel 3 | from enum import Enum 4 | 5 | 6 | class ProgrammingLanguage(Enum): 7 | """ 8 | Enum for valid LeetCodeSubmission programming languages 9 | """ 10 | 11 | CPP = "c++" 12 | JAVA = "java" 13 | PYTHON = "python" 14 | PYTHON3 = "python3" 15 | C = "c" 16 | C_SHARP = "c#" 17 | JAVASCRIPT = "javascript" 18 | RUBY = "ruby" 19 | SWIFT = "swift" 20 | GO = "go" 21 | SCALA = "scala" 22 | KOTLIN = "kotlin" 23 | RUST = "rust" 24 | PHP = "php" 25 | TYPESCRIPT = "typescript" 26 | RACKET = "racket" 27 | ERLANG = "erlang" 28 | ELIXIR = "elixir" 29 | DART = "dart" 30 | MYSQL = "mysql" 31 | MS_SQL_SERVER = "ms sql server" 32 | ORACLE = "oracle" 33 | 34 | 35 | class LeetCodeSubmission(BaseModel): 36 | """ 37 | Model for a Leetcode Code Submission 38 | """ 39 | 40 | code: str 41 | lang: ProgrammingLanguage 42 | question_id: str 43 | question_slug: str 44 | question_id: Optional[str] = None 45 | timeout: int = 5 46 | -------------------------------------------------------------------------------- /lbt/datasets_adapter/math_dataset.py: -------------------------------------------------------------------------------- 1 | # from opencompass.registry import ICL_EVALUATORS # , TEXT_POSTPROCESSORS 2 | import re 3 | import random 4 | 5 | from datasets import concatenate_datasets 6 | 7 | from lbt.exam_scorer import BaseExamScorer 8 | from lbt.exam_maker import FixedExamMaker 9 | 10 | 11 | class MATHEvaluator: 12 | """ 13 | Copied from opencompass, as directly using opencompass interfere with logging. 14 | """ 15 | 16 | def _fix_fracs(self, string): 17 | substrs = string.split("\\frac") 18 | new_str = substrs[0] 19 | if len(substrs) > 1: 20 | substrs = substrs[1:] 21 | for substr in substrs: 22 | new_str += "\\frac" 23 | if substr[0] == "{": 24 | new_str += substr 25 | else: 26 | try: 27 | assert len(substr) >= 2 28 | except AssertionError: 29 | return string 30 | a = substr[0] 31 | b = substr[1] 32 | if b != "{": 33 | if len(substr) > 2: 34 | post_substr = substr[2:] 35 | new_str += "{" + a + "}{" + b + "}" + post_substr 36 | else: 37 | new_str += "{" + a + "}{" + b + "}" 38 | else: 39 | if len(substr) > 2: 40 | post_substr = substr[2:] 41 | new_str += "{" + a + "}" + b + post_substr 42 | else: 43 | new_str += "{" + a + "}" + b 44 | string = new_str 45 | return string 46 | 47 | def _fix_a_slash_b(self, string): 48 | if len(string.split("/")) != 2: 49 | return string 50 | a = string.split("/")[0] 51 | b = string.split("/")[1] 52 | try: 53 | a = int(a) 54 | b = int(b) 55 | assert string == "{}/{}".format(a, b) 56 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 57 | return new_string 58 | except AssertionError: 59 | return string 60 | 61 | def _remove_right_units(self, string): 62 | # "\\text{ " only ever occurs (at least in the val set) when describing 63 | # units 64 | if "\\text{ " in string: 65 | splits = string.split("\\text{ ") 66 | assert len(splits) == 2 67 | return splits[0] 68 | else: 69 | return string 70 | 71 | def _fix_sqrt(self, string): 72 | if "\\sqrt" not in string: 73 | return string 74 | splits = string.split("\\sqrt") 75 | new_string = splits[0] 76 | for split in splits[1:]: 77 | if split[0] != "{": 78 | a = split[0] 79 | new_substr = "\\sqrt{" + a + "}" + split[1:] 80 | else: 81 | new_substr = "\\sqrt" + split 82 | new_string += new_substr 83 | return new_string 84 | 85 | def _strip_string(self, string): 86 | # linebreaks 87 | string = string.replace("\n", "") 88 | 89 | # remove inverse spaces 90 | string = string.replace("\\!", "") 91 | 92 | # replace \\ with \ 93 | string = string.replace("\\\\", "\\") 94 | 95 | # replace tfrac and dfrac with frac 96 | string = string.replace("tfrac", "frac") 97 | string = string.replace("dfrac", "frac") 98 | 99 | # remove \( and \) 100 | string = string.replace("\\(", "") 101 | string = string.replace("\\)", "") 102 | 103 | # remove \left and \right 104 | string = string.replace("\\left", "") 105 | string = string.replace("\\right", "") 106 | 107 | # Remove circ (degrees) 108 | string = string.replace("^{\\circ}", "") 109 | string = string.replace("^\\circ", "") 110 | 111 | # remove dollar signs 112 | string = string.replace("\\$", "") 113 | 114 | # remove units (on the right) 115 | string = self._remove_right_units(string) 116 | 117 | # remove percentage 118 | string = string.replace("\\%", "") 119 | string = string.replace("\%", "") # noqa: W605 120 | 121 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, 122 | # add "0" if "." is the start of the string 123 | string = string.replace(" .", " 0.") 124 | string = string.replace("{.", "{0.") 125 | # if empty, return empty string 126 | if len(string) == 0: 127 | return string 128 | if string[0] == ".": 129 | string = "0" + string 130 | 131 | # to consider: get rid of e.g. "k = " or "q = " at beginning 132 | if len(string.split("=")) == 2: 133 | if len(string.split("=")[0]) <= 2: 134 | string = string.split("=")[1] 135 | 136 | # fix sqrt3 --> sqrt{3} 137 | string = self._fix_sqrt(string) 138 | 139 | # remove spaces 140 | string = string.replace(" ", "") 141 | 142 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works 143 | # with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 144 | string = self._fix_fracs(string) 145 | 146 | # manually change 0.5 --> \frac{1}{2} 147 | if string == "0.5": 148 | string = "\\frac{1}{2}" 149 | 150 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix 151 | # in case the model output is X/Y 152 | string = self._fix_a_slash_b(string) 153 | 154 | return string 155 | 156 | def is_equiv(self, str1, str2, verbose=False): 157 | if str1 is None and str2 is None: 158 | print("WARNING: Both None") 159 | return True 160 | if str1 is None or str2 is None: 161 | return False 162 | 163 | try: 164 | ss1 = self._strip_string(str1) 165 | ss2 = self._strip_string(str2) 166 | if verbose: 167 | print(ss1, ss2) 168 | return ss1 == ss2 169 | except: # noqa 170 | return str1 == str2 171 | 172 | def can_recall(self, extracted_answer, gt_answer): 173 | str1 = extracted_answer 174 | str2 = gt_answer 175 | try: 176 | ss1 = self._strip_string(str1) 177 | ss2 = self._strip_string(str2) 178 | return ss2 in ss1 179 | except: 180 | return str2 in str1 181 | 182 | 183 | class MathExamScorer(BaseExamScorer): 184 | NAME = "math" 185 | 186 | def __init__(self, recall_mode=False): 187 | super().__init__() 188 | self.evaluator = MATHEvaluator() 189 | # ICL_EVALUATORS.build( 190 | # {"type": "opencompass.datasets.MATHEvaluator"} 191 | # ) 192 | self.recall_mode = recall_mode 193 | 194 | def score_exam_result(self, exam_gt_item, exam_result_item): 195 | gt_answer = self.post_process(exam_gt_item["answer"]) 196 | extracted_answer = self.post_process(exam_result_item["rationale"]) 197 | exam_result_item["answer"] = extracted_answer 198 | if self.recall_mode: 199 | is_correct = self.evaluator.can_recall(extracted_answer, gt_answer) 200 | else: 201 | is_correct = self.evaluator.is_equiv(extracted_answer, gt_answer) 202 | return float(is_correct) 203 | 204 | @staticmethod 205 | def _normalize_final_answer(final_answer: str) -> str: 206 | """Normalize a final answer to a quantitative reasoning question.""" 207 | RE_SUBSTITUTIONS = [ 208 | (r"\\le(?!ft)", r"<"), # replace \le as <, but do not change "\left" 209 | (r"(? 0: 288 | # final_answer = re.findall(r"finalansweris(.*)", final_answer)[-1] 289 | 290 | # if len(re.findall(r"oxed\{(.*?)\}", final_answer)) > 0: 291 | # final_answer = re.findall(r"oxed\{(.*?)\}", final_answer)[-1] 292 | 293 | # if len(re.findall(r"\$\$(.*?)\$\$", final_answer)) > 0: 294 | # final_answer = re.findall(r"\$(.*?)\$", final_answer)[-1] 295 | # final_answer = final_answer.strip() 296 | # if "rac" in final_answer and "\\frac" not in final_answer: 297 | # final_answer = final_answer.replace("rac", "\\frac") 298 | 299 | # Normalize shorthand TeX: 300 | # \fracab -> \frac{a}{b} 301 | # \frac{abc}{bef} -> \frac{abc}{bef} 302 | # \fracabc -> \frac{a}{b}c 303 | final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) 304 | 305 | final_answer = re.sub( 306 | r"(? \sqrt{a} 312 | # \sqrtab -> sqrt{a}b 313 | final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) 314 | 315 | final_answer = final_answer.replace("$", "") 316 | 317 | # Normalize 100,000 -> 100000 318 | if final_answer.replace(",", "").isdigit(): 319 | final_answer = final_answer.replace(",", "") 320 | 321 | return final_answer 322 | 323 | # Copy and modified from opencompass.datasets.math 324 | @staticmethod 325 | def post_process(text: str) -> str: 326 | matches = re.findall(r"\[\[Final Answer\]\]:\n([^\n]+)\n", text) 327 | if not matches: 328 | answer = text.strip().split("\n")[-1] 329 | else: 330 | answer = matches[0] 331 | return MathExamScorer._normalize_final_answer(answer) 332 | # for maybe_ans in text.split("."): 333 | # if "final answer" in maybe_ans.lower(): 334 | # return normalize_final_answer(maybe_ans) 335 | # return normalize_final_answer(text.split(".")[0]) 336 | 337 | 338 | class MathMetaInfoMaker(FixedExamMaker): 339 | NAME = "math_metainfo" 340 | 341 | def __init__( 342 | self, 343 | exam_bank_dataset, 344 | selected_indexes=None, 345 | same_subject=True, 346 | level_controls=[ 347 | "=" 348 | ], # Number: the corresponding level; =: the same level as the teaching question; >: higher/harder levels ...; <: lower/easier levels ... 349 | num_exam_questions=16, 350 | random=False, # False: choose the first `num_exam_questions` that satisfy the meta-info control; True: random choose 351 | ): 352 | super().__init__(exam_bank_dataset, selected_indexes) 353 | 354 | self.same_subject = same_subject 355 | self.level_controls = level_controls 356 | self.num_exam_questions = num_exam_questions 357 | self.random = random 358 | if self.level_controls: 359 | assert "level" in self.exam_selected_dataset.features 360 | if self.same_subject: 361 | assert "subject" in self.exam_selected_dataset.features 362 | 363 | @staticmethod 364 | def _is_int(num): 365 | try: 366 | int(num) 367 | except ValueError: 368 | return False 369 | return True 370 | 371 | def _parse_permit_levels(self, teaching_level): 372 | permit_levels = [] 373 | for control in self.level_controls: 374 | if self._is_int(control): 375 | permit_levels.append(int(control)) 376 | else: 377 | assert control in ["=", ">", "<"] 378 | if teaching_level is None: 379 | self.logger.warn( 380 | "The `level` feature of the teaching item is not set. Level" 381 | f" control `{control}` not supported." 382 | ) 383 | continue 384 | if control == "=": 385 | permit_levels.append(teaching_level) 386 | elif control == ">": 387 | permit_levels += list(range(teaching_level + 1, 6)) 388 | elif control == ">": 389 | permit_levels += list(range(1, teaching_level)) 390 | return permit_levels 391 | 392 | def make_exam_questions(self, teaching_items): 393 | num_list = self._get_num_exam_items(teaching_items, self.num_exam_questions) 394 | final_selected_dataset = None 395 | for t_item, num_exam in zip(teaching_items, num_list): 396 | if num_exam == 0: 397 | continue 398 | 399 | # Filter according to item["subject"] 400 | if self.same_subject: 401 | if "subject" not in t_item: 402 | self.logger.warn( 403 | "The `level` feature of the teaching item is not set. Level" 404 | f" control `{control}` not supported." 405 | ) 406 | continue 407 | exam_selected_dataset = self.exam_selected_dataset.filter( 408 | lambda exam_item: exam_item["subject"] == t_item["subject"] 409 | ) 410 | 411 | # Filter according to item["level"] 412 | permit_levels = self._parse_permit_levels(t_item.get("level", None)) 413 | exam_selected_dataset = exam_selected_dataset.filter( 414 | lambda exam_item: exam_item["level"] in permit_levels 415 | ) 416 | 417 | if exam_selected_dataset.num_rows > num_exam: 418 | # Select `num_exam` exam items from the filtered dataset 419 | if self.random: 420 | # Random `num_exam` items 421 | all_indexes = list(range(exam_selected_dataset.num_rows)) 422 | indexes = random.sample(population=all_indexes, k=num_exam) 423 | else: 424 | # First `num_exam` items 425 | indexes = list(range(num_exam)) 426 | exam_selected_dataset = exam_selected_dataset.select(indexes) 427 | elif exam_selected_dataset.num_rows < num_exam: 428 | # Only warning 429 | # FIXME: rewrite this function to ensure returning `num_exam` exam items 430 | self.logger.warn( 431 | "The size of the returned exam dataset would be smaller than the" 432 | f" set `num_exam_questions`: {self.num_exam_questions}" 433 | ) 434 | 435 | if final_selected_dataset is None: 436 | final_selected_dataset = exam_selected_dataset 437 | else: 438 | final_selected_dataset = concatenate_datasets( 439 | [final_selected_dataset, exam_selected_dataset] 440 | ) 441 | 442 | return final_selected_dataset 443 | 444 | 445 | if __name__ == "__main__": 446 | from pprint import pprint 447 | from datasets import load_from_disk 448 | 449 | t_dataset = load_from_disk("examples/rationale/data/math_500").to_list() 450 | top16_maker = MathMetaInfoMaker( 451 | "examples/rationale/data/math_12k/", 452 | level_controls=["=", 5], 453 | num_exam_questions=16, 454 | random=False, 455 | ) 456 | random16_maker = MathMetaInfoMaker( 457 | "examples/rationale/data/math_12k/", 458 | level_controls=["=", 5], 459 | num_exam_questions=16, 460 | random=True, 461 | ) 462 | for t_set in [t_dataset[:1], t_dataset[:2], t_dataset[:3]]: 463 | print(f"num teaching items: {len(t_set)}") 464 | pprint(t_set) 465 | e_set_top16 = top16_maker.make_exam_questions(t_set) 466 | e_set_random16 = random16_maker.make_exam_questions(t_set) 467 | print("top16:") 468 | pprint(e_set_top16.select_columns(["level", "subject", "unique_id"]).to_list()) 469 | print("random16:") 470 | pprint( 471 | e_set_random16.select_columns(["level", "subject", "unique_id"]).to_list() 472 | ) 473 | -------------------------------------------------------------------------------- /lbt/datasets_adapter/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imagination-research/lbt/c14bf6e5799a06052f24caa1d41275d326d423c9/lbt/datasets_adapter/utils/__init__.py -------------------------------------------------------------------------------- /lbt/datasets_adapter/utils/add_test_cases.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from .utils_llm import LanguageFunction 3 | import os 4 | import inspect 5 | 6 | UTILS_DIR = os.path.dirname(inspect.getabsfile(inspect.currentframe())) 7 | 8 | 9 | def extract_test_cases(dataset: pd.DataFrame, lang: str) -> pd.DataFrame: 10 | """ 11 | Add test cases to the dataset 12 | Adds columns: 'test_cases' (List[str]) 13 | """ 14 | dataset = dataset.copy() 15 | dataset.reset_index(inplace=True, drop=True) 16 | dataset["test_cases"] = None 17 | 18 | # extract test cases 19 | for ind, row in dataset.iterrows(): 20 | print(f"Extracting test cases for problem {ind+1}/{len(dataset)}") 21 | examples = extract_examples(row["description"]) 22 | function_signature = row["signature"] 23 | test_cases = examples_to_test_cases(examples, function_signature, lang) 24 | dataset.at[ind, "test_cases"] = test_cases 25 | return dataset 26 | 27 | 28 | def extract_examples(description): 29 | """ 30 | Extract a natural language representation of the examples from the description 31 | """ 32 | inputs = [l for l in description.split("\n") if l.strip().startswith("Input")] 33 | outputs = [ 34 | l.strip("Output: ") 35 | for l in description.split("\n") 36 | if l.strip().startswith("Output") 37 | ] 38 | 39 | examples = [] 40 | 41 | for i, (input_str, output_str) in enumerate(zip(inputs, outputs)): 42 | example_str = f"Example {i+1}:\n{input_str}\nOutput: {output_str}" 43 | examples.append(example_str) 44 | return "\n\n".join(examples) 45 | 46 | 47 | def examples_to_test_cases( 48 | examples: str, function_signature: str, language: str 49 | ) -> str: 50 | """ 51 | Extract test cases from a natural language representation of the examples 52 | """ 53 | lang_function = LanguageFunction.from_yaml( 54 | os.path.join(UTILS_DIR, "extract_tests.yaml") 55 | ) 56 | response = lang_function( 57 | function_signature=function_signature, examples=examples, language=language 58 | ) 59 | test_cases = response["response"].split("\n") 60 | return test_cases 61 | -------------------------------------------------------------------------------- /lbt/datasets_adapter/utils/clean_leetcode.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import re 3 | 4 | 5 | def remove_class_dependent(dataset: pd.DataFrame) -> pd.DataFrame: 6 | """ 7 | Remove problems that depend on class definitions 8 | """ 9 | dataset = dataset.copy() 10 | no_defs_inds = [ 11 | ind 12 | for ind, row in dataset.iterrows() 13 | if row["cpp_snippet"].split(" ")[0] == "class" 14 | ] 15 | no_defs = dataset.iloc[no_defs_inds] 16 | return no_defs 17 | 18 | 19 | def remove_void(dataset: pd.DataFrame) -> pd.DataFrame: 20 | """ 21 | Remove problems that request a void implementation 22 | """ 23 | dataset = dataset.copy() 24 | # ret_inds = [ind for ind, row in dataset.iterrows() if '\"\"\"' in row['python3_snippet'].split('\n')[2]] 25 | ret_inds = [] 26 | for ind, row in dataset.iterrows(): 27 | if len(row["python3_snippet"].split("\n")) > 2: 28 | if '"""' in row["python3_snippet"].split("\n")[2]: 29 | ret_inds.append(ind) 30 | ret = dataset.drop(ret_inds) 31 | return ret 32 | 33 | 34 | def remove_class_impls(dataset: pd.DataFrame) -> pd.DataFrame: 35 | """ 36 | Remove problems that request a class implementation 37 | """ 38 | dataset = dataset.copy() 39 | function_name_regex = r"(?<=def\s)\w+" 40 | impl_inds = [ 41 | ind 42 | for ind, row in dataset.iterrows() 43 | if re.search(function_name_regex, row["python3_snippet"]).group(0) == "__init__" 44 | ] 45 | no_impl = dataset.drop(impl_inds) 46 | return no_impl 47 | 48 | 49 | def remove_examples(dataset: pd.DataFrame) -> pd.DataFrame: 50 | """ 51 | Return a copy of the dataset without examples in the descriptions 52 | """ 53 | dataset = dataset.copy() 54 | for ind, row in dataset.iterrows(): 55 | res = docstring_remove_empty(docstring_remove_examples(row["description"])) 56 | dataset.at[ind, "description"] = res 57 | 58 | return dataset 59 | 60 | 61 | def docstring_remove_examples(docstring: str): 62 | """ 63 | Remove the examples from the docstring 64 | """ 65 | lines = [l.strip() for l in docstring.split("\n")] 66 | for i, line in enumerate(lines): 67 | if "Example" in line: 68 | return "\n".join(lines[:i]) 69 | return docstring 70 | 71 | 72 | def docstring_remove_empty(desc: str): 73 | """ 74 | Remove empty lines from the docstring 75 | """ 76 | return "\n".join(line for line in desc.split("\n") if line.strip()) 77 | -------------------------------------------------------------------------------- /lbt/datasets_adapter/utils/extract_tests.yaml: -------------------------------------------------------------------------------- 1 | gpt_model_cfgs: 2 | - model_type: azure_openai 3 | model_cfg: 4 | model: gpt-35-turbo 5 | api_version: 2024-02-15-preview 6 | api_endpoint: https://infini-ai-east-us-2.openai.azure.com/ 7 | sample_cfg: 8 | top_p: 1.0 9 | temperature: 1.0 10 | 11 | function: 12 | exam_template: | 13 | FUNCTION SIGNATURE: 14 | {function_signature} 15 | 16 | PSEUDOCODE TESTS 17 | {examples} 18 | 19 | LANGUAGE: {language} 20 | 21 | stub_items: 22 | - role: "user" 23 | content: | 24 | FUNCTION SIGNATURE: 25 | def minReverseOperations(n: int, p: int, banned: List[int], k: int) -> List[int]: 26 | 27 | PSEUDOCODE TESTS 28 | Example 1: 29 | Input: n = 4, p = 0, banned = [1,2], k = 4 30 | Output: [0,-1,-1,1] 31 | 32 | Example 2: 33 | Input: n = 5, p = 0, banned = [2,4], k = 3 34 | Output: [0,-1,-1,-1,-1] 35 | 36 | Example 3: 37 | Input: n = 4, p = 2, banned = [0,1,3], k = 1 38 | Output: [-1,-1,0,-1] 39 | 40 | LANGUAGE: python 41 | 42 | - role: "assistant" 43 | content: | 44 | assert minReverseOperations(4, 0, [1,2], 4) == [0,-1,-1,1] 45 | assert minReverseOperations(5, 0, [2,4], 3) == [0,-1,-1,-1,-1] 46 | assert minReverseOperations(4, 2, [0,1,3], 1) == [-1,-1,0,-1] -------------------------------------------------------------------------------- /lbt/datasets_adapter/utils/fetch_leetcode.py: -------------------------------------------------------------------------------- 1 | import dotenv 2 | import pandas as pd 3 | import ast 4 | from bs4 import BeautifulSoup 5 | import leetcode 6 | import html2text 7 | import re 8 | import urllib.parse 9 | from typing import Dict 10 | 11 | import dotenv 12 | import html2text 13 | import leetcode 14 | import pandas as pd 15 | import requests 16 | from bs4 import BeautifulSoup 17 | 18 | h = html2text.HTML2Text() 19 | h.ignore_links = True 20 | h.ignore_images = True 21 | h.ignore_emphasis = True 22 | 23 | dotenv.load_dotenv() 24 | 25 | import ast 26 | import re 27 | from abc import ABC, abstractmethod 28 | from typing import List 29 | 30 | import astunparse 31 | 32 | ## ------------------------------------------------------------------------------------------------- 33 | ## LeetCode formatter 34 | ## ------------------------------------------------------------------------------------------------- 35 | 36 | 37 | class SubmissionFormatter(ABC): 38 | """ 39 | Class that converts between HumanEval and Leetcode submission formats. 40 | """ 41 | 42 | @staticmethod 43 | @abstractmethod 44 | def to_leetcode(humaneval_snippet: str): 45 | """ 46 | Convert the string to leetcode format 47 | """ 48 | 49 | @staticmethod 50 | @abstractmethod 51 | def to_humaneval(leetcode_snippet: str): 52 | """ 53 | Convert the string to humaneval format 54 | """ 55 | 56 | @staticmethod 57 | @abstractmethod 58 | def add_docstring(snippet: str, description: str): 59 | """ 60 | Add a docstring to the snippet 61 | """ 62 | 63 | @staticmethod 64 | @abstractmethod 65 | def extract_signature(source: str) -> str: 66 | """ 67 | Extract the signature from the function 68 | """ 69 | 70 | 71 | class PythonSubmissionFormatter: 72 | @staticmethod 73 | def add_docstring(snippet: str, description: str): 74 | snippet = snippet.strip("\n") 75 | # Add 4 spaces to the beginning of every line 76 | description = "\n".join([" " * 4 + line for line in description.splitlines()]) 77 | docstring = f''' """ 78 | {description} 79 | """''' 80 | return f"{snippet}\n{docstring}\n" 81 | 82 | @staticmethod 83 | def to_humaneval(leetcode_snippet: str) -> str: 84 | try: 85 | tree = ast.parse(leetcode_snippet) 86 | except IndentationError: 87 | class_source = leetcode_snippet.strip() + "\n pass" 88 | tree = ast.parse(class_source) 89 | func_node = tree.body[0].body[0] 90 | func_node.args.args.pop(0) # Remove 'self' argument 91 | 92 | if isinstance(func_node.body[-1], ast.Pass): 93 | func_node.body.pop() 94 | 95 | new_tree = ast.Module(body=[func_node], type_ignores=[]) 96 | return f"{astunparse.unparse(new_tree).strip()}\n" 97 | 98 | @staticmethod 99 | def to_leetcode(humaneval_snippet: str, class_name: str = "Solution") -> str: 100 | # Get imports 101 | imports = "\n".join( 102 | PythonSubmissionFormatter.extract_imports(humaneval_snippet) 103 | ) 104 | # Remove imports 105 | # humaneval_snippet = re.sub(r"^from\s+\S+\s+import.*|^import.*", "", humaneval_snippet, flags=re.MULTILINE) 106 | try: 107 | tree = ast.parse(humaneval_snippet) 108 | except IndentationError: 109 | function_source = humaneval_snippet.strip() + "\n pass" 110 | tree = ast.parse(function_source) 111 | 112 | func_node = None 113 | for child in ast.iter_child_nodes(tree): 114 | if isinstance(child, ast.FunctionDef): 115 | func_node = child 116 | break 117 | 118 | docstring = ast.get_docstring(func_node) 119 | if docstring is not None: 120 | func_node.body.pop(0) 121 | 122 | if func_node.body and isinstance(func_node.body[-1], ast.Pass): 123 | func_node.body.pop() 124 | 125 | # Add 'self' argument back to the function 126 | self_arg = ast.arg(arg="self", annotation=None) 127 | func_node.args.args.insert(0, self_arg) 128 | class_node = ast.ClassDef( 129 | name=class_name, 130 | bases=[], 131 | keywords=[], 132 | body=[func_node], 133 | decorator_list=[], 134 | ) 135 | new_tree = ast.Module(body=[class_node], type_ignores=[]) 136 | return f"{imports}\n{astunparse.unparse(new_tree).strip()}\n" 137 | 138 | @staticmethod 139 | def extract_imports(source: str) -> List[str]: 140 | """ 141 | Extract top level imports 142 | """ 143 | standard_import = re.compile(r"^import (\w+(?:, \w+)*)") 144 | from_import = re.compile(r"^from (\w+) import (\w+(?:, \w+)*)") 145 | 146 | imports = [] 147 | 148 | for line in source.splitlines(): 149 | std_match = standard_import.match(line) 150 | from_match = from_import.match(line) 151 | 152 | if std_match: 153 | imports.append(std_match.group(0)) 154 | 155 | if from_match: 156 | imports.append(from_match.group(0)) 157 | 158 | return imports 159 | 160 | @staticmethod 161 | def extract_signature(source: str) -> str: 162 | return source.replace("def ", "", 1)[:-1] 163 | 164 | 165 | class RustSubmissionFormatter: 166 | @staticmethod 167 | def add_docstring(snippet: str, description: str): 168 | # Formatting the docstring in Rust style using /* */ 169 | rust_docstring = f"/*\n{description}\n*/" 170 | 171 | # Combining the docstring and the signature 172 | result = f"{rust_docstring}\n{snippet}" 173 | return result 174 | 175 | @staticmethod 176 | def extract_imports(source: str) -> List[str]: 177 | rust_import = re.compile(r"^use ([\w::]+(?:\s+as\s+\w+)?)(?:;\s*)?$") 178 | 179 | imports = [] 180 | 181 | for line in source.splitlines(): 182 | rust_match = rust_import.match(line) 183 | 184 | if rust_match: 185 | imports.append(rust_match.group(0).strip()) 186 | 187 | return imports 188 | 189 | @staticmethod 190 | def remove_imports(source: str) -> str: 191 | rust_import = re.compile(r"^use ([\w::]+(?:\s+as\s+\w+)?)(?:;\s*)?$") 192 | 193 | lines = source.splitlines() 194 | new_lines = [] 195 | for line in lines: 196 | if rust_import.match(line): 197 | print(f"Removing import: {line}") 198 | else: 199 | new_lines.append(line) 200 | 201 | return "\n".join(new_lines) 202 | 203 | @staticmethod 204 | def to_humaneval(leetcode_snippet: str) -> str: 205 | # Remove comments 206 | function_source = re.sub(r"//.*", "", leetcode_snippet) 207 | # Using the re.DOTALL flag to match across multiple lines 208 | function_source = re.sub(r"/\*.*?\*/", "", function_source, flags=re.DOTALL) 209 | 210 | # Remove solution class def 211 | function_source = re.sub(r"impl Solution \{\n", "", function_source) 212 | reversed_source = function_source[::-1] 213 | reversed_substituted = re.sub(r"\}", "", reversed_source, count=1) 214 | function_source = reversed_substituted[::-1] 215 | 216 | # Remove pub from function 217 | function_source = re.sub(r"pub ", "", function_source) 218 | 219 | # Unindent function 220 | whitespace = leading_whitespace_count(function_source) 221 | function_source = "\n".join( 222 | [line[whitespace:] for line in function_source.splitlines()] 223 | ) 224 | function_source = function_source.strip() 225 | 226 | # Remove whitespace from every line in the function 227 | return f"{function_source}\n" 228 | 229 | @staticmethod 230 | def to_leetcode(humaneval_snippet: str, struct_name: str = "Solution") -> str: 231 | imports = "\n".join(RustSubmissionFormatter.extract_imports(humaneval_snippet)) 232 | function_source = RustSubmissionFormatter.remove_imports(humaneval_snippet) 233 | 234 | function_source = re.sub(r"//.*", "", function_source) # Remove comments 235 | function_source = re.sub(r"/\*.*?\*/", "", function_source, flags=re.DOTALL) 236 | function_source = function_source.strip() 237 | function_source = re.sub( 238 | r"fn ", "pub fn ", function_source, count=1 239 | ) # Add pub to root function 240 | return f"{imports}\nimpl {struct_name} {{\n{function_source}\n}}\n" # Add impl struct_name { } around function 241 | 242 | @staticmethod 243 | def extract_signature(source: str) -> str: 244 | return source.strip("fn ").replace("{", "").replace("}", "").strip().strip("\n") 245 | 246 | 247 | ## ------------------------------------------------------------------------------------------------- 248 | ## Small Tools 249 | ## ------------------------------------------------------------------------------------------------- 250 | 251 | 252 | def leading_whitespace_count(s): 253 | # Split the string into lines and get the first line 254 | first_line = [l for l in s.splitlines() if l][0] if s else "" 255 | 256 | # Find the index of the first non-whitespace character 257 | non_whitespace_index = next( 258 | (i for i, char in enumerate(first_line) if not char.isspace()), None 259 | ) 260 | 261 | # If the entire line consists of whitespaces (or is empty), then return its length 262 | if non_whitespace_index is None: 263 | return len(first_line) 264 | 265 | return non_whitespace_index 266 | 267 | 268 | def format_integer(n): 269 | """Format the integer to have a length of 4 by padding with zeroes.""" 270 | return str(n).zfill(4)[:4] 271 | 272 | 273 | def get_info(question_slug: str, api_instance): 274 | """ 275 | Retrieves the metadata of the question with the given slug 276 | """ 277 | graphql_request = leetcode.GraphqlQuery( 278 | query=""" 279 | query getQuestionDetail($titleSlug: String!) { 280 | question(titleSlug: $titleSlug) { 281 | codeSnippets { 282 | lang 283 | langSlug 284 | code 285 | __typename 286 | } 287 | content 288 | title 289 | topicTags { 290 | name 291 | slug 292 | } 293 | } 294 | } 295 | """, 296 | variables={"titleSlug": question_slug}, 297 | operation_name="getQuestionDetail", 298 | ) 299 | response = ast.literal_eval(str(api_instance.graphql_post(body=graphql_request))) 300 | data = response["data"]["question"] 301 | return data 302 | 303 | 304 | ## ------------------------------------------------------------------------------------------------- 305 | ## Fetch datasets and solutions 306 | ## ------------------------------------------------------------------------------------------------- 307 | 308 | 309 | def fetch_solutions(dataset: pd.DataFrame, lang: str) -> pd.DataFrame: 310 | """ 311 | Fetch the solutions for the given lang 312 | """ 313 | dataset = dataset.copy() 314 | for ind, row in dataset.iterrows(): 315 | print(f"Fetching solution for problem {ind+1}/{len(dataset)}") 316 | try: 317 | solution = fetch_solution( 318 | row["frontend_question_id"], row["question_title"], lang 319 | ) 320 | except: 321 | solution = "N/A" 322 | dataset.at[ind, "solution"] = solution if solution is not None else "" 323 | return dataset 324 | 325 | 326 | def fetch_solution( 327 | frontend_question_id: int, question_title: str, lang: str = "python3" 328 | ): 329 | """Get the solution of the question from the LeetCode github repository.""" 330 | LANG_EXT_MAP = { 331 | "python3": "py", 332 | "java": "java", 333 | "cpp": "cpp", 334 | } 335 | 336 | if lang not in LANG_EXT_MAP: 337 | raise ValueError(f"Solutions not supported for Language {lang}") 338 | 339 | FORMATTER_MAP: Dict[str, SubmissionFormatter] = { 340 | "python3": PythonSubmissionFormatter, 341 | "rust": RustSubmissionFormatter, 342 | } 343 | question_id = format_integer(int(frontend_question_id)) 344 | 345 | url = f"https://raw.githubusercontent.com/walkccc/LeetCode/main/solutions/{question_id}. {question_title}/{question_id}.{LANG_EXT_MAP[lang]}" 346 | encoded_url = urllib.parse.quote(url, safe=":/") 347 | response = requests.get(encoded_url) 348 | if response.status_code == 404: 349 | return None 350 | return FORMATTER_MAP[lang].to_humaneval(response.text) 351 | 352 | 353 | def fetch_dataset(api_instance, topic="algorithms", difficulty=3, paid_only=False): 354 | """ 355 | Get the coding questions from leetcode 356 | """ 357 | question_infos = api_instance.api_problems_topic_get(topic=topic) 358 | print(f"Fetched question infos") 359 | 360 | questions = [ 361 | q 362 | for q in question_infos.stat_status_pairs 363 | if q.difficulty.level == difficulty and q.paid_only == paid_only 364 | ] 365 | 366 | df = pd.DataFrame() 367 | for ind, question in enumerate(questions): 368 | if ind > 999: 369 | break 370 | print(f"Fetching code snippets for problem {ind + 1}/{len(questions)}") 371 | question_slug = question.stat.question__title_slug 372 | info = get_info(question_slug, api_instance) 373 | snippets = info["code_snippets"] 374 | content = BeautifulSoup(info["content"], features="html.parser") 375 | text_content = h.handle(str(content)) 376 | text_content = "\n".join(line.lstrip() for line in text_content.split("\n")) 377 | text_content = re.sub("\n\n+", "\n\n", text_content) 378 | text_content = text_content.strip().strip("\n") 379 | tags = [] 380 | for tag in info["topic_tags"]: 381 | tags.append(tag["name"]) 382 | 383 | df.at[ind, "question_slug"] = question.stat.question__title_slug 384 | df.at[ind, "question_title"] = question.stat.question__title 385 | df.at[ind, "frontend_question_id"] = int(question.stat.frontend_question_id) 386 | df.at[ind, "question_id"] = int(question.stat.question_id) 387 | df.at[ind, "description"] = text_content 388 | df.at[ind, "tags"] = ",".join(tags) 389 | 390 | for snippet in snippets: 391 | df.at[ind, snippet["lang_slug"] + "_snippet"] = snippet["code"] 392 | 393 | return df 394 | -------------------------------------------------------------------------------- /lbt/datasets_adapter/utils/format_leetcode.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from .utils_leetcode import lines_to_jsonl 3 | from .fetch_leetcode import ( 4 | PythonSubmissionFormatter, 5 | RustSubmissionFormatter, 6 | SubmissionFormatter, 7 | ) 8 | 9 | FORMATTERS = { 10 | "python3": PythonSubmissionFormatter, 11 | "rust": RustSubmissionFormatter, 12 | } 13 | 14 | 15 | def format_problems(dataset: pd.DataFrame, lang: str): 16 | """ 17 | Convert problems to functions with their descriptsions as docstrings 18 | Adds columns: 'signature', 'prompt' 19 | """ 20 | formatter: SubmissionFormatter = FORMATTERS.get(lang) 21 | dataset = dataset.copy() 22 | for ind, row in dataset.iterrows(): 23 | formatted_problem = formatter.to_humaneval(row[f"{lang}_snippet"]) 24 | prompt = formatter.add_docstring(formatted_problem, row["description"]) 25 | signature = formatter.extract_signature(formatted_problem) 26 | dataset.at[ind, "signature"] = signature 27 | dataset.at[ind, "prompt"] = prompt 28 | return dataset 29 | 30 | 31 | def to_jsonl(dataset: pd.DataFrame, path: str): 32 | """ 33 | Save the dataset to a jsonl file 34 | """ 35 | print(f"Writing dataset to {path}") 36 | lines = [] 37 | for ind, row in dataset.iterrows(): 38 | task_id = row["question_slug"] 39 | test_cases = "\n".join(row.get("test_cases", [])) 40 | solution = row.get("solution", "") 41 | prompt = row["prompt"] 42 | signature = row["signature"] 43 | docstring = row["description"] 44 | tags = row["tags"] 45 | 46 | line = { 47 | "task_id": task_id, 48 | "prompt": prompt, 49 | "canonical_solution": solution, 50 | "test": test_cases, 51 | "signature": signature, 52 | "docstring": docstring, 53 | "tags": tags, 54 | } 55 | 56 | lines.append(line) 57 | 58 | lines_to_jsonl(lines, path) 59 | -------------------------------------------------------------------------------- /lbt/datasets_adapter/utils/transform_code.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Dataset 2 | import json 3 | 4 | 5 | ## ------------------------------------------------------------------------------------------------- 6 | ## Dataset Transformation 7 | ## ------------------------------------------------------------------------------------------------- 8 | def SynthesisDatasetTrans(code_dataset): 9 | Synthesis_Prompt = "Write a python function '{}' to solve the following problem: {} Please note: (1) Import the necessary Python packages, and If the function head requires additional packages, ensure they are imported before the function head; (2) Pay attention to Python's indentation format; (3) Absolutely refrain from outputting any unnecessary explanatory text after generating the code." 10 | 11 | Synthesis_list = [] 12 | for code_sample in code_dataset: 13 | index = code_sample["prompt"].find("def") 14 | code_sample_temp = code_sample["prompt"][index:].split('"""') 15 | if len(code_sample_temp) == 1: 16 | code_sample_temp = code_sample["prompt"][index:].split("'''") 17 | code_sample["question"] = Synthesis_Prompt.format( 18 | code_sample_temp[0].strip(), code_sample_temp[1].strip() 19 | ) 20 | 21 | # deal with test 22 | index = code_sample["test"].find("def") 23 | code_sample["test"] = code_sample["test"][index:] 24 | if "def" in code_sample["test"]: 25 | code_sample["test"] = code_sample["test"] + "\n\ncheck({})".format( 26 | code_sample["entry_point"] 27 | ) 28 | 29 | Synthesis_list.append(code_sample) 30 | 31 | return Synthesis_list 32 | 33 | 34 | def DebugDatasetTrans(code_dataset): 35 | Debug_Prompt = "{}\n\nFix bugs in {}." 36 | 37 | Debug_list = [] 38 | for code_sample in code_dataset: 39 | code_sample_temp = code_sample["prompt"].split('"""') 40 | if len(code_sample_temp) == 1: 41 | code_sample_temp = code_sample["prompt"].split("'''") 42 | 43 | # TODO: add some bugs here, we can use teacher to generate some buggy code. 44 | code_sample_temp = code_sample_temp[0] + code_sample["canonical_solution"] 45 | Debug_list.append( 46 | Debug_Prompt.format(code_sample_temp, code_sample["entry_point"]) 47 | ) 48 | return Debug_list 49 | 50 | 51 | def ExplainDatasetTrans(code_dataset): 52 | Explain_Prompt = "{}\n\nProvide a concise natural language description of the function using at most 500 characters." 53 | 54 | Explain_list = [] 55 | for code_sample in code_dataset: 56 | code_sample_temp = code_sample["prompt"].split('"""') 57 | if len(code_sample_temp) == 1: 58 | code_sample_temp = code_sample["prompt"].split("'''") 59 | code_sample_temp = code_sample_temp[0] + code_sample["canonical_solution"] 60 | Explain_list.append(Explain_Prompt.format(code_sample_temp)) 61 | return Explain_list 62 | 63 | 64 | if __name__ == "__main__": 65 | # We use this code to evaluate the coding ability of the instruction-tuned LLMs. 66 | code_dataset = load_dataset("openai_humaneval", split="test") 67 | 68 | # generate the synthesis data list 69 | synthesis_data_list = SynthesisDatasetTrans(code_dataset) 70 | # dump jsonl 71 | with open( 72 | "./examples/datasets/humaneval.jsonl", "w" 73 | ) as file: 74 | for sample in synthesis_data_list: 75 | file.write(json.dumps(sample) + "\n") 76 | 77 | synthesis_data_list = Dataset.from_list(synthesis_data_list) 78 | synthesis_data_list.save_to_disk( 79 | "./datasets/datasets/humaneval" 80 | ) 81 | -------------------------------------------------------------------------------- /lbt/datasets_adapter/utils/utils_leetcode.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import requests 4 | from bs4 import BeautifulSoup 5 | import os 6 | from typing import List 7 | import leetcode 8 | import leetcode.auth 9 | from typing import Dict 10 | import string 11 | 12 | 13 | def lines_to_jsonl(lines: List[Dict], file_path: str): 14 | """ 15 | Convert a list of dicts to a jsonl file 16 | """ 17 | # Empty the current file 18 | open(file_path, "w").close() 19 | 20 | with open(file_path, "a") as file: 21 | for dict_data in lines: 22 | json_line = json.dumps(dict_data) 23 | file.write(json_line + os.linesep) 24 | 25 | 26 | def get_api_instance(leetcode_session, csrf_token): 27 | """ 28 | Get the leetcode api instance 29 | """ 30 | configuration = leetcode.Configuration() 31 | 32 | configuration.api_key["x-csrftoken"] = csrf_token 33 | configuration.api_key["csrftoken"] = csrf_token 34 | configuration.api_key["LEETCODE_SESSION"] = leetcode_session 35 | configuration.api_key["Referer"] = "https://leetcode.com" 36 | configuration.debug = False 37 | 38 | api_instance = leetcode.DefaultApi(leetcode.ApiClient(configuration)) 39 | 40 | return api_instance 41 | 42 | 43 | def get_question(url): 44 | """ 45 | Get the question page 46 | """ 47 | while True: 48 | res = requests.get(url) # type: ignore 49 | status = res.status_code 50 | if status == 200: 51 | return res 52 | elif status == 404: 53 | return None 54 | else: 55 | print(status) 56 | time.sleep(300) 57 | 58 | 59 | def title_slug(title): 60 | """ 61 | Format the title into a title slug 62 | """ 63 | return "-".join(title.lower().split()) 64 | 65 | 66 | def slug_to_title(question_slug: str) -> str: 67 | """Format a Leetcode question's slug as a title""" 68 | return string.capwords(question_slug.replace("-", " ")).strip() 69 | 70 | 71 | def format_integer(n): 72 | """Format the integer to have a length of 4 by padding with zeroes.""" 73 | return str(n).zfill(4)[:4] 74 | 75 | 76 | def get_code_snippets(url): 77 | """ 78 | Gets the code snippets for the given question url 79 | """ 80 | res = get_question(url) 81 | if res is None: 82 | return None 83 | soup = BeautifulSoup(res.content, "html.parser") 84 | script_tag = soup.find("script", {"type": "application/json"}) 85 | data = dict(json.loads(script_tag.string)) 86 | queries = data["props"]["pageProps"]["dehydratedState"]["queries"] 87 | query = [ 88 | i 89 | for i in queries 90 | if "question" in i["state"]["data"] 91 | and "codeSnippets" in i["state"]["data"]["question"] 92 | ][0] 93 | code_snippets = query["state"]["data"]["question"]["codeSnippets"] 94 | return code_snippets 95 | 96 | 97 | url = "https://leetcode.com/graphql/" 98 | 99 | payload = lambda slug: json.dumps( 100 | { 101 | "query": "\n query consolePanelConfig($titleSlug: String!) {\n question(titleSlug: $titleSlug) {\n exampleTestcaseList\n }\n}\n ", 102 | "variables": {"titleSlug": slug}, 103 | "operationName": "consolePanelConfig", 104 | } 105 | ) 106 | 107 | headers = { 108 | "authority": "leetcode.com", 109 | "accept": "*/*", 110 | "accept-language": "en-US,en;q=0.9", 111 | "authorization": "", 112 | "baggage": "sentry-environment=production,sentry-release=8f466f72,sentry-transaction=%2Fproblems%2F%5Bslug%5D%2F%5B%5B...tab%5D%5D,sentry-public_key=2a051f9838e2450fbdd5a77eb62cc83c,sentry-trace_id=897972800d1c46e5a5d499f12244a91b,sentry-sample_rate=0.004", 113 | "content-type": "application/json", 114 | "cookie": 'gr_user_id=35b498db-f28f-485f-8b44-417f8fba15ed; __stripe_mid=04d7a882-553c-499c-8866-bcf56aac8ef6ed918f; __atuvc=1%7C5; NEW_PROBLEMLIST_PAGE=1; csrftoken=9BiGVDJiJS7iFJKVYZ1CNMNulRAvYUdlezUlp1oYOrsR2zVsk9mZh1MD6C2d6twV; messages="9b526d67f2587ca52e83b4431db91f6bd6abdac1$[[\\"__json_message\\"\\0540\\05425\\054\\"You have signed out.\\"]\\054[\\"__json_message\\"\\0540\\05425\\054\\"Successfully signed in as beckles168.\\"]\\054[\\"__json_message\\"\\0540\\05425\\054\\"You have signed out.\\"]\\054[\\"__json_message\\"\\0540\\05425\\054\\"Successfully signed in as leetcodeexecutor.\\"]\\054[\\"__json_message\\"\\0540\\05425\\054\\"You have signed out.\\"]\\054[\\"__json_message\\"\\0540\\05425\\054\\"Successfully signed in as beckles168.\\"]]"; 87b5a3c3f1a55520_gr_last_sent_cs1=beckles168; _gid=GA1.2.2067840721.1681477917; LEETCODE_SESSION=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJfYXV0aF91c2VyX2lkIjoiOTIwNjcxMyIsIl9hdXRoX3VzZXJfYmFja2VuZCI6ImFsbGF1dGguYWNjb3VudC5hdXRoX2JhY2tlbmRzLkF1dGhlbnRpY2F0aW9uQmFja2VuZCIsIl9hdXRoX3VzZXJfaGFzaCI6IjU2MGIwZGIzMjVjOTcwNTk3OGFkZDI4MjY0MzM5NjU0NzVjZDhmMjYiLCJpZCI6OTIwNjcxMywiZW1haWwiOiJiZWNrbGVzMTY4QGdtYWlsLmNvbSIsInVzZXJuYW1lIjoiYmVja2xlczE2OCIsInVzZXJfc2x1ZyI6ImJlY2tsZXMxNjgiLCJhdmF0YXIiOiJodHRwczovL2Fzc2V0cy5sZWV0Y29kZS5jb20vdXNlcnMvYXZhdGFycy9hdmF0YXJfMTY4MDY1MjE2OC5wbmciLCJyZWZyZXNoZWRfYXQiOjE2ODE2NjIzMDAsImlwIjoiNzIuMTk1LjEzNC4zMSIsImlkZW50aXR5IjoiNzIzYzUxMjYzYzgwZjZiZTc5ZmEyMTE5MWVlMGIzODciLCJzZXNzaW9uX2lkIjozNzg5MzIzNH0.BJV_u27JVniHZ73kI76oTTkFGK4OHNJPpv-F58pZBUc; 87b5a3c3f1a55520_gr_session_id=73357b2f-2c35-49f4-8256-556aa503d604; 87b5a3c3f1a55520_gr_last_sent_sid_with_cs1=73357b2f-2c35-49f4-8256-556aa503d604; 87b5a3c3f1a55520_gr_session_id_73357b2f-2c35-49f4-8256-556aa503d604=true; _gat=1; 87b5a3c3f1a55520_gr_cs1=beckles168; _ga=GA1.1.1043183799.1675086637; __stripe_sid=d8eb8303-f932-4cfd-92ef-9ed80b781cae827bea; _ga_CDRWKZTDEX=GS1.1.1681662302.39.1.1681665678.0.0.0; _dd_s=rum=0&expire=1681666578197; LEETCODE_SESSION=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJfYXV0aF91c2VyX2lkIjoiOTIwNjcxMyIsIl9hdXRoX3VzZXJfYmFja2VuZCI6ImFsbGF1dGguYWNjb3VudC5hdXRoX2JhY2tlbmRzLkF1dGhlbnRpY2F0aW9uQmFja2VuZCIsIl9hdXRoX3VzZXJfaGFzaCI6IjU2MGIwZGIzMjVjOTcwNTk3OGFkZDI4MjY0MzM5NjU0NzVjZDhmMjYiLCJpZCI6OTIwNjcxMywiZW1haWwiOiJiZWNrbGVzMTY4QGdtYWlsLmNvbSIsInVzZXJuYW1lIjoiYmVja2xlczE2OCIsInVzZXJfc2x1ZyI6ImJlY2tsZXMxNjgiLCJhdmF0YXIiOiJodHRwczovL2Fzc2V0cy5sZWV0Y29kZS5jb20vdXNlcnMvYXZhdGFycy9hdmF0YXJfMTY4MDY1MjE2OC5wbmciLCJyZWZyZXNoZWRfYXQiOjE2ODE2NjIzMDAsImlwIjoiNTQuODYuNTAuMTM5IiwiaWRlbnRpdHkiOiI3MjNjNTEyNjNjODBmNmJlNzlmYTIxMTkxZWUwYjM4NyIsInNlc3Npb25faWQiOjM3ODkzMjM0fQ.DtQ8KCL7Qsua4Bp-vOMJfg4VJUjX4NSxhdNXs756x4M; csrftoken=9BiGVDJiJS7iFJKVYZ1CNMNulRAvYUdlezUlp1oYOrsR2zVsk9mZh1MD6C2d6twV', 115 | "origin": "https://leetcode.com", 116 | "random-uuid": "4922dfe3-8c3c-1d65-9b7a-ca84bfe9f756", 117 | "referer": "https://leetcode.com/problems/two-sum/", 118 | "sec-ch-ua": '"Chromium";v="112", "Google Chrome";v="112", "Not:A-Brand";v="99"', 119 | "sec-ch-ua-mobile": "?0", 120 | "sec-ch-ua-platform": '"macOS"', 121 | "sec-fetch-dest": "empty", 122 | "sec-fetch-mode": "cors", 123 | "sec-fetch-site": "same-origin", 124 | "sentry-trace": "897972800d1c46e5a5d499f12244a91b-a37933a4a1d212e3-0", 125 | "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36", 126 | "x-csrftoken": "9BiGVDJiJS7iFJKVYZ1CNMNulRAvYUdlezUlp1oYOrsR2zVsk9mZh1MD6C2d6twV", 127 | } 128 | 129 | 130 | def test_cases_from_slug(slug: str) -> List[str]: 131 | response = requests.post(url, headers=headers, data=payload(slug)) 132 | return dict(response.json())["data"]["question"]["exampleTestcaseList"] 133 | -------------------------------------------------------------------------------- /lbt/datasets_adapter/utils/utils_llm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List, Sequence 3 | import json 4 | import yaml 5 | from termcolor import colored 6 | 7 | from fastchat.model import get_conversation_template 8 | from fastchat.conversation import get_conv_template 9 | 10 | from lbt.base import Component 11 | 12 | 13 | class LanguageFunction: 14 | def __init__(self, config: Dict, **model_kwargs) -> None: 15 | self.chat_model = Component.init_from_cfg(config["gpt_model_cfgs"], "model") 16 | self.sample_cfg = config["gpt_model_cfgs"]["sample_cfg"] 17 | 18 | # read prompts 19 | function = dict(config["function"]) 20 | self.stub_items = function.get("stub_items", []) 21 | self.exam_template = function["exam_template"] 22 | 23 | def __call__(self, callback=False, **kwargs) -> Dict: 24 | """ 25 | Call the Agent Function with the given arguments. 26 | """ 27 | # add fschat 28 | try: 29 | # exact match conversation name 30 | conv = get_conv_template(self.chat_model.conv_template_type) 31 | except KeyError: 32 | # get through model adapter 33 | conv = get_conversation_template(self.chat_model.conv_template_type) 34 | # For base model, use conv_template type "raw" 35 | 36 | for t_item in self.stub_items: 37 | if t_item["role"] == "user": 38 | conv.append_message(conv.roles[0], t_item["content"]) 39 | elif t_item["role"] == "assistant": 40 | conv.append_message(conv.roles[1], t_item["content"]) 41 | else: 42 | raise ValueError(f"Invalid role: {t_item['role']}") 43 | 44 | # exam samples 45 | exam_user = self.exam_template.format(**kwargs) 46 | exam_assistant = None 47 | 48 | conv.append_message(conv.roles[0], exam_user) 49 | conv.append_message(conv.roles[1], None) 50 | 51 | # call the chat model for responses 52 | response = self.chat_model.text_generator( 53 | (conv, exam_assistant), return_full_text=False, **self.sample_cfg 54 | ) 55 | response_dict = {"response": response} 56 | return response_dict 57 | 58 | @classmethod 59 | def from_yaml(cls, filepath: str): 60 | """ 61 | Load an agent from a YAML file. 62 | 63 | Args: 64 | filepath (str): The path to the YAML file. 65 | 66 | Returns: 67 | Agent: The agent. 68 | """ 69 | with open(filepath, "r", encoding="utf-8") as file: 70 | yaml_obj = yaml.safe_load(file) 71 | return cls(yaml_obj) 72 | -------------------------------------------------------------------------------- /lbt/exam_maker.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os.path as osp 3 | from abc import abstractmethod 4 | from typing import Union, Optional, List 5 | 6 | import numpy as np 7 | import torch 8 | from datasets import Dataset, load_from_disk 9 | from fastchat.model import get_conversation_template 10 | from fastchat.conversation import get_conv_template 11 | 12 | from lbt.base import Component 13 | 14 | 15 | ### ---- exam makers ---- 16 | class BaseExamMaker(Component): 17 | REGISTRY = "exam_maker" 18 | 19 | def _get_num_exam_items(self, teaching_items, total_num_exam): 20 | num_t = len(teaching_items) 21 | if total_num_exam % num_t != 0: 22 | if num_t > total_num_exam: 23 | # Top-1 neighbor for the front teaching_items 24 | num_list = [1] * total_num_exam + [0] * (num_t - total_num_exam) 25 | else: 26 | # Averagely allocate neighbor num quota to teaching items, 27 | # the first teaching item get extra quota 28 | num_other = total_num_exam // num_t 29 | num_first = total_num_exam - num_other * (num_t - 1) 30 | num_list = [num_first] + [num_other] * (num_t - 1) 31 | self.logger.warn( 32 | f"Want to fetch {total_num_exam} exam questions that are" 33 | f" similar to {num_t} teaching items. {total_num_exam} %" 34 | f" {num_t} != 0." 35 | ) 36 | else: 37 | num_list = [total_num_exam // num_t] * num_t 38 | return num_list 39 | 40 | @abstractmethod 41 | def make_exam_questions(self, teaching_items): 42 | # (1) choose from `exam_bank_dataset` 43 | # (2) use a strong `exam_proposal_model` to propose questions 44 | pass 45 | 46 | 47 | class FixedExamMaker(BaseExamMaker): 48 | NAME = "fixed" 49 | 50 | def __init__( 51 | self, 52 | exam_bank_dataset: Union[str, Dataset], 53 | selected_indexes: Optional[Union[str, List]] = None, 54 | ): 55 | super().__init__() 56 | 57 | if isinstance(exam_bank_dataset, str): 58 | assert osp.exists(exam_bank_dataset) 59 | self.exam_bank_dataset_path = exam_bank_dataset 60 | self.exam_bank_dataset = load_from_disk(self.exam_bank_dataset_path) 61 | else: 62 | assert isinstance(exam_bank_dataset, Dataset) 63 | self.exam_bank_dataset = exam_bank_dataset 64 | 65 | if selected_indexes is None: 66 | self.exam_selected_dataset = self.exam_bank_dataset 67 | else: 68 | if isinstance(selected_indexes, str): 69 | selected_indexes = eval(selected_indexes) 70 | else: 71 | assert isinstance(selected_indexes, (tuple, list)) 72 | self.exam_selected_dataset = self.exam_bank_dataset.select(selected_indexes) 73 | self.selected_indexes = selected_indexes 74 | 75 | def make_exam_questions(self, teaching_items): 76 | return self.exam_selected_dataset 77 | 78 | 79 | class QuesSimilarityExamMaker(FixedExamMaker): 80 | """ 81 | Ref: opencompass icl_topk_retriever 82 | """ 83 | 84 | NAME = "ques_similarity" 85 | 86 | def __init__( 87 | self, 88 | exam_bank_dataset: Union[str, Dataset], 89 | selected_indexes: Optional[Union[str, List]] = None, 90 | sentence_transformers_model_name: Optional[str] = "all-mpnet-base-v2", 91 | knn_pickle_path: str = "knn16.pkl", 92 | num_exam_questions: int = 1, 93 | num_repetitions: int = 8, 94 | ): 95 | super().__init__(exam_bank_dataset, selected_indexes) 96 | 97 | self.num_exam_questions = num_exam_questions 98 | self.num_repetitions = num_repetitions 99 | 100 | if knn_pickle_path: 101 | import pickle 102 | 103 | with open(knn_pickle_path, "rb") as f: 104 | self.knn_pkl = pickle.load(f) 105 | else: 106 | from sentence_transformers import SentenceTransformer 107 | 108 | self.knn_pkl = None 109 | 110 | self.model = SentenceTransformer(sentence_transformers_model_name) 111 | self.model = self.model.to("cuda") 112 | self.model.eval() 113 | self.emb_dim = self.model.get_sentence_embedding_dimension() 114 | 115 | self.emb_index = self._create_index() 116 | 117 | def _create_index(self): 118 | import faiss 119 | 120 | avail_exam_questions = self.exam_selected_dataset["question"] 121 | all_embs = [] 122 | for question in avail_exam_questions: 123 | emb = self._embed(question) 124 | all_embs.append(emb) 125 | self.all_embs = np.stack(all_embs).astype("float32") 126 | 127 | size = len(avail_exam_questions) 128 | index = faiss.IndexIDMap(faiss.IndexFlatIP(self.emb_dim)) 129 | id_list = np.array(list(range(size))) 130 | index.add_with_ids(self.all_embs, id_list) 131 | return index 132 | 133 | def _knn_search(self, question, knn_num): 134 | emb = self._embed(question) 135 | self.logger.info( 136 | f"Retrieving {knn_num}-NN exam question indexes for teaching question" 137 | f' """{question}""" ...' 138 | ) 139 | emb = np.expand_dims(emb, axis=0).astype("float32") 140 | near_ids = self.emb_index.search(emb, knn_num)[1][0].tolist() 141 | return near_ids 142 | 143 | def _embed(self, question): 144 | with torch.no_grad(): 145 | emb = self.model.encode(question, show_progress_bar=False) 146 | return emb 147 | 148 | def make_exam_questions(self, teaching_items): 149 | knn_num_list = self._get_num_exam_items(teaching_items, self.num_exam_questions) 150 | all_exam_question_ids = [] 151 | for t_item, knn_num in zip(teaching_items, knn_num_list): 152 | if knn_num == 0: 153 | continue 154 | 155 | if self.knn_pkl: 156 | for idx, question in self.knn_pkl[t_item["question"]][:knn_num]: 157 | assert self.exam_selected_dataset[idx]["question"] == question 158 | all_exam_question_ids.extend([idx] * self.num_repetitions) 159 | else: 160 | exam_question_ids = self._knn_search(t_item["question"], knn_num) 161 | all_exam_question_ids.extend(exam_question_ids * self.num_repetitions) 162 | return self.exam_selected_dataset.select(all_exam_question_ids) 163 | 164 | 165 | class FunctionalExamMaker(FixedExamMaker): 166 | NAME = "func" 167 | 168 | def __init__( 169 | self, 170 | exam_bank_dataset: Union[str, Dataset], 171 | selected_indexes: Optional[Union[str, List]] = None, 172 | num_exam_questions: int = 3, 173 | num_repetitions: int = 3, 174 | ): 175 | super().__init__(exam_bank_dataset, selected_indexes) 176 | 177 | self.num_exam_questions = num_exam_questions 178 | self.num_repetitions = num_repetitions 179 | 180 | def make_exam_questions(self, teaching_items): 181 | all_exam_question_ids = [] 182 | for t_item in teaching_items: 183 | exam_question_ids = [] 184 | for i, e_item in enumerate(self.exam_selected_dataset): 185 | if t_item["unique_id"] == e_item["unique_id"]: 186 | exam_question_ids.append(i) 187 | all_exam_question_ids.extend(exam_question_ids * self.num_repetitions) 188 | 189 | assert ( 190 | len(all_exam_question_ids) == self.num_exam_questions * self.num_repetitions 191 | ) 192 | return self.exam_selected_dataset.select(all_exam_question_ids) 193 | 194 | 195 | ### ---- exam prompters ---- 196 | class ExamPrompter(Component): 197 | REGISTRY = "exam_prompter" 198 | NAME = "basic" 199 | 200 | PROMPT_ROLE_SWITCH_STR = "[ROLESWITCHING assistant:]" 201 | 202 | def __init__( 203 | self, 204 | demo_template, 205 | exam_template, 206 | debug_template=None, 207 | instruction="", 208 | use_multi_round_conv=False, 209 | stub_teaching_items=None, 210 | ): 211 | super().__init__() 212 | self.demo_template = demo_template 213 | self.exam_template = exam_template 214 | self.debug_template = debug_template 215 | self.instruction = instruction 216 | self.use_multi_round_conv = use_multi_round_conv 217 | self.stub_teaching_items = stub_teaching_items or [] 218 | 219 | if self.use_multi_round_conv: 220 | assert self.PROMPT_ROLE_SWITCH_STR in self.demo_template, ( 221 | "`use_multi_round_conv==True`: Using multiple conversation rounds to" 222 | " present the teaching demostrations. Must specify the conversation" 223 | " switching point in `demo_template`." 224 | ) 225 | else: 226 | assert self.PROMPT_ROLE_SWITCH_STR not in self.demo_template 227 | 228 | def make_exam_prompt_fastchat( 229 | self, teaching_items, exam_item, conv_template_type, debug=False 230 | ) -> str: 231 | try: 232 | # exact match conversation name 233 | conv = get_conv_template(conv_template_type) 234 | except KeyError: 235 | # get through model adapter 236 | conv = get_conversation_template(conv_template_type) 237 | assert ( 238 | conv.name != "one_shot" 239 | ), f"`{conv_template_type}` not supported in `fastchat`." 240 | # For base model, use conv_template type "raw" 241 | 242 | if not debug: 243 | _exam_item = exam_item.copy() 244 | demo_items = self.stub_teaching_items + teaching_items 245 | 246 | if not debug: 247 | _exam_item = exam_item.copy() 248 | demo_items = self.stub_teaching_items + teaching_items 249 | 250 | if self.use_multi_round_conv: 251 | demo_template_user, demo_template_assistant = self.demo_template.split( 252 | self.PROMPT_ROLE_SWITCH_STR 253 | ) 254 | for t_item in demo_items: 255 | demo_user = demo_template_user.format(**t_item) 256 | demo_assistant = demo_template_assistant.format(**t_item) 257 | conv.append_message(conv.roles[0], demo_user) 258 | conv.append_message(conv.roles[1], demo_assistant) 259 | else: 260 | demo = "\n\n\n".join( 261 | [self.demo_template.format(**t_item) for t_item in (demo_items)] 262 | ) 263 | _exam_item["demo"] = demo 264 | if self.use_multi_round_conv: 265 | demo_template_user, demo_template_assistant = self.demo_template.split( 266 | self.PROMPT_ROLE_SWITCH_STR 267 | ) 268 | for t_item in demo_items: 269 | demo_user = demo_template_user.format(**t_item) 270 | demo_assistant = demo_template_assistant.format(**t_item) 271 | conv.append_message(conv.roles[0], demo_user) 272 | conv.append_message(conv.roles[1], demo_assistant) 273 | else: 274 | demo = "\n\n\n".join( 275 | [self.demo_template.format(**t_item) for t_item in (demo_items)] 276 | ) 277 | _exam_item["demo"] = demo 278 | 279 | if self.PROMPT_ROLE_SWITCH_STR in self.exam_template: 280 | # has partial answer 281 | exam_template_user, exam_template_assistant = self.exam_template.split( 282 | self.PROMPT_ROLE_SWITCH_STR 283 | ) 284 | exam_user = exam_template_user.format(**_exam_item) 285 | exam_assistant = exam_template_assistant.format(**_exam_item) 286 | else: 287 | exam_user = self.exam_template.format(**_exam_item) 288 | exam_assistant = None 289 | 290 | conv.append_message(conv.roles[0], self.instruction + exam_user) 291 | conv.append_message(conv.roles[1], None) 292 | return (conv, exam_assistant) 293 | 294 | def make_exam_prompt_chat_template(self, teaching_items, exam_item) -> str: 295 | conv = [{"role": "system", "content": "You are a helpful assistant."}] 296 | 297 | _exam_item = exam_item.copy() 298 | demo_items = self.stub_teaching_items + teaching_items 299 | 300 | if self.use_multi_round_conv: 301 | demo_template_user, demo_template_assistant = self.demo_template.split( 302 | self.PROMPT_ROLE_SWITCH_STR 303 | ) 304 | for t_item in demo_items: 305 | demo_user = demo_template_user.format(**t_item) 306 | demo_assistant = demo_template_assistant.format(**t_item) 307 | conv.append({"role": "user", "content": demo_user}) 308 | conv.append({"role": "assistant", "content": demo_assistant}) 309 | else: 310 | demo = "\n\n\n".join( 311 | [self.demo_template.format(**t_item) for t_item in (demo_items)] 312 | ) 313 | _exam_item["demo"] = demo 314 | 315 | if self.PROMPT_ROLE_SWITCH_STR in self.exam_template: 316 | # has partial answer 317 | exam_template_user, exam_template_assistant = self.exam_template.split( 318 | self.PROMPT_ROLE_SWITCH_STR 319 | ) 320 | exam_user = exam_template_user.format(**_exam_item) 321 | exam_assistant = exam_template_assistant.format(**_exam_item) 322 | else: 323 | exam_user = self.exam_template.format(**_exam_item) 324 | exam_assistant = None 325 | 326 | conv.append({"role": "user", "content": self.instruction + exam_user}) 327 | return (conv, exam_assistant) 328 | 329 | 330 | if __name__ == "__main__": 331 | from pprint import pprint 332 | from termcolor import colored 333 | 334 | teaching_items = load_from_disk( 335 | "../NLP-playground/examples/rationale/data/math_solution_worstRationale_10" 336 | ) 337 | name_mapping = {"problem": "question", "solution": "rationale"} 338 | teaching_items = teaching_items.select_columns( 339 | ["solution", "answer", "problem"] 340 | ).rename_columns(name_mapping) 341 | teaching_items = teaching_items.to_list()[:2] 342 | print(colored("Teaching items:\n----", "green")) 343 | pprint(teaching_items) 344 | 345 | exam_band_dataset = load_from_disk( 346 | "../NLP-playground/examples/rationale/data/math_1500" 347 | ) 348 | exam_item = exam_band_dataset[0] 349 | exam_item = {new_n: exam_item[old_n] for old_n, new_n in name_mapping.items()} 350 | print(colored("Exam items:\n----", "green")) 351 | pprint(exam_item) 352 | 353 | # OpenCompass ICL template for Math 354 | single_conv_prompter = ExamPrompter( 355 | demo_template="""Question:\n{question}\n\nSolution:\n{rationale}\n\nFinal Answer:\nThe final answer is $${answer}$$.\n""", 356 | exam_template=( 357 | "{demo}\n\n\nQuestion:\n{question}\n\n[ROLESWITCHING assistant:]Solution:\n" 358 | ), 359 | use_multi_round_conv=False, 360 | ) 361 | multi_conv_prompter = ExamPrompter( 362 | demo_template="""Question:\n{question}\n\n[ROLESWITCHING assistant:]Solution:\n{rationale}\n\nFinal Answer:\nThe final answer is $${answer}$$.\n""", 363 | exam_template="Question:\n{question}\n\n[ROLESWITCHING assistant:]Solution:\n", 364 | use_multi_round_conv=True, 365 | ) 366 | 367 | for conv_template_type in ["llama-2", "qwen-7b-chat", "chatglm3"]: 368 | single_conv_prompt = single_conv_prompter.make_exam_prompt( 369 | teaching_items, exam_item, conv_template_type 370 | ) 371 | print( 372 | colored( 373 | f"[{conv_template_type}] Single-round conversation prompt string\n----", 374 | "green", 375 | ) 376 | ) 377 | print(single_conv_prompt) 378 | 379 | multi_conv_prompt = multi_conv_prompter.make_exam_prompt( 380 | teaching_items, exam_item, conv_template_type 381 | ) 382 | print( 383 | colored( 384 | f"[{conv_template_type}] Multi-round conversation prompt string\n----", 385 | "green", 386 | ) 387 | ) 388 | print(multi_conv_prompt) 389 | -------------------------------------------------------------------------------- /lbt/exam_scorer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import abstractmethod 3 | from lbt.base import Component 4 | 5 | 6 | class BaseExamScorer(Component): 7 | REGISTRY = "exam_scorer" 8 | 9 | @abstractmethod 10 | def score_exam_result(self, exam_gt_item, exam_result_item): 11 | """ 12 | Return score. 13 | """ 14 | pass 15 | 16 | 17 | class ModelExamScorer(BaseExamScorer): 18 | NAME = "model_based" 19 | 20 | def __init__(self, model_type, model_cfg): 21 | super().__init__() 22 | 23 | self.model = Component.init_from_cfg( 24 | {"model_type": model_type, "model_cfg": model_cfg}, registry_name="model" 25 | ) 26 | 27 | def score_exam_result(self, exam_gt_item, exam_result_item): 28 | # TODO: (1) make judge prompt (prompt template); (2) text_generator; (3) parse score from results (parsing) 29 | pass 30 | -------------------------------------------------------------------------------- /lbt/models/__init__.py: -------------------------------------------------------------------------------- 1 | from lbt.models.base import * 2 | -------------------------------------------------------------------------------- /lbt/models/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import os 3 | from openai import OpenAI, AzureOpenAI 4 | import openai 5 | 6 | import torch 7 | from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer 8 | 9 | from tenacity import ( 10 | retry, 11 | retry_if_not_exception_type, 12 | stop_after_attempt, 13 | wait_random_exponential, 14 | ) 15 | 16 | 17 | from lbt.base import Component 18 | 19 | 20 | class BaseModel(Component): 21 | REGISTRY = "model" 22 | 23 | @property 24 | @abstractmethod 25 | def name(self): 26 | pass 27 | 28 | @property 29 | @abstractmethod 30 | def conv_template_type(self): 31 | pass 32 | 33 | @abstractmethod 34 | def text_generator( 35 | self, iterator, batch_size, num_return_sequences, **generate_kwargs 36 | ): 37 | """ 38 | Return a generator object that yields the generations to the prompts in `iterator` one by one 39 | """ 40 | 41 | 42 | class StubModel(BaseModel): 43 | NAME = "stub" 44 | 45 | @property 46 | def name(self): 47 | return "stub" 48 | 49 | @property 50 | def conv_template_type(self): 51 | return "raw" 52 | 53 | def text_generator( 54 | self, iterator, batch_size, num_return_sequences, **generate_kwargs 55 | ): 56 | for _ in iterator: 57 | yield [ 58 | {"generated_text": f"random answer {index}"} 59 | for index in range(num_return_sequences) 60 | ] 61 | 62 | 63 | class OpenAIModel(BaseModel): 64 | NAME = "openai" 65 | 66 | def __init__( 67 | self, 68 | model, 69 | name=None, 70 | api_key=None, 71 | fastchat=True, 72 | ): 73 | super().__init__() 74 | 75 | self._model = model 76 | self._api_key = api_key or os.environ.get("OPENAI_API_KEY") 77 | self._client = OpenAI(api_key=self._api_key) 78 | self._name = name or model 79 | self.fastchat = fastchat 80 | 81 | @property 82 | def conv_template_type(self): 83 | return "chatgpt" 84 | 85 | @property 86 | def name(self): 87 | return self._name 88 | 89 | def text_generator(self, iterator, return_full_text, **generate_kwargs): 90 | return self._request( 91 | iterator, return_full_text=return_full_text, **generate_kwargs 92 | ) 93 | 94 | @retry( 95 | retry=retry_if_not_exception_type((openai.BadRequestError, TypeError)), 96 | wait=wait_random_exponential(min=1, max=60), 97 | stop=stop_after_attempt(15), 98 | ) 99 | def _retry_wrapper(self, messages, **generate_kwargs): 100 | response = self._client.chat.completions.create( 101 | model=self._model, 102 | messages=messages, 103 | **generate_kwargs, 104 | ) 105 | return response 106 | 107 | def _request(self, conv_iterator, return_full_text, **generate_kwargs): 108 | for conv, partial_answer in conv_iterator: 109 | messages = [] 110 | messages.append({"role": "system", "content": conv.system_message}) 111 | if partial_answer is not None: 112 | conv.messages[-1][1] = partial_answer 113 | else: 114 | del conv.messages[-1] 115 | for message in conv.messages: 116 | messages.append({"role": message[0], "content": message[1]}) 117 | response = self._retry_wrapper(messages, **generate_kwargs) 118 | answers = [] 119 | for choice in response.choices: 120 | answer = choice.message.content 121 | if return_full_text: 122 | answer = partial_answer + answer 123 | answers.append({"generated_text": answer}) 124 | yield answers 125 | 126 | 127 | class AzureOpenAIModel(OpenAIModel): 128 | NAME = "azure_openai" 129 | 130 | def __init__( 131 | self, 132 | model, 133 | name=None, 134 | api_key=None, 135 | api_endpoint=None, 136 | api_version="2024-02-15-preview", 137 | fastchat=True, 138 | ): 139 | BaseModel.__init__(self) 140 | 141 | self._model = model 142 | self._api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY") 143 | self._api_endpoint = api_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT") 144 | self._api_version = api_version 145 | self._client = AzureOpenAI( 146 | api_key=self._api_key, 147 | api_version=self._api_version, 148 | azure_endpoint=self._api_endpoint, 149 | ) 150 | self._name = name or "azure_" + model 151 | self.fastchat = fastchat 152 | 153 | 154 | class HFModel(BaseModel): 155 | NAME = "huggingface" 156 | 157 | def __init__( 158 | self, 159 | path, 160 | pt_path=None, 161 | max_new_tokens=1024, 162 | conv_template_type=None, 163 | name=None, 164 | fastchat=False, 165 | ): 166 | super().__init__() 167 | 168 | self.path = path 169 | self._conv_template_type = conv_template_type or path 170 | self._name = name or path 171 | self.fastchat = fastchat 172 | 173 | self.tokenizer = AutoTokenizer.from_pretrained( 174 | path, 175 | trust_remote_code=True, 176 | padding_side="left", 177 | ) 178 | self.model = AutoModelForCausalLM.from_pretrained( 179 | path, 180 | device_map="auto", 181 | trust_remote_code=True, 182 | torch_dtype=torch.float16, 183 | ).eval() 184 | 185 | if pt_path: 186 | state_dict = torch.load(pt_path)["state"] 187 | self.model.load_state_dict(state_dict) 188 | print(f"loading pre-trained weights from {pt_path}") 189 | 190 | self.generator = pipeline( 191 | task="text-generation", 192 | tokenizer=self.tokenizer, 193 | model=self.model, 194 | device_map="auto", 195 | trust_remote_code=True, 196 | ) 197 | # set default max_new_tokens 198 | if self.generator.model.generation_config.max_new_tokens is None: 199 | self.generator.model.generation_config.max_new_tokens = max_new_tokens 200 | 201 | # set padding token 202 | if self.generator.tokenizer.pad_token_id is None: 203 | if self.generator.model.generation_config.pad_token_id is not None: 204 | self.generator.tokenizer.pad_token_id = ( 205 | self.generator.model.generation_config.pad_token_id 206 | ) 207 | else: 208 | eos_token_id = self.generator.model.generation_config.eos_token_id 209 | if isinstance(eos_token_id, (list, tuple)): 210 | eos_token_id = eos_token_id[0] 211 | self.generator.tokenizer.pad_token_id = eos_token_id 212 | 213 | @property 214 | def conv_template_type(self): 215 | return self._conv_template_type 216 | 217 | @property 218 | def name(self): 219 | return self._name 220 | 221 | def _transform_conv_iterator_to_prompt_iterator(self, conv_iterator): 222 | for conv, partial_answer in conv_iterator: 223 | if self.fastchat: 224 | prompt = conv.get_prompt() 225 | else: 226 | prompt = self.generator.tokenizer.apply_chat_template( 227 | conv, tokenize=False, add_generation_prompt=True 228 | ) 229 | 230 | if partial_answer is not None: 231 | prompt += partial_answer 232 | 233 | yield prompt 234 | 235 | def text_generator( 236 | self, iterator, batch_size, num_return_sequences, **generate_kwargs 237 | ): 238 | # TODO: parallel test 239 | if "llama-3" in self.path.lower(): 240 | terminators = [ 241 | self.generator.tokenizer.eos_token_id, 242 | self.generator.tokenizer.convert_tokens_to_ids("<|eot_id|>"), 243 | ] 244 | 245 | return self.generator( 246 | self._transform_conv_iterator_to_prompt_iterator(iterator), 247 | batch_size=batch_size, 248 | num_return_sequences=num_return_sequences, 249 | eos_token_id=terminators, 250 | **generate_kwargs, 251 | ) 252 | 253 | return self.generator( 254 | self._transform_conv_iterator_to_prompt_iterator(iterator), 255 | batch_size=batch_size, 256 | num_return_sequences=num_return_sequences, 257 | **generate_kwargs, 258 | ) 259 | 260 | -------------------------------------------------------------------------------- /lbt/qa_item.py: -------------------------------------------------------------------------------- 1 | class QAItem(dict): 2 | def __init__(self, question, rationale=None, answer=None, prompt=None, task_id=None): 3 | super().__init__() 4 | self["question"] = question 5 | self["rationale"] = rationale 6 | self["answer"] = answer 7 | self["prompt"] = prompt 8 | self["task_id"] = task_id 9 | 10 | def __getattr__(self, attr_name): 11 | if attr_name in self: 12 | return self[attr_name] 13 | raise super().__getattribute__(attr_name) 14 | 15 | def __setattr__(self, attr_name, value): 16 | self[attr_name] = value 17 | -------------------------------------------------------------------------------- /lbt/test.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Any, Tuple 2 | import numpy as np 3 | 4 | from tqdm import tqdm 5 | from datasets import Dataset 6 | 7 | from lbt.qa_item import QAItem 8 | from lbt.models.base import BaseModel 9 | from lbt.exam_maker import ExamPrompter 10 | from lbt.exam_scorer import BaseExamScorer 11 | 12 | 13 | def aggregate_scores(scores: List[float]) -> float: 14 | scores = np.array(scores) 15 | return scores.mean() 16 | 17 | 18 | def test_single_student( 19 | student: BaseModel, 20 | exam_prompter: ExamPrompter, 21 | exam_scorer: BaseExamScorer, 22 | teaching_items: List[QAItem], 23 | exam_dataset: Dataset, 24 | sample_cfg: Dict[str, Any], 25 | debug: bool = False, 26 | ) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: 27 | def exam_prompt_generator(debug=False, rationales=False, answers=False): 28 | for exam_item in exam_dataset: 29 | if debug: 30 | exam_item["rationale"] = rationales[i] 31 | exam_item["answer"] = answers[i] 32 | if student.fastchat: 33 | yield exam_prompter.make_exam_prompt_fastchat( 34 | teaching_items, exam_item, student.conv_template_type 35 | ) 36 | else: 37 | yield exam_prompter.make_exam_prompt_chat_template( 38 | teaching_items, exam_item 39 | ) 40 | 41 | single_student_exam_rationales = [] 42 | single_student_exam_answers = [] 43 | single_student_exam_scores = [] 44 | 45 | for i, single_question_exam_rationales in enumerate( 46 | tqdm( 47 | student.text_generator( 48 | exam_prompt_generator(), 49 | return_full_text=False, 50 | **sample_cfg, 51 | ), 52 | total=len(exam_dataset), 53 | ) 54 | ): 55 | exam_gt_item = exam_dataset[i] 56 | 57 | single_question_exam_rationales = [ 58 | _["generated_text"] for _ in single_question_exam_rationales 59 | ] 60 | exam_result_items = [ 61 | QAItem(question=exam_gt_item["question"], rationale=rationale, task_id=exam_gt_item.get("task_id", None)) 62 | for rationale in single_question_exam_rationales 63 | ] 64 | 65 | scores = [ 66 | exam_scorer.score_exam_result(exam_gt_item, exam_result_item) 67 | for exam_result_item in exam_result_items 68 | ] 69 | single_question_exam_answers = [ 70 | exam_result_item["answer"] for exam_result_item in exam_result_items 71 | ] 72 | 73 | single_student_exam_rationales.append(single_question_exam_rationales) 74 | single_student_exam_answers.append(single_question_exam_answers) 75 | single_student_exam_scores.append(scores) 76 | 77 | # add a debug loop, check each question's answer 78 | if debug: 79 | single_student_exam_answers_debug = [] 80 | single_student_exam_scores_debug = [] 81 | for i, single_question_exam_rationales_debug in enumerate( 82 | tqdm( 83 | student.text_generator( 84 | exam_prompt_generator(debug=debug, rationales=single_student_exam_rationales, answers=single_student_exam_answers), 85 | return_full_text=False, 86 | **sample_cfg, 87 | ), 88 | total=len(exam_dataset), 89 | ) 90 | ): 91 | exam_gt_item = exam_dataset[i] 92 | 93 | # if there are no error, skip 94 | if '# ' not in single_student_exam_answers[i][0][:5]: 95 | single_student_exam_answers_debug.append(single_student_exam_answers[i]) 96 | single_student_exam_scores_debug.append(single_student_exam_scores[i]) 97 | continue 98 | 99 | single_question_exam_rationales_debug = [ 100 | _["generated_text"] for _ in single_question_exam_rationales_debug 101 | ] 102 | exam_result_items = [ 103 | QAItem(question=exam_gt_item["question"], rationale=rationale, task_id=exam_gt_item.get("task_id", None)) 104 | for rationale in single_question_exam_rationales_debug 105 | ] 106 | 107 | scores = [ 108 | exam_scorer.score_exam_result(exam_gt_item, exam_result_item) 109 | for exam_result_item in exam_result_items 110 | ] 111 | single_question_exam_answers = [ 112 | exam_result_item["answer"] for exam_result_item in exam_result_items 113 | ] 114 | 115 | single_student_exam_answers_debug.append(single_question_exam_answers) 116 | single_student_exam_scores_debug.append(scores) 117 | # rename the debug information 118 | single_student_exam_answers = single_student_exam_answers_debug 119 | single_student_exam_scores = single_student_exam_scores_debug 120 | 121 | return ( 122 | single_student_exam_rationales, 123 | single_student_exam_answers, 124 | single_student_exam_scores, 125 | ) 126 | -------------------------------------------------------------------------------- /lbt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # The `lbt.util` sub-package is largly copied from https://github.com/walkerning/aw_nas 2 | import sys 3 | import inspect 4 | 5 | from lbt.utils.log import * 6 | from lbt.utils.registry import * 7 | 8 | 9 | def get_default_argspec(func): 10 | sig = inspect.signature(func) # pylint: disable=no-member 11 | return [ 12 | (n, param.default) 13 | for n, param in sig.parameters.items() 14 | if not param.default is param.empty 15 | ] 16 | 17 | 18 | def _add_text_prefix(text, prefix): 19 | lines = text.split("\n") 20 | return "\n".join([prefix + line if line else line for line in lines]) 21 | -------------------------------------------------------------------------------- /lbt/utils/log.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # pylint: disable=invalid-name 3 | 4 | import os 5 | import sys 6 | import logging 7 | 8 | __all__ = ["logger", "getLogger"] 9 | 10 | # by default, log level is logging.INFO 11 | LEVEL = "info" 12 | if "LBT_LOG_LEVEL" in os.environ: 13 | LEVEL = os.environ["LBT_LOG_LEVEL"] 14 | LEVEL = getattr(logging, LEVEL.upper()) 15 | 16 | LOG_FORMAT = "%(asctime)s %(name)-16s %(levelname)7s: %(message)s" 17 | 18 | logging.basicConfig( 19 | stream=sys.stdout, level=LEVEL, format=LOG_FORMAT, datefmt="%m/%d %I:%M:%S %p" 20 | ) 21 | 22 | logger = logging.getLogger() 23 | 24 | 25 | def addFile(self, filename): 26 | handler = logging.FileHandler(filename) 27 | handler.setFormatter(logging.Formatter(LOG_FORMAT)) 28 | self.addHandler(handler) 29 | 30 | 31 | # logger.__class__.addFile = addFile 32 | logging.Logger.addFile = addFile 33 | 34 | 35 | def getLogger(name): 36 | return logger.getChild(name) 37 | -------------------------------------------------------------------------------- /lbt/utils/registry.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """A simple registry meta class. 3 | """ 4 | 5 | import abc 6 | import collections 7 | 8 | from lbt.utils import getLogger 9 | 10 | __all__ = ["RegistryMeta", "RegistryError"] 11 | 12 | LOGGER = getLogger("registry") 13 | 14 | 15 | class RegistryError(Exception): 16 | pass 17 | 18 | 19 | def _default_dct_of_list(): 20 | return collections.defaultdict(list) 21 | 22 | 23 | class RegistryMeta(abc.ABCMeta): 24 | registry_dct = collections.defaultdict(dict) 25 | supported_rollout_dct = collections.defaultdict(_default_dct_of_list) 26 | 27 | def __init__(cls, name, bases, namespace): 28 | super(RegistryMeta, cls).__init__(name, bases, namespace) 29 | if hasattr(cls, "REGISTRY"): 30 | # register the class 31 | table = cls.REGISTRY 32 | abstract_methods = cls.__abstractmethods__ 33 | if not abstract_methods: 34 | entry = namespace.get("NAME", name.lower()) 35 | setattr(cls, "NAME", entry) 36 | RegistryMeta.registry_dct[table][entry] = cls 37 | LOGGER.debug( 38 | "Register class `%s` as entry `%s` in table `%s`.", 39 | name, 40 | entry, 41 | table, 42 | ) 43 | 44 | if cls.REGISTRY == "rollout": 45 | # allow new defined rollout class to declare which component can be reused 46 | if hasattr(cls, "supported_components"): 47 | for registry, type_ in cls.supported_components: 48 | RegistryMeta.supported_rollout_dct[registry][type_].append( 49 | entry 50 | ) 51 | else: 52 | if "NAME" in namespace: 53 | entry = namespace["NAME"] 54 | LOGGER.warning( 55 | ( 56 | "Can't register abstract class `%s` as entry `%s`" 57 | " in table `%s`, ignore. Abstract methods: %s" 58 | ), 59 | name, 60 | entry, 61 | table, 62 | ", ".join(abstract_methods), 63 | ) 64 | 65 | @classmethod 66 | def get_class(mcs, table, name): 67 | try: 68 | return mcs.all_classes(table)[name] 69 | except KeyError: 70 | raise RegistryError( 71 | "No registry item {} available in registry {}.".format(name, table) 72 | ) 73 | 74 | @classmethod 75 | def all_classes(mcs, table): 76 | try: 77 | return mcs.registry_dct[table] 78 | except KeyError: 79 | raise RegistryError("No registry table {} available.".format(table)) 80 | 81 | @classmethod 82 | def avail_tables(mcs): 83 | return mcs.registry_dct.keys() 84 | 85 | def all_classes_(cls): 86 | return RegistryMeta.all_classes(cls.REGISTRY) 87 | 88 | def get_class_(cls, name): 89 | return RegistryMeta.get_class(cls.REGISTRY, name) 90 | 91 | def registered_supported_rollouts_(cls): 92 | return RegistryMeta.supported_rollout_dct[cls.REGISTRY][cls.NAME] 93 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "lbt" 7 | version = "0.0.1" 8 | requires-python = ">=3.8" 9 | classifiers = [ 10 | "Programming Language :: Python :: 3", 11 | ] 12 | dependencies = [ 13 | "tqdm", 14 | "pandas", 15 | "datasets", 16 | "openai", "tenacity", # API model 17 | "fschat", "transformers", "accelerate", # open-source model (based on fastchat and huggingface Transformer) 18 | "termcolor", 19 | "PyYaml", 20 | "ipdb", 21 | "ipython", 22 | "faiss-gpu", # QuesSimilarityExamMaker 23 | "sentence-transformers", # QuesSimilarityExamMaker 24 | "flash-attn", 25 | "auto-gptq", "optimum", # Qwen-72B-Chat-Int4 26 | "openai==1.16.2", 27 | "gradio", 28 | "tomark", 29 | "Jinja2", 30 | "pyext", 31 | "ninja", 32 | "python-leetcode", 33 | "pydantic", 34 | "bs4", 35 | "html2text", 36 | "requests", 37 | "python-dotenv" 38 | ] 39 | 40 | [project.optional-dependencies] 41 | dev = ["pre-commit", "black"] 42 | 43 | [tool.setuptools.packages.find] 44 | exclude = ["docs", "dist*", "scripts*", "tests*", "data*", "results*"] 45 | 46 | [tool.wheel] 47 | exclude = ["docs", "dist*", "scripts*", "tests*", "data*", "results*"] 48 | -------------------------------------------------------------------------------- /scripts/code/prepare_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import pandas as pd 5 | from datasets import Dataset 6 | 7 | from lbt.datasets_adapter.utils.fetch_leetcode import fetch_dataset, fetch_solutions 8 | from lbt.datasets_adapter.utils.utils_leetcode import get_api_instance 9 | from lbt.datasets_adapter.utils.clean_leetcode import remove_class_dependent, remove_void, remove_class_impls, remove_examples 10 | from lbt.datasets_adapter.utils.format_leetcode import format_problems, to_jsonl 11 | 12 | parser = argparse.ArgumentParser(description="Configuration for building uncontaminated Leetcode Hard dataset") 13 | parser.add_argument('--langs', nargs='+', default=['python3'], help="List of languages.") 14 | parser.add_argument('--output_dir', type=str, default="./examples/leetcode", help="Directory to save the built dataset.") 15 | parser.add_argument('--extract_test_cases', action='store_true', help="If set, test cases will be extracted from problem descriptions using GPT.") 16 | parser.add_argument('--remove_examples', action='store_true', help="If set, examples will be removed. Cannot be used with --extract_test_cases.") 17 | parser.add_argument('--fetch_solutions', action='store_true', help="If set, solutions to problems will be fetched. Currently only supports lang=python3.") 18 | parser.add_argument('--topic', type=str, default='algorithms', choices=['algorithms']) 19 | parser.add_argument('--difficulty', type=int, default=3, choices=[1, 2, 3], help="Get data of certain difficulty. 1: Easy, 2: Medium, 3: Hard") 20 | 21 | args = parser.parse_args() 22 | 23 | if __name__ == "__main__": 24 | # Check LEETCODE environment variables 25 | try: 26 | leetcode_session = os.environ["LEETCODE_SESSION"] 27 | except: 28 | print("Environment variable LEETCODE_SESSION is not set. Please refer to README") 29 | exit(1) 30 | 31 | # Check OPENAI environment variables 32 | if args.extract_test_cases: 33 | try: 34 | os.environ["OPENAI_API_KEY"] 35 | import openai 36 | except: 37 | print("Extra dependencies and setup are required for test case extraction. Please refer to README") 38 | exit(1) 39 | if args.remove_examples: 40 | print("Cannot use --remove_examples with --extract_test_cases") 41 | exit(1) 42 | 43 | os.makedirs(args.output_dir, exist_ok=True) 44 | 45 | api_instance = get_api_instance(leetcode_session=leetcode_session, csrf_token=os.environ["CSRF_TOKEN"]) 46 | dataset = fetch_dataset(api_instance, topic=args.topic, difficulty=args.difficulty) 47 | 48 | # use pandas to save pandas format dataset 49 | # dataset.to_csv(os.path.join(args.output_dir, f'leetcode-{args.difficulty}-{args.topic}.csv'), index=False) 50 | 51 | filtered_dataset = \ 52 | remove_class_impls( 53 | remove_void( 54 | remove_class_dependent(dataset))).reset_index(drop=True) 55 | 56 | if args.remove_examples: 57 | filtered_dataset = remove_examples(filtered_dataset) 58 | 59 | print(f"Filtered out {len(dataset) - len(filtered_dataset)} problem(s)") 60 | 61 | for lang in args.langs: 62 | print(f"Formatting dataset for {lang}") 63 | formatted_dataset = format_problems(filtered_dataset, lang) 64 | if args.extract_test_cases: 65 | print(f"Extracting test cases for {lang}") 66 | from lbt.datasets_adapter.utils.add_test_cases import extract_test_cases 67 | formatted_dataset = extract_test_cases(formatted_dataset, lang) 68 | if args.fetch_solutions: 69 | print(f"Fetching solutions for {lang}") 70 | formatted_dataset = fetch_solutions(formatted_dataset, lang) 71 | 72 | # save into the huggingface datasets in the Humaneval format 73 | q_templete = "Write a python code \n\"\"\"{}\"\"\"\n to solve the following problem: \n\n{} \n" 74 | for sample in formatted_dataset: 75 | # transform signature into class 76 | sample["signature"] = sample["signature"].replace("(", "(self, ") 77 | sample["signature"] = 'class Solution():\n def ' + sample["signature"] 78 | 79 | # transform test cases into class format 80 | sample["test"] = sample["test"].replace("assert ", "assert Solution().") 81 | 82 | # add new column for question and rationale 83 | sample["question"] = q_templete.format(sample["signature"], sample["docstring"]) 84 | if "rationale" not in sample.keys(): 85 | sample["rationale"] = "" 86 | 87 | to_jsonl(formatted_dataset, os.path.join(args.output_dir, f'leetcode-{args.difficulty}-{lang}', 'dataset.jsonl')) 88 | 89 | t_dataset = Dataset.from_list(formatted_dataset) 90 | if "canonical_solution" in t_dataset.features: 91 | t_dataset = t_dataset.rename_columns({"canonical_solution": "answer"}) 92 | t_dataset.save_to_disk(os.path.join(args.output_dir, f'leetcode-{args.difficulty}-{lang}')) 93 | -------------------------------------------------------------------------------- /scripts/code/prepare_teaching_datasets.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | 5 | from datasets import load_from_disk, Dataset 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--input', type=str, required=True, help='Input baseline directory') 9 | parser.add_argument('--output', type=str, required=True, help='Output directory') 10 | parser.add_argument('--freq', type=int, default=8, help='freq times') 11 | args = parser.parse_args() 12 | 13 | if __name__ == "__main__": 14 | dataset_list = [] 15 | for m in range(args.freq): 16 | dataset_list.append(load_from_disk(os.path.join(args.input, str(m)))) 17 | 18 | question_num = len(dataset_list[0][0]["exam_questions"]) 19 | # make a dictionary to store the question items, indexed by the question num 20 | question_dict = {k: [] for k in range(question_num)} 21 | 22 | # search the dataset to find each rationale 23 | for dataset in dataset_list: 24 | for m in range(question_num): 25 | model_name = list(dataset[0]["exam_details"].keys())[0] 26 | 27 | teacher_item = {} 28 | teacher_item["task_id"] = dataset[0]["task_id"][m] 29 | teacher_item["answer"] = dataset[0]["exam_details"][model_name]['answers'][m][0] 30 | teacher_item["question"] = dataset[0]["exam_questions"][m] 31 | teacher_item["rationale"] = dataset[0]["exam_details"][model_name]['rationales'][m][0] 32 | if '/bs-' in args.input: 33 | teacher_item["tags"] = "Binary Search" 34 | elif '/dp-' in args.input: 35 | teacher_item["tags"] = "Dynamic Programming" 36 | elif 'code_contests' in args.input or "apps" in args.input: 37 | teacher_item["tags"] = "competition" 38 | else: 39 | print(args.input) 40 | raise NotImplementedError 41 | teacher_item["model_name"] = model_name 42 | teacher_item["score"] = dataset[0]["exam_details"][model_name]['scores'][m][0] 43 | 44 | # append into question_dict 45 | question_dict[m].append(teacher_item) 46 | 47 | 48 | # save the question_dict into a dataset and jsonl file 49 | for m in range(question_num): 50 | t_dataset = Dataset.from_list(question_dict[m]) 51 | t_dataset.save_to_disk(os.path.join(args.output, str(m))) 52 | 53 | with open(os.path.join(args.output, str(m), "dataset.jsonl"), "w") as f: 54 | for item in question_dict[m]: 55 | f.write(json.dumps(item) + "\n") -------------------------------------------------------------------------------- /scripts/code/search_rationale.py: -------------------------------------------------------------------------------- 1 | from lbt.datasets_adapter.leetcode_sub.types import LeetCodeSubmission, ProgrammingLanguage 2 | from lbt.datasets_adapter.leetcode_sub.environment import LeetCodeEnv 3 | 4 | import argparse 5 | import os 6 | import json 7 | import tqdm 8 | import signal 9 | import contextlib 10 | import time 11 | 12 | from datasets import load_from_disk, Dataset 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--input', type=str, help='Input directory') 16 | parser.add_argument('--is_submit', action='store_true', help='submit mode') 17 | parser.add_argument('--resume', action='store_true', help='resume submission') 18 | parser.add_argument('--num_sample', type=int, default=8, help='number of samples') 19 | args = parser.parse_args() 20 | 21 | class TimeoutException(Exception): 22 | pass 23 | 24 | @contextlib.contextmanager 25 | def time_limit(seconds: float): 26 | def signal_handler(signum, frame): 27 | raise TimeoutException("Timed out!") 28 | signal.setitimer(signal.ITIMER_REAL, seconds) 29 | signal.signal(signal.SIGALRM, signal_handler) 30 | try: 31 | yield 32 | finally: 33 | signal.setitimer(signal.ITIMER_REAL, 0) 34 | 35 | def eval_code(is_submit, code, question_slug, visible_score): 36 | # if "AssertionError" in code: 37 | # return 0.0 38 | if is_submit: 39 | # define the submission question 40 | sub = LeetCodeSubmission(code=code.replace('\/', '/'), 41 | lang=ProgrammingLanguage.PYTHON3, 42 | question_slug=question_slug, 43 | timeout=20) 44 | 45 | # if there are some exception, retry 46 | try_num=0 47 | while True: 48 | if try_num > 5: 49 | return -1.0 50 | try: 51 | try_num += 1 52 | # time.sleep(30) 53 | env = LeetCodeEnv() 54 | with time_limit(40): 55 | status, reward, done, submission_result = env.step(sub) 56 | print(status, reward, done, submission_result) 57 | 58 | # number of correct and testcases 59 | total_correct = submission_result['total_correct'] 60 | total_testcases = submission_result['total_testcases'] 61 | return total_correct / total_testcases 62 | # if keyboard interrupt, exit 63 | except KeyboardInterrupt: 64 | print("******** Keyboard Interrupt ********") 65 | exit() 66 | except TimeoutException: 67 | print("******** Timeout ********") 68 | continue 69 | except Exception as e: 70 | print(f"******** {e} ********") 71 | continue 72 | else: 73 | return visible_score 74 | 75 | if __name__ == "__main__": 76 | # Check LEETCODE environment variables 77 | if args.is_submit: 78 | output_path = os.path.join(args.input, 'results-submit.jsonl') 79 | try: 80 | leetcode_session = os.environ["LEETCODE_SESSION"] 81 | except: 82 | print("Environment variable LEETCODE_SESSION is not set. Please refer to README") 83 | exit(1) 84 | else: 85 | output_path = os.path.join(args.input, 'results-visible.jsonl') 86 | 87 | if not args.resume: 88 | if os.path.exists(output_path): 89 | os.remove(output_path) 90 | else: 91 | if os.path.exists(output_path): 92 | past_data = [] 93 | with open(output_path, 'r+') as f: 94 | for line in f: 95 | past_data.append(json.loads(line)) 96 | past_name = len(past_data) // args.num_sample 97 | past_test_id = len(past_data) % args.num_sample 98 | else: 99 | past_name = 0 100 | past_test_id = 0 101 | 102 | # find all the file names in args.input 103 | file_names = os.listdir(args.input) 104 | # remove all the file name with .json 105 | file_names = [file_name for file_name in file_names if '.json' not in file_name] 106 | for file_name in range(len(file_names)): 107 | if args.resume and file_name < past_name: 108 | continue 109 | datasets = load_from_disk(os.path.join(args.input, str(file_name))) 110 | for test_id, dataset in enumerate(datasets): 111 | if args.resume and test_id < past_test_id and file_name == past_name: 112 | continue 113 | # get exam information 114 | s_model_name = list(dataset["exam_details"].keys())[0] 115 | s_task_ids = dataset['task_id'] 116 | s_answer = dataset["exam_details"][s_model_name]["answers"] 117 | s_visible_scores = dataset["exam_details"][s_model_name]["scores"] 118 | 119 | # check oracle-1 120 | if 'pipeline-1' in args.input: 121 | # if oracle-1, get teacher information 122 | t_task_id = dataset["teaching_items"][0]["task_id"] 123 | t_answer = dataset["teaching_items"][0]["answer"] 124 | # check if scores in dictionary 125 | if "score" in dataset["teaching_items"][0]: 126 | t_visible_score = dataset["teaching_items"][0]["score"] 127 | else: 128 | t_visible_score = dataset["teaching_items"][0]["scores"] 129 | t_final_score = eval_code(args.is_submit, t_answer, t_task_id, t_visible_score) 130 | 131 | # refine the exam information 132 | rm_idx = s_task_ids.index(t_task_id) 133 | s_task_ids.pop(rm_idx) 134 | s_answer.pop(rm_idx) 135 | s_visible_scores.pop(rm_idx) 136 | 137 | s_final_scores = [] 138 | for i in tqdm.tqdm(range(len(s_task_ids))): 139 | s_final_scores.append(eval_code(args.is_submit, s_answer[i][0], s_task_ids[i], s_visible_scores[i][0])) 140 | 141 | # calculate the average score 142 | s_avg_score = sum(s_final_scores) / len(s_final_scores) 143 | 144 | # write the submit_scores in a json file 145 | with open(output_path, 'a') as f: 146 | item = {"t_task_id": t_task_id, "t_score": t_final_score, "exam_id": test_id, "s_score": s_final_scores, "s_avg_score": s_avg_score} 147 | f.write(json.dumps(item) + "\n") -------------------------------------------------------------------------------- /scripts/code_exam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import yaml 4 | import json 5 | import argparse 6 | 7 | from tqdm import tqdm 8 | from datasets import Dataset, load_from_disk 9 | 10 | from lbt.base import Component 11 | from lbt.test import test_single_student, aggregate_scores 12 | from lbt.utils.log import getLogger 13 | 14 | 15 | LOGGER = getLogger("exam") 16 | 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("cfg_file", type=str, help="Path to the config file") 21 | parser.add_argument("--output-path", type=str, required=True) 22 | parser.add_argument( 23 | "--teaching-dataset-file", 24 | type=str, 25 | help="The items in this dataset file will be used as few-shot demonstrations.", 26 | default=None, 27 | ) 28 | parser.add_argument("--exam-dataset-file", type=str, required=True) 29 | args = parser.parse_args() 30 | 31 | with open(args.cfg_file, "r") as rf: 32 | cfg = yaml.safe_load(rf) 33 | 34 | teaching_plans = cfg.get("teaching_plans", "every") 35 | assert isinstance(teaching_plans, (list, tuple)) or teaching_plans in { 36 | "every", 37 | "no demo", 38 | } 39 | 40 | # Load teaching datasets 41 | if teaching_plans != "no demo": 42 | # Initialize teaching datasets 43 | assert args.teaching_dataset_file is not None, ( 44 | 'Only when `teaching_plans == "no demo"`, `--teaching-dataset-file` can be' 45 | " omitted." 46 | ) 47 | teaching_dataset = load_from_disk(args.teaching_dataset_file) 48 | LOGGER.info( 49 | f"Loaded teaching dataset from {args.teaching_dataset_file}, fields:" 50 | f" {teaching_dataset.features}" 51 | ) 52 | # the columns are: question, answer, test 53 | if "answer" not in teaching_dataset.features: 54 | LOGGER.info( 55 | f"Cannot find `answer` field in {args.teaching_dataset_file}" 56 | ) 57 | raise NotImplementedError 58 | else: 59 | LOGGER.info( 60 | f"Use the `answer` field in {args.teaching_dataset_file} as the" 61 | " demonstration." 62 | ) 63 | # the columns are: question, rationale, answer 64 | 65 | # Load exam_dataset 66 | exam_dataset = load_from_disk(args.exam_dataset_file) 67 | 68 | if "answer" not in exam_dataset.features and "canonical_solution" in exam_dataset.features: 69 | exam_dataset = exam_dataset.rename_columns({"canonical_solution": "answer"}) 70 | 71 | if "question" not in exam_dataset.features and "prompt" in exam_dataset.features: 72 | exam_dataset = exam_dataset.rename_columns({"prompt": "question"}) 73 | 74 | LOGGER.info( 75 | f"Loaded exam dataset from {args.exam_dataset_file}, fields:" 76 | f" {exam_dataset.features}" 77 | ) 78 | if "rationale" not in exam_dataset.features: 79 | # LOGGER.info( 80 | # f"Use the `solution` field in {args.exam_dataset_file} as the GT to measure" 81 | # " scores." 82 | # ) 83 | # exam_dataset = exam_dataset.rename_columns({"solution": "rationale"}) 84 | LOGGER.info( 85 | f"We do not use `rationale` to measure" 86 | " scores." 87 | ) 88 | else: 89 | LOGGER.info( 90 | f"Use the `rationale` field in {args.exam_dataset_file} as the GT to" 91 | " measure scores." 92 | ) 93 | # the columns are: question, rationale, answer 94 | 95 | # Unify teaching plan as a list 96 | if teaching_plans == "every": 97 | # Take `num_rows` exams, each with one row from the teaching dataset as the demonstration 98 | teaching_plans = [[index] for index in range(teaching_dataset.num_rows)] 99 | elif teaching_plans == "no demo": 100 | # Take 1 exam, with no demonstrations from the teaching dataset 101 | teaching_plans = [[]] 102 | else: 103 | # Take `len(teaching_plans)` exams, 104 | # each item in list is a list of indexes, which are the teaching-dataset indexes 105 | # that will be used as the demonstrations in one exam 106 | assert ( 107 | max([num for num in sum(teaching_plans, []) if isinstance(num, int)]) 108 | < teaching_dataset.num_rows 109 | ) # do a check 110 | 111 | # Initialize exam_maker, exam_prompter, exam_scorer, student_models 112 | exam_maker = Component.init_from_cfg( 113 | cfg, "exam_maker", exam_bank_dataset=exam_dataset 114 | ) 115 | exam_prompter = Component.init_from_cfg(cfg, "exam_prompter") 116 | exam_scorer = Component.init_from_cfg(cfg, "exam_scorer") 117 | student_pool = [ 118 | Component.init_from_cfg(s_m_cfg, "model") 119 | for s_m_cfg in cfg["student_model_cfgs"] 120 | ] 121 | student_sample_cfgs = [ 122 | s_m_cfg.get("sample_cfg", {}) for s_m_cfg in cfg["student_model_cfgs"] 123 | ] 124 | 125 | # Prepare output directory, dump the config 126 | os.makedirs(args.output_path, exist_ok=True) 127 | cfg["teaching_dataset_file"] = args.teaching_dataset_file 128 | cfg["exam_dataset_file"] = args.exam_dataset_file 129 | with open(os.path.join(args.output_path, "config.yaml"), "w") as wf: 130 | yaml.safe_dump(cfg, wf) 131 | 132 | # Loop: Iterate over the teaching plans 133 | # The output dataset has fields: teaching_items: List, exam_questions: List[str], 134 | # exam_gt_rationales List[str]: exam_gt_answers: List[str], 135 | # scores: Dict[str, float], exam_details: Dict[str, List] 136 | output_items = [] 137 | for teaching_plan in tqdm(teaching_plans): 138 | teaching_item_question_only = False 139 | 140 | if teaching_plan: 141 | if teaching_plan[0] == "question-only": 142 | teaching_item_question_only = True 143 | teaching_plan = teaching_plan[1:] 144 | teaching_items = [teaching_dataset[index] for index in teaching_plan] 145 | else: 146 | teaching_items = [] 147 | 148 | output_item = { 149 | "teaching_items": teaching_items, 150 | "exam_questions": [], 151 | "exam_gt_rationales": [], 152 | "exam_gt_answers": [], 153 | "task_id": [], # New in code datasets 154 | "exam_details": {student.name: None for student in student_pool}, 155 | "scores": {student.name: None for student in student_pool}, 156 | } 157 | 158 | # Decide the exam questions 159 | s_exam_dataset = exam_maker.make_exam_questions(teaching_items) 160 | # Record the exam questions and gt answers for this teaching question - rationale pair 161 | output_item["exam_questions"] = s_exam_dataset["question"] 162 | output_item["exam_gt_rationales"] = s_exam_dataset["rationale"] 163 | output_item["exam_gt_answers"] = s_exam_dataset["answer"] 164 | output_item["task_id"] = s_exam_dataset["task_id"] # New in code datasets 165 | 166 | 167 | # Loop: Evaluate each student 168 | for student, stu_sample_cfg in zip(student_pool, student_sample_cfgs): 169 | # Loop: Evaluate each question 170 | sample_cfg = copy.deepcopy( 171 | cfg.get("general_student_sample_cfg", {}) 172 | ) # general sample config 173 | sample_cfg.update(stu_sample_cfg) # update with per-student sample config 174 | ( 175 | single_student_rationales, 176 | single_student_answers, 177 | single_student_scores, 178 | ) = test_single_student( 179 | student=student, 180 | exam_prompter=exam_prompter, 181 | exam_scorer=exam_scorer, 182 | teaching_items=( 183 | teaching_items if not teaching_item_question_only else [] 184 | ), 185 | exam_dataset=s_exam_dataset, 186 | sample_cfg=sample_cfg, 187 | ) 188 | 189 | # judges & exam_rationales: a nested list of shape `num_exam_questions x num_exam_answer_per_question`, 190 | # where every item is a score or a string 191 | output_item["exam_details"][student.name] = { 192 | "rationales": single_student_rationales, 193 | "answers": single_student_answers, 194 | "scores": single_student_scores, 195 | } 196 | score = aggregate_scores(single_student_scores) 197 | output_item["scores"][student.name] = score 198 | 199 | output_items.append(output_item) 200 | 201 | # Save the results 202 | output_dataset = Dataset.from_list(output_items) 203 | LOGGER.info(f"Dumping results to {args.output_path} ...") 204 | output_dataset.save_to_disk(args.output_path) 205 | output_dataset.to_csv(os.path.join(args.output_path, "dataset.csv")) 206 | output_dataset.to_json(os.path.join(args.output_path, "dataset.json"), indent=2) 207 | -------------------------------------------------------------------------------- /scripts/exam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import yaml 4 | import argparse 5 | 6 | from tqdm import tqdm 7 | from datasets import Dataset, load_from_disk 8 | 9 | from lbt.base import Component 10 | from lbt.test import test_single_student, aggregate_scores 11 | from lbt.utils.log import getLogger 12 | 13 | 14 | LOGGER = getLogger("exam") 15 | 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("cfg_file", type=str, help="Path to the config file") 20 | parser.add_argument("--output-path", type=str, required=True) 21 | parser.add_argument( 22 | "--teaching-dataset-file", 23 | type=str, 24 | help="The items in this dataset file will be used as few-shot demonstrations.", 25 | default=None, 26 | ) 27 | parser.add_argument("--exam-dataset-file", type=str, required=True) 28 | args = parser.parse_args() 29 | 30 | with open(args.cfg_file, "r") as rf: 31 | cfg = yaml.safe_load(rf) 32 | 33 | teaching_plans = cfg.get("teaching_plans", "every") 34 | assert isinstance(teaching_plans, (list, tuple)) or teaching_plans in { 35 | "every", 36 | "no demo", 37 | } 38 | 39 | # Load datasets 40 | if teaching_plans != "no demo": 41 | # Initialize teaching and exam datasets 42 | assert args.teaching_dataset_file is not None, ( 43 | 'Only when `teaching_plans == "no demo"`, `--teaching-dataset-file` can be' 44 | " omitted." 45 | ) 46 | teaching_dataset = load_from_disk(args.teaching_dataset_file) 47 | LOGGER.info( 48 | f"Loaded teaching dataset from {args.teaching_dataset_file}, fields:" 49 | f" {teaching_dataset.features}" 50 | ) 51 | # the columns are: question, rationale, solution, answer 52 | if "rationale" not in teaching_dataset.features: 53 | LOGGER.info( 54 | f"Use the `solution` field in {args.teaching_dataset_file} as the" 55 | " demonstration." 56 | ) 57 | teaching_dataset = teaching_dataset.rename_columns( 58 | {"solution": "rationale"} 59 | ) 60 | else: 61 | LOGGER.info( 62 | f"Use the `rationale` field in {args.teaching_dataset_file} as the" 63 | " demonstration." 64 | ) 65 | # the columns are: question, rationale, answer 66 | 67 | exam_dataset = load_from_disk(args.exam_dataset_file) 68 | LOGGER.info( 69 | f"Loaded exam dataset from {args.exam_dataset_file}, fields:" 70 | f" {exam_dataset.features}" 71 | ) 72 | if "rationale" not in exam_dataset.features: 73 | LOGGER.info( 74 | f"Use the `solution` field in {args.exam_dataset_file} as the GT to measure" 75 | " scores." 76 | ) 77 | exam_dataset = exam_dataset.rename_columns({"solution": "rationale"}) 78 | else: 79 | LOGGER.info( 80 | f"Use the `rationale` field in {args.exam_dataset_file} as the GT to" 81 | " measure scores." 82 | ) 83 | # the columns are: question, rationale, answer 84 | 85 | # Unify teaching plan as a list 86 | if teaching_plans == "every": 87 | # Take `num_rows` exams, each with one row from the teaching dataset as the demonstration 88 | teaching_plans = [[index] for index in range(teaching_dataset.num_rows)] 89 | elif teaching_plans == "no demo": 90 | # Take 1 exam, with no demonstrations from the teaching dataset 91 | teaching_plans = [[]] 92 | else: 93 | # Take `len(teaching_plans)` exams, 94 | # each item in list is a list of indexes, which are the teaching-dataset indexes 95 | # that will be used as the demonstrations in one exam 96 | assert ( 97 | max([num for num in sum(teaching_plans, []) if isinstance(num, int)]) 98 | < teaching_dataset.num_rows 99 | ) # do a check 100 | 101 | # Initialize exam_maker, exam_prompter, exam_scorer, student_models 102 | exam_maker = Component.init_from_cfg( 103 | cfg, "exam_maker", exam_bank_dataset=exam_dataset 104 | ) 105 | exam_prompter = Component.init_from_cfg(cfg, "exam_prompter") 106 | exam_scorer = Component.init_from_cfg(cfg, "exam_scorer") 107 | student_pool = [ 108 | Component.init_from_cfg(s_m_cfg, "model") 109 | for s_m_cfg in cfg["student_model_cfgs"] 110 | ] 111 | student_sample_cfgs = [ 112 | s_m_cfg.get("sample_cfg", {}) for s_m_cfg in cfg["student_model_cfgs"] 113 | ] 114 | 115 | # Prepare output directory, dump the config 116 | os.makedirs(args.output_path, exist_ok=True) 117 | cfg["teaching_dataset_file"] = args.teaching_dataset_file 118 | cfg["exam_dataset_file"] = args.exam_dataset_file 119 | with open(os.path.join(args.output_path, "config.yaml"), "w") as wf: 120 | yaml.safe_dump(cfg, wf) 121 | 122 | # Loop: Iterate over the teaching plans 123 | # The output dataset has fields: teaching_items: List, exam_questions: List[str], 124 | # exam_gt_rationales List[str]: exam_gt_answers: List[str], 125 | # scores: Dict[str, float], exam_details: Dict[str, List] 126 | output_items = [] 127 | for teaching_plan in tqdm(teaching_plans): 128 | teaching_item_question_only = False 129 | 130 | if teaching_plan: 131 | if teaching_plan[0] == "question-only": 132 | teaching_item_question_only = True 133 | teaching_plan = teaching_plan[1:] 134 | teaching_items = [teaching_dataset[index] for index in teaching_plan] 135 | else: 136 | teaching_items = [] 137 | 138 | output_item = { 139 | "teaching_items": teaching_items, 140 | "exam_questions": [], 141 | "exam_gt_rationales": [], 142 | "exam_gt_answers": [], 143 | "exam_details": {student.name: None for student in student_pool}, 144 | "scores": {student.name: None for student in student_pool}, 145 | } 146 | 147 | # Decide the exam questions 148 | s_exam_dataset = exam_maker.make_exam_questions(teaching_items) 149 | # Record the exam questions and gt answers for this teaching question - rationale pair 150 | output_item["exam_questions"] = s_exam_dataset["question"] 151 | output_item["exam_gt_rationales"] = s_exam_dataset["rationale"] 152 | output_item["exam_gt_answers"] = s_exam_dataset["answer"] 153 | 154 | # Loop: Evaluate each student 155 | for student, stu_sample_cfg in zip(student_pool, student_sample_cfgs): 156 | # Loop: Evaluate each question 157 | sample_cfg = copy.deepcopy( 158 | cfg.get("general_student_sample_cfg", {}) 159 | ) # general sample config 160 | sample_cfg.update(stu_sample_cfg) # update with per-student sample config 161 | ( 162 | single_student_rationales, 163 | single_student_answers, 164 | single_student_scores, 165 | ) = test_single_student( 166 | student=student, 167 | exam_prompter=exam_prompter, 168 | exam_scorer=exam_scorer, 169 | teaching_items=( 170 | teaching_items if not teaching_item_question_only else [] 171 | ), 172 | exam_dataset=s_exam_dataset, 173 | sample_cfg=sample_cfg, 174 | ) 175 | 176 | # judges & exam_rationales: a nested list of shape `num_exam_questions x num_exam_answer_per_question`, 177 | # where every item is a score or a string 178 | output_item["exam_details"][student.name] = { 179 | "rationales": single_student_rationales, 180 | "answers": single_student_answers, 181 | "scores": single_student_scores, 182 | } 183 | score = aggregate_scores(single_student_scores) 184 | output_item["scores"][student.name] = score 185 | 186 | output_items.append(output_item) 187 | 188 | # Save the results 189 | output_dataset = Dataset.from_list(output_items) 190 | LOGGER.info(f"Dumping results to {args.output_path} ...") 191 | output_dataset.save_to_disk(args.output_path) 192 | output_dataset.to_csv(os.path.join(args.output_path, "dataset.csv")) 193 | output_dataset.to_json(os.path.join(args.output_path, "dataset.json"), indent=2) 194 | -------------------------------------------------------------------------------- /scripts/math/prepare_datasets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from glob import glob 4 | 5 | from datasets import Dataset 6 | 7 | 8 | def last_boxed_only_string(string): 9 | idx = string.rfind("\\boxed") 10 | if idx < 0: 11 | idx = string.rfind("\\fbox") 12 | if idx < 0: 13 | return None 14 | 15 | i = idx 16 | right_brace_idx = None 17 | num_left_braces_open = 0 18 | while i < len(string): 19 | if string[i] == "{": 20 | num_left_braces_open += 1 21 | if string[i] == "}": 22 | num_left_braces_open -= 1 23 | if num_left_braces_open == 0: 24 | right_brace_idx = i 25 | break 26 | i += 1 27 | 28 | if right_brace_idx == None: 29 | retval = None 30 | else: 31 | retval = string[idx:right_brace_idx + 1] 32 | 33 | return retval[7:-1] 34 | 35 | def process_json_file(json_file, key_mapping = {"problem": "question", "type": "subject"}): 36 | with open(json_file) as f: 37 | problem = json.load(f) 38 | 39 | answer = last_boxed_only_string(problem["solution"]) 40 | unique_id = "/".join(json_file.split("/")[-3:]) 41 | problem["answer"] = answer 42 | problem["unique_id"] = unique_id 43 | problem["level"] = int(problem["level"][-1]) 44 | 45 | for old_key in key_mapping: 46 | problem = {key_mapping[old_key] if key == old_key else key: value for key, value in problem.items()} 47 | 48 | return problem 49 | 50 | def split_dataset(dataset, dataname, num_problems=10, num_rationales=256): 51 | num_splits = len(dataset) // num_problems 52 | splits_ids = set() 53 | 54 | for i in range(num_splits): 55 | dataset_split = [] 56 | start = i * num_problems 57 | if i != num_splits - 1: 58 | end = (i + 1) * num_problems 59 | else: 60 | end = len(dataset) 61 | 62 | for j in range(start, end): 63 | problem = dataset[j] 64 | dataset_split.extend([problem] * num_rationales) 65 | splits_ids.add(problem['unique_id']) 66 | 67 | Dataset.from_list(dataset_split).save_to_disk( 68 | f"./examples/datasets/math/math_splits/{dataname}_r{num_rationales}s{i}" 69 | ) 70 | 71 | assert len(splits_ids) == len(dataset) 72 | 73 | 74 | if __name__ == "__main__": 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument("--num_problems", type=int, default=10) 77 | parser.add_argument("--num_rationales", type=int, default=256) 78 | args = parser.parse_args() 79 | 80 | math500_dataset = [] 81 | math500_ids = set() 82 | with open("./examples/datasets/math/MATH/math500_splits/test.jsonl") as f: 83 | for line in f: 84 | problem = json.loads(line) 85 | problem = {"question" if key == "problem" else key: value for key, value in problem.items()} 86 | id = problem['unique_id'] 87 | math500_dataset.append(problem) 88 | math500_ids.add(id) 89 | 90 | math200_dataset = [] 91 | snapshots_dataset = [] 92 | math200_ids = set() 93 | snapshots_ids = set() 94 | snapshot1_json_files = glob(f"./examples/datasets/math/MATH/Oct-2023/test/*/*.json") 95 | for snapshot1_json_path in snapshot1_json_files: 96 | id = "/".join(snapshot1_json_path.split("/")[-3:]) 97 | snapshots_ids.add(id) 98 | 99 | math_json_path = snapshot1_json_path.replace("/Oct-2023/", "/data/") 100 | snapshot2_json_path = snapshot1_json_path.replace("/Oct-2023/", "/Nov-2023/") 101 | snapshot3_json_path = snapshot1_json_path.replace("/Oct-2023/", "/Dec-2023/") 102 | 103 | problem = process_json_file(math_json_path) 104 | snapshot1 = process_json_file(snapshot1_json_path) 105 | snapshot2 = process_json_file(snapshot2_json_path) 106 | snapshot3 = process_json_file(snapshot3_json_path) 107 | 108 | if id in math500_ids: 109 | math200_dataset.append(problem) 110 | snapshots_dataset.extend([snapshot1, snapshot2, snapshot3]) 111 | math200_ids.add(id) 112 | 113 | assert len(math200_ids) == 181 114 | assert len(math500_ids) == 500 115 | assert len(snapshots_ids) == 1745 116 | 117 | Dataset.from_list(math200_dataset).save_to_disk("./examples/datasets/math/math200") 118 | Dataset.from_list(math500_dataset).save_to_disk("./examples/datasets/math/math500") 119 | Dataset.from_list(snapshots_dataset).save_to_disk("./examples/datasets/math/snapshots") 120 | 121 | split_dataset(math200_dataset, "math200", args.num_problems, args.num_rationales) 122 | -------------------------------------------------------------------------------- /scripts/math/prepare_teaching_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from glob import glob 3 | 4 | from datasets import Dataset, load_from_disk 5 | 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--outputs", type=str, default="./outputs") 10 | parser.add_argument("--teacher_exp", type=str) 11 | parser.add_argument("--teacher_name", type=str) 12 | parser.add_argument("--dataname", type=str, default="math200") 13 | parser.add_argument("--num_rationales", type=int, default=256) 14 | args = parser.parse_args() 15 | 16 | dataset_paths = glob(f"./examples/datasets/math/math_splits/{args.dataname}_r{args.num_rationales}s*") 17 | 18 | for path in dataset_paths: 19 | split = path.split("/")[-1] 20 | dataset_split = load_from_disk(path) 21 | teacher_dataset = load_from_disk( 22 | f"{args.outputs}/{args.teacher_exp}/rationales/{split}" 23 | ) 24 | 25 | questions = teacher_dataset[0]["exam_questions"] 26 | rationales = teacher_dataset[0]['exam_details'][args.teacher_name]['rationales'] 27 | answers = teacher_dataset[0]['exam_details'][args.teacher_name]['answers'] 28 | 29 | assert len(dataset_split) == len(questions) 30 | 31 | teaching_dataset = [] 32 | for j in range(len(dataset_split)): 33 | assert dataset_split[j]["question"] == questions[j] 34 | 35 | problem = {} 36 | problem["question"] = dataset_split[j]["question"] 37 | index = rationales[j][0].find("[[Final Answer]]") 38 | problem["solution"] = rationales[j][0][:index].rstrip() 39 | problem["answer"] = answers[j][0] 40 | problem["subject"] = dataset_split[j]["subject"] 41 | problem["level"] = dataset_split[j]["level"] 42 | problem["unique_id"] = dataset_split[j]["unique_id"] 43 | teaching_dataset.append(problem) 44 | 45 | Dataset.from_list(teaching_dataset).save_to_disk(f"{args.outputs}/{args.teacher_exp}/teaching/{split}") 46 | -------------------------------------------------------------------------------- /scripts/math/search_rationale.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from glob import glob 3 | 4 | from tqdm import tqdm 5 | from datasets import load_from_disk 6 | 7 | from lbt.datasets_adapter.math_dataset import MathExamScorer 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--outputs", type=str, default="./outputs") 13 | parser.add_argument("--teacher_exp", type=str) 14 | parser.add_argument("--teacher_name", type=str) 15 | parser.add_argument("--student_exp", type=str) 16 | parser.add_argument("--student_name", type=str) 17 | parser.add_argument("--dataname", type=str, default="math200") 18 | parser.add_argument("--num_rationales", type=int, default=256) 19 | parser.add_argument("--num_exam_questions", type=int, default=3) 20 | parser.add_argument("--num_repetitions", type=int, default=3) 21 | args = parser.parse_args() 22 | 23 | print(args) 24 | 25 | scorer = MathExamScorer(False) 26 | exam_dataset = load_from_disk(f"./examples/datasets/math/snapshots") 27 | dataset_paths = glob(f"./examples/datasets/math/math_splits/{args.dataname}_r{args.num_rationales}s*") 28 | 29 | num_prev_problems = 0 30 | accuracy = [0] * 4 31 | 32 | for i in tqdm(range(len(dataset_paths))): 33 | path = f"./examples/datasets/math/math_splits/{args.dataname}_r{args.num_rationales}s{i}" 34 | split = path.split("/")[-1] 35 | dataset_split = load_from_disk(path) 36 | teacher_dataset = load_from_disk(f"{args.outputs}/{args.teacher_exp}/rationales/{split}") 37 | student_dataset = load_from_disk(f"{args.outputs}/{args.student_exp}/{args.teacher_exp}_exams/{split}") 38 | 39 | teacher_rationales = teacher_dataset[0]['exam_details'][args.teacher_name]['rationales'] 40 | teacher_answers = teacher_dataset[0]['exam_details'][args.teacher_name]['answers'] 41 | gt_rationales = teacher_dataset[0]['exam_gt_rationales'] 42 | gt_answers = teacher_dataset[0]['exam_gt_answers'] 43 | 44 | num_problems = len(dataset_split) // args.num_rationales 45 | for j in range(num_problems): 46 | all_answers = {} 47 | rationale_start = args.num_rationales * j 48 | for k in range(rationale_start, rationale_start + args.num_rationales): 49 | student_rationales = student_dataset[k]["exam_details"][args.student_name]["rationales"] 50 | answer = teacher_answers[k][0] 51 | 52 | lbt_score = 0 53 | for l in range(args.num_exam_questions * args.num_repetitions): 54 | question = student_dataset[k]["exam_questions"][l] 55 | snapshot_start = args.num_exam_questions * (num_prev_problems + j) 56 | snapshot_end = args.num_exam_questions * (num_prev_problems + j + 1) 57 | for index in range(snapshot_start, snapshot_end): 58 | if exam_dataset[index]["question"] == question: 59 | break 60 | assert exam_dataset[index]["question"] == question 61 | 62 | exam_gt_rationale = exam_dataset[index]['solution'] 63 | exam_gt_answer = exam_dataset[index]['answer'] 64 | student_rationale = student_rationales[l][0] 65 | gt = {"answer": f"[[Solution]]:\nLet's think step by step.\n\n{exam_gt_rationale}\n\n[[Final Answer]]:\n${exam_gt_answer}$\n"} 66 | exam = {"rationale": student_rationale} 67 | lbt_score += scorer.score_exam_result(gt, exam) 68 | 69 | if answer in all_answers: 70 | all_answers[answer][0] += 1 71 | all_answers[answer][1] = max(lbt_score, all_answers[answer][1]) 72 | all_answers[answer][2] += lbt_score 73 | all_answers[answer][3] = all_answers[answer][2] / all_answers[answer][0] 74 | else: 75 | all_answers[answer] = [1, lbt_score, lbt_score, lbt_score] 76 | 77 | for mode in ["MAJ", "MAX", "SUM", "AVG"]: 78 | if mode == "MAJ": 79 | dim = 0 80 | elif mode == "MAX": 81 | dim = 1 82 | elif mode == 'SUM': 83 | dim = 2 84 | elif mode == "AVG": 85 | dim = 3 86 | 87 | winner = sorted(all_answers.items(), key=lambda item: (item[1][dim], item[1][0]))[-1][0] 88 | 89 | for k in range(rationale_start, rationale_start + args.num_rationales): 90 | gt_rationale = gt_rationales[k] 91 | gt_answer = gt_answers[k] 92 | teacher_rationale = teacher_rationales[k][0] 93 | teacher_answer = teacher_answers[k][0] 94 | 95 | if teacher_answer == winner: 96 | gt = {"answer": f"[[Solution]]:\nLet's think step by step.\n\n{gt_rationale}\n\n[[Final Answer]]:\n${gt_answer}$\n"} 97 | exam = {"rationale": teacher_rationale} 98 | accuracy[dim] += scorer.score_exam_result(gt, exam) 99 | break 100 | 101 | num_prev_problems += num_problems 102 | 103 | for mode in ["MAJ", "MAX", "SUM", "AVG"]: 104 | if mode == "MAJ": 105 | dim = 0 106 | elif mode == "MAX": 107 | dim = 1 108 | elif mode == 'SUM': 109 | dim = 2 110 | elif mode == "AVG": 111 | dim = 3 112 | 113 | accuracy[dim] *= (100 / num_prev_problems) 114 | 115 | print(f"{mode}: {accuracy[dim]:.2f}") 116 | -------------------------------------------------------------------------------- /tests/test_math_extraction.py: -------------------------------------------------------------------------------- 1 | from datasets import load_from_disk 2 | 3 | from lbt.base import Component 4 | from lbt.exam_maker import ExamPrompter, FixedExamMaker 5 | from lbt.datasets_adapter.math_dataset import MathExamScorer 6 | 7 | # output_dataset = load_from_disk("results/try_filter/").select_columns( 8 | # ["question", "answer", "solution"] 9 | # ).rename_columns({"solution": "rationale"}) 10 | output_dataset = ( 11 | load_from_disk( 12 | "../NLP-playground/examples/rationale/data/math_solution_worstRationale_10/" 13 | ) 14 | .select_columns(["problem", "answer", "solution"]) 15 | .rename_columns({"problem": "question", "solution": "rationale"}) 16 | ) 17 | 18 | exam_dataset = load_from_disk("../NLP-playground/examples/rationale/data/math_1500/") 19 | # the columns are: problem, solution, solution 20 | exam_dataset = exam_dataset.select_columns( 21 | ["problem", "answer", "solution"] 22 | ).rename_columns({"problem": "question", "solution": "rationale"}) 23 | # exam_maker = Component.init_from_cfg(cfg, "exam_maker", exam_bank_dataset=exam_dataset) 24 | exam_maker = FixedExamMaker( 25 | exam_bank_dataset=exam_dataset, selected_indexes="range(0, 16)" 26 | ) 27 | exam_dataset = exam_maker.make_exam_questions(None) 28 | 29 | exam_prompter = ExamPrompter( 30 | demo_template="""Question:\n{question}\n\n[ROLESWITCHING assistant:]Solution:\n{rationale}\n\nFinal Answer:\n${answer}$\n""", 31 | exam_template="Question:\n{question}\n\n[ROLESWITCHING assistant:]Solution:\n", 32 | use_multi_round_conv=True, 33 | ) 34 | 35 | stub_teacher_items = [ 36 | { 37 | "question": "What is 10+8-4?", 38 | "rationale": "10+8=18, 18-4=14. So the answer value is 14.", 39 | "answer": "14", 40 | }, 41 | { 42 | "question": "What is the result of $\frac{6 \times 3}{2} ?", 43 | "rationale": ( 44 | "$6 \times 3 = 18$, $18/2=9$. $9$ should be the result of $\frac{6 \times" 45 | " 3}{2}." 46 | ), 47 | "answer": "9", 48 | }, 49 | ] 50 | teach_index = 0 51 | teaching_items = stub_teacher_items + [output_dataset[teach_index]] 52 | exam_index = 3 53 | exam_item = exam_dataset[exam_index] 54 | conv_template_type = "Qwen/Qwen-14B-Chat" 55 | prompt = exam_prompter.make_exam_prompt(teaching_items, exam_item, conv_template_type) 56 | 57 | exam_scorer = MathExamScorer() 58 | -------------------------------------------------------------------------------- /tests/test_ques_similarity_exam_maker.py: -------------------------------------------------------------------------------- 1 | from datasets import load_from_disk 2 | 3 | from lbt.base import Component 4 | from lbt.exam_maker import QuesSimilarityExamMaker 5 | 6 | teaching_dataset = ( 7 | load_from_disk( 8 | "../NLP-playground/examples/rationale/data/math_solution_worstRationale_10/" 9 | ) 10 | .select_columns(["problem", "answer", "solution"]) 11 | .rename_columns({"problem": "question", "solution": "rationale"}) 12 | ) 13 | 14 | exam_dataset = load_from_disk("../NLP-playground/examples/rationale/data/math_1500/") 15 | # the columns are: problem, solution, solution 16 | exam_dataset = exam_dataset.select_columns( 17 | ["problem", "answer", "solution"] 18 | ).rename_columns({"problem": "question", "solution": "rationale"}) 19 | 20 | exam_maker = QuesSimilarityExamMaker( 21 | exam_bank_dataset=exam_dataset, 22 | selected_indexes="range(0, 16)", 23 | num_exam_questions=4, 24 | ) 25 | 26 | s_exam_dataset_1t = exam_maker.make_exam_questions([teaching_dataset[0]]) 27 | print(s_exam_dataset_1t["question"]) 28 | s_exam_dataset_2t = exam_maker.make_exam_questions(teaching_dataset.to_list()[:2]) 29 | print(s_exam_dataset_2t["question"]) 30 | s_exam_dataset_3t = exam_maker.make_exam_questions(teaching_dataset.to_list()[3:6]) 31 | print(s_exam_dataset_3t["question"]) 32 | --------------------------------------------------------------------------------