├── .dockerignore ├── .gitignore ├── LICENSE.txt ├── README.md ├── pyproject.toml ├── src └── geniz │ ├── __init__.py │ └── example │ ├── geniz │ ├── __init__.py │ ├── auto_import.py │ ├── coder.py │ ├── data_collector.py │ ├── debugger.py │ ├── llm.py │ ├── prompts │ │ ├── programmer_prompt.yaml │ │ ├── programmer_prompt_simple.yaml │ │ ├── test_designer_prompt.yaml │ │ └── test_designer_prompt_simple.yaml │ ├── python_code.py │ ├── round_info.py │ ├── test_designer.py │ └── util.py │ ├── input.py │ ├── keys.json.example │ └── webserver.py └── static ├── demo_1.gif ├── geniz_diagram.png └── screenshot_0.png /.dockerignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | keys.cfg 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # Mac files 135 | *.DS_Store 136 | 137 | # iPython Notebooks 138 | *.ipynb 139 | 140 | # Evaluation folders 141 | results/ 142 | testbed/ 143 | temp/ 144 | 145 | # Ignore all YAML files in data/ 146 | data/*/ic-* 147 | data/*/single-issues 148 | 149 | # Fine tuning data 150 | fine_tune/*.ipynb 151 | fine_tune/subtasks/*.jsonl 152 | temp*.jsonl 153 | 154 | # Inspector 155 | inspector/*.json 156 | 157 | # Ignore all files in the private folder 158 | private/ 159 | 160 | ### Website 161 | 162 | # dependencies 163 | website/frontend/node_modules 164 | website/frontend/package-lock.json 165 | website/frontend/.pnp 166 | *.pnp.js 167 | 168 | # testing 169 | website/frontend/coverage 170 | 171 | # production 172 | website/frontend/build 173 | 174 | # misc 175 | *.env.local 176 | *.env.development.local 177 | *.env.test.local 178 | *.env.production.local 179 | .api_key 180 | *npm-debug.log* 181 | *yarn-debug.log* 182 | *yarn-error.log* 183 | 184 | 185 | # demo yamls (for editing) 186 | *.demo.yaml 187 | 188 | # trajectory files 189 | trajectories/** 190 | !trajectories/demonstrations/** 191 | 192 | .vscode/** 193 | 194 | # PyCharm 195 | .idea/ 196 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Custom 132 | candidate_* 133 | input_* 134 | RoundInfo.json 135 | keys.json 136 | locked_tests.json -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2024 Sudocode Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Geniz 2 | 3 | 4 |

5 | [Discord] 6 |

