├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── assets ├── images │ └── ac_overview.png └── pdf │ └── acecoder_v1.pdf ├── data ├── README.md ├── __init__.py ├── acecode_89k │ └── generate_main_dataset.py ├── acecode_pair_300k │ ├── __init__.py │ ├── convert_dataset_for_llama_factory.py │ ├── create_rm_dataset.py │ ├── create_rm_dataset.sh │ └── generate_main_pair_dataset.py ├── inference │ ├── ComputeAccuracy.py │ ├── Constants.py │ ├── EvaluateInferencedCode.py │ ├── GetDatasets.py │ ├── InferenceModels.py │ ├── Utility.py │ ├── __init__.py │ ├── native_inference.py │ ├── post_process_functions.py │ └── vllm_inference.py ├── setup.py ├── setup.sh ├── training_dataset │ ├── __init__.py │ ├── bigcode_python_fns │ │ ├── __init__.py │ │ ├── dataset.py │ │ ├── generate_test_cases.py │ │ └── preprocess.py │ ├── consolidate_dataset.py │ ├── constants.py │ ├── create_test_case_and_prompt.py │ ├── evaluate_inferenced_code.py │ ├── evaluate_inferenced_code.sh │ ├── evol │ │ ├── __init__.py │ │ ├── evol_dataset.py │ │ ├── generate_test_cases.py │ │ └── preprocess_evol.py │ ├── inference_generated_prompts.py │ ├── inference_generated_prompts.sh │ ├── oss │ │ ├── __init__.py │ │ ├── generate_test_cases.py │ │ ├── oss_dataset.py │ │ └── preprocess_oss.py │ └── util.py └── utility │ ├── __init__.py │ └── utility.py ├── examples └── run_acecoderm.py ├── setup.py ├── src └── acecoder │ ├── __init__.py │ ├── eval_test_cases.py │ ├── evalplus_eval.py │ └── rm_utils.py └── train ├── train_rl └── README.md └── train_rm ├── README.md └── configs ├── ds_z3_config.json ├── train_qwen_coder_ins_2.5_32b.yaml └── train_qwen_coder_ins_2.5_7b.yaml /.gitignore: -------------------------------------------------------------------------------- 1 | # Dataset Curation Cache Data 2 | data/inferenced output 3 | data/training_dataset/bigcode_python_fns/data 4 | data/training_dataset/evol/data 5 | data/training_dataset/oss/data 6 | data/cache 7 | data/*.jsonl 8 | data/generated_datasets 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/#use-with-ide 119 | .pdm.toml 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | # pytype static type analyzer 159 | .pytype/ 160 | 161 | # Cython debug symbols 162 | cython_debug/ 163 | 164 | # PyCharm 165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 167 | # and can be added to the global gitignore or merged into this file. For a more nuclear 168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 169 | #.idea/ 170 | .DS_Store 171 | 172 | /bak 173 | /test* 174 | /data 175 | /.vscode 176 | /*.json 177 | /*.ipynb 178 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "train/train_rl/OpenRLHF"] 2 | path = train/train_rl/OpenRLHF 3 | url = https://github.com/jdf-prog/OpenRLHF.git 4 | [submodule "train/train_rm/LLaMA-Factory"] 5 | path = train/train_rm/LLaMA-Factory 6 | url = https://github.com/hiyouga/LLaMA-Factory.git 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 wenhu chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🂡 AceCoder 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 |
16 | 17 | 18 | 19 | Authors: 20 | Huaye Zeng, 21 | Dongfu Jiang, 22 | HaoZhe Wang, 23 | Ping Nie, 24 | Xiaotong Chen, 25 | Wenhu Chen  @ 26 | TIGER-Lab   27 | 28 | 29 | ## 🔥News 30 | 31 | - [2025/2/3] We release the [AceCoder Paper](https://arxiv.org/abs/2502.01718), along with the [🤗 Models and Datasets](https://huggingface.co/collections/TIGER-Lab/acecoder-67a16011a6c7d65cad529eba) on Hugging Face. 32 | 33 | 34 | ## Overview 35 | ![./assets/images/ac_overview.png](./assets/images/ac_overview.png) 36 | 37 |
Abstract 38 | 39 | - We introduce AceCoder, the first work to propose a fully automated pipeline for synthesizing large-scale reliable tests used for the reward model training and reinforcement learning in the coding scenario. To do this, we curated the dataset [AceCode-87K](https://huggingface.co/datasets/TIGER-Lab/AceCode-87K), where we start from a seed code dataset and prompt powerful LLMs to "imagine" proper test cases for the coding question and filter the noisy ones. 40 | 41 | - We trained two reward model [AceCodeRM-7B](https://huggingface.co/TIGER-Lab/AceCodeRM-7B) and [AceCodeRM-32B](https://huggingface.co/TIGER-Lab/AceCodeRM-32B) on the constructed [preference pairs](https://huggingface.co/datasets/TIGER-Lab/AceCodePair-300K). Best-of-N sampling results on HumanEval(+), MBPP(+), BigCodeBench, LiveCodeBench (V4) show consistent improvement. 42 | 43 | - We perform RL training from three policy models: Qwen2.5-7B-Instruct and Qwen2.5-Coder-7B-Base and Qwen2.5-Coder-7B-Instruct. Two types of reward can be used, i.e. the trained reward model RM-7B and the rule-based reward, i.e. binary pass rate over the test cases in dataset. Additionaly, we also experiment with RL from the base model like DeepSeek-R1. Results show that directly RL from the Base Qwen2.5-Coder model can get **25%** improvement on HumanEval-plus and **6%** on MBPP-plus within just **80** optimization steps. 44 | 45 | - To our knowledge, this is the first work to propose a fully automated pipeline for synthesizing large-scale reliable tests used for the reward model training and reinforcement learning in the coding scenario. We believe our \dataset{} will unlock the potential of RL training for code generation models and help the community to further push the boundaries of LLM's coding abilities. 46 | 47 |
48 | 49 | ## 📚Dataset 50 | - [AceCode-87K](https://huggingface.co/datasets/TIGER-Lab/AceCode-87K): The first large-scale coding dataset with an average of 16 test cases per prompt, synthesized by GPT-4o-mini 51 | - [AceCodePair-300K](https://huggingface.co/datasets/TIGER-Lab/AceCodePair-300K): Constructed preference pairs from AceCode-87K for training reward model. 52 | - AceCode-87K-hard: where you can create sample 25% of the hard examples following commands [here](https://github.com/TIGER-AI-Lab/AceCoder/tree/main/train/train_rl#data-preparation) 53 | 54 | ## 🤗Model 55 | 56 | ### AceCodeRM (Reward Model) 57 | - [AceCodeRM-7B](https://huggingface.co/TIGER-Lab/AceCodeRM-7B): A reward model trained on AceCodePair-300K from Qwen2.5-Coder-7B-Instruct 58 | - [AceCodeRM-32B](https://huggingface.co/TIGER-Lab/AceCodeRM-32B): A reward model trained on AceCodePair-300K from Qwen2.5-Coder-32B-Instruct 59 | 60 | ### AceCoder (RL Model) 61 | | Initial Policy Model | Reward Type | Training dataset | Final RL Model | 62 | |:---------------------:|:-----------:|:----------------:|:--------------:| 63 | | Qwen2.5-7B-Instruct | AceCodeRM-7B | AceCode-87K-hard (22k) | [TIGER-Lab/AceCoder-Qwen2.5-7B-Ins-RM](https://huggingface.co/TIGER-Lab/AceCoder-Qwen2.5-7B-Ins-RM) | 64 | | Qwen2.5-7B-Instruct | Rule | AceCode-87K-hard (22k) | [TIGER-Lab/AceCoder-Qwen2.5-7B-Ins-Rule](https://huggingface.co/TIGER-Lab/AceCoder-Qwen2.5-7B-Ins-Rule) | 65 | | Qwen2.5-Coder-7B-Instruct | AceCodeRM-7B | AceCode-87K-hard (22k) | [TIGER-Lab/AceCoder-Qwen2.5-Coder-7B-Ins-RM](https://huggingface.co/TIGER-Lab/AceCoder-Qwen2.5-Coder-7B-Ins-RM) | 66 | | Qwen2.5-Coder-7B-Instruct | Rule | AceCode-87K-hard (22k) | [TIGER-Lab/AceCoder-Qwen2.5-Coder-7B-Ins-Rule](https://huggingface.co/TIGER-Lab/AceCoder-Qwen2.5-Coder-7B-Ins-Rule) | 67 | | Qwen2.5-Coder-7B | AceCodeRM-7B | AceCode-87K-hard (22k) | [TIGER-Lab/AceCoder-Qwen2.5-Coder-7B-Base-RM](https://huggingface.co/TIGER-Lab/AceCoder-Qwen2.5-Coder-7B-Base-RM) | 68 | | Qwen2.5-Coder-7B | Rule | AceCode-87K-hard (22k) | [TIGER-Lab/AceCoder-Qwen2.5-Coder-7B-Base-Rule](https://huggingface.co/TIGER-Lab/AceCoder-Qwen2.5-Coder-7B-Base-Rule) | 69 | 70 | ## 📈 Performance 71 | See our [website](https://tiger-ai-lab.github.io/AceCoder/) or [paper](https://arxiv.org/abs/2502.01718) for detailed performance report. 72 | 73 | ## 🚀Quick Start 74 | 75 | ```bash 76 | git submodule init 77 | git submodule update 78 | ``` 79 | 80 | ### Use AceCodrRM 81 | First install acecoder as a package: 82 | ```bash 83 | pip install https://github.com/TIGER-AI-Lab/AceCoder.git 84 | ``` 85 | Then see [examples/run_acecoderm.py](examples/run_acecoderm.py) for how to use AceCoderRM. Quick command `python examples/run_acecoderm.py` will run the example. 86 | 87 | ### Training Reward Model 88 | See [train/train_rm/README.md](train/train_rm/README.md) for detailed instructions. 89 | 90 | ### Training RL Model 91 | See [train/train_rl/README.md](train/train_rl/README.md) for detailed instructions. 92 | 93 | ### Evaluation 94 | We use [Evalplus](https://github.com/evalplus/evalplus), [bigcodebench](https://github.com/bigcode-project/bigcodebench), [LiveCodeBench](https://github.com/LiveCodeBench/LiveCodeBench) for evaluation of HumanEval(+), MBPP(+), BigCodeBench, LiveCodeBench (V4) respectively. 95 | 96 | ## Citation 97 | If you find this work helpful, please consider citing: 98 | ```bibtex 99 | @article{AceCoder, 100 | title={AceCoder: Acing Coder RL via Automated Test-Case Synthesis}, 101 | author={Zeng, Huaye and Jiang, Dongfu and Wang, Haozhe and Nie, Ping and Chen, Xiaotong and Chen, Wenhu}, 102 | journal={ArXiv}, 103 | year={2025}, 104 | volume={2502.01718} 105 | } 106 | ``` 107 | -------------------------------------------------------------------------------- /assets/images/ac_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/AceCoder/e37ef2aea2e3710be6b06e9a91dd92c98321df96/assets/images/ac_overview.png -------------------------------------------------------------------------------- /assets/pdf/acecoder_v1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/AceCoder/e37ef2aea2e3710be6b06e9a91dd92c98321df96/assets/pdf/acecoder_v1.pdf -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # AceCode (Data Repository) 2 | Welcome to the data directory for the AceCode project. In this folder, you can find scripts / code used to recreate the AceCode-87K and AceCodePair-300k. 3 | **IMPORTANT: All instruction in this folder assumes your terminal is in the current folder (AceCoder/data/), please use ```cd data``` if you are not. We also use conda to manage our environment, so make sure you initialize to the correct interpreter:** 4 | 5 | ```bash 6 | conda init 7 | conda activate acecoder_data 8 | ``` 9 | 10 | ## Installation 11 | I assume you have **CUDA 12.1** and **conda** installed. With those, run the following command: 12 | 13 | ```bash 14 | source setup.sh 15 | ``` 16 | 17 | You will be prompted some y/n options for installing packages, type y then enter in each instance. 18 | 19 | Note if you have cuda 11.8, then you need to: 20 | 1. remove "vllm", "torch", and "xformers" from setup.py 21 | 2. uncomment the following code in setup.sh 22 | 23 | ```bash 24 | export VLLM_VERSION=0.2.6 25 | export PYTHON_VERSION=311 26 | pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl 27 | 28 | pip uninstall torch -y 29 | pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cu118 30 | 31 | pip uninstall xformers -y 32 | pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118 33 | ``` 34 | 35 | ## Dataset Curation 36 | Follow the following steps closely to create AceCode-87K and AceCodePair-300K. 37 | 38 | ### Download datasets from hub 39 | We will download the following datasets from huggingface and cache them locally: 40 | - Bigcode Python Functions: [bigcode/stack-dedup-python-fns](https://huggingface.co/datasets/bigcode/stack-dedup-python-fns) 41 | - OSS: [ise-uiuc/Magicoder-OSS-Instruct-75K-Instruction-Response](https://huggingface.co/datasets/ise-uiuc/Magicoder-OSS-Instruct-75K-Instruction-Response) 42 | - Evol Instruct: [ise-uiuc/Magicoder-Evol-Instruct-110K](https://huggingface.co/datasets/ise-uiuc/Magicoder-Evol-Instruct-110K) 43 | 44 | ```bash 45 | python training_dataset/bigcode_python_fns/preprocess.py 46 | python training_dataset/evol/preprocess_evol.py 47 | python training_dataset/oss/preprocess_oss.py 48 | ``` 49 | 50 | ### Use GPT-4o-mini to convert seed code data into LeetCode-Like questions and test cases 51 | First add the following environment variable to your shell (you can also add this to ~/.bashrc): 52 | ```bash 53 | export OPENAI_API_KEYS="sk-your-openai-api-key" 54 | export OPENAI_API_TYPE="OpenAI" 55 | ``` 56 | 57 | If you just want to sample a few questions to try it out, run (cost less than $1 USD): 58 | ```bash 59 | python training_dataset/bigcode_python_fns/generate_test_cases.py --ct=50 60 | python training_dataset/evol/generate_test_cases.py --ct=50 61 | python training_dataset/oss/generate_test_cases.py --ct=50 62 | ``` 63 | 64 | If you want to fully recreate our dataset, run (this will cost you around $300 USD): 65 | ```bash 66 | python training_dataset/bigcode_python_fns/generate_test_cases.py --ct=50000 67 | python training_dataset/evol/generate_test_cases.py --ct=-1 68 | python training_dataset/oss/generate_test_cases.py --ct=-1 69 | ``` 70 | 71 | ### Creating Inferences for the generated leetcode-like prompts 72 | Run the following to create inferences for the generated LeetCode-like prompts. This process is GPU heavy and you may want to set CUDA_VISIBLE_DEVICES if you do not wish to run the process on all of your gpus. 73 | ```bash 74 | source training_dataset/inference_generated_prompts.sh 75 | ``` 76 | 77 | ### Evaluate the inferenced code 78 | **Note: this may drain up your CPU resources and it may also make unpredictable changes to your file system since we are executing generated code. You may want to run it in a docker for your safety.** 79 | 80 | Run the following to compute the accuracies for the generated code: 81 | ```bash 82 | source training_dataset/evaluate_inferenced_code.sh 83 | ``` 84 | 85 | ### Consolidate the dataset 86 | Run: 87 | ```bash 88 | python data/training_dataset/consolidate_dataset.py 89 | ``` 90 | 91 | ### Creating AceCode-98K 92 | run: 93 | ```bash 94 | python acecode_87K/generate_main_dataset.py 95 | ``` 96 | 97 | ### Creating AceCodePair-300K 98 | run: 99 | ```bash 100 | source acecode_pair_300k/create_rm_dataset.sh 101 | ``` -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/AceCoder/e37ef2aea2e3710be6b06e9a91dd92c98321df96/data/__init__.py -------------------------------------------------------------------------------- /data/acecode_89k/generate_main_dataset.py: -------------------------------------------------------------------------------- 1 | from training_dataset.constants import DATASET_LST 2 | from utility.utility import load_jsonl, save_jsonl 3 | 4 | # Use this script to generate the "AceCode-87K dataset, which contains the questions, tests, inferences, etc." 5 | 6 | 7 | def generate_entries(oracle_model: str, save_path: str): 8 | out = [] 9 | for dataset_name in DATASET_LST: 10 | jsonl_file_name = ( 11 | f"training_dataset/{dataset_name}/data/v3_{oracle_model}.jsonl" 12 | ) 13 | data = load_jsonl(jsonl_file_name) 14 | for entry in data: 15 | id = entry["id"] 16 | prompt = entry["prompt"] 17 | tests = entry["tests"] 18 | inferences = entry["inferences"] 19 | if prompt is None or len(tests) < 5: 20 | continue 21 | inferences_out = [] 22 | for i, (inf, acc, inf_model) in enumerate(inferences): 23 | inferences_out.append( 24 | { 25 | "model_name": inf_model, 26 | "completion_id": i, 27 | "completion": inf, 28 | "pass_rate": acc, 29 | } 30 | ) 31 | out.append( 32 | { 33 | "id": id, 34 | "source": dataset_name, 35 | "question": prompt, 36 | "test_cases": tests, 37 | "inferences": inferences_out, 38 | "context_messages": [{"content": prompt, "role": "user"}], 39 | } 40 | ) 41 | 42 | save_jsonl(save_path, out) 43 | 44 | 45 | if __name__ == "__main__": 46 | oracle_model = "qwen_coder_2.5_32b_greedy" 47 | save_path = "AceCode-87K.jsonl" 48 | generate_entries(oracle_model=oracle_model, save_path=save_path) 49 | -------------------------------------------------------------------------------- /data/acecode_pair_300k/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/AceCoder/e37ef2aea2e3710be6b06e9a91dd92c98321df96/data/acecode_pair_300k/__init__.py -------------------------------------------------------------------------------- /data/acecode_pair_300k/convert_dataset_for_llama_factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import List 4 | 5 | from training_dataset.constants import DATASET_LST, MODEL_LST 6 | from utility.utility import load_jsonl 7 | 8 | template_input_str = """Below is a Python script with a self-contained function that solves the problem and passes corresponding tests: 9 | ```python 10 | """ 11 | 12 | 13 | def remove_assert_statment(input_str: str): 14 | lst = input_str.splitlines() 15 | lst = [i for i in lst if not i.strip().startswith("assert")] 16 | return "\n".join(lst) 17 | 18 | 19 | def get_mbpp_style_prompt(program1: str, program2: str, prompt: str, test: List[str]): 20 | """ 21 | Create a prompt with the following style: 22 | Write a function to find the shared elements from the given two lists. 23 | assert set(similar_elements((3, 4, 5, 6),(5, 7, 4, 10))) == set((4, 5)) 24 | """ 25 | mbpp_prompt = f'"""\n{prompt}\n{test[0].strip()}\n"""\n' 26 | mbpp_prompt = prompt_final_post_process(mbpp_prompt) 27 | output = { 28 | "instruction": mbpp_prompt, 29 | "input": "", 30 | "chosen": template_input_str + program1 + "\n```", 31 | "rejected": template_input_str + program2 + "\n```", 32 | } 33 | return output 34 | 35 | 36 | def compare_starting_code(prog1: str, prog2: str): 37 | """See if two starting code are kind of similar, by splitting by lines""" 38 | prog1 = [line.strip() for line in prog1.splitlines()] 39 | prog2 = [line.strip() for line in prog2.splitlines()] 40 | prog1 = [i for i in prog1 if len(i) > 0] 41 | prog2 = [i for i in prog2 if len(i) > 0] 42 | if len(prog1) != len(prog2): 43 | return False 44 | for i in prog1: 45 | if i not in prog2: 46 | return False 47 | return True 48 | 49 | 50 | def prompt_final_post_process(input_str: str) -> str: 51 | out = f"""Please provide a self-contained Python script that solves the following problem in a markdown code block: 52 | ``` 53 | {input_str} 54 | ``` 55 | """ 56 | return out 57 | 58 | 59 | def remove_start_header(input_str: str): 60 | START_HEADERS = ["<|start_header_id|>assistant<|end_header_id|>"] 61 | for i in START_HEADERS: 62 | if input_str.startswith(i): 63 | return input_str[len(i) :].strip() 64 | return input_str.strip() 65 | 66 | 67 | def convert_dataset( 68 | dataset_lst: List[str], 69 | model_name: str = "cross_model", 70 | return_size: int = 1, 71 | ): 72 | """Convert the dataset 73 | 74 | Parameter: 75 | model_name: the model name used to create the DPO dataset 76 | return_size: how many entries max can be generated from each question 77 | """ 78 | out = [] 79 | for dataset in dataset_lst: 80 | file_name = ( 81 | f"training_dataset/{dataset}/data/dpo_{model_name}_{return_size}.jsonl" 82 | ) 83 | lst = load_jsonl(file_name) 84 | for i in lst: 85 | prog1 = i["program_1"] 86 | prog2 = i["program_2"] 87 | prompt = get_mbpp_style_prompt( 88 | program1=remove_start_header(prog1), 89 | program2=remove_start_header(prog2), 90 | prompt=i["prompt"], 91 | test=i["tests"], 92 | ) 93 | out.append(prompt) 94 | 95 | # print(f"MBPP style ct: {mbpp_style_ct}, humaneval style ct: {human_eval_style_ct}") 96 | os.makedirs("generated_datasets", exist_ok=True) 97 | out_file_name = f"generated_datasets/dpo_{model_name}_{return_size}.json" 98 | with open(out_file_name, "w") as f: 99 | f.write(json.dumps(out, indent=4)) 100 | print(f"Finished Generating for {model_name}, {dataset_lst}") 101 | 102 | 103 | if __name__ == "__main__": 104 | for model in list(MODEL_LST.keys()): 105 | convert_dataset( 106 | dataset_lst=DATASET_LST, 107 | model_name=model, 108 | return_size="inf", 109 | scaled=False, 110 | ) 111 | -------------------------------------------------------------------------------- /data/acecode_pair_300k/create_rm_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple 2 | 3 | from training_dataset.constants import DATASET_LST, MODEL_LST 4 | from utility.utility import load_jsonl, save_jsonl 5 | 6 | 7 | def create_cross_model_dataset( 8 | dataset_name: str, return_size: str, oracle_model_name: str 9 | ): 10 | """Create DPO dataset but the data can come from either any model""" 11 | dataset = [] 12 | data = load_jsonl( 13 | f"training_dataset/{dataset_name}/data/v3_{oracle_model_name}.jsonl" 14 | ) 15 | 16 | generated_lst = [] # this list tracks the number of generated entries 17 | for i in range(len(data)): 18 | # tuple in the form: (code, accuracy, model name) 19 | if data[i]["prompt"] is None: 20 | # we skip because it's None 21 | test_cases = [] 22 | elif not data[i]["prompt"].isascii(): 23 | # we skip because it contains non-ascii code 24 | test_cases = [] 25 | else: 26 | test_cases = create_dataset_helper_1( 27 | data[i]["inferences"], 28 | data[i]["prompt"], 29 | data[i]["tests"], 30 | return_size=return_size, 31 | ) 32 | dataset += test_cases 33 | # print(f"question {i} generated {len(test_cases)} entries") 34 | generated_lst.append(len(test_cases)) 35 | no_entry_ct = len([i for i in generated_lst if i == 0]) 36 | print( 37 | f"{dataset_name} - cross model - {return_size} - generated {sum(generated_lst)} entries, average yield is {sum(generated_lst) / len(data)}. {no_entry_ct} entries (or {no_entry_ct / len(data) * 100}%) produced no entires at all." 38 | ) 39 | 40 | save_jsonl( 41 | f"training_dataset/{dataset_name}/data/dpo_cross_models_{return_size}.jsonl", 42 | dataset, 43 | ) 44 | return generated_lst 45 | 46 | 47 | def create_dataset_with_only_one_model( 48 | model_name: str, dataset_name: str, oracle_model_name: str, return_size: int = 1 49 | ): 50 | """Create DPO dataset but using questions from only 1 model""" 51 | dataset = [] 52 | data = load_jsonl( 53 | f"training_dataset/{dataset_name}/data/v3_{oracle_model_name}.jsonl" 54 | ) 55 | 56 | generated_lst = [] # this list tracks the number of generated entries 57 | for i in range(len(data)): 58 | if data[i]["prompt"] is None: 59 | # we skip because it's None 60 | test_cases = [] 61 | elif not data[i]["prompt"].isascii(): 62 | # we skip because it contains non-ascii code 63 | test_cases = [] 64 | else: 65 | test_cases = create_dataset_helper_1( 66 | data[i]["inferences"], 67 | data[i]["prompt"], 68 | data[i]["tests"], 69 | specific_model_name=model_name, 70 | return_size=return_size, 71 | ) 72 | dataset += test_cases 73 | # print(f"question {i} generated {len(test_cases)} entries") 74 | generated_lst.append(len(test_cases)) 75 | no_entry_ct = len([i for i in generated_lst if i == 0]) 76 | print( 77 | f"{dataset_name} - {model_name} - {return_size} - generated {sum(generated_lst)} entries, average yield is {sum(generated_lst) / len(data)}. {no_entry_ct} entries (or {no_entry_ct / len(data) * 100}%) produced no entires at all." 78 | ) 79 | 80 | save_jsonl( 81 | f"training_dataset/{dataset_name}/data/dpo_{model_name}_{return_size}.jsonl", 82 | dataset, 83 | ) 84 | return generated_lst 85 | 86 | 87 | def create_dataset_helper_1( 88 | inferences: List[Tuple], 89 | prompt: Any, 90 | tests: Any, 91 | specific_model_name: str = None, 92 | return_size: int = 1, 93 | ) -> List[Dict]: 94 | """Create a dataset for 1 prompt, this is a helper function for the overall create_dataset function. 95 | 96 | Input: 97 | inferences: a list of tuple in the form: (code, accuracy, model name) 98 | prompt: the prompt for the question, will be appended to each test case 99 | tests: the tests for each question, will be appended to each test case 100 | specific_model_name: if you only want to generate test cases from 1 model 101 | 102 | output: 103 | A list of dictionary, each representing one question 104 | """ 105 | inferences_by_model = {} 106 | for program, acc, model in inferences: 107 | if specific_model_name is not None: 108 | if model != specific_model_name: 109 | continue 110 | if model in inferences_by_model: 111 | inferences_by_model[model].append((program, acc, model)) 112 | else: 113 | inferences_by_model[model] = [(program, acc, model)] 114 | 115 | model_lst = list(inferences_by_model.keys()) 116 | output = [] 117 | 118 | # we first generate test cases for each model 119 | for model in model_lst: 120 | output += create_dataset_helper_2( 121 | inferences_by_model[model], return_size=return_size 122 | ) 123 | 124 | for i in output: 125 | i.update({"prompt": prompt, "tests": tests}) 126 | 127 | return output 128 | 129 | 130 | def create_dataset_helper_2( 131 | inferences: List[Tuple], return_size: int = 3, require_different_model: bool = False 132 | ) -> List[Dict]: 133 | """Create a dataset for 1 prompt, this is a helper function for the overall create_dataset function. 134 | 135 | Input: 136 | tuple in the form: (code, accuracy, model name) 137 | return_size: how many entries do you want the program to return. If less than 1 then will return all entries 138 | require_different_model: should the return models be from different dataset 139 | 140 | output: 141 | A list of dictionary, each representing one question 142 | """ 143 | output = [] 144 | seen = set() 145 | inferences = sorted(inferences, key=lambda x: x[1], reverse=True) 146 | highest_acc = inferences[0][1] 147 | for j in range(len(inferences) - 1): 148 | for k in range(len(inferences) - 1, j, -1): 149 | prog1, acc1, model1 = inferences[j] 150 | prog2, acc2, model2 = inferences[k] 151 | if require_different_model and model1 == model2: 152 | continue 153 | if ( 154 | max(acc1, acc2) >= 0.8 155 | and abs(acc2 - acc1) >= 0.4 156 | and min(acc1, acc2) > 0 157 | ): 158 | ram = prog1 + prog2 159 | if ram not in seen: 160 | # we prevent duplicates 161 | entry = { 162 | "program_1": prog1, 163 | "program_2": prog2, 164 | "winner": 1, 165 | "accuracy_1": acc1, 166 | "accuracy_2": acc2, 167 | "accuracy_difference": abs(acc1 - acc2), 168 | "model_1": model1, 169 | "model_2": model2, 170 | } 171 | output.append(entry) 172 | seen.add(ram) 173 | if return_size != "inf": 174 | if len(output) >= return_size and return_size > 0: 175 | return output 176 | return output 177 | 178 | 179 | if __name__ == "__main__": 180 | oracle_model_name = "qwen_coder_2.5_32b_greedy" 181 | for dataset in DATASET_LST: 182 | for model in MODEL_LST.keys(): 183 | create_dataset_with_only_one_model( 184 | model_name=model, 185 | dataset_name=dataset, 186 | return_size="inf", 187 | oracle_model_name=oracle_model_name, 188 | ) 189 | -------------------------------------------------------------------------------- /data/acecode_pair_300k/create_rm_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python acecode_pair_300k/create_rm_dataset.py 4 | python acecode_pair_300k/convert_dataset_for_llama_factory.py 5 | python acecode_pair_300k/generate_main_pair_dataset.py -------------------------------------------------------------------------------- /data/acecode_pair_300k/generate_main_pair_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from utility.utility import save_jsonl 4 | 5 | # Use this script to generate the "AceCode-87K dataset, which contains the questions, tests, inferences, etc." 6 | 7 | 8 | def generate_entries(inf_model_name: str, save_path: str): 9 | json_file_name = f"generated_datasets/dpo_{inf_model_name}_inf.json" 10 | with open(json_file_name, "r") as f: 11 | data = json.load(f) 12 | 13 | for i in range(len(data)): 14 | data[i]["id"] = i 15 | 16 | save_jsonl(save_path, data) 17 | 18 | 19 | if __name__ == "__main__": 20 | inf_model_name = "qwen_coder_2.5" 21 | save_path = "AceCodePair-300K.jsonl" 22 | generate_entries(inf_model_name=inf_model_name, save_path=save_path) 23 | -------------------------------------------------------------------------------- /data/inference/ComputeAccuracy.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from collections import defaultdict 5 | from typing import Any, Dict, List 6 | 7 | import numpy as np 8 | 9 | 10 | def _load_processed_data( 11 | key: str, 12 | dataset_name: str, 13 | model_name: str, 14 | test_set_name: str = "default", 15 | ) -> Dict[int, List[Any]]: 16 | """get the saved processed, output is a dictionary where the key is the index and the value is the list of whatever field requested""" 17 | file_path = ( 18 | f"inferenced output/{dataset_name}/processed/{test_set_name}/{model_name}.jsonl" 19 | ) 20 | if not os.path.exists(file_path): 21 | raise Exception(f"No saved inference between {dataset_name} + {model_name}") 22 | 23 | with open(file_path, "r") as f: 24 | lst = f.readlines() 25 | lst = [json.loads(i) for i in lst] 26 | out = defaultdict(list) 27 | for i in lst: 28 | out[i["id"]].append(i[key]) 29 | return dict(out) 30 | 31 | 32 | def load_processed_model_accuracy( 33 | dataset_name: str, model_name: str, test_set_name: str = "default" 34 | ) -> Dict[int, List[str]]: 35 | """get the saved processed, output is a dictionary where the key is the index and the value is the list of accuracy""" 36 | return _load_processed_data( 37 | key="accuracy", 38 | dataset_name=dataset_name, 39 | model_name=model_name, 40 | test_set_name=test_set_name, 41 | ) 42 | 43 | 44 | def load_processed_model_tests_status( 45 | dataset_name: str, model_name: str, test_set_name: str = "default" 46 | ) -> Dict[int, List[str]]: 47 | """get the saved processed, output is a dictionary where the key is the index and the value is the list of accuracy""" 48 | return _load_processed_data( 49 | key="test_case_status", 50 | dataset_name=dataset_name, 51 | model_name=model_name, 52 | test_set_name=test_set_name, 53 | ) 54 | 55 | 56 | def get_oracle_test_case_status( 57 | dataset_name: str, model_name: str, test_set_name: str = "default" 58 | ) -> Dict[int, List[float]]: 59 | """For rach test case, if any one inference passed, we will consider that test case as passed""" 60 | data = load_processed_model_tests_status( 61 | dataset_name=dataset_name, model_name=model_name, test_set_name=test_set_name 62 | ) 63 | data = {k: np.max(np.array(v), axis=0).tolist() for k, v in data.items()} 64 | return data 65 | 66 | 67 | def get_oracle_accuracy( 68 | dataset_name: str, model_name: str, test_set_name: str = "default" 69 | ) -> float: 70 | """Compute the accuracy if you randomly select from the answer set. Note, prior to running 71 | this function you should have a jsonl file in f"inferenced output/{dataset_name}/processed/{model_name}.jsonl" 72 | where each line is a json object with the following 3 fields: id, accuracy, and response. 73 | 74 | Parameter: 75 | dataset_name 76 | model_name 77 | 78 | Return: 79 | the oracle accuracy 80 | """ 81 | accuracy_dict = load_processed_model_accuracy( 82 | dataset_name, model_name, test_set_name 83 | ) 84 | max_acc_lst = [max(accuracy_dict[i]) for i in accuracy_dict] 85 | return sum(max_acc_lst) / len(max_acc_lst) 86 | 87 | 88 | def get_random_select_accuracy( 89 | dataset_name: str, 90 | model_name: str, 91 | test_set_name: str = "default", 92 | sample_ct: int = 10, 93 | ) -> float: 94 | """Compute the accuracy if you randomly select from the answer set. Note, prior to running 95 | this function you should have a jsonl file in f"inferenced output/{dataset_name}/processed/{model_name}.jsonl" 96 | where each line is a json object with the following 3 fields: id, accuracy, and response. 97 | 98 | Parameter: 99 | dataset_name 100 | model_name 101 | sample_ct: an integer indicating how many time you would like the 102 | program to do random sampling. 103 | 104 | Return: 105 | the randomly selectred accuracy 106 | """ 107 | if sample_ct <= 0: 108 | raise Exception(f"sample_ct must be at least one, {sample_ct} provided") 109 | accuracy_dict = load_processed_model_accuracy( 110 | dataset_name, model_name, test_set_name=test_set_name 111 | ) 112 | max_acc_lst = [ 113 | [random.choice(accuracy_dict[idx]) for i in range(sample_ct)] 114 | for idx in accuracy_dict 115 | ] 116 | max_acc_lst = [sum(i) / len(i) for i in max_acc_lst] 117 | return sum(max_acc_lst) / len(max_acc_lst) 118 | 119 | 120 | def get_average_select_accuracy( 121 | dataset_name: str, 122 | model_name: str, 123 | test_set_name: str = "default", 124 | sample_ct: int = 10, 125 | ) -> float: 126 | """Compute the accuracy if you average from the answer set. Note, prior to running 127 | this function you should have a jsonl file in f"inferenced output/{dataset_name}/processed/{model_name}.jsonl" 128 | where each line is a json object with the following 3 fields: id, accuracy, and response. 129 | 130 | Parameter: 131 | dataset_name 132 | model_name 133 | sample_ct: an integer indicating how many time you would like the 134 | program to do random sampling. 135 | 136 | Return: 137 | the average accuracy 138 | """ 139 | if sample_ct <= 0: 140 | raise Exception(f"sample_ct must be at least one, {sample_ct} provided") 141 | accuracy_dict = load_processed_model_accuracy( 142 | dataset_name, model_name, test_set_name=test_set_name 143 | ) 144 | max_acc_lst = [ 145 | sum(accuracy_dict[idx]) / len(accuracy_dict[idx]) for idx in accuracy_dict 146 | ] 147 | return sum(max_acc_lst) / len(max_acc_lst) 148 | 149 | 150 | def get_greedy_accuracy( 151 | dataset_name: str, model_name: str, test_set_name: str = "default" 152 | ) -> float: 153 | """Compute the accuracy if you randomly select from the answer set. Note, prior to running 154 | this function you should have a jsonl file in f"inferenced output/{dataset_name}/processed/{model_name}.jsonl" 155 | where each line is a json object with the following 3 fields: id, accuracy, and response. 156 | Moreover, if you give a best of n model, then this function will return the accuracy of the first 157 | response for each question. 158 | 159 | Parameter: 160 | dataset_name 161 | model_name 162 | 163 | Return: 164 | the one shot accuracy 165 | """ 166 | accuracy_dict = load_processed_model_accuracy( 167 | dataset_name, model_name, test_set_name=test_set_name 168 | ) 169 | max_acc_lst = [accuracy_dict[idx][0] for idx in accuracy_dict] 170 | return sum(max_acc_lst) / len(max_acc_lst) 171 | -------------------------------------------------------------------------------- /data/inference/Constants.py: -------------------------------------------------------------------------------- 1 | # Code models related constants 2 | MBPP = "mbpp" 3 | MBPP_LENGTH = 500 4 | 5 | MBPPPLUS = "mbppplus" 6 | MBPPPLUS_LENGTH = 378 7 | 8 | HUMANEVAL = "humaneval" 9 | HUMANEVAL_LENGTH = 164 10 | 11 | HUMANEVALPLUS = "humanevalplus" 12 | HUMANEVALPLUS_LENGTH = 164 13 | 14 | LIVECODEBENCH = "livecodebench" 15 | LIVECODEBENCH_LENGTH = 714 16 | 17 | STABLE_CODE = "stable_code" 18 | STARCODER2 = "starcoder2" 19 | CODELLAMA = "codellama" 20 | DEEPSEEK = "deepseek" 21 | LEETCODE = "leetcode" 22 | LLAMA = "llama" 23 | CODEQWEN = "code_qwen" 24 | WIZARDCODER = "wizardcoder" 25 | MISTRAL = "mistral" 26 | QWEN_CODER = "qwen_coder" 27 | NXCODE = "nxcode" 28 | 29 | # this is the generic sampling parameters that will be passed for different type of inferences 30 | GENERIC_SAMPLING_PARAMETERS = { 31 | "greedy": { 32 | "temperature": 0, 33 | }, 34 | "best_of_n_diverse_beam_search": { 35 | "num_beams": 16, 36 | "num_return_sequences": 8, 37 | "num_beam_groups": 8, 38 | "early_stopping": True, 39 | "diversity_penalty": 1.0, 40 | }, 41 | "best_of_n_beam_search": { 42 | "n": 16, 43 | "temperature": 0, 44 | "use_beam_search": True, 45 | "early_stopping": True, 46 | }, 47 | "best_of_n_top_p_sampling": { 48 | "n": 16, 49 | "top_p": 1.0, 50 | "temperature": 0.7, 51 | }, 52 | "best_of_n_top_k_sampling": { 53 | "n": 16, 54 | "top_k": 50, 55 | "temperature": 0.7, 56 | }, 57 | } 58 | 59 | MODEL_PATH = { 60 | "athena_v2": {"72b": "Nexusflow/Athene-V2-Chat"}, 61 | f"{DEEPSEEK}_coder": { 62 | "1.3b": "deepseek-ai/deepseek-coder-1.3b-instruct", 63 | "6.7b": "deepseek-ai/deepseek-coder-6.7b-instruct", 64 | "7b": "deepseek-ai/deepseek-coder-7b-instruct-v1.5", 65 | "33b": "deepseek-ai/deepseek-coder-33b-instruct", 66 | }, 67 | f"{DEEPSEEK}_coder_v2": {"16b": "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"}, 68 | f"{CODEQWEN}_v1.5": { 69 | "7b": "Qwen/CodeQwen1.5-7B-Chat", 70 | }, 71 | f"{QWEN_CODER}_2.5": { 72 | "7b": "Qwen/Qwen2.5-Coder-7B-Instruct", 73 | "32b": "Qwen/Qwen2.5-Coder-32B-Instruct", 74 | }, 75 | f"{NXCODE}_cq_orpo": { 76 | "7b": "NTQAI/Nxcode-CQ-7B-orpo", 77 | }, 78 | STARCODER2: { 79 | "3b": "bigcode/starcoder2-3b", 80 | "7b": "bigcode/starcoder2-7b", 81 | "15b": "bigcode/starcoder2-15b", 82 | }, 83 | f"{MISTRAL}_instruct_v3": { 84 | "7b": "mistralai/Mistral-7B-Instruct-v0.3", 85 | }, 86 | f"{CODELLAMA}": { 87 | "7b": "codellama/CodeLlama-7b-hf", 88 | "13b": "codellama/CodeLlama-13b-hf", 89 | "34b": "codellama/CodeLlama-34b-hf", 90 | "70b": "codellama/CodeLlama-70b-hf", 91 | }, 92 | f"{CODELLAMA}_python": { 93 | "7b": "codellama/CodeLlama-7b-Python-hf", 94 | "13b": "codellama/CodeLlama-13b-Python-hf", 95 | "34b": "codellama/CodeLlama-34b-Python-hf", 96 | "70b": "codellama/CodeLlama-70b-Python-hf", 97 | }, 98 | f"{CODELLAMA}_instruct": { 99 | "7b": "codellama/CodeLlama-7b-Instruct-hf", 100 | "13b": "codellama/CodeLlama-13b-Instruct-hf", 101 | "34b": "codellama/CodeLlama-34b-Instruct-hf", 102 | "70b": "codellama/CodeLlama-70b-Instruct-hf", 103 | }, 104 | f"{LLAMA}3_instruct": { 105 | "8b": "meta-llama/Meta-Llama-3-8B-Instruct", 106 | "70b": "meta-llama/Meta-Llama-3-70B-Instruct", 107 | }, 108 | f"{LLAMA}3.1_instruct": { 109 | "8b": "meta-llama/Llama-3.1-8B-Instruct", 110 | "70b": "meta-llama/Llama-3.1-70B-Instruct", 111 | }, 112 | } 113 | -------------------------------------------------------------------------------- /data/inference/EvaluateInferencedCode.py: -------------------------------------------------------------------------------- 1 | import json 2 | import multiprocessing 3 | import os 4 | import shutil 5 | from ctypes import c_int, c_int64 6 | from typing import Callable, Dict, List, Optional 7 | 8 | from tqdm import tqdm 9 | 10 | from inference.Utility import (append_inference, get_saved_inference_index, 11 | load_saved_inference) 12 | 13 | # we save the current working directory and restore them later 14 | cwd = os.getcwd() 15 | cache_wd = cwd + "/cache" 16 | 17 | tmp_chmod = os.chmod 18 | tmp_fchmod = os.fchmod 19 | tmp_chdir = os.chdir 20 | tmp_rmdir = os.rmdir 21 | # tmp_open = open 22 | tmp_print = print 23 | tmp_rm_tree = shutil.rmtree 24 | tmp_unlink = os.unlink 25 | 26 | 27 | def process_one_model( 28 | model_name: str, 29 | dataset_name: str, 30 | tests: List[List[str]], 31 | test_set_name: str = "default", 32 | processing_func: Optional[Callable] = None, 33 | max_execution_time: float = 5.0, 34 | binary_grade: bool = False, 35 | fast_algo: bool = True, 36 | ) -> None: 37 | """Get the inferences in f"inferences output/{dataset_name}/{model_name}", calculate each entry's 38 | accuracy, then write to f"inferences output/{dataset_name}/processed/{test_set_name}/{model_name}" where: 39 | - each entry is appended with its accuracy 40 | - each entry is processed using processing_func 41 | We do this because 1) executing program can be costly, and we do not want to reexecute 42 | them each time. 2) For pairRM, we would like to present the whole program in an 43 | executable state, which might not be possible without post processing. 44 | 45 | Parameter: 46 | model_name: the model name, ex: "deepseek_coder_1.3b_greedy" 47 | dataset_name: the dataset name, ex: "mbpp" or "evol" 48 | tests: a list of a list of strings 49 | processing_func: the function which is used to process the output further before 50 | being evaluated. The default is to just return the whole string itself. 51 | max_execution_time: the maximum execution time given to each test case 52 | binary_grade: if set to True, will give a grade of 1 if all tests passed. However 53 | will give a grade of 0 if even one test failed. If set to false, the grade will 54 | be a floating number between 0 and 1. 55 | fast_algo: if set to True, then we will use a faster algorithm but less accurate (usually 56 | used for best of n). The slow version is much slower but more accurate (greedy) 57 | """ 58 | os.makedirs( 59 | f"inferenced output/{dataset_name}/processed/{test_set_name}/", exist_ok=True 60 | ) 61 | programs = load_saved_inference(dataset_name=dataset_name, model_name=model_name) 62 | 63 | processed_idx = get_saved_inference_index( 64 | dataset_name=f"{dataset_name}/processed/{test_set_name}", model_name=model_name 65 | ) 66 | programs = {i: programs[i] for i in programs if i >= processed_idx} 67 | if len(programs) == 0: 68 | return # we already processed this 69 | print( 70 | f"starting evaluating for {model_name} on {dataset_name} with starting index {processed_idx}" 71 | ) 72 | if processing_func: 73 | programs = { 74 | i: [processing_func(program) for program in programs[i]] for i in programs 75 | } 76 | 77 | output = [] 78 | for i in tqdm(programs.keys(), leave=False): 79 | # run the outputted program as it is 80 | for program_idx, program in enumerate(programs[i]): 81 | question_tests = tests[i] 82 | if len(question_tests) == 0: 83 | acc = 0 84 | test_case_status = [] 85 | elif fast_algo: 86 | test_case_status = get_successful_tests_fast( 87 | program=program, 88 | tests=question_tests, 89 | max_execution_time=max_execution_time, 90 | ) 91 | acc = sum(test_case_status) / len(test_case_status) 92 | else: 93 | test_case_status = get_successful_tests_slow( 94 | program=program, 95 | tests=question_tests, 96 | max_execution_time=max_execution_time, 97 | ) 98 | acc = sum(test_case_status) / len(test_case_status) 99 | if binary_grade and acc < 1: 100 | acc = 0 101 | output.append( 102 | json.dumps( 103 | { 104 | "id": i, 105 | "inference_id": program_idx, 106 | "response": program, 107 | "accuracy": acc, 108 | "test_case_status": test_case_status, 109 | } 110 | ) 111 | ) 112 | 113 | if (i + 1) % 25 == 0: 114 | append_inference( 115 | dataset_name=f"{dataset_name}/processed/{test_set_name}", 116 | model_name=model_name, 117 | lst=output, 118 | ) 119 | output = [] 120 | if len(output) > 0: 121 | append_inference( 122 | dataset_name=f"{dataset_name}/processed/{test_set_name}", 123 | model_name=model_name, 124 | lst=output, 125 | ) 126 | 127 | 128 | def process_one_model_after_remove_prompt( 129 | model_name: str, 130 | dataset_name: str, 131 | tests: List[List[str]], 132 | prompts: Dict[int, str], 133 | processing_func: Callable = (lambda x: x), 134 | ) -> None: 135 | """Get the inferences in f"inferences output/{dataset_name}/{model_name}", calculate each entry's 136 | accuracy, then write to f"inferences output/{dataset_name}/processed/{model_name}" where: 137 | - each entry is appended with its accuracy 138 | - each entry is processed using processing_func 139 | - if the prompt is at the beginning of the entry, then we would remove the prompt from the entry 140 | We do this because 1) executing program can be costly, and we do not want to reexecute 141 | them each time. 2) For pairRM, we would like to present the whole program in an 142 | executable state, which might not be possible without post processing. 143 | 144 | Parameter: 145 | model_name: the model name, ex: "deepseek_coder_1.3b_greedy" 146 | dataset_name: the dataset name, ex: "mbpp" or "evol" 147 | tests: a list of a list of strings 148 | prompts: a dictionary where the key is the id and the value is the prompt. We will remove the 149 | prompt from the inference first if it's present. 150 | processing_func: the function which is used to process the output further before 151 | being evaluated. The default is to just return the whole string itself. 152 | """ 153 | os.makedirs(f"inferenced output/{dataset_name}/processed/", exist_ok=True) 154 | programs = load_saved_inference(dataset_name=dataset_name, model_name=model_name) 155 | 156 | processed_idx = get_saved_inference_index( 157 | dataset_name=f"{dataset_name}/processed", model_name=model_name 158 | ) 159 | programs = {i: programs[i] for i in programs if i >= processed_idx} 160 | if len(programs) == 0: 161 | return # we already processed this 162 | 163 | for i in programs: 164 | tmp_lst = [] 165 | for program in programs[i]: 166 | if program.startswith(prompts[i]): 167 | program = program[len(prompts[i]) :] 168 | tmp_lst.append(processing_func(program)) 169 | programs[i] = tmp_lst 170 | 171 | output = [] 172 | for i in tqdm(programs, leave=True): 173 | # run the outputted program as it is 174 | for program in programs[i]: 175 | prog1 = program 176 | acc1 = get_successful_tests_fast(program=prog1, tests=tests[i]) 177 | output.append(json.dumps({"id": i, "response": prog1, "accuracy": acc1})) 178 | 179 | if (i + 1) % 25 == 0: 180 | append_inference( 181 | dataset_name=f"{dataset_name}/processed", 182 | model_name=model_name, 183 | lst=output, 184 | ) 185 | output = [] 186 | if len(output) > 0: 187 | append_inference( 188 | dataset_name=f"{dataset_name}/processed", model_name=model_name, lst=output 189 | ) 190 | 191 | 192 | # ------------------------------------------------------------- 193 | # The slow but accurate version 194 | # ------------------------------------------------------------- 195 | return_var = multiprocessing.Value(c_int, 0) 196 | 197 | 198 | def run_single_test_against_program_helper(func: str, test: str) -> int: 199 | """Return 1 if func finish running, 0 otherwise""" 200 | execution_context = {} 201 | execution_context.update({"__builtins__": __builtins__}) 202 | try: 203 | exec(func, execution_context) 204 | exec(test, execution_context) 205 | return_var.value = 1 206 | return 1 207 | except Exception as e: 208 | return_var.value = 0 209 | return 0 210 | 211 | 212 | # very unstable, seems to work, seeing there are no uses will be deprecated for now. 213 | def get_successful_tests_slow( 214 | program: str, tests: List[str], max_execution_time: float = 1.0 215 | ) -> List[int]: 216 | """Run a program against a list of tests, if the program exited successfully then we consider 217 | the test to be passed. Note that you SHOULD ONLY RUN THIS FUNCTION IN A VIRTUAL ENVIRONMENT 218 | as we do not gurantee the safety of the program provided. 219 | 220 | Parameter: 221 | program: a string representation of the python program you want to run 222 | tests: a list of assert statements which are considered to be the test cases 223 | max_execution_time: the number of second each individual test can run before 224 | it is considered failed and terminated 225 | 226 | Return: 227 | a list of 0/1 indicating passed or not""" 228 | test_ct = len(tests) 229 | if test_ct == 0: 230 | return [] 231 | if not should_execute(program=program, tests=tests): 232 | return [0] * len(tests) 233 | 234 | reliability_guard() 235 | result = [] 236 | for test in tests: 237 | return_var.value = 0 238 | p = multiprocessing.Process( 239 | target=run_single_test_against_program_helper, args=(program, test) 240 | ) 241 | p.start() 242 | p.join(timeout=max_execution_time) 243 | if p.is_alive(): 244 | p.kill() 245 | result.append(return_var.value) 246 | 247 | partial_undo_reliability_guard() 248 | return result 249 | 250 | 251 | # ------------------------------------------------------------- 252 | # The fast but not accurate version 253 | # ------------------------------------------------------------- 254 | return_var_2 = multiprocessing.Value(c_int64, 0) 255 | 256 | 257 | def run_tests_against_program_helper_2(func: str, tests: List[str]) -> float: 258 | """Return 1 if func finish running, 0 otherwise""" 259 | execution_context = {} 260 | execution_context.update({"__builtins__": __builtins__}) 261 | try: 262 | # try running the function declaration first 263 | exec(func, execution_context) 264 | # tmp_print("Function Executed Correctly") 265 | except Exception as e: 266 | # if this fails then tests will not work 267 | # tmp_print(e) 268 | return_var_2.value = 0 269 | return 0 270 | for idx, test in enumerate(tests): 271 | try: 272 | # try the tests individually 273 | exec(test, execution_context) 274 | return_var_2.value += 2**idx 275 | # tmp_print("a test passed") 276 | except Exception as e: 277 | # tmp_print(e) 278 | pass 279 | # tmp_print(f"Return value: {return_var.value}") 280 | # tmp_print("---------------------------") 281 | return return_var_2.value 282 | 283 | 284 | # very unstable, seems to work, seeing there are no uses will be deprecated for now. 285 | def get_successful_tests_fast( 286 | program: str, tests: List[str], max_execution_time: float = 1.0 287 | ) -> List[int]: 288 | """Run a program against a list of tests, if the program exited successfully then we consider 289 | the test to be passed. Note that you SHOULD ONLY RUN THIS FUNCTION IN A VIRTUAL ENVIRONMENT 290 | as we do not gurantee the safety of the program provided. 291 | 292 | Parameter: 293 | program: a string representation of the python program you want to run 294 | tests: a list of assert statements which are considered to be the test cases 295 | max_execution_time: the number of second each individual test can run before 296 | it is considered failed and terminated 297 | 298 | Return: 299 | a list of 0/1 indicating passed or not""" 300 | test_ct = len(tests) 301 | if test_ct == 0: 302 | return [] 303 | if not should_execute(program=program, tests=tests): 304 | return [0] * len(tests) 305 | 306 | reliability_guard() 307 | return_var_2.value = 0 308 | p = multiprocessing.Process( 309 | target=run_tests_against_program_helper_2, args=(program, tests) 310 | ) 311 | p.start() 312 | p.join(timeout=max_execution_time) 313 | if p.is_alive(): 314 | p.kill() 315 | 316 | partial_undo_reliability_guard() 317 | 318 | num = int(return_var_2.value) 319 | return [(num >> i) & 1 for i in range(len(tests))] 320 | 321 | 322 | # ------------------------------------------------------------- 323 | # Utility 324 | # ------------------------------------------------------------- 325 | 326 | 327 | def should_execute(program: str, tests: List[str]) -> bool: 328 | """Determine if we should try to execute this program at all for safety 329 | reasons.""" 330 | dangerous_commands = [ 331 | "threading", 332 | "multiprocess", 333 | "multiprocessing", 334 | "import os", 335 | "from os", 336 | "shutil", 337 | "import torch", 338 | "from torch", 339 | "import sklearn", 340 | "from sklearn", 341 | ] 342 | for comm in dangerous_commands: 343 | if comm in program: 344 | return False # assume the program fails 345 | return True 346 | 347 | 348 | # ------------------------------------------------------------- 349 | # For safety handling 350 | # ------------------------------------------------------------- 351 | 352 | 353 | def partial_undo_reliability_guard(): 354 | """Undo the chmod, fchmod, print and open operation""" 355 | import builtins 356 | 357 | os.chmod = tmp_chmod 358 | os.fchmod = tmp_fchmod 359 | os.chdir = tmp_chdir 360 | os.unlink = tmp_unlink 361 | os.rmdir = tmp_rmdir 362 | # shutil.rmtree = tmp_rmtree 363 | # builtins.open = tmp_open 364 | builtins.print = tmp_print 365 | 366 | # restore working directory 367 | os.chdir(cwd) 368 | # shutil.rmtree(cache_wd) 369 | shutil.rmtree = tmp_rm_tree 370 | 371 | 372 | def reliability_guard(maximum_memory_bytes: Optional[int] = None): 373 | """ 374 | This function is copied from https://github.com/openai/human-eval/blob/master/human_eval/execution.py. 375 | It disables various destructive functions and prevents the generated code 376 | from interfering with the test (e.g. fork bomb, killing other processes, 377 | removing filesystem files, etc.) 378 | 379 | WARNING 380 | This function is NOT a security sandbox. Untrusted code, including, model- 381 | generated code, should not be blindly executed outside of one. See the 382 | Codex paper for more information about OpenAI's code sandbox, and proceed 383 | with caution. 384 | """ 385 | import faulthandler 386 | import platform 387 | 388 | if maximum_memory_bytes is not None: 389 | import resource 390 | 391 | resource.setrlimit( 392 | resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes) 393 | ) 394 | resource.setrlimit( 395 | resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes) 396 | ) 397 | if not platform.uname().system == "Darwin": 398 | resource.setrlimit( 399 | resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes) 400 | ) 401 | 402 | faulthandler.disable() 403 | 404 | import builtins 405 | 406 | builtins.exit = None 407 | builtins.quit = None 408 | # builtins.open = None 409 | builtins.print = lambda *args, **kwargs: None 410 | 411 | import os 412 | 413 | # we save the current working directory and restore them later 414 | os.makedirs(cache_wd, exist_ok=True) 415 | os.chdir(cache_wd) 416 | 417 | # os.environ["OMP_NUM_THREADS"] = "1" 418 | # os.kill = None 419 | os.system = None 420 | os.putenv = None 421 | os.remove = None 422 | os.removedirs = None 423 | os.rmdir = None 424 | os.fchdir = None 425 | os.setuid = None 426 | # os.fork = None 427 | os.forkpty = None 428 | os.killpg = None 429 | os.rename = None 430 | os.renames = None 431 | os.truncate = None 432 | os.replace = None 433 | os.unlink = None 434 | os.fchmod = None 435 | os.fchown = None 436 | os.chmod = None 437 | os.chown = None 438 | os.chroot = None 439 | os.fchdir = None 440 | os.lchflags = None 441 | os.lchmod = None 442 | os.lchown = None 443 | os.getcwd = None 444 | os.chdir = None 445 | 446 | import shutil 447 | 448 | shutil.rmtree = None 449 | shutil.move = None 450 | shutil.chown = None 451 | 452 | import subprocess 453 | 454 | subprocess.Popen = None # type: ignore 455 | 456 | # __builtins__['help'] = None 457 | 458 | import sys 459 | 460 | sys.modules["ipdb"] = None 461 | sys.modules["joblib"] = None 462 | sys.modules["resource"] = None 463 | sys.modules["psutil"] = None 464 | sys.modules["tkinter"] = None 465 | 466 | 467 | if __name__ == "__main__": 468 | # for testing purpose 469 | program = "a = 1" 470 | bad_test = "assert False" 471 | good_test = "assert True" 472 | time_out_test = f""" 473 | for i in range(9999999999999999999): 474 | for k in range(99999999999999999999): 475 | print("hello world") 476 | """ 477 | test_case_status = get_successful_tests_fast( 478 | program=program, 479 | tests=[ 480 | bad_test, 481 | good_test, 482 | bad_test, 483 | good_test, 484 | good_test, 485 | time_out_test, 486 | good_test, 487 | ], 488 | ) 489 | print(test_case_status) 490 | test_case_status = get_successful_tests_fast( 491 | program=program, 492 | tests=[ 493 | bad_test, 494 | bad_test, 495 | time_out_test, 496 | time_out_test, 497 | time_out_test, 498 | time_out_test, 499 | ], 500 | ) 501 | print(test_case_status) 502 | -------------------------------------------------------------------------------- /data/inference/GetDatasets.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import datasets 4 | 5 | 6 | def get_dataset_from_huggingface( 7 | hugging_face_dataset_path: str, 8 | hugging_face_dataset_name: Optional[str] = None, 9 | ) -> datasets.dataset_dict.DatasetDict: 10 | """Get a hugging face dataset object. We will store this file locally so that we do not have to redownload this 11 | 12 | Parameter: 13 | hugging_face_dataset_path: a string that you cope from huggingface, ex: "gsm8k", "TIGER-Lab/MathInstruct" 14 | hugging_face_dataset_name: huggingface dataset name, for GSM8K we have "main" and "socratic" 15 | """ 16 | data = datasets.load_dataset(hugging_face_dataset_path, hugging_face_dataset_name) 17 | return data 18 | -------------------------------------------------------------------------------- /data/inference/InferenceModels.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | import torch 4 | 5 | from inference.Constants import GENERIC_SAMPLING_PARAMETERS 6 | from inference.native_inference import native_inference 7 | from inference.Utility import (get_huggingface_model_path, 8 | get_suggested_inference_batch_size) 9 | from inference.vllm_inference import vllm_inference 10 | 11 | 12 | def inference( 13 | model_name: str, 14 | model_type: str, 15 | dataset_name: str, 16 | sampling_method: str, 17 | prompts: List[str] = None, 18 | input_token_ids: List = None, 19 | chunk_size: int = -1, 20 | inference_method: str = "default", 21 | additional_sampling_params: Dict = {}, 22 | custom_model_path: Optional[str] = None, 23 | seperate_tokenizer_path: Optional[str] = None, 24 | ) -> None: 25 | """this function inference a model against GSM8K using the vllm. 26 | 27 | | Parameter: 28 | | model_name: the model name, ex: f"{MAMMOTH}_cot" 29 | | model_type: which specific model do you want to use, ex: "7b" 30 | | dataset_name: the name of the dataset, ex: f"{GSM8K}" 31 | | sampling_method: should be a sampling method available in GENERIC_SAMPLING_PARAMETERS 32 | in Constants.py. Currently support: {greedy, best_of_n_diverse_beam_search, 33 | best_of_n_beam_search, best_of_n_top_p_sampling, best_of_n_top_k_sampling}. 34 | | prompts: the prompts that will be inferenced on, note that we will apply template on top of this prompt 35 | | local_model_name: a string indicating the model, this string will be used as a filename 36 | | chunk_size: the program will store every chunk_size inferenced prompts, if received <= 0, then will use the 37 | recommended chunk size stored in ModelConfigs 38 | | vllm: whether to use vllm or not 39 | | inference_method: currently support "default", "native", "vllm" and "accelerate" 40 | | additional_sampling_params: sampling parameter you would like to pass in in addition to the pre-configured one 41 | | custom_model_path: the path to the model if you have fine-tuned it 42 | | seperate_tokenizer_path: if there is a different path for tokenizer 43 | """ 44 | 45 | sampling_params = GENERIC_SAMPLING_PARAMETERS[sampling_method] 46 | sampling_params.update(additional_sampling_params) 47 | 48 | if custom_model_path is None: 49 | huggingface_model_path = get_huggingface_model_path( 50 | model_name=model_name, model_size=model_type 51 | ) 52 | local_model_name = f"{model_name}_{model_type}_{sampling_method}" 53 | trust_remote_code = True 54 | 55 | if chunk_size <= 0: 56 | chunk_size = get_suggested_inference_batch_size(model_size=model_type) 57 | else: 58 | # we are using custom model 59 | huggingface_model_path = custom_model_path 60 | local_model_name = f"{model_name}_{sampling_method}" 61 | trust_remote_code = False 62 | 63 | chunk_size = max(1, chunk_size) # cannot be less than 1 64 | 65 | if model_name == "deepseek_coder" and model_type == "6.7b": 66 | additional_vllm_param = {"max_model_len": 40000} 67 | else: 68 | additional_vllm_param = None 69 | 70 | if "best_of_n" in sampling_method: 71 | # we have to reduce the chunk size accordingly to prevent errors 72 | n = max(sampling_params.get("n", 0), sampling_params.get("num_beams", 0)) 73 | chunk_size = max(1, chunk_size // n) 74 | 75 | num_gpus = torch.cuda.device_count() 76 | chunk_size = chunk_size * num_gpus 77 | 78 | if "diverse_beam_search" in sampling_method: 79 | # vLLM does not support diverse beam search 80 | assert inference_method in {"default", "native"} 81 | inference_method = "native" 82 | 83 | if inference_method in {"default", "vllm"}: 84 | vllm_inference( 85 | huggingface_model_path=huggingface_model_path, 86 | dataset_name=dataset_name, 87 | prompts=prompts, 88 | input_token_ids=input_token_ids, 89 | local_model_name=local_model_name, 90 | chunk_size=max(1, chunk_size), 91 | sampling_params=sampling_params, 92 | trust_remote_code=trust_remote_code, 93 | seperate_tokenizer_path=seperate_tokenizer_path, 94 | additional_engine_param=additional_vllm_param, 95 | ) 96 | else: 97 | # have not write the support yet 98 | assert input_token_ids is None 99 | native_inference( 100 | huggingface_model_path=huggingface_model_path, 101 | dataset_name=dataset_name, 102 | prompts=prompts, 103 | local_model_name=local_model_name, 104 | chunk_size=max( 105 | 1, chunk_size 106 | ), # native inference is slow so we reduce chunk size 107 | sampling_params=sampling_params, 108 | trust_remote_code=trust_remote_code, 109 | seperate_tokenizer_path=seperate_tokenizer_path, 110 | ) 111 | -------------------------------------------------------------------------------- /data/inference/Utility.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Dict, Iterable, List, Tuple 4 | 5 | from inference.Constants import MODEL_PATH 6 | 7 | 8 | def append_inference(dataset_name: str, model_name: str, lst: List[str]) -> None: 9 | """append to output file""" 10 | file_path = f"inferenced output/{dataset_name}/{model_name}.jsonl" 11 | dir_path = f"inferenced output/{dataset_name}/" 12 | os.makedirs(dir_path, exist_ok=True) 13 | with open(file_path, "a") as f: 14 | for i in lst: 15 | f.write(i + "\n") 16 | 17 | 18 | def get_saved_inference_index(dataset_name: str, model_name: str) -> int: 19 | """Check if previous inference has been done with regard to dataset and model. If so, then it will return the next id which will need to be inferenced 20 | If not, it will create the inference file""" 21 | file_path = f"inferenced output/{dataset_name}/{model_name}.jsonl" 22 | try: 23 | if os.path.exists(file_path): 24 | with open(file_path, "r") as f: 25 | lst = f.readlines() 26 | if len(lst) > 0: 27 | dic = json.loads(lst[-1]) 28 | return dic["id"] + 1 29 | else: 30 | return 0 31 | else: 32 | dir_path = f"inferenced output/{dataset_name}/" 33 | os.makedirs(dir_path, exist_ok=True) 34 | open(file_path, "a").close() 35 | return 0 36 | except: 37 | print(dataset_name) 38 | print(model_name) 39 | raise Exception(f"Failed to read {file_path}") 40 | 41 | 42 | def load_saved_inference(dataset_name: str, model_name: str) -> Dict[int, List[str]]: 43 | """get the saved inference, output is a dictionary where the key is the index and the value is the response""" 44 | file_path = f"inferenced output/{dataset_name}/{model_name}.jsonl" 45 | if not os.path.exists(file_path): 46 | raise Exception(f"No saved inference between {dataset_name} + {model_name}") 47 | 48 | with open(file_path, "r") as f: 49 | lst = f.readlines() 50 | lst = [json.loads(i) for i in lst] 51 | out = {} 52 | for i in lst: 53 | if i["id"] in out: 54 | out[i["id"]].append(i["response"]) 55 | else: 56 | out[i["id"]] = [i["response"]] 57 | return out 58 | 59 | 60 | def load_processed_inference( 61 | dataset_name: str, model_name: str 62 | ) -> Dict[int, List[Tuple[str, float]]]: 63 | """get the processed inference with information such that output is a dictionary where: 64 | the key is the index and the value is a list of tuple in the following form: (code in string, accuracy) 65 | """ 66 | file_path = f"inferenced output/{dataset_name}/processed/{model_name}.jsonl" 67 | if not os.path.exists(file_path): 68 | raise Exception(f"No processed inference between {dataset_name} + {model_name}") 69 | 70 | with open(file_path, "r") as f: 71 | lst = f.readlines() 72 | 73 | lst = [json.loads(i) for i in lst] 74 | out = {} 75 | for i in lst: 76 | if i["id"] in out: 77 | out[i["id"]].append((i["response"], i["accuracy"])) 78 | else: 79 | out[i["id"]] = [(i["response"], i["accuracy"])] 80 | return out 81 | 82 | 83 | def print_inferenced_output( 84 | dataset_name: str, model_name: str, indices: int | Iterable[int] = range(10) 85 | ) -> None: 86 | """Print the inferenced output to the terminal, 87 | 88 | Parameter: 89 | dataset_name 90 | model_name 91 | indices: either an integer or a list of integers which you would like to print 92 | """ 93 | 94 | inferences = load_saved_inference(dataset_name, model_name) 95 | if type(indices) == int: 96 | for sentence in inferences[indices]: 97 | print(sentence) 98 | return 99 | 100 | for i in indices: 101 | for sentence in inferences[i]: 102 | print(f"Index {i}:") 103 | print(sentence) 104 | 105 | 106 | def get_huggingface_model_path(model_name: str, model_size: str) -> str: 107 | """Get the huggingface model path 108 | 109 | Parameter: 110 | model_name: a string such as "qwen_coder_2.5" 111 | model_size: a string such as "7b" 112 | """ 113 | 114 | if model_name not in MODEL_PATH: 115 | raise Exception(f"{model_name} not in MODEL Constants") 116 | if model_size not in MODEL_PATH[model_name]: 117 | raise Exception( 118 | f"{model_size} not found for {model_name}, available ones are: {list(MODEL_PATH[model_name].keys())}" 119 | ) 120 | return MODEL_PATH[model_name][model_size] 121 | 122 | 123 | def get_suggested_inference_batch_size(model_size: str | float) -> int: 124 | """Get the suggested inference batchsize 125 | 126 | Parameter: 127 | model_size: an float such as 7 representing 7B parameters or a string such as '7b' 128 | """ 129 | if type(model_size) == str: 130 | model_size = float(model_size[:-1]) 131 | if model_size <= 10: 132 | return 64 133 | elif model_size <= 40: 134 | return 16 135 | elif model_size <= 80: 136 | return 4 137 | else: 138 | return 2 139 | -------------------------------------------------------------------------------- /data/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/AceCoder/e37ef2aea2e3710be6b06e9a91dd92c98321df96/data/inference/__init__.py -------------------------------------------------------------------------------- /data/inference/native_inference.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Callable, Dict, List, Optional 3 | 4 | import torch 5 | from tqdm import tqdm 6 | from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig 7 | 8 | from inference.Utility import append_inference, get_saved_inference_index 9 | from utility.utility import MyTimer 10 | 11 | 12 | def native_inference( 13 | huggingface_model_path: str, 14 | dataset_name: str, 15 | prompts: List[str], 16 | local_model_name: str, 17 | chunk_size: int = 20, 18 | sampling_params: Dict = None, 19 | generation_config: GenerationConfig = None, 20 | trust_remote_code: bool = False, 21 | seperate_tokenizer_path: Optional[str] = None, 22 | ) -> None: 23 | """Use the native huggingface generate function to make inferences. Mostly exist because vLLM does not support 24 | diverse beam search. 25 | 26 | | Parameter: 27 | | huggingface_model_path: the model path copied from Hugging Face, ex: "deepseek-ai/deepseek-math-7b-base" 28 | | dataset_name: a string representing the dataset, this string will be part of a filename 29 | | prompts: a list of string, each is the prompt 30 | | local_model_name: a string indicating the model, this string will be used as a filename 31 | | post_process_func: a function that will be used to process the output answer 32 | | chunk_size: the program will store every chunk_size inferenced prompts, the default 200 works well 33 | with one shot model inference for 7B models. You should decrease this number if you are inferencing 34 | on larger models / doing best of n inferences. 35 | | sampling_params: the sampling parameter passed for inferencing model, if given none then we will use 36 | {"max_tokens"=1024} 37 | | generation_config: a generation config object 38 | | trust_remote_code: the same as the one defined in huggingface transformers 39 | """ 40 | starting_idx = get_saved_inference_index(dataset_name, local_model_name) 41 | if starting_idx >= len(prompts): 42 | print( 43 | f"Inference for {dataset_name} using {local_model_name} is already complete" 44 | ) 45 | return 46 | print( 47 | f"starting inferencing for {dataset_name} using {local_model_name} with starting index {starting_idx}" 48 | ) 49 | timer = MyTimer() 50 | device = torch.device("cuda") 51 | if sampling_params is None: 52 | sampling_params = {"max_new_tokens": 1024} 53 | else: 54 | if "max_tokens" in sampling_params: 55 | tmp_val = sampling_params.pop("max_tokens") 56 | sampling_params["max_new_tokens"] = tmp_val 57 | elif "max_new_tokens" not in sampling_params: 58 | # default max new tokens 59 | sampling_params["max_new_tokens"] = 1024 60 | 61 | if "n" in sampling_params: 62 | # we need to replace the name for best_of_n sampling 63 | tmp_val = sampling_params.pop("n") 64 | sampling_params["num_beams"] = tmp_val 65 | sampling_params["num_return_sequences"] = tmp_val 66 | 67 | prompts = prompts[starting_idx:] 68 | model = AutoModelForCausalLM.from_pretrained( 69 | huggingface_model_path, 70 | torch_dtype=torch.bfloat16, 71 | device_map="auto", 72 | trust_remote_code=trust_remote_code, 73 | ) 74 | if generation_config is not None: 75 | model.generation_config = generation_config 76 | 77 | if seperate_tokenizer_path is None: 78 | seperate_tokenizer_path = huggingface_model_path 79 | 80 | tokenizer = AutoTokenizer.from_pretrained( 81 | seperate_tokenizer_path, trust_remote_code=trust_remote_code 82 | ) 83 | 84 | timer.print_runtime(" loading models and datasets") 85 | 86 | output_lst = [] 87 | with torch.no_grad(): 88 | for i, prompt in enumerate(tqdm(prompts)): 89 | encoding = tokenizer(prompt, return_tensors="pt") 90 | outputs = model.generate(**encoding.to(device), **sampling_params) 91 | 92 | response_strs = [ 93 | tokenizer.decode(output, skip_special_tokens=True) for output in outputs 94 | ] 95 | response_strs = [ 96 | i[(len(prompt)) :] if i.startswith(prompt) else i for i in response_strs 97 | ] 98 | output_lst += [ 99 | json.dumps({"id": i + starting_idx, "response": strr}) 100 | for strr in response_strs 101 | ] 102 | 103 | timer.print_runtime(f" inferenced {i} entries", reset_timer=False) 104 | if (i + 1) % chunk_size == 0: 105 | append_inference(dataset_name, local_model_name, output_lst) 106 | output_lst = [] 107 | if len(output_lst) > 0: 108 | append_inference(dataset_name, local_model_name, output_lst) 109 | -------------------------------------------------------------------------------- /data/inference/post_process_functions.py: -------------------------------------------------------------------------------- 1 | def eval_post_process(input_str: str) -> str: 2 | markdowns = ["python", "markdown\n```python", "markdown\n```", "markdown"] 3 | for markdown in markdowns: 4 | if input_str.startswith(markdown): 5 | input_str = input_str[len(markdown) :] 6 | break 7 | end_tokens = ["```", "import unittest", "from unittest", "if __name__", "assert"] 8 | indices = [input_str.find(i) for i in end_tokens] 9 | indices = [i for i in indices if i > 0] 10 | if len(indices) > 0: 11 | return input_str[: min(indices)] 12 | return input_str 13 | 14 | 15 | def deepseek_coder_post_process(input_str: str) -> str: 16 | left_idx = input_str.find("```python") 17 | if left_idx < 0: 18 | return deepseek_coder_post_process_2(input_str) 19 | out = input_str[left_idx + 9 :] 20 | right_idx = out.find("```") 21 | if right_idx < 0: 22 | return deepseek_coder_post_process_2(input_str) 23 | out = out[:right_idx] 24 | out = out.strip() 25 | return out 26 | 27 | 28 | def deepseek_coder_post_process_2(input_str: str) -> str: 29 | right_idx = input_str.find("[DONE]") 30 | if right_idx < 0: 31 | return deepseek_coder_post_process_3(input_str) 32 | out = input_str[:right_idx] 33 | out = out.strip() 34 | return out 35 | 36 | 37 | def deepseek_coder_post_process_3(input_str: str) -> str: 38 | right_idx = input_str.find("[END]") 39 | if right_idx < 0: 40 | return input_str 41 | out = input_str[:right_idx] 42 | out = out.strip() 43 | return out 44 | 45 | 46 | def codellama_post_process(input_str: str) -> str: 47 | return '\n """ ' + input_str + "\n return result" 48 | 49 | 50 | def codellama_instruct_post_process(input_str: str) -> str: 51 | left_idx = input_str.find("[PYTHON]") 52 | if left_idx < 0: 53 | return input_str 54 | out = input_str[left_idx + 8 :] 55 | right_idx = out.find("[/PYTHON]") 56 | if right_idx < 0: 57 | return input_str 58 | out = out[:right_idx] 59 | out = out.strip() 60 | return out 61 | 62 | 63 | def starcoder2_post_process(input_str: str) -> str: 64 | out = input_str[3:] 65 | left_idx = out.find('"""') 66 | if left_idx < 0: 67 | return input_str 68 | out = input_str[left_idx + 6 :] 69 | right_idx = out.find('"""') 70 | if right_idx < 0: 71 | return input_str 72 | out = out[:right_idx] 73 | out = out.strip() 74 | return out 75 | -------------------------------------------------------------------------------- /data/inference/vllm_inference.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Callable, Dict, List, Optional 3 | 4 | import torch 5 | from tqdm import tqdm 6 | from vllm import LLM, SamplingParams 7 | 8 | from inference.Utility import append_inference, get_saved_inference_index 9 | from utility.utility import MyTimer, chunking 10 | 11 | 12 | def vllm_inference( 13 | huggingface_model_path: str, 14 | dataset_name: str, 15 | local_model_name: str, 16 | prompts: List[str] = None, 17 | input_token_ids: List = None, 18 | chunk_size: int = 200, 19 | sampling_params: Dict = None, 20 | trust_remote_code: bool = False, 21 | seperate_tokenizer_path: Optional[str] = None, 22 | additional_engine_param: Optional[Dict] = None, 23 | ) -> None: 24 | """Make inference on provided prompts using vllm, will use chunking and save inferenced answer locally. 25 | 26 | | Parameter: 27 | | huggingface_model_path: the model path copied from Hugging Face, ex: "deepseek-ai/deepseek-math-7b-base" 28 | | dataset_name: a string representing the dataset, this string will be part of a filename 29 | | prompts: a list of string, each is the prompt 30 | | local_model_name: a string indicating the model, this string will be used as a filename 31 | | input_token_ids: alternatively, you can pass in a List of input_token_ids that have already 32 | been tokenized by a tokenizer. 33 | | post_process_func: a function that will be used to process the output answer 34 | | chunk_size: the program will store every chunk_size inferenced prompts, the default 200 works well 35 | with one shot model inference for 7B models. You should decrease this number if you are inferencing 36 | on larger models / doing best of n inferences. 37 | | sampling_params: the sampling parameter passed for inferencing model, if given none then we will use 38 | {"max_tokens"=1024} 39 | | trust_remote_code: the same as the one defined in huggingface transformers 40 | | seperate_tokenizer_path: if you have a path that links to a tokenizer that you would like to use 41 | | additional_engine_param: additional parameter for the vLLM engine 42 | """ 43 | starting_idx = get_saved_inference_index(dataset_name, local_model_name) 44 | if prompts: 45 | input_length = len(prompts) 46 | elif input_token_ids: 47 | input_length = len(input_token_ids) 48 | else: 49 | raise Exception("Both prompts and input token ids are None") 50 | if starting_idx >= input_length: 51 | print( 52 | f"Inference for {dataset_name} using {local_model_name} is already complete" 53 | ) 54 | return 55 | print( 56 | f"starting vLLM inferencing for {dataset_name} using {local_model_name} with starting index {starting_idx}" 57 | ) 58 | timer = MyTimer() 59 | if prompts: 60 | prompts = prompts[starting_idx:] 61 | if input_token_ids: 62 | input_token_ids = input_token_ids[starting_idx:] 63 | if sampling_params is None: 64 | sampling_params = {"max_tokens": 1024} 65 | elif "max_new_tokens" in sampling_params: 66 | tmp_val = sampling_params.pop("max_new_tokens") 67 | sampling_params["max_tokens"] = tmp_val 68 | elif "max_tokens" not in sampling_params: 69 | # default max new tokens 70 | sampling_params["max_tokens"] = 1024 71 | 72 | sampling_params = SamplingParams(**sampling_params) 73 | 74 | if not additional_engine_param: 75 | llm = LLM( 76 | model=huggingface_model_path, 77 | tokenizer=seperate_tokenizer_path, 78 | tensor_parallel_size=torch.cuda.device_count(), 79 | trust_remote_code=trust_remote_code, 80 | swap_space=24, 81 | ) 82 | else: 83 | llm = LLM( 84 | model=huggingface_model_path, 85 | tokenizer=seperate_tokenizer_path, 86 | tensor_parallel_size=torch.cuda.device_count(), 87 | trust_remote_code=trust_remote_code, 88 | swap_space=24, 89 | **additional_engine_param, 90 | ) 91 | timer.print_runtime(f" getting model - {huggingface_model_path}") 92 | 93 | if prompts: 94 | text_prompt = True 95 | inputs = chunking(prompts, chunk_size) 96 | elif input_token_ids: 97 | text_prompt = False 98 | inputs = chunking(input_token_ids, chunk_size) 99 | 100 | for i, prompts in enumerate(tqdm(inputs, leave=True)): 101 | if text_prompt: 102 | outputs = llm.generate(prompts=prompts, sampling_params=sampling_params) 103 | else: 104 | outputs = llm.generate( 105 | prompt_token_ids=prompts, sampling_params=sampling_params 106 | ) 107 | answers = [] 108 | for idx, output in enumerate(outputs): 109 | for answer in output.outputs: 110 | answers.append( 111 | json.dumps( 112 | { 113 | "id": i * chunk_size + idx + starting_idx, 114 | "response": answer.text, 115 | } 116 | ) 117 | ) 118 | append_inference(dataset_name, local_model_name, answers) 119 | -------------------------------------------------------------------------------- /data/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | # you should run setup.sh instead of this one, setup.sh will call this package 4 | setup( 5 | name="acecoder_data", 6 | version="1.0.0", 7 | description="", 8 | author="Wyett (Huaye) Zeng", 9 | author_email="wyettzeng@gmail.com", 10 | packages=find_packages(), 11 | url="https://github.com/TIGER-AI-Lab/AceCoder", # github 12 | install_requires=[ 13 | # comment the following if you have CUDA 11.8 14 | "torch", 15 | "vllm", 16 | "xformers", 17 | # Do not comment any of these: 18 | "accelerate", 19 | "datasets", 20 | "numpy", 21 | "fire", 22 | "tqdm", 23 | "transformers", 24 | "flash_attn", 25 | "tqdm", 26 | "datasets", 27 | "matplotlib", 28 | "seaborn", 29 | "rewardbench", 30 | "openpyxl", 31 | "scikit-learn", 32 | ], 33 | ) 34 | -------------------------------------------------------------------------------- /data/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create conda environment 4 | conda create -n acecoder_data python=3.11 5 | conda init 6 | conda activate acecoder_data 7 | 8 | # uncomment the following if you have CUDA 11.8 9 | # export VLLM_VERSION=0.2.6 10 | # export PYTHON_VERSION=311 11 | # pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl 12 | 13 | # pip uninstall torch -y 14 | # pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cu118 15 | 16 | # pip uninstall xformers -y 17 | # pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118 18 | 19 | install packages 20 | pip install torch # need to install this first for flash attn later 21 | pip install -e . 22 | 23 | ## Intall easy open ai by Jiang Dong Fu 24 | pip install git+https://github.com/jdf-prog/easy-openai 25 | -------------------------------------------------------------------------------- /data/training_dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/AceCoder/e37ef2aea2e3710be6b06e9a91dd92c98321df96/data/training_dataset/__init__.py -------------------------------------------------------------------------------- /data/training_dataset/bigcode_python_fns/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/AceCoder/e37ef2aea2e3710be6b06e9a91dd92c98321df96/data/training_dataset/bigcode_python_fns/__init__.py -------------------------------------------------------------------------------- /data/training_dataset/bigcode_python_fns/dataset.py: -------------------------------------------------------------------------------- 1 | from inference.GetDatasets import get_dataset_from_huggingface 2 | 3 | 4 | def get_bigcode_python_fn_dataset(): 5 | data = get_dataset_from_huggingface("bigcode/stack-dedup-python-fns")["train"] 6 | return data 7 | -------------------------------------------------------------------------------- /data/training_dataset/bigcode_python_fns/generate_test_cases.py: -------------------------------------------------------------------------------- 1 | from fire import Fire 2 | from tqdm import tqdm 3 | 4 | from training_dataset.create_test_case_and_prompt import \ 5 | create_test_cases_using_gpt 6 | from training_dataset.util import parse_incomplete_json 7 | from utility.utility import append_jsonl, chunking, load_jsonl 8 | 9 | MAX_CHUNK_SIZE = 20 10 | 11 | 12 | def generate_bigcode_python_fns_test_case(ct: int = 50): 13 | """step 2 of the process, generate test cases using chat gpt""" 14 | # we check for last generated responses, so we do not waste openAI tokens 15 | jsonl_file_name = "training_dataset/bigcode_python_fns/data/v2.jsonl" 16 | start_idx = 0 17 | try: 18 | past_data = load_jsonl(jsonl_file_name) 19 | last_idx = past_data[-1]["id"] 20 | start_idx = last_idx + 1 21 | except: 22 | # we start at 0 as we never inferenced before 23 | pass 24 | 25 | if start_idx >= ct and ct > 0: 26 | return # we already finished inferencing 27 | 28 | # preparing the input 29 | data = load_jsonl("training_dataset/bigcode_python_fns/data/v1.jsonl") 30 | if ct > 0: 31 | data = data[:ct] 32 | data = data[start_idx:] 33 | data_chunks = chunking(data, MAX_CHUNK_SIZE) 34 | inferenced_ct = 0 35 | total_price = 0 36 | for chunk in tqdm(data_chunks, desc="Chunk Size"): 37 | programs = [i["program"] for i in chunk] 38 | responses, price = create_test_cases_using_gpt( 39 | programs=programs, 40 | use_cache=False, 41 | return_price=True, 42 | ) 43 | total_price += price 44 | tests = [] 45 | questions = [] 46 | for i in responses: 47 | obj = parse_incomplete_json(i) 48 | question = obj.get("question", "please ignore this question") 49 | test = obj.get("tests", ["assert False"]) 50 | questions.append(question) 51 | tests.append(test) 52 | for i in range(len(tests)): 53 | chunk[i]["tests"] = tests[i] 54 | chunk[i]["gpt_question"] = questions[i] 55 | 56 | append_jsonl(jsonl_file_name, chunk) 57 | inferenced_ct += len(chunk) 58 | print(f"Total Cost: {total_price}") 59 | 60 | 61 | if __name__ == "__main__": 62 | Fire(generate_bigcode_python_fns_test_case) 63 | -------------------------------------------------------------------------------- /data/training_dataset/bigcode_python_fns/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | from tqdm import tqdm 5 | 6 | from training_dataset.bigcode_python_fns.dataset import \ 7 | get_bigcode_python_fn_dataset 8 | from training_dataset.util import remove_print_statements_from_python_program 9 | from utility.utility import load_jsonl, save_jsonl 10 | 11 | 12 | def get_bigcode_python_fns_programs(use_cache: bool = True) -> List[str]: 13 | """Step 1 of the process, extract python programs and instructions from the dataset. We only keep programs in function or class form.""" 14 | file_name = "training_dataset/bigcode_python_fns/data/v1.jsonl" 15 | if os.path.exists(file_name) and use_cache: 16 | out = load_jsonl(file_name) 17 | return [i["program"] for i in out] 18 | os.makedirs("training_dataset/bigcode_python_fns/data/", exist_ok=True) 19 | data = get_bigcode_python_fn_dataset() 20 | out = [] 21 | idx = 0 22 | for i in tqdm(range(len(data))): 23 | program = data[i]["content"] 24 | if len(program) <= 100: 25 | # too short 26 | continue 27 | program = remove_print_statements_from_python_program(program) 28 | out.append({"id": idx, "program": program}) 29 | idx += 1 30 | save_jsonl(file_name, out) 31 | return [i["program"] for i in out] 32 | 33 | 34 | if __name__ == "__main__": 35 | get_bigcode_python_fns_programs() 36 | -------------------------------------------------------------------------------- /data/training_dataset/consolidate_dataset.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from types import NoneType 3 | 4 | from inference.ComputeAccuracy import get_oracle_test_case_status 5 | from training_dataset.constants import DATASET_LST, MODEL_LST 6 | from utility.utility import MyTimer, load_jsonl, save_jsonl 7 | 8 | 9 | def recursive_clean(obj): 10 | """Clean an object and remove all non-utf-8 characters""" 11 | if type(obj) in [int, float, bool, NoneType]: 12 | return obj 13 | if type(obj) == str: 14 | return obj.encode("utf-8", errors="replace").decode("utf-8") 15 | elif type(obj) == list: 16 | return [recursive_clean(i) for i in obj] 17 | elif type(obj) == dict: 18 | return {recursive_clean(k): recursive_clean(v) for k, v in obj.items()} 19 | else: 20 | raise Exception(f"Unknown object type: {type(obj)}: {str(obj)[:300]}") 21 | 22 | 23 | def consolidate_processed_data( 24 | dataset_name: str, 25 | ct: int = -1, 26 | oracle_model_name: str = "qwen_coder_2.5_32b_greedy", 27 | min_oracle_model_pass_case_requirement: int = 3, 28 | ): 29 | """Get accuracy of the provided solution from the dataset""" 30 | timer = MyTimer() 31 | # we now append each entry with all the inferenced answers and solution accuracy 32 | data = recursive_clean(load_jsonl(f"training_dataset/{dataset_name}/data/v2.jsonl")) 33 | test_case_status = get_oracle_test_case_status( 34 | dataset_name=dataset_name, model_name=oracle_model_name 35 | ) 36 | if ct > 0: 37 | data = data[:ct] 38 | timer.print_runtime(f"{dataset_name} loading data and oracle answer") 39 | 40 | # adding the inferenced code as "inferences" to the dataset 41 | inferenced_program = {} 42 | for short_model_name, full_model_name in MODEL_LST.items(): 43 | acc_file_path = f"inferenced output/{dataset_name}/processed/default/{full_model_name}.jsonl" 44 | acc_lst = recursive_clean(load_jsonl(acc_file_path)) 45 | ram_dict = defaultdict(list) 46 | for acc_row in acc_lst: 47 | ram_dict[acc_row["id"]].append(acc_row) 48 | inferenced_program[short_model_name] = ram_dict 49 | timer.print_runtime(f"{dataset_name} loading {short_model_name}'s inference") 50 | 51 | out = [] 52 | for i, row in enumerate(data): 53 | tmp_lst = [] 54 | ground_truth_test_case = test_case_status[i] 55 | if sum(ground_truth_test_case) <= min_oracle_model_pass_case_requirement: 56 | continue # less than the minimum requirement, so we just skipped 57 | for model_name in MODEL_LST: 58 | for roww in inferenced_program[model_name][i]: 59 | model_test_case = roww["test_case_status"] 60 | new_pass_rate = [ 61 | a * b for a, b in zip(ground_truth_test_case, model_test_case) 62 | ] 63 | acc = sum(new_pass_rate) / len(new_pass_rate) 64 | # tuple in the form: (code, accuracy, model name) 65 | tmp_lst.append((roww["response"], acc, model_name)) 66 | tests = [ 67 | test 68 | for test_idx, test in enumerate(row["tests"]) 69 | if ground_truth_test_case[test_idx] == 1 70 | ] 71 | to_be_add_entry = { 72 | "id": f"{dataset_name}_{row['id']}", 73 | "prompt": row["gpt_question"], 74 | "tests": tests, 75 | "inferences": tmp_lst, 76 | } 77 | out.append(to_be_add_entry) 78 | timer.print_runtime(f"{dataset_name} creating dataset") 79 | 80 | jsonl_file_name = ( 81 | f"training_dataset/{dataset_name}/data/v3_{oracle_model_name}.jsonl" 82 | ) 83 | save_jsonl(jsonl_file_name, out) 84 | timer.print_runtime(f"{dataset_name} saving v3 data") 85 | 86 | 87 | def consolidate_processed_data_without_oracle( 88 | dataset_name: str, 89 | min_test_case_requirement: int = 3, 90 | ct: int = -1, 91 | ): 92 | """Get accuracy of the provided solution from the dataset""" 93 | timer = MyTimer() 94 | # we now append each entry with all the inferenced answers and solution accuracy 95 | data = recursive_clean(load_jsonl(f"training_dataset/{dataset_name}/data/v2.jsonl")) 96 | if ct > 0: 97 | data = data[:ct] 98 | timer.print_runtime(f"{dataset_name} loading data and oracle answer") 99 | 100 | # adding the inferenced code as "inferences" to the dataset 101 | inferenced_program = {} 102 | for short_model_name, full_model_name in MODEL_LST.items(): 103 | acc_file_path = f"inferenced output/{dataset_name}/processed/default/{full_model_name}.jsonl" 104 | acc_lst = recursive_clean(load_jsonl(acc_file_path)) 105 | ram_dict = defaultdict(list) 106 | for acc_row in acc_lst: 107 | ram_dict[acc_row["id"]].append(acc_row) 108 | inferenced_program[short_model_name] = ram_dict 109 | timer.print_runtime(f"{dataset_name} loading {short_model_name}'s inference") 110 | 111 | out = [] 112 | for i, row in enumerate(data): 113 | tmp_lst = [] 114 | for model_name in MODEL_LST: 115 | for roww in inferenced_program[model_name][i]: 116 | model_test_case = roww["test_case_status"] 117 | if len(model_test_case) < min_test_case_requirement: 118 | continue 119 | acc = sum(model_test_case) / len(model_test_case) 120 | # tuple in the form: (code, accuracy, model name) 121 | tmp_lst.append((roww["response"], acc, model_name)) 122 | tests = row["tests"] 123 | to_be_add_entry = { 124 | "id": f"{dataset_name}_{row['id']}", 125 | "prompt": row["gpt_question"], 126 | "tests": tests, 127 | "inferences": tmp_lst, 128 | } 129 | out.append(to_be_add_entry) 130 | timer.print_runtime(f"{dataset_name} creating dataset") 131 | 132 | jsonl_file_name = f"training_dataset/{dataset_name}/data/v3_no_oracle.jsonl" 133 | save_jsonl(jsonl_file_name, out) 134 | timer.print_runtime(f"{dataset_name} saving v3 data") 135 | 136 | 137 | if __name__ == "__main__": 138 | for dataset_name in DATASET_LST: 139 | consolidate_processed_data( 140 | dataset_name=dataset_name, oracle_model_name="qwen_coder_2.5_32b_greedy" 141 | ) 142 | -------------------------------------------------------------------------------- /data/training_dataset/constants.py: -------------------------------------------------------------------------------- 1 | MODEL_LST = { 2 | "qwen_coder_2.5": "qwen_coder_2.5_7b_best_of_n_top_p_sampling", 3 | } 4 | 5 | DATASET_LST = ["oss", "evol", "bigcode_python_fns"] 6 | -------------------------------------------------------------------------------- /data/training_dataset/create_test_case_and_prompt.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from easy_openai import openai_completions 4 | 5 | PROMPT_TEMPLATE_RAW = """You are the latest and best bot aimed at transforming some code snippet into a leetcode style question. You will be provided with a prompt for writing code, along with a reference program that answers the question. Please complete the following for me: 6 | 1. Come up with a leetcode style question which consists of a well-defined problem. The generated question should meet the following criteria: 7 | a. The question is clear and understandable, with enough details to describe what the input and output are. 8 | b. The question should be solvable by only implementing 1 function instead of multiple functions or a class. Therefore, please avoid questions which require complicated pipelines. 9 | c. The question itself should not require any access to external resource or database. 10 | d. Feel free to use part of the original question if necessary. Moreover, please do not ask for runtime and space complexity analysis or any test cases in your response. 11 | 2. Based on the modified question that you generated in part 1, you need to create around 20 test cases for this modified question. Each test case should be independent assert clauses. The parameters and expected output of each test case should all be constants, **without accessing any external resources**. 12 | 13 | Here is the original question: 14 | {instruction} 15 | 16 | Here is the reference program that answers the question: 17 | ```python 18 | {program} 19 | ``` 20 | 21 | Now give your modified question and generated test cases in the following json format: 22 | {"question": ..., "tests":["assert ...", "assert ..."]}. 23 | """ 24 | 25 | PROMPT_TEMPLATE_NO_INSTRUCTION = """You are the latest and best bot aimed at transforming some code snippet into a leetcode style question. You will be provided with a reference program. Please complete the following for me: 26 | 1. Come up with a leetcode style question which consists of a well-defined problem. The generated question should meet the following criteria: 27 | a. The question is clear and understandable, with enough details to describe what the input and output are. 28 | b. The question should be solvable by only implementing 1 function instead of multiple functions or a class. Therefore, please avoid questions which require complicated pipelines. 29 | c. The question itself should not require any access to external resource or database. 30 | d. Feel free to use part of the original question if necessary. Moreover, please do not ask for runtime and space complexity analysis or any test cases in your response. 31 | 2. Based on the modified question that you generated in part 1, you need to create around 20 test cases for this modified question. Each test case should be independent assert clauses. The parameters and expected output of each test case should all be constants, **without accessing any external resources**. 32 | 33 | Here is the reference program: 34 | ```python 35 | {program} 36 | ``` 37 | 38 | Now give your modified question and generated test cases in the following json format: 39 | {"question": ..., "tests":["assert ...", "assert ..."]}. 40 | """ 41 | 42 | 43 | def create_test_cases_using_gpt( 44 | programs: List[str], 45 | instructions: Optional[List[str]] = None, 46 | use_cache: bool = True, 47 | return_price: bool = False, 48 | ) -> None: 49 | """Use this program to create tests cases for raw (badly formatted) questions and source code. Ex: Evol""" 50 | if instructions: 51 | chatmls = [ 52 | [ 53 | { 54 | "role": "system", 55 | "content": "You are an AI assistant that helps people with python coding tasks.", 56 | }, 57 | { 58 | "role": "user", 59 | "content": PROMPT_TEMPLATE_RAW.replace("{program}", prompt).replace( 60 | "{instruction}", instruction 61 | ), 62 | }, 63 | ] 64 | for prompt, instruction in zip(programs, instructions) 65 | ] 66 | else: 67 | chatmls = [ 68 | [ 69 | { 70 | "role": "system", 71 | "content": "You are an AI assistant that helps people with python coding tasks.", 72 | }, 73 | { 74 | "role": "user", 75 | "content": PROMPT_TEMPLATE_NO_INSTRUCTION.replace( 76 | "{program}", prompt 77 | ), 78 | }, 79 | ] 80 | for prompt in programs 81 | ] 82 | output = openai_completions( 83 | chatmls, 84 | model_name="gpt-4o-mini", 85 | use_cache=use_cache, 86 | return_json=True, 87 | temperature=0.7, 88 | max_tokens=8192, 89 | ) 90 | if return_price: 91 | total_price = sum(output["price_per_example"]) 92 | return output["completions"], total_price 93 | else: 94 | return output["completions"] 95 | -------------------------------------------------------------------------------- /data/training_dataset/evaluate_inferenced_code.py: -------------------------------------------------------------------------------- 1 | import fire 2 | 3 | from inference.EvaluateInferencedCode import process_one_model 4 | from inference.post_process_functions import eval_post_process 5 | from utility.utility import load_jsonl 6 | 7 | 8 | def codeblock_post_process(input_str: str) -> str: 9 | left_idx = input_str.find("```python") 10 | if left_idx < 0: 11 | return codeblock_post_process_2(input_str) 12 | out = input_str[left_idx + 9 :] 13 | right_idx = out.find("```") 14 | if right_idx < 0: 15 | return codeblock_post_process_2(input_str) 16 | out = out[:right_idx] 17 | out = out.strip() 18 | return out 19 | 20 | 21 | def codeblock_post_process_2(input_str: str) -> str: 22 | left_idx = input_str.find("```") 23 | if left_idx < 0: 24 | return input_str 25 | out = input_str[left_idx + 3 :] 26 | right_idx = out.find("```") 27 | if right_idx < 0: 28 | return input_str 29 | out = out[:right_idx] 30 | out = out.strip() 31 | return out 32 | 33 | 34 | def evaluate_inferenced_code( 35 | model_name: str, 36 | dataset_name: str, 37 | model_type: str, 38 | sampling_method: str = "best_of_n_top_p_sampling", 39 | ): 40 | print( 41 | f"Starting evaluation for {dataset_name} - {model_name} {model_type} {sampling_method}" 42 | ) 43 | "Get the accuracy of the inferenced program" 44 | data = load_jsonl(f"training_dataset/{dataset_name}/data/v2.jsonl") 45 | tests = [ 46 | (i["tests"] if i["tests"] is not None and len(i["tests"]) > 0 else []) 47 | for i in data 48 | ] 49 | if sampling_method == "greedy": 50 | fast_algo = True 51 | max_execution_time = 0.8 52 | else: 53 | fast_algo = True 54 | max_execution_time = 0.2 55 | process_one_model( 56 | model_name=f"{model_name}_{model_type}_{sampling_method}", 57 | dataset_name=dataset_name, 58 | tests=tests, 59 | max_execution_time=max_execution_time, 60 | processing_func=eval_post_process, 61 | fast_algo=fast_algo, 62 | ) 63 | 64 | 65 | if __name__ == "__main__": 66 | fire.Fire(evaluate_inferenced_code) 67 | -------------------------------------------------------------------------------- /data/training_dataset/evaluate_inferenced_code.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | datasets=( 4 | "bigcode_python_fns" 5 | "oss" 6 | "evol" 7 | ) 8 | 9 | for dataset_name in "${datasets[@]}" 10 | do 11 | python training_dataset/evaluate_inferenced_code.py --model_name="qwen_coder_2.5" --model_type="32b" --dataset_name=$dataset_name --sampling_method="greedy" & 12 | python training_dataset/evaluate_inferenced_code.py --model_name="qwen_coder_2.5" --model_type="7b" --dataset_name=$dataset_name --sampling_method="best_of_n_top_p_sampling"& 13 | 14 | done 15 | 16 | wait # wait for all process to finish 17 | echo "Finish evaluating inferenced code" -------------------------------------------------------------------------------- /data/training_dataset/evol/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/AceCoder/e37ef2aea2e3710be6b06e9a91dd92c98321df96/data/training_dataset/evol/__init__.py -------------------------------------------------------------------------------- /data/training_dataset/evol/evol_dataset.py: -------------------------------------------------------------------------------- 1 | from inference.GetDatasets import get_dataset_from_huggingface 2 | 3 | 4 | # https://huggingface.co/datasets/ise-uiuc/Magicoder-Evol-Instruct-110K?row=0 5 | def get_evol_dataset(): 6 | data = get_dataset_from_huggingface("ise-uiuc/Magicoder-Evol-Instruct-110K")[ 7 | "train" 8 | ] 9 | return data 10 | -------------------------------------------------------------------------------- /data/training_dataset/evol/generate_test_cases.py: -------------------------------------------------------------------------------- 1 | from fire import Fire 2 | from tqdm import tqdm 3 | 4 | from training_dataset.create_test_case_and_prompt import \ 5 | create_test_cases_using_gpt 6 | from training_dataset.util import parse_incomplete_json 7 | from utility.utility import append_jsonl, chunking, load_jsonl 8 | 9 | MAX_CHUNK_SIZE = 50 10 | 11 | 12 | def generate_evol_test_case(ct: int = 50): 13 | """step 2 of the process, generate test cases using chat gpt""" 14 | # we check for last generated responses, so we do not waste openAI tokens 15 | jsonl_file_name = "training_dataset/evol/data/v2.jsonl" 16 | start_idx = 0 17 | try: 18 | past_data = load_jsonl(jsonl_file_name) 19 | last_idx = past_data[-1]["id"] 20 | start_idx = last_idx + 1 21 | except: 22 | # we start at 0 as we never inferenced before 23 | pass 24 | 25 | if start_idx >= ct and ct > 0: 26 | return # we already finished inferencing 27 | 28 | # preparing the input 29 | data = load_jsonl("training_dataset/evol/data/v1.jsonl") 30 | if ct > 0: 31 | data = data[:ct] 32 | data = data[start_idx:] 33 | data_chunks = chunking(data, MAX_CHUNK_SIZE) 34 | inferenced_ct = 0 35 | total_price = 0 36 | for chunk in tqdm(data_chunks): 37 | programs = [i["program"] for i in chunk] 38 | instructions = [i["instruction"] for i in chunk] 39 | responses, price = create_test_cases_using_gpt( 40 | programs=programs, 41 | instructions=instructions, 42 | use_cache=False, 43 | return_price=True, 44 | ) 45 | total_price += price 46 | tests = [] 47 | questions = [] 48 | for i in responses: 49 | obj = parse_incomplete_json(i) 50 | question = obj.get("question", "please ignore this question") 51 | test = obj.get("tests", ["assert False"]) 52 | questions.append(question) 53 | tests.append(test) 54 | for i in range(len(tests)): 55 | chunk[i]["tests"] = tests[i] 56 | chunk[i]["gpt_question"] = questions[i] 57 | 58 | append_jsonl(jsonl_file_name, chunk) 59 | inferenced_ct += len(chunk) 60 | print(f"Total Cost: {total_price}") 61 | 62 | 63 | if __name__ == "__main__": 64 | Fire(generate_evol_test_case) 65 | -------------------------------------------------------------------------------- /data/training_dataset/evol/preprocess_evol.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | from tqdm import tqdm 5 | 6 | from training_dataset.evol.evol_dataset import get_evol_dataset 7 | from training_dataset.util import (get_python_code_from_string, 8 | remove_print_statements_from_python_program) 9 | from utility.utility import load_jsonl, save_jsonl 10 | 11 | 12 | def get_evol_programs(use_cache: bool = True) -> List[str]: 13 | """Step 1 of the process, extract python programs and instructions from the dataset. We only keep programs in function or class form.""" 14 | file_name = "training_dataset/evol/data/v1.jsonl" 15 | if os.path.exists(file_name) and use_cache: 16 | out = load_jsonl(file_name) 17 | return [i["program"] for i in out] 18 | os.makedirs("training_dataset/evol/data/", exist_ok=True) 19 | data = get_evol_dataset() 20 | out = [] 21 | idx = 0 22 | for i in tqdm(range(len(data))): 23 | program = get_python_code_from_string(data[i]["response"]) 24 | instruction = data[i]["instruction"] 25 | if len(program) == 0: 26 | # no python code found 27 | continue 28 | if "def " not in program and "class " not in program: 29 | continue 30 | program = remove_print_statements_from_python_program(program) 31 | out.append({"id": idx, "instruction": instruction, "program": program}) 32 | idx += 1 33 | save_jsonl(file_name, out) 34 | return [i["program"] for i in out] 35 | 36 | 37 | if __name__ == "__main__": 38 | get_evol_programs() 39 | -------------------------------------------------------------------------------- /data/training_dataset/inference_generated_prompts.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import fire 4 | from transformers import AutoTokenizer 5 | 6 | from inference.InferenceModels import inference 7 | from inference.Utility import get_huggingface_model_path 8 | from utility.utility import load_jsonl 9 | 10 | TEMPLATE_INPUT_PROMPT = """Below is a Python script with a self-contained function that solves the problem and passes corresponding tests: 11 | ```python""" 12 | 13 | 14 | def get_untokenized_prompt( 15 | questions: List[str], 16 | tests: List[List[str]], 17 | test_ct: int = 2, 18 | ) -> List[List[Dict[str, str]]]: 19 | """Prepare prompts for deepseek coder 20 | 21 | Arguments: 22 | questions: a list of questions being asked for the LLM 23 | tests: a list of tests to be included in the prompt 24 | test_ct: how many test do you want to reveal to the LLM? default value is -1 which means show all 25 | """ 26 | if test_ct >= 0: 27 | tests = [i[:test_ct] for i in tests] 28 | output = [] 29 | for question, test in zip(questions, tests): 30 | test_prompt = "\n".join(test) 31 | question_prompt = f'"""\n{question}\n{test_prompt}\n"""' 32 | instruction_prompt = f"""Please provide a self-contained Python script that solves the following problem in a markdown code block: 33 | ``` 34 | {question_prompt} 35 | ``` 36 | """ 37 | chat = [ 38 | {"role": "user", "content": instruction_prompt}, 39 | {"role": "assistant", "content": TEMPLATE_INPUT_PROMPT}, 40 | ] 41 | 42 | output.append(chat) 43 | 44 | return output 45 | 46 | 47 | def get_tokenized_prompt( 48 | model_path: str, 49 | questions: List[str], 50 | tests: List[List[str]], 51 | test_ct: int = 2, 52 | ) -> List[str]: 53 | """Prepare prompts for deepseek coder 54 | 55 | Arguments: 56 | questions: a list of questions being asked for the LLM 57 | tests: a list of tests to be included in the prompt 58 | test_ct: how many test do you want to reveal to the LLM? default value is -1 which means show all 59 | """ 60 | chats = get_untokenized_prompt(questions=questions, tests=tests, test_ct=test_ct) 61 | output = [] 62 | tokenizer = AutoTokenizer.from_pretrained(model_path) 63 | for chat in chats: 64 | try: 65 | token_ids = tokenizer.apply_chat_template( 66 | chat, 67 | continue_final_message=True, 68 | add_generation_prompt=False, 69 | )[ 70 | :-1 71 | ] # Remote EOS token 72 | output.append(token_ids) 73 | except Exception as e: 74 | print(e) 75 | print(f"error tokenizing: {chat}") 76 | output.append( 77 | tokenizer.apply_chat_template( 78 | [{"role": "user", "content": "please ignore this quesiton"}], 79 | continue_final_message=True, 80 | add_generation_prompt=False, 81 | ) 82 | ) 83 | 84 | return output 85 | 86 | 87 | def create_inference( 88 | model_name: str, 89 | dataset_name: str, 90 | model_type: str, 91 | ct: int = -1, 92 | sampling_method: str = "best_of_n_top_p_sampling", 93 | ) -> None: 94 | print( 95 | f"starting inference: model name: {model_name} {model_type}, dataset name: {dataset_name}" 96 | ) 97 | 98 | load_file_name = f"training_dataset/{dataset_name}/data/v2.jsonl" 99 | 100 | data = load_jsonl(load_file_name) 101 | if ct > 0: 102 | data = data[:ct] 103 | instructions = [ 104 | i["gpt_question"] if (i["gpt_question"] is not None) else "ignore this question" 105 | for i in data 106 | ] 107 | tests = [ 108 | i["tests"][:3] if (i["gpt_question"] is not None) else ["ignore this question"] 109 | for i in data 110 | ] 111 | huggingface_path = get_huggingface_model_path( 112 | model_name=model_name, model_size=model_type 113 | ) 114 | input_token_ids = get_tokenized_prompt( 115 | model_path=huggingface_path, questions=instructions, tests=tests 116 | ) 117 | 118 | inference( 119 | model_name=model_name, 120 | model_type=model_type, 121 | dataset_name=dataset_name, 122 | sampling_method=sampling_method, 123 | input_token_ids=input_token_ids, 124 | inference_method="vllm", 125 | ) 126 | 127 | 128 | if __name__ == "__main__": 129 | fire.Fire(create_inference) 130 | -------------------------------------------------------------------------------- /data/training_dataset/inference_generated_prompts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | datasets=( 4 | "bigcode_python_fns" 5 | "oss" 6 | "evol" 7 | ) 8 | 9 | for dataset_name in "${datasets[@]}" 10 | do 11 | # oracle model 12 | python training_dataset/inference_generated_prompts.py --dataset_name=$dataset_name --model_name="qwen_coder_2.5" --model_type="32b" --sampling_method="greedy" 13 | sleep 5 14 | python training_dataset/inference_generated_prompts.py --dataset_name=$dataset_name --model_name="qwen_coder_2.5" --model_type="7b" --sampling_method="best_of_n_top_p_sampling" 15 | done 16 | 17 | echo "Finish Inferencing" -------------------------------------------------------------------------------- /data/training_dataset/oss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/AceCoder/e37ef2aea2e3710be6b06e9a91dd92c98321df96/data/training_dataset/oss/__init__.py -------------------------------------------------------------------------------- /data/training_dataset/oss/generate_test_cases.py: -------------------------------------------------------------------------------- 1 | from fire import Fire 2 | from tqdm import tqdm 3 | 4 | from training_dataset.create_test_case_and_prompt import \ 5 | create_test_cases_using_gpt 6 | from training_dataset.util import parse_incomplete_json 7 | from utility.utility import append_jsonl, chunking, load_jsonl 8 | 9 | MAX_CHUNK_SIZE = 20 10 | 11 | 12 | def generate_oss_test_case(ct: int = 50): 13 | """step 2 of the process, generate test cases using chat gpt""" 14 | # we check for last generated responses, so we do not waste openAI tokens 15 | jsonl_file_name = "training_dataset/oss/data/v2.jsonl" 16 | start_idx = 0 17 | try: 18 | past_data = load_jsonl(jsonl_file_name) 19 | last_idx = past_data[-1]["id"] 20 | start_idx = last_idx + 1 21 | except: 22 | # we start at 0 as we never inferenced before 23 | pass 24 | 25 | if start_idx >= ct and ct > 0: 26 | return # we already finished inferencing 27 | 28 | # preparing the input 29 | data = load_jsonl("training_dataset/oss/data/v1.jsonl") 30 | if ct > 0: 31 | data = data[:ct] 32 | data = data[start_idx:] 33 | data_chunks = chunking(data, MAX_CHUNK_SIZE) 34 | inferenced_ct = 0 35 | total_price = 0 36 | for chunk in tqdm(data_chunks, desc="Chunk Size"): 37 | programs = [i["program"] for i in chunk] 38 | instructions = [i["instruction"] for i in chunk] 39 | responses, price = create_test_cases_using_gpt( 40 | programs=programs, 41 | instructions=instructions, 42 | use_cache=False, 43 | return_price=True, 44 | ) 45 | total_price += price 46 | tests = [] 47 | questions = [] 48 | for i in responses: 49 | obj = parse_incomplete_json(i) 50 | question = obj.get("question", "please ignore this question") 51 | test = obj.get("tests", ["assert False"]) 52 | questions.append(question) 53 | tests.append(test) 54 | for i in range(len(tests)): 55 | chunk[i]["tests"] = tests[i] 56 | chunk[i]["gpt_question"] = questions[i] 57 | 58 | append_jsonl(jsonl_file_name, chunk) 59 | inferenced_ct += len(chunk) 60 | print(f"Total Cost: {total_price}") 61 | 62 | 63 | if __name__ == "__main__": 64 | Fire(generate_oss_test_case) 65 | -------------------------------------------------------------------------------- /data/training_dataset/oss/oss_dataset.py: -------------------------------------------------------------------------------- 1 | from inference.GetDatasets import get_dataset_from_huggingface 2 | 3 | 4 | # https://huggingface.co/datasets/ise-uiuc/Magicoder-Evol-Instruct-110K?row=0 5 | def get_oss_dataset(): 6 | data = get_dataset_from_huggingface( 7 | "ise-uiuc/Magicoder-OSS-Instruct-75K-Instruction-Response" 8 | )["train"] 9 | return data 10 | -------------------------------------------------------------------------------- /data/training_dataset/oss/preprocess_oss.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | from tqdm import tqdm 5 | 6 | from training_dataset.oss.oss_dataset import get_oss_dataset 7 | from training_dataset.util import (get_python_code_from_string, 8 | remove_print_statements_from_python_program) 9 | from utility.utility import load_jsonl, save_jsonl 10 | 11 | 12 | def get_oss_programs(use_cache: bool = True) -> List[str]: 13 | """Step 1 of the process, extract python programs and instructions from the dataset. We only keep programs in function or class form.""" 14 | file_name = "training_dataset/oss/data/v1.jsonl" 15 | if os.path.exists(file_name) and use_cache: 16 | out = load_jsonl(file_name) 17 | return [i["program"] for i in out] 18 | os.makedirs("training_dataset/oss/data/", exist_ok=True) 19 | data = get_oss_dataset() 20 | out = [] 21 | idx = 0 22 | for i in tqdm(range(len(data))): 23 | if data[i]["lang"] != "python": 24 | continue # only do python for now 25 | program = get_python_code_from_string(data[i]["response"]) 26 | instruction = data[i]["instruction"] 27 | if len(program) == 0: 28 | # no python code found 29 | continue 30 | program = remove_print_statements_from_python_program(program) 31 | out.append({"id": idx, "instruction": instruction, "program": program}) 32 | idx += 1 33 | save_jsonl(file_name, out) 34 | return [i["program"] for i in out] 35 | 36 | 37 | if __name__ == "__main__": 38 | get_oss_programs() 39 | -------------------------------------------------------------------------------- /data/training_dataset/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Any, Dict, List 4 | 5 | 6 | def load_jsonl(file_path: str) -> List[Dict[Any, Any]]: 7 | """load a .jsonl file. Return a List of dictionary, where each dictionary is a line in the file""" 8 | if not os.path.exists(file_path): 9 | raise Exception(f"{file_path} Does not exist!!!!") 10 | with open(file_path, "r") as f: 11 | lst = f.readlines() 12 | lst = [json.loads(i) for i in lst] 13 | return lst 14 | 15 | 16 | def get_python_code_from_string(input: str) -> str: 17 | """Basically find code wrapped in ```python ... ``` and return it. If none is found then will return the 18 | empty string""" 19 | left_index = input.find("```python") 20 | if left_index < 0: 21 | return "" 22 | input = input[left_index + 9 :] 23 | right_index = input.find("```") 24 | if right_index < 0: 25 | return "" 26 | input = input[:right_index] 27 | return input 28 | 29 | 30 | def parse_incomplete_json(input: str) -> Any: 31 | """A helper function that will: 32 | 1. try to parse the whole thing as json 33 | 2. try to find json object wrapped in ```json ... ``` and parse it 34 | 3. Try to see if the json is incomplete. if so then try to parse the incomplete json 35 | 36 | This will only work when we are missing ]} at the end, modify if you need it for other 37 | cases. 38 | """ 39 | input = input.strip() 40 | left_idx = input.find("```json") 41 | if left_idx >= 0: 42 | input = input[left_idx + 7 :] 43 | right_idx = input.rfind("```") 44 | if right_idx >= 0: 45 | input = input[:right_idx] 46 | try: 47 | out = json.loads(input) 48 | return out 49 | except: 50 | pass 51 | 52 | # we now assume that the string is incomplete 53 | while len(input) > 0: 54 | try: 55 | data = json.loads(input + "]}") 56 | return data 57 | except json.decoder.JSONDecodeError: 58 | input = input[:-1] 59 | # we cannot parse this 60 | return {"question": None, "tests": None} 61 | 62 | 63 | def remove_print_statements_from_python_program(input: str) -> str: 64 | lst = input.splitlines() 65 | lst = [i for i in lst if not i.strip().startswith("print")] 66 | return "\n".join(lst) 67 | 68 | 69 | def print_data(file: str, idx: int = 0): 70 | data = load_jsonl(file) 71 | data = [row for row in data if row["id"] == idx][0] 72 | for key in data: 73 | print(f"----------------{key}:-------------------") 74 | if type(data[key]) == list: 75 | for i in data[key]: 76 | if type(i) == list: 77 | # we omit the original inferences for easier print statements 78 | for ii in i: 79 | print(ii) 80 | break 81 | else: 82 | print(i) 83 | print(f"Contained {len(data[key])} items-----") 84 | else: 85 | print(data[key]) 86 | 87 | 88 | if __name__ == "__main__": 89 | print_data("training_dataset/oss/data/v2.jsonl", 22) 90 | -------------------------------------------------------------------------------- /data/utility/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/AceCoder/e37ef2aea2e3710be6b06e9a91dd92c98321df96/data/utility/__init__.py -------------------------------------------------------------------------------- /data/utility/utility.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from typing import Any, Dict, List 5 | 6 | 7 | def chunking(lst: List[Any], n: int) -> List[List[Any]]: 8 | """Split a list into a list of list where each sublist is of size n""" 9 | if n <= 0: 10 | raise Exception(f"Are you fucking kidding me with n = {n}?") 11 | if len(lst) <= n: 12 | return [lst] 13 | return [lst[i : i + n] for i in range(0, len(lst), n)] 14 | 15 | 16 | def load_jsonl(file_path: str) -> List[Dict[Any, Any]]: 17 | """load a .jsonl file. Return a List of dictionary, where each dictionary is a line in the file""" 18 | if not os.path.exists(file_path): 19 | raise Exception(f"{file_path} Does not exist!!!!") 20 | with open(file_path, "r") as f: 21 | lst = f.readlines() 22 | output = [json.loads(i) for i in lst] 23 | return output 24 | 25 | 26 | def save_jsonl(file_path: str, content: List[Dict[Any, Any]]) -> None: 27 | """save a .jsonl file.""" 28 | with open(file_path, "w") as f: 29 | for i in content: 30 | f.write(json.dumps(i) + "\n") 31 | 32 | 33 | def append_jsonl(file_path: str, content: List[Dict[Any, Any]]) -> None: 34 | """append to a .jsonl file.""" 35 | with open(file_path, "a") as f: 36 | for i in content: 37 | f.write(json.dumps(i) + "\n") 38 | 39 | 40 | class MyTimer: 41 | """A simple timer class where you initialize it, then just call print_runtime everytime you want to time yourself""" 42 | 43 | def __init__(self) -> None: 44 | self.start = time.time() 45 | 46 | def print_runtime(self, message: str, reset_timer: bool = True) -> None: 47 | """Print the runtime, the output will be in the form of f"{message} took ..." 48 | 49 | Parameter: 50 | message: a string indicating what you have done 51 | reset_timer: whether to reset timer so that next call to this function will show the time in between print_runtime 52 | """ 53 | runtime = time.time() - self.start 54 | minute = int(runtime / 60) 55 | seconds = runtime % 60 56 | if minute > 0: 57 | print(f"{message} took {minute} minutes {seconds} seconds") 58 | else: 59 | print(f"{message} took {seconds} seconds") 60 | 61 | if reset_timer: 62 | self.start = time.time() 63 | -------------------------------------------------------------------------------- /examples/run_acecoderm.py: -------------------------------------------------------------------------------- 1 | """pip install git+https://github.com/TIGER-AI-Lab/AceCoder""" 2 | from acecoder import Qwen2ForCausalRM 3 | from transformers import AutoTokenizer 4 | 5 | model_path = "TIGER-Lab/AceCodeRM-7B" 6 | model = Qwen2ForCausalRM.from_pretrained(model_path, device_map="auto") 7 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 8 | 9 | question = """\ 10 | Given an array of numbers, write a function runningSum that returns an array where each element at index i is the sum of all elements from index 0 to i (inclusive). 11 | For example: 12 | Input: nums = [1,2,3,4] 13 | Output: [1,3,6,10] 14 | """ 15 | 16 | program_with_3_errors = """\ 17 | def runningSum(nums): 18 | result = [] 19 | current_sum = 0 20 | for i in range(1, len(nums)): 21 | result.append(nums[i]) 22 | current_sum += nums[i] 23 | return result 24 | """ 25 | 26 | program_with_2_errors = """\ 27 | def runningSum(nums): 28 | result = [] 29 | current_sum = 0 30 | for i in range(0, len(nums)): 31 | result.append(nums[i]) 32 | current_sum += nums[i] 33 | return result 34 | """ 35 | 36 | program_with_1_errors = """\ 37 | def runningSum(nums): 38 | result = [] 39 | current_sum = 0 40 | for i in range(0, len(nums)): 41 | result.append(current_sum) 42 | current_sum += nums[i] 43 | return result 44 | """ 45 | program_correct = """\ 46 | def runningSum(nums): 47 | result = [] 48 | current_sum = 0 49 | for num in nums: 50 | current_sum += num 51 | result.append(current_sum) 52 | return result 53 | """ 54 | 55 | program_chats = [ 56 | [ 57 | { 58 | "content": question, 59 | "role": "user", 60 | }, 61 | { 62 | "role": "assistant", 63 | "content": program 64 | } 65 | ] for program in [program_with_3_errors, program_with_2_errors, program_with_1_errors, program_correct] 66 | ] 67 | 68 | input_tokens = tokenizer.apply_chat_template( 69 | program_chats, 70 | tokenize=True, 71 | return_dict=True, 72 | padding=True, 73 | return_tensors="pt", 74 | ).to(model.device) 75 | 76 | _, _, values = model( 77 | **input_tokens, 78 | output_hidden_states=True, 79 | return_dict=True, 80 | use_cache=False, 81 | ) 82 | masks = input_tokens["attention_mask"] 83 | rm_scores = values.gather( 84 | dim=-1, index=(masks.sum(dim=-1, keepdim=True) - 1) 85 | ) # find the last token (eos) in each sequence, a 86 | rm_scores = rm_scores.squeeze() 87 | 88 | print("RM Scores:", rm_scores) 89 | print("Score of program with 3 errors:", rm_scores[0].item()) 90 | print("Score of program with 2 errors:", rm_scores[1].item()) 91 | print("Score of program with 1 errors:", rm_scores[2].item()) 92 | print("Score of correct program:", rm_scores[3].item()) 93 | """ 94 | RM Scores: tensor([-20.5058, -1.7867, 0.4395, 23.0689], device='cuda:0', 95 | grad_fn=) 96 | Score of program with 3 errors: -20.505754470825195 97 | Score of program with 2 errors: -1.7866804599761963 98 | Score of program with 1 errors: 0.43949759006500244 99 | Score of correct program: 23.068859100341797 100 | """ -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='acecoder', 5 | version='0.0.1', 6 | description='Official Codes for of "ACECODER: Acing Coder RL via Automated Test-Case Synthesis"', 7 | long_description=open('README.md').read(), 8 | long_description_content_type='text/markdown', 9 | author='Dongfu Jiang', 10 | author_email='dongfu.jiang@uwaterloo.ca', 11 | package_dir={'': 'src'}, # Add this line 12 | packages=find_packages(where='src'), # Modify this line 13 | url='https://github.com/TIGER-AI-Lab/AceCoder', 14 | install_requires=[ 15 | "transformers", 16 | "torch", 17 | "datasets", 18 | "accelerate", 19 | ], 20 | ) 21 | 22 | 23 | 24 | # change it to pyproject.toml 25 | # [build-system] 26 | # python setup.py sdist bdist_wheel 27 | # twine upload dist/* -------------------------------------------------------------------------------- /src/acecoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .rm_utils import AceCodeRM 2 | from .eval_test_cases import evaluate as evaluate_test_cases 3 | -------------------------------------------------------------------------------- /src/acecoder/eval_test_cases.py: -------------------------------------------------------------------------------- 1 | import json 2 | import multiprocessing 3 | import os 4 | import pickle 5 | import threading 6 | import time 7 | from collections import Counter, defaultdict 8 | from concurrent.futures import ProcessPoolExecutor, as_completed 9 | from datetime import datetime 10 | from typing import Any, Dict, List, Optional, Tuple, Union 11 | from warnings import warn 12 | 13 | import numpy as np 14 | from termcolor import cprint 15 | from tqdm import tqdm 16 | 17 | from evalplus.sanitize import sanitize, code_extract 18 | 19 | # from evalplus.config import * 20 | from .evalplus_eval import ( 21 | untrusted_check_assert, 22 | ) 23 | 24 | DEFAULT_GT_TIME_LIMIT_FACTOR = 4.0 25 | DEFAULT_MIN_TIME_LIMIT = 1.0 26 | 27 | # 1st item: the status 28 | # 2nd item (optional): the detailed pass/fail boolean for each input 29 | Result = Tuple[str, List[bool]] 30 | 31 | def check_correctness_assert( 32 | task_id: int, 33 | completion_id: int, 34 | entry_point: str, 35 | solution: str, 36 | assert_tests: List[str], 37 | dataset: str=None, 38 | base_only=False, 39 | fast_check=False, 40 | identifier=None, 41 | min_time_limit: float = DEFAULT_MIN_TIME_LIMIT, 42 | gt_time_limit_factor: float = DEFAULT_GT_TIME_LIMIT_FACTOR, 43 | extract_solution:bool=False, 44 | atol: int=1e-6, 45 | ) -> Dict[str, Result]: # {...}, "base" | "plus" -> (status, details) 46 | is_extracted = False 47 | if extract_solution: 48 | # base model may sometimes outputs too many "\n" and makes the code extraction too ** flow. 49 | # so we skip them if the number of lines > 500 50 | if not len(solution.split("\n")) > 500: 51 | extracted_solution = code_extract(solution.encode('utf-8', 'ignore').decode('utf-8').replace('\x00', '')) 52 | # if entry_point in _solution: 53 | # solution = _solution 54 | is_extracted = True 55 | else: 56 | extracted_solution = solution 57 | ret = { 58 | "completion_id": completion_id, 59 | "task_id": task_id, 60 | "_identifier": identifier, 61 | "solution": extracted_solution, 62 | "n_tests": len(assert_tests), 63 | } 64 | eval_results = untrusted_check_assert( 65 | dataset, 66 | extracted_solution, 67 | entry_point, 68 | assert_tests, 69 | atol=atol, 70 | ref_time=[DEFAULT_MIN_TIME_LIMIT]*len(assert_tests), 71 | fast_check=fast_check, 72 | min_time_limit=min_time_limit, 73 | gt_time_limit_factor=gt_time_limit_factor, 74 | ) 75 | if eval_results["status"] == "syntax_error" and is_extracted: 76 | # try to use the original solution 77 | ret['solution'] = solution 78 | eval_results = untrusted_check_assert( 79 | dataset, 80 | solution, 81 | entry_point, 82 | assert_tests, 83 | atol=atol, 84 | ref_time=[DEFAULT_MIN_TIME_LIMIT]*len(assert_tests), 85 | fast_check=fast_check, 86 | min_time_limit=min_time_limit, 87 | gt_time_limit_factor=gt_time_limit_factor, 88 | ) 89 | ret["eval_results"] = eval_results 90 | 91 | return ret 92 | 93 | def get_entry_point_from_test_case(test_case: str) -> str: 94 | """ 95 | Get the entry point from the first test case. 96 | Args: 97 | test_case: a test case string, like "assert f(1) == 2" 98 | Returns: 99 | the entry point, like "f" 100 | """ 101 | start_idx = test_case.find("assert ") + len("assert ") 102 | end_idx = test_case.find("(") 103 | return test_case[start_idx:end_idx] 104 | 105 | def get_test_inputs_outputs_from_test_case(test_cases: List[str]) -> Tuple[List[str], List[str]]: 106 | """ 107 | Get the inputs and outputs from the test cases. 108 | Args: 109 | test_cases: a list of test case strings 110 | Returns: 111 | a tuple of inputs and outputs 112 | """ 113 | inputs = [] 114 | outputs = [] 115 | for test_case in test_cases: 116 | input_start_idx = test_case.find("(") 117 | assert input_start_idx != -1, f"Cannot find '(' in {test_case}" 118 | output_start_idx = test_case.find("==") 119 | if output_start_idx == -1: 120 | output_start_idx = test_case.rfind("is") 121 | assert output_start_idx != -1, f"Cannot find '==' or 'is' in {test_case}" 122 | output_start_idx += 2 123 | input_end_idx = test_case[:output_start_idx].rfind(")") 124 | assert input_end_idx != -1, f"Cannot find ')' in {test_case}" 125 | test_input = test_case[input_start_idx+1:input_end_idx].strip() 126 | try: 127 | if test_input: 128 | test_input = eval(test_input) 129 | else: 130 | test_input = [] 131 | except: 132 | print(f"Cannot eval {test_input}") 133 | print(test_case) 134 | print(input_start_idx, input_end_idx) 135 | raise 136 | inputs.append(test_input) 137 | assert output_start_idx != -1, f"Cannot find '==' in {test_case}" 138 | output = eval(test_case[output_start_idx:].strip()) 139 | outputs.append(output) 140 | return inputs, outputs 141 | 142 | 143 | def evaluate( 144 | samples: Union[str, List[Dict[str, Any]]], 145 | dataset: str = None, 146 | base_only: bool = False, 147 | parallel: Optional[int] = None, 148 | i_just_wanna_run: bool = False, 149 | test_details: bool = True, 150 | min_time_limit: float = DEFAULT_MIN_TIME_LIMIT, 151 | gt_time_limit_factor: float = DEFAULT_GT_TIME_LIMIT_FACTOR, 152 | output_file: Optional[str] = None, 153 | n_workers: Optional[int] = None, 154 | extract_solution: bool = True, 155 | ): 156 | if not n_workers: 157 | n_workers = parallel or max(1, multiprocessing.cpu_count() // 2) 158 | 159 | if isinstance(samples, str) and os.path.exists(samples): 160 | result_path = samples.replace(".jsonl", ".eval_results.json") 161 | elif isinstance(samples, list): 162 | result_path = None 163 | 164 | if output_file is not None: 165 | result_path = output_file 166 | 167 | if result_path and os.path.isfile(result_path) and not i_just_wanna_run: 168 | print(f"Load from previous results from {result_path}") 169 | if result_path.endswith(".jsonl"): 170 | with open(result_path, "r") as f: 171 | all_samples_results = [json.loads(line) for line in f] 172 | else: 173 | with open(result_path, "r") as f: 174 | all_samples_results = json.load(f) 175 | 176 | else: 177 | if isinstance(samples, str) and os.path.exists(samples): 178 | if samples.endswith(".jsonl"): 179 | with open(samples, "r") as f: 180 | all_samples = [json.loads(line) for line in f] 181 | else: 182 | with open(samples, "r") as f: 183 | all_samples = json.load(f) 184 | else: 185 | all_samples = samples 186 | 187 | dataset_hash = None 188 | 189 | _identifier_list = [x['_identifier'] for x in all_samples] 190 | with ProcessPoolExecutor(max_workers=n_workers) as executor: 191 | futures = [] 192 | completion_id = Counter() 193 | n_samples = 0 194 | eval_results = defaultdict(list) # task_id -> 195 | remainings = set() 196 | 197 | for sample in tqdm(all_samples, desc="Submitting samples"): 198 | task_id = sample["task_id"] 199 | # test_inputs, expected_output = get_test_inputs_outputs_from_test_case(sample["tests"]) 200 | entry_point = get_entry_point_from_test_case(sample['tests'][0]) 201 | solution = sample["output"] 202 | remainings.add(sample["_identifier"]) 203 | args = ( 204 | task_id, 205 | completion_id[task_id], 206 | entry_point, 207 | solution, 208 | sample["tests"], 209 | dataset, 210 | base_only, 211 | not test_details, # fast_check 212 | sample["_identifier"] if "_identifier" in sample else None, 213 | min_time_limit, 214 | gt_time_limit_factor, 215 | extract_solution, 216 | ) 217 | futures.append(executor.submit(check_correctness_assert, *args)) 218 | completion_id[task_id] += 1 219 | n_samples += 1 220 | 221 | assert n_samples == len(remainings), "Missing problems in unfinished" 222 | 223 | def stucking_checker(): 224 | while remainings: 225 | last_size = len(remainings) 226 | time.sleep(20) 227 | if last_size != len(remainings) or len(remainings) == 0: 228 | continue 229 | # Potential stucking 230 | warn("No samples had finished testing in the last 20s") 231 | warn(f"{len(remainings)} samples to be tested...") 232 | 233 | threading.Thread(target=stucking_checker).start() 234 | 235 | all_samples_results_identifier_map = {} 236 | for i, future in tqdm(enumerate(as_completed(futures)), total=n_samples): 237 | result = future.result() 238 | # except TimeoutError: 239 | # print(f"Timeout for {i}th sample") 240 | # result = { 241 | # "completion_id": i, 242 | # "task_id": task_id, 243 | # "_identifier": sample["_identifier"], 244 | # "solution": solution, 245 | # "n_tests": len(sample["tests"]), 246 | # "base": ["timeout", []] 247 | # } 248 | remainings.remove(result["_identifier"]) 249 | # result['pass_rate'] = result['eval_results']['pass_rate'] 250 | all_samples_results_identifier_map[result["_identifier"]] = result 251 | eval_results[result["task_id"]].append(result) 252 | 253 | all_samples_results = [all_samples_results_identifier_map[x] for x in _identifier_list] 254 | # save the results 255 | if result_path: 256 | if result_path.endswith(".jsonl"): 257 | with open(result_path, "w") as f: 258 | for result in all_samples_results: 259 | f.write(json.dumps(result) + "\n") 260 | else: 261 | with open(result_path, "w") as f: 262 | json.dump(all_samples_results, f, indent=4) 263 | print(f"Results saved to {result_path}") 264 | 265 | pass_rates = [x['eval_results']['pass_rate'] for x in all_samples_results] 266 | if __name__ == "__main__": 267 | print(f"Pass rate: {np.mean(pass_rates)}") 268 | else: 269 | return all_samples_results, pass_rates 270 | 271 | def main(): 272 | from fire import Fire 273 | 274 | Fire(evaluate) 275 | 276 | 277 | if __name__ == "__main__": 278 | main() 279 | -------------------------------------------------------------------------------- /src/acecoder/evalplus_eval.py: -------------------------------------------------------------------------------- 1 | # The MIT License 2 | # 3 | # Copyright (c) OpenAI (https://openai.com) 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in 13 | # all copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | # THE SOFTWARE. 22 | 23 | import itertools 24 | import multiprocessing 25 | import os 26 | import time 27 | from multiprocessing import Array, Value 28 | from typing import Any, Dict, List, Optional, Tuple, Union 29 | 30 | import numpy as np 31 | import psutil 32 | import threading 33 | 34 | from ctypes import c_char, create_string_buffer 35 | from evalplus.config import * 36 | from evalplus.eval.utils import ( 37 | create_tempdir, 38 | reliability_guard, 39 | swallow_io, 40 | time_limit, 41 | TimeoutException 42 | ) 43 | 44 | def compatible_eval_result(results: Dict) -> Dict: 45 | # compatibility 46 | for task_results in results["eval"].values(): 47 | # update the "files" field to "nfiles" 48 | if "files" in task_results and "nfiles" not in task_results: 49 | task_results["nfiles"] = len(task_results.pop("files")) 50 | return results 51 | 52 | # Example usage: 53 | def read_string(arr, index, str_length=256): 54 | start = index * str_length 55 | # Read until null terminator or end of string slot 56 | raw = arr[start:start + str_length] 57 | return raw.split(b'\x00')[0].decode() 58 | 59 | def write_string(arr, index, string, str_length=256): 60 | start = index * str_length 61 | buf = create_string_buffer(string[:str_length].encode(), str_length) 62 | arr[start:start + str_length] = buf.raw 63 | 64 | # unbiased estimator from https://github.com/openai/human-eval 65 | def estimate_pass_at_k( 66 | num_samples: Union[int, List[int], np.ndarray], 67 | num_correct: Union[List[int], np.ndarray], 68 | k: int, 69 | ) -> np.ndarray: 70 | """ 71 | Estimates pass@k of each problem and returns them in an array. 72 | """ 73 | 74 | def estimator(n: int, c: int, k: int) -> float: 75 | """ 76 | Calculates 1 - comb(n - c, k) / comb(n, k). 77 | """ 78 | if n - c < k: 79 | return 1.0 80 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) 81 | 82 | if isinstance(num_samples, int): 83 | num_samples_it = itertools.repeat(num_samples, len(num_correct)) 84 | else: 85 | assert len(num_samples) == len(num_correct) 86 | num_samples_it = iter(num_samples) 87 | 88 | return np.array( 89 | [estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)] 90 | ) 91 | 92 | 93 | PASS = "pass" 94 | FAIL = "fail" 95 | TIMEOUT = "timeout" 96 | MISSING_DEPENDENCY = "missing_dependency" 97 | UNEXECUTED = "unexecuted" 98 | SYNTAX_ERROR = "syntax_error" 99 | 100 | _SUCCESS = 0 101 | _FAILED = 1 102 | _TIMEOUT = 2 103 | _UNKNOWN = 3 104 | _MISSING_DEPENDENCY = 4 105 | _UNEXECUTED = 5 106 | _SYNTAX_ERROR = 6 107 | 108 | _mapping = {_SUCCESS: PASS, _FAILED: FAIL, _TIMEOUT: TIMEOUT, _MISSING_DEPENDENCY: MISSING_DEPENDENCY, _UNEXECUTED: UNEXECUTED, _SYNTAX_ERROR: SYNTAX_ERROR, _UNKNOWN: None} 109 | ERROR_STR_LEN = 256 110 | 111 | 112 | def query_maximum_memory_bytes() -> Optional[int]: 113 | # Disable functionalities that can make destructive changes to the test. 114 | # allow only 4GB memory usage 115 | maximum_memory_bytes = os.getenv( 116 | "EVALPLUS_MAX_MEMORY_BYTES", 4 * 1024 * 1024 * 1024 117 | ) 118 | maximum_memory_bytes = min(int(maximum_memory_bytes), psutil.virtual_memory().total) 119 | if maximum_memory_bytes == -1: 120 | return None 121 | return maximum_memory_bytes 122 | 123 | 124 | def is_floats(x) -> bool: 125 | # check if it is float; List[float]; Tuple[float] 126 | if isinstance(x, float): 127 | return True 128 | if isinstance(x, (list, tuple)) and x: 129 | return all(isinstance(i, float) for i in x) 130 | if isinstance(x, np.ndarray): 131 | return x.dtype == np.float64 or x.dtype == np.float32 132 | return False 133 | 134 | def unsafe_execute_assert( 135 | dataset: str, 136 | entry_point: str, 137 | code: str, 138 | assert_tests: List, 139 | time_limits, 140 | atol, 141 | fast_check, 142 | stat, # Value 143 | details, # Array 144 | code_error, 145 | tests_errors, # Array 146 | progress, # Value 147 | ): 148 | with create_tempdir(): 149 | # These system calls are needed when cleaning up tempdir. 150 | import os 151 | import shutil 152 | 153 | rmtree = shutil.rmtree 154 | rmdir = os.rmdir 155 | chdir = os.chdir 156 | reliability_guard(maximum_memory_bytes=query_maximum_memory_bytes()) 157 | exec_globals = {} 158 | try: 159 | with swallow_io(): 160 | exec(code, exec_globals) 161 | # fn = exec_globals[entry_point] 162 | 163 | for i, test_case in enumerate(assert_tests): 164 | with swallow_io(): 165 | try: 166 | with time_limit(time_limits[i]): 167 | exec(test_case, exec_globals) 168 | details[i] = _SUCCESS 169 | except ModuleNotFoundError as e: 170 | details[i] = _MISSING_DEPENDENCY 171 | write_string(tests_errors, i, str(e), ERROR_STR_LEN) 172 | except SyntaxError as e: 173 | details[i] = _SYNTAX_ERROR 174 | write_string(tests_errors, i, str(e), ERROR_STR_LEN) 175 | except TimeoutException as e: 176 | details[i] = _TIMEOUT 177 | write_string(tests_errors, i, str(e), ERROR_STR_LEN) 178 | except Exception as e: 179 | details[i] = _FAILED 180 | write_string(tests_errors, i, str(e), ERROR_STR_LEN) 181 | 182 | progress.value += 1 183 | if details[i] != _SUCCESS and fast_check: 184 | raise Exception("Fast check failed") 185 | 186 | stat.value = _SUCCESS 187 | except SyntaxError: 188 | stat.value = _SYNTAX_ERROR 189 | except ModuleNotFoundError: 190 | stat.value = _MISSING_DEPENDENCY 191 | except BaseException as e: 192 | # if module not found error, pring it for debug. 193 | # if "No module named" in str(e): 194 | # print(e) 195 | stat.value = _FAILED 196 | write_string(code_error, 0, str(e), ERROR_STR_LEN) 197 | # Needed for cleaning up. 198 | shutil.rmtree = rmtree 199 | os.rmdir = rmdir 200 | os.chdir = chdir 201 | 202 | def untrusted_check_assert( 203 | dataset: str, 204 | code: str, 205 | entry_point: str, 206 | assert_tests: List[str], 207 | atol, 208 | ref_time: List[float], 209 | fast_check: bool = False, 210 | min_time_limit: float = DEFAULT_MIN_TIME_LIMIT, 211 | gt_time_limit_factor: float = DEFAULT_GT_TIME_LIMIT_FACTOR, 212 | ) -> Tuple[str, np.ndarray]: 213 | time_limits = [max(min_time_limit, gt_time_limit_factor * t) for t in ref_time] 214 | timeout = min(os.getenv("EVALPLUS_TIMEOUT_PER_TASK", 15), sum(time_limits)) + 1 215 | if not fast_check: 216 | timeout += 1 # extra time for data collection 217 | 218 | # shared memory objects 219 | progress = Value("i", 0) 220 | stat = Value("i", _UNKNOWN) 221 | details = Array("b", [False for _ in range(len(assert_tests))]) 222 | # errors is a list of strings 223 | # Method 2: Or if you need to initialize with spaces 224 | tests_errors = Array(c_char, b" " * (len(assert_tests) * ERROR_STR_LEN)) 225 | code_error = Array(c_char, b" " * ERROR_STR_LEN) 226 | 227 | p = multiprocessing.Process( 228 | target=unsafe_execute_assert, 229 | args=( 230 | dataset, 231 | entry_point, 232 | code, 233 | assert_tests, 234 | time_limits, 235 | atol, 236 | fast_check, 237 | stat, 238 | details, 239 | code_error, 240 | tests_errors, 241 | progress, 242 | ), 243 | ) 244 | p.start() 245 | p.join(timeout=timeout + 1) 246 | if p.is_alive(): 247 | p.terminate() 248 | time.sleep(0.1) 249 | if p.is_alive(): 250 | p.kill() 251 | time.sleep(0.1) 252 | 253 | stat = _mapping[stat.value] 254 | details = details[: progress.value] + [_UNEXECUTED] * (len(assert_tests) - progress.value) 255 | 256 | tests_errors = [read_string(tests_errors, i, ERROR_STR_LEN) for i in range(len(assert_tests))] 257 | tests_errors = [x if x.strip() else None for x in tests_errors] 258 | code_error = read_string(code_error, 0, ERROR_STR_LEN) if code_error[0] != 0 else None 259 | code_error = code_error if code_error.strip() else None 260 | 261 | details = [{"pass": x == _SUCCESS, "reason": _mapping[x], "error_message": tests_errors[i], "time_limit": time_limits[i]} for i, x in enumerate(details)] 262 | pass_rate = sum([x["pass"] for x in details]) / len(details) if details else 0 263 | 264 | if not stat: 265 | stat = TIMEOUT 266 | 267 | if stat == PASS: 268 | if len(details) != len(assert_tests) or not all([x["pass"] for x in details]): 269 | stat = FAIL 270 | 271 | result = { 272 | "status": stat, 273 | "code_error": code_error, 274 | "details": details, 275 | "pass_rate": pass_rate 276 | } 277 | 278 | return result -------------------------------------------------------------------------------- /src/acecoder/rm_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import Qwen2ForCausalLM 4 | 5 | class ValueHead(nn.Module): 6 | r""" 7 | The ValueHead class implements a head for GPT2 that returns a scalar for each output token. 8 | """ 9 | 10 | def __init__(self, config, **kwargs): 11 | super().__init__() 12 | if not hasattr(config, "summary_dropout_prob"): 13 | summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1) 14 | else: 15 | summary_dropout_prob = config.summary_dropout_prob 16 | 17 | self.dropout = ( 18 | nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity() 19 | ) 20 | 21 | # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m 22 | if hasattr(config, "hidden_size"): 23 | hidden_size = config.hidden_size 24 | if hasattr(config, "word_embed_proj_dim"): 25 | hidden_size = config.word_embed_proj_dim 26 | elif hasattr(config, "is_encoder_decoder"): 27 | if config.is_encoder_decoder and hasattr(config, "decoder"): 28 | if hasattr(config.decoder, "hidden_size"): 29 | hidden_size = config.decoder.hidden_size 30 | 31 | self.summary = nn.Linear(hidden_size, 1) 32 | 33 | self.flatten = nn.Flatten() 34 | 35 | def forward(self, hidden_states): 36 | output = self.dropout(hidden_states) 37 | 38 | # For now force upcast in fp32 if needed. Let's keep the 39 | # output in fp32 for numerical stability. 40 | if output.dtype != self.summary.weight.dtype: 41 | output = output.to(self.summary.weight.dtype) 42 | 43 | output = self.summary(output) 44 | return output 45 | 46 | 47 | class AceCodeRM(Qwen2ForCausalLM): 48 | def __init__(self, config): 49 | super().__init__(config) 50 | self.v_head = ValueHead(config) 51 | 52 | def forward( 53 | self, 54 | input_ids=None, 55 | past_key_values=None, 56 | attention_mask=None, 57 | return_past_key_values=False, 58 | **kwargs, 59 | ): 60 | r""" 61 | Applies a forward pass to the wrapped model and returns the logits of the value head. 62 | 63 | Args: 64 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 65 | Indices of input sequence tokens in the vocabulary. 66 | past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`): 67 | Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model 68 | (see `past_key_values` input) to speed up sequential decoding. 69 | attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): 70 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 71 | - 1 for tokens that are **not masked**, 72 | - 0 for tokens that are **masked**. 73 | return_past_key_values (bool): A flag indicating if the computed hidden-states should be returned. 74 | kwargs (`dict`, `optional`): 75 | Additional keyword arguments, that are passed to the wrapped model. 76 | """ 77 | kwargs["output_hidden_states"] = ( 78 | True # this had already been set in the LORA / PEFT examples 79 | ) 80 | kwargs["past_key_values"] = past_key_values 81 | 82 | # if ( 83 | # self.is_peft_model 84 | # and 85 | # self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING" 86 | # ): 87 | # kwargs.pop("past_key_values") 88 | 89 | base_model_output = super().forward( 90 | input_ids=input_ids, 91 | attention_mask=attention_mask, 92 | **kwargs, 93 | ) 94 | 95 | last_hidden_state = base_model_output.hidden_states[-1] 96 | lm_logits = base_model_output.logits 97 | loss = base_model_output.loss 98 | 99 | if last_hidden_state.device != self.v_head.summary.weight.device: 100 | last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device) 101 | 102 | value = self.v_head(last_hidden_state).squeeze(-1) 103 | 104 | # force upcast in fp32 if logits are in half-precision 105 | if lm_logits.dtype != torch.float32: 106 | lm_logits = lm_logits.float() 107 | 108 | rm_scores = value.gather( 109 | dim=-1, index=(attention_mask.sum(dim=-1, keepdim=True) - 1) 110 | ) # find the last token (eos) in each sequence, a 111 | 112 | rm_scores = rm_scores.squeeze() 113 | 114 | if return_past_key_values: 115 | return (rm_scores, base_model_output.past_key_values) 116 | else: 117 | return rm_scores -------------------------------------------------------------------------------- /train/train_rl/README.md: -------------------------------------------------------------------------------- 1 | # Installtion 2 | You need to first update the submodule by running the following command: 3 | ```bash 4 | git submodule init 5 | git submodule update 6 | ``` 7 | 8 | Then, you can install the required packages for OpenRLHF with the following command: 9 | ```bash 10 | conda create -n 11 | cd OpenRLHF 12 | pip install -e .[vllm] 13 | pip install evalplus # requried for rule-based reward for code generation 14 | ``` 15 | 16 | ## Data Preparation 17 | - To get the AceCode-87K-hard that only keeps 25% of the examples that makes the RL training faster, run the following command: 18 | ```bash 19 | python scripts/get_hard_data.py --dataset_path "TIGER-Lab/AceCode-87K" --output_path "./data/acecode_87K/acecode_87K.json" --only_keep_hard_examples True 20 | ``` 21 | 22 | ## Reward model preparation 23 | Since [AceCodeRM-7B](https://huggingface.co/TIGER-Lab/AceCodeRM-7B) is trained with LlamaFactory, the format might be different from the OpenRLHF RM format, but it's generally the same. The only difference is that the Llamafactory enabled the `bias=True` for the final linear layer, while OpenRLHF uses `bias=False`. 24 | 25 | Two ways to use RM for RL training: 26 | - Directly set `reward_pretrain="TIGER-Lab/AceCodeRM-7B"` in the RL training script and set `value_head_prefix="summary"` in the training script. 27 | - Convert the RM to OpenRLHF format weights with the following command: 28 | ```bash 29 | python scripts/change_lf_rm_to_openrlhf_rm.py --lf_rm_model_path "TIGER-Lab/AceCodeRM-7B" --openrlhf_rm_model_path "./models/AceCodeRM-7B-openrlhf" --push_to_hub False 30 | ``` 31 | Then, set `reward_pretrain="./models/AceCodeRM-7B-openrlhf"` in the RL training script and set `value_head_prefix="score"` in the training script. 32 | 33 | (Note: the reason why we use LlamaFactory for training RM is historical reason. We have tried using OpenRLHF to train RM, and the performance is similar.) 34 | 35 | 36 | ### Training RL 37 | 38 | please `export WANDB_API_KEY=your_wandb_api_key` before running the following scripts. 39 | 40 | - with reward model 41 | ```bash 42 | bash scripts/train_reinforce_ray.sh # reinforcement++ 43 | # and change the following variables in the script 44 | # policy_pretrain="Your initial policy model" 45 | # reward_pretrain="TIGER-Lab/AceCodeRM-7B" 46 | # dataset_path="./data/acecode_87K/acecode_87K.json" 47 | # run_name="Your run name" 48 | ``` 49 | - with rule-based reward (binary pass rate) 50 | ```bash 51 | bash scripts/train_reinforce_ray_rule_rm.sh # reinforcement++ 52 | # and change the following variables in the script 53 | # policy_pretrain="Your initial policy model" 54 | # binary_reward=True 55 | # dataset_path="./data/acecode_87K/acecode_87K.json" 56 | # run_name="Your run name" 57 | ``` -------------------------------------------------------------------------------- /train/train_rm/README.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | We use LLama-Factory off the shelf for our reward model training. Therefore, to install the environment, please refer to their repository. At the time of writing this page, the following scripts work for us: 4 | ```bash 5 | conda create -n llamaFactory python=3.11 6 | conda init 7 | conda activate llamaFactory 8 | pip install -e ".[torch,metrics]" 9 | pip install deepspeed==0.15.4 10 | pip install -U "huggingface_hub[cli]" 11 | ``` 12 | 13 | ## Setup 14 | Please complete the following steps: 15 | 1. Move the 3 files under configs into the llamafactory directory after you have cloned it. 16 | 2. Add the following two entries to `LLaMA-Factory/data/dataset_info.json`: 17 | ```json 18 | "AceCodePair-300K": { 19 | "hf_hub_url": "TIGER-Lab/AceCodePair-300K", 20 | "ranking": true, 21 | "columns": { 22 | "prompt": "instruction", 23 | "query": "input", 24 | "chosen": "chosen", 25 | "rejected": "rejected" 26 | } 27 | }, 28 | "AceCodePair-QwenCoderIns32B": { 29 | "hf_hub_url": "TIGER-Lab/AceCodePair-QwenCoderIns32B", 30 | "ranking": true, 31 | "columns": { 32 | "prompt": "instruction", 33 | "query": "input", 34 | "chosen": "chosen", 35 | "rejected": "rejected" 36 | } 37 | } 38 | ``` 39 | 40 | ## Training 41 | 1. Change the `output_dir` field in the yaml files that you have copied for the desired model output path. 42 | 2. Run: 43 | ```bash 44 | llamafactory-cli train train_qwen_coder_ins_2.5_{7/32}b.yaml 45 | ``` 46 | -------------------------------------------------------------------------------- /train/train_rm/configs/ds_z3_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": true 17 | }, 18 | "zero_optimization": { 19 | "stage": 3, 20 | "overlap_comm": true, 21 | "contiguous_gradients": true, 22 | "sub_group_size": 1e9, 23 | "reduce_bucket_size": "auto", 24 | "stage3_prefetch_bucket_size": 15000000, 25 | "stage3_param_persistence_threshold": "auto", 26 | "stage3_max_live_parameters": 1e9, 27 | "stage3_max_reuse_distance": 1e9, 28 | "stage3_gather_16bit_weights_on_model_save": true 29 | } 30 | } -------------------------------------------------------------------------------- /train/train_rm/configs/train_qwen_coder_ins_2.5_32b.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: Qwen/Qwen2.5-Coder-32B-Instruct 3 | 4 | ### method 5 | stage: rm 6 | do_train: true 7 | finetuning_type: full 8 | deepspeed: ds_z3_config.json 9 | 10 | ### dataset 11 | dataset: AceCodePair-QwenCoderIns32B 12 | template: qwen 13 | cutoff_len: 1024 14 | overwrite_cache: true 15 | preprocessing_num_workers: 16 16 | 17 | ### output 18 | output_dir: # Replace this with the output directory 19 | logging_steps: 10 20 | save_steps: 300 21 | plot_loss: true 22 | overwrite_output_dir: true 23 | stage: rm 24 | 25 | ### train 26 | per_device_train_batch_size: 1 27 | gradient_accumulation_steps: 16 28 | learning_rate: 1.0e-5 29 | num_train_epochs: 1.0 30 | lr_scheduler_type: cosine 31 | warmup_ratio: 0.1 32 | bf16: true 33 | ddp_timeout: 480000000 34 | 35 | ### eval 36 | val_size: 0.005 37 | per_device_eval_batch_size: 1 38 | eval_strategy: steps 39 | eval_steps: 50 40 | -------------------------------------------------------------------------------- /train/train_rm/configs/train_qwen_coder_ins_2.5_7b.yaml: -------------------------------------------------------------------------------- 1 | ### model 2 | model_name_or_path: Qwen/Qwen2.5-Coder-7B-Instruct 3 | 4 | ### method 5 | stage: rm 6 | do_train: true 7 | finetuning_type: full 8 | deepspeed: ds_z3_config.json 9 | 10 | ### dataset 11 | dataset: AceCodePair-300K 12 | template: qwen 13 | cutoff_len: 2048 14 | overwrite_cache: true 15 | preprocessing_num_workers: 16 16 | 17 | ### output 18 | output_dir: # Replace this with the output directory 19 | logging_steps: 10 20 | save_steps: 300 21 | plot_loss: true 22 | overwrite_output_dir: true 23 | stage: rm 24 | 25 | ### train 26 | per_device_train_batch_size: 4 27 | gradient_accumulation_steps: 4 28 | learning_rate: 1.0e-5 29 | num_train_epochs: 1.0 30 | lr_scheduler_type: cosine 31 | warmup_ratio: 0.1 32 | bf16: true 33 | ddp_timeout: 480000000 34 | 35 | ### eval 36 | val_size: 0.005 37 | per_device_eval_batch_size: 4 38 | eval_strategy: steps 39 | eval_steps: 200 40 | --------------------------------------------------------------------------------