├── .gitignore ├── LICENSE ├── README.md ├── core ├── __init__.py ├── evaluation.py └── prompts.py ├── eval_codet5.py ├── eval_llama.py ├── eval_mpt.py ├── eval_mpt_large.py ├── eval_opencode.py ├── eval_openllama.py ├── eval_replit.py ├── eval_replit_glaive.py ├── eval_replit_instruct.py ├── eval_starcoder.py ├── eval_wizard.py ├── eval_xgen.py ├── human-eval ├── LICENSE ├── README.md ├── data │ └── HumanEval.jsonl.gz ├── human_eval │ ├── __init__.py │ ├── data.py │ ├── evaluate_functional_correctness.py │ ├── evaluation.py │ └── execution.py ├── requirements.txt └── setup.py ├── process_eval.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | *.jsonl 162 | scratch.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Anton Bacaj 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # code-eval 2 | 3 | ## What 4 | 5 | This is a repo I use to run human-eval on code models, adjust as needed. Some scripts were adjusted from wizardcoder repo ([process_eval.py](https://github.com/nlpxucan/WizardLM/blob/main/WizardCoder/src/process_humaneval.py)). The evaluation code is duplicated in several files, mostly to handle edge cases around model tokenizing and loading (will clean it up). 6 | 7 | ## Results 8 | 9 | Table is sorted by pass@1 score. 10 | 11 | | model | size | pass@1 | pass@10 | screenshot | 12 | | ----------------------------------------------------------------------------------------------------- | ---- | ------- | ------- | ------------------------------------------------------------------------------------------------------------------ | 13 | | [sahil2801/replit-code-instruct-glaive](https://huggingface.co/sahil2801/replit-code-instruct-glaive) | 3B | 63.5% | 67% | ![instruct-glaive](https://github.com/abacaj/code-eval/assets/7272343/6fd7527d-0dc4-4b48-8a57-ad0373074bc5) | 14 | | [WizardCoder-15B-V1.0](https://huggingface.co/WizardLM/WizardCoder-15B-V1.0) | 15B | 57% | 68.9% | ![wizardcoder](https://github.com/abacaj/code-eval/assets/7272343/0b941ff8-b474-4236-bbc0-89d925bbd34e) | 15 | | [bigcode/starcoder](https://huggingface.co/bigcode/starcoder) | 15B | 34.6% | 48.7% | ![starcoder](https://github.com/abacaj/code-eval/assets/7272343/eb5df978-f56b-4557-a433-8b8fa863a059) | 16 | | [openchat/opencoderplus](https://huggingface.co/openchat/opencoderplus) | 15B | 27.3% | 43.9% | ![opencoder](https://github.com/abacaj/code-eval/assets/7272343/1fa9f5ef-941b-4ea8-981e-c3f258c03fee) | 17 | | [teknium/Replit-v1-CodeInstruct-3B](https://huggingface.co/teknium/Replit-v1-CodeInstruct-3B) | 3B | 25.8% | 42.6% | ![replit-codeinstruct-v1](https://github.com/abacaj/code-eval/assets/7272343/4fca98d8-2c22-43ce-9639-e998ecb4fedc) | 18 | | [teknium/Replit-v2-CodeInstruct-3B](https://huggingface.co/teknium/Replit-v2-CodeInstruct-3B) | 3B | 21.5% | 31% | ![replit-codeinstruct-v2](https://github.com/abacaj/code-eval/assets/7272343/655aaa1d-0715-4fcd-b9ba-a22b5fddb215) | 19 | | [replit-code-v1-3b](https://huggingface.co/replit/replit-code-v1-3b) | 3B | 17.1% | 29.8% | ![replit-code-v1](https://github.com/abacaj/code-eval/assets/7272343/6b387aa8-db60-4f04-b458-35b010b1145c) | 20 | | [mpt-7b](https://huggingface.co/mosaicml/mpt-7b) | 7B | 15.9% | 23.7% | ![mpt-7b](https://github.com/abacaj/code-eval/assets/7272343/16965905-a368-4254-aeab-5e44126eba84) | 21 | | [xgen-7b-8k-base](https://huggingface.co/Salesforce/xgen-7b-8k-base) | 7B | 14.9% | 22.5% | ![xgen-7b-8k-base](https://github.com/abacaj/code-eval/assets/7272343/995c84a9-ee69-43bf-8502-a74eba1d927a) | 22 | | [openllama-7b-v2](https://huggingface.co/openlm-research/open_llama_7b) | 7B | 14% | 23.1% | ![openllama-7b-v2](https://github.com/abacaj/code-eval/assets/7272343/e38f08a0-ae74-4c51-b3a7-638781477e1b) | 23 | | [llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b-hf) | 7B | 13.1% | 21.9% | ![llama-2-7b](https://github.com/abacaj/code-eval/assets/7272343/cc86cc7c-beac-4993-9ca3-d91a48a790e4) | 24 | | [llama-7b](https://huggingface.co/huggyllama/llama-7b) | 7B | 12.1% | 18.9% | ![llama-7b](https://github.com/abacaj/code-eval/assets/7272343/605a3c4e-0b2b-4c10-a185-f2a4d34ec10d) | 25 | | [mpt-30b](https://huggingface.co/mosaicml/mpt-30b) | 30B | pending | pending | pending | 26 | 27 | ## FAQ 28 | 29 | > Why is there a discrepancy on some of the scores between official numbers? 30 | 31 | Because it is not obvious or published what prompt or processing the official models used to conduct their evaluation on this benchmark. The goal here is to try and best reproduce those numbers, in many cases it is possible to get very close to the published numbers. 32 | 33 | All of the scores here were run independently of any published numbers and are reproducible by cloning the repo and following the setup. 34 | 35 | > Why do some models have a filter_code post generation step? 36 | 37 | Base models can in many cases repeat outputs, breaking the benchmark scores. Instruct models don't have this problem and so you won't see this step, they tend to output a end of sequence token. 38 | 39 | ## Setup 40 | 41 | Create python environment 42 | 43 | ```sh 44 | python -m venv env && source env/bin/activate 45 | ``` 46 | 47 | Install dependencies 48 | 49 | ```sh 50 | pip install -r requirements.txt 51 | ``` 52 | 53 | Run the eval script 54 | 55 | ```sh 56 | # replace script file name for various models: 57 | # eval_wizard.py 58 | # eval_opencode.py 59 | # eval_mpt.py 60 | # eval_starcoder.py 61 | # eval_replit.py 62 | # eval_replit_glaive.py 63 | # eval_replit_instruct.py 64 | 65 | python eval_wizard.py 66 | ``` 67 | 68 | Process the jsonl file to extract code samples from model completions. 69 | 70 | **Note**: Only wizard & opencoder require this, they return markdown output with code. 71 | 72 | ```sh 73 | # replace args for various models: 74 | # --path results/wizard --out_path results/wizard/eval.jsonl 75 | # --path results/opencode --out_path results/opencode/eval.jsonl 76 | 77 | python process_eval.py --path results/wizard --out_path results/wizard/processed.jsonl --add_prompt 78 | ``` 79 | 80 | Then get the results 81 | 82 | ```sh 83 | # replace args for various models: 84 | # results/wizard/processed.jsonl 85 | # results/starcoder/eval.jsonl 86 | # results/mpt/eval.jsonl 87 | # results/opencode/processed.jsonl 88 | # results/replit_instruct/eval.jsonl 89 | # results/replit_glaive/eval.jsonl 90 | # results/replit/eval.jsonl 91 | 92 | evaluate_functional_correctness results/wizard/processed.jsonl 93 | ``` 94 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluation import run_eval, fix_indents, filter_code, split_batch 2 | from .prompts import instruct_prompt, standard_prompt, replit_glaive_prompt 3 | -------------------------------------------------------------------------------- /core/evaluation.py: -------------------------------------------------------------------------------- 1 | from human_eval.data import write_jsonl, read_problems 2 | from transformers import ( 3 | PreTrainedModel, 4 | PreTrainedTokenizer, 5 | ) 6 | from tqdm import tqdm 7 | import itertools 8 | import typing 9 | 10 | BatchGenerator = typing.Callable[ 11 | [PreTrainedModel, PreTrainedTokenizer, str, int], list[str] 12 | ] 13 | 14 | 15 | # reference: https://github.com/declare-lab/instruct-eval/blob/main/human_eval/main.py#L35 16 | def filter_code(completion: str) -> str: 17 | # The program tends to overwrite, we only take the first function 18 | completion = completion.lstrip("\n") 19 | return completion.split("\n\n")[0] 20 | 21 | 22 | def fix_indents(text: str) -> str: 23 | return text.replace("\t", " ") 24 | 25 | 26 | def split_batch(samples: list[str], size=4): 27 | mini_batches = [] 28 | 29 | for i in range(0, len(samples), size): 30 | mini_batches.append(samples[i : i + size]) 31 | 32 | return mini_batches 33 | 34 | 35 | def run_eval( 36 | model: PreTrainedModel, 37 | tokenizer: PreTrainedTokenizer, 38 | num_samples_per_task: int, 39 | out_path: str, 40 | generate_batch_completion: BatchGenerator, 41 | format_tabs: bool = False, 42 | ): 43 | problems = read_problems() 44 | # problems = dict(itertools.islice(problems.items(), 20)) 45 | samples = [] 46 | pbar = tqdm(total=len(problems) * num_samples_per_task) 47 | 48 | for task_id in problems: 49 | if format_tabs: 50 | prompt = problems[task_id]["prompt"].replace(" ", "\t") 51 | else: 52 | prompt = problems[task_id]["prompt"] 53 | 54 | batch_completions = generate_batch_completion( 55 | model, tokenizer, prompt, num_samples_per_task 56 | ) 57 | 58 | for sample in batch_completions: 59 | result = dict( 60 | task_id=task_id, 61 | completion=sample, 62 | ) 63 | 64 | samples += [result] 65 | 66 | pbar.update(num_samples_per_task) 67 | 68 | write_jsonl(out_path, samples) 69 | -------------------------------------------------------------------------------- /core/prompts.py: -------------------------------------------------------------------------------- 1 | def instruct_prompt(prompt: str) -> str: 2 | return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nComplete the following Python code without any tests or explanation\n{prompt}\n\n### Response:""" 3 | 4 | 5 | def standard_prompt(prompt: str) -> str: 6 | return f"""Complete the following Python code without any tests or explanation\n{prompt}""" 7 | 8 | 9 | def write_prompt(prompt: str) -> str: 10 | return f"""Write a python program to complete the following code:\n{prompt}""" 11 | 12 | 13 | def replit_glaive_prompt(prompt: str) -> str: 14 | return f"""Below is an instruction that describes a task, paired with an input that provides further context.\n Write a response that appropriately completes the request.\n\n ### Instruction:\nWrite a program to perform the given task.\n\n Input:\n{prompt}\n\n### Response:""" 15 | -------------------------------------------------------------------------------- /eval_codet5.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoTokenizer, 3 | AutoModelForSeq2SeqLM, 4 | PreTrainedModel, 5 | PreTrainedTokenizer, 6 | ) 7 | from core import run_eval, filter_code, fix_indents, standard_prompt 8 | import os 9 | import torch 10 | 11 | # TODO: move to python-dotenv 12 | # add hugging face access token here 13 | TOKEN = "" 14 | 15 | 16 | @torch.inference_mode() 17 | def generate_batch_completion( 18 | model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt, batch_size 19 | ) -> list[str]: 20 | prompt_input = standard_prompt(prompt) 21 | input_batch = [prompt for _ in range(batch_size)] 22 | inputs = tokenizer(input_batch, return_tensors="pt").to(model.device) 23 | 24 | generated_ids = model.generate( 25 | **inputs, 26 | use_cache=True, 27 | max_new_tokens=512, 28 | temperature=0.2, 29 | top_p=0.95, 30 | do_sample=True, 31 | eos_token_id=tokenizer.eos_token_id, 32 | pad_token_id=tokenizer.pad_token_id, 33 | ) 34 | 35 | batch_completions = tokenizer.batch_decode( 36 | generated_ids, 37 | skip_special_tokens=True, 38 | ) 39 | 40 | return [filter_code(fix_indents(completion)) for completion in batch_completions] 41 | 42 | 43 | if __name__ == "__main__": 44 | # adjust for n = 10 etc 45 | num_samples_per_task = 10 46 | out_path = "results/codet5p_770/eval.jsonl" 47 | os.makedirs("results/codet5p_770", exist_ok=True) 48 | 49 | tokenizer = AutoTokenizer.from_pretrained( 50 | "Salesforce/codet5p-770m", 51 | trust_remote_code=True, 52 | use_auth_token=TOKEN, 53 | ) 54 | 55 | model = torch.compile( 56 | AutoModelForSeq2SeqLM.from_pretrained( 57 | "Salesforce/codet5p-770m", 58 | torch_dtype=torch.bfloat16, 59 | trust_remote_code=True, 60 | use_auth_token=TOKEN, 61 | ) 62 | .eval() 63 | .to("cuda:5") 64 | ) 65 | 66 | run_eval( 67 | model, 68 | tokenizer, 69 | num_samples_per_task, 70 | out_path, 71 | generate_batch_completion, 72 | True, 73 | ) 74 | -------------------------------------------------------------------------------- /eval_llama.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | LlamaTokenizer, 3 | LlamaForCausalLM, 4 | PreTrainedModel, 5 | PreTrainedTokenizer, 6 | ) 7 | from core import filter_code, run_eval, fix_indents 8 | import os 9 | import torch 10 | 11 | # TODO: move to python-dotenv 12 | # add hugging face access token here 13 | TOKEN = "" 14 | 15 | 16 | @torch.inference_mode() 17 | def generate_batch_completion( 18 | model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt, batch_size 19 | ) -> list[str]: 20 | input_batch = [prompt for _ in range(batch_size)] 21 | inputs = tokenizer(input_batch, return_tensors="pt").to(model.device) 22 | input_ids_cutoff = inputs.input_ids.size(dim=1) 23 | 24 | generated_ids = model.generate( 25 | **inputs, 26 | use_cache=True, 27 | max_new_tokens=512, 28 | temperature=0.2, 29 | top_p=0.95, 30 | do_sample=True, 31 | eos_token_id=tokenizer.eos_token_id, 32 | pad_token_id=tokenizer.eos_token_id, # model has no pad token 33 | ) 34 | 35 | batch_completions = tokenizer.batch_decode( 36 | [ids[input_ids_cutoff:] for ids in generated_ids], 37 | skip_special_tokens=True, 38 | ) 39 | 40 | return [filter_code(fix_indents(completion)) for completion in batch_completions] 41 | 42 | 43 | if __name__ == "__main__": 44 | # adjust for n = 10 etc 45 | num_samples_per_task = 10 46 | out_path = "results/llama/eval.jsonl" 47 | os.makedirs("results/llama", exist_ok=True) 48 | 49 | tokenizer = LlamaTokenizer.from_pretrained( 50 | "huggyllama/llama-7b", 51 | ) 52 | 53 | model = torch.compile( 54 | LlamaForCausalLM.from_pretrained( 55 | "huggyllama/llama-7b", 56 | torch_dtype=torch.bfloat16, 57 | ) 58 | .eval() 59 | .to("cuda") 60 | ) 61 | 62 | run_eval( 63 | model, 64 | tokenizer, 65 | num_samples_per_task, 66 | out_path, 67 | generate_batch_completion, 68 | True, 69 | ) 70 | -------------------------------------------------------------------------------- /eval_mpt.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoTokenizer, 3 | AutoModelForCausalLM, 4 | PreTrainedModel, 5 | PreTrainedTokenizer, 6 | ) 7 | from core import filter_code, run_eval 8 | import os 9 | import torch 10 | 11 | # TODO: move to python-dotenv 12 | # add hugging face access token here 13 | TOKEN = "" 14 | 15 | 16 | @torch.inference_mode() 17 | def generate_batch_completion( 18 | model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt, batch_size 19 | ) -> list[str]: 20 | input_batch = [prompt for _ in range(batch_size)] 21 | inputs = tokenizer(input_batch, return_tensors="pt").to(model.device) 22 | input_ids_cutoff = inputs.input_ids.size(dim=1) 23 | 24 | generated_ids = model.generate( 25 | **inputs, 26 | use_cache=True, 27 | max_new_tokens=512, 28 | temperature=0.2, 29 | top_p=0.95, 30 | do_sample=True, 31 | eos_token_id=tokenizer.eos_token_id, 32 | pad_token_id=tokenizer.eos_token_id, # model has no pad token 33 | ) 34 | 35 | batch_completions = tokenizer.batch_decode( 36 | [ids[input_ids_cutoff:] for ids in generated_ids], 37 | skip_special_tokens=True, 38 | ) 39 | 40 | return [filter_code(completion) for completion in batch_completions] 41 | 42 | 43 | if __name__ == "__main__": 44 | # adjust for n = 10 etc 45 | num_samples_per_task = 10 46 | out_path = "results/mpt/eval.jsonl" 47 | os.makedirs("results/mpt", exist_ok=True) 48 | 49 | tokenizer = AutoTokenizer.from_pretrained( 50 | "mosaicml/mpt-7b", 51 | trust_remote_code=True, 52 | use_auth_token=TOKEN, 53 | ) 54 | 55 | model = torch.compile( 56 | AutoModelForCausalLM.from_pretrained( 57 | "mosaicml/mpt-7b", 58 | torch_dtype=torch.bfloat16, 59 | trust_remote_code=True, 60 | use_auth_token=TOKEN, 61 | init_device="cuda", 62 | ).eval() 63 | ) 64 | 65 | run_eval( 66 | model, 67 | tokenizer, 68 | num_samples_per_task, 69 | out_path, 70 | generate_batch_completion, 71 | ) 72 | -------------------------------------------------------------------------------- /eval_mpt_large.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoTokenizer, 3 | AutoModelForCausalLM, 4 | PreTrainedModel, 5 | PreTrainedTokenizer, 6 | ) 7 | from core import filter_code, run_eval, split_batch 8 | import os 9 | import torch 10 | 11 | # TODO: move to python-dotenv 12 | # add hugging face access token here 13 | TOKEN = "" 14 | 15 | 16 | @torch.inference_mode() 17 | def generate_batch_completion( 18 | model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt, batch_size 19 | ) -> list[str]: 20 | input_batch = [prompt for _ in range(batch_size)] 21 | mini_batch = split_batch(input_batch, 2) 22 | batch_completions = [] 23 | 24 | for batch in mini_batch: 25 | inputs = tokenizer(batch, return_tensors="pt").to(model.device) 26 | input_ids_cutoff = inputs.input_ids.size(dim=1) 27 | 28 | generated_ids = model.generate( 29 | **inputs, 30 | use_cache=True, 31 | max_new_tokens=512, 32 | temperature=0.2, 33 | top_p=0.95, 34 | do_sample=True, 35 | eos_token_id=tokenizer.eos_token_id, 36 | pad_token_id=tokenizer.eos_token_id, # model has no pad token 37 | ) 38 | 39 | batch_completions += tokenizer.batch_decode( 40 | [ids[input_ids_cutoff:] for ids in generated_ids], 41 | skip_special_tokens=True, 42 | ) 43 | 44 | return [filter_code(completion) for completion in batch_completions] 45 | 46 | 47 | if __name__ == "__main__": 48 | # adjust for n = 10 etc 49 | num_samples_per_task = 10 50 | out_path = "results/mpt_large/eval.jsonl" 51 | os.makedirs("results/mpt_large", exist_ok=True) 52 | 53 | tokenizer = AutoTokenizer.from_pretrained( 54 | "mosaicml/mpt-30b", 55 | trust_remote_code=True, 56 | use_auth_token=TOKEN, 57 | ) 58 | 59 | model = torch.compile( 60 | AutoModelForCausalLM.from_pretrained( 61 | "mosaicml/mpt-30b", 62 | torch_dtype=torch.bfloat16, 63 | trust_remote_code=True, 64 | use_auth_token=TOKEN, 65 | device_map="auto", 66 | max_memory={ 67 | 0: "20GiB", 68 | 1: "20GiB", 69 | 2: "20GiB", 70 | }, 71 | ).eval() 72 | ) 73 | 74 | run_eval( 75 | model, 76 | tokenizer, 77 | num_samples_per_task, 78 | out_path, 79 | generate_batch_completion, 80 | ) 81 | -------------------------------------------------------------------------------- /eval_opencode.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoTokenizer, 3 | GPTBigCodeForCausalLM, 4 | PreTrainedModel, 5 | PreTrainedTokenizer, 6 | ) 7 | from core import run_eval, standard_prompt 8 | import os 9 | import torch 10 | 11 | # TODO: move to python-dotenv 12 | # add hugging face access token here 13 | TOKEN = "" 14 | 15 | 16 | @torch.inference_mode() 17 | def generate_batch_completion( 18 | model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt: str, batch_size: int 19 | ) -> list[str]: 20 | batch_input = [tokenize_opencode(tokenizer, prompt) for _ in range(batch_size)] 21 | inputs = convert_to_tensors(batch_input, model.device) 22 | input_ids_cutoff = inputs["input_ids"].size(dim=1) 23 | 24 | generated_ids = model.generate( 25 | **inputs, 26 | use_cache=True, 27 | max_new_tokens=512, 28 | temperature=0.2, 29 | top_p=0.95, 30 | do_sample=True, 31 | eos_token_id=tokenizer.eos_token_id, 32 | pad_token_id=tokenizer.eos_token_id, # model has no pad token 33 | ) 34 | 35 | batch_completions = tokenizer.batch_decode( 36 | [ids[input_ids_cutoff:] for ids in generated_ids], 37 | skip_special_tokens=True, 38 | ) 39 | 40 | return batch_completions 41 | 42 | 43 | def tokenize_opencode(tokenizer: PreTrainedTokenizer, prompt: str): 44 | input_ids = [] 45 | attention_mask = [] 46 | 47 | # verbose, but follows what is shown in the readme 48 | user = tokenizer("User:") 49 | prompt_text = tokenizer(standard_prompt(prompt)) 50 | eot_token = tokenizer("<|end_of_turn|>") 51 | assistant = tokenizer("Assistant:") 52 | 53 | # verbose, but follows what is shown in the readme 54 | input_ids += user.input_ids 55 | input_ids += prompt_text.input_ids 56 | input_ids += eot_token.input_ids 57 | input_ids += assistant.input_ids 58 | 59 | # verbose, but follows what is shown in the readme 60 | attention_mask += user.attention_mask 61 | attention_mask += prompt_text.attention_mask 62 | attention_mask += eot_token.attention_mask 63 | attention_mask += assistant.attention_mask 64 | 65 | return { 66 | "input_ids": input_ids, 67 | "attention_mask": attention_mask, 68 | } 69 | 70 | 71 | def convert_to_tensors(opencode_tokens: list[dict], device: torch.device): 72 | input_ids = [tokens["input_ids"] for tokens in opencode_tokens] 73 | attention_mask = [tokens["attention_mask"] for tokens in opencode_tokens] 74 | 75 | return { 76 | "input_ids": torch.tensor(input_ids).to(device), 77 | "attention_mask": torch.tensor(attention_mask).to(device), 78 | } 79 | 80 | 81 | if __name__ == "__main__": 82 | # adjust for n = 10 etc 83 | num_samples_per_task = 10 84 | out_path = "results/opencode/eval.jsonl" 85 | os.makedirs("results/opencode", exist_ok=True) 86 | 87 | tokenizer = AutoTokenizer.from_pretrained( 88 | "openchat/opencoderplus", 89 | use_auth_token=TOKEN, 90 | ) 91 | 92 | model = torch.compile( 93 | GPTBigCodeForCausalLM.from_pretrained( 94 | "openchat/opencoderplus", 95 | device_map="auto", 96 | torch_dtype=torch.bfloat16, 97 | max_memory={ 98 | 0: "18GiB", 99 | 1: "18GiB", 100 | }, 101 | use_auth_token=TOKEN, 102 | ).eval() 103 | ) 104 | 105 | run_eval( 106 | model, 107 | tokenizer, 108 | num_samples_per_task, 109 | out_path, 110 | generate_batch_completion, 111 | True, 112 | ) 113 | -------------------------------------------------------------------------------- /eval_openllama.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | LlamaTokenizer, 3 | LlamaForCausalLM, 4 | PreTrainedModel, 5 | PreTrainedTokenizer, 6 | ) 7 | from core import filter_code, run_eval, fix_indents 8 | import os 9 | import torch 10 | 11 | # TODO: move to python-dotenv 12 | # add hugging face access token here 13 | TOKEN = "" 14 | 15 | 16 | @torch.inference_mode() 17 | def generate_batch_completion( 18 | model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt, batch_size 19 | ) -> list[str]: 20 | input_batch = [prompt for _ in range(batch_size)] 21 | inputs = tokenizer(input_batch, return_tensors="pt").to(model.device) 22 | input_ids_cutoff = inputs.input_ids.size(dim=1) 23 | 24 | generated_ids = model.generate( 25 | **inputs, 26 | use_cache=True, 27 | max_new_tokens=512, 28 | temperature=0.2, 29 | top_p=0.95, 30 | do_sample=True, 31 | eos_token_id=tokenizer.eos_token_id, 32 | pad_token_id=tokenizer.eos_token_id, # model has no pad token 33 | ) 34 | 35 | batch_completions = tokenizer.batch_decode( 36 | [ids[input_ids_cutoff:] for ids in generated_ids], 37 | skip_special_tokens=True, 38 | ) 39 | 40 | return [filter_code(fix_indents(completion)) for completion in batch_completions] 41 | 42 | 43 | if __name__ == "__main__": 44 | # adjust for n = 10 etc 45 | num_samples_per_task = 10 46 | out_path = "results/openllama/eval.jsonl" 47 | os.makedirs("results/openllama", exist_ok=True) 48 | 49 | tokenizer = LlamaTokenizer.from_pretrained( 50 | "openlm-research/open_llama_7b_v2", 51 | ) 52 | 53 | model = torch.compile( 54 | LlamaForCausalLM.from_pretrained( 55 | "openlm-research/open_llama_7b_v2", 56 | torch_dtype=torch.bfloat16, 57 | ) 58 | .eval() 59 | .to("cuda") 60 | ) 61 | 62 | run_eval( 63 | model, 64 | tokenizer, 65 | num_samples_per_task, 66 | out_path, 67 | generate_batch_completion, 68 | True, 69 | ) 70 | -------------------------------------------------------------------------------- /eval_replit.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoTokenizer, 3 | AutoModelForCausalLM, 4 | PreTrainedModel, 5 | PreTrainedTokenizer, 6 | ) 7 | from core import run_eval, filter_code 8 | import os 9 | import torch 10 | 11 | # TODO: move to python-dotenv 12 | # add hugging face access token here 13 | TOKEN = "" 14 | 15 | 16 | @torch.inference_mode() 17 | def generate_batch_completion( 18 | model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt, batch_size 19 | ) -> list[str]: 20 | input_batch = [prompt for _ in range(batch_size)] 21 | inputs = tokenizer(input_batch, return_tensors="pt").to(model.device) 22 | input_ids_cutoff = inputs.input_ids.size(dim=1) 23 | 24 | generated_ids = model.generate( 25 | **inputs, 26 | use_cache=True, 27 | max_new_tokens=512, 28 | temperature=0.2, 29 | top_p=0.95, 30 | do_sample=True, 31 | eos_token_id=tokenizer.eos_token_id, 32 | pad_token_id=tokenizer.pad_token_id, 33 | ) 34 | 35 | batch_completions = tokenizer.batch_decode( 36 | [ids[input_ids_cutoff:] for ids in generated_ids], 37 | skip_special_tokens=True, 38 | clean_up_tokenization_spaces=False, 39 | ) 40 | 41 | return [filter_code(completion) for completion in batch_completions] 42 | 43 | 44 | if __name__ == "__main__": 45 | # adjust for n = 10 etc 46 | num_samples_per_task = 10 47 | out_path = "results/replit/eval.jsonl" 48 | os.makedirs("results/replit", exist_ok=True) 49 | 50 | tokenizer = AutoTokenizer.from_pretrained( 51 | "replit/replit-code-v1-3b", 52 | trust_remote_code=True, 53 | use_auth_token=TOKEN, 54 | ) 55 | 56 | model = torch.compile( 57 | AutoModelForCausalLM.from_pretrained( 58 | "replit/replit-code-v1-3b", 59 | torch_dtype=torch.bfloat16, 60 | trust_remote_code=True, 61 | use_auth_token=TOKEN, 62 | init_device="cuda", 63 | ).eval() 64 | ) 65 | 66 | run_eval( 67 | model, 68 | tokenizer, 69 | num_samples_per_task, 70 | out_path, 71 | generate_batch_completion, 72 | ) 73 | -------------------------------------------------------------------------------- /eval_replit_glaive.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoTokenizer, 3 | AutoModelForCausalLM, 4 | PreTrainedTokenizer, 5 | PreTrainedModel, 6 | ) 7 | from core import run_eval, replit_glaive_prompt 8 | import os 9 | import torch 10 | 11 | # TODO: move to python-dotenv 12 | # add hugging face access token here 13 | TOKEN = "" 14 | 15 | 16 | @torch.inference_mode() 17 | def generate_batch_completion( 18 | model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt: str, batch_size: int 19 | ) -> list[str]: 20 | prompt_input = replit_glaive_prompt(prompt) 21 | input_batch = [prompt_input for _ in range(batch_size)] 22 | inputs = tokenizer(input_batch, return_tensors="pt").to(model.device) 23 | input_ids_cutoff = inputs.input_ids.size(dim=1) 24 | 25 | generated_ids = model.generate( 26 | **inputs, 27 | use_cache=True, 28 | max_new_tokens=512, 29 | temperature=0.2, 30 | top_p=0.95, 31 | do_sample=True, 32 | eos_token_id=tokenizer.eos_token_id, 33 | pad_token_id=tokenizer.pad_token_id, 34 | ) 35 | 36 | batch_completions = tokenizer.batch_decode( 37 | [ids[input_ids_cutoff:] for ids in generated_ids], 38 | skip_special_tokens=True, 39 | clean_up_tokenization_spaces=False, 40 | ) 41 | 42 | return batch_completions 43 | 44 | 45 | if __name__ == "__main__": 46 | # adjust for n = 10 etc 47 | num_samples_per_task = 10 48 | out_path = "results/replit_glaive/eval.jsonl" 49 | os.makedirs("results/replit_glaive", exist_ok=True) 50 | 51 | tokenizer = AutoTokenizer.from_pretrained( 52 | "sahil2801/replit-code-instruct-glaive", 53 | trust_remote_code=True, 54 | use_auth_token=TOKEN, 55 | ) 56 | 57 | model = torch.compile( 58 | AutoModelForCausalLM.from_pretrained( 59 | "sahil2801/replit-code-instruct-glaive", 60 | torch_dtype=torch.bfloat16, 61 | trust_remote_code=True, 62 | use_auth_token=TOKEN, 63 | init_device="cuda", 64 | ).eval() 65 | ) 66 | 67 | run_eval( 68 | model, tokenizer, num_samples_per_task, out_path, generate_batch_completion 69 | ) 70 | -------------------------------------------------------------------------------- /eval_replit_instruct.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoTokenizer, 3 | AutoModelForCausalLM, 4 | PreTrainedTokenizer, 5 | PreTrainedModel, 6 | ) 7 | from core import run_eval, instruct_prompt 8 | import os 9 | import torch 10 | 11 | # TODO: move to python-dotenv 12 | # add hugging face access token here 13 | TOKEN = "" 14 | 15 | 16 | @torch.inference_mode() 17 | def generate_batch_completion( 18 | model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt: str, batch_size: int 19 | ) -> list[str]: 20 | prompt_input = instruct_prompt(prompt) 21 | input_batch = [prompt_input for _ in range(batch_size)] 22 | inputs = tokenizer(input_batch, return_tensors="pt").to(model.device) 23 | input_ids_cutoff = inputs.input_ids.size(dim=1) 24 | 25 | generated_ids = model.generate( 26 | **inputs, 27 | use_cache=True, 28 | max_new_tokens=512, 29 | temperature=0.2, 30 | top_p=0.95, 31 | do_sample=True, 32 | eos_token_id=tokenizer.eos_token_id, 33 | pad_token_id=tokenizer.pad_token_id, 34 | ) 35 | 36 | batch_completions = tokenizer.batch_decode( 37 | [ids[input_ids_cutoff:] for ids in generated_ids], 38 | skip_special_tokens=True, 39 | clean_up_tokenization_spaces=False, 40 | ) 41 | 42 | return batch_completions 43 | 44 | 45 | if __name__ == "__main__": 46 | # adjust for n = 10 etc 47 | num_samples_per_task = 10 48 | out_path = "results/replit_instruct/eval.jsonl" 49 | os.makedirs("results/replit_instruct", exist_ok=True) 50 | 51 | tokenizer = AutoTokenizer.from_pretrained( 52 | "teknium/Replit-v1-CodeInstruct-3B", 53 | trust_remote_code=True, 54 | use_auth_token=TOKEN, 55 | ) 56 | 57 | model = torch.compile( 58 | AutoModelForCausalLM.from_pretrained( 59 | "teknium/Replit-v1-CodeInstruct-3B", 60 | torch_dtype=torch.bfloat16, 61 | trust_remote_code=True, 62 | use_auth_token=TOKEN, 63 | init_device="cuda", 64 | ).eval() 65 | ) 66 | 67 | run_eval( 68 | model, tokenizer, num_samples_per_task, out_path, generate_batch_completion 69 | ) 70 | -------------------------------------------------------------------------------- /eval_starcoder.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoTokenizer, 3 | GPTBigCodeForCausalLM, 4 | PreTrainedTokenizer, 5 | PreTrainedModel, 6 | ) 7 | from core import run_eval, filter_code, fix_indents 8 | import os 9 | import torch 10 | 11 | # TODO: move to python-dotenv 12 | # add hugging face access token here 13 | TOKEN = "" 14 | 15 | 16 | @torch.inference_mode() 17 | def generate_batch_completion( 18 | model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt: str, batch_size: int 19 | ) -> list[str]: 20 | input_batch = [prompt for _ in range(batch_size)] 21 | inputs = tokenizer(input_batch, return_tensors="pt").to(model.device) 22 | input_ids_cutoff = inputs.input_ids.size(dim=1) 23 | 24 | generated_ids = model.generate( 25 | **inputs, 26 | use_cache=True, 27 | max_new_tokens=512, 28 | temperature=0.2, 29 | top_p=0.95, 30 | do_sample=True, 31 | eos_token_id=tokenizer.eos_token_id, 32 | pad_token_id=tokenizer.eos_token_id, # model has no pad token 33 | ) 34 | 35 | batch_completions = tokenizer.batch_decode( 36 | [ids[input_ids_cutoff:] for ids in generated_ids], 37 | skip_special_tokens=True, 38 | ) 39 | 40 | # fix_indents is required to fix the tab character that is generated from starcoder model 41 | return [filter_code(fix_indents(completion)) for completion in batch_completions] 42 | 43 | 44 | if __name__ == "__main__": 45 | # adjust for n = 10 etc 46 | num_samples_per_task = 10 47 | out_path = "results/starcoder/eval.jsonl" 48 | os.makedirs("results/starcoder", exist_ok=True) 49 | 50 | tokenizer = AutoTokenizer.from_pretrained( 51 | "bigcode/starcoder", 52 | trust_remote_code=True, 53 | use_auth_token=TOKEN, 54 | ) 55 | 56 | model = torch.compile( 57 | GPTBigCodeForCausalLM.from_pretrained( 58 | "bigcode/starcoder", 59 | device_map="auto", 60 | torch_dtype=torch.bfloat16, 61 | trust_remote_code=True, 62 | max_memory={ 63 | 0: "18GiB", 64 | 1: "18GiB", 65 | }, 66 | use_auth_token=TOKEN, 67 | ).eval() 68 | ) 69 | 70 | run_eval( 71 | model, 72 | tokenizer, 73 | num_samples_per_task, 74 | out_path, 75 | generate_batch_completion, 76 | True, 77 | ) 78 | -------------------------------------------------------------------------------- /eval_wizard.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoTokenizer, 3 | GPTBigCodeForCausalLM, 4 | PreTrainedTokenizer, 5 | PreTrainedModel, 6 | ) 7 | from core import run_eval, instruct_prompt 8 | import os 9 | import torch 10 | 11 | # TODO: move to python-dotenv 12 | # add hugging face access token here 13 | TOKEN = "" 14 | 15 | 16 | @torch.inference_mode() 17 | def generate_batch_completion( 18 | model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt: str, batch_size: int 19 | ) -> list[str]: 20 | prompt_input = instruct_prompt(prompt) 21 | input_batch = [prompt_input for _ in range(batch_size)] 22 | inputs = tokenizer(input_batch, return_tensors="pt").to(model.device) 23 | input_ids_cutoff = inputs.input_ids.size(dim=1) 24 | 25 | generated_ids = model.generate( 26 | **inputs, 27 | use_cache=True, 28 | max_new_tokens=512, 29 | temperature=0.2, 30 | top_p=0.95, 31 | do_sample=True, 32 | eos_token_id=tokenizer.eos_token_id, 33 | pad_token_id=tokenizer.pad_token_id, 34 | ) 35 | 36 | batch_completions = tokenizer.batch_decode( 37 | [ids[input_ids_cutoff:] for ids in generated_ids], 38 | skip_special_tokens=True, 39 | ) 40 | 41 | return batch_completions 42 | 43 | 44 | if __name__ == "__main__": 45 | # adjust for n = 10 etc 46 | num_samples_per_task = 10 47 | out_path = "results/wizard/eval.jsonl" 48 | os.makedirs("results/wizard", exist_ok=True) 49 | 50 | tokenizer = AutoTokenizer.from_pretrained( 51 | "WizardLM/WizardCoder-15B-V1.0", 52 | use_auth_token=TOKEN, 53 | ) 54 | 55 | model = torch.compile( 56 | GPTBigCodeForCausalLM.from_pretrained( 57 | "WizardLM/WizardCoder-15B-V1.0", 58 | device_map="auto", 59 | torch_dtype=torch.bfloat16, 60 | max_memory={ 61 | 0: "18GiB", 62 | 1: "18GiB", 63 | }, 64 | use_auth_token=TOKEN, 65 | ).eval() 66 | ) 67 | 68 | run_eval( 69 | model, 70 | tokenizer, 71 | num_samples_per_task, 72 | out_path, 73 | generate_batch_completion, 74 | True, 75 | ) 76 | -------------------------------------------------------------------------------- /eval_xgen.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoTokenizer, 3 | AutoModelForCausalLM, 4 | PreTrainedModel, 5 | PreTrainedTokenizer, 6 | ) 7 | from core import filter_code, run_eval 8 | import os 9 | import torch 10 | 11 | # TODO: move to python-dotenv 12 | # add hugging face access token here 13 | TOKEN = "" 14 | 15 | 16 | @torch.inference_mode() 17 | def generate_batch_completion( 18 | model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt, batch_size 19 | ) -> list[str]: 20 | input_batch = [prompt for _ in range(batch_size)] 21 | inputs = tokenizer(input_batch, return_tensors="pt").to(model.device) 22 | input_ids_cutoff = inputs.input_ids.size(dim=1) 23 | 24 | generated_ids = model.generate( 25 | **inputs, 26 | use_cache=True, 27 | max_new_tokens=512, 28 | temperature=0.2, 29 | top_p=0.95, 30 | do_sample=True, 31 | eos_token_id=tokenizer.eos_token_id, 32 | pad_token_id=tokenizer.eos_token_id, # model has no pad token 33 | ) 34 | 35 | batch_completions = tokenizer.batch_decode( 36 | [ids[input_ids_cutoff:] for ids in generated_ids], 37 | skip_special_tokens=True, 38 | ) 39 | 40 | return [filter_code(completion) for completion in batch_completions] 41 | 42 | 43 | if __name__ == "__main__": 44 | # adjust for n = 10 etc 45 | num_samples_per_task = 10 46 | out_path = "results/xgen/eval.jsonl" 47 | os.makedirs("results/xgen", exist_ok=True) 48 | 49 | tokenizer = AutoTokenizer.from_pretrained( 50 | "Salesforce/xgen-7b-8k-base", 51 | trust_remote_code=True, 52 | use_auth_token=TOKEN, 53 | ) 54 | 55 | model = torch.compile( 56 | AutoModelForCausalLM.from_pretrained( 57 | "Salesforce/xgen-7b-8k-base", 58 | torch_dtype=torch.bfloat16, 59 | trust_remote_code=True, 60 | use_auth_token=TOKEN, 61 | ) 62 | .eval() 63 | .to("cuda") 64 | ) 65 | 66 | run_eval( 67 | model, 68 | tokenizer, 69 | num_samples_per_task, 70 | out_path, 71 | generate_batch_completion, 72 | ) 73 | -------------------------------------------------------------------------------- /human-eval/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) OpenAI (https://openai.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /human-eval/README.md: -------------------------------------------------------------------------------- 1 | # HumanEval: Hand-Written Evaluation Set 2 | 3 | This is an evaluation harness for the HumanEval problem solving dataset 4 | described in the paper "[Evaluating Large Language Models Trained on 5 | Code](https://arxiv.org/abs/2107.03374)". 6 | 7 | ## Installation 8 | 9 | Make sure to use python 3.7 or later: 10 | ``` 11 | $ conda create -n codex python=3.7 12 | $ conda activate codex 13 | ``` 14 | 15 | Check out and install this repository: 16 | ``` 17 | $ git clone https://github.com/openai/human-eval 18 | $ pip install -e human-eval 19 | ``` 20 | 21 | ## Usage 22 | 23 | **This program exists to run untrusted model-generated code. Users are strongly 24 | encouraged not to do so outside of a robust security sandbox. The [execution 25 | call](https://github.com/openai/human-eval/blob/master/human_eval/execution.py#L48-L58) 26 | in `execution.py` is deliberately commented out to ensure users read this 27 | disclaimer before running code in a potentially unsafe manner. See the comment in 28 | `execution.py` for more information and instructions.** 29 | 30 | After following the above instructions to enable execution, generate samples 31 | and save them in the following JSON Lines (jsonl) format, where each sample is 32 | formatted into a single line like so: 33 | ``` 34 | {"task_id": "Corresponding HumanEval task ID", "completion": "Completion only without the prompt"} 35 | ``` 36 | We provide `example_problem.jsonl` and `example_solutions.jsonl` under `data` 37 | to illustrate the format and help with debugging. 38 | 39 | Here is nearly functional example code (you just have to provide 40 | `generate_one_completion` to make it work) that saves generated completions to 41 | `samples.jsonl`. 42 | ``` 43 | from human_eval.data import write_jsonl, read_problems 44 | 45 | problems = read_problems() 46 | 47 | num_samples_per_task = 200 48 | samples = [ 49 | dict(task_id=task_id, completion=generate_one_completion(problems[task_id]["prompt"])) 50 | for task_id in problems 51 | for _ in range(num_samples_per_task) 52 | ] 53 | write_jsonl("samples.jsonl", samples) 54 | ``` 55 | 56 | To evaluate the samples, run 57 | ``` 58 | $ evaluate_functional_correctness samples.jsonl 59 | Reading samples... 60 | 32800it [00:01, 23787.50it/s] 61 | Running test suites... 62 | 100%|...| 32800/32800 [16:11<00:00, 33.76it/s] 63 | Writing results to samples.jsonl_results.jsonl... 64 | 100%|...| 32800/32800 [00:00<00:00, 42876.84it/s] 65 | {'pass@1': ..., 'pass@10': ..., 'pass@100': ...} 66 | ``` 67 | This script provides more fine-grained information in a new file ending in 68 | `_results.jsonl`. Each row now contains whether the completion 69 | `passed` along with the execution `result` which is one of "passed", "timed 70 | out", or "failed". 71 | 72 | As a quick sanity-check, the example samples should yield 0.5 pass@1. 73 | ``` 74 | $ evaluate_functional_correctness data/example_samples.jsonl --problem_file=data/example_problem.jsonl 75 | Reading samples... 76 | 6it [00:00, 3397.11it/s] 77 | Running example suites... 78 | 100%|...| 6/6 [00:03<00:00, 1.96it/s] 79 | Writing results to data/example_samples.jsonl_results.jsonl... 80 | 100%|...| 6/6 [00:00<00:00, 6148.50it/s] 81 | {'pass@1': 0.4999999999999999} 82 | ``` 83 | 84 | Because there is no unbiased way of estimating pass@k when there are fewer 85 | samples than k, the script does not evaluate pass@k for these cases. To 86 | evaluate with other k values, pass `--k=`. For 87 | other options, see 88 | ``` 89 | $ evaluate_functional_correctness --help 90 | ``` 91 | However, we recommend that you use the default values for the rest. 92 | 93 | ## Known Issues 94 | 95 | While evaluation uses very little memory, you might see the following error 96 | message when the system is running out of RAM. Since this may cause some 97 | correct programs to fail, we recommend that you free some memory and try again. 98 | ``` 99 | malloc: can't allocate region 100 | ``` 101 | 102 | ## Citation 103 | 104 | Please cite using the following bibtex entry: 105 | 106 | ``` 107 | @article{chen2021codex, 108 | title={Evaluating Large Language Models Trained on Code}, 109 | author={Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba}, 110 | year={2021}, 111 | eprint={2107.03374}, 112 | archivePrefix={arXiv}, 113 | primaryClass={cs.LG} 114 | } 115 | ``` 116 | -------------------------------------------------------------------------------- /human-eval/data/HumanEval.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abacaj/code-eval/602c49d9fdb17495c6730bba3e92a315e9cbfd54/human-eval/data/HumanEval.jsonl.gz -------------------------------------------------------------------------------- /human-eval/human_eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abacaj/code-eval/602c49d9fdb17495c6730bba3e92a315e9cbfd54/human-eval/human_eval/__init__.py -------------------------------------------------------------------------------- /human-eval/human_eval/data.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Dict 2 | import gzip 3 | import json 4 | import os 5 | 6 | 7 | ROOT = os.path.dirname(os.path.abspath(__file__)) 8 | HUMAN_EVAL = os.path.join(ROOT, "..", "data", "HumanEval.jsonl.gz") 9 | 10 | 11 | def read_problems(evalset_file: str = HUMAN_EVAL) -> Dict[str, Dict]: 12 | return {task["task_id"]: task for task in stream_jsonl(evalset_file)} 13 | 14 | 15 | def stream_jsonl(filename: str) -> Iterable[Dict]: 16 | """ 17 | Parses each jsonl line and yields it as a dictionary 18 | """ 19 | if filename.endswith(".gz"): 20 | with open(filename, "rb") as gzfp: 21 | with gzip.open(gzfp, 'rt') as fp: 22 | for line in fp: 23 | if any(not x.isspace() for x in line): 24 | yield json.loads(line) 25 | else: 26 | with open(filename, "r") as fp: 27 | for line in fp: 28 | if any(not x.isspace() for x in line): 29 | yield json.loads(line) 30 | 31 | 32 | def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False): 33 | """ 34 | Writes an iterable of dictionaries to jsonl 35 | """ 36 | if append: 37 | mode = 'ab' 38 | else: 39 | mode = 'wb' 40 | filename = os.path.expanduser(filename) 41 | if filename.endswith(".gz"): 42 | with open(filename, mode) as fp: 43 | with gzip.GzipFile(fileobj=fp, mode='wb') as gzfp: 44 | for x in data: 45 | gzfp.write((json.dumps(x) + "\n").encode('utf-8')) 46 | else: 47 | with open(filename, mode) as fp: 48 | for x in data: 49 | fp.write((json.dumps(x) + "\n").encode('utf-8')) 50 | -------------------------------------------------------------------------------- /human-eval/human_eval/evaluate_functional_correctness.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import sys 3 | 4 | from human_eval.data import HUMAN_EVAL 5 | from human_eval.evaluation import evaluate_functional_correctness 6 | 7 | 8 | def entry_point( 9 | sample_file: str, 10 | k: str = "1,10,100", 11 | n_workers: int = 4, 12 | timeout: float = 3.0, 13 | problem_file: str = HUMAN_EVAL, 14 | ): 15 | """ 16 | Evaluates the functional correctness of generated samples, and writes 17 | results to f"{sample_file}_results.jsonl.gz" 18 | """ 19 | k = list(map(int, k.split(","))) 20 | results = evaluate_functional_correctness(sample_file, k, n_workers, timeout, problem_file) 21 | print(results) 22 | 23 | 24 | def main(): 25 | fire.Fire(entry_point) 26 | 27 | 28 | sys.exit(main()) 29 | -------------------------------------------------------------------------------- /human-eval/human_eval/evaluation.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, Counter 2 | from concurrent.futures import ThreadPoolExecutor, as_completed 3 | from typing import List, Union, Iterable, Dict 4 | import itertools 5 | 6 | import numpy as np 7 | import tqdm 8 | 9 | from human_eval.data import HUMAN_EVAL, read_problems, stream_jsonl, write_jsonl 10 | from human_eval.execution import check_correctness 11 | 12 | 13 | def estimate_pass_at_k( 14 | num_samples: Union[int, List[int], np.ndarray], 15 | num_correct: Union[List[int], np.ndarray], 16 | k: int, 17 | ) -> np.ndarray: 18 | """ 19 | Estimates pass@k of each problem and returns them in an array. 20 | """ 21 | 22 | def estimator(n: int, c: int, k: int) -> float: 23 | """ 24 | Calculates 1 - comb(n - c, k) / comb(n, k). 25 | """ 26 | if n - c < k: 27 | return 1.0 28 | return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) 29 | 30 | if isinstance(num_samples, int): 31 | num_samples_it = itertools.repeat(num_samples, len(num_correct)) 32 | else: 33 | assert len(num_samples) == len(num_correct) 34 | num_samples_it = iter(num_samples) 35 | 36 | return np.array( 37 | [estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)] 38 | ) 39 | 40 | 41 | def evaluate_functional_correctness( 42 | sample_file: str, 43 | k: List[int] = [1, 10, 100], 44 | n_workers: int = 4, 45 | timeout: float = 3.0, 46 | problem_file: str = HUMAN_EVAL, 47 | ): 48 | """ 49 | Evaluates the functional correctness of generated samples, and writes 50 | results to f"{sample_file}_results.jsonl.gz" 51 | """ 52 | 53 | problems = read_problems(problem_file) 54 | 55 | # Check the generated samples against test suites. 56 | with ThreadPoolExecutor(max_workers=n_workers) as executor: 57 | futures = [] 58 | completion_id = Counter() 59 | n_samples = 0 60 | results = defaultdict(list) 61 | 62 | print("Reading samples...") 63 | for sample in tqdm.tqdm(stream_jsonl(sample_file)): 64 | task_id = sample["task_id"] 65 | completion = sample["completion"] 66 | args = (problems[task_id], completion, timeout, completion_id[task_id]) 67 | future = executor.submit(check_correctness, *args) 68 | futures.append(future) 69 | completion_id[task_id] += 1 70 | n_samples += 1 71 | 72 | assert len(completion_id) == len(problems), "Some problems are not attempted." 73 | 74 | print("Running test suites...") 75 | for future in tqdm.tqdm(as_completed(futures), total=len(futures)): 76 | result = future.result() 77 | results[result["task_id"]].append((result["completion_id"], result)) 78 | 79 | # Calculate pass@k. 80 | total, correct = [], [] 81 | for result in results.values(): 82 | result.sort() 83 | passed = [r[1]["passed"] for r in result] 84 | total.append(len(passed)) 85 | correct.append(sum(passed)) 86 | total = np.array(total) 87 | correct = np.array(correct) 88 | 89 | ks = k 90 | pass_at_k = { 91 | f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() 92 | for k in ks 93 | if (total >= k).all() 94 | } 95 | 96 | # Finally, save the results in one file: 97 | def combine_results(): 98 | for sample in stream_jsonl(sample_file): 99 | task_id = sample["task_id"] 100 | result = results[task_id].pop(0) 101 | sample["result"] = result[1]["result"] 102 | sample["passed"] = result[1]["passed"] 103 | yield sample 104 | 105 | out_file = sample_file + "_results.jsonl" 106 | print(f"Writing results to {out_file}...") 107 | write_jsonl(out_file, tqdm.tqdm(combine_results(), total=n_samples)) 108 | 109 | return pass_at_k 110 | -------------------------------------------------------------------------------- /human-eval/human_eval/execution.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable, Dict 2 | import ast 3 | import contextlib 4 | import faulthandler 5 | import io 6 | import os 7 | import multiprocessing 8 | import platform 9 | import signal 10 | import tempfile 11 | 12 | 13 | def check_correctness(problem: Dict, completion: str, timeout: float, 14 | completion_id: Optional[int] = None) -> Dict: 15 | """ 16 | Evaluates the functional correctness of a completion by running the test 17 | suite provided in the problem. 18 | 19 | :param completion_id: an optional completion ID so we can match 20 | the results later even if execution finishes asynchronously. 21 | """ 22 | 23 | def unsafe_execute(): 24 | 25 | with create_tempdir(): 26 | 27 | # These system calls are needed when cleaning up tempdir. 28 | import os 29 | import shutil 30 | rmtree = shutil.rmtree 31 | rmdir = os.rmdir 32 | chdir = os.chdir 33 | 34 | # Disable functionalities that can make destructive changes to the test. 35 | reliability_guard() 36 | 37 | # Construct the check program and run it. 38 | print(completion) 39 | check_program = ( 40 | problem["prompt"] + completion + "\n" + 41 | problem["test"] + "\n" + 42 | f"check({problem['entry_point']})" 43 | ) 44 | 45 | try: 46 | exec_globals = {} 47 | with swallow_io(): 48 | with time_limit(timeout): 49 | exec(check_program, exec_globals) 50 | result.append("passed") 51 | except TimeoutException: 52 | result.append("timed out") 53 | except BaseException as e: 54 | result.append(f"failed: {e}") 55 | 56 | # Needed for cleaning up. 57 | shutil.rmtree = rmtree 58 | os.rmdir = rmdir 59 | os.chdir = chdir 60 | 61 | manager = multiprocessing.Manager() 62 | result = manager.list() 63 | 64 | p = multiprocessing.Process(target=unsafe_execute) 65 | p.start() 66 | p.join(timeout=timeout + 1) 67 | if p.is_alive(): 68 | p.kill() 69 | 70 | if not result: 71 | result.append("timed out") 72 | 73 | return dict( 74 | task_id=problem["task_id"], 75 | passed=result[0] == "passed", 76 | result=result[0], 77 | completion_id=completion_id, 78 | ) 79 | 80 | 81 | @contextlib.contextmanager 82 | def time_limit(seconds: float): 83 | def signal_handler(signum, frame): 84 | raise TimeoutException("Timed out!") 85 | signal.setitimer(signal.ITIMER_REAL, seconds) 86 | signal.signal(signal.SIGALRM, signal_handler) 87 | try: 88 | yield 89 | finally: 90 | signal.setitimer(signal.ITIMER_REAL, 0) 91 | 92 | 93 | @contextlib.contextmanager 94 | def swallow_io(): 95 | stream = WriteOnlyStringIO() 96 | with contextlib.redirect_stdout(stream): 97 | with contextlib.redirect_stderr(stream): 98 | with redirect_stdin(stream): 99 | yield 100 | 101 | 102 | @contextlib.contextmanager 103 | def create_tempdir(): 104 | with tempfile.TemporaryDirectory() as dirname: 105 | with chdir(dirname): 106 | yield dirname 107 | 108 | 109 | class TimeoutException(Exception): 110 | pass 111 | 112 | 113 | class WriteOnlyStringIO(io.StringIO): 114 | """ StringIO that throws an exception when it's read from """ 115 | 116 | def read(self, *args, **kwargs): 117 | raise IOError 118 | 119 | def readline(self, *args, **kwargs): 120 | raise IOError 121 | 122 | def readlines(self, *args, **kwargs): 123 | raise IOError 124 | 125 | def readable(self, *args, **kwargs): 126 | """ Returns True if the IO object can be read. """ 127 | return False 128 | 129 | 130 | class redirect_stdin(contextlib._RedirectStream): # type: ignore 131 | _stream = 'stdin' 132 | 133 | 134 | @contextlib.contextmanager 135 | def chdir(root): 136 | if root == ".": 137 | yield 138 | return 139 | cwd = os.getcwd() 140 | os.chdir(root) 141 | try: 142 | yield 143 | except BaseException as exc: 144 | raise exc 145 | finally: 146 | os.chdir(cwd) 147 | 148 | 149 | def reliability_guard(maximum_memory_bytes: Optional[int] = None): 150 | """ 151 | This disables various destructive functions and prevents the generated code 152 | from interfering with the test (e.g. fork bomb, killing other processes, 153 | removing filesystem files, etc.) 154 | 155 | WARNING 156 | This function is NOT a security sandbox. Untrusted code, including, model- 157 | generated code, should not be blindly executed outside of one. See the 158 | Codex paper for more information about OpenAI's code sandbox, and proceed 159 | with caution. 160 | """ 161 | 162 | if maximum_memory_bytes is not None: 163 | import resource 164 | resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) 165 | resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) 166 | if not platform.uname().system == 'Darwin': 167 | resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) 168 | 169 | faulthandler.disable() 170 | 171 | import builtins 172 | builtins.exit = None 173 | builtins.quit = None 174 | 175 | import os 176 | os.environ['OMP_NUM_THREADS'] = '1' 177 | 178 | os.kill = None 179 | os.system = None 180 | os.putenv = None 181 | os.remove = None 182 | os.removedirs = None 183 | os.rmdir = None 184 | os.fchdir = None 185 | os.setuid = None 186 | os.fork = None 187 | os.forkpty = None 188 | os.killpg = None 189 | os.rename = None 190 | os.renames = None 191 | os.truncate = None 192 | os.replace = None 193 | os.unlink = None 194 | os.fchmod = None 195 | os.fchown = None 196 | os.chmod = None 197 | os.chown = None 198 | os.chroot = None 199 | os.fchdir = None 200 | os.lchflags = None 201 | os.lchmod = None 202 | os.lchown = None 203 | os.getcwd = None 204 | os.chdir = None 205 | 206 | import shutil 207 | shutil.rmtree = None 208 | shutil.move = None 209 | shutil.chown = None 210 | 211 | import subprocess 212 | subprocess.Popen = None # type: ignore 213 | 214 | __builtins__['help'] = None 215 | 216 | import sys 217 | sys.modules['ipdb'] = None 218 | sys.modules['joblib'] = None 219 | sys.modules['resource'] = None 220 | sys.modules['psutil'] = None 221 | sys.modules['tkinter'] = None 222 | -------------------------------------------------------------------------------- /human-eval/requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | fire 3 | numpy 4 | -------------------------------------------------------------------------------- /human-eval/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup, find_packages 5 | 6 | 7 | setup( 8 | name="human-eval", 9 | py_modules=["human-eval"], 10 | version="1.0", 11 | description="", 12 | author="OpenAI", 13 | packages=find_packages(), 14 | install_requires=[ 15 | str(r) 16 | for r in pkg_resources.parse_requirements( 17 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 18 | ) 19 | ], 20 | entry_points={ 21 | "console_scripts": [ 22 | "evaluate_functional_correctness = human_eval.evaluate_functional_correctness", 23 | ] 24 | } 25 | ) 26 | -------------------------------------------------------------------------------- /process_eval.py: -------------------------------------------------------------------------------- 1 | from human_eval.data import read_problems, write_jsonl, stream_jsonl 2 | import glob 3 | from tqdm import tqdm 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | 8 | # Inputs 9 | parser.add_argument("--path", type=str, help="") 10 | parser.add_argument("--out_path", type=str, help="") 11 | parser.add_argument("--add_prompt", action="store_true", help="") 12 | 13 | args = parser.parse_args() 14 | 15 | 16 | files = sorted(glob.glob(args.path + "/*.jsonl")) 17 | print("{} files in {}".format(len(files), args.path)) 18 | 19 | problems = read_problems() 20 | 21 | output = [] 22 | a = 0 23 | for code_file in tqdm(files, total=len(files)): 24 | codes = [c for c in stream_jsonl(code_file)] 25 | if args.add_prompt: 26 | for code in codes: 27 | task_id = code["task_id"] 28 | prompt = problems[task_id]["prompt"] 29 | completion = code["completion"] 30 | completion = completion.replace("\r", "") 31 | if "```python" in completion: 32 | def_line = completion.index("```python") 33 | completion = completion[def_line:].strip() 34 | completion = completion.replace("```python", "") 35 | # print(completion) 36 | try: 37 | next_line = completion.index("```") 38 | completion = completion[:next_line].strip() 39 | except: 40 | a += 1 41 | print(completion) 42 | print("================\n") 43 | # print(completion) 44 | if '__name__ == "__main__"' in completion: 45 | next_line = completion.index('if __name__ == "__main__":') 46 | completion = completion[:next_line].strip() 47 | # print(completion) 48 | 49 | if "# Example usage" in completion: 50 | # print(completion) 51 | next_line = completion.index("# Example usage") 52 | completion = completion[:next_line].strip() 53 | 54 | code["completion"] = completion 55 | 56 | output += codes 57 | 58 | print("save to {}".format(args.out_path)) 59 | write_jsonl(args.out_path, output) 60 | print(a) 61 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -e ./human-eval 2 | transformers==4.30.2 3 | torch==2.0.1 4 | accelerate==0.20.3 5 | sentencepiece==0.1.99 6 | einops==0.6.1 7 | tiktoken==0.4.0 --------------------------------------------------------------------------------