7 | 8 | 9 | Code generation can be highly sensitive to prompting and initial conditions. Geniz explores an evolutionary approach to code generation, where multiple candidate solutions converge on a solution over multiple generations. 10 | 11 | Given an input description of a coding task, the system generates, tests, and ranks a multitude of solutions to find the ones that perform the best. 12 | 13 | ![diagram](static/geniz_diagram.png) 14 | 15 | Geniz includes an interactive codegen app that allows humans to select candidate solutions and evolve them until they converge on a correct solution. The human serves as a flexible fitness function in selecting applicable tests and candidates to survive to the next generation. 16 | 17 | [![Demo Video](static/demo_1.gif)](https://www.youtube.com/watch?v=S_vB7qQ3qs4) 18 | 19 | Geniz combines recent developed code-gen approaches like [AgentCoder](https://arxiv.org/abs/2312.13010), [Reflexion](https://arxiv.org/abs/2303.11366), and [Self-Consistency](https://arxiv.org/abs/2203.11171). The approach particularly excels with smaller models. We foresee this approach as an economical option for achieving high-quality results with lower hardware requirements. 20 | 21 | Benchmark Geniz without human invention on HumanEval dataset. 22 | | Model | Baseline | Geniz (without human) | 23 | | ------------------------------- | -------- | --------------------- | 24 | | OpenCodeInterpreter-1.3B | 48.7%* | 72.0% (+45%) | 25 | | Llama-3-70b-instruct | 81.7% | 85.9% (+5%) | 26 | | Phi-3-mini-128k-instruct (3.8B) | 57.9% | 74.1% (+28%) | 27 | 28 | Note: * is our reproduction. 29 | 30 | 31 | ### Features 32 | * **Code solution generation**: Geniz generates a variety of potential solutions in parallel. 33 | * **Test generation and execution**: After generation, the system tests each candidate against all test cases and ranks them based on test evaluation and output consistency. 34 | * **Candidate evolution**: New solutions are generated based on the previous generation of candidates. 35 | * **Interactive human feedback**: Users can tweak solutions, modify tests, and select candidates to steer the evolution of better candidate solutions. 36 | 37 | ### Revolutionizing Coding 38 | Geniz aims to change the way people tackle coding challenges from platforms like LeetCode, TopCoder, and USACO. Rather than spending hours coding and debugging, users can focus on defining the problem's requirements and let the LLMs handle the implementation details. 39 | 40 | ### Harnessing Local Code-Gen LLMs 41 | Geniz represents the first practical application of local/smaller LLMs, enabling them to solve complex coding problems effectively. By running the LLMs locally, Geniz ensures privacy and allows users to benefit from the latest advancements in language models without relying on external services. 42 | 43 | ## Getting Started 44 | 45 | To get started with Geniz, follow these steps: 46 | 47 | * Clone the repository `git clone https://github.com/sudocode-ai/geniz.git` 48 | * Install the dependencies: `cd geniz && pip install -e .` 49 | * (Optional) Prepare keys.json (See keys.json.example), or type in the UI. 50 | * Run the webapp: cd src/geniz/example && python webserver.py 51 | 52 | 53 | ## Contributing 54 | 55 | We welcome contributions from the community! If you'd like to contribute to Geniz, please follow our contributing guidelines. 56 | Contact person: [Ning Ren](https://www.linkedin.com/in/renning22/), [Alex Ngai](https://www.linkedin.com/in/alexngai/) and [Randy Song](https://www.linkedin.com/in/randy-song/). 57 | 58 | ## License 59 | 60 | Geniz is released under the MIT License. 61 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | # This is the name of your project. The first time you publish this 3 | # package, this name will be registered for you. It will determine how 4 | # users can install this project, e.g.: 5 | # 6 | # $ pip install geniz 7 | # 8 | # And where it will live on PyPI: https://pypi.org/project/sampleproject/ 9 | # 10 | # There are some restrictions on what makes a valid project name 11 | # specification here: 12 | # https://packaging.python.org/specifications/core-metadata/#name 13 | name = "geniz" # Required 14 | 15 | # Versions should comply with PEP 440: 16 | # https://www.python.org/dev/peps/pep-0440/ 17 | # 18 | # For a discussion on single-sourcing the version, see 19 | # https://packaging.python.org/guides/single-sourcing-package-version/ 20 | version = "0.0.1" # Required 21 | 22 | # This is a one-line description or tagline of what your project does. This 23 | # corresponds to the "Summary" metadata field: 24 | # https://packaging.python.org/specifications/core-metadata/#summary 25 | description = "A simple Python library" # Optional 26 | 27 | # This is an optional longer description of your project that represents 28 | # the body of text which users will see when they visit PyPI. 29 | # 30 | # Often, this is the same as your README, so you can just read it in from 31 | # that file directly (as we have already done above) 32 | # 33 | # This field corresponds to the "Description" metadata field: 34 | # https://packaging.python.org/specifications/core-metadata/#description-optional 35 | readme = "README.md" # Optional 36 | 37 | # Specify which Python versions you support. In contrast to the 38 | # 'Programming Language' classifiers above, 'pip install' will check this 39 | # and refuse to install the project if the version does not match. See 40 | # https://packaging.python.org/guides/distributing-packages-using-setuptools/#python-requires 41 | requires-python = ">=3.7" 42 | 43 | # This is either text indicating the license for the distribution, or a file 44 | # that contains the license 45 | # https://packaging.python.org/en/latest/specifications/core-metadata/#license 46 | license = {file = "LICENSE.txt"} 47 | 48 | # This field adds keywords for your project which will appear on the 49 | # project page. What does your project relate to? 50 | # 51 | # Note that this is a list of additional keywords, separated 52 | # by commas, to be used to assist searching for the distribution in a 53 | # larger catalog. 54 | keywords = ["setuptools", "development"] # Optional 55 | 56 | # This should be your name or the name of the organization who originally 57 | # authored the project, and a valid email address corresponding to the name 58 | # listed. 59 | authors = [ 60 | {name = "Ning Ren", email = "renning22@gmail.com" } # Optional 61 | ] 62 | 63 | # This should be your name or the names of the organization who currently 64 | # maintains the project, and a valid email address corresponding to the name 65 | # listed. 66 | maintainers = [ 67 | {name = "Ning Ren", email = "renning22@gmail.com" } # Optional 68 | ] 69 | 70 | # Classifiers help users find your project by categorizing it. 71 | # 72 | # For a list of valid classifiers, see https://pypi.org/classifiers/ 73 | classifiers = [ # Optional 74 | # How mature is this project? Common values are 75 | # 3 - Alpha 76 | # 4 - Beta 77 | # 5 - Production/Stable 78 | "Development Status :: 3 - Alpha", 79 | 80 | # Indicate who your project is intended for 81 | "Intended Audience :: Developers", 82 | "Topic :: Software Development :: Build Tools", 83 | 84 | # Pick your license as you wish 85 | "License :: OSI Approved :: MIT License", 86 | 87 | # Specify the Python versions you support here. In particular, ensure 88 | # that you indicate you support Python 3. These classifiers are *not* 89 | # checked by "pip install". See instead "python_requires" below. 90 | "Programming Language :: Python :: 3", 91 | "Programming Language :: Python :: 3.7", 92 | "Programming Language :: Python :: 3.8", 93 | "Programming Language :: Python :: 3.9", 94 | "Programming Language :: Python :: 3.10", 95 | "Programming Language :: Python :: 3.11", 96 | "Programming Language :: Python :: 3 :: Only", 97 | ] 98 | 99 | # This field lists other packages that your project depends on to run. 100 | # Any package you put here will be installed by pip when your project is 101 | # installed, so they must be valid existing projects. 102 | # 103 | # For an analysis of this field vs pip's requirements files see: 104 | # https://packaging.python.org/discussions/install-requires-vs-requirements/ 105 | dependencies = [ # Optional 106 | "docker", 107 | "gradio", 108 | "isort", 109 | "litellm", 110 | "pydantic", 111 | "ray[default]", 112 | ] 113 | 114 | # List additional groups of dependencies here (e.g. development 115 | # dependencies). Users will be able to install these using the "extras" 116 | # syntax, for example: 117 | # 118 | # $ pip install geniz[dev] 119 | # 120 | # Similar to `dependencies` above, these must be valid existing 121 | # projects. 122 | [project.optional-dependencies] # Optional 123 | dev = ["check-manifest"] 124 | test = ["coverage"] 125 | 126 | # List URLs that are relevant to your project 127 | # 128 | # This field corresponds to the "Project-URL" and "Home-Page" metadata fields: 129 | # https://packaging.python.org/specifications/core-metadata/#project-url-multiple-use 130 | # https://packaging.python.org/specifications/core-metadata/#home-page-optional 131 | # 132 | # Examples listed include a pattern for specifying where the package tracks 133 | # issues, where the source is hosted, where to say thanks to the package 134 | # maintainers, and where to support the project financially. The key is 135 | # what's used to render the link text on PyPI. 136 | [project.urls] # Optional 137 | "Homepage" = "https://github.com/sudocode-ai/autogenesis" 138 | "Bug Reports" = "https://github.com/sudocode-ai/autogenesis/issues" 139 | "Funding" = "https://github.com/sudocode-ai/autogenesis" 140 | "Say Thanks!" = "https://github.com/sudocode-ai/autogenesis" 141 | "Source" = "https://github.com/sudocode-ai/autogenesis" 142 | 143 | # The following would provide a command line executable called `sample` 144 | # which executes the function `main` from this package when invoked. 145 | [project.scripts] # Optional 146 | geniz = "geniz:main" 147 | 148 | # This is configuration specific to the `setuptools` build backend. 149 | # If you are using a different build backend, you will need to change this. 150 | [tool.setuptools] 151 | # If there are data files included in your packages that need to be 152 | # installed, specify them here. 153 | package-data = {"sample" = ["*.dat"]} 154 | 155 | [build-system] 156 | # These are the assumed default build requirements from pip: 157 | # https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support 158 | requires = ["setuptools>=67.0.0", "wheel"] 159 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /src/geniz/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sudocode-ai/geniz/0a38c18cdffa0be6d079f2564dd8e8c92c5007ad/src/geniz/__init__.py -------------------------------------------------------------------------------- /src/geniz/example/geniz/__init__.py: -------------------------------------------------------------------------------- 1 | from .auto_import import auto_import 2 | from .coder import CodeAgent 3 | from .debugger import DebugAgent 4 | from .round_info import get_round_info 5 | -------------------------------------------------------------------------------- /src/geniz/example/geniz/auto_import.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import importlib 3 | from os.path import basename, dirname, isfile, join 4 | 5 | def auto_import(): 6 | modules = glob.glob(join(dirname(__file__), "..", "*.py")) 7 | python_files_to_import = [basename(f) 8 | for f in modules if isfile(f) and not f.startswith('__') and f != 'webserver.py'] 9 | for py_file in python_files_to_import: 10 | module_name = py_file.removesuffix('.py') 11 | try: 12 | importlib.import_module(module_name) 13 | except Exception as e: 14 | print(f'ignore {module_name} ({py_file}), error:\n{e}') 15 | with open(f'{py_file}.error', 'w') as f: 16 | f.write(str(e)) 17 | -------------------------------------------------------------------------------- /src/geniz/example/geniz/coder.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import inspect 3 | import json 4 | import logging 5 | import os 6 | import pickle 7 | from collections import defaultdict 8 | from multiprocessing.pool import ThreadPool 9 | from pathlib import Path 10 | from typing import Callable, List, Optional 11 | from unittest.mock import MagicMock 12 | 13 | import isort 14 | import ray 15 | from pydantic import BaseModel 16 | 17 | from .auto_import import auto_import 18 | from .data_collector import (DATA_DIST, DataPoint, 19 | calculate_candidates_stats_scores, 20 | data_collection_mode, 21 | datapoint_to_input_output_str, 22 | get_candidate_input_output, get_data_collector, 23 | get_test_dist, is_data_collection_mode) 24 | from .debugger import get_all_test_cases 25 | from .llm import ChatMessage, query_llm 26 | from .python_code import PythonCode 27 | from .test_designer import create_test_file 28 | from .util import (PersistStateToFile, alphanumeric_uuid, entropy_list, 29 | load_prompt, make_function_call_statement_str, 30 | shorten_answer, shorten_list) 31 | 32 | _INPUT = 'input.py' 33 | _OUTPUT = 'output.py' 34 | 35 | _REGISTRY = dict() 36 | 37 | PROGRAMMER_PROMPT = load_prompt( 38 | os.path.join( 39 | os.path.dirname(os.path.realpath(__file__)), 40 | "prompts/programmer_prompt_simple.yaml", 41 | ) 42 | ) 43 | 44 | 45 | @ray.remote 46 | def _execute_remote(function_obj, candidate_id, *args, **kwargs): 47 | print(f'remotely execute {candidate_id}, input={shorten_answer(args)}') 48 | try: 49 | ret = function_obj(*args, **kwargs) 50 | except Exception as e: 51 | ret = e 52 | dp = DataPoint(input=args, output=ret, candidate=candidate_id) 53 | 54 | # Check serializable. 55 | try: 56 | ray.put(dp) 57 | return dp 58 | except: 59 | pass 60 | return None 61 | 62 | 63 | class CodeAgentState(BaseModel, PersistStateToFile): 64 | function_obj: Callable 65 | source_code: str 66 | source_filename: str 67 | function_name: str 68 | 69 | @property 70 | def id(self) -> str: 71 | return f'{self.function_name}({self.source_filename})' 72 | 73 | @property 74 | def clean_source_code(self) -> str: 75 | """Without decorator.""" 76 | clean_source_code = self.source_code 77 | clean_source_code = clean_source_code.replace( 78 | '@geniz.CodeAgent()', '') 79 | clean_source_code = clean_source_code.replace( 80 | '@geniz.CodeAgent', '') 81 | clean_source_code = clean_source_code.replace( 82 | '@geniz.DebugAgent()', '') 83 | clean_source_code = clean_source_code.replace( 84 | '@geniz.DebugAgent', '') 85 | return clean_source_code 86 | 87 | @property 88 | def module_name(self) -> str: 89 | return self.source_filename.removesuffix('.py') 90 | 91 | def load_previous_history(self) -> List[ChatMessage]: 92 | history_filename = Path(f'{self.source_filename}.history.pickle') 93 | if history_filename.exists(): 94 | return pickle.loads(history_filename.read_bytes()) 95 | return [] 96 | 97 | def debug_str(self): 98 | return self.id 99 | 100 | def debug_oracle_pass(self): 101 | """Eval if passing against the dataset ground-truth. 102 | 103 | CAUTION: only print in logging to help development. Use this information to code-gen is CHEATING. 104 | """ 105 | try: 106 | import eval_output 107 | eval_output.check(self.function_obj) 108 | except: 109 | return False 110 | return True 111 | 112 | def call_and_record_data_remotely(self, *args, **kwargs): 113 | return _execute_remote.remote(self.function_obj, self.id, *args, *kwargs) 114 | 115 | def is_genesis(self) -> bool: 116 | return self.source_filename == _INPUT 117 | 118 | def delete(self): 119 | logging.info(f'=== Delete {self.id}') 120 | os.system(f'rm {self.source_filename}*') 121 | 122 | def submit(self): 123 | final_source_code = f'''# Submitted from {self.id} 124 | import os 125 | import sys 126 | from typing import Any, Callable, Dict, List, Optional, Tuple 127 | 128 | import numpy as np 129 | 130 | from {self.module_name} import * 131 | 132 | {self.clean_source_code} 133 | 134 | final_answer = {self.function_name} 135 | ''' 136 | with open(_OUTPUT, 'w') as f: 137 | f.write(final_source_code) 138 | 139 | def query_llm_and_save_candidate(self, message, system_message=PROGRAMMER_PROMPT.format()) -> Optional[List[str]]: 140 | previous_history = self.load_previous_history() 141 | new_id = str(alphanumeric_uuid()[-5:]) 142 | output_filename = f'candidate_{new_id}.py' 143 | replys, histories = query_llm( 144 | message, system_message=system_message, previous_history=previous_history, filename=output_filename) 145 | 146 | wrote_filenames = [] 147 | for i, reply in enumerate(replys): 148 | candidate = PythonCode.extract_main_program_code_block_from_full_reply( 149 | reply, function_name=self.function_name) 150 | if candidate.valid(): 151 | final_code = f'''# Mutation from '{self.source_filename}' 152 | import os 153 | import sys 154 | from typing import Any, Callable, Dict, List, Optional, Tuple 155 | 156 | import numpy as np 157 | 158 | import geniz 159 | 160 | {candidate.code_text} 161 | ''' 162 | new_candidate_filename = f'candidate_{new_id}_{i}.py' 163 | with open(new_candidate_filename, 'w') as f: 164 | f.write(final_code) 165 | logging.info(f'Wrote to {new_candidate_filename}') 166 | isort.file(new_candidate_filename) 167 | new_history_filename = f'{new_candidate_filename}.history.pickle' 168 | with open(new_history_filename, 'wb') as f: 169 | pickle.dump(histories[i], f) 170 | logging.info(f'Wrote to {new_history_filename}') 171 | wrote_filenames.append(new_candidate_filename) 172 | return wrote_filenames 173 | 174 | def generate_test_case(self) -> bool: 175 | return create_test_file(self.source_filename, 176 | self.function_name, self.clean_source_code) 177 | 178 | def generate_seed_candidate(self) -> bool: 179 | message = f''' 180 | ```python 181 | {self.clean_source_code} 182 | ``` 183 | 184 | Generate '{self.function_name}'. 185 | ''' 186 | return bool(self.query_llm_and_save_candidate(message)) 187 | 188 | def generate_candidate_by_execution_result(self, execution_results: str) -> bool: 189 | return bool(self.query_llm_and_save_candidate(execution_results)) 190 | 191 | 192 | def get_all_agents() -> List[CodeAgentState]: 193 | return [agent for agent in _REGISTRY.values() if not agent.is_genesis()] 194 | 195 | 196 | def get_genesis() -> CodeAgentState: 197 | return [agent for agent in _REGISTRY.values() if agent.is_genesis()][0] 198 | 199 | 200 | def find_agent(candidate_id: str) -> Optional[CodeAgentState]: 201 | return _REGISTRY.get(candidate_id, None) 202 | 203 | 204 | def CodeAgent(method=None, **kwargs): 205 | def _harness(method): 206 | source_code = inspect.getsource(method) 207 | source_filename = os.path.basename(method.__code__.co_filename) 208 | code_agent_state = CodeAgentState( 209 | function_obj=method, source_code=source_code, source_filename=source_filename, function_name=method.__name__) 210 | _REGISTRY[code_agent_state.id] = code_agent_state 211 | 212 | @functools.wraps(method) 213 | def _wrapper(*args, **kwargs): 214 | # Mock to avoid any test errors (e.g. `assert ==`). 215 | always_true = MagicMock() 216 | always_true.__eq__.return_value = True 217 | always_true.__gt__.return_value = True 218 | always_true.__lt__.return_value = True 219 | always_true.__abs__.return_value = 0 220 | always_true.__neg__.return_value = 0 221 | always_true.__pos__.return_value = 0 222 | 223 | if is_data_collection_mode(): 224 | call_refs = [agent.call_and_record_data_remotely( 225 | *args, **kwargs) for agent in get_all_agents()] 226 | readys, unreadys = ray.wait( 227 | call_refs, num_returns=len(call_refs), timeout=5) 228 | for task in unreadys: 229 | ray.cancel(task, force=True) 230 | datapoints = ray.get(readys) 231 | collector = get_data_collector() 232 | for dp in datapoints: 233 | if dp is not None: 234 | collector.record(dp) 235 | return always_true 236 | else: 237 | # TODO: non data collection mode 238 | # return code_agent_state.call_and_maybe_fix_exception(*args, **kwargs) 239 | return always_true 240 | 241 | setattr(_wrapper, '__code_agent_state__', code_agent_state) 242 | return _wrapper 243 | 244 | if method is not None: 245 | return _harness(method) 246 | return _harness 247 | 248 | 249 | def print_candidate_rank(ranked_candidate): 250 | if not ranked_candidate: 251 | return 252 | print('\nCandidate rank:') 253 | for score, candidate_id in ranked_candidate: 254 | agent = find_agent(candidate_id) 255 | if agent is not None: 256 | if agent.debug_oracle_pass(): 257 | oracle_pass = '✔' 258 | else: 259 | oracle_pass = '✖' 260 | else: 261 | oracle_pass = 'N/A' 262 | print(f'{score:8}: {candidate_id:50} - {oracle_pass}') 263 | 264 | 265 | def load_locked_tests(): 266 | if not os.path.exists('locked_tests.json'): 267 | return {} 268 | with open('locked_tests.json', 'r') as f: 269 | return json.load(f) 270 | 271 | 272 | def save_locked_tests(locked_tests): 273 | with open('locked_tests.json', 'w') as f: 274 | json.dump(locked_tests, f, indent=2) 275 | 276 | 277 | def execute_all_tests_in_parallel(): 278 | all_test_cases = get_all_test_cases() 279 | with data_collection_mode(test_case='parallel_mode'): 280 | with ThreadPool() as p: 281 | p.map(lambda x: x(), all_test_cases) 282 | 283 | 284 | def refresh_all_data(): 285 | DATA_DIST.clear() 286 | auto_import() 287 | execute_all_tests_in_parallel() 288 | 289 | 290 | def generate_code(): 291 | auto_import() 292 | genesis = get_genesis() 293 | genesis.generate_seed_candidate() 294 | 295 | 296 | def generate_test(): 297 | auto_import() 298 | genesis = get_genesis() 299 | genesis.generate_test_case() 300 | 301 | 302 | def get_test_and_candidate_info(): 303 | refresh_all_data() 304 | genesis = get_genesis() 305 | function_name = genesis.function_name 306 | 307 | locked_tests = load_locked_tests() 308 | 309 | candidate_to_passed_locked_tests = defaultdict(dict) 310 | candidate_to_failed_locked_tests = defaultdict(dict) 311 | test_info = [] 312 | candidate_to_input_output = get_candidate_input_output() 313 | test_dist = get_test_dist() 314 | for input, dist_with_output in test_dist.items(): 315 | dist = [x[0] for x in dist_with_output] 316 | outputs = [x[1] for x in dist_with_output] 317 | most_frequent_output = outputs[0] 318 | print( 319 | f'=== {shorten_answer(input)} -> {most_frequent_output}, entropy: {entropy_list(dist):.2f}, dist: {shorten_list(dist)}') 320 | 321 | outputs_info = defaultdict(list) 322 | for candidate_id, input_to_datapoint_dict in candidate_to_input_output.items(): 323 | for _, datapoint in input_to_datapoint_dict.items(): 324 | try: 325 | input_str, output_str = datapoint_to_input_output_str( 326 | datapoint) 327 | except: 328 | continue 329 | if input_str != input: 330 | continue 331 | 332 | call_str = make_function_call_statement_str( 333 | datapoint.input, datapoint.output, function_name, limit=1000) 334 | outputs_info[output_str].append({ 335 | 'datapoint': datapoint, 336 | 'call_str': call_str, 337 | 'candidate_id': candidate_id, 338 | }) 339 | 340 | default_output_str = most_frequent_output 341 | locked_output = locked_tests.get(input, None) 342 | locked = False 343 | if (locked_output is not None) and (locked_output in outputs): 344 | default_output_str = locked_output 345 | locked = True 346 | 347 | this_test_id = alphanumeric_uuid() 348 | if locked: 349 | for output, info in outputs_info.items(): 350 | if output == default_output_str: 351 | for i in info: 352 | candidate_to_passed_locked_tests[i['candidate_id']][this_test_id] = i['call_str'] 353 | else: 354 | for i in info: 355 | candidate_to_failed_locked_tests[i['candidate_id']][this_test_id] = i['call_str'] 356 | 357 | test_info.append({ 358 | 'id': this_test_id, 359 | 'input': input, 360 | 'default_output_str': default_output_str, 361 | 'default_call_str': outputs_info[default_output_str][0]['call_str'], 362 | 'outputs': outputs, # in rank order. 363 | 'outputs_info': outputs_info, 364 | 'locked': locked, 365 | }) 366 | 367 | all_candidates = get_all_agents() 368 | all_candidate_ids = [c.id for c in all_candidates] 369 | 370 | candidate_to_stats_scores = calculate_candidates_stats_scores( 371 | all_candidate_ids) 372 | candidate_info = [{ 373 | 'candidate_id': candidate_id, 374 | 'candidate': find_agent(candidate_id), 375 | 'stats_score': stats_score, 376 | 'tests_score': len(candidate_to_passed_locked_tests[candidate_id]), 377 | 'passed_tests': candidate_to_passed_locked_tests[candidate_id], 378 | 'failed_tests': candidate_to_failed_locked_tests[candidate_id], 379 | } for candidate_id, stats_score in candidate_to_stats_scores.items()] 380 | 381 | test_info = sorted(test_info, key=lambda x: x['locked'], reverse=True) 382 | candidate_info = sorted(candidate_info, key=lambda x: ( 383 | x['tests_score'], x['stats_score']), reverse=True) 384 | return { 385 | 'test_info': test_info, 386 | 'candidate_info': candidate_info, 387 | 'locked_tests': locked_tests 388 | } 389 | -------------------------------------------------------------------------------- /src/geniz/example/geniz/data_collector.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from collections import defaultdict 3 | from contextlib import contextmanager 4 | from threading import Lock 5 | from typing import Any, Dict, List, Tuple 6 | 7 | from pydantic import BaseModel 8 | 9 | from .util import shorten_answer 10 | 11 | _lock = Lock() 12 | DATA_DIST = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) 13 | 14 | COLLECTION_MODE = False 15 | CURRENT_COLLECTOR = None 16 | 17 | 18 | class DataPoint(BaseModel): 19 | input: Tuple 20 | output: Any 21 | candidate: str = '' 22 | test_case: str = '' 23 | 24 | def debug_str(self): 25 | msg = f'DataPoint\n input: {str(self.input)[:50]}\n output: {str(self.output)[:50]}\n candidate: {self.candidate}\n test_case: {self.test_case}' 26 | return msg 27 | 28 | 29 | def datapoint_to_input_output_str(datapoint): 30 | input_str = str(datapoint.input) 31 | if isinstance(datapoint.output, Exception): 32 | e = datapoint.output 33 | output_str = f'Exception({type(e).__name__}: {e})' 34 | else: 35 | output_str = str(datapoint.output) 36 | return input_str, output_str 37 | 38 | class DataCollector(BaseModel): 39 | test_case: str 40 | data_points: List[DataPoint] = [] 41 | 42 | def record(self, datapoint: DataPoint): 43 | # Skip invalid executions. 44 | # if datapoint.output is None or isinstance(datapoint.output, Exception): 45 | # return 46 | datapoint.test_case = self.test_case 47 | self.data_points.append(datapoint) 48 | # print(f'Record {datapoint.debug_str()}') 49 | 50 | def merge(self): 51 | _lock.acquire() 52 | for datapoint in self.data_points: 53 | try: 54 | input_str, output_str = datapoint_to_input_output_str(datapoint) 55 | DATA_DIST[input_str][output_str][datapoint.candidate].append( 56 | datapoint) 57 | except: 58 | pass 59 | _lock.release() 60 | 61 | 62 | class ReenterException(Exception): 63 | pass 64 | 65 | 66 | def _enter_collection_mode(collector: DataCollector): 67 | global COLLECTION_MODE, CURRENT_COLLECTOR 68 | if COLLECTION_MODE is True: 69 | raise ReenterException('re-enter collection mode') 70 | COLLECTION_MODE = True 71 | CURRENT_COLLECTOR = collector 72 | 73 | 74 | def _leave_collection_mode(): 75 | global COLLECTION_MODE, CURRENT_COLLECTOR 76 | if COLLECTION_MODE is False: 77 | print('Warning leave collection mode without entering') 78 | COLLECTION_MODE = False 79 | CURRENT_COLLECTOR = None 80 | 81 | 82 | def is_data_collection_mode() -> bool: 83 | return COLLECTION_MODE 84 | 85 | 86 | def get_data_collector() -> DataCollector: 87 | return CURRENT_COLLECTOR # type: ignore 88 | 89 | 90 | def reduce_to_most_frequent_answer() -> Dict[Tuple, str]: 91 | most_frequent = dict() 92 | for input, output_dist in DATA_DIST.items(): 93 | if not output_dist: 94 | continue 95 | answer = max(output_dist.items(), key=lambda x: len(x[1])) 96 | most_frequent[input] = shorten_answer(answer[0]) 97 | return most_frequent 98 | 99 | 100 | def calculate_candidates_stats_scores(all_candidates: List[str] = []) -> Dict[str, int]: 101 | candidate_to_score = defaultdict(int) 102 | for c in all_candidates: 103 | candidate_to_score[c] = 0 104 | for _, output_dist in DATA_DIST.items(): 105 | if not output_dist: 106 | continue 107 | sort_by_frequency = sorted( 108 | output_dist.items(), reverse=True, key=lambda x: len(x[1])) 109 | # Find the first valid output to credit candidates score. 110 | bounty = 3 111 | for _, candidate_dist in sort_by_frequency: 112 | for candidate in candidate_dist: 113 | candidate_to_score[candidate] += bounty 114 | bounty -= 1 115 | if bounty == 0: 116 | break 117 | return candidate_to_score 118 | 119 | 120 | def get_test_dist() -> Dict[Tuple, Any]: 121 | test_dist = dict() 122 | for input, output_dist in DATA_DIST.items(): 123 | dist = sorted([(len(points), output) 124 | for output, points in output_dist.items()], reverse=True, key=lambda x: x[0]) 125 | test_dist[input] = dist 126 | return test_dist 127 | 128 | 129 | def get_candidate_input_output() -> Dict[str, Dict[str, DataPoint]]: 130 | candidate_input_output = defaultdict(dict) 131 | for input, output_dist in DATA_DIST.items(): 132 | for _, candidate_dist in output_dist.items(): 133 | for candidate_id, datapoints in candidate_dist.items(): 134 | candidate_input_output[candidate_id][input] = datapoints[0] 135 | return candidate_input_output 136 | 137 | 138 | def get_data_dist(): 139 | return DATA_DIST 140 | 141 | 142 | @contextmanager 143 | def data_collection_mode(test_case): 144 | collector = DataCollector(test_case=test_case) 145 | try: 146 | _enter_collection_mode(collector) 147 | yield collector 148 | except Exception as e: 149 | print(f' suppress exception: {e}') 150 | traceback.print_exc() 151 | finally: 152 | collector.merge() 153 | _leave_collection_mode() 154 | -------------------------------------------------------------------------------- /src/geniz/example/geniz/debugger.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | 4 | from .data_collector import is_data_collection_mode 5 | 6 | _REGISTRY = [] 7 | 8 | 9 | def DebugAgent(**kwargs): 10 | def _harness(method): 11 | @functools.wraps(method) 12 | def _wrapper(*args, **kwargs): 13 | name = method.__name__ 14 | source_filename = os.path.basename(method.__code__.co_filename) 15 | test_case = f'{name}({source_filename})' 16 | if not is_data_collection_mode(): 17 | print( 18 | f'>>> Skip test case {test_case}, not in collection mode') 19 | return 20 | print( 21 | f'>>> Run test case {test_case} in data collection mode') 22 | method(*args, **kwargs) 23 | 24 | _REGISTRY.append(_wrapper) 25 | return _wrapper 26 | 27 | return _harness 28 | 29 | 30 | def get_all_test_cases(): 31 | return _REGISTRY 32 | -------------------------------------------------------------------------------- /src/geniz/example/geniz/llm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List, Optional 3 | 4 | import litellm 5 | from pydantic import BaseModel 6 | 7 | BATCH_INFERENCE_N = 1 8 | 9 | 10 | class ChatMessage(BaseModel): 11 | role: str 12 | content: str 13 | 14 | def to_dict(self) -> Dict[str, str]: 15 | return {"role": self.role, "content": self.content} 16 | 17 | def to_string(self) -> str: 18 | return f"{self.role}: {self.content}" 19 | 20 | 21 | def query_llm(prompt: str, *, system_message: str = '', previous_history: List[ChatMessage] = [], filename: Optional[str] = None): 22 | MODEL = os.getenv('MODEL', 'openai/Phi-3-mini-128k-instruct-a100') 23 | API_BASE = os.getenv('API_BASE', 'https://geniz.ai/v1') 24 | API_KEY = os.getenv('API_KEY', '') 25 | n = int(os.getenv('BATCH_INFERENCE_N', '1')) 26 | 27 | messages: List[ChatMessage] = ( 28 | previous_history + [ChatMessage(role='user', content=prompt)]) 29 | if filename is not None: 30 | with open(f'{filename}.prompt.txt', 'w') as f: 31 | f.write(f'{MODEL}\n\n{system_message}\n\n') 32 | for i, message in enumerate(messages): 33 | f.write(f'=== {i}: {message.role} ===\n') 34 | f.write(message.content) 35 | f.write('\n') 36 | 37 | final_messages = [msg.to_dict() for msg in messages] 38 | if system_message: 39 | final_messages = [ 40 | {"role": "system", "content": system_message}] + final_messages 41 | 42 | 43 | if n == 1: 44 | n = None 45 | response = litellm.completion( 46 | model=MODEL, 47 | api_key=API_KEY, 48 | api_base=API_BASE, 49 | messages=final_messages, 50 | temperature=0.9, 51 | num_retries=1, 52 | n=n, 53 | ) 54 | replys = [choice.message.content for choice in response.choices] 55 | histories = [] 56 | for reply in replys: 57 | new_messages = messages + \ 58 | [ChatMessage(role='assistant', content=reply)] 59 | histories.append(new_messages) 60 | if filename is not None: 61 | with open(f'{filename}.prompt.txt', 'a') as f: 62 | for i, reply in enumerate(replys): 63 | f.write(f'\n=== Reply {i} ===\n') 64 | f.write(reply) 65 | return replys, histories 66 | -------------------------------------------------------------------------------- /src/geniz/example/geniz/prompts/programmer_prompt.yaml: -------------------------------------------------------------------------------- 1 | _type: prompt 2 | input_variables: 3 | - objective 4 | template: | 5 | **Role**: You are a software programmer. 6 | 7 | **Task**: As a programmer, you are required to complete the function. Use a Chain-of-Thought approach to break down the problem, create pseudocode, and then write the code in Python language. Ensure that your code is efficient, readable, and well-commented. 8 | 9 | For example: 10 | 11 | **Input Code Snippet**: 12 | ```python 13 | from typing import List 14 | 15 | def has_close_elements(numbers: List[float], threshold: float) -> bool: 16 | """ 17 | Check if in given list of numbers, are any two numbers closer to each other than given threshold. 18 | >>> has_close_elements([1.0, 2.0, 3.0], 0.5) 19 | False 20 | >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3) 21 | True 22 | """ 23 | 24 | # TODO: Implement the logic to determine if any two numbers are closer than the threshold 25 | pass 26 | 27 | # Add your code here to complete the function (Do not include any test cases) 28 | ``` 29 | 30 | **Instructions**: 31 | 1. **Understand and Clarify**: Make sure you understand the task. 32 | 2. **Algorithm/Method Selection**: Decide on the most efficient way. 33 | 3. **Pseudocode Creation**: Write down the steps you will follow in pseudocode. 34 | 4. **Code Generation**: Translate your pseudocode into executable Python code. -------------------------------------------------------------------------------- /src/geniz/example/geniz/prompts/programmer_prompt_simple.yaml: -------------------------------------------------------------------------------- 1 | _type: prompt 2 | input_variables: 3 | - objective 4 | template: | 5 | **Role**: You are a software programmer. 6 | 7 | **Instructions**: 8 | 1. **Understand and Clarify**: Make sure you understand the task. 9 | 2. **Algorithm/Method Selection**: Decide on the most efficient way. 10 | 3. **Pseudocode Creation**: Write down the steps you will follow in pseudocode. 11 | 4. **Code Generation**: Translate your pseudocode into executable Python code. -------------------------------------------------------------------------------- /src/geniz/example/geniz/prompts/test_designer_prompt.yaml: -------------------------------------------------------------------------------- 1 | _type: prompt 2 | input_variables: 3 | - function_name 4 | - input_code_snippet 5 | template: | 6 | **Role**: As a tester, your task is to create comprehensive test cases for the incomplete `{function_name}` function. These test cases should encompass Basic, Edge to ensure the code's robustness, reliability. 7 | 8 | **Input Code Snippet**: 9 | ```python 10 | {input_code_snippet} 11 | ``` 12 | 13 | **1. Basic Test Cases**: 14 | - **Objective**: To verify the fundamental functionality of the `{function_name}` function under normal 15 | conditions. 16 | 17 | **2. Edge Test Cases**: 18 | - **Objective**: To evaluate the function's behavior under extreme or unusual conditions. 19 | 20 | **Instructions**: 21 | - Implement a comprehensive set of test cases following the guidelines above. 22 | - Ensure each test case is well-documented with comments explaining the scenario it covers. 23 | - Pay special attention to edge cases as they often reveal hidden bugs. 24 | - Do not use any framework or testing library like pytest, implement as plain python with name == "__main__" entrance. 25 | -------------------------------------------------------------------------------- /src/geniz/example/geniz/prompts/test_designer_prompt_simple.yaml: -------------------------------------------------------------------------------- 1 | _type: prompt 2 | input_variables: 3 | - function_name 4 | - input_code_snippet 5 | template: | 6 | **Role**: As a tester, your task is to create comprehensive test cases for the incomplete `{function_name}` function. These test cases should encompass Basic, Edge to ensure the code's robustness, reliability. 7 | 8 | **Input Code Snippet**: 9 | ```python 10 | {input_code_snippet} 11 | ``` 12 | 13 | **Instructions**: 14 | - The test case must be as short as possible. 15 | - Do not use `pytest` or `unittest` frameworks. 16 | - Implement as plain python functions. 17 | -------------------------------------------------------------------------------- /src/geniz/example/geniz/python_code.py: -------------------------------------------------------------------------------- 1 | import ast 2 | from typing import Optional 3 | 4 | from pydantic import BaseModel 5 | 6 | 7 | _CODE_AGENT_DECORATOR = ast.Call( 8 | func=ast.Attribute( 9 | value=ast.Name(id='geniz', ctx=ast.Load()), 10 | attr='CodeAgent', 11 | ctx=ast.Load()), 12 | args=[], 13 | keywords=[]) 14 | 15 | 16 | _DEBUG_AGENT_DECORATOR = ast.Call( 17 | func=ast.Attribute( 18 | value=ast.Name(id='geniz', ctx=ast.Load()), 19 | attr='DebugAgent', 20 | ctx=ast.Load()), 21 | args=[], 22 | keywords=[]) 23 | 24 | 25 | def maybe_add_decorators(tree: ast.Module, function_name: Optional[str] = None) -> bool: 26 | has_func_def = False 27 | tree.body.insert(0, ast.Import(names=[ast.alias(name='geniz')])) 28 | for statement in reversed(tree.body): 29 | if isinstance(statement, ast.FunctionDef): 30 | if statement.name.startswith('test_'): 31 | statement.decorator_list = [_DEBUG_AGENT_DECORATOR] 32 | else: 33 | if function_name: 34 | if statement.name == function_name: 35 | has_func_def = True 36 | statement.decorator_list = [_CODE_AGENT_DECORATOR] 37 | return has_func_def 38 | 39 | 40 | def find_longest_code_block(text, function_name: Optional[str] = None, require_func_def = False): 41 | # May generate several code blocks, alwasy choose the longest. 42 | code_block_starters = text.split('```python') 43 | code_blocks = [] 44 | for starter in code_block_starters: 45 | try: 46 | code_block = starter.split('```')[0] 47 | # Valid only if be able to parse it. 48 | # We assume only to generate function for now. 49 | tree = ast.parse(code_block) 50 | has_func_def = maybe_add_decorators(tree, function_name) 51 | if require_func_def: 52 | if not has_func_def: 53 | continue 54 | code_block = ast.unparse(tree) 55 | except: 56 | continue 57 | code_blocks.append(code_block) 58 | return max(code_blocks, key=lambda x: len(x)) 59 | 60 | 61 | class PythonCode(BaseModel): 62 | # Code snippet. 63 | code_text: Optional[str] = None 64 | # Exception string if error from parsing 65 | parsing_error: Optional[str] = None 66 | 67 | def valid(self) -> bool: 68 | return self.code_text is not None 69 | 70 | @classmethod 71 | def extract_main_program_code_block_from_full_reply( 72 | cls, 73 | full_reply: str, 74 | function_name: Optional[str] = None 75 | ) -> "PythonCode": 76 | """Extract and validate the code block from replied texts.""" 77 | try: 78 | code_gen_section = full_reply 79 | # May generate several code blocks, alwasy choose the longest. 80 | code_block_text = find_longest_code_block( 81 | code_gen_section, function_name=function_name, require_func_def=True) 82 | return cls(code_text=code_block_text) 83 | except Exception as e: 84 | return cls(code_text=None, parsing_error=str(e)) 85 | 86 | @classmethod 87 | def extract_test_case_code_block_from_full_reply( 88 | cls, 89 | full_reply: str, 90 | function_name: Optional[str] = None 91 | ) -> "PythonCode": 92 | """Extract and validate test case code block from replied texts.""" 93 | try: 94 | # May generate several code blocks, alwasy choose the longest. 95 | code_block_text = find_longest_code_block( 96 | full_reply, function_name=function_name) 97 | return cls(code_text=code_block_text) 98 | except Exception as e: 99 | return cls(code_text=None, parsing_error=str(e)) 100 | 101 | 102 | _TEST_CODE = ''' 103 | def make_palindrome(string: str) -> str: 104 | """ Find the shortest palindrome that begins with a supplied string. 105 | Algorithm idea is simple: 106 | - Find the longest postfix of supplied string that is a palindrome. 107 | - Append to the end of the string reverse of a string prefix that comes before the palindromic suffix. 108 | >>> make_palindrome('') 109 | '' 110 | >>> make_palindrome('cat') 111 | 'catac' 112 | >>> make_palindrome('cata') 113 | 'catac' 114 | """ 115 | postfix = string[::-1] # Reverse the string 116 | for i in range(len(string)): 117 | if string[:len(string)-i] == postfix[i:]: 118 | return string + string[:i][::-1] 119 | return '' 120 | 121 | print(make_palindrome("")) 122 | print(make_palindrome("cat")) 123 | print(make_palindrome("cata")) 124 | ''' 125 | 126 | 127 | def test_parse_and_add_decorator(): 128 | tree = ast.parse(_TEST_CODE) 129 | maybe_add_decorators(tree) 130 | print(ast.unparse(tree)) 131 | 132 | 133 | if __name__ == '__main__': 134 | test_parse_and_add_decorator() 135 | -------------------------------------------------------------------------------- /src/geniz/example/geniz/round_info.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from datetime import datetime, timedelta 3 | 4 | from pydantic import BaseModel 5 | 6 | from .util import PersistStateToFile 7 | 8 | 9 | class RoundInfo(BaseModel, PersistStateToFile): 10 | created: datetime = datetime.now() 11 | round: int = 1 12 | session_id: str = str(uuid.uuid4()) 13 | 14 | def dead(self): 15 | if datetime.now() - self.created > timedelta(minutes=25): 16 | return True 17 | return False 18 | 19 | 20 | def get_round_info(): 21 | round_info = RoundInfo().load() 22 | return round_info 23 | 24 | def get_round_filename_prefix(): 25 | round_info = RoundInfo().load() 26 | return f'round_{round_info.round}' -------------------------------------------------------------------------------- /src/geniz/example/geniz/test_designer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | 5 | import isort 6 | 7 | from .llm import query_llm 8 | from .python_code import PythonCode 9 | from .util import alphanumeric_uuid, load_prompt 10 | 11 | TEST_DESIGNER_PROMPT = load_prompt( 12 | os.path.join( 13 | os.path.dirname(os.path.realpath(__file__)), 14 | "prompts/test_designer_prompt_simple.yaml", 15 | ) 16 | ) 17 | 18 | 19 | def create_test_file(source_filename, function_name, source_code) -> bool: 20 | source_error_filename = f'{source_filename}.error' 21 | module_name = source_filename.removesuffix('.py') 22 | new_id = str(alphanumeric_uuid()[-5:]) 23 | output_filename = f'{module_name}_{new_id}_test.py' 24 | 25 | # Skip the submitted result. 26 | if source_filename == 'output.py': 27 | return False 28 | # Skip if original program already contains error 29 | # Maybe add another fixer here. 30 | if os.path.exists(source_error_filename): 31 | os.system(f'rm {source_error_filename}*') 32 | return False 33 | if os.path.exists(output_filename): 34 | return False 35 | 36 | logging.info( 37 | f'TestDesigner: generating {output_filename} for {function_name} in {source_filename}...') 38 | 39 | system_message = TEST_DESIGNER_PROMPT.format( 40 | function_name=function_name, input_code_snippet=source_code) 41 | message = f'''Please generate test cases for following code, and add "from {module_name} import {function_name}" in top and assume {function_name} is already defined in {module_name}. 42 | ```python 43 | {source_code} 44 | ``` 45 | 46 | Test case is like: 47 | ```python 48 | def test_{function_name}(): 49 | assert xxxx 50 | assert xxxx 51 | ... 52 | ``` 53 | ''' 54 | replys, histories = query_llm(message, system_message=system_message, filename=output_filename) 55 | for i, reply in enumerate(replys): 56 | test_code = PythonCode.extract_test_case_code_block_from_full_reply( 57 | reply, function_name=function_name) 58 | if not test_code.valid(): 59 | logging.info(f'Test case is invalid:\n\n{test_code.code_text}') 60 | return False 61 | test_case_filename = f'{module_name}_{new_id}_{i}_test.py' 62 | with open(test_case_filename, 'w') as f: 63 | code_text = f'from {module_name} import {function_name}\n{test_code.code_text}' 64 | f.write(code_text) 65 | logging.info(f'Wrote to {test_case_filename}') 66 | isort.file(test_case_filename) 67 | new_history_filename = f'{test_case_filename}.history.pickle' 68 | with open(new_history_filename, 'wb') as f: 69 | pickle.dump(histories[i], f) 70 | logging.info(f'Wrote to {new_history_filename}') 71 | return True 72 | -------------------------------------------------------------------------------- /src/geniz/example/geniz/util.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import string 4 | import uuid 5 | from math import log2 6 | from pathlib import Path 7 | from typing import Any, List 8 | 9 | import yaml 10 | from jinja2 import Template 11 | from pydantic import BaseModel 12 | 13 | 14 | class PersistStateToFile: 15 | 16 | @classmethod 17 | def filename(cls): 18 | return f'{cls.__name__}.json' 19 | 20 | @classmethod 21 | def load(cls): 22 | try: 23 | with open(cls.filename(), 'r') as f: 24 | json_dict = json.load(f) 25 | return cls(**json_dict) 26 | except: 27 | return cls() 28 | 29 | def save(self): 30 | with open(self.filename(), 'w') as f: 31 | f.write(self.model_dump_json(indent=2)) 32 | 33 | 34 | BASE62 = string.ascii_uppercase + string.ascii_lowercase + string.digits 35 | 36 | 37 | def base62_encode(num) -> str: 38 | if num == 0: 39 | return BASE62[0] 40 | arr = [] 41 | while num: 42 | num, rem = divmod(num, 62) 43 | arr.append(BASE62[rem]) 44 | arr.reverse() 45 | return "".join(arr) 46 | 47 | 48 | def alphanumeric_uuid() -> str: 49 | """Generate a base62-encoded UUID for IDs that need to be URL-safe.""" 50 | uuid_hex = uuid.uuid4().hex 51 | uuid_int = int(uuid_hex, 16) 52 | return base62_encode(uuid_int) 53 | 54 | 55 | def shorten_answer(var: Any, limit: int = 50) -> str: 56 | var_s = str(var) 57 | if len(var_s) > limit: 58 | var_s = var_s[:limit] + '...' 59 | return var_s 60 | 61 | 62 | def shorten_list(var: List, limit: int = 5) -> str: 63 | if len(var) > limit: 64 | return str(var[:limit] + ['...']) 65 | return str(var) 66 | 67 | 68 | def entropy_list(var: List) -> float: 69 | def normalize(data): 70 | """ 71 | Normalize an array of integers to get probabilities. 72 | """ 73 | data_sum = sum(data) 74 | probabilities = [val / data_sum for val in data] 75 | return probabilities 76 | 77 | def entropy(probabilities): 78 | """ 79 | Calculate the entropy of a list of probabilities. 80 | """ 81 | ent = 0.0 82 | for prob in probabilities: 83 | if prob > 0: 84 | ent += -prob * log2(prob) 85 | return ent 86 | 87 | return entropy(normalize(var)) 88 | 89 | 90 | def make_function_call_statement_str(input, output, function_name, limit=200) -> str: 91 | inputs = [] 92 | for i in input: 93 | if isinstance(i, str): 94 | inputs.append(f'"{i}"') 95 | else: 96 | inputs.append(str(i)) 97 | if not inputs: 98 | return '' 99 | args_str = ', '.join(inputs) 100 | 101 | if isinstance(output, str): 102 | output_str = f'"{output}"' 103 | elif isinstance(output, Exception): 104 | output_str = f'Exception("{str(output)}")' 105 | else: 106 | output_str = str(output) 107 | 108 | formatted_str = f'{function_name}({args_str}) -> {output_str}' 109 | if limit: 110 | return shorten_answer(formatted_str, limit=limit) 111 | return formatted_str 112 | 113 | 114 | class PromptTemplate(BaseModel): 115 | """Lightweight prompt template implementation.""" 116 | 117 | template: str 118 | input_variables: List[str] 119 | template_format: str # ["f-string", "jinja2"] 120 | 121 | def format(self, **kwargs): 122 | """Format the prompt template.""" 123 | if self.template_format == "jinja2": 124 | return self._format_jinja2(**kwargs) 125 | return self.template.format(**kwargs) 126 | 127 | def _format_jinja2(self, **kwargs): 128 | return Template(self.template).render(**kwargs) 129 | 130 | 131 | def load_prompt(path: str) -> PromptTemplate: 132 | """Load prompt from file.""" 133 | file_path = Path(path) 134 | if file_path.suffix == ".json": 135 | with open(file_path) as f: 136 | config = json.load(f) 137 | elif file_path.suffix == ".yaml": 138 | with open(file_path, "r") as f: 139 | config = yaml.safe_load(f) 140 | else: 141 | raise ValueError(f"Got unsupported file type {file_path.suffix}") 142 | 143 | return PromptTemplate( 144 | template=config["template"], 145 | input_variables=config["input_variables"], 146 | template_format=( 147 | config["template_format"] if "template_format" in config else "f-string" 148 | ), 149 | ) 150 | -------------------------------------------------------------------------------- /src/geniz/example/input.py: -------------------------------------------------------------------------------- 1 | import geniz 2 | from typing import List 3 | 4 | @geniz.CodeAgent() 5 | def findMedianSortedArrays(nums1: List[int], nums2: List[int]) -> float: 6 | """ 7 | Given two sorted arrays nums1 and nums2 of size m and n respectively, return the median of the two sorted arrays. 8 | The overall run time complexity should be O(log (m+n)). 9 | 10 | 11 | Example 1: 12 | Input: nums1 = [1,3], nums2 = [2] 13 | Output: 2.00000 14 | Explanation: merged array = [1,2,3] and median is 2. 15 | 16 | Example 2: 17 | Input: nums1 = [1,2], nums2 = [3,4] 18 | Output: 2.50000 19 | Explanation: merged array = [1,2,3,4] and median is (2 + 3) / 2 = 2.5. 20 | """ -------------------------------------------------------------------------------- /src/geniz/example/keys.json.example: -------------------------------------------------------------------------------- 1 | { 2 | "MODEL": "openai/Phi-3-mini-128k-instruct", 3 | "API_BASE": "https://shale.live/v1", 4 | "API_KEY": "trapile.ai" 5 | } -------------------------------------------------------------------------------- /src/geniz/example/webserver.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import logging 4 | import os 5 | from functools import partial 6 | 7 | import gradio as gr 8 | import ray 9 | 10 | from geniz.coder import ( 11 | generate_code, 12 | generate_test, 13 | get_genesis, 14 | get_test_and_candidate_info, 15 | save_locked_tests, 16 | ) 17 | 18 | logging.basicConfig( 19 | format='%(asctime)s %(levelname)-8s %(message)s', 20 | level=logging.INFO, 21 | datefmt='%Y-%m-%d %H:%M:%S') 22 | 23 | os.environ['RAY_IGNORE_UNHANDLED_ERRORS'] = '1' 24 | 25 | if os.path.exists('keys.json'): 26 | with open('keys.json', 'r') as f: 27 | cfg = json.load(f) 28 | for i in cfg: 29 | os.environ[i] = cfg[i] 30 | 31 | 32 | if gr.NO_RELOAD: 33 | ray.init(include_dashboard=False, ignore_reinit_error=True) 34 | 35 | 36 | def get_original_problem_prompt(): 37 | genesis = get_genesis() 38 | return genesis.clean_source_code 39 | 40 | 41 | initial_app_state = get_test_and_candidate_info() 42 | 43 | 44 | _CSS = ''' 45 | .test_case_container { 46 | border-width: medium; 47 | } 48 | 49 | .test_case_locked { 50 | border-color: green; 51 | } 52 | ''' 53 | 54 | with gr.Blocks(css=_CSS, title='Geniz') as demo: 55 | app_state = gr.State(initial_app_state) 56 | 57 | def click_run_all_tests(): 58 | return get_test_and_candidate_info() 59 | 60 | def click_gen_code(): 61 | generate_code() 62 | return get_test_and_candidate_info() 63 | 64 | def click_gen_test(): 65 | generate_test() 66 | return get_test_and_candidate_info() 67 | 68 | with gr.Row(): 69 | with gr.Accordion(label='LLM settings', open=True): 70 | with gr.Row(): 71 | api_base_box = gr.Textbox( 72 | os.getenv('API_BASE', 'https://geniz.ai/v1'), label='API_BASE', interactive=True) 73 | api_key_box = gr.Textbox( 74 | os.getenv('API_KEY', ''), label='API_KEY', type='password', interactive=True) 75 | model_box = gr.Textbox( 76 | os.getenv('MODEL', 'openai/Phi-3-mini-128k-instruct-a100'), label='MODEL', interactive=True) 77 | batch_inference_n = gr.Dropdown( 78 | [1, 3, 5, 10], value=1, label='Batch Inference', interactive=True) 79 | 80 | def change_api_base(input): 81 | os.environ['API_BASE'] = input 82 | api_base_box.input(change_api_base, inputs=[api_base_box]) 83 | 84 | def change_api_key(input): 85 | os.environ['API_KEY'] = input 86 | api_key_box.input(change_api_key, inputs=[api_key_box]) 87 | 88 | def change_model(input): 89 | os.environ['MODEL'] = input 90 | model_box.input(change_model, inputs=[model_box]) 91 | 92 | def change_batch_inference(input): 93 | os.environ['BATCH_INFERENCE_N'] = str(input) 94 | batch_inference_n.input( 95 | change_batch_inference, inputs=[batch_inference_n]) 96 | 97 | def update_problem_prompt(input): 98 | genesis = get_genesis() 99 | genesis.source_code = input 100 | 101 | with gr.Row(): 102 | with gr.Accordion(label='Problem Description', open=True): 103 | prompt_editor = gr.Code( 104 | value=get_original_problem_prompt, 105 | language='python', 106 | show_label=False, 107 | interactive=True, 108 | ) 109 | prompt_editor.change(lambda problem: update_problem_prompt( 110 | problem), inputs=[prompt_editor]) 111 | 112 | gr.Markdown("---") 113 | with gr.Row(): 114 | gen_code_button = gr.Button("Generate Code") 115 | gen_test_button = gr.Button("Generate Test") 116 | run_all_tests_button = gr.Button("Run All Tests") 117 | 118 | @gr.render(inputs=[app_state]) 119 | def render_app(this_app_state): 120 | with gr.Row(equal_height=True): 121 | with gr.Column(): 122 | all_candidate_info = this_app_state['candidate_info'] 123 | for i, candidate_info in enumerate(all_candidate_info): 124 | candidate_id = candidate_info['candidate_id'] 125 | candidate = candidate_info['candidate'] 126 | tests_score = candidate_info['tests_score'] 127 | stars = '⭐' * tests_score 128 | with gr.Accordion(label=f'{candidate_id} {stars}', 129 | open=False): 130 | with gr.Row(): 131 | code_editor = gr.Code( 132 | value=candidate.clean_source_code, 133 | language='python', 134 | interactive=True, 135 | show_label=False) 136 | for test_id, test_call_str in candidate_info['passed_tests'].items(): 137 | with gr.Row(): 138 | gr.Text(f'✅ {test_call_str}', show_label=False) 139 | for test_id, test_call_str in candidate_info['failed_tests'].items(): 140 | with gr.Row(): 141 | gr.Text(f'❌ {test_call_str}', show_label=False) 142 | with gr.Row(): 143 | delete_button = gr.Button('Delete', scale=0) 144 | 145 | def click_delete_button(this_candidate_info, this_app_state): 146 | this_candidate_id = this_candidate_info['candidate_id'] 147 | this_candidate = this_candidate_info['candidate'] 148 | this_candidate.delete() 149 | all_candidate_info = this_app_state['candidate_info'] 150 | this_app_state['candidate_info'] = [ 151 | c for c in all_candidate_info if c['candidate_id'] != this_candidate_id] 152 | return this_app_state 153 | delete_button.click(partial(click_delete_button, copy.copy(candidate_info)), 154 | inputs=[app_state], outputs=app_state) 155 | 156 | fix_button = gr.Button('Fix', scale=0) 157 | 158 | def click_fix_button(this_candidate_info): 159 | this_candidate = this_candidate_info['candidate'] 160 | passed_tests_strs = '\n'.join( 161 | [call_str for _, call_str in this_candidate_info['passed_tests'].items()]) 162 | failed_tests_strs = '\n'.join( 163 | [call_str for _, call_str in this_candidate_info['failed_tests'].items()]) 164 | execution_result = f'''Here are correct/wrong test cases after real execution of the program. 165 | 166 | ### Correct input output cases ### 167 | ``` 168 | {passed_tests_strs} 169 | ``` 170 | 171 | ### Wrong input output cases ### 172 | ``` 173 | {failed_tests_strs} 174 | ``` 175 | 176 | Please fix wrong cases and regenerate the program. 177 | ''' 178 | this_candidate.generate_candidate_by_execution_result( 179 | execution_result) 180 | return get_test_and_candidate_info() 181 | fix_button.click(partial(click_fix_button, copy.copy(candidate_info)), 182 | inputs=None, outputs=app_state) 183 | 184 | with gr.Column(): 185 | for info in this_app_state['test_info']: 186 | default_output_str = info['default_output_str'] 187 | default_call_str = info['default_call_str'] 188 | locked = info['locked'] 189 | elem_id = info['id'] 190 | elem_classes = ['test_case_container'] 191 | if locked: 192 | elem_classes.append('test_case_locked') 193 | with gr.Group(elem_id=elem_id, elem_classes=elem_classes): 194 | with gr.Row(): 195 | test_box = gr.Textbox( 196 | default_call_str, show_label=False, interactive=True) 197 | with gr.Row(): 198 | output_options = info['outputs'] 199 | output_radio_group = gr.Radio( 200 | choices=output_options, 201 | value=default_output_str, 202 | container=False, 203 | interactive=True, 204 | label='Output Values') 205 | lock_checkbox = gr.Checkbox(label='Lock', value=locked) 206 | 207 | def output_radio_group_trigger(this_info, selected_output, this_app_state): 208 | locked_tests = this_app_state['locked_tests'] 209 | output_info = this_info['outputs_info'].get( 210 | selected_output, None) 211 | if output_info is None or len(output_info) == 0: 212 | return this_info['default_call_str'], this_app_state 213 | if this_info['input'] in locked_tests: 214 | if selected_output != locked_tests[this_info['input']]: 215 | locked_tests[this_info['input'] 216 | ] = selected_output 217 | save_locked_tests(locked_tests) 218 | return output_info[0]['call_str'], this_app_state 219 | 220 | output_radio_group.change( 221 | partial(output_radio_group_trigger, 222 | copy.copy(info)), 223 | inputs=[output_radio_group, app_state], 224 | outputs=[test_box, app_state]) 225 | 226 | def lock_checkbox_trigger(this_info, true_or_false, selected_output): 227 | locked_tests = this_app_state['locked_tests'] 228 | if true_or_false is True: 229 | locked_tests[this_info['input'] 230 | ] = selected_output 231 | save_locked_tests(locked_tests) 232 | else: 233 | locked_tests.pop(this_info['input'], None) 234 | save_locked_tests(locked_tests) 235 | 236 | lock_checkbox.change( 237 | partial(lock_checkbox_trigger, copy.copy(info)), 238 | inputs=[lock_checkbox, output_radio_group], 239 | outputs=None, 240 | js='''(x, y) => { 241 | var element = document.getElementById("''' + str(elem_id) + '''"); 242 | if (x) { 243 | element.classList.add("test_case_locked"); 244 | } else { 245 | element.classList.remove("test_case_locked"); 246 | } 247 | return [x, y]; 248 | } 249 | ''') 250 | 251 | gen_code_button.click(click_gen_code, inputs=None, outputs=app_state) 252 | gen_test_button.click(click_gen_test, inputs=None, outputs=app_state) 253 | run_all_tests_button.click( 254 | click_run_all_tests, inputs=None, outputs=app_state) 255 | # code_editor.change(code_editor_change, code_editor, None) 256 | # gr.on(triggers=None, fn=click_run_all_tests, inputs=[], every=2) 257 | # dep = demo.load(click_run_all_tests, inputs=[], outputs=[test_info_state], every=2) 258 | 259 | if __name__ == "__main__": 260 | demo.launch(debug=True) 261 | -------------------------------------------------------------------------------- /static/demo_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sudocode-ai/geniz/0a38c18cdffa0be6d079f2564dd8e8c92c5007ad/static/demo_1.gif -------------------------------------------------------------------------------- /static/geniz_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sudocode-ai/geniz/0a38c18cdffa0be6d079f2564dd8e8c92c5007ad/static/geniz_diagram.png -------------------------------------------------------------------------------- /static/screenshot_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sudocode-ai/geniz/0a38c18cdffa0be6d079f2564dd8e8c92c5007ad/static/screenshot_0.png --------------------------------------------------------------------------------