├── stgen ├── __init__.py ├── utils.py ├── README.md ├── main.py ├── contract_generator.py └── st_generator.py ├── coffe ├── __init__.py ├── config.py ├── sandbox.py ├── dataset.py ├── sanitize.py ├── main.py └── code_execution.py ├── datasets ├── mbpp │ ├── data-00000-of-00001.arrow │ ├── state.json │ └── dataset_info.json ├── codeparrot_apps │ ├── data-00000-of-00001.arrow │ ├── state.json │ └── dataset_info.json ├── openai_humaneval │ ├── data-00000-of-00001.arrow │ ├── state.json │ └── dataset_info.json └── deepmind_code_contests │ ├── data-00000-of-00001.arrow │ ├── state.json │ └── dataset_info.json ├── requirements.txt ├── Dockerfile ├── setup.py ├── .gitignore ├── README.md ├── LICENSE └── perf.json /stgen/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /coffe/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" -------------------------------------------------------------------------------- /datasets/mbpp/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnnyPeng18/Coffe/HEAD/datasets/mbpp/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /datasets/codeparrot_apps/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnnyPeng18/Coffe/HEAD/datasets/codeparrot_apps/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /datasets/openai_humaneval/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnnyPeng18/Coffe/HEAD/datasets/openai_humaneval/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cirron==0.3 2 | transformers==4.41.0 3 | docker==7.1.0 4 | datasets==2.19.1 5 | coverage 6 | openai 7 | timeout_decorator 8 | scipy 9 | termcolor -------------------------------------------------------------------------------- /datasets/deepmind_code_contests/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnnyPeng18/Coffe/HEAD/datasets/deepmind_code_contests/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /datasets/mbpp/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "172cf84c63ce282b", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": "test" 13 | } -------------------------------------------------------------------------------- /datasets/codeparrot_apps/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "271731ea54339871", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": null 13 | } -------------------------------------------------------------------------------- /datasets/openai_humaneval/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "788c49cc88f12c77", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": "test" 13 | } -------------------------------------------------------------------------------- /datasets/deepmind_code_contests/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "e494af664c3ac846", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": null 13 | } -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10-slim 2 | 3 | # Set up proxy to access Internet if necessary 4 | #ENV http_proxy "" 5 | #ENV https_proxy "" 6 | 7 | RUN apt-get update && apt-get install -y git 8 | 9 | RUN apt-get install -y gcc g++ linux-perf 10 | 11 | RUN pip install --upgrade pip 12 | 13 | COPY . /Coffe 14 | 15 | RUN cd /Coffe && pip install . 16 | 17 | RUN cd .. && coffe init -d Coffe/datasets -w / -------------------------------------------------------------------------------- /datasets/deepmind_code_contests/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "citation": "", 3 | "description": "", 4 | "features": { 5 | "task_id": { 6 | "dtype": "int32", 7 | "_type": "Value" 8 | }, 9 | "description": { 10 | "dtype": "string", 11 | "_type": "Value" 12 | }, 13 | "solutions": [ 14 | { 15 | "dtype": "string", 16 | "_type": "Value" 17 | } 18 | ], 19 | "input_output": { 20 | "dtype": "string", 21 | "_type": "Value" 22 | } 23 | }, 24 | "homepage": "", 25 | "license": "" 26 | } -------------------------------------------------------------------------------- /datasets/codeparrot_apps/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "citation": "", 3 | "description": "", 4 | "features": { 5 | "problem_id": { 6 | "dtype": "int32", 7 | "_type": "Value" 8 | }, 9 | "problem": { 10 | "dtype": "string", 11 | "_type": "Value" 12 | }, 13 | "input_output": { 14 | "dtype": "string", 15 | "_type": "Value" 16 | }, 17 | "solutions": [ 18 | { 19 | "dtype": "string", 20 | "_type": "Value" 21 | } 22 | ], 23 | "starter_code": { 24 | "dtype": "string", 25 | "_type": "Value" 26 | } 27 | }, 28 | "homepage": "", 29 | "license": "" 30 | } -------------------------------------------------------------------------------- /coffe/config.py: -------------------------------------------------------------------------------- 1 | 2 | benchmarks = { 3 | "openai_humaneval": { 4 | "code_keyword": "canonical_solution", 5 | "testcase_keyword": "testcases", 6 | "add_list": False, 7 | "path": "openai_humaneval" 8 | }, 9 | "mbpp": { 10 | "code_keyword": "code", 11 | "testcase_keyword": "testcases", 12 | "add_list": False, 13 | "path": "mbpp" 14 | }, 15 | "codeparrot/apps": { 16 | "code_keyword": "solutions", 17 | "testcase_keyword": "input_output", 18 | "add_list": True, 19 | "path": "codeparrot_apps" 20 | }, 21 | "deepmind/code_contests": { 22 | "code_keyword": "solutions", 23 | "testcase_keyword": ["private_tests", "generated_tests"], 24 | "add_list": True, 25 | "path": "deepmind_code_contests" 26 | }, 27 | "function": { 28 | "path": "function" 29 | }, 30 | "file": { 31 | "path": "file" 32 | } 33 | } 34 | 35 | -------------------------------------------------------------------------------- /datasets/openai_humaneval/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "citation": "", 3 | "config_name": "openai_humaneval", 4 | "dataset_size": 194394, 5 | "description": "", 6 | "download_checksums": { 7 | "https://huggingface.co/datasets/openai_humaneval/resolve/7dce6050a7d6d172f3cc5c32aa97f52fa1a2e544/openai_humaneval/test-00000-of-00001.parquet": { 8 | "num_bytes": 83920, 9 | "checksum": null 10 | } 11 | }, 12 | "download_size": 83920, 13 | "features": { 14 | "task_id": { 15 | "dtype": "string", 16 | "_type": "Value" 17 | }, 18 | "prompt": { 19 | "dtype": "string", 20 | "_type": "Value" 21 | }, 22 | "canonical_solution": { 23 | "dtype": "string", 24 | "_type": "Value" 25 | }, 26 | "test": { 27 | "dtype": "string", 28 | "_type": "Value" 29 | }, 30 | "entry_point": { 31 | "dtype": "string", 32 | "_type": "Value" 33 | } 34 | }, 35 | "homepage": "", 36 | "license": "", 37 | "size_in_bytes": 278314, 38 | "splits": { 39 | "test": { 40 | "name": "test", 41 | "num_bytes": 194394, 42 | "num_examples": 164, 43 | "dataset_name": "parquet" 44 | } 45 | } 46 | } -------------------------------------------------------------------------------- /stgen/utils.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import os 3 | import re 4 | import string 5 | from tqdm import tqdm 6 | 7 | import copy 8 | from typing import Any, List 9 | 10 | 11 | def make_request(query, model='gpt-4o', n=1, temperature=0.7, max_tokens=4096, initial_instruction = "You are a helpful assistant about software development."): 12 | messages = [ 13 | {"role": "system", "content": initial_instruction}, 14 | {"role": "user", "content": query} 15 | ] 16 | # make it less than 4096 tokens 17 | retry_times = 0 18 | key = os.environ.get('API_KEY') 19 | base_url = os.environ.get('BASE_URL') 20 | if key == None or base_url == None: 21 | raise ValueError("Please set environment variable $API_KEY and $BASE_URL before making requests to remote LLM services.") 22 | client = openai.OpenAI( 23 | api_key=key, 24 | base_url=base_url 25 | ) 26 | 27 | while retry_times < 3: 28 | try: 29 | answers = client.chat.completions.create( 30 | model=model, 31 | messages=messages, 32 | temperature=temperature, 33 | max_tokens=max_tokens, 34 | n = n 35 | ) 36 | return [x.message.content for x in answers.choices] 37 | except Exception as e: 38 | print(e) 39 | retry_times += 1 40 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from os import path, environ 3 | from coffe import __version__ 4 | 5 | 6 | setup( 7 | name = "coffe", 8 | version = __version__, 9 | description = "Coffe: A Code Efficiency Benchmark for Code Generation", 10 | long_description = open(path.join(path.abspath(path.dirname(__file__)), "README.md"), "r", encoding = "utf-8").read(), 11 | long_description_content_type='text/markdown', 12 | url = "https://github.com/JohnnyPeng18/Coffe", 13 | author = "Yun Peng", 14 | author_email = "normal@yunpeng.work", 15 | classifiers = [ 16 | "Development Status :: 4 - Beta", 17 | "Programming Language :: Python :: 3.9", 18 | "Programming Language :: Python :: 3.10", 19 | "Intended Audience :: Science/Research", 20 | "License :: OSI Approved :: Apache Software License", 21 | "Topic :: Software Development :: Build Tools" 22 | ], 23 | keywords = ["python", "code generation", "large language model", "performance", "LLM evaluation"], 24 | packages = ["coffe", "stgen"], 25 | python_requries='>=3.9', 26 | install_requires = open(path.join(path.abspath(path.dirname(__file__)), "requirements.txt"), "r", encoding = "utf-8").read().splitlines(), 27 | entry_points={ 28 | 'console_scripts': [ 29 | 'coffe = coffe.main:main', 30 | 'stgen = stgen.main:main' 31 | ] 32 | } 33 | ) -------------------------------------------------------------------------------- /stgen/README.md: -------------------------------------------------------------------------------- 1 | # STGen 2 | STGen is a framework that generates stressful test cases for code with correctness test cases. It is used in the COFFE benchmark to add stressful test cases. 3 | 4 | ## Installation 5 | The installation of COFFE benchmark automatically installs STGen, you do not need to perform extra operations to install STGen again. 6 | 7 | 8 | ## Usage 9 | ### Pipeline 10 | STGen provides a pipeline to directly generate stressful test cases. You can use the following command: 11 | ```bash 12 | stgen pipe -l func -d Coffe/datasets/function/data.json -t Coffe/datasets/function/testcases.json -s Coffe/datasets/function/best_solutions.json -n 5 -o 13 | ``` 14 | 15 | This command generates five stressful test cases for function-level code given three data files in the Coffe benchmark. 16 | If you want to generate stressful test cases for other code, you should prepare the three data files follow the formats of the above three files. For the `data.json` file, each instance should contain the `entry_point`, `prompt` and `final_prompt` fields for function-level instances and `prompt` and `final_prompt` fields for file-level instances. 17 | 18 | ### Single Phase 19 | The generation process in STGen contains two phases: 1) contract generation and 2) test case generation. We also provide commands to conduct the single phases. 20 | 21 | **Phase I - Contract Generation:** 22 | 23 | You can use the following command to only generate contracts for code: 24 | ```shell 25 | stgen contract -l func -d Coffe/datasets/function/data.json -t Coffe/datasets/function/testcases.json -s Coffe/datasets/function/best_solutions.json -o 26 | ``` 27 | This command generates contracts for function-level code given the data files in the COFFE benchmark. Each contract is the original code containing extra `assert` statements, which could be used to guide the generation of stressful test cases. 28 | 29 | **Phase II - Test Case Generation:** 30 | 31 | You can use the following command to only generate stressful test cases given the previously generated contracts: 32 | ```shell 33 | stgen st -l func -d Coffe/datasets/function/data.json -t Coffe/datasets/function/testcases.json -s Coffe/datasets/function/best_solutions.json -c -o -n 5 34 | ``` 35 | This command generates five stressful test cases for function-level code given the data files and previously generated contract file ``. -------------------------------------------------------------------------------- /coffe/sandbox.py: -------------------------------------------------------------------------------- 1 | import docker 2 | import os 3 | from datetime import datetime 4 | 5 | class SandBox(object): 6 | def __init__(self, workdir, perf_path): 7 | self.workdir = workdir 8 | self.image_tag = "coffe" 9 | self.client = docker.from_env() 10 | self.perf_path = perf_path 11 | 12 | def _run(self, args): 13 | self.run(args[0], args[1], args[2]) 14 | 15 | def run(self, command, worker_id, timeout): 16 | container_workdir = '/data' 17 | mount = docker.types.Mount(target = container_workdir, source = self.workdir, type = 'bind', read_only = False) 18 | 19 | buf_prefix = 'stdbuf -i0 -o0 -e0' 20 | timeout_prefix = 'timeout {}'.format(timeout) 21 | command = " ".join([buf_prefix, timeout_prefix, command]) 22 | exit_code = 0 23 | 24 | try: 25 | container = self.client.containers.run(image=self.image_tag, command=['/bin/bash', '-c', command], detach=True, security_opt=["seccomp=" + open(self.perf_path, "r").read()], network_mode='host', mounts=[mount]) 26 | except Exception as e: 27 | print(f"Worker {worker_id}: container running failed, reason: {e}") 28 | os.system(f'echo "[{datetime.now()}]Worker {worker_id} running failed." >> {self.workdir}/ERROR_{worker_id}') 29 | exit_code = -1 30 | return exit_code 31 | 32 | try: 33 | exit_code = container.wait(timeout = timeout + 100, condition = 'not-running')['StatusCode'] 34 | except Exception as e: 35 | print(f'Worker {worker_id}: Container time out, killed.') 36 | try: 37 | if container.status == 'running': 38 | container.kill() 39 | except Exception as e: 40 | print(e) 41 | os.system(f'echo "[{datetime.now()}]Worker {worker_id} timeout" >> {self.workdir}/ERROR_{worker_id}') 42 | exit_code = -1 43 | finally: 44 | try: 45 | log = container.logs(stdout = True, stderr = True).decode(encoding = 'utf-8', errors = 'ignore').strip() 46 | with open(os.path.join(self.workdir, f'CHECK_LOG_{worker_id}'), 'w', encoding = 'utf-8') as lf: 47 | lf.write(log) 48 | container.remove(v=True, force = True) 49 | except Exception as e: 50 | print(e) 51 | os.system(f'echo "[{datetime.now()}]Worker {worker_id} logerror" >> {self.workdir}/ERROR_{worker_id}') 52 | return exit_code 53 | -------------------------------------------------------------------------------- /datasets/mbpp/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": "parquet", 3 | "citation": "", 4 | "config_name": "sanitized", 5 | "dataset_name": "mbpp", 6 | "dataset_size": 219685, 7 | "description": "", 8 | "download_checksums": { 9 | "hf://datasets/mbpp@4bb6404fdc6cacfda99d4ac4205087b89d32030c/sanitized/train-00000-of-00001.parquet": { 10 | "num_bytes": 33854, 11 | "checksum": null 12 | }, 13 | "hf://datasets/mbpp@4bb6404fdc6cacfda99d4ac4205087b89d32030c/sanitized/test-00000-of-00001.parquet": { 14 | "num_bytes": 60864, 15 | "checksum": null 16 | }, 17 | "hf://datasets/mbpp@4bb6404fdc6cacfda99d4ac4205087b89d32030c/sanitized/validation-00000-of-00001.parquet": { 18 | "num_bytes": 13987, 19 | "checksum": null 20 | }, 21 | "hf://datasets/mbpp@4bb6404fdc6cacfda99d4ac4205087b89d32030c/sanitized/prompt-00000-of-00001.parquet": { 22 | "num_bytes": 6717, 23 | "checksum": null 24 | } 25 | }, 26 | "download_size": 115422, 27 | "features": { 28 | "source_file": { 29 | "dtype": "string", 30 | "_type": "Value" 31 | }, 32 | "task_id": { 33 | "dtype": "int32", 34 | "_type": "Value" 35 | }, 36 | "prompt": { 37 | "dtype": "string", 38 | "_type": "Value" 39 | }, 40 | "code": { 41 | "dtype": "string", 42 | "_type": "Value" 43 | }, 44 | "test_imports": { 45 | "feature": { 46 | "dtype": "string", 47 | "_type": "Value" 48 | }, 49 | "_type": "Sequence" 50 | }, 51 | "test_list": { 52 | "feature": { 53 | "dtype": "string", 54 | "_type": "Value" 55 | }, 56 | "_type": "Sequence" 57 | } 58 | }, 59 | "homepage": "", 60 | "license": "", 61 | "size_in_bytes": 335107, 62 | "splits": { 63 | "train": { 64 | "name": "train", 65 | "num_bytes": 63468, 66 | "num_examples": 120, 67 | "dataset_name": "mbpp" 68 | }, 69 | "test": { 70 | "name": "test", 71 | "num_bytes": 132753, 72 | "num_examples": 257, 73 | "dataset_name": "mbpp" 74 | }, 75 | "validation": { 76 | "name": "validation", 77 | "num_bytes": 20056, 78 | "num_examples": 43, 79 | "dataset_name": "mbpp" 80 | }, 81 | "prompt": { 82 | "name": "prompt", 83 | "num_bytes": 3408, 84 | "num_examples": 7, 85 | "dataset_name": "mbpp" 86 | } 87 | }, 88 | "version": { 89 | "version_str": "0.0.0", 90 | "major": 0, 91 | "minor": 0, 92 | "patch": 0 93 | } 94 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # COFFE 2 | 3 | COFFE is a Python benchmark for evaluating the time efficiency of LLM-generated code. It is released by the FSE'25 paper "[COFFE: A Code Efficiency Benchmark for Code Generation](https://arxiv.org/abs/2502.02827)". You can also refer to the [project webpage](https://www.yunpeng.site/projects/coffe/) for more details. 4 | 5 | ## Data 6 | 7 | COFFE is designed for evaluating both function-level code and file-level code. It contains selected instances from HumanEval, MBPP, APPS and Code Contests. COFFE keeps the original test cases in these benchmarks as *correctness test cases* and adds new test cases designed for time efficiency evaluation as *stressful test cases*. 8 | 9 | **Statistics:** 10 | 11 | |Category|#Instance|#Solution/Instance|#Correctness/Instance | #Stressful/Instance| 12 | |----|----|----|----|----| 13 | |Function-Level|398|1.00|5.72|4.99| 14 | |File-Level|358|66.93|43.68|4.95| 15 | 16 | **Data Files:** 17 | 18 | All instances in COFFE are in `Coffe/datasets`, where `Coffe/datasets/function` contains all function-level instances and `Coffe/datasets/file` contains all file-level instances. In each repo: 19 | - `best_solutions.json` contains the best ground truth solution COFFE uses to calculate efficient@1 and speedup. 20 | - `stressful_testcases.json` contains all stressful test cases COFFE adds. 21 | - `solutions.json` contains all ground truth solutions in original benchmarks. 22 | - `testcases.json` contains all correctness test cases in original benchmarks. 23 | 24 | ## Installation 25 | 26 | **Requirements:** 27 | - Linux Machine 28 | - Docker 29 | - Python>=3.10 30 | 31 | We suggest you create a virtural environment before installing COFFE. 32 | 33 | 1. To use COFFE, please clone this repo in your current workspace `workspace/` and execute: 34 | ```bash 35 | cd Coffe && pip install . 36 | ``` 37 | 2. COFFE comes with a docker image, to install it: 38 | ```bash 39 | docker build . -t coffe 40 | ``` 41 | Note that if your network requires proxy, please modify the Dockerfile in `Coffe/` to indicate it, otherwise the docker image building process could fail. 42 | 43 | 3. Go back to your workspace and initialize COFFE: 44 | ```bash 45 | cd .. && coffe 46 | ``` 47 | If your installation succeeds, you could see the statistics of COFFE. 48 | 49 | If your installation fails with the reason `Your OS does not support measuring CPU instruction counts`, this may be because of a permission error. In this case, check the default permission level in your system: 50 | 51 | ```bash 52 | cat /proc/sys/kernel/perf_event_paranoid 53 | ``` 54 | 55 | If the output is greater than `2`, then you do not have permission to measure CPU instruction counts by default. 56 | 57 | To enable this measurement, try to set the `/proc/sys/kernel/perf_event_paranoid` to `2`, `1`, `0,` or `-1`. The smaller it is, the larger the permission you have. 58 | 59 | ``` 60 | echo 1 | sudo tee /proc/sys/kernel/perf_event_paranoid 61 | ``` 62 | 63 | This will temporarily allow you to access the CPU instruction count and will expire after you restart the system. 64 | 65 | If you want to permanently allow the measurement (This may induce security issues!): 66 | 67 | Edit the `/etc/sysctl.conf` by adding the following line: 68 | 69 | ``` 70 | kernel.perf_event_paranoid= -1 71 | ``` 72 | 73 | Then reload the configuration: 74 | 75 | ```bash 76 | sysctl -p /etc/sysctl.conf 77 | ``` 78 | 79 | ## Usage 80 | 81 | ### Pipeline 82 | 83 | When you prepare the predictions from LLMs, COFFE provides a pipeline to calculate the efficient@1 and speedup defined in the paper: 84 | ```bash 85 | coffe pipe 86 | -p ,..., 87 | -f 88 | -n 89 | ``` 90 | This command has four phases: 91 | 1. santize the predictions. 92 | 2. select the correct predictions based on correctness test cases. 93 | 3. evaluate the GPU instruction count based on stressful test cases. 94 | 4. calculate the final metrics. 95 | 96 | For example: 97 | ```bash 98 | coffe pipe function Coffe/examples/function -p Coffe/examples/function/GPT-4o.json -f efficient_at_1 -n 8 99 | ``` 100 | This command evaluates the predictions from GPT-4o on the function-level instances of COFFE. If you want to evaluate other LLMs, please prepare a `JSON` file with the same format as `Coffe/examples/function/GPT-4o.json`. 101 | 102 | **Prediction File Format:** 103 | 104 | In the `JSON` file, the key is the prompt used to query the LLM for the results, you could get the prompts in `datasets/function/prompts.json` and `datasets/file/prompts.json`. The value contains two objects, the first is a list contains the raw outputs from LLMs and the second is an indicator for the whether the raw output is valid. 105 | 106 | **Note:** 107 | 108 | In default, COFFE will run all predictions in docker. However, if you could not successfully install the docker or want to run the predictions on the host machine, you can add the `-x` option. 109 | 110 | 111 | ### Single Evaluation 112 | 113 | The `pipe` command provides an entire pipeline for calculating the final metrics. This pipeline could also be completed by executing the following four single evaluation commands. 114 | 115 | 1. Sanitize the predictions 116 | ```bash 117 | coffe eval 118 | -p ,..., 119 | -m compilable_rate 120 | ``` 121 | This commands output a file ending with `SOLUTIONS.json` that contains the predictions without syntax errors. 122 | 123 | 2. Select correct predictions 124 | ```bash 125 | coffe eval 126 | -p ,..., 127 | -m correctness 128 | -n 129 | ``` 130 | This commands accept prediction files ending with `SOLUTIONS.json` and output a file ending with `PASSED_SOLUTIONS.json` that contains the predictions pass all correctness solutions. 131 | 132 | **Note:** 133 | This command will combine all correct solutions and ground truth solutions together into files `_all_indexes.json` (used in step 4) and `_all_PASSED_SOLUTIONS.json` for the next step. 134 | 135 | 3. Evaluate the GPU instruction count 136 | ```bash 137 | coffe eval 138 | -p 139 | -m instr_count 140 | -n 141 | ``` 142 | This command will evaluate the GPU instruction count each prediction consumes and output a file ending with `STRESSFUL_INSTRUCTION.json`. 143 | 144 | **Note:** 145 | This command could only accept one single prediction file ending with `PASSED_SOLUTIONS.json`. 146 | 147 | 4. Calculating the efficient@1/speedup 148 | ```bash 149 | coffe eval 150 | -p , 151 | -m instr_count 152 | -f 153 | ``` 154 | This command calculate the efficient@1 or speedup. 155 | 156 | **Note:** 157 | This command requires the index file and the instruction file as COFFE compares the performance of predictions with grouth truth solutions to calculate the metrics. 158 | 159 | ## STGen 160 | 161 | For details about the stressful test case generation approach STGen, please see `stgen/`. 162 | 163 | 164 | ## Cite 165 | If you use COFFE, please cite us: 166 | ``` 167 | @misc{peng2025coffe, 168 | title={COFFE: A Code Efficiency Benchmark for Code Generation}, 169 | author={Yun Peng and Jun Wan and Yichen Li and Xiaoxue Ren}, 170 | year={2025}, 171 | eprint={2502.02827}, 172 | archivePrefix={arXiv}, 173 | primaryClass={cs.SE}, 174 | url={https://arxiv.org/abs/2502.02827}, 175 | } 176 | ``` 177 | -------------------------------------------------------------------------------- /stgen/main.py: -------------------------------------------------------------------------------- 1 | from genericpath import isdir 2 | import json 3 | import os 4 | import argparse 5 | from termcolor import colored 6 | 7 | 8 | from stgen.st_generator import gen_file_sts, gen_func_sts 9 | from stgen.contract_generator import gen_file_contracts, gen_func_contracts 10 | 11 | 12 | 13 | def info(args): 14 | print("Welcome to use STGen!") 15 | print("STGen is part of the Coffe benchmark and enables stressful test case generation.") 16 | print("For more details, please see https://github.com/JohnnyPeng18/Coffe/stgen.") 17 | print("To see the options of STGen, please use -h option.") 18 | 19 | def check_environ(): 20 | key = os.environ.get('API_KEY') 21 | base_url = os.environ.get('BASE_URL') 22 | if key == None or base_url == None: 23 | raise ValueError("Please set environment variable $API_KEY and $BASE_URL before making requests to remote LLM services.") 24 | 25 | 26 | def contract(args): 27 | check_environ() 28 | if args.level == "func": 29 | if os.path.isdir(args.output_path): 30 | output_file = os.path.join(args.output_path, "func_contracts.json") 31 | else: 32 | output_file = args.output_path 33 | gen_func_contracts(args.data_file, args.test_file, args.solution_file, output_file, verbose = args.verbose) 34 | elif args.level == "file": 35 | if os.path.isdir(args.output_path): 36 | output_file = os.path.join(args.output_path, "file_contracts.json") 37 | else: 38 | output_file = args.output_path 39 | gen_file_contracts(args.data_file, args.test_file, args.solution_file, output_file, verbose = args.verbose) 40 | else: 41 | raise ValueError(f"Unrecognized option: {args.level}!") 42 | print(f"Generated contracts stored into {output_file}") 43 | 44 | def st(args): 45 | check_environ() 46 | if args.level == "func": 47 | if os.path.isdir(args.output_path): 48 | output_file = os.path.join(args.output_path, "func_stressful_testcases.json") 49 | else: 50 | output_file = args.output_path 51 | gen_func_sts(args.data_file, args.test_file, args.contract_file, output_file, verbose = args.verbose, num = args.num) 52 | elif args.level == "file": 53 | if not args.solution_file: 54 | raise ValueError("Path to solution file must be indicated with option -s when generating file-level stressful test cases!") 55 | if os.path.isdir(args.output_path): 56 | output_file = os.path.join(args.output_path, "file_stressful_testcases.json") 57 | else: 58 | output_file = args.output_path 59 | gen_file_sts(args.data_file, args.test_file, args.solution_file, args.contract_file, output_file, verbose = args.verbose, num = args.num) 60 | else: 61 | raise ValueError(f"Unrecognized option: {args.level}!") 62 | print(f"Generated Stressful Test Cases stored into {output_file}") 63 | 64 | 65 | def pipe(args): 66 | check_environ() 67 | if args.level == "func": 68 | if os.path.isdir(args.output_path): 69 | contract_file = os.path.join(args.output_path, "func_contracts.json") 70 | st_file = os.path.join(args.output_path, "func_stressful_testcases.json") 71 | else: 72 | contract_file = os.path.join(os.path.dirname(args.output_path), "func_contracts.json") 73 | st_file = args.output_path 74 | print(colored("+++++++++++Step 1: Generating Contracts...", "green")) 75 | gen_func_contracts(args.data_file, args.test_file, args.solution_file, contract_file) 76 | print(f"Generated contracts stored into {contract_file}") 77 | print(colored("Done!", "green")) 78 | print(colored("+++++++++++Step 2: Generating Stressful Test Cases...", "green")) 79 | gen_func_sts(args.data_file, args.test_file, contract_file, st_file, num = args.num) 80 | print(f"Generated Stressful Test Cases stored into {st_file}") 81 | print(colored("Done!", "green")) 82 | print(colored("Pipeline Finished!", "green")) 83 | elif args.level == "file": 84 | if not args.solution_file: 85 | raise ValueError("Path to solution file must be indicated with option -s when generating file-level stressful test cases!") 86 | if os.path.isdir(args.output_path): 87 | contract_file = os.path.join(args.output_path, "file_contracts.json") 88 | st_file = os.path.join(args.output_path, "file_stressful_testcases.json") 89 | else: 90 | contract_file = os.path.join(os.path.dirname(args.output_path), "file_contracts.json") 91 | st_file = args.output_path 92 | print(colored("+++++++++++Step 1: Generating Contracts..", "green")) 93 | gen_file_contracts(args.data_file, args.test_file, args.solution_file, contract_file) 94 | print(f"Generated contracts stored into {contract_file}") 95 | print(colored("Done!", "green")) 96 | print(colored("+++++++++++Step 2: Generating Stressful Test Cases...", "green")) 97 | gen_file_sts(args.data_file, args.test_file, args.solution_file, contract_file, st_file, num = args.num) 98 | print(f"Generated Stressful Test Cases stored into {st_file}") 99 | print(colored("Done!", "green")) 100 | print(colored("Pipeline Finished!", "green")) 101 | else: 102 | raise ValueError(f"Unrecognized option: {args.level}!") 103 | 104 | 105 | 106 | def main(): 107 | arg_parser = argparse.ArgumentParser() 108 | sub_parsers = arg_parser.add_subparsers(dest='cmd') 109 | arg_parser.set_defaults(func = info) 110 | 111 | contract_parser = sub_parsers.add_parser('contract') 112 | contract_parser.add_argument('-o', '--output_path', required=True, type=str, help='Path to the output contract file') 113 | contract_parser.add_argument('-l', '--level', required = False, default="func", type=str, help="Should be func for function-level or file for file-level stressful test cases.") 114 | contract_parser.add_argument('-d', '--data_file', required=True, type=str, help='Path to the data file') 115 | contract_parser.add_argument('-t', '--test_file', required=True, type=str, help='Path to the correctness test case file') 116 | contract_parser.add_argument('-s', '--solution_file', required=True, type=str, help='Path to the ground truth solution file') 117 | contract_parser.add_argument('-v', '--verbose', required=False, default=False, action = "store_true", help='Display debug information') 118 | contract_parser.set_defaults(func = contract) 119 | 120 | st_parser = sub_parsers.add_parser('st') 121 | st_parser.add_argument('-o', '--output_path', required=True, type=str, help='Path to the output contract file') 122 | st_parser.add_argument('-l', '--level', required = False, default="func", type=str, help="Should be func for function-level or file for file-level stressful test cases.") 123 | st_parser.add_argument('-d', '--data_file', required=True, type=str, help='Path to the data file') 124 | st_parser.add_argument('-t', '--test_file', required=True, type=str, help='Path to the correctness test case file') 125 | st_parser.add_argument('-s', '--solution_file', required=False, type=str, help='Path to the ground truth solution file') 126 | st_parser.add_argument('-c', '--contract_file', required=True, type=str, help='Path to the contract file') 127 | st_parser.add_argument('-n', '--num', required=False, type=int, default=5, help='Number of stressful test cases to generate for each instance') 128 | st_parser.add_argument('-v', '--verbose', required=False, default=False, action = "store_true", help='Display debug information') 129 | st_parser.set_defaults(func = st) 130 | 131 | pipe_parser = sub_parsers.add_parser('pipe') 132 | pipe_parser.add_argument('-o', '--output_path', required=True, type=str, help='Path to the output contract file') 133 | pipe_parser.add_argument('-l', '--level', required = False, default="func", type=str, help="Should be func for function-level or file for file-level stressful test cases.") 134 | pipe_parser.add_argument('-d', '--data_file', required=True, type=str, help='Path to the data file') 135 | pipe_parser.add_argument('-t', '--test_file', required=True, type=str, help='Path to the correctness test case file') 136 | pipe_parser.add_argument('-s', '--solution_file', required=False, type=str, help='Path to the ground truth solution file') 137 | pipe_parser.add_argument('-n', '--num', required=False, type=int, default=5, help='Number of stressful test cases to generate for each instance') 138 | pipe_parser.set_defaults(func = pipe) 139 | 140 | args = arg_parser.parse_args() 141 | args.func(args) 142 | -------------------------------------------------------------------------------- /stgen/contract_generator.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import random 3 | from typing import List, Dict, Any, Tuple, Union 4 | import time 5 | import os 6 | import json 7 | import gc 8 | from termcolor import colored 9 | from tqdm import tqdm 10 | 11 | 12 | from stgen.utils import make_request 13 | from coffe.code_execution import untrusted_check 14 | from coffe.sanitize import sanitize 15 | 16 | 17 | class FuncContractGenerator: 18 | def __init__(self, prompt: str, testcases: list, code: str, entry_point: str, verbose = False): 19 | self.prompt = prompt 20 | self.testcases = testcases 21 | self.code = code 22 | self.entry_point = entry_point 23 | self.iteration = 20 24 | self.target_num = 4 25 | self.current_num = 0 26 | self.verbose = verbose 27 | self.history = [self.code] 28 | 29 | self.instructions = [ 30 | "Generate a new, different contract (assertion) for the function. The assertion should check input validity or enforce a function invariant. Only return the new assert statement, nothing else.", 31 | "Create a unique contract (assertion) for the function that wasn't used before. Focus on input validation or function preconditions. Only provide the new assert statement.", 32 | ] 33 | 34 | def get_last_version(self): 35 | return self.history[-1] 36 | 37 | def update(self, new_code: str): 38 | self.current_num += 1 39 | self.history.append(new_code) 40 | if self.verbose: 41 | print(colored(f"Update {self.current_num} version", "blue")) 42 | 43 | def insert_contract_into_code(self, contract: str) -> str: 44 | lines = self.get_last_version().split("\n") 45 | func_start = next(i for i, line in enumerate(lines) if line.startswith(f"def {self.entry_point}")) 46 | indent = len(lines[func_start]) - len(lines[func_start].lstrip()) 47 | lines.insert(func_start + 1, " " * (indent + 4) + contract) 48 | return "\n".join(lines) 49 | 50 | def generate_and_insert_contract(self) -> bool: 51 | instruction = random.choice(self.instructions) 52 | prompt = f"{instruction}\n\n" 53 | prompt += f"Task description:\n```\n{self.prompt}\n```\n" 54 | prompt += f"Current function code:\n```python\n{self.get_last_version()}\n```\n" 55 | 56 | demo_testcases = random.choices(self.testcases, k=min(3, len(self.testcases))) 57 | prompt += f"Example testcases:\n```python\n{demo_testcases}\n```\n" 58 | 59 | if len(self.history) > 1: 60 | existing_contracts = [line.strip() for line in self.history[-1].split("\n") if line.strip().startswith("assert")] 61 | if existing_contracts: 62 | prompt += f"Existing contracts (do not repeat these):\n```python\n{existing_contracts}\n```\n" 63 | 64 | prompt += "Generate a new, unique contract (assert statement) for this function. Return only the assert statement, nothing else." 65 | 66 | if self.verbose: 67 | print(prompt) 68 | ret = make_request(prompt) 69 | if self.verbose: 70 | print(ret) 71 | 72 | try: 73 | new_contract = sanitize(ret[0], "", codegen=False, global_code=True).strip() 74 | if len(new_contract) == 0: 75 | if self.verbose: 76 | print(colored("Contract insertion failed", "red")) 77 | return False 78 | if not new_contract.startswith("assert"): 79 | new_contract = f"assert {new_contract}" 80 | except Exception as e: 81 | print(f"Error when sanitizing the contract: {e}") 82 | return False 83 | 84 | if self.verbose: 85 | print(colored("Prompt: ", "green")) 86 | print(prompt) 87 | print(colored("Generated contract: ", "green")) 88 | print(new_contract) 89 | 90 | new_code = self.insert_contract_into_code(new_contract) 91 | execute_code = new_code.replace(f"def {self.entry_point}", "def solution") 92 | 93 | time_limits = [100 for _ in self.testcases] 94 | stat, results = untrusted_check( 95 | io=False, 96 | code=execute_code, 97 | testcases=self.testcases, 98 | atol=None, 99 | ref_time=time_limits, 100 | fast_check=True, 101 | check=True, 102 | generator=False, 103 | gt_time_limit_factor=1.0 104 | ) 105 | 106 | 107 | if len(results) > 0 and all(result["status"] == 1 for result in results): 108 | self.update(new_code) 109 | if self.verbose: 110 | print(colored("Contract insertion successful", "green")) 111 | return True 112 | else: 113 | if self.verbose: 114 | print(colored("Contract insertion failed", "red")) 115 | return False 116 | 117 | def gen(self): 118 | while self.iteration > 0 and self.current_num < self.target_num: 119 | self.iteration -= 1 120 | self.generate_and_insert_contract() 121 | 122 | return self.history[-1] 123 | 124 | class FileContractGenerator: 125 | def __init__(self, prompt:str, testcases: list, code: str, io: bool, verbose = False): 126 | self.prompt = prompt 127 | self.testcases = testcases 128 | self.code = code 129 | self.io = io 130 | self.iteration = 20 131 | self.tartget_num = 4 132 | self.current_num = 0 133 | self.history = [self.code] 134 | self.verbose = verbose 135 | 136 | self.instructions = [ 137 | "Please insert the constracts of inputs (using assertion) to the given code based on task description. Please consider the internal logical constraints and the data constrains of the inputs while generating this assertion. Just insert one more assert based on current version. Just reply edited code without any other text.\n", 138 | "Please insert the constracts of inputs (using assertion) to fullfill the data type and constrains to the given code based on task description. Insert on contracts per time. Just insert one more assert based on current version. Just reply edited code without any other text.\n", 139 | ] 140 | 141 | def get_last_version(self): 142 | return self.history[-1] 143 | 144 | def update(self, new_code: str): 145 | self.current_num += 1 146 | self.history.append(new_code) 147 | if self.verbose: 148 | print(colored(f"Update {self.current_num} version", "blue")) 149 | 150 | def get_update_pairs(self): 151 | if len(self.history) < 2: 152 | return None 153 | if len(self.history) >= 2: 154 | return [(self.history[i], self.history[i+1]) for i in range(len(self.history)-1)] 155 | 156 | def check_correctness(self, results: List) -> bool: 157 | if len(results) == 0: 158 | return False, "Empty results" 159 | for result in results: 160 | if result["status"] != 1: 161 | return False, result["input"] 162 | return True, ' ' 163 | 164 | def contract_update(self): 165 | instruction = random.choice(self.instructions) 166 | prompt = f"\n{random.choice(instruction)}\n" 167 | 168 | prompt += f"Here is the task description of the code that we want to insert contracts:\n```\n{self.prompt}\n```" 169 | prompt += f"Here is the code of that we want to insert contracts:\n```\n{self.get_last_version()}\n```" 170 | demo_testcases = random.choices(self.testcases, k=random.randint(1, min(len(self.testcases), 10))) 171 | prompt += f"Here is the example inputs and outputs of the code:\n```\n{demo_testcases}\n```" 172 | 173 | prompt += instruction 174 | 175 | ret = make_request( 176 | prompt, 177 | ) 178 | 179 | try: 180 | ret = sanitize(ret[0], "", codegen=False, global_code=True) 181 | except Exception as e: 182 | return False 183 | 184 | if self.verbose: 185 | print(colored("Prompt: ", "green")) 186 | print(prompt) 187 | print(colored("Response: ", "green")) 188 | print(ret) 189 | 190 | results = [] 191 | 192 | time_limits = [100 for t in self.testcases] 193 | stat, results = untrusted_check( 194 | io=self.io, 195 | code=ret, 196 | testcases=self.testcases, 197 | atol=None, 198 | ref_time=time_limits, 199 | fast_check=True, 200 | check = True, 201 | generator=False, 202 | gt_time_limit_factor=1.0 203 | ) 204 | correctness, input = self.check_correctness(results) 205 | if correctness: 206 | self.update(ret) 207 | if self.verbose: 208 | print(colored("Execution success", "green")) 209 | return True 210 | else: 211 | if self.verbose: 212 | print(colored("Execution failed", "red")) 213 | print(colored(f"Input: {input}", "red")) 214 | return False 215 | 216 | def gen(self): 217 | while self.iteration > 0 and self.current_num < self.tartget_num: 218 | self.iteration -= 1 219 | self.contract_update() 220 | 221 | 222 | return self.history[-1] 223 | 224 | 225 | def gen_func_contracts(data_file, testcase_file, solution_file, output_file, verbose = False): 226 | data = json.load(open(data_file, "r")) 227 | testcases = json.load(open(testcase_file, "r")) 228 | solutions = json.load(open(solution_file, "r")) 229 | 230 | contracts = {} 231 | 232 | for d in tqdm(data): 233 | if d["final_prompt"] not in testcases or d["final_prompt"] not in solutions: 234 | continue 235 | entry_point = d["entry_point"] 236 | solution = solutions[d["final_prompt"]][0].replace(f"def solution", f"def {entry_point}") 237 | generator = FuncContractGenerator(d["prompt"], testcases[d["final_prompt"]], solution, entry_point, verbose = verbose) 238 | contracts[d["final_prompt"]] = generator.gen() 239 | 240 | 241 | with open(output_file, "w", encoding = "utf-8") as f: 242 | f.write(json.dumps(contracts, sort_keys=True, indent=4, separators=(',', ': '))) 243 | 244 | 245 | def gen_file_contracts(data_file, testcase_file, solution_file, output_file, verbose = False): 246 | data = json.load(open(data_file, "r")) 247 | testcases = json.load(open(testcase_file, "r")) 248 | solutions = json.load(open(solution_file, "r")) 249 | 250 | contracts = {} 251 | 252 | for d in tqdm(data): 253 | if d["final_prompt"] not in testcases or d["final_prompt"] not in solutions: 254 | continue 255 | solution, io = solutions[d["final_prompt"]] 256 | prompt = d["final_prompt"].replace("You are an expert Python programmer, and here is your task:\n", "").replace("\n```python", "") 257 | generator = FileContractGenerator(prompt, testcases[d["final_prompt"]], solution, io, verbose = verbose) 258 | contracts[d["final_prompt"]] = generator.gen() 259 | 260 | 261 | with open(output_file, "w", encoding = "utf-8") as f: 262 | f.write(json.dumps(contracts, sort_keys=True, indent=4, separators=(',', ': '))) 263 | 264 | 265 | if __name__ == '__main__': 266 | gen_func_contracts("Coffe/datasets/function/data.json", "Coffe/datasets/function/testcases.json", "Coffe/datasets/function/best_solutions.json", "function_contracts.json") 267 | #gen_file_contracts("Coffe/datasets/file/data.json", "Coffe/datasets/file/testcases.json", "Coffe/datasets/file/best_solutions.json", "file_contracts.json") -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /coffe/dataset.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, load_from_disk 2 | import argparse 3 | import os 4 | import json 5 | import ast 6 | import sys 7 | import time 8 | 9 | from coffe.sanitize import CodeVisitor, CodeProcessor 10 | from coffe.config import benchmarks 11 | 12 | class Dataset(object): 13 | def __init__(self, name, data_path = None, testfile_path = None, full = True, train_description_path = None, tmp_dir = "./tmp") -> None: 14 | self.name = name 15 | if self.name not in benchmarks: 16 | raise ValueError(f"Cannot find the benchmark {self.name} in Coffe. Make sure you have registered it in config.py and reinstall Coffe.") 17 | self.index = -1 18 | self.tmp_dir = tmp_dir 19 | self.data_path = data_path 20 | if self.name in ["function", "file"]: 21 | self.dataset = json.load(open(os.path.join(data_path, "data.json"), "r")) 22 | elif not data_path: 23 | self.dataset = load_dataset(name, cache_dir = "./temp_datasets") 24 | else: 25 | self.dataset = load_from_disk(data_path) 26 | 27 | if self.name == "mbpp": 28 | testfile_path = os.path.join(data_path, "MbppPlus.jsonl") 29 | 30 | if testfile_path and name in ["mbpp", "openai_humaneval"]: 31 | raw_data = open(testfile_path, "r").read().splitlines() 32 | data = [] 33 | for line in raw_data: 34 | data.append(json.loads(line)) 35 | self.testcases = {} 36 | for instance in data: 37 | if self.name == "mbpp": 38 | self.testcases[instance["task_id"].replace("HumanEval/", "").replace("Mbpp/", "")] = {"inputs": instance["base_input"], "outputs": None, "entry_point": instance["entry_point"]} 39 | else: 40 | self.testcases[instance["task_id"].replace("HumanEval/", "").replace("Mbpp/", "")] = {"inputs": instance["base_input"], "outputs": None} 41 | else: 42 | self.testcases = {} 43 | 44 | if full: 45 | self.prompt2instance = self.get_prompt2instance() 46 | self.prompt2groundtruth = {} 47 | self.prompt2testcase = {} 48 | self.prompt2stressful = {} 49 | self.prompt2io = {} 50 | 51 | 52 | self.reset_index() 53 | 54 | def reset_index(self): 55 | self.index = -1 56 | 57 | def length(self): 58 | return len(self.dataset) 59 | 60 | def next(self): 61 | self.index += 1 62 | finish = False 63 | if self.index == len(self.dataset) - 1: 64 | finish = True 65 | 66 | instance = {} 67 | if self.name in ["openai_humaneval", "mbpp"] and len(self.testcases) > 0: 68 | for key in self.dataset[self.index]: 69 | instance[key] = self.dataset[self.index][key] 70 | if str(instance["task_id"]).replace("HumanEval/", "") in self.testcases: 71 | instance["testcases"] = self.testcases[str(instance["task_id"]).replace("HumanEval/", "")] 72 | if "entry_point" in instance["testcases"]: 73 | instance["entry_point"] = instance["testcases"]["entry_point"] 74 | else: 75 | instance["testcases"] = None 76 | else: 77 | instance = self.dataset[self.index] 78 | 79 | 80 | return instance, finish 81 | 82 | def get_function_signature(self, instance): 83 | if self.name == "mbpp": 84 | code = instance["code"] 85 | lines = code.splitlines() 86 | if "entry_point" in instance: 87 | for line in lines: 88 | if line.startswith("def") and instance["entry_point"] in line: 89 | return line.replace("def", "").replace(":", "").strip() 90 | else: 91 | visitor = CodeVisitor(code) 92 | visitor.run() 93 | for line in lines: 94 | if line.startswith("def") and visitor.funcs[-1] in line: 95 | return line.replace("def", "").replace(":", "").strip() 96 | elif self.name in ["codeparrot/apps"]: 97 | code = instance["starter_code"] 98 | lines = code.splitlines() 99 | for line in lines: 100 | if line.startswith("def"): 101 | return line.replace("def", "").replace(":", "").split("->")[0].strip() 102 | return None 103 | 104 | def get_prompt(self, instance): 105 | if self.name in ["function", "file"]: 106 | return instance["final_prompt"] 107 | prompt = "" 108 | if self.name == "openai_humaneval": 109 | prompt += instance["prompt"] 110 | if self.name in ["codeparrot/apps"]: 111 | prompt += "You are an expert Python programmer, and here is your task:\n" 112 | prompt += instance["problem"] 113 | if len(instance["starter_code"]) > 0: 114 | prompt += "\nPlease write a Python function {} for the task.\n```python".format(self.get_function_signature(instance)) 115 | else: 116 | prompt += "\nDo not give explanations, only give the Python code.\nPython Solution:\n```python\n" 117 | if self.name == "mbpp": 118 | signature = self.get_function_signature(instance) 119 | prompt += "You are an expert Python programmer, and here is your task: {} Please write a Python function {} for the task.\n```python".format(instance["prompt"], signature if signature else "") 120 | if self.name == "deepmind/code_contests": 121 | prompt += "You are an expert Python programmer, and here is your task:\n" 122 | prompt += instance["description"] 123 | prompt += "\nDo not give explanations, only give the Python code.\nPython Solution:\n```python\n" 124 | 125 | prompt = prompt.strip() 126 | 127 | return prompt 128 | 129 | def get_chat(self, instance): 130 | return [{"role": "user", "content": self.get_prompt(instance)}] 131 | 132 | def get_prompt_for_current_instance(self): 133 | return self.get_prompt(self.dataset[self.index]) 134 | 135 | def get_all_prompts(self, model = None, context_length = None): 136 | self.reset_index() 137 | prompts = [] 138 | finish = False 139 | while(not finish): 140 | instance, finish = self.next() 141 | prompt = self.get_prompt(instance) 142 | prompts.append(prompt) 143 | 144 | new_prompts = [] 145 | overlong_prompts = [] 146 | 147 | if model != None and context_length != None: 148 | for p in prompts: 149 | if model.get_prompt_length(p) >= context_length: 150 | overlong_prompts.append(p) 151 | else: 152 | new_prompts.append(p) 153 | else: 154 | new_prompts += prompts 155 | 156 | self.reset_index() 157 | 158 | return new_prompts, overlong_prompts 159 | 160 | 161 | def get_prompt2instance(self): 162 | self.reset_index() 163 | finish = False 164 | prompt2instance = {} 165 | while(not finish): 166 | instance, finish = self.next() 167 | prompt = self.get_prompt(instance) 168 | prompt2instance[prompt] = instance 169 | 170 | self.reset_index() 171 | 172 | return prompt2instance 173 | 174 | 175 | def save_prompt2id(self, file_path = None): 176 | self.reset_index() 177 | finish = False 178 | prompt2id = {} 179 | while(not finish): 180 | instance, finish = self.next() 181 | prompt = self.get_prompt(instance) 182 | if self.name == "codeparrot/apps": 183 | prompt2id[prompt] = instance["problem_id"] 184 | elif self.name == "deepmind/code_contests": 185 | prompt2id[prompt] = instance["name"] 186 | elif self.name in ["mbpp", "openai_humaneval"]: 187 | prompt2id[prompt] = instance["task_id"] 188 | 189 | self.reset_index() 190 | 191 | if file_path != None: 192 | if os.path.exists(file_path): 193 | new_filepath = "{}_{}.json".format(file_path.replace(".json", ""), time.time()) 194 | print("Warning! {} already exists and renamed to {} to avoid overwriting.".format(file_path, new_filepath)) 195 | os.system("mv {} {}".format(file_path, new_filepath)) 196 | with open(file_path, "w") as f: 197 | f.write(json.dumps(prompt2id, sort_keys=True, indent=4, separators=(',', ': '))) 198 | elif self.data_path != None: 199 | file_path = os.path.join(self.data_path, "prompt2id.json") 200 | if os.path.exists(file_path): 201 | new_filepath = "{}_{}.json".format(file_path.replace(".json", ""), time.time()) 202 | print("Warning! {} already exists and renamed to {} to avoid overwriting.".format(file_path, new_filepath)) 203 | os.system("mv {} {}".format(file_path, new_filepath)) 204 | with open(file_path, "w") as f: 205 | f.write(json.dumps(prompt2id, sort_keys=True, indent=4, separators=(',', ': '))) 206 | else: 207 | raise ValueError("Cannot find the file path for prompt2id.") 208 | 209 | 210 | def load_testcases(self, file_path = None): 211 | if file_path != None: 212 | self.prompt2testcase = json.load(open(file_path, "r")) 213 | elif self.data_path != None: 214 | self.prompt2testcase = json.load(open(os.path.join(self.data_path, "testcases.json"), "r")) 215 | else: 216 | raise ValueError("Cannot find the file path for test cases.") 217 | 218 | def load_stressful_testcases(self, file_path = None): 219 | if file_path != None: 220 | self.prompt2stressful = json.load(open(file_path, "r")) 221 | elif self.data_path != None: 222 | self.prompt2stressful = json.load(open(os.path.join(self.data_path, "stressful_testcases.json"), "r")) 223 | else: 224 | raise ValueError("Cannot find the file path for stressful test cases.") 225 | 226 | def load_groundtruths(self, file_path = None): 227 | if file_path != None: 228 | data = json.load(open(file_path, "r")) 229 | self.prompt2groundtruth = data["prompt2groundtruth"] 230 | self.prompt2io = data["prompt2io"] 231 | elif self.data_path != None: 232 | data = json.load(open(os.path.join(self.data_path, "solutions.json"), "r")) 233 | self.prompt2groundtruth = data["prompt2groundtruth"] 234 | self.prompt2io = data["prompt2io"] 235 | else: 236 | raise ValueError("Cannot find the file path for ground truth solutions.") 237 | 238 | def load_best_groundtruths(self, file_path = None): 239 | if file_path != None: 240 | data = json.load(open(file_path, "r")) 241 | self.prompt2bestgroundtruth = data 242 | elif self.data_path != None: 243 | data = json.load(open(os.path.join(self.data_path, "best_solutions.json"), "r")) 244 | self.prompt2bestgroundtruth = data 245 | else: 246 | raise ValueError("Cannot find the file path for best ground truth solutions.") 247 | 248 | def print_info(self): 249 | self.load_testcases() 250 | self.load_groundtruths() 251 | self.load_stressful_testcases() 252 | 253 | prompts, overlong_prompts = self.get_all_prompts() 254 | empty_solution = 0 255 | empty_testcase = 0 256 | empty_stressful = 0 257 | total_solution = 0 258 | total_testcase = 0 259 | total_stressful = 0 260 | total_num = len(prompts + overlong_prompts) 261 | 262 | 263 | for prompt in (prompts + overlong_prompts): 264 | if prompt not in self.prompt2groundtruth or len(self.prompt2groundtruth[prompt]) == 0: 265 | empty_solution += 1 266 | else: 267 | total_solution += len(self.prompt2groundtruth[prompt]) 268 | if prompt not in self.prompt2testcase or len(self.prompt2testcase[prompt]) == 0 or prompt not in self.prompt2groundtruth or len(self.prompt2groundtruth[prompt]) == 0: 269 | empty_testcase += 1 270 | else: 271 | total_testcase += len(self.prompt2testcase[prompt]) 272 | if prompt not in self.prompt2stressful or len(self.prompt2stressful[prompt]) == 0 or prompt not in self.prompt2groundtruth or len(self.prompt2groundtruth[prompt]) == 0: 273 | empty_stressful += 1 274 | else: 275 | total_stressful += len(self.prompt2stressful[prompt]) 276 | 277 | print("Dataset Name: {}\nTotal instance: {}\nInstance with solutions: {}\nInstance with testcases: {}\nInstance with stressful testcases: {}\nAverage solutions: {} ({}/{})\nAverage testcases: {} ({}/{})\nAverage stressful testcases: {} ({}/{})".format(self.name, total_num, total_num - empty_solution, total_num - empty_testcase, total_num - empty_stressful, total_solution / (total_num - empty_solution), total_solution, total_num - empty_solution, total_testcase / (total_num - empty_testcase), total_testcase, total_num - empty_testcase, total_stressful / (total_num - empty_stressful), total_stressful, total_num - empty_stressful)) 278 | print("=" * 20) 279 | 280 | -------------------------------------------------------------------------------- /coffe/sanitize.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import re 3 | import traceback 4 | from typing import List, Optional 5 | 6 | 7 | 8 | class CodeVisitor(ast.NodeVisitor): 9 | def __init__(self, code): 10 | self.funcs = [] 11 | self.func_names = [] 12 | self.classes = [] 13 | self.code = code 14 | self.classname = "global" 15 | self.has_class = False 16 | self.has_input = False 17 | self.only_func = False 18 | self.all_func_in_class = False 19 | 20 | def visit_ClassDef(self, node): 21 | self.classname = node.name 22 | self.has_class = True 23 | self.classes.append(node.name) 24 | self.generic_visit(node) 25 | self.classname = "global" 26 | 27 | def visit_Call(self, node): 28 | if isinstance(node.func, ast.Name) and node.func.id == "input": 29 | self.has_input = True 30 | 31 | self.generic_visit(node) 32 | 33 | def visit_FunctionDef(self, node): 34 | self.func_names.append("{}@{}".format(node.name, self.classname)) 35 | if self.classname == "global": 36 | self.funcs.append(node.name) 37 | self.generic_visit(node) 38 | 39 | def run(self): 40 | self.root = ast.parse(self.code) 41 | self.only_func = True 42 | for statement in self.root.body: 43 | if not isinstance(statement, ast.FunctionDef) and not isinstance(statement, ast.ClassDef): 44 | self.only_func = False 45 | break 46 | self.visit(self.root) 47 | 48 | if len(self.classes) > 0 and len(self.func_names) > 0: 49 | self.all_func_in_class = True 50 | for f in self.func_names: 51 | if f.split("@")[-1] == "global": 52 | self.all_func_in_class = False 53 | break 54 | 55 | class PlaceHolder(ast.NodeTransformer): 56 | def __init__(self): 57 | pass 58 | 59 | def visit_FunctionDef(self, node): 60 | if len(node.body) == 0: 61 | node.body.append(ast.Pass()) 62 | return node 63 | else: 64 | self.generic_visit(node) 65 | return node 66 | 67 | def visit_AsyncFunctionDef(self, node): 68 | if len(node.body) == 0: 69 | node.body.append(ast.Pass()) 70 | return node 71 | else: 72 | self.generic_visit(node) 73 | return node 74 | 75 | def visit_ClassDef(self, node): 76 | if len(node.body) == 0: 77 | node.body.append(ast.Pass()) 78 | return node 79 | else: 80 | self.generic_visit(node) 81 | return node 82 | 83 | def visit_If(self, node): 84 | if len(node.body) == 0: 85 | node.body.append(ast.Pass()) 86 | return node 87 | else: 88 | self.generic_visit(node) 89 | return node 90 | 91 | def visit_For(self, node): 92 | if len(node.body) == 0: 93 | node.body.append(ast.Pass()) 94 | return node 95 | else: 96 | self.generic_visit(node) 97 | return node 98 | 99 | def visit_While(self, node): 100 | if len(node.body) == 0: 101 | node.body.append(ast.Pass()) 102 | return node 103 | else: 104 | self.generic_visit(node) 105 | return node 106 | 107 | def visit_AsyncFor(self, node): 108 | if len(node.body) == 0: 109 | node.body.append(ast.Pass()) 110 | return node 111 | else: 112 | self.generic_visit(node) 113 | return node 114 | 115 | def visit_With(self, node): 116 | if len(node.body) == 0: 117 | node.body.append(ast.Pass()) 118 | return node 119 | else: 120 | self.generic_visit(node) 121 | return node 122 | 123 | def visit_AsyncWith(self, node): 124 | if len(node.body) == 0: 125 | node.body.append(ast.Pass()) 126 | return node 127 | else: 128 | self.generic_visit(node) 129 | return node 130 | 131 | def visit_Try(self, node): 132 | if len(node.body) == 0: 133 | node.body.append(ast.Pass()) 134 | return node 135 | else: 136 | self.generic_visit(node) 137 | return node 138 | 139 | def visit_TryStar(self, node): 140 | if len(node.body) == 0: 141 | node.body.append(ast.Pass()) 142 | return node 143 | else: 144 | self.generic_visit(node) 145 | return node 146 | 147 | def run(self, root): 148 | self.visit(root) 149 | return root 150 | 151 | 152 | class CommentRemover(ast.NodeTransformer): 153 | def __init__(self, code): 154 | self.root = ast.parse(code) 155 | 156 | def visit_Import(self, node): 157 | return None 158 | 159 | def visit_ImportFrom(self, node): 160 | return None 161 | 162 | def visit_Expr(self, node): 163 | self.generic_visit(node) 164 | if type(node.value) == ast.Constant and isinstance(node.value.value, str): 165 | return None 166 | else: 167 | return node 168 | 169 | def run(self): 170 | self.visit(self.root) 171 | placeholder = PlaceHolder() 172 | self.root = placeholder.run(self.root) 173 | ast.fix_missing_locations(self.root) 174 | new_code = ast.unparse(self.root) 175 | return new_code 176 | 177 | 178 | 179 | 180 | class CodeProcessor(ast.NodeTransformer): 181 | def __init__(self, code, entry_point = None, force_rename = False): 182 | self.code = code 183 | self.entry_point = entry_point 184 | self.classname = "global" 185 | self.mode = "funcname" 186 | self.ori_name = None 187 | self.force_rename = force_rename 188 | 189 | 190 | def visit_ClassDef(self, node): 191 | self.classname = node.name 192 | self.generic_visit(node) 193 | self.classname = "global" 194 | 195 | def visit_Name(self, node): 196 | if self.mode == "funcname" and self.ori_name != None and node.id == self.ori_name: 197 | node.id = "solution" 198 | 199 | return node 200 | 201 | 202 | def visit_FunctionDef(self, node): 203 | #rename the first function generated as LLMs tend to generate extra useless code in the end of response 204 | if not self.entry_point and self.mode == "funcname" and node.name == self.visitor.funcs[-1] and self.classname == "global": 205 | self.ori_name = node.name 206 | node.name = "solution" 207 | elif self.entry_point and self.mode == "funcname" and node.name == self.entry_point and self.classname == "global": 208 | self.ori_name = node.name 209 | node.name = "solution" 210 | 211 | self.generic_visit(node) 212 | 213 | return node 214 | 215 | def visit_Call(self, node): 216 | if self.mode == "input" and isinstance(node.func, ast.Name) and node.func.id == "input": 217 | new_node = ast.Name(id = "inputs") 218 | ast.fix_missing_locations(new_node) 219 | return new_node 220 | else: 221 | self.generic_visit(node) 222 | 223 | 224 | return node 225 | 226 | def run(self, no_modify = False): 227 | try: 228 | remover = CommentRemover(self.code) 229 | new_code = remover.run().strip() 230 | if len(new_code) == 0: 231 | return -1, False 232 | self.visitor = CodeVisitor(new_code) 233 | self.visitor.run() 234 | self.root = ast.parse(self.code) 235 | if self.visitor.all_func_in_class: 236 | if no_modify: 237 | return ast.unparse(self.root), False 238 | self.classname = self.visitor.classes[-1] 239 | self.funcname = None 240 | for func_name in self.visitor.func_names: 241 | if func_name.split("@")[-1] == self.classname: 242 | self.funcname = func_name.split("@")[0] 243 | args = ast.arguments(posonlyargs = [], args = [], vararg = ast.arg(arg = "args"), kwonlyargs = [], kw_defaults = [], kwarg = None, defaults = []) 244 | init_statement = ast.Assign(targets = [ast.Name(id = "s", ctx = ast.Store)], value = ast.Call(func = ast.Name(id = self.classname, ctx = ast.Load), args = [], keywords = []), type_comment = None) 245 | ast.fix_missing_locations(init_statement) 246 | call_statement = ast.Expr(value = ast.Call(func = ast.Attribute(value = ast.Name(id = "s", ctx = ast.Load), attr = self.funcname, ctx = ast.Store), args = [ast.Starred(value = ast.Name(id = "args", ctx = ast.Load))], keywords = [])) 247 | ast.fix_missing_locations(call_statement) 248 | statements = [init_statement, call_statement] 249 | new_node = ast.FunctionDef(name = "solution", args = args, body = statements, decorator_list =[], returns = None, type_comment = None, type_params = []) 250 | ast.fix_missing_locations(new_node) 251 | self.root.body = self.root.body + [new_node] 252 | return ast.unparse(self.root), False 253 | elif (len(self.visitor.funcs) > 0 and self.visitor.only_func) or self.force_rename: 254 | if no_modify: 255 | return ast.unparse(self.root), False 256 | self.mode = "funcname" 257 | self.visit(self.root) 258 | return ast.unparse(self.root), False 259 | else: 260 | return self.code, True 261 | except Exception as e: 262 | #print(e) 263 | #traceback.print_exc() 264 | return -1, False 265 | 266 | 267 | 268 | 269 | 270 | def syntax_check(code, verbose=False): 271 | try: 272 | ast.parse(code) 273 | return True 274 | except (SyntaxError, MemoryError): 275 | if verbose: 276 | traceback.print_exc() 277 | return False 278 | 279 | 280 | def remove_unindented_lines(code, protect_before, execeptions, trim_tails): 281 | lines = code.splitlines() 282 | cut_idx = [] 283 | cut_enabled = False 284 | for i, line in enumerate(lines): 285 | if not cut_enabled and line.startswith(protect_before): 286 | cut_enabled = True 287 | continue 288 | if line.strip() == "": 289 | continue 290 | if any(line.startswith(e) for e in execeptions): 291 | continue 292 | 293 | lspace = len(line) - len(line.lstrip()) 294 | if lspace == 0: 295 | cut_idx.append(i) 296 | 297 | if any(line.rstrip().startswith(t) for t in trim_tails): 298 | # cut off everything behind 299 | cut_idx.extend(list(range(i, len(lines)))) 300 | break 301 | 302 | return "\n".join([line for i, line in enumerate(lines) if i not in cut_idx]) 303 | 304 | 305 | def to_four_space_indents(old_code): 306 | new_code = "" 307 | for line in old_code.splitlines(): 308 | lspace = len(line) - len(line.lstrip()) 309 | if lspace == 3: 310 | new_code += " " 311 | new_code += line + "\n" 312 | return new_code 313 | 314 | def remove_space_for_codegen(old_code): 315 | new_code = "" 316 | for line in old_code.splitlines(): 317 | if len(line.strip()) == 0: 318 | new_code += "\n" 319 | continue 320 | lspace = len(line) - len(line.lstrip()) 321 | if lspace % 4 == 2: 322 | new_code += " "*(lspace - 2) 323 | elif lspace%4 == 0: 324 | new_code += " "*lspace 325 | items = line.lstrip().split(" ") 326 | for index, i in enumerate(items): 327 | new_code += i.replace(" ", "") 328 | if index < len(items) - 1: 329 | new_code += " " 330 | 331 | new_code += "\n" 332 | 333 | return new_code 334 | 335 | 336 | 337 | 338 | def sanitize( 339 | old_code: str, 340 | entry_point: str, 341 | rm_prefix_lines: Optional[str] = None, 342 | eofs: List = None, 343 | codegen: bool = False, 344 | global_code: bool = False, 345 | chat: bool = False 346 | ): 347 | new_code = old_code.replace("\r\n", "\n").replace("\\_", "_").replace("if __name__", "if 1 or __name__") 348 | if "<|end_header_id|>" in new_code: 349 | new_code = new_code.split("<|end_header_id|>")[-1] 350 | if codegen: 351 | new_code = remove_space_for_codegen(new_code) 352 | if global_code and "```" in new_code: 353 | if len(new_code.split("```python\n")) > 1: 354 | new_code = new_code.split("```python\n")[1] 355 | elif len(new_code.split("```")) > 2: 356 | new_code = new_code.split("```")[1] 357 | new_code = new_code.split("```")[0] 358 | 359 | return new_code.strip() 360 | 361 | 362 | if new_code.endswith("```"): 363 | new_code = new_code[:-3] 364 | if rm_prefix_lines is not None: 365 | new_code = "\n".join( 366 | [ 367 | line 368 | for line in old_code.splitlines() 369 | if not line.startswith(rm_prefix_lines) 370 | ] 371 | ) 372 | 373 | new_code = "\n" + new_code 374 | def_left = "def " + entry_point 375 | 376 | # basic handling of chat output 377 | new_code = new_code.replace("\n```python\n", "\n```\n") 378 | if def_left in new_code: 379 | for chunk in new_code.split("\n```\n"): 380 | if def_left in chunk: 381 | new_code = chunk 382 | break 383 | else: 384 | new_code = new_code.split("```")[0] 385 | 386 | if codegen: 387 | for chunk in new_code.split("\"\"\""): 388 | if def_left in chunk: 389 | new_code = chunk 390 | break 391 | 392 | chunks = [chunk for chunk in re.split(f"{def_left}\s*\(", new_code)] 393 | # TODO: having return does not mean this is complete 394 | bodies = [chunk for chunk in chunks[1:] if " return " in chunk.split("\ndef")[0]] 395 | def_left = def_left + "(" 396 | new_code = def_left + def_left.join(bodies) if len(bodies) > 0 else "" # fn + impl 397 | new_code = to_four_space_indents(new_code) 398 | 399 | for eof in eofs or []: 400 | new_code = new_code.split(eof)[0] 401 | 402 | # remove lines starting from the first unindented line after def_left 403 | new_code = remove_unindented_lines( 404 | new_code, 405 | protect_before=def_left, 406 | execeptions=["def ", "import ", "from "], 407 | trim_tails=['"""', "if", "print"], 408 | ) 409 | new_code = chunks[0] + new_code 410 | 411 | # cut all functions that are not syntactically correct && not the entry point 412 | parts = new_code.split("\ndef ") 413 | includes = [parts[0]] 414 | for fn in new_code.split("\ndef ")[1:]: 415 | if ( 416 | fn.strip().startswith(entry_point + " ") 417 | or fn.strip().startswith(entry_point + "(") 418 | or syntax_check("\ndef " + fn) 419 | ): 420 | includes.append(fn) 421 | new_code = "\ndef ".join(includes) 422 | return new_code.strip() -------------------------------------------------------------------------------- /stgen/st_generator.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import random 3 | from tabnanny import verbose 4 | from typing import List 5 | import time 6 | import os 7 | import json 8 | import gc 9 | from termcolor import colored 10 | from multiprocessing import Array, Value 11 | import copy 12 | from typing import Any, List 13 | from tqdm import tqdm 14 | 15 | from stgen.utils import make_request 16 | from coffe.code_execution import untrusted_check 17 | from coffe.sanitize import sanitize 18 | 19 | 20 | 21 | class FuncSTGen: 22 | def __init__(self, inputs: List, entry_point: str, contract: str, verbose: bool = False): 23 | self.contract = contract 24 | self.entry_point = entry_point 25 | self.seed_pool: List[Any] = copy.deepcopy(inputs) 26 | self.new_inputs = [] 27 | self.seed_hash = set([hash(str(x)) for x in self.seed_pool]) 28 | self.instruction_messages = [ 29 | "Please generate stressful test cases for given function. Please generate complex, time-consuming inputs to test the function. Please generate difficult inputs to test the function, as diverse as possible, as complex as possible.", 30 | "Please generate stressful test cases for given function. Please generate time-consuming inputs to test the function. Please generate difficult inputs to test the function, as diverse as possible, as complex as possible.", 31 | "Please generate stressful test cases for given function. Please generate difficult, time-consuming inputs to test the function. Please generate difficult inputs to test the function, as diverse as possible, as complex as possible.", 32 | ] 33 | self.iteration = 20 34 | self.seed_num = 5 35 | self.verbose = verbose 36 | 37 | def input_seed_selection(self) -> List: 38 | return random.sample(self.seed_pool, k=min(len(self.seed_pool), self.seed_num)) 39 | 40 | def parse_output(self, output): 41 | output = output[0] 42 | output = output.replace("```", '') 43 | output = output.replace("json", '') 44 | if self.verbose: 45 | print("Original output:", output) 46 | output = json.loads(output) 47 | return output 48 | 49 | def check_correctness(self, results: List) -> bool: 50 | if len(results) == 0: 51 | return False, "Empty results" 52 | for result in results: 53 | if result["status"] != 1: 54 | print(result["status_reason"]) 55 | return False, result["input"] 56 | return True, ' ' 57 | 58 | def generate_one(self, selected_inputs: List) -> List: 59 | message = f"We want to conduct stress test. Here is a function that we want to conduct the stress test:\n```\n{self.contract}\n```" 60 | 61 | str_inputs = "\n".join(str(input) for input in selected_inputs) 62 | 63 | message += f"\nThese are some example inputs used to test the function:\n```\n{str_inputs}\n```" 64 | message += "Learn the format of the example input, but do not learn their style, generate test cases as diverse as possible, as complex as possible. Do not generate any test cases of a scale exceeding 10**4. Generate less than 10 cases per time. Please consider the internal contraints of the function. \n" 65 | message += f"\n{random.choice(self.instruction_messages)}" 66 | message += """ 67 | Return format should be pure json format which can be handled by json.loads() function in python, for some cases, you can use python expression to replace extreme long list or dict, but be sure to let it be a string variable, instead of a direct list, because we will handle it using json.loads().: 68 | Return: { 69 | """ 70 | for input in selected_inputs: 71 | message += " " 72 | message += f"\"input{selected_inputs.index(input)+1}\": \"{str(input)}\",\n" 73 | message +=""" 74 | ... 75 | } 76 | """ 77 | message += "\n```json" 78 | if self.verbose: 79 | print(message) 80 | print("") 81 | ret = make_request(message) 82 | ret_json = self.parse_output(ret) 83 | inputs = [] 84 | for key, value in ret_json.items(): 85 | inputs.append(value) 86 | new_inputs = [] 87 | for i in range(len(inputs)): 88 | try: 89 | raw = eval(inputs[i]) # eval the input 90 | del raw 91 | new_inputs.append(inputs[i]) 92 | except: 93 | continue 94 | return new_inputs 95 | 96 | def generate(self, num: int): 97 | self.iteration = num * 3 98 | while len(self.new_inputs) < num and self.iteration >= 0: 99 | if self.verbose: 100 | print(colored("Got new inputs:" + str(len(self.new_inputs)) + " needs:" + str(num), "green")) 101 | seeds = self.input_seed_selection() 102 | try: 103 | gc.collect() 104 | new_inputs = self.generate_one(seeds) 105 | for new_input in new_inputs: 106 | if hash(str(new_input)) not in self.seed_hash: 107 | execute_code = self.contract.replace(f"def {self.entry_point}", "def solution") 108 | time_limits = [50] 109 | testcase = {"input": eval(new_input), "output": None} 110 | stat, results = untrusted_check( 111 | io=False, 112 | code=execute_code, 113 | testcases=[testcase], 114 | atol=None, 115 | ref_time=time_limits, 116 | fast_check=False, 117 | check=False, 118 | generator=False, 119 | gt_time_limit_factor=1.0 120 | ) 121 | correctness, reason = self.check_correctness(results) 122 | if correctness: 123 | if self.verbose: 124 | print("new_input passed:", new_input) 125 | print(' ') 126 | self.seed_pool.append(new_input) 127 | self.seed_hash.add(hash(str(new_input))) 128 | self.new_inputs.append(new_input) 129 | elif self.verbose: 130 | print("new_input failed:", new_input) 131 | print(f"reason: {reason}") 132 | self.iteration -= 1 133 | except Exception as e: 134 | if self.verbose: 135 | print(e) 136 | self.iteration -= 1 137 | return self.new_inputs[:num] 138 | 139 | 140 | 141 | 142 | class FileSTGen: 143 | def __init__(self, description:str, io: bool, inputs: List, contract: str, verbose: bool = False): 144 | self.io = io, 145 | self.description = description 146 | self.inputs = inputs 147 | self.seed_pool: List[Any] = copy.deepcopy(inputs) 148 | self.new_inputs = [] 149 | self.seed_hash = set([hash(str(x)) for x in self.seed_pool]) 150 | self.contract = contract 151 | self.instruction_messages = [ 152 | "Please generate stressful test cases for given code. Please generate complex, time-consuming inputs to test the function. Please generate difficult inputs to test the function, as diverse as possible, as complex as possible.", 153 | "Please generate stressful test cases for given code. Please generate time-consuming inputs to test the function. Please generate difficult inputs to test the function, as diverse as possible, as complex as possible.", 154 | "Please generate stressful test cases for given code. Please generate difficult, time-consuming inputs to test the function. Please generate difficult inputs to test the function, as diverse as possible, as complex as possible.", 155 | ] 156 | self.iteration = 50 157 | self.seed_num = 10 158 | self.verbose = verbose 159 | 160 | def input_seed_selection(self) -> List: 161 | return random.sample(self.seed_pool, k=min(len(self.seed_pool), self.seed_num)) 162 | 163 | def check_correctness(self, results: List) -> bool: 164 | if len(results) == 0: 165 | return False, "Empty results" 166 | for result in results: 167 | if result["status"] != 1: 168 | print(result["status_reason"]) 169 | return False, result["input"] 170 | return True, ' ' 171 | 172 | def generate_one(self, selected_inputs: List): 173 | str_inputs = "\n".join(str(input) for input in selected_inputs) 174 | message = f"We want to conduct stress test. Here is code that we want to conduct the stress test:\n```\n{self.contract}\n```" 175 | message += f"\n{random.choice(self.instruction_messages)}\n" 176 | message += f"Here is the description of the given code:\n```\n{self.description}\n```" 177 | message = message.replace("You are an expert Python programmer, and here is your task:\n", "").replace("```python", "") 178 | message += f"\nThese are some example inputs used to test the function:\n```\n{str_inputs}\n```" 179 | message += "Learn the format of the example input, but do not learn their style, generate test cases as diverse as possible, as complex as possible. Do not generate any test cases of a scale exceeding the contrain described in the problem description. Generate less than 10 cases per time. Please consider the internal contrain of the function. \n" 180 | message += "Please consider the assertions in the code as the constrains of generated test cases. \n" 181 | message += f"\n{random.choice(self.instruction_messages)}\n" 182 | message += """ 183 | **INSTRUCTION** 184 | please write an input generator function generate_input() for this code (DO NOT generate outputs). No need to include the example usage, just the function. 185 | The input generator should take no parameters and return one single test input. 186 | The generated input should meet the input constraints of the problem description. 187 | The generated input must be stressful test input that can distinguish the time efficiency of different programs. 188 | Please reply with ONLY the code without any other content. 189 | Wrap your input generator with ```. 190 | 191 | You can use the python library random if necessary, 192 | here are some examples of how to use the library, which may be helpful: 193 | random.randint(1, 10) 194 | random.randrange(1,100,2) 195 | random.uniform(1.1,5.4) 196 | random.random() 197 | """ 198 | if self.verbose: 199 | print(message) 200 | print("") 201 | if len(message) > 450000: 202 | return False 203 | ret = make_request(message) 204 | try: 205 | ret_code = sanitize(ret[0], "", codegen=False, global_code=True) 206 | except Exception as e: 207 | return False 208 | if self.verbose: 209 | print("ret_code:", ret_code) 210 | return ret_code 211 | 212 | def generate(self, num: int): 213 | self.iteration = num * 3 214 | while len(self.new_inputs) < num and self.iteration >= 0: 215 | if self.verbose: 216 | print(colored("Got new inputs:" + str(len(self.new_inputs)) + " needs:" + str(num), "green")) 217 | seeds = self.input_seed_selection() 218 | try: 219 | gc.collect() 220 | new_inputs = [self.generate_one(seeds)] 221 | if new_inputs[0] == False: 222 | if self.verbose: 223 | print("Overlong prompt.") 224 | self.iteration -= 2 225 | continue 226 | results = [] 227 | time_limits = [50 for t in new_inputs] 228 | stat, results = untrusted_check( 229 | io=self.io, 230 | code=self.contract, 231 | testcases=new_inputs, 232 | atol=None, 233 | ref_time=time_limits, 234 | check = False, 235 | generator=True, 236 | gt_time_limit_factor=1.0 237 | ) 238 | 239 | correctness, reason = self.check_correctness(results) 240 | 241 | if correctness: 242 | if self.verbose: 243 | print("new_input passed:", new_inputs[0]) 244 | self.seed_pool.append(new_inputs[0]) 245 | self.seed_hash.add(hash(new_inputs[0])) 246 | self.new_inputs.append(new_inputs[0]) 247 | elif self.verbose: 248 | print("new_input failed:", new_inputs[0]) 249 | print("reason:", reason) 250 | self.iteration -= 1 251 | except Exception as e: 252 | self.iteration -= 1 253 | if self.verbose: 254 | print(colored("Error in executing new inputs", "red")) 255 | print(f"Error: {e}") 256 | self.iteration -= 1 257 | continue 258 | return self.new_inputs[:num] 259 | 260 | 261 | 262 | 263 | 264 | 265 | def gen_func_sts(data_file, testcase_file, contract_file, output_file, verbose = False, num = 5): 266 | data = json.load(open(data_file, "r")) 267 | testcases = json.load(open(testcase_file, "r")) 268 | contracts = json.load(open(contract_file, "r")) 269 | 270 | stressful_testcases = {} 271 | 272 | for d in tqdm(data): 273 | if d["final_prompt"] not in testcases or d["final_prompt"] not in contracts: 274 | continue 275 | entry_point = d["entry_point"] 276 | test_inputs = [testcase["input"] for testcase in testcases[d["final_prompt"]]] 277 | generator = FuncSTGen(test_inputs, entry_point, contracts[d["final_prompt"]], verbose = verbose) 278 | st_inputs = generator.generate(num) 279 | stressful_testcases[d["final_prompt"]] = [] 280 | for st_input in st_inputs: 281 | stressful_testcases[d["final_prompt"]].append({"input": st_input, "output": None}) 282 | 283 | with open(output_file, "w", encoding = "utf-8") as f: 284 | f.write(json.dumps(stressful_testcases, sort_keys=True, indent=4, separators=(',', ': '))) 285 | 286 | 287 | def gen_file_sts(data_file, testcase_file, solution_file, contract_file, output_file, verbose = False, num = 5): 288 | data = json.load(open(data_file, "r")) 289 | testcases = json.load(open(testcase_file, "r")) 290 | solutions = json.load(open(solution_file, "r")) 291 | contracts = json.load(open(contract_file, "r")) 292 | 293 | stressful_testcases = {} 294 | 295 | for d in tqdm(data): 296 | if d["final_prompt"] not in testcases or d["final_prompt"] not in contracts or d["final_prompt"] not in solutions: 297 | continue 298 | test_inputs = [testcase["input"] for testcase in testcases[d["final_prompt"]]] 299 | solution, io = solutions[d["final_prompt"]] 300 | prompt = d["final_prompt"].replace("You are an expert Python programmer, and here is your task:\n", "").replace("\n```python", "") 301 | generator = FileSTGen(prompt, io, test_inputs, contracts[d["final_prompt"]], verbose = verbose) 302 | st_inputs = generator.generate(num) 303 | stressful_testcases[d["final_prompt"]] = st_inputs 304 | 305 | with open(output_file, "w", encoding = "utf-8") as f: 306 | f.write(json.dumps(stressful_testcases, sort_keys=True, indent=4, separators=(',', ': '))) 307 | 308 | 309 | if __name__ == "__main__": 310 | #gen_func_sts("Coffe/datasets/function/data.json", "Coffe/datasets/function/testcases.json", "function_contracts.json", "function_sts.json", verbose = True) 311 | gen_file_sts("Coffe/datasets/file/data.json", "Coffe/datasets/file/testcases.json", "Coffe/datasets/file/best_solutions.json", "file_contracts.json", "file_sts.json", verbose = True) 312 | -------------------------------------------------------------------------------- /coffe/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from multiprocessing import Value 3 | import os 4 | from termcolor import colored 5 | import sys 6 | import json 7 | from cirron import Collector 8 | import time 9 | 10 | from coffe.evaluate import evaluate 11 | from coffe.config import benchmarks 12 | from coffe.dataset import Dataset 13 | 14 | 15 | def init(args): 16 | if os.path.exists(args.dataset): 17 | dataset_path = args.dataset 18 | else: 19 | raise FileNotFoundError(f"Cannot find the path for dataset: {args.dataset}!") 20 | if os.path.exists(args.workdir): 21 | workdir = args.workdir 22 | else: 23 | raise FileNotFoundError(f"Working directory path does not exist: {args.workdir}!") 24 | if os.path.exists(args.perf): 25 | perf_path = args.perf 26 | else: 27 | raise FileNotFoundError(f"Working directory path does not exist: {args.workdir}!") 28 | 29 | data = {"dataset": dataset_path, "workdir": workdir, "perf_path": args.perf} 30 | 31 | with open("coffe_init.json", "w", encoding = "utf-8") as f: 32 | f.write(json.dumps(data)) 33 | 34 | try: 35 | with Collector() as c: 36 | time.sleep(0.1) 37 | except: 38 | raise OSError(f"Your OS does not support measuring CPU instruction counts.") 39 | if c.counters.instruction_count <= 0: 40 | raise OSError(f"Your OS does not support measuring CPU instruction counts.") 41 | 42 | for benchmark in benchmarks: 43 | dataset = Dataset(benchmark, data_path = os.path.join(dataset_path, benchmarks[benchmark]["path"])) 44 | dataset.print_info() 45 | 46 | print(colored("Coffe initialized!", "green")) 47 | 48 | def check_init(): 49 | if os.path.exists("coffe_init.json"): 50 | data = json.load(open("coffe_init.json", "r")) 51 | if "dataset" in data and "workdir" in data: 52 | if data["workdir"] != os.getcwd(): 53 | print("Your current dir is {}, but your working dir of coffe is set to {}.".format(data["workdir"], os.getcwd())) 54 | print("This may cause potential errors, consider initialize coffe again.") 55 | exit() 56 | return data["dataset"], data["workdir"], data["perf_path"] 57 | else: 58 | raise ValueError("Coffe configuration corrupted, please initialize coffe again.") 59 | else: 60 | raise ValueError("You must initialize Coffe first!") 61 | 62 | def info(args): 63 | print("Welcome to use Coffe!") 64 | print("Coffe is a time efficiency evaluation framework for Python code generation.") 65 | print("For more details, please see https://github.com/JohnnyPeng18/Coffe.") 66 | print("Use `coffe init [dataset_path] [workdir_path]` to initialize Coffe first!") 67 | print("To see the options of Coffe, please use -h option.") 68 | 69 | def check_input_file(filename, metric): 70 | if metric == "correctness" and not filename.endswith("SOLUTIONS.json"): 71 | raise ValueError("The filename of the prediction file must ends with SOLUTIONS.json to evaluate correctness.") 72 | if metric in ["time", "instr_count"] and not filename.endswith("PASSED_SOLUTIONS.json"): 73 | raise ValueError("The filename of the prediction file must ends with PASSED_SOLUTIONS.json to evaluate time efficiency.") 74 | 75 | def eval(args): 76 | if not (hasattr(args, "checked_init") and args.checked_init): 77 | args.dataset_path, args.work_dir, args.perf_path = check_init() 78 | if hasattr(args, "command") and args.command: 79 | command = args.command 80 | else: 81 | command = 'coffe eval ' + ' '.join(sys.argv[2:]) 82 | if "-x" in command: 83 | print(colored("You are running the code on your host machine using -x option, this may cause security issues.", "red")) 84 | 85 | if args.parallel_num > 0 and args.host_machine: 86 | print(colored("You cannot use multiple workers because you are on your host machine with option -w.")) 87 | args.parallel_num = 0 88 | 89 | if args.host_machine: 90 | args.single_worker = True 91 | 92 | if args.final_metric: 93 | evaluate(args, command) 94 | return None 95 | 96 | if args.dataset in ["codeparrot/apps", "deepmind/code_contests", "file"] and "generator" not in args.extra_options: 97 | args.extra_options += "generator" 98 | command += " -e generator" 99 | 100 | if args.metric == "correctness" and (not args.single_worker or args.host_machine): 101 | dataset = Dataset(args.dataset, data_path = os.path.join(args.dataset_path, benchmarks[args.dataset]["path"])) 102 | dataset.load_best_groundtruths() 103 | results = {} 104 | indexes = {} 105 | for prompt in dataset.prompt2bestgroundtruth: 106 | results[prompt] = [dataset.prompt2bestgroundtruth[prompt]] 107 | indexes[prompt] = ["gt"] 108 | 109 | if "," in args.prediction and args.metric in ["compilable_rate", "correctness"]: 110 | print("Handling multiple prediction files...") 111 | prediction_files = args.prediction.split(",") 112 | ori_pred = args.prediction 113 | 114 | for pred in prediction_files: 115 | args.prediction = pred 116 | check_input_file(pred, args.metric) 117 | sub_command = command.replace(ori_pred, pred) 118 | evaluate(args, sub_command) 119 | if args.metric == "correctness" and (not args.single_worker or args.host_machine): 120 | res = json.load(open(pred.replace("_SOLUTIONS.json", "_PASSED_SOLUTIONS.json"), "r")) 121 | for prompt in results: 122 | results[prompt] += res[prompt] 123 | indexes[prompt] += [pred.replace("_SOLUTIONS.json", "")] * len(res[prompt]) 124 | 125 | if args.metric == "correctness" and (not args.single_worker or args.host_machine): 126 | dataset_name = args.dataset.replace("/", "_") 127 | solution_file = os.path.join(os.path.dirname(pred), f"{dataset_name}_all_PASSED_SOLUTIONS.json") 128 | index_file = os.path.join(os.path.dirname(pred), f"{dataset_name}_all_indexes.json") 129 | with open(solution_file, "w", encoding = "utf-8") as f: 130 | f.write(json.dumps(results, sort_keys=True, indent=4, separators=(',', ': '))) 131 | with open(index_file, "w", encoding = "utf-8") as f: 132 | f.write(json.dumps(indexes, sort_keys=True, indent=4, separators=(',', ': '))) 133 | print(f"Correctness evaluation complete, all predictions and groundtruths have been saved to {solution_file} and {index_file}.") 134 | print("Please use the above files to do performance measurement.") 135 | 136 | elif "," in args.prediction: 137 | raise ValueError("You could only give multiple prediction files when evaluating compilable rate and correctness!") 138 | else: 139 | check_input_file(args.prediction, args.metric) 140 | evaluate(args, command) 141 | if args.metric == "correctness" and (not args.single_worker or args.host_machine): 142 | res = json.load(open(args.prediction.replace("_SOLUTIONS.json", "_PASSED_SOLUTIONS.json"), "r")) 143 | for prompt in results: 144 | results[prompt] += res[prompt] 145 | indexes[prompt] += [args.prediction.replace("_SOLUTIONS.json", "")] * len(res[prompt]) 146 | dataset_name = args.dataset.replace("/", "_") 147 | solution_file = os.path.join(os.path.dirname(args.prediction), f"{dataset_name}_all_PASSED_SOLUTIONS.json") 148 | index_file = os.path.join(os.path.dirname(args.prediction), f"{dataset_name}_all_indexes.json") 149 | with open(solution_file, "w", encoding = "utf-8") as f: 150 | f.write(json.dumps(results, sort_keys=True, indent=4, separators=(',', ': '))) 151 | with open(index_file, "w", encoding = "utf-8") as f: 152 | f.write(json.dumps(indexes, sort_keys=True, indent=4, separators=(',', ': '))) 153 | print(f"Correctness evaluation complete, all predictions and groundtruths have been saved to {solution_file} and {index_file}.") 154 | print("Please use the above files to do performance measurement.") 155 | 156 | def pipe(args): 157 | args.dataset_path, args.work_dir, args.perf_path = check_init() 158 | args.stressful = True 159 | args.single_worker = False 160 | args.index = -1 161 | args.subset = "" 162 | args.output_testcase = False 163 | 164 | args.checked_init = True 165 | 166 | if args.final_metric not in ["speedup", "efficient_at_1"]: 167 | raise ValueError("The final metric could only be speedup or efficient_at_1 in pipeline mode.") 168 | 169 | final_metric = args.final_metric if args.final_metric else "" 170 | 171 | args.final_metric = None 172 | 173 | ori_prediction = args.prediction 174 | 175 | 176 | print(colored("+++++++++++Step 1: Checking Syntax Errors...", "green")) 177 | command = 'coffe eval ' + ' '.join(sys.argv[2:]).replace("-f ", "").replace(final_metric, "") 178 | command += " -m compilable_rate" 179 | args.metric = "compilable_rate" 180 | print(f"Executing Command: {command}...") 181 | args.command = command 182 | eval(args) 183 | print(colored("Done!", "green")) 184 | 185 | print(colored("+++++++++++Step 2: Checking Correctness...", "green")) 186 | command = 'coffe eval ' + ' '.join(sys.argv[2:]).replace(args.metric, "correctness").replace("-f ", "").replace(final_metric, "") 187 | command = command.replace(ori_prediction, ori_prediction.replace(".json", "_SOLUTIONS.json")) 188 | command += " -m correctness" 189 | args.prediction = ori_prediction.replace(".json", "_SOLUTIONS.json") 190 | args.metric = "correctness" 191 | print(f"Executing Command: {command}...") 192 | args.command = command 193 | eval(args) 194 | print(colored("Done!", "green")) 195 | args.stressful = True 196 | 197 | print(colored("+++++++++++Step 3: Measuring GPU Instruction Count...", "green")) 198 | if "," in ori_prediction: 199 | dirname = os.path.dirname(ori_prediction.split(",")[-1]) 200 | else: 201 | dirname = os.path.dirname(ori_prediction) 202 | dataset_name = args.dataset.replace("/", "_") 203 | args.prediction = os.path.join(dirname, f"{dataset_name}_all_PASSED_SOLUTIONS.json") 204 | command = 'coffe eval ' + ' '.join(sys.argv[2:]).replace(args.metric, "instr_count").replace("-f ", "").replace(final_metric, "") 205 | command = command.replace(ori_prediction, args.prediction) 206 | command += " -m instr_count" 207 | 208 | if args.dataset in ["codeparrot/apps", "deepmind/code_contests", "file"] and "generator" not in args.extra_options: 209 | args.extra_options = "generator" 210 | command += " -e generator" 211 | 212 | args.metric = "instr_count" 213 | print(f"Executing Command: {command}...") 214 | args.command = command 215 | eval(args) 216 | print(colored("Done!", "green")) 217 | 218 | print(colored("Measurement Finished. CPU instruction count results stored into {}".format(args.prediction.replace("_PASSED_SOLUTIONS.json", "_STRESSFUL_INSTRUCTION.json")), "green")) 219 | 220 | print(colored("+++++++++++Step 4: Calculating Metrics...", "green")) 221 | args.final_metric = final_metric 222 | args.single_worker = False 223 | command = 'coffe eval ' + ' '.join(sys.argv[2:]) 224 | command = command.replace(ori_prediction, args.prediction.replace("_PASSED_SOLUTIONS.json", "_indexes.json") + "," + args.prediction.replace("_PASSED_SOLUTIONS.json", "_STRESSFUL_INSTRUCTION.json")) 225 | command += " -m instr_count" 226 | args.prediction = args.prediction.replace("_PASSED_SOLUTIONS.json", "_indexes.json") + "," + args.prediction.replace("_PASSED_SOLUTIONS.json", "_STRESSFUL_INSTRUCTION.json") 227 | args.metric = "instr_count" 228 | print(f"Executing Command: {command}...") 229 | args.command = command 230 | eval(args) 231 | print(colored("Metrics result written into file: {}".format(os.path.join(args.output_path, f"{args.final_metric}_results.json")), "green")) 232 | 233 | print(colored("Pipeline Finished!", "green")) 234 | 235 | def main(): 236 | arg_parser = argparse.ArgumentParser() 237 | sub_parsers = arg_parser.add_subparsers(dest='cmd') 238 | arg_parser.set_defaults(func = info) 239 | 240 | init_parser = sub_parsers.add_parser('init') 241 | init_parser.add_argument('-d', '--dataset', required = False, default= os.path.join("Coffe", "datasets"), type=str, help = "Path to the COFFE benchmark location") 242 | init_parser.add_argument('-w', '--workdir', required = False, default=os.getcwd(), type=str, help = "The working directory for dockers and results") 243 | init_parser.add_argument('-p', '--perf', required = False, default=os.path.join("Coffe", "perf.json"), type=str, help = "Path to the COFFE perf.json config file") 244 | init_parser.set_defaults(func = init) 245 | 246 | 247 | evaluate_parser = sub_parsers.add_parser('eval') 248 | evaluate_parser.add_argument('dataset', help = "Benchmark name") 249 | evaluate_parser.add_argument('output_path', help = "Path to the output directory") 250 | evaluate_parser.add_argument('-p', '--prediction', required = False, type = str, help = "Path to the prediction file") 251 | evaluate_parser.add_argument('-i', '--index', required = False, default = -1, type = int, help = "The index of workers in parallel processing, should NOT be manually set") 252 | evaluate_parser.add_argument('-n', '--parallel_num', required = False, default = 0, type = int, help = "The number of workers in parallel processing") 253 | evaluate_parser.add_argument('-s', '--subset', required = False, default = "", type = str, help = "Path to the file of a subset of indexes") 254 | evaluate_parser.add_argument('-r', '--stressful', required = False, default = True, action = "store_true", help = "Enable stressful test cases") 255 | evaluate_parser.add_argument('-t', '--output_testcase', required = False, default = False, action = "store_true", help = "Output the most expensive test case for time/instr_count measurement or the first failed case for correctness check") 256 | evaluate_parser.add_argument('-m', '--metric', required = False, default = "correctness", type = str, help = "The metric to be evaluated, can be compilable_rate, correctness, time, or instr_count for code solutions and testcase_compilable_rate, accuracy, testcase_time, testcase_instr_count, coverage for test cases, and testcase_solution_time, testcase_solution_instr_count for test cases on predictions") 257 | evaluate_parser.add_argument('-e', '--extra_options', required = False, default = "", type = str, help = "Extra options for the evaluation") 258 | evaluate_parser.add_argument('-w', '--single_worker', required = False, default = False, action = "store_true", help = "Running as a single internal worker instead of calling docker containers. DANGEROUS!") 259 | evaluate_parser.add_argument('-x', '--host_machine', required = False, default = False, action = "store_true", help = "Running code on host machine instead of calling docker containers. DANGEROUS!") 260 | evaluate_parser.add_argument('-f', '--final_metric', required = False, type = str, help = "The final metric calculated based on the measurement results, can be correlation, rsd, pass_k, line_coverage, branch_coverage, max, avg, testcase_compilable_rate, accuracy, efficient_at_1") 261 | evaluate_parser.set_defaults(func = eval) 262 | 263 | pipeline_parser = sub_parsers.add_parser('pipe') 264 | pipeline_parser.add_argument('dataset', help = "Benchmark name") 265 | pipeline_parser.add_argument('output_path', help = "Path to the output directory") 266 | pipeline_parser.add_argument('-p', '--prediction', required = False, type = str, help = "Path to the prediction file") 267 | pipeline_parser.add_argument('-n', '--parallel_num', required = False, default = 0, type = int, help = "The number of workers in parallel processing") 268 | pipeline_parser.add_argument('-e', '--extra_options', required = False, default = "", type = str, help = "Extra options for the evaluation") 269 | pipeline_parser.add_argument('-x', '--host_machine', required = False, default = False, action = "store_true", help = "Running code on host machine instead of calling docker containers. DANGEROUS!") 270 | pipeline_parser.add_argument('-f', '--final_metric', required = False, type = str, help = "The final metric calculated based on the measurement results, can be speedup or efficient_at_1.") 271 | pipeline_parser.set_defaults(func = pipe) 272 | 273 | 274 | args = arg_parser.parse_args() 275 | args.func(args) -------------------------------------------------------------------------------- /perf.json: -------------------------------------------------------------------------------- 1 | { 2 | "defaultAction": "SCMP_ACT_ERRNO", 3 | "defaultErrnoRet": 1, 4 | "archMap": [ 5 | { 6 | "architecture": "SCMP_ARCH_X86_64", 7 | "subArchitectures": [ 8 | "SCMP_ARCH_X86", 9 | "SCMP_ARCH_X32" 10 | ] 11 | }, 12 | { 13 | "architecture": "SCMP_ARCH_AARCH64", 14 | "subArchitectures": [ 15 | "SCMP_ARCH_ARM" 16 | ] 17 | }, 18 | { 19 | "architecture": "SCMP_ARCH_MIPS64", 20 | "subArchitectures": [ 21 | "SCMP_ARCH_MIPS", 22 | "SCMP_ARCH_MIPS64N32" 23 | ] 24 | }, 25 | { 26 | "architecture": "SCMP_ARCH_MIPS64N32", 27 | "subArchitectures": [ 28 | "SCMP_ARCH_MIPS", 29 | "SCMP_ARCH_MIPS64" 30 | ] 31 | }, 32 | { 33 | "architecture": "SCMP_ARCH_MIPSEL64", 34 | "subArchitectures": [ 35 | "SCMP_ARCH_MIPSEL", 36 | "SCMP_ARCH_MIPSEL64N32" 37 | ] 38 | }, 39 | { 40 | "architecture": "SCMP_ARCH_MIPSEL64N32", 41 | "subArchitectures": [ 42 | "SCMP_ARCH_MIPSEL", 43 | "SCMP_ARCH_MIPSEL64" 44 | ] 45 | }, 46 | { 47 | "architecture": "SCMP_ARCH_S390X", 48 | "subArchitectures": [ 49 | "SCMP_ARCH_S390" 50 | ] 51 | }, 52 | { 53 | "architecture": "SCMP_ARCH_RISCV64", 54 | "subArchitectures": null 55 | } 56 | ], 57 | "syscalls": [ 58 | { 59 | "names": [ 60 | "accept", 61 | "accept4", 62 | "access", 63 | "adjtimex", 64 | "alarm", 65 | "bind", 66 | "brk", 67 | "cachestat", 68 | "capget", 69 | "capset", 70 | "chdir", 71 | "chmod", 72 | "chown", 73 | "chown32", 74 | "clock_adjtime", 75 | "clock_adjtime64", 76 | "clock_getres", 77 | "clock_getres_time64", 78 | "clock_gettime", 79 | "clock_gettime64", 80 | "clock_nanosleep", 81 | "clock_nanosleep_time64", 82 | "close", 83 | "close_range", 84 | "connect", 85 | "copy_file_range", 86 | "creat", 87 | "dup", 88 | "dup2", 89 | "dup3", 90 | "epoll_create", 91 | "epoll_create1", 92 | "epoll_ctl", 93 | "epoll_ctl_old", 94 | "epoll_pwait", 95 | "epoll_pwait2", 96 | "epoll_wait", 97 | "epoll_wait_old", 98 | "eventfd", 99 | "eventfd2", 100 | "execve", 101 | "execveat", 102 | "exit", 103 | "exit_group", 104 | "faccessat", 105 | "faccessat2", 106 | "fadvise64", 107 | "fadvise64_64", 108 | "fallocate", 109 | "fanotify_mark", 110 | "fchdir", 111 | "fchmod", 112 | "fchmodat", 113 | "fchmodat2", 114 | "fchown", 115 | "fchown32", 116 | "fchownat", 117 | "fcntl", 118 | "fcntl64", 119 | "fdatasync", 120 | "fgetxattr", 121 | "flistxattr", 122 | "flock", 123 | "fork", 124 | "fremovexattr", 125 | "fsetxattr", 126 | "fstat", 127 | "fstat64", 128 | "fstatat64", 129 | "fstatfs", 130 | "fstatfs64", 131 | "fsync", 132 | "ftruncate", 133 | "ftruncate64", 134 | "futex", 135 | "futex_requeue", 136 | "futex_time64", 137 | "futex_wait", 138 | "futex_waitv", 139 | "futex_wake", 140 | "futimesat", 141 | "getcpu", 142 | "getcwd", 143 | "getdents", 144 | "getdents64", 145 | "getegid", 146 | "getegid32", 147 | "geteuid", 148 | "geteuid32", 149 | "getgid", 150 | "getgid32", 151 | "getgroups", 152 | "getgroups32", 153 | "getitimer", 154 | "getpeername", 155 | "getpgid", 156 | "getpgrp", 157 | "getpid", 158 | "getppid", 159 | "getpriority", 160 | "getrandom", 161 | "getresgid", 162 | "getresgid32", 163 | "getresuid", 164 | "getresuid32", 165 | "getrlimit", 166 | "get_robust_list", 167 | "getrusage", 168 | "getsid", 169 | "getsockname", 170 | "getsockopt", 171 | "get_thread_area", 172 | "gettid", 173 | "gettimeofday", 174 | "getuid", 175 | "getuid32", 176 | "getxattr", 177 | "inotify_add_watch", 178 | "inotify_init", 179 | "inotify_init1", 180 | "inotify_rm_watch", 181 | "io_cancel", 182 | "ioctl", 183 | "io_destroy", 184 | "io_getevents", 185 | "io_pgetevents", 186 | "io_pgetevents_time64", 187 | "ioprio_get", 188 | "ioprio_set", 189 | "io_setup", 190 | "io_submit", 191 | "ipc", 192 | "kill", 193 | "landlock_add_rule", 194 | "landlock_create_ruleset", 195 | "landlock_restrict_self", 196 | "lchown", 197 | "lchown32", 198 | "lgetxattr", 199 | "link", 200 | "linkat", 201 | "listen", 202 | "listxattr", 203 | "llistxattr", 204 | "_llseek", 205 | "lremovexattr", 206 | "lseek", 207 | "lsetxattr", 208 | "lstat", 209 | "lstat64", 210 | "madvise", 211 | "map_shadow_stack", 212 | "membarrier", 213 | "memfd_create", 214 | "memfd_secret", 215 | "mincore", 216 | "mkdir", 217 | "mkdirat", 218 | "mknod", 219 | "mknodat", 220 | "mlock", 221 | "mlock2", 222 | "mlockall", 223 | "mmap", 224 | "mmap2", 225 | "mprotect", 226 | "mq_getsetattr", 227 | "mq_notify", 228 | "mq_open", 229 | "mq_timedreceive", 230 | "mq_timedreceive_time64", 231 | "mq_timedsend", 232 | "mq_timedsend_time64", 233 | "mq_unlink", 234 | "mremap", 235 | "msgctl", 236 | "msgget", 237 | "msgrcv", 238 | "msgsnd", 239 | "msync", 240 | "munlock", 241 | "munlockall", 242 | "munmap", 243 | "name_to_handle_at", 244 | "nanosleep", 245 | "newfstatat", 246 | "_newselect", 247 | "open", 248 | "openat", 249 | "openat2", 250 | "perf_event_open", 251 | "pause", 252 | "pidfd_open", 253 | "pidfd_send_signal", 254 | "pipe", 255 | "pipe2", 256 | "pkey_alloc", 257 | "pkey_free", 258 | "pkey_mprotect", 259 | "poll", 260 | "ppoll", 261 | "ppoll_time64", 262 | "prctl", 263 | "pread64", 264 | "preadv", 265 | "preadv2", 266 | "prlimit64", 267 | "process_mrelease", 268 | "pselect6", 269 | "pselect6_time64", 270 | "pwrite64", 271 | "pwritev", 272 | "pwritev2", 273 | "read", 274 | "readahead", 275 | "readlink", 276 | "readlinkat", 277 | "readv", 278 | "recv", 279 | "recvfrom", 280 | "recvmmsg", 281 | "recvmmsg_time64", 282 | "recvmsg", 283 | "remap_file_pages", 284 | "removexattr", 285 | "rename", 286 | "renameat", 287 | "renameat2", 288 | "restart_syscall", 289 | "rmdir", 290 | "rseq", 291 | "rt_sigaction", 292 | "rt_sigpending", 293 | "rt_sigprocmask", 294 | "rt_sigqueueinfo", 295 | "rt_sigreturn", 296 | "rt_sigsuspend", 297 | "rt_sigtimedwait", 298 | "rt_sigtimedwait_time64", 299 | "rt_tgsigqueueinfo", 300 | "sched_getaffinity", 301 | "sched_getattr", 302 | "sched_getparam", 303 | "sched_get_priority_max", 304 | "sched_get_priority_min", 305 | "sched_getscheduler", 306 | "sched_rr_get_interval", 307 | "sched_rr_get_interval_time64", 308 | "sched_setaffinity", 309 | "sched_setattr", 310 | "sched_setparam", 311 | "sched_setscheduler", 312 | "sched_yield", 313 | "seccomp", 314 | "select", 315 | "semctl", 316 | "semget", 317 | "semop", 318 | "semtimedop", 319 | "semtimedop_time64", 320 | "send", 321 | "sendfile", 322 | "sendfile64", 323 | "sendmmsg", 324 | "sendmsg", 325 | "sendto", 326 | "setfsgid", 327 | "setfsgid32", 328 | "setfsuid", 329 | "setfsuid32", 330 | "setgid", 331 | "setgid32", 332 | "setgroups", 333 | "setgroups32", 334 | "setitimer", 335 | "setpgid", 336 | "setpriority", 337 | "setregid", 338 | "setregid32", 339 | "setresgid", 340 | "setresgid32", 341 | "setresuid", 342 | "setresuid32", 343 | "setreuid", 344 | "setreuid32", 345 | "setrlimit", 346 | "set_robust_list", 347 | "setsid", 348 | "setsockopt", 349 | "set_thread_area", 350 | "set_tid_address", 351 | "setuid", 352 | "setuid32", 353 | "setxattr", 354 | "shmat", 355 | "shmctl", 356 | "shmdt", 357 | "shmget", 358 | "shutdown", 359 | "sigaltstack", 360 | "signalfd", 361 | "signalfd4", 362 | "sigprocmask", 363 | "sigreturn", 364 | "socketcall", 365 | "socketpair", 366 | "splice", 367 | "stat", 368 | "stat64", 369 | "statfs", 370 | "statfs64", 371 | "statx", 372 | "symlink", 373 | "symlinkat", 374 | "sync", 375 | "sync_file_range", 376 | "syncfs", 377 | "sysinfo", 378 | "tee", 379 | "tgkill", 380 | "time", 381 | "timer_create", 382 | "timer_delete", 383 | "timer_getoverrun", 384 | "timer_gettime", 385 | "timer_gettime64", 386 | "timer_settime", 387 | "timer_settime64", 388 | "timerfd_create", 389 | "timerfd_gettime", 390 | "timerfd_gettime64", 391 | "timerfd_settime", 392 | "timerfd_settime64", 393 | "times", 394 | "tkill", 395 | "truncate", 396 | "truncate64", 397 | "ugetrlimit", 398 | "umask", 399 | "uname", 400 | "unlink", 401 | "unlinkat", 402 | "utime", 403 | "utimensat", 404 | "utimensat_time64", 405 | "utimes", 406 | "vfork", 407 | "vmsplice", 408 | "wait4", 409 | "waitid", 410 | "waitpid", 411 | "write", 412 | "writev" 413 | ], 414 | "action": "SCMP_ACT_ALLOW" 415 | }, 416 | { 417 | "names": [ 418 | "process_vm_readv", 419 | "process_vm_writev", 420 | "ptrace" 421 | ], 422 | "action": "SCMP_ACT_ALLOW", 423 | "includes": { 424 | "minKernel": "4.8" 425 | } 426 | }, 427 | { 428 | "names": [ 429 | "socket" 430 | ], 431 | "action": "SCMP_ACT_ALLOW", 432 | "args": [ 433 | { 434 | "index": 0, 435 | "value": 40, 436 | "op": "SCMP_CMP_NE" 437 | } 438 | ] 439 | }, 440 | { 441 | "names": [ 442 | "personality" 443 | ], 444 | "action": "SCMP_ACT_ALLOW", 445 | "args": [ 446 | { 447 | "index": 0, 448 | "value": 0, 449 | "op": "SCMP_CMP_EQ" 450 | } 451 | ] 452 | }, 453 | { 454 | "names": [ 455 | "personality" 456 | ], 457 | "action": "SCMP_ACT_ALLOW", 458 | "args": [ 459 | { 460 | "index": 0, 461 | "value": 8, 462 | "op": "SCMP_CMP_EQ" 463 | } 464 | ] 465 | }, 466 | { 467 | "names": [ 468 | "personality" 469 | ], 470 | "action": "SCMP_ACT_ALLOW", 471 | "args": [ 472 | { 473 | "index": 0, 474 | "value": 131072, 475 | "op": "SCMP_CMP_EQ" 476 | } 477 | ] 478 | }, 479 | { 480 | "names": [ 481 | "personality" 482 | ], 483 | "action": "SCMP_ACT_ALLOW", 484 | "args": [ 485 | { 486 | "index": 0, 487 | "value": 131080, 488 | "op": "SCMP_CMP_EQ" 489 | } 490 | ] 491 | }, 492 | { 493 | "names": [ 494 | "personality" 495 | ], 496 | "action": "SCMP_ACT_ALLOW", 497 | "args": [ 498 | { 499 | "index": 0, 500 | "value": 4294967295, 501 | "op": "SCMP_CMP_EQ" 502 | } 503 | ] 504 | }, 505 | { 506 | "names": [ 507 | "sync_file_range2", 508 | "swapcontext" 509 | ], 510 | "action": "SCMP_ACT_ALLOW", 511 | "includes": { 512 | "arches": [ 513 | "ppc64le" 514 | ] 515 | } 516 | }, 517 | { 518 | "names": [ 519 | "arm_fadvise64_64", 520 | "arm_sync_file_range", 521 | "sync_file_range2", 522 | "breakpoint", 523 | "cacheflush", 524 | "set_tls" 525 | ], 526 | "action": "SCMP_ACT_ALLOW", 527 | "includes": { 528 | "arches": [ 529 | "arm", 530 | "arm64" 531 | ] 532 | } 533 | }, 534 | { 535 | "names": [ 536 | "arch_prctl" 537 | ], 538 | "action": "SCMP_ACT_ALLOW", 539 | "includes": { 540 | "arches": [ 541 | "amd64", 542 | "x32" 543 | ] 544 | } 545 | }, 546 | { 547 | "names": [ 548 | "modify_ldt" 549 | ], 550 | "action": "SCMP_ACT_ALLOW", 551 | "includes": { 552 | "arches": [ 553 | "amd64", 554 | "x32", 555 | "x86" 556 | ] 557 | } 558 | }, 559 | { 560 | "names": [ 561 | "s390_pci_mmio_read", 562 | "s390_pci_mmio_write", 563 | "s390_runtime_instr" 564 | ], 565 | "action": "SCMP_ACT_ALLOW", 566 | "includes": { 567 | "arches": [ 568 | "s390", 569 | "s390x" 570 | ] 571 | } 572 | }, 573 | { 574 | "names": [ 575 | "riscv_flush_icache" 576 | ], 577 | "action": "SCMP_ACT_ALLOW", 578 | "includes": { 579 | "arches": [ 580 | "riscv64" 581 | ] 582 | } 583 | }, 584 | { 585 | "names": [ 586 | "open_by_handle_at" 587 | ], 588 | "action": "SCMP_ACT_ALLOW", 589 | "includes": { 590 | "caps": [ 591 | "CAP_DAC_READ_SEARCH" 592 | ] 593 | } 594 | }, 595 | { 596 | "names": [ 597 | "bpf", 598 | "clone", 599 | "clone3", 600 | "fanotify_init", 601 | "fsconfig", 602 | "fsmount", 603 | "fsopen", 604 | "fspick", 605 | "lookup_dcookie", 606 | "mount", 607 | "mount_setattr", 608 | "move_mount", 609 | "open_tree", 610 | "perf_event_open", 611 | "quotactl", 612 | "quotactl_fd", 613 | "setdomainname", 614 | "sethostname", 615 | "setns", 616 | "syslog", 617 | "umount", 618 | "umount2", 619 | "unshare" 620 | ], 621 | "action": "SCMP_ACT_ALLOW", 622 | "includes": { 623 | "caps": [ 624 | "CAP_SYS_ADMIN" 625 | ] 626 | } 627 | }, 628 | { 629 | "names": [ 630 | "clone" 631 | ], 632 | "action": "SCMP_ACT_ALLOW", 633 | "args": [ 634 | { 635 | "index": 0, 636 | "value": 2114060288, 637 | "op": "SCMP_CMP_MASKED_EQ" 638 | } 639 | ], 640 | "excludes": { 641 | "caps": [ 642 | "CAP_SYS_ADMIN" 643 | ], 644 | "arches": [ 645 | "s390", 646 | "s390x" 647 | ] 648 | } 649 | }, 650 | { 651 | "names": [ 652 | "clone" 653 | ], 654 | "action": "SCMP_ACT_ALLOW", 655 | "args": [ 656 | { 657 | "index": 1, 658 | "value": 2114060288, 659 | "op": "SCMP_CMP_MASKED_EQ" 660 | } 661 | ], 662 | "comment": "s390 parameter ordering for clone is different", 663 | "includes": { 664 | "arches": [ 665 | "s390", 666 | "s390x" 667 | ] 668 | }, 669 | "excludes": { 670 | "caps": [ 671 | "CAP_SYS_ADMIN" 672 | ] 673 | } 674 | }, 675 | { 676 | "names": [ 677 | "clone3" 678 | ], 679 | "action": "SCMP_ACT_ERRNO", 680 | "errnoRet": 38, 681 | "excludes": { 682 | "caps": [ 683 | "CAP_SYS_ADMIN" 684 | ] 685 | } 686 | }, 687 | { 688 | "names": [ 689 | "reboot" 690 | ], 691 | "action": "SCMP_ACT_ALLOW", 692 | "includes": { 693 | "caps": [ 694 | "CAP_SYS_BOOT" 695 | ] 696 | } 697 | }, 698 | { 699 | "names": [ 700 | "chroot" 701 | ], 702 | "action": "SCMP_ACT_ALLOW", 703 | "includes": { 704 | "caps": [ 705 | "CAP_SYS_CHROOT" 706 | ] 707 | } 708 | }, 709 | { 710 | "names": [ 711 | "delete_module", 712 | "init_module", 713 | "finit_module" 714 | ], 715 | "action": "SCMP_ACT_ALLOW", 716 | "includes": { 717 | "caps": [ 718 | "CAP_SYS_MODULE" 719 | ] 720 | } 721 | }, 722 | { 723 | "names": [ 724 | "acct" 725 | ], 726 | "action": "SCMP_ACT_ALLOW", 727 | "includes": { 728 | "caps": [ 729 | "CAP_SYS_PACCT" 730 | ] 731 | } 732 | }, 733 | { 734 | "names": [ 735 | "kcmp", 736 | "pidfd_getfd", 737 | "process_madvise", 738 | "process_vm_readv", 739 | "process_vm_writev", 740 | "ptrace" 741 | ], 742 | "action": "SCMP_ACT_ALLOW", 743 | "includes": { 744 | "caps": [ 745 | "CAP_SYS_PTRACE" 746 | ] 747 | } 748 | }, 749 | { 750 | "names": [ 751 | "iopl", 752 | "ioperm" 753 | ], 754 | "action": "SCMP_ACT_ALLOW", 755 | "includes": { 756 | "caps": [ 757 | "CAP_SYS_RAWIO" 758 | ] 759 | } 760 | }, 761 | { 762 | "names": [ 763 | "settimeofday", 764 | "stime", 765 | "clock_settime", 766 | "clock_settime64" 767 | ], 768 | "action": "SCMP_ACT_ALLOW", 769 | "includes": { 770 | "caps": [ 771 | "CAP_SYS_TIME" 772 | ] 773 | } 774 | }, 775 | { 776 | "names": [ 777 | "vhangup" 778 | ], 779 | "action": "SCMP_ACT_ALLOW", 780 | "includes": { 781 | "caps": [ 782 | "CAP_SYS_TTY_CONFIG" 783 | ] 784 | } 785 | }, 786 | { 787 | "names": [ 788 | "get_mempolicy", 789 | "mbind", 790 | "set_mempolicy", 791 | "set_mempolicy_home_node" 792 | ], 793 | "action": "SCMP_ACT_ALLOW", 794 | "includes": { 795 | "caps": [ 796 | "CAP_SYS_NICE" 797 | ] 798 | } 799 | }, 800 | { 801 | "names": [ 802 | "syslog" 803 | ], 804 | "action": "SCMP_ACT_ALLOW", 805 | "includes": { 806 | "caps": [ 807 | "CAP_SYSLOG" 808 | ] 809 | } 810 | }, 811 | { 812 | "names": [ 813 | "bpf" 814 | ], 815 | "action": "SCMP_ACT_ALLOW", 816 | "includes": { 817 | "caps": [ 818 | "CAP_BPF" 819 | ] 820 | } 821 | }, 822 | { 823 | "names": [ 824 | "perf_event_open" 825 | ], 826 | "action": "SCMP_ACT_ALLOW", 827 | "includes": { 828 | "caps": [ 829 | "CAP_PERFMON" 830 | ] 831 | } 832 | } 833 | ] 834 | } -------------------------------------------------------------------------------- /coffe/code_execution.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import faulthandler 3 | import io 4 | import os 5 | import platform 6 | import signal 7 | import tempfile 8 | from typing import Optional 9 | 10 | import itertools 11 | import multiprocessing 12 | import os 13 | import time 14 | from multiprocessing import Array, Value 15 | from typing import Any, Dict, List, Tuple, Union 16 | 17 | import numpy as np 18 | 19 | from io import StringIO 20 | import sys 21 | 22 | from unittest.mock import patch, mock_open 23 | 24 | import traceback 25 | 26 | from cirron import Collector 27 | 28 | from coverage import Coverage 29 | 30 | import random 31 | 32 | 33 | _PASS = "pass" 34 | _FAIL = "fail" 35 | _TIMEOUT = "timeout" 36 | 37 | SUCCEED = 1 38 | FAILED = -1 39 | TIMEOUT = -2 40 | UNKNOWN = -3 41 | 42 | _mapping = {SUCCEED: _PASS, FAILED: _FAIL, TIMEOUT: _TIMEOUT, UNKNOWN: None} 43 | 44 | INF = 9999999999999999 45 | 46 | class Capturing(list): 47 | def __enter__(self): 48 | self._stdout = sys.stdout 49 | sys.stdout = self._stringio = StringIO() 50 | # Make closing the StringIO a no-op 51 | self._stringio.close = lambda x: 1 52 | return self 53 | def __exit__(self, *args): 54 | self.append(self._stringio.getvalue()) 55 | del self._stringio # free up some memory 56 | sys.stdout = self._stdout 57 | 58 | @contextlib.contextmanager 59 | def swallow_io(): 60 | stream = WriteOnlyStringIO() 61 | with contextlib.redirect_stdout(stream): 62 | with contextlib.redirect_stderr(stream): 63 | with redirect_stdin(stream): 64 | yield 65 | 66 | 67 | 68 | @contextlib.contextmanager 69 | def output_io(out): 70 | with contextlib.redirect_stdout(out): 71 | with contextlib.redirect_stderr(out): 72 | yield 73 | 74 | 75 | @contextlib.contextmanager 76 | def time_limit(seconds: float): 77 | def signal_handler(signum, frame): 78 | raise TimeoutException("Timed out!") 79 | 80 | signal.setitimer(signal.ITIMER_REAL, seconds) 81 | signal.signal(signal.SIGALRM, signal_handler) 82 | try: 83 | yield 84 | finally: 85 | signal.setitimer(signal.ITIMER_REAL, 0) 86 | 87 | 88 | @contextlib.contextmanager 89 | def create_tempdir(): 90 | with tempfile.TemporaryDirectory() as dirname: 91 | with chdir(dirname): 92 | yield dirname 93 | 94 | 95 | @contextlib.contextmanager 96 | def chdir(root): 97 | if root == ".": 98 | yield 99 | return 100 | cwd = os.getcwd() 101 | os.chdir(root) 102 | try: 103 | yield 104 | except BaseException as exc: 105 | raise exc 106 | finally: 107 | os.chdir(cwd) 108 | 109 | 110 | class TimeoutException(Exception): 111 | pass 112 | 113 | 114 | class WriteOnlyStringIO(io.StringIO): 115 | """StringIO that throws an exception when it's read from""" 116 | 117 | def read(self, *args, **kwargs): 118 | raise IOError 119 | 120 | def readline(self, *args, **kwargs): 121 | raise IOError 122 | 123 | def readlines(self, *args, **kwargs): 124 | raise IOError 125 | 126 | def readable(self, *args, **kwargs): 127 | """Returns True if the IO object can be read.""" 128 | return False 129 | 130 | 131 | class redirect_stdin(contextlib._RedirectStream): # type: ignore 132 | _stream = "stdin" 133 | 134 | 135 | def reliability_guard(maximum_memory_bytes: Optional[int] = None): 136 | """ 137 | This disables various destructive functions and prevents the generated code 138 | from interfering with the test (e.g. fork bomb, killing other processes, 139 | removing filesystem files, etc.) 140 | 141 | WARNING 142 | This function is NOT a security sandbox. Untrusted code, including, model- 143 | generated code, should not be blindly executed outside of one. See the 144 | Codex paper for more information about OpenAI's code sandbox, and proceed 145 | with caution. 146 | """ 147 | 148 | if maximum_memory_bytes is not None: 149 | import resource 150 | 151 | resource.setrlimit( 152 | resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes) 153 | ) 154 | resource.setrlimit( 155 | resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes) 156 | ) 157 | if not platform.uname().system == "Darwin": 158 | resource.setrlimit( 159 | resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes) 160 | ) 161 | 162 | faulthandler.disable() 163 | 164 | import builtins 165 | 166 | builtins.exit = None 167 | builtins.quit = None 168 | 169 | import os 170 | 171 | os.environ["OMP_NUM_THREADS"] = "1" 172 | 173 | os.kill = None 174 | os.system = None 175 | os.putenv = None 176 | os.remove = None 177 | os.removedirs = None 178 | os.rmdir = None 179 | os.fchdir = None 180 | os.setuid = None 181 | os.fork = None 182 | os.forkpty = None 183 | os.killpg = None 184 | os.rename = None 185 | os.renames = None 186 | os.truncate = None 187 | os.replace = None 188 | os.unlink = None 189 | os.fchmod = None 190 | os.fchown = None 191 | os.chmod = None 192 | os.chown = None 193 | os.chroot = None 194 | os.fchdir = None 195 | os.lchflags = None 196 | os.lchmod = None 197 | os.lchown = None 198 | #os.getcwd = None 199 | os.chdir = None 200 | builtins.open = None 201 | 202 | import shutil 203 | 204 | shutil.rmtree = None 205 | shutil.move = None 206 | shutil.chown = None 207 | 208 | import subprocess 209 | 210 | subprocess.Popen = None # type: ignore 211 | 212 | #__builtins__["help"] = None 213 | 214 | import sys 215 | 216 | sys.modules["ipdb"] = None 217 | sys.modules["joblib"] = None 218 | sys.modules["resource"] = None 219 | sys.modules["psutil"] = None 220 | sys.modules["tkinter"] = None 221 | 222 | 223 | def is_floats(x) -> bool: 224 | # check if it is float; List[float]; Tuple[float] 225 | if isinstance(x, float): 226 | return True 227 | if isinstance(x, (list, tuple)): 228 | return all(isinstance(i, float) for i in x) 229 | if isinstance(x, np.ndarray): 230 | return x.dtype == np.float64 or x.dtype == np.float32 231 | return False 232 | 233 | def _poly(xs: list, x: float): 234 | """ 235 | Evaluates polynomial with coefficients xs at point x. 236 | return xs[0] + xs[1] * x + xs[1] * x^2 + .... xs[n] * x^n 237 | """ 238 | return sum([coeff * math.pow(x, i) for i, coeff in enumerate(xs)]) 239 | 240 | MBPP_OUTPUT_NOT_NONE_TASKS = ["check_str", "text_match_three", "text_starta_endb"] 241 | 242 | # transform?? 243 | def trasform_tuples_into_lists(lst): 244 | if isinstance(lst, list) or isinstance(lst, tuple): 245 | new_lst = [] 246 | for l in lst: 247 | new_lst.append(trasform_tuples_into_lists(l)) 248 | return new_lst 249 | elif isinstance(lst, dict): 250 | new_dict = {} 251 | for key in lst: 252 | new_dict[key] = trasform_tuples_into_lists(lst[key]) 253 | return new_dict 254 | else: 255 | return lst 256 | 257 | def is_equal(a, b, atol = 0): 258 | if a == b: 259 | return True 260 | # 忽略空格与回车的影响 261 | if isinstance(a, str) and isinstance(b, str) and a.replace("\n", "").replace(" ","").replace("[", "").replace("]", "") == b.replace("\n", "").replace(" ","").replace("[", "").replace("]", ""): 262 | return True 263 | transformed_a = trasform_tuples_into_lists(a) 264 | transformed_b = trasform_tuples_into_lists(b) 265 | if transformed_a == transformed_b: 266 | return True 267 | # 若是浮点数则1e-6为绝对误差判断 268 | if atol == 0 and is_floats(a) and is_floats(b): 269 | atol = 1e-6 270 | if atol != 0: 271 | try: 272 | # 检查浮点数相等 atol绝对容差 273 | np.testing.assert_allclose(b, a, atol=atol) 274 | return True 275 | except: 276 | return False 277 | 278 | return False 279 | #检查列表是否全部相等 280 | def is_all_equal(a, b, atol = 0): 281 | if len(a) != len(b): 282 | return False 283 | for i, a1 in enumerate(a): 284 | if not is_equal(a1["model_output"], b[i]["model_output"]): 285 | return False 286 | 287 | return True 288 | 289 | 290 | def check_success(results): 291 | for result in results: 292 | if "status" not in result or result["status"] != SUCCEED: 293 | return False 294 | 295 | return True 296 | 297 | 298 | 299 | def check_output(testcase, out, atol): 300 | expected = testcase["output"][0] 301 | if isinstance(out, Capturing): 302 | out = "\n".join(out).strip() 303 | if is_equal(out,expected): 304 | return True 305 | 306 | raise ValueError("Output does not match the expected output in testcase! Expected: {} Got: {}".format(expected, out)) 307 | 308 | 309 | def run_stdin_code(code, exec_globals, inputs, measure_time = False): 310 | 311 | inputs_line_iterator = iter(inputs.split("\n")) 312 | @patch('builtins.open', mock_open(read_data=inputs)) 313 | @patch('sys.stdin', StringIO(inputs)) 314 | @patch('sys.stdin.readline', lambda *args: next(inputs_line_iterator)) 315 | @patch('sys.stdin.readlines', lambda *args: inputs.split("\n")) 316 | @patch('sys.stdin.read', lambda *args: inputs) 317 | def inner_call(): 318 | if measure_time: 319 | start_time = time.time() 320 | exec(code, exec_globals) 321 | if measure_time: 322 | end_time = time.time() 323 | return end_time - start_time 324 | 325 | return inner_call() 326 | 327 | def run_stdin_code_for_instr(code, exec_globals, inputs): 328 | 329 | inputs_line_iterator = iter(inputs.split("\n")) 330 | @patch('builtins.open', mock_open(read_data=inputs)) 331 | @patch('sys.stdin', StringIO(inputs)) 332 | @patch('sys.stdin.readline', lambda *args: next(inputs_line_iterator)) 333 | @patch('sys.stdin.readlines', lambda *args: inputs.split("\n")) 334 | @patch('sys.stdin.read', lambda *args: inputs) 335 | def inner_call(): 336 | with Collector() as collector: 337 | exec(code, exec_globals) 338 | return collector.counters.instruction_count 339 | return inner_call() 340 | 341 | def run_stdin_code_coverage(inputs): 342 | inputs_line_iterator = iter(inputs.split("\n")) 343 | @patch('builtins.open', mock_open(read_data=inputs)) 344 | @patch('sys.stdin', StringIO(inputs)) 345 | @patch('sys.stdin.readline', lambda *args: next(inputs_line_iterator)) 346 | @patch('sys.stdin.readlines', lambda *args: inputs.split("\n")) 347 | @patch('sys.stdin.read', lambda *args: inputs) 348 | def inner_call(): 349 | with Capturing() as out: 350 | import temp 351 | return inner_call() 352 | 353 | 354 | def eval_stdin_input(inputs): 355 | elements = inputs.split("\n") 356 | new_inputs = [] 357 | for element in elements: 358 | new_inputs.append(json.dumps(eval(element))) 359 | 360 | return "\n".join(new_inputs) 361 | 362 | 363 | def unsafe_testcase_execute( 364 | testcase, 365 | generator, 366 | stat 367 | ): 368 | ''' 369 | Execute the expressions in test cases to verify compilable rate. 370 | This is an internal function and should not be called outside this file. 371 | ''' 372 | sys.set_int_max_str_digits(1000000) 373 | sys.setrecursionlimit(1000000) 374 | with create_tempdir(): 375 | import os 376 | import shutil 377 | import json 378 | 379 | rmtree = shutil.rmtree 380 | rmdir = os.rmdir 381 | chdir = os.chdir 382 | # Disable functionalities that can make destructive changes to the test. 383 | # allow only 1GB memory usage 384 | maximum_memory_bytes = 1024 * 1024 * 1024 385 | reliability_guard(maximum_memory_bytes=maximum_memory_bytes) 386 | random.seed(1024) 387 | try: 388 | if not generator: 389 | eval(testcase) 390 | else: 391 | exec_globals = {} 392 | exec(testcase, exec_globals) 393 | fn = exec_globals["generate_input"] 394 | fn() 395 | stat.value = SUCCEED 396 | except BaseException as e: 397 | stat.value = FAILED 398 | shutil.rmtree = rmtree 399 | os.rmdir = rmdir 400 | os.chdir = chdir 401 | 402 | def untrusted_testcase_check( 403 | testcase, 404 | generator = False 405 | ): 406 | ''' 407 | The wrapper for the testcase correctness check, can be called outside. 408 | ''' 409 | timeout = 10 410 | stat = Value("i", UNKNOWN) 411 | p = multiprocessing.Process( 412 | target=unsafe_testcase_execute, 413 | args=( 414 | testcase, 415 | generator, 416 | stat 417 | ), 418 | ) 419 | p.start() 420 | p.join(timeout=timeout + 1) 421 | if p.is_alive(): 422 | p.terminate() 423 | time.sleep(0.1) 424 | if p.is_alive(): 425 | p.kill() 426 | time.sleep(0.1) 427 | 428 | stat = _mapping[stat.value] 429 | 430 | if not stat: 431 | stat = _TIMEOUT 432 | 433 | # print(stat) 434 | return stat 435 | 436 | 437 | 438 | 439 | 440 | def unsafe_execute( 441 | io: bool, 442 | code: str, 443 | testcases: List, 444 | time_limits, 445 | results, 446 | atol, 447 | fast_check, 448 | check, 449 | generator, 450 | stat, 451 | details, 452 | progress, 453 | ): 454 | ''' 455 | Exceute the code in the given testcases to verify the correctness. 456 | This is an internal function and should not be called outside this file. 457 | ''' 458 | sys.set_int_max_str_digits(1000000) 459 | sys.setrecursionlimit(1000000) 460 | with create_tempdir(): 461 | # These system calls are needed when cleaning up tempdir. 462 | import os 463 | import shutil 464 | import json 465 | 466 | rmtree = shutil.rmtree 467 | rmdir = os.rmdir 468 | chdir = os.chdir 469 | # Disable functionalities that can make destructive changes to the test. 470 | # allow only 4GB memory usage 471 | maximum_memory_bytes = 4 * 1024 * 1024 * 1024 472 | # reliability_guard(maximum_memory_bytes=maximum_memory_bytes) 473 | exec_globals = {} 474 | random.seed(1024) 475 | try: 476 | if io: 477 | code = code.replace("if __name__", "if 1 or __name__") 478 | for i, testcase in enumerate(testcases): 479 | result = {} 480 | try: 481 | if isinstance(testcase, str): 482 | if not generator: 483 | exec_input = eval(testcase)["input"] 484 | else: 485 | testcase_exec_globals = {} 486 | exec(testcase, testcase_exec_globals) 487 | testcase_fn = testcase_exec_globals["generate_input"] 488 | exec_input = testcase_fn() 489 | testcase_input = exec_input 490 | elif not isinstance(testcase["input"], str): 491 | testcase_input = testcase["input"][0] 492 | else: 493 | testcase_input = testcase["input"] 494 | testcase_input = eval(testcase_input) 495 | with time_limit(time_limits[i]): 496 | with Capturing() as out: 497 | run_stdin_code(code, exec_globals, testcase_input) 498 | out = out[0] 499 | if not isinstance(testcase, str): 500 | result["input"] = testcase["input"] 501 | result["output"] = testcase["output"][0] 502 | else: 503 | result["input_output"] = testcase 504 | result["global"] = True 505 | result["model_output"] = out 506 | if check: 507 | check_output(testcase, out, atol) 508 | result["status"] = SUCCEED 509 | result["status_reason"] = None 510 | results.append(result) 511 | except BaseException as e: 512 | #traceback.print_exc() 513 | if not isinstance(testcase, str): 514 | result["input"] = testcase["input"] 515 | result["output"] = testcase["output"][0] 516 | else: 517 | result["input_output"] = testcase 518 | result["global"] = True 519 | result["status"] = FAILED 520 | result["status_reason"] = str(e) 521 | results.append(result) 522 | details[i] = False 523 | progress.value += 1 524 | if fast_check: 525 | raise 526 | continue 527 | details[i] = True 528 | progress.value += 1 529 | stat.value = SUCCEED 530 | else: 531 | with swallow_io(): 532 | exec(code, exec_globals) 533 | fn = exec_globals["solution"] 534 | for i, testcase in enumerate(testcases): 535 | result = {} 536 | if isinstance(testcase, str): 537 | if not generator: 538 | exec_input = eval(testcase)["input"] 539 | else: 540 | testcase_exec_globals = {} 541 | exec(testcase, testcase_exec_globals) 542 | testcase_fn = testcase_exec_globals["generate_input"] 543 | exec_input = testcase_fn() 544 | if "," in code.split(":")[0]: 545 | testcase_input = exec_input 546 | else: 547 | testcase_input = [exec_input] 548 | elif not isinstance(testcase["input"], str): 549 | testcase_input = testcase["input"] 550 | else: 551 | testcase_input = testcase["input"] 552 | testcase_input = eval(testcase_input) 553 | try: 554 | with time_limit(time_limits[i]): 555 | out = fn(*testcase_input) 556 | if not isinstance(testcase, str): 557 | result["input"] = testcase["input"] 558 | if isinstance(testcase["output"], list): 559 | result["output"] = testcase["output"][0] 560 | else: 561 | result["output"] = testcase["output"] 562 | else: 563 | result["input_output"] = testcase 564 | result["global"] = False 565 | result["model_output"] = out 566 | if check: 567 | check_output(testcase, out, atol) 568 | result["status"] = SUCCEED 569 | result["status_reason"] = None 570 | results.append(result) 571 | except BaseException as e: 572 | if not isinstance(testcase, str): 573 | result["input"] = testcase["input"] 574 | result["output"] = testcase["output"][0] 575 | else: 576 | result["input_output"] = testcase 577 | result["global"] = False 578 | result["status"] = FAILED 579 | result["status_reason"] = str(e) 580 | results.append(result) 581 | details[i] = False 582 | progress.value += 1 583 | if fast_check: 584 | raise 585 | continue 586 | details[i] = True 587 | progress.value += 1 588 | stat.value = SUCCEED 589 | except BaseException as e: 590 | stat.value = FAILED 591 | # Needed for cleaning up. 592 | shutil.rmtree = rmtree 593 | os.rmdir = rmdir 594 | os.chdir = chdir 595 | 596 | 597 | def untrusted_check( 598 | io: bool, 599 | code: str, 600 | testcases: list, 601 | atol, 602 | ref_time: List[float], 603 | fast_check: bool = False, 604 | check: bool = True, 605 | generator: bool = False, 606 | min_time_limit: float = 0.1, 607 | gt_time_limit_factor: float = 10.0 608 | ) -> Tuple[str, np.ndarray]: 609 | ''' 610 | The wrapper for the correctness check, can be called outside. 611 | ''' 612 | time_limits = [max(min_time_limit, gt_time_limit_factor * t) for t in ref_time] 613 | timeout = min(int(os.getenv("TIMEOUT_PER_TASK", 3600)), sum(time_limits)) + 2 614 | if not fast_check: 615 | timeout += 1 # extra time for data collection 616 | 617 | # shared memory objects 618 | progress = Value("i", 0) 619 | stat = Value("i", UNKNOWN) 620 | details = Array("b", [False for _ in range(len(testcases))]) 621 | 622 | 623 | manager = multiprocessing.Manager() 624 | results = manager.list() 625 | 626 | p = multiprocessing.Process( 627 | target=unsafe_execute, 628 | args=( 629 | io, 630 | code, 631 | testcases, 632 | time_limits, 633 | results, 634 | atol, 635 | fast_check, 636 | check, 637 | generator, 638 | # return values 639 | stat, 640 | details, 641 | progress, 642 | ), 643 | ) 644 | p.start() 645 | p.join(timeout=timeout + 1) 646 | if p.is_alive(): 647 | p.terminate() 648 | time.sleep(0.1) 649 | if p.is_alive(): 650 | p.kill() 651 | time.sleep(0.1) 652 | 653 | stat = _mapping[stat.value] 654 | details = details[: progress.value] 655 | 656 | if not stat: 657 | stat = _TIMEOUT 658 | 659 | if stat == _PASS: 660 | # print("??????????????????") 661 | if len(details) != len(testcases) or not all(details): 662 | stat = _FAIL 663 | # print(stat) 664 | return stat, results 665 | 666 | 667 | def unsafe_runtime_execute( 668 | io: bool, 669 | code: str, 670 | testcase: Dict, 671 | time_lmt, 672 | timeout, 673 | generator, 674 | instr = False 675 | ): 676 | ''' 677 | Execute the verified code to collect the runtime information, including execution time and CPU istruction count. 678 | Note that the input code is assumed to be correct. 679 | This is an internal function and should not be called outside this file. 680 | ''' 681 | sys.set_int_max_str_digits(1000000) 682 | sys.setrecursionlimit(1000000) 683 | def unsafe_execute(): 684 | with create_tempdir(): 685 | # These system calls are needed when cleaning up tempdir. 686 | import os 687 | import shutil 688 | import json 689 | 690 | rmtree = shutil.rmtree 691 | rmdir = os.rmdir 692 | chdir = os.chdir 693 | # Disable functionalities that can make destructive changes to the test. 694 | # allow only 4GB memory usage 695 | maximum_memory_bytes = 4 * 1024 * 1024 * 1024 696 | reliability_guard(maximum_memory_bytes=maximum_memory_bytes) 697 | random.seed(1024) 698 | try: 699 | if io: 700 | new_code = code.replace("if __name__", "if 1 or __name__") 701 | if isinstance(testcase, str): 702 | if not generator: 703 | exec_input = eval(testcase)["input"] 704 | else: 705 | testcase_exec_globals = {} 706 | exec(testcase, testcase_exec_globals) 707 | testcase_fn = testcase_exec_globals["generate_input"] 708 | exec_input = testcase_fn() 709 | testcase_input = exec_input 710 | elif not isinstance(testcase["input"], str): 711 | testcase_input = testcase["input"][0] 712 | else: 713 | testcase_input = testcase["input"] 714 | testcase_input = eval(testcase_input) 715 | exec_globals = {} 716 | with time_limit(time_lmt): 717 | with Capturing() as out: 718 | if not instr: 719 | duration = run_stdin_code(new_code, exec_globals, testcase_input, measure_time = True) 720 | else: 721 | duration = run_stdin_code_for_instr(new_code, exec_globals, testcase_input) 722 | # print(duration) 723 | results.append(duration) 724 | else: 725 | with swallow_io(): 726 | exec_globals = {} 727 | if isinstance(testcase, str): 728 | if not generator: 729 | exec_input = eval(testcase)["input"] 730 | else: 731 | testcase_exec_globals = {} 732 | exec(testcase, testcase_exec_globals) 733 | testcase_fn = testcase_exec_globals["generate_input"] 734 | exec_input = testcase_fn() 735 | if "," in code.split(":")[0]: 736 | testcase_input = exec_input 737 | else: 738 | testcase_input = [exec_input] 739 | elif not isinstance(testcase["input"], str): 740 | testcase_input = testcase["input"] 741 | else: 742 | testcase_input = testcase["input"] 743 | testcase_input = eval(testcase_input) 744 | exec(code, exec_globals) 745 | fn = exec_globals["solution"] 746 | with time_limit(time_lmt): 747 | if not instr: 748 | start_time = time.time() 749 | fn(*testcase_input) 750 | duration = time.time() - start_time 751 | else: 752 | with Collector() as collector: 753 | fn(*testcase_input) 754 | duration = collector.counters.instruction_count 755 | # print(duration) 756 | results.append(duration) 757 | except BaseException as e: 758 | print(e) 759 | pass 760 | # Needed for cleaning up. 761 | shutil.rmtree = rmtree 762 | os.rmdir = rmdir 763 | os.chdir = chdir 764 | manager = multiprocessing.Manager() 765 | results = manager.list() 766 | 767 | p = multiprocessing.Process( 768 | target=unsafe_execute 769 | ) 770 | p.start() 771 | p.join(timeout=timeout + 1) 772 | if p.is_alive(): 773 | p.terminate() 774 | time.sleep(0.1) 775 | if p.is_alive(): 776 | p.kill() 777 | time.sleep(0.1) 778 | 779 | if len(results) > 0: 780 | return results[0] 781 | else: 782 | raise RuntimeError("Time/Instruction measurement failed!") 783 | 784 | 785 | def unsafe_coverage_execute( 786 | io: bool, 787 | code: str, 788 | testcases: List, 789 | time_lmts, 790 | timeout, 791 | generator = False 792 | ): 793 | ''' 794 | Execute the verified code to collect the coverage of test cases. 795 | There will be only one coverage score for all input test cases. 796 | Note that the input code is assumed to be correct. 797 | This is an internal function and should not be called outside this file. 798 | ''' 799 | sys.set_int_max_str_digits(1000000) 800 | sys.setrecursionlimit(1000000) 801 | def unsafe_execute(): 802 | new_code = code.replace("if __name__", "if 1 or __name__") 803 | with open("gittemp.py", "w", encoding = "utf-8") as f: 804 | f.write(new_code) 805 | with create_tempdir(): 806 | # These system calls are needed when cleaning up tempdir. 807 | import os 808 | import shutil 809 | import json 810 | import builtins 811 | 812 | rmtree = shutil.rmtree 813 | rmdir = os.rmdir 814 | chdir = os.chdir 815 | remove = os.remove 816 | unlink = os.unlink 817 | openfile = builtins.open 818 | # Disable functionalities that can make destructive changes to the test. 819 | # allow only 4GB memory usage 820 | maximum_memory_bytes = 4 * 1024 * 1024 * 1024 821 | cov = Coverage() 822 | reliability_guard(maximum_memory_bytes=maximum_memory_bytes) 823 | random.seed(1024) 824 | try: 825 | if io: 826 | testcase_inputs = [] 827 | for testcase in testcases: 828 | if isinstance(testcase, str): 829 | if not generator: 830 | exec_input = eval(testcase)["input"] 831 | else: 832 | testcase_exec_globals = {} 833 | exec(testcase, testcase_exec_globals) 834 | testcase_fn = testcase_exec_globals["generate_input"] 835 | exec_input = testcase_fn() 836 | testcase_input = exec_input 837 | elif not isinstance(testcase["input"], str): 838 | testcase_input = testcase["input"][0] 839 | else: 840 | testcase_input = testcase["input"] 841 | testcase_input = eval(testcase_input) 842 | testcase_inputs.append(testcase_input) 843 | with cov.collect(): 844 | for i, testcase_input in enumerate(testcase_inputs): 845 | with time_limit(time_lmts[i]): 846 | run_stdin_code_coverage(testcase_input) 847 | else: 848 | testcase_inputs = [] 849 | for testcase in testcases: 850 | if isinstance(testcase, str): 851 | if not generator: 852 | exec_input = eval(testcase)["input"] 853 | else: 854 | testcase_exec_globals = {} 855 | exec(testcase, testcase_exec_globals) 856 | testcase_fn = testcase_exec_globals["generate_input"] 857 | exec_input = testcase_fn() 858 | if "," in code.split(":")[0]: 859 | testcase_input = exec_input 860 | else: 861 | testcase_input = [exec_input] 862 | elif not isinstance(testcase["input"], str): 863 | testcase_input = testcase["input"] 864 | else: 865 | testcase_input = testcase["input"] 866 | testcase_input = eval(testcase_input) 867 | testcase_inputs.append(testcase_input) 868 | from temp import solution 869 | with swallow_io(): 870 | with cov.collect(): 871 | for i, testcase_input in enumerate(testcase_inputs): 872 | with time_limit(time_lmts[i]): 873 | solution(*testcase_input) 874 | except BaseException as e: 875 | pass 876 | # Needed for cleaning up. 877 | import os 878 | shutil.rmtree = rmtree 879 | os.rmdir = rmdir 880 | os.chdir = chdir 881 | os.remove = remove 882 | os.unlink = unlink 883 | import builtins 884 | builtins.open = openfile 885 | cov.json_report(outfile = "coverage.json", pretty_print = True) 886 | data = json.load(open("coverage.json", "r")) 887 | if io: 888 | for filename in data["files"]: 889 | if filename.endswith("temp.py"): 890 | results.append(data["files"][filename]) 891 | break 892 | else: 893 | for filename in data["files"]: 894 | if filename.endswith("temp.py"): 895 | results.append(data["files"][filename]["functions"]["solution"]) 896 | break 897 | os.remove("coverage.json") 898 | os.remove("temp.py") 899 | manager = multiprocessing.Manager() 900 | results = manager.list() 901 | p = multiprocessing.Process( 902 | target=unsafe_execute 903 | ) 904 | p.start() 905 | p.join(timeout=timeout + 1) 906 | if p.is_alive(): 907 | p.terminate() 908 | time.sleep(0.1) 909 | if p.is_alive(): 910 | p.kill() 911 | time.sleep(0.1) 912 | 913 | if len(results) > 0: 914 | return results[0] 915 | else: 916 | raise RuntimeError("Coverage measurement failed!") 917 | 918 | def untrusted_runtime_measure( 919 | io: bool, 920 | code: str, 921 | testcases: list, 922 | ref_time: List[float], 923 | min_time_limit: float = 0.1, 924 | gt_time_limit_factor: float = 10.0, 925 | generator = False, 926 | std = False 927 | ) -> Tuple[str, np.ndarray]: 928 | ''' 929 | The wrapper for the execution time collection, can be called outside. 930 | ''' 931 | time_limits = [max(min_time_limit, gt_time_limit_factor * t) for t in ref_time] 932 | timeout = min(int(os.getenv("TIMEOUT_PER_TASK", 3600)), sum(time_limits)) + 1 933 | 934 | execution_time = [] 935 | execution_time_mean = [] 936 | execution_time_std = [] 937 | for i, testcase in enumerate(testcases): 938 | exec_time = [] 939 | try: 940 | for _ in range(12): 941 | exec_time.append(unsafe_runtime_execute(io, code, testcase, time_limits[i], timeout, generator)) 942 | exec_time.remove(max(exec_time)) 943 | exec_time.remove(min(exec_time)) 944 | execution_time_mean.append(np.mean(exec_time)) 945 | execution_time_std.append(np.std(exec_time, ddof = 1)) 946 | except Exception as e: 947 | execution_time_mean.append(INF) 948 | execution_time_std.append(-1) 949 | 950 | 951 | if not std: 952 | return execution_time_mean 953 | else: 954 | return execution_time_mean, execution_time_std 955 | 956 | 957 | def untrusted_detailed_runtime_measure( 958 | io: bool, 959 | code: str, 960 | testcases: list, 961 | ref_time: List[float], 962 | min_time_limit: float = 0.1, 963 | gt_time_limit_factor: float = 10.0, 964 | generator = False 965 | ) -> Tuple[str, np.ndarray]: 966 | ''' 967 | The wrapper for the execution time collection, can be called outside. 968 | ''' 969 | time_limits = [max(min_time_limit, gt_time_limit_factor * t) for t in ref_time] 970 | timeout = min(int(os.getenv("TIMEOUT_PER_TASK", 3600)), sum(time_limits)) + 1 971 | 972 | execution_time = [] 973 | execution_time_mean = [] 974 | execution_time_std = [] 975 | execution_time_max = [] 976 | execution_time_min = [] 977 | try: 978 | for i, testcase in enumerate(testcases): 979 | exec_time = [] 980 | for _ in range(100): 981 | exec_time.append(unsafe_runtime_execute(io, code, testcase, time_limits[i], timeout, generator)) 982 | execution_time_mean.append(np.mean(exec_time)) 983 | execution_time_std.append(np.std(exec_time, ddof = 1)) 984 | execution_time_max.append(max(exec_time)) 985 | execution_time_min.append(min(exec_time)) 986 | execution_time.append(exec_time) 987 | except Exception as e: 988 | pass 989 | 990 | return execution_time_mean, execution_time_std, execution_time_max, execution_time_min, execution_time 991 | 992 | 993 | def untrusted_instruction_measure( 994 | io: bool, 995 | code: str, 996 | testcases: list, 997 | ref_time: List[float], 998 | min_time_limit: float = 0.1, 999 | gt_time_limit_factor: float = 10.0, 1000 | std = False, 1001 | generator = False 1002 | ) -> Tuple[str, np.ndarray]: 1003 | ''' 1004 | The wrapper for the CPU instrution count collection, can be called outside. 1005 | ''' 1006 | time_limits = [max(min_time_limit, gt_time_limit_factor * t) for t in ref_time] 1007 | timeout = min(int(os.getenv("TIMEOUT_PER_TASK", 3600)), sum(time_limits)) + 1 1008 | 1009 | instr_counts = [] 1010 | instr_counts_mean = [] 1011 | instr_counts_std = [] 1012 | for i, testcase in enumerate(testcases): 1013 | instr_count = [] 1014 | try: 1015 | for _ in range(12): 1016 | instr_count.append(unsafe_runtime_execute(io, code, testcase, time_limits[i], timeout, generator, instr = True)) 1017 | instr_count.remove(max(instr_count)) 1018 | instr_count.remove(min(instr_count)) 1019 | instr_counts_mean.append(np.mean(instr_count)) 1020 | instr_counts_std.append(np.std(instr_count, ddof=1)) 1021 | except Exception as e: 1022 | instr_counts_mean.append(INF) 1023 | instr_counts_std.append(-1) 1024 | 1025 | if not std: 1026 | return instr_counts_mean 1027 | else: 1028 | return instr_counts_mean, instr_counts_std 1029 | 1030 | 1031 | 1032 | def untrusted_coverage_measure( 1033 | io: bool, 1034 | code: str, 1035 | testcases: list, 1036 | ref_time: List[float], 1037 | min_time_limit: float = 0.1, 1038 | gt_time_limit_factor: float = 10.0, 1039 | generator = False, 1040 | ): 1041 | ''' 1042 | The wrapper for test case coverage measurement, can be called outside. 1043 | ''' 1044 | time_limits = [max(min_time_limit, gt_time_limit_factor * t) for t in ref_time] 1045 | timeout = min(int(os.getenv("TIMEOUT_PER_TASK", 3600)), sum(time_limits)) + 1 1046 | 1047 | sys.path.append(os.getcwd()) 1048 | 1049 | try: 1050 | coverage = unsafe_coverage_execute(io, code, testcases, time_limits, timeout, generator) 1051 | except: 1052 | coverage = None 1053 | 1054 | return coverage 1055 | 1056 | --------------------------------------------------------------------------------