├── LICENSE ├── figures ├── framework.pdf └── framework.png ├── readme.md ├── requirements.txt └── src ├── benchmarks ├── compile-error-hard.csv.gz ├── compile-error.csv.gz ├── gpt35_reproduce.jsonl.gz ├── gpt35_reproduce_machine.jsonl.gz ├── humaneval-py-hardest50.jsonl ├── humaneval-py.jsonl ├── humaneval-rs-hardest50.jsonl ├── humaneval-rs-sorted.jsonl ├── humaneval-rs.jsonl ├── leetcode-hard-py.jsonl ├── mbpp-py.jsonl ├── mbpp-rs.jsonl ├── rtllm.jsonl.gz ├── simulate-hard.csv.gz ├── verilogeval-machine.jsonl.gz ├── verilogeval-manual.jsonl.gz ├── verilogeval-simulate-hard.jsonl ├── verilogeval-simulate.jsonl ├── verilogeval-simulate2.jsonl ├── verilogeval-syntax-hard.jsonl └── verilogeval-syntax.jsonl ├── executors ├── __init__.py ├── executor_types.py ├── executor_utils.py ├── factory.py ├── leet_executor.py ├── py_executor.py ├── rs_executor.py └── vg_executor.py ├── generators ├── __init__.py ├── agents │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── langchain_callback.cpython-38.pyc │ │ ├── langchain_tools.cpython-38.pyc │ │ ├── openai_function.cpython-38.pyc │ │ ├── plan_execute.cpython-38.pyc │ │ ├── react.cpython-38.pyc │ │ ├── react_json_single_input_parser.cpython-38.pyc │ │ ├── rtlfixer.cpython-38.pyc │ │ └── utils.cpython-38.pyc │ ├── langchain_callback.py │ ├── langchain_tools.py │ ├── openai_function.py │ ├── plan_execute.py │ ├── react.py │ ├── react_json_single_input_parser.py │ ├── rtlfixer.py │ └── utils.py ├── factory.py ├── generator_types.py ├── generator_utils.py ├── model.py ├── parse.py ├── py_generate.py ├── rs_generate.py └── verilog_generate.py ├── main.py ├── scripts ├── run_example.sh ├── run_oneshot_fix_compile.sh ├── run_react_fix_compile.sh └── run_react_fix_simulate.sh ├── task ├── __init__.py ├── oneshot_fix_compile.py ├── react_fix_compile.py ├── react_fix_compile_rtllm.py └── react_fix_simulate.py ├── utils.py └── visualize.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /figures/framework.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/figures/framework.pdf -------------------------------------------------------------------------------- /figures/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/figures/framework.png -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # RTLFixer: Automatically Fixing RTL Syntax Errors with Large Language Models 2 | 3 | This repo holds the code and benchmark datasets for [RTLFixer: Automatically Fixing RTL Syntax Errors with Large Language Models](https://arxiv.org/abs/2311.16543). 4 | 5 | ![RTLFixer Framework](./figures/framework.png) 6 | 7 | 8 | 9 | We release the VerilogEval-Syntax and VerilogEval-Simulate [here](./src/benchmarks) 10 | 11 | 12 | ### Quick Start 13 | 14 | 15 | #### Installation 16 | 17 | To get started: 18 | 19 | 1. Clone this repo: 20 | ```bash 21 | $ git clone https://github.com/ 22 | ``` 23 | 24 | 2. Install the module dependencies into your environment: 25 | ```bash 26 | $ pip install -r requirements.txt 27 | ``` 28 | 29 | 3. Install vcdvcd and verilog-eval 30 | ```bash 31 | $ pip install vcdvcd 32 | $ git clone https://github.com/NVlabs/verilog-eval 33 | $ pip install -e verilog-eval 34 | 35 | ``` 36 | 37 | 4. Set `OPENAI_API_KEY` environment variable to your OpenAI API key: 38 | ```bash 39 | $ export OPENAI_API_KEY= 40 | ``` 41 | 42 | 43 | #### Run 44 | Run scripts: 45 | ```bash 46 | $ cd src 47 | $ ./scripts/run_oneshot_fix_compile.sh 48 | ``` 49 | 50 | From command line: 51 | ```bash 52 | $ python main.py \ 53 | --run_name "test_oneshot_compile" \ 54 | --root_dir "exp" \ 55 | --dataset_path ./benchmarks/verilogeval-syntax-hard.jsonl \ 56 | --task "oneshot_fix_compile" \ 57 | --agent_feedback "rag" \ 58 | --language "verilog" \ 59 | --model "gpt-3.5-turbo-16k-0613" \ 60 | --pass_at_k "1" \ 61 | --num_samples '1' \ 62 | --compiler 'quartus' \ 63 | --verbose 64 | ``` 65 | 66 | 67 | #### Tasks 68 | 69 | - `oneshot_fix_compile` 70 | 71 | - `react_fix_compile` 72 | 73 | - `react_fix_simulate` 74 | 75 | 76 | #### Agent Feedbacks 77 | 78 | Each notebook allows you to specify the reflexion strategy to be used by the agents. The available reflexion strategies, which are defined in an `Enum`, include: 79 | 80 | - `nofeedback` - The agent is given binary feedback (failed or succeed). 81 | 82 | - `feedback` - The agent is given the compiler logs as feedback. 83 | 84 | - `rag` - The agent is given the retrieved human guidance as feedback. 85 | 86 | 87 | #### Compilers 88 | 89 | - `iverilog` - https://github.com/steveicarus/iverilog 90 | 91 | - `modelsim` - https://www.intel.com/content/www/us/en/software-kit/750368/modelsim-intel-fpgas-standard-edition-software-version-18-1.html 92 | 93 | - `vcs` - https://www.synopsys.com/support/licensing-installation-computeplatforms/installation.html 94 | 95 | - `quartus` - https://www.intel.com/content/www/us/en/software-kit/660907/intel-quartus-prime-lite-edition-design-software-version-20-1-1-for-windows.html 96 | 97 | 98 | 99 | ### Other Notes 100 | 101 | For all questions, contact [yundat@nvidia.com](yundat@nvidia.com) 102 | 103 | ### Cite 104 | 105 | ```bibtex 106 | @misc{tsai2024rtlfixer, 107 | title={RTLFixer: Automatically Fixing RTL Syntax Errors with Large Language Models}, 108 | author={Yun-Da Tsai and Mingjie Liu and Haoxing Ren}, 109 | year={2024}, 110 | eprint={2311.16543}, 111 | archivePrefix={arXiv}, 112 | primaryClass={cs.AR} 113 | } 114 | ``` 115 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jsonlines==3.1.0 2 | openai==0.27.0 3 | datasets==2.7.0 4 | tenacity==8.1.0 5 | astunparse==1.6.3 6 | langchain==0.0.342 7 | pydantic==1.10.11 8 | langchain-experimental==0.0.27 9 | -------------------------------------------------------------------------------- /src/benchmarks/compile-error-hard.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/src/benchmarks/compile-error-hard.csv.gz -------------------------------------------------------------------------------- /src/benchmarks/compile-error.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/src/benchmarks/compile-error.csv.gz -------------------------------------------------------------------------------- /src/benchmarks/gpt35_reproduce.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/src/benchmarks/gpt35_reproduce.jsonl.gz -------------------------------------------------------------------------------- /src/benchmarks/gpt35_reproduce_machine.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/src/benchmarks/gpt35_reproduce_machine.jsonl.gz -------------------------------------------------------------------------------- /src/benchmarks/rtllm.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/src/benchmarks/rtllm.jsonl.gz -------------------------------------------------------------------------------- /src/benchmarks/simulate-hard.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/src/benchmarks/simulate-hard.csv.gz -------------------------------------------------------------------------------- /src/benchmarks/verilogeval-machine.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/src/benchmarks/verilogeval-machine.jsonl.gz -------------------------------------------------------------------------------- /src/benchmarks/verilogeval-manual.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/src/benchmarks/verilogeval-manual.jsonl.gz -------------------------------------------------------------------------------- /src/executors/__init__.py: -------------------------------------------------------------------------------- 1 | from .py_executor import PyExecutor 2 | from .rs_executor import RsExecutor 3 | from .factory import executor_factory 4 | -------------------------------------------------------------------------------- /src/executors/executor_types.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, List, Tuple 2 | from abc import ABC, abstractmethod 3 | 4 | class ExecuteResult(NamedTuple): 5 | is_passing: bool 6 | feedback: str 7 | state: Tuple[bool] 8 | 9 | class Executor(ABC): 10 | @abstractmethod 11 | def execute(self, func: str, tests: List[str], timeout: int = 5) -> ExecuteResult: 12 | ... 13 | 14 | @abstractmethod 15 | def evaluate(self, name: str, func: str, test: str, timeout: int = 5) -> bool: 16 | ... 17 | 18 | # class Executor: 19 | # def execute(self, func: str, tests: List[str], timeout: int = 5) -> ExecuteResult: 20 | # raise NotImplementedError 21 | 22 | # def evaluate(self, name: str, func: str, test: str, timeout: int = 5) -> bool: 23 | # raise NotImplementedError 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /src/executors/executor_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def timeout_handler(_, __): 5 | raise TimeoutError() 6 | 7 | import os, json 8 | def to_jsonl(dict_data, file_path): 9 | with open(file_path, 'a') as file: 10 | json_line = json.dumps(dict_data) 11 | file.write(json_line + os.linesep) 12 | 13 | from threading import Thread 14 | class PropagatingThread(Thread): 15 | def run(self): 16 | self.exc = None 17 | try: 18 | if hasattr(self, '_Thread__target'): 19 | # Thread uses name mangling prior to Python 3. 20 | self.ret = self._Thread__target(*self._Thread__args, **self._Thread__kwargs) 21 | else: 22 | self.ret = self._target(*self._args, **self._kwargs) 23 | except BaseException as e: 24 | self.exc = e 25 | 26 | def join(self, timeout=None): 27 | super(PropagatingThread, self).join(timeout) 28 | if self.exc: 29 | raise self.exc 30 | return self.ret 31 | 32 | 33 | def simple_syntax_fixer(code_completion: str, problem: dict = {}): 34 | 35 | try: 36 | # parse code block 37 | code_completion = parse_markdown_code_block(code_completion) 38 | 39 | # add endmodule 40 | if code_completion and 'endmodule' not in [i for i in code_completion.strip().split('\n') if i][-1]: 41 | code_completion += "\nendmodule" 42 | # elif not code_completion.strip().endswith('endmodule'): 43 | # code_completion += "\nendmodule" 44 | 45 | # remove lines 46 | impl = [] 47 | for i in code_completion.split('\n'): 48 | i = i.strip() 49 | if not i: 50 | continue 51 | # elif i.startswith('//'): 52 | # continue 53 | elif 'timescale' in i: 54 | continue 55 | impl.append(i.strip()) 56 | code_completion = "\n".join(impl) 57 | 58 | 59 | 60 | # add module header 61 | if 'module top_module' not in code_completion: 62 | code_completion = problem['prompt'] + '\n' + code_completion 63 | 64 | # force correct module header 65 | elif re.sub('\s+', '', problem['prompt']) not in re.sub('\s+', '', code_completion) and 'top_module' in code_completion and 'full_module' not in code_completion: 66 | st = None 67 | ed = None 68 | for e, line in enumerate(code_completion.split('\n')): 69 | 70 | if 'top_module' in line: 71 | st = e 72 | elif ');' in line: 73 | ed = e+1 74 | break 75 | code_completion = "\n".join(code_completion.split('\n')[:st] + problem['prompt'].split('\n') + code_completion.split('\n')[ed:]) 76 | 77 | 78 | 79 | # if 'clk' not in problem['prompt'] and 'clk' in code_completion: 80 | # code_completion = code_completion.replace('posedge clk', '*') 81 | 82 | except Exception: 83 | import traceback 84 | print(traceback.format_exc()) 85 | 86 | return code_completion 87 | 88 | 89 | def parse_markdown_code_block(text: str, ext: str = 'verilog'): 90 | try: 91 | cleaned_output = text.strip() 92 | if f"```{ext}" in cleaned_output: 93 | _, cleaned_output = cleaned_output.split(f"```{ext}") 94 | if "```" in cleaned_output: 95 | cleaned_output, _ = cleaned_output.split("```") 96 | if cleaned_output.startswith(f"```{ext}"): 97 | cleaned_output = cleaned_output[len(f"```{ext}"):] 98 | if cleaned_output.startswith("```"): 99 | cleaned_output = cleaned_output[len("```"):] 100 | if cleaned_output.endswith("```"): 101 | cleaned_output = cleaned_output[: -len("```")] 102 | return cleaned_output.strip() 103 | except Exception: 104 | return text 105 | 106 | 107 | def function_with_timeout(func, args, timeout): 108 | result_container = [] 109 | 110 | def wrapper(): 111 | result_container.append(func(*args)) 112 | 113 | thread = PropagatingThread(target=wrapper) 114 | thread.start() 115 | thread.join(timeout) 116 | 117 | if thread.is_alive(): 118 | raise TimeoutError() 119 | else: 120 | return result_container[0] 121 | 122 | # Py tests 123 | 124 | # if __name__ == "__main__": 125 | # formatter = PySubmissionFormatter() 126 | # leetcode_1 = 'class Solution:\n def solveSudoku(self, board: List[List[str]]) -> None:\n """\n Do not return anything, modify board in-place instead.\n """\n ' 127 | # humaneval_1 = 'def solveSudoku(self, board: List[List[str]]) -> None:\n """\n Do not return anything, modify board in-place instead.\n """\n' 128 | 129 | # assert leetcode_1 == formatter.to_leetcode(humaneval_1) 130 | # assert humaneval_1 == formatter.to_humaneval(leetcode_1) 131 | 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /src/executors/factory.py: -------------------------------------------------------------------------------- 1 | from .py_executor import PyExecutor 2 | from .rs_executor import RsExecutor 3 | from .executor_types import Executor 4 | from .leet_executor import LeetExecutor 5 | from .vg_executor import VerilogExecutor 6 | 7 | def executor_factory(lang: str, is_leet: bool = False) -> Executor: 8 | if lang == "py" or lang == "python": 9 | if is_leet: 10 | print("Using LeetCode Python executor") 11 | from .leetcode_env.leetcode_env.leetcode_types import ProgrammingLanguage 12 | from .leetcode_env.leetcode_env.utils import PySubmissionFormatter, RsSubmissionFormatter 13 | return LeetExecutor(ProgrammingLanguage.PYTHON3, 14 | PyExecutor(), 15 | PySubmissionFormatter) 16 | else: 17 | return PyExecutor() 18 | elif lang == "rs" or lang == "rust": 19 | if is_leet: 20 | from .leetcode_env.leetcode_env.leetcode_types import ProgrammingLanguage 21 | from .leetcode_env.leetcode_env.utils import PySubmissionFormatter, RsSubmissionFormatter 22 | return LeetExecutor(ProgrammingLanguage.RUST, 23 | RsExecutor(), 24 | RsSubmissionFormatter) 25 | else: 26 | return RsExecutor() 27 | elif lang == "vg" or lang == "verilog": 28 | return VerilogExecutor() 29 | else: 30 | raise ValueError(f"Invalid language for executor: {lang}") 31 | -------------------------------------------------------------------------------- /src/executors/leet_executor.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import List 4 | 5 | from .executor_types import ExecuteResult, Executor 6 | from .executor_utils import to_jsonl 7 | from datetime import datetime 8 | 9 | class LeetExecutor(Executor): 10 | def __init__(self, lang, executor: Executor, formatter): 11 | from .leetcode_env.leetcode_env.utils import SubmissionFormatter 12 | from .leetcode_env.leetcode_env.leetcode_types import ProgrammingLanguage 13 | from .leetcode_env.leetcode_env.environment import LeetCodeEnv 14 | assert isinstance(formatter, SubmissionFormatter) 15 | assert isinstance(lang, ProgrammingLanguage) 16 | self.lang = lang 17 | self.executor = executor 18 | self.formatter = formatter 19 | self.env = LeetCodeEnv() 20 | self.name = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 21 | 22 | def execute(self, func: str, tests: List[str], timeout: int = 5) -> ExecuteResult: 23 | return self.executor.execute(func, tests, timeout) 24 | 25 | def evaluate(self, name: str, func: str, test: str, timeout: int = 5) -> bool: 26 | from .leetcode_env.leetcode_env.leetcode_types import LeetCodeSubmission 27 | from .leetcode_env.leetcode_env.utils import id_from_slug 28 | print(f'Timeout is {timeout} seconds') 29 | try: 30 | leetcode_formatted_func = self.formatter.to_leetcode(func) 31 | except Exception as e: 32 | print(f'Error formatting function to leetcode: {e}') 33 | return False 34 | print('----------------- LEETCODE SUBMISSION ------------------') 35 | print(leetcode_formatted_func) 36 | print('--------------------------------------------------------') 37 | submission = LeetCodeSubmission( 38 | code=leetcode_formatted_func, 39 | lang=self.lang, 40 | question_id=id_from_slug(name, self.env.api_instance), 41 | question_slug=name, 42 | timeout=timeout 43 | ) 44 | 45 | status, reward, _, info = self.env.step(submission) 46 | 47 | print('----------------- LEETCODE SUBMISSION ------------------') 48 | print(status) 49 | print('--------------------------------------------------------') 50 | 51 | to_jsonl({ 52 | 'name': name, 53 | 'status': status, 54 | 'reward': reward, 55 | 'info': info 56 | }, self.name) 57 | 58 | return reward 59 | -------------------------------------------------------------------------------- /src/executors/py_executor.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import signal 3 | import astunparse 4 | 5 | from .executor_utils import function_with_timeout 6 | 7 | from typing import List 8 | from .executor_types import ExecuteResult, Executor 9 | 10 | class PyExecutor(Executor): 11 | def execute(self, func: str, tests: List[str], timeout: int = 5) -> ExecuteResult: 12 | # Combine function code and assert statement 13 | imports = 'from typing import *' 14 | func_test_list = [f'{imports}\n{func}\n{test}' for test in tests] 15 | 16 | # Run the tests and collect the results 17 | success_tests = [] 18 | failed_tests = [] 19 | is_passing = True 20 | num_tests = len(func_test_list) 21 | for i in range(num_tests): 22 | try: 23 | 24 | function_with_timeout(exec, (func_test_list[i], globals()), timeout) 25 | 26 | success_tests += [tests[i]] 27 | except Exception: 28 | output = get_output(func, tests[i], timeout=timeout) 29 | failed_tests += [f"{tests[i]} # output: {output}"] 30 | is_passing = False 31 | 32 | state = [] 33 | for test in tests: 34 | if test in success_tests: 35 | state += [True] 36 | else: 37 | state += [False] 38 | 39 | state = tuple(state) 40 | 41 | feedback = "Tested passed:" 42 | for test in success_tests: 43 | feedback += f"\n{test}" 44 | feedback += "\n\nTests failed:" 45 | for test in failed_tests: 46 | feedback += f"\n{test}" 47 | 48 | return ExecuteResult(is_passing, feedback, state) 49 | 50 | def evaluate(self, name: str, func: str, test: str, timeout: int = 5) -> bool: 51 | """ 52 | Evaluates the implementation on Human-Eval Python. 53 | 54 | probably should be written in a dataset-agnostic way but not now 55 | """ 56 | code = f"""{func} 57 | 58 | {test} 59 | 60 | check({name}) 61 | """ 62 | try: 63 | 64 | function_with_timeout(exec, (code, globals()), timeout) 65 | 66 | return True 67 | except Exception: 68 | return False 69 | 70 | def get_call_str(assert_statement: str) -> str: 71 | ast_parsed = ast.parse(assert_statement) 72 | try: 73 | call_str = ast_parsed.body[0].test.left # type: ignore 74 | except: 75 | call_str = ast_parsed.body[0].test # type: ignore 76 | 77 | return astunparse.unparse(call_str).strip() 78 | 79 | def get_output(func: str, assert_statement: str, timeout: int = 5) -> str: 80 | try: 81 | exec(f"from typing import *\n{func}", globals()) 82 | func_call = get_call_str(assert_statement) 83 | output = function_with_timeout(eval, (func_call, globals()), timeout) 84 | return output 85 | except TimeoutError: 86 | return "TIMEOUT" 87 | except Exception as e: 88 | return str(e) 89 | 90 | if __name__ == "__main__": 91 | pass 92 | # Test the function 93 | func = "def add(a, b):\n while True:\n x = 1\n return a + b" 94 | tests = ["assert add(1, 2) == 3", "assert add(1, 2) == 4"] 95 | print(PyExecutor().execute(func, tests, timeout=1)) 96 | -------------------------------------------------------------------------------- /src/executors/rs_executor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import signal 3 | import subprocess 4 | import json 5 | 6 | from .executor_utils import timeout_handler 7 | from .executor_types import ExecuteResult, Executor 8 | 9 | from typing import List, Tuple, Optional 10 | 11 | 12 | cargo_harness_dir = os.path.join(os.path.dirname( 13 | os.path.realpath(__file__)), "cargo_harness") 14 | 15 | 16 | def create_temp_project() -> Tuple[str, str]: 17 | # get pid of the process 18 | pid = os.getpid() 19 | # get random number 20 | rand = os.urandom(8).hex() 21 | # create a temp directory 22 | temp_dir = f"/tmp/cargo_harness-{pid}-{rand}" 23 | # delete the temp directory if it exists 24 | if os.path.exists(temp_dir): 25 | os.system(f"rm -rf {temp_dir}") 26 | os.mkdir(temp_dir) 27 | # move the cargo harness into the temp directory 28 | os.system(f"cp -r {cargo_harness_dir}/* {temp_dir}") 29 | main_path = os.path.join(temp_dir, "src", "main.rs") 30 | return temp_dir, main_path 31 | 32 | 33 | def write_to_file(path: str, code: str): 34 | prelude = "fn main() {\n" 35 | postlude = "\n}" 36 | code = prelude + indent_code(code) + postlude 37 | # delete the file if it exists 38 | if os.path.exists(path): 39 | os.remove(path) 40 | # write the code to the file 41 | with open(path, "w") as f: 42 | f.write(code) 43 | 44 | 45 | def write_to_file_toplevel(path: str, code: str): 46 | # delete the file if it exists 47 | if os.path.exists(path): 48 | os.remove(path) 49 | # write the code to the file 50 | with open(path, "w") as f: 51 | f.write(code) 52 | 53 | 54 | def run_with_timeout(cmd: str, tmp_cargo_path: str, timeout: int = 5, print_debug: bool = False) -> Optional[Tuple[str, str]]: 55 | """ 56 | Runs the given command with a timeout. Produces a tuple of stdout and stderr. 57 | If the command times out, returns None. 58 | """ 59 | # set up the timeout handler 60 | signal.signal(signal.SIGALRM, timeout_handler) 61 | signal.alarm(timeout) 62 | 63 | # run the command 64 | p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, 65 | stderr=subprocess.PIPE, cwd=tmp_cargo_path) 66 | try: 67 | out, err = p.communicate() 68 | # reset the timeout handler 69 | signal.alarm(0) 70 | except TimeoutError: 71 | p.kill() 72 | return None 73 | 74 | # decode the output 75 | out = out.decode("utf-8") 76 | err = err.decode("utf-8") 77 | if print_debug: 78 | print("## RUN OUTPUTS ##") 79 | print("STDOUT:") 80 | print(out) 81 | print("STDERR:") 82 | print(err, flush=True) 83 | 84 | return out, err 85 | 86 | 87 | class RsExecutor(Executor): 88 | def execute(self, func: str, tests: List[str], timeout: int = 5) -> ExecuteResult: 89 | # Combine function code and assert statement 90 | func_test_list = [f'{func}\n{test}' for test in tests] 91 | 92 | tmp_dir, temp_file = create_temp_project() 93 | 94 | # run cargo check --message-format=json 95 | write_to_file(temp_file, func) 96 | res = run_with_timeout( 97 | "cargo check --message-format=json", tmp_dir, timeout=timeout) 98 | assert res is not None, "Timeout in cargo check, wow" 99 | 100 | errs = grab_compile_errs(res[0]) # (check returns stdin) 101 | if len(errs) > 0: 102 | # cleanup the temp directory 103 | os.system(f"rm -rf {tmp_dir}") 104 | state = tuple([False] * len(tests)) 105 | 106 | err_str = "" 107 | for err in errs: 108 | err_str += f"\n{err}" 109 | 110 | return ExecuteResult(False, err_str, state) 111 | 112 | # Run the tests and collect the results 113 | tests_res: List[Tuple[bool, str]] = [] 114 | num_tests = len(func_test_list) 115 | for i in range(num_tests): 116 | """ 117 | # use some sort of timeout limit to handle infinite loops 118 | if pass, add to success tests 119 | if fail, add to failed tests with the log from the compiler 120 | """ 121 | write_to_file(temp_file, func_test_list[i]) 122 | 123 | # run cargo run 124 | res = run_with_timeout("cargo run", tmp_dir, timeout=timeout) 125 | if res is None: 126 | tests_res.append((False, "Timeout")) 127 | continue 128 | 129 | # check if we have any failed tests 130 | errs = grab_runtime_errs(res[1]) 131 | if len(errs) > 0: 132 | tests_res.append((False, str(errs[0]))) 133 | continue 134 | 135 | # if we get here, the test passed 136 | tests_res.append((True, "")) 137 | 138 | # cleanup the temp directory 139 | os.system(f"rm -rf {tmp_dir}") 140 | 141 | passed_str = "" 142 | failed_str = "" 143 | state = [] 144 | for i, (passed, output) in enumerate(tests_res): 145 | test = tests[i] 146 | if passed: 147 | passed_str += f"\n{test}" 148 | else: 149 | failed_str += f"\n{test} // output: {output}" 150 | state.append(passed) 151 | 152 | feedback = "Tested passed:" 153 | feedback += passed_str 154 | feedback += "\n\nTests failed:" 155 | feedback += failed_str 156 | 157 | is_passing = len(failed_str) == 0 158 | 159 | return ExecuteResult(is_passing, feedback, tuple(state)) 160 | 161 | def evaluate(self, name: str, func: str, test: str, timeout: int = 5) -> bool: 162 | """ 163 | Evaluates the implementation on Human-Eval Rust (MultiPL-E generated, 164 | 165 | Federico Cassano, John Gouwar, Daniel Nguyen, Sydney Nguyen, Luna Phipps-Costin, Donald Pinckney, Ming-Ho Yee, Yangtian Zi, Carolyn Jane Anderson, Molly Q Feldman, Arjun Guha, Michael Greenberg, Abhinav Jangda ). 166 | If you use this function please cite: 167 | @misc{cassano2022multiple, 168 | title={MultiPL-E: A Scalable and Extensible Approach to Benchmarking Neural Code Generation}, 169 | author={Federico Cassano and John Gouwar and Daniel Nguyen and Sydney Nguyen and Luna Phipps-Costin and Donald Pinckney and Ming-Ho Yee and Yangtian Zi and Carolyn Jane Anderson and Molly Q Feldman and Arjun Guha and Michael Greenberg and Abhinav Jangda}, 170 | year={2022}, 171 | eprint={2208.08227}, 172 | archivePrefix={arXiv}, 173 | primaryClass={cs.LG} 174 | }) 175 | 176 | TODO: do it actually 177 | """ 178 | tmp_dir, tmp_path = create_temp_project() 179 | print(f"Evaluating\n{func + test}", flush=True) 180 | write_to_file_toplevel(tmp_path, func + test) 181 | 182 | res = run_with_timeout( 183 | "cargo check --message-format=json", tmp_dir, timeout=timeout, print_debug=True) 184 | assert res is not None, "Timeout in cargo check, wow" 185 | 186 | errs = grab_compile_errs(res[0]) # (check returns stdin) 187 | if len(errs) > 0: 188 | # cleanup the temp directory 189 | os.system(f"rm -rf {tmp_dir}") 190 | print("Compile errors. Failed eval", flush=True) 191 | return False 192 | 193 | # compile and run the binary 194 | res = run_with_timeout("cargo run", tmp_dir, 195 | timeout=timeout, print_debug=True) 196 | os.system(f"rm -rf {tmp_dir}") 197 | 198 | if res is None: 199 | print("Timeout?. Failed eval", flush=True) 200 | return False 201 | else: 202 | errs = grab_runtime_errs(res[1]) 203 | if len(errs) > 0: 204 | print("Runtime errors. Failed eval", flush=True) 205 | return False 206 | 207 | print("Passed eval", flush=True) 208 | return len(errs) == 0 209 | 210 | 211 | assert_no_panic = r""" 212 | macro_rules! assert_eq_nopanic { 213 | ($left:expr, $right:expr) => { 214 | std::panic::catch_unwind(|| { 215 | assert_eq!($left, $right); 216 | }).unwrap_or_else(|_| {}); 217 | }; 218 | () => {}; 219 | } 220 | """ 221 | 222 | 223 | def transform_asserts(code: str) -> str: 224 | """ 225 | Transform all asserts into assert_eq_nopanic! asserts, inserting the macro 226 | definition at the top of the code. 227 | """ 228 | code.replace("assert_eq!", "assert_eq_nopanic!") 229 | return assert_no_panic + code 230 | 231 | 232 | def revert_asserts(code: str) -> str: 233 | """ 234 | Revert all assert_eq_nopanic! asserts back into assert_eq! asserts. 235 | """ 236 | normal = code.replace("assert_eq_nopanic!", "assert_eq!") 237 | # remove the macro definition 238 | return normal[len(assert_no_panic):] 239 | 240 | 241 | def indent_code(code: str, spaces: int = 4) -> str: 242 | """ 243 | Indent the code by the given number of spaces. 244 | """ 245 | return "\n".join([" " * spaces + line for line in code.splitlines()]) 246 | 247 | 248 | class CompileErr: 249 | def __init__(self, rendered): 250 | self.rendered = rendered 251 | 252 | def __str__(self): 253 | return self.rendered 254 | 255 | def __repr__(self): 256 | return "{" + str(self) + "}" 257 | 258 | 259 | class RuntimeErr: 260 | def __init__(self, left, right, line, column, panic_reason): 261 | # right and left are only used for assert_eq! errors 262 | self.left = left 263 | self.right = right 264 | # NOTE: currently not using the below 265 | self.line = line 266 | self.column = column 267 | self.panic_reason = panic_reason 268 | 269 | def __str__(self): 270 | if self.left is not None and self.right is not None: 271 | return f"assertion failed: {self.left} == {self.right}" 272 | else: 273 | return self.panic_reason 274 | 275 | def __repr__(self): 276 | return "{" + str(self) + "}" 277 | 278 | 279 | # assumes that the input is the stdout of cargo check --message-format=json 280 | # returns a list of compile errors as CompileErr objects 281 | def grab_compile_errs(inp: str) -> List[CompileErr]: 282 | # we get a stream of json objects, so we need to parse them one by one 283 | objs = [] 284 | for line in inp.splitlines(): 285 | if line == "": 286 | continue 287 | o = json.loads(line) 288 | if o is not None and o["reason"] == "compiler-message" and \ 289 | o["message"]["level"] == "error" and \ 290 | o["message"]["spans"] != []: 291 | rendered = o["message"]["rendered"] 292 | objs.append(CompileErr(rendered)) 293 | 294 | return objs 295 | 296 | # assumes that the given input is the stderr of cargo run. 297 | # returns a list of failed assertions as RuntimeErr objects 298 | 299 | 300 | def grab_runtime_errs(inp: str) -> List[RuntimeErr]: 301 | failed_asserts = [] 302 | split = inp.splitlines() 303 | curr_left = None 304 | panic_reason = None 305 | for line in split: 306 | if "fatal runtime" in line: 307 | # we have a panic 308 | panic_idx = line.index("fatal runtime") 309 | panic_reason = line[panic_idx + len("fatal runtime") + 1:] 310 | elif "panicked at" in line: 311 | panic_idx = line.index("panicked at") 312 | # strip source line if it exists 313 | if "src/main.rs" in line: 314 | line = line[:line.index("src/main.rs")] 315 | panic_reason = line[panic_idx + len("panicked at") + 1:] 316 | elif "left:" in line: 317 | split = line.split("`") 318 | if len(split) < 2: 319 | continue 320 | curr_left = split[1] 321 | elif "right:" in line: 322 | split = line.split("`") 323 | if len(split) < 2: 324 | continue 325 | curr_right = split[1] 326 | # get the line and column number 327 | fileinto = line.split(",")[-1] 328 | line = int(fileinto.split(":")[1]) 329 | column = int(fileinto.split(":")[2]) 330 | failed_asserts.append(RuntimeErr( 331 | curr_left, curr_right, line, column, panic_reason)) 332 | curr_left = None 333 | panic_reason = None 334 | 335 | if panic_reason is not None: 336 | failed_asserts.append(RuntimeErr(None, None, None, None, panic_reason)) 337 | 338 | return failed_asserts 339 | 340 | 341 | if __name__ == "__main__": 342 | test_runtime = r""" 343 | Finished dev [unoptimized + debuginfo] target(s) in 0.00s 344 | Running `target/debug/testing` 345 | thread 'main' panicked at 'assertion failed: `(left == right)` 346 | left: `1`, 347 | right: `2`', src/main.rs:11:5 348 | note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace 349 | thread 'main' panicked at 'assertion failed: `(left == right)` 350 | left: `3`, 351 | right: `2`', src/main.rs:12:5 352 | thread 'main' panicked at 'assertion failed: `(left == right)` 353 | left: `[5, -3, -4]`, 354 | right: `[-4, -3, 5]`', src/main.rs:24:5 355 | note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace 356 | thread 'main' panicked at 'assertion failed: `(left == right)` 357 | left: `"hello"`, 358 | right: `"hola"`', src/main.rs:24:5 359 | note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace 360 | """ 361 | 362 | # test input 363 | test_compiletime = r""" 364 | {"reason":"compiler-message","package_id":"testing 0.1.0 (path+file:///home/elleven/Downloads/testing)","manifest_path":"/home/elleven/Downloads/testing/Cargo.toml","target":{"kind":["bin"],"crate_types":["bin"],"name":"testing","src_path":"/home/elleven/Downloads/testing/src/main.rs","edition":"2021","doc":true,"doctest":false,"test":true},"message":{"rendered":"error[E0282]: type annotations needed\n --> src/main.rs:2:9\n |\n2 | let sakfsdfjfndslv;\n | ^^^^^^^^^^^^^^\n |\nhelp: consider giving `sakfsdfjfndslv` an explicit type\n |\n2 | let sakfsdfjfndslv: _;\n | +++\n\n","children":[{"children":[],"code":null,"level":"help","message":"consider giving `sakfsdfjfndslv` an explicit type","rendered":null,"spans":[{"byte_end":34,"byte_start":34,"column_end":23,"column_start":23,"expansion":null,"file_name":"src/main.rs","is_primary":true,"label":null,"line_end":2,"line_start":2,"suggested_replacement":": _","suggestion_applicability":"HasPlaceholders","text":[{"highlight_end":23,"highlight_start":23,"text":" let sakfsdfjfndslv;"}]}]}],"code":{"code":"E0282","explanation":"The compiler could not infer a type and asked for a type annotation.\n\nErroneous code example:\n\n```compile_fail,E0282\nlet x = \"hello\".chars().rev().collect();\n```\n\nThis error indicates that type inference did not result in one unique possible\ntype, and extra information is required. In most cases this can be provided\nby adding a type annotation. Sometimes you need to specify a generic type\nparameter manually.\n\nA common example is the `collect` method on `Iterator`. It has a generic type\nparameter with a `FromIterator` bound, which for a `char` iterator is\nimplemented by `Vec` and `String` among others. Consider the following snippet\nthat reverses the characters of a string:\n\nIn the first code example, the compiler cannot infer what the type of `x` should\nbe: `Vec` and `String` are both suitable candidates. To specify which type\nto use, you can use a type annotation on `x`:\n\n```\nlet x: Vec = \"hello\".chars().rev().collect();\n```\n\nIt is not necessary to annotate the full type. Once the ambiguity is resolved,\nthe compiler can infer the rest:\n\n```\nlet x: Vec<_> = \"hello\".chars().rev().collect();\n```\n\nAnother way to provide the compiler with enough information, is to specify the\ngeneric type parameter:\n\n```\nlet x = \"hello\".chars().rev().collect::>();\n```\n\nAgain, you need not specify the full type if the compiler can infer it:\n\n```\nlet x = \"hello\".chars().rev().collect::>();\n```\n\nApart from a method or function with a generic type parameter, this error can\noccur when a type parameter of a struct or trait cannot be inferred. In that\ncase it is not always possible to use a type annotation, because all candidates\nhave the same return type. For instance:\n\n```compile_fail,E0282\nstruct Foo {\n num: T,\n}\n\nimpl Foo {\n fn bar() -> i32 {\n 0\n }\n\n fn baz() {\n let number = Foo::bar();\n }\n}\n```\n\nThis will fail because the compiler does not know which instance of `Foo` to\ncall `bar` on. Change `Foo::bar()` to `Foo::::bar()` to resolve the error.\n"},"level":"error","message":"type annotations needed","spans":[{"byte_end":34,"byte_start":20,"column_end":23,"column_start":9,"expansion":null,"file_name":"src/main.rs","is_primary":true,"label":null,"line_end":2,"line_start":2,"suggested_replacement":null,"suggestion_applicability":null,"text":[{"highlight_end":23,"highlight_start":9,"text":" let sakfsdfjfndslv;"}]}]}} 365 | {"reason":"compiler-message","package_id":"testing 0.1.0 (path+file:///home/elleven/Downloads/testing)","manifest_path":"/home/elleven/Downloads/testing/Cargo.toml","target":{"kind":["bin"],"crate_types":["bin"],"name":"testing","src_path":"/home/elleven/Downloads/testing/src/main.rs","edition":"2021","doc":true,"doctest":false,"test":true},"message":{"rendered":"error: aborting due to previous error\n\n","children":[],"code":null,"level":"error","message":"aborting due to previous error","spans":[]}} 366 | {"reason":"compiler-message","package_id":"testing 0.1.0 (path+file:///home/elleven/Downloads/testing)","manifest_path":"/home/elleven/Downloads/testing/Cargo.toml","target":{"kind":["bin"],"crate_types":["bin"],"name":"testing","src_path":"/home/elleven/Downloads/testing/src/main.rs","edition":"2021","doc":true,"doctest":false,"test":true},"message":{"rendered":"For more information about this error, try `rustc --explain E0282`.\n","children":[],"code":null,"level":"failure-note","message":"For more information about this error, try `rustc --explain E0282`.","spans":[]}} 367 | {"reason":"build-finished","success":false} 368 | """ 369 | 370 | assert(len(grab_compile_errs(test_compiletime)) == 1) 371 | print(grab_runtime_errs(test_runtime)) 372 | assert(len(grab_runtime_errs(test_runtime)) == 4) 373 | -------------------------------------------------------------------------------- /src/executors/vg_executor.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import signal 3 | import astunparse 4 | from typing import List 5 | from collections import defaultdict 6 | # from verilog_eval.execution import check_correctness 7 | 8 | from .executor_utils import function_with_timeout, simple_syntax_fixer 9 | from .executor_types import ExecuteResult, Executor 10 | 11 | 12 | 13 | class VerilogExecutor(Executor): 14 | 15 | def __init__(self): 16 | self.failed_reasons = defaultdict(list) 17 | 18 | def execute(self, func: str, tests: List[str], timeout: int = 5) -> ExecuteResult: 19 | 20 | # Run the tests and collect the results 21 | success_tests = [] 22 | failed_tests = [] 23 | is_passing = True 24 | num_tests = len(tests) 25 | for i in range(num_tests): 26 | try: 27 | success_tests += [tests[i]] 28 | except Exception: 29 | output = get_output(func, tests[i], timeout=timeout) 30 | failed_tests += [f"{tests[i]} # output: {output}"] 31 | is_passing = False 32 | 33 | state = [] 34 | for test in tests: 35 | if test in success_tests: 36 | state += [True] 37 | else: 38 | state += [False] 39 | 40 | state = tuple(state) 41 | 42 | feedback = "Tested passed:" 43 | for test in success_tests: 44 | feedback += f"\n{test}" 45 | feedback += "\n\nTests failed:" 46 | for test in failed_tests: 47 | feedback += f"\n{test}" 48 | 49 | return ExecuteResult(is_passing, feedback, state) 50 | 51 | def evaluate(self, problem: dict, completion: str, test: str, timeout: int = 5, compile_only: bool = False) -> bool: 52 | """ 53 | Evaluates the implementation on VerilogEval. 54 | probably should be written in a dataset-agnostic way but not now 55 | """ 56 | completion = simple_syntax_fixer(completion, problem) 57 | result = check_correctness(problem, completion, timeout, compile_only=compile_only) 58 | return result 59 | 60 | 61 | if __name__ == "__main__": 62 | pass 63 | # Test the function 64 | func = "" 65 | tests = [] 66 | print(VerilogExecutor().execute(func, tests, timeout=10)) 67 | 68 | 69 | 70 | 71 | from typing import Optional, Callable, Dict 72 | import ast 73 | import contextlib 74 | import faulthandler 75 | import io 76 | import os 77 | import multiprocessing 78 | import platform 79 | import signal 80 | import tempfile 81 | 82 | import subprocess 83 | import re 84 | from threading import Timer 85 | from verilog_eval.execution import time_limit, swallow_io, create_tempdir, TimeoutException, WriteOnlyStringIO, reliability_guard 86 | 87 | 88 | def check_correctness(problem: Dict, completion: str, timeout: float, 89 | completion_id: Optional[int] = None, unit_test_length: Optional[int] = None, compile_only: bool = False) -> Dict: 90 | """ 91 | Evaluates the functional correctness of a completion by running the test 92 | suite provided in the problem. 93 | :param completion_id: an optional completion ID so we can match 94 | the results later even if execution finishes asynchronously. 95 | """ 96 | 97 | def unsafe_execute(): 98 | 99 | with create_tempdir(): 100 | 101 | # These system calls are needed when cleaning up tempdir. 102 | import os 103 | import shutil 104 | rmtree = shutil.rmtree 105 | rmdir = os.rmdir 106 | chdir = os.chdir 107 | 108 | reliability_guard() 109 | 110 | 111 | result['passed'] = True 112 | result['compiler_log'] = "" 113 | result['test_output'] = "" 114 | result['verilog_test'] = problem["test"] + "\n" + completion 115 | result['completion'] = completion 116 | if "wave.vcd" in result: 117 | result.pop("wave.vcd") 118 | 119 | 120 | compile_func = iverlog_compile 121 | if compile_only == "quartus": 122 | compile_func = quartus_compile 123 | elif compile_only == "vcs": 124 | compile_func = vcs_compile 125 | elif compile_only == "modelsim": 126 | compile_func = modelsim_compile 127 | 128 | out = compile_func(completion, problem['task_id']) 129 | result['compiler_log'] = out 130 | 131 | if not verilog_compile_is_correct(out): 132 | result['passed'] = False 133 | 134 | result['haha'] = compile_only 135 | if not compile_only and result['passed']: 136 | 137 | iverlog_compile(completion, problem['task_id'], problem['test']) 138 | 139 | # simulate 140 | out, err = execute("vvp -n test.vvp", 20) 141 | result['test_output'] = f"{out}\n{err}" 142 | match = re.search(r'Mismatches: ([0-9]*) in ([0-9]*) samples', out) 143 | if match: 144 | cor, tot = [int(i) for i in match.groups()] 145 | if cor != 0: 146 | result['passed'] = False 147 | else: 148 | result['passed'] = False 149 | 150 | if os.path.exists('wave.vcd'): 151 | result["wave.vcd"] = open("wave.vcd", "r").read() 152 | 153 | # Needed for cleaning up. 154 | shutil.rmtree = rmtree 155 | os.rmdir = rmdir 156 | os.chdir = chdir 157 | 158 | manager = multiprocessing.Manager() 159 | result = manager.dict() 160 | 161 | p = multiprocessing.Process(target=unsafe_execute) 162 | p.start() 163 | p.join(timeout=timeout + 1) 164 | if p.is_alive(): 165 | p.kill() 166 | 167 | if not result: 168 | result = { 169 | 'passed': False, 170 | 'compiler_log': "timed out", 171 | 'test_result': 'timed out', 172 | } 173 | 174 | return dict( 175 | task_id=problem["task_id"], 176 | passed=result["passed"], 177 | feedback=dict(result), 178 | completion_id=completion_id, 179 | ) 180 | 181 | 182 | def iverlog_compile(verilog_test: str, task_id: str, test: str = ""): 183 | 184 | extra_cmd = "" 185 | if test: 186 | verilog_test = f"{test}\n{verilog_test}" 187 | extra_cmd = "-s tb" 188 | 189 | with open(f"{task_id}.sv", 'w') as f: 190 | f.write(verilog_test) 191 | out, err = execute( 192 | f"iverilog -Wall -Winfloop -Wno-timescale -g2012 {extra_cmd} -o test.vvp {task_id}.sv", 193 | 10 194 | ) 195 | return err 196 | 197 | 198 | def execute(cmd: str, timeout: int): 199 | try: 200 | with swallow_io(): 201 | with time_limit(timeout): 202 | p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 203 | timer = Timer(timeout, p.kill) 204 | try: 205 | timer.start() 206 | out, err = p.communicate() 207 | finally: 208 | timer.cancel() 209 | 210 | out, err = out.decode("utf-8"), err.decode("utf-8") 211 | except TimeoutException: 212 | out = err = "timed out" 213 | except BaseException as e: 214 | out = err = f"failed: {e}" 215 | finally: 216 | return out, err 217 | 218 | 219 | def verilog_compile_is_correct(log: str): 220 | log = log.lower() 221 | 222 | # include warning 223 | # if 'warning' in log: 224 | # return False 225 | 226 | if 'error' in log or 'give up' in log: 227 | return False 228 | 229 | return True 230 | 231 | 232 | def quartus_compile(verilog_test: str, task_id: str): 233 | with open(f"{task_id}.sv", 'w') as f: 234 | f.write(verilog_test) 235 | 236 | # grep '^Error|^Warning' 237 | with open('top_module.qsf', 'w') as f: 238 | f.write(f"set_global_assignment -name SYSTEMVERILOG_FILE {task_id}.sv") 239 | out, err = execute( 240 | f"/root/verilog-LLM/Quartus/y/quartus/bin/quartus_map --effort=fast --parallel=4 --read_settings_files=on --write_settings_files=off top_module -c top_module | grep '^Error'", 241 | 30 242 | ) 243 | 244 | tmp = [] 245 | for i in out.strip().split('\n'): 246 | if 'Error (' in i: 247 | trunc = i.find('Check for and fix') 248 | if trunc > 0: 249 | i = i[:trunc] 250 | tmp.append(i) 251 | out = '\n'.join(tmp) 252 | 253 | return out 254 | 255 | 256 | def vcs_compile(verilog_test: str, task_id: str): 257 | with open(f"{task_id}.sv", 'w') as f: 258 | f.write(verilog_test) 259 | out, err = execute( 260 | f"vcs -j8 +v2k -sverilog -q -full64 {task_id}.sv", 261 | 20 262 | ) 263 | return out 264 | 265 | 266 | def modelsim_compile(verilog_test: str, task_id: str): 267 | with open(f"{task_id}.sv", 'w') as f: 268 | f.write(verilog_test) 269 | out, err = execute( 270 | f"/root/intelFPGA/16.1/modelsim_ase/linuxaloem/vlog -sv -quiet {task_id}.sv", 271 | 10 272 | ) 273 | print(out) 274 | return out -------------------------------------------------------------------------------- /src/generators/__init__.py: -------------------------------------------------------------------------------- 1 | from .py_generate import PyGenerator 2 | from .rs_generate import RsGenerator 3 | from .factory import generator_factory, model_factory, agent_factory 4 | from .model import ModelBase, GPT4, GPT35 5 | from .agents import ReAct 6 | -------------------------------------------------------------------------------- /src/generators/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .react import ReAct 2 | from .plan_execute import PlanAndExecute 3 | from .openai_function import OpenAIFunc 4 | from .rtlfixer import RTLFixer 5 | 6 | 7 | __all__ = [ 8 | "ReAct", 9 | "OpenAIFunc", 10 | "PlanAndExecute", 11 | "RTLFixer" 12 | ] 13 | -------------------------------------------------------------------------------- /src/generators/agents/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/src/generators/agents/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/generators/agents/__pycache__/langchain_callback.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/src/generators/agents/__pycache__/langchain_callback.cpython-38.pyc -------------------------------------------------------------------------------- /src/generators/agents/__pycache__/langchain_tools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/src/generators/agents/__pycache__/langchain_tools.cpython-38.pyc -------------------------------------------------------------------------------- /src/generators/agents/__pycache__/openai_function.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/src/generators/agents/__pycache__/openai_function.cpython-38.pyc -------------------------------------------------------------------------------- /src/generators/agents/__pycache__/plan_execute.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/src/generators/agents/__pycache__/plan_execute.cpython-38.pyc -------------------------------------------------------------------------------- /src/generators/agents/__pycache__/react.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/src/generators/agents/__pycache__/react.cpython-38.pyc -------------------------------------------------------------------------------- /src/generators/agents/__pycache__/react_json_single_input_parser.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/src/generators/agents/__pycache__/react_json_single_input_parser.cpython-38.pyc -------------------------------------------------------------------------------- /src/generators/agents/__pycache__/rtlfixer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/src/generators/agents/__pycache__/rtlfixer.cpython-38.pyc -------------------------------------------------------------------------------- /src/generators/agents/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RTLFixer/24ceebd9176d59bf302e935a800fcb46a10c634c/src/generators/agents/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /src/generators/agents/langchain_callback.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Union, Type, Any 2 | from langchain.schema import ( 3 | AgentAction, 4 | AgentFinish, 5 | BaseMessage, 6 | LLMResult 7 | ) 8 | from langchain.callbacks import StdOutCallbackHandler 9 | 10 | 11 | class MyCallbackHandler(StdOutCallbackHandler): 12 | 13 | def __init__(self, agent, max_iter: int = 20): 14 | super().__init__() 15 | self.agent = agent 16 | self.max_iter = max_iter 17 | self.history = [] 18 | 19 | def log_history(function): 20 | def wrap_function(self, *args, **kwargs): 21 | log = (function.__name__, args, kwargs) 22 | 23 | # print("================================================================") 24 | # print(log) 25 | # print("================================================================") 26 | # import pdb 27 | # pdb.set_trace() 28 | 29 | self.history.append(log) 30 | return function(self, *args, **kwargs) 31 | return wrap_function 32 | 33 | @log_history 34 | def on_llm_start( 35 | self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any 36 | ) -> Any: 37 | """Run when LLM starts running.""" 38 | 39 | @log_history 40 | def on_chat_model_start( 41 | self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], **kwargs: Any 42 | ) -> Any: 43 | """Run when Chat Model starts running.""" 44 | 45 | @log_history 46 | def on_llm_new_token(self, token: str, **kwargs: Any) -> Any: 47 | """Run on new LLM token. Only available when streaming is enabled.""" 48 | 49 | @log_history 50 | def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any: 51 | """Run when LLM ends running.""" 52 | 53 | @log_history 54 | def on_llm_error( 55 | self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any 56 | ) -> Any: 57 | """Run when LLM errors.""" 58 | 59 | @log_history 60 | def on_chain_start( 61 | self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any 62 | ) -> Any: 63 | """Run when chain starts running.""" 64 | self.history = [] 65 | 66 | @log_history 67 | def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any: 68 | """Run when chain ends running.""" 69 | 70 | def on_chain_error( 71 | self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any 72 | ) -> Any: 73 | """Run when chain errors.""" 74 | 75 | @log_history 76 | def on_tool_start( 77 | self, serialized: Dict[str, Any], input_str: str, **kwargs: Any 78 | ) -> Any: 79 | """Run when tool starts running.""" 80 | 81 | @log_history 82 | def on_tool_end(self, output: str, **kwargs: Any) -> Any: 83 | """Run when tool ends running.""" 84 | 85 | @log_history 86 | def on_tool_error( 87 | self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any 88 | ) -> Any: 89 | """Run when tool errors.""" 90 | 91 | @log_history 92 | def on_text(self, text: str, **kwargs: Any) -> Any: 93 | """Run on arbitrary text.""" 94 | 95 | @log_history 96 | def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: 97 | """Run on agent action.""" 98 | 99 | @log_history 100 | def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: 101 | """Run on agent end.""" 102 | -------------------------------------------------------------------------------- /src/generators/agents/openai_function.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | from typing import Dict, List, Union, Type, Any 4 | import langchain 5 | from copy import deepcopy 6 | from pydantic import BaseModel 7 | 8 | # from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser, ReActJsonSingleInputOutputParser 9 | from langchain.chat_models import ChatOpenAI, AzureChatOpenAI 10 | from langchain.agents import tool, load_tools, OpenAIFunctionsAgent, AgentExecutor, OpenAIMultiFunctionsAgent, initialize_agent, AgentType, AgentOutputParser 11 | from langchain.prompts import MessagesPlaceholder 12 | from langchain.memory import ConversationBufferMemory 13 | from langchain.callbacks import HumanApprovalCallbackHandler 14 | from langchain.schema import ( 15 | AIMessage, 16 | HumanMessage, 17 | SystemMessage, 18 | ) 19 | from uuid import uuid4 20 | 21 | 22 | from .langchain_callback import MyCallbackHandler 23 | from .langchain_tools import VerilogToolkit 24 | from .react_json_single_input_parser import RobustReActJsonSingleInputOutputParser 25 | from ..model import GPTChat, Message 26 | langchain.debug = True 27 | 28 | 29 | def parse_markdown_code_block(text: str, ext: str = 'verilog'): 30 | try: 31 | cleaned_output = text.strip() 32 | if f"```{ext}" in cleaned_output: 33 | _, cleaned_output = cleaned_output.split(f"```{ext}") 34 | if "```" in cleaned_output: 35 | cleaned_output, _ = cleaned_output.split("```") 36 | if cleaned_output.startswith(f"```{ext}"): 37 | cleaned_output = cleaned_output[len(f"```{ext}"):] 38 | if cleaned_output.startswith("```"): 39 | cleaned_output = cleaned_output[len("```"):] 40 | if cleaned_output.endswith("```"): 41 | cleaned_output = cleaned_output[: -len("```")] 42 | return cleaned_output.strip() 43 | except Exception: 44 | return 45 | 46 | 47 | import os 48 | os.environ["GOOGLE_CSE_ID"] = "6188f8abcc17b467b" 49 | os.environ["GOOGLE_API_KEY"] = "AIzaSyDDb_jr_4NAmPxtkoVr6XLd7CrstT-WMCs" 50 | 51 | 52 | class OpenAIFunc(GPTChat): 53 | 54 | def __init__( 55 | self, 56 | model_name: str, 57 | exe, 58 | max_iters: int, 59 | tools: list = [], 60 | system_prompt: str = None, 61 | temperature: int = 0.2, 62 | memory_key: str = None 63 | ): 64 | super().__init__(model_name) 65 | self.toolkit = VerilogToolkit(exe) 66 | if 'azure' in model_name: 67 | import openai 68 | openai.api_type = "azure" 69 | openai.api_base = "https://testinstance1.openai.azure.com/" 70 | openai.api_version = "2023-07-01-preview" 71 | openai.api_key = "1854446716704e61a5d76c807c895d45" 72 | 73 | os.environ['OPENAI_API_BASE'] = "https://testinstance1.openai.azure.com/" 74 | os.environ['OPENAI_API_KEY'] = "1854446716704e61a5d76c807c895d45" 75 | os.environ['OPENAI_API_VERSION'] = "2023-07-01-preview" 76 | os.environ['OPENAI_API_TYPE'] = "azure" 77 | 78 | self.llm = AzureChatOpenAI( 79 | deployment_name="Morris-16k-for-sum", 80 | model_name="gpt-35-turbo-16k", 81 | temperature=temperature, 82 | max_tokens=2048, 83 | top_p=1.0 84 | ) 85 | else: 86 | self.llm = ChatOpenAI(model=model_name, temperature=temperature, max_tokens=2048, top_p=1.0) 87 | 88 | self.tools = tools if tools else self.toolkit.tools 89 | self.agent_executor = initialize_agent( 90 | self.tools, 91 | self.llm, 92 | memory = ConversationBufferMemory(memory_key=memory_key, return_messages=True, input_key='input', output_key='output') if memory_key else None, 93 | agent=AgentType.OPENAI_FUNCTIONS, 94 | # agent=AgentType.OPENAI_MULTI_FUNCTIONS, 95 | verbose=True, 96 | callback_manager=None, 97 | max_iterations=max_iters, 98 | handle_parsing_errors=self._handle_error, 99 | return_intermediate_steps=True, 100 | early_stopping_method="generate", 101 | agent_kwargs={ 102 | 'system_message': SystemMessage(content=system_prompt) if system_prompt else None, 103 | 'extra_prompt_messages': [MessagesPlaceholder(variable_name=memory_key)] if memory_key else None, 104 | # 'output_parser': RobustReActJsonSingleInputOutputParser(), 105 | # 'input_variables': None, 106 | # "prefix": None, 107 | # 'format_instructions': None, 108 | # "output_parser": self.output_parser, 109 | }, 110 | ) 111 | self.uuid = uuid4().hex 112 | self.callbacks = [ 113 | # HumanApprovalCallbackHandler(), 114 | MyCallbackHandler(self.agent_executor) # must be last one 115 | ] 116 | self.history = [] 117 | self.error_logs = [] 118 | 119 | def _handle_error(self, error: str): 120 | try: 121 | with open(f'error_log/{self.uuid}.parse_error_log', 'a') as f: 122 | print(error, file=f) 123 | self.error_logs.append(error) 124 | except Exception: 125 | return "Function arguments is not in valid json format. I should fix it and try again." 126 | 127 | def postprocess(self, text: str): 128 | tmp = parse_markdown_code_block(text, 'verilog') 129 | if tmp: 130 | return tmp 131 | return text 132 | 133 | def serialize(self, history): 134 | try: 135 | from langchain.load.serializable import Serializable 136 | from pydantic import BaseModel 137 | from uuid import UUID 138 | 139 | def todict(obj, classkey=None): 140 | if isinstance(obj, dict): 141 | data = {} 142 | for (k, v) in obj.items(): 143 | data[k] = todict(v, classkey) 144 | return data 145 | elif isinstance(obj, Serializable): 146 | return obj.to_json() 147 | elif isinstance(obj, BaseModel): 148 | return obj.json() 149 | elif isinstance(obj, UUID): 150 | return obj.hex 151 | elif hasattr(obj, "_ast"): 152 | return todict(obj._ast()) 153 | elif hasattr(obj, "__iter__") and not isinstance(obj, str): 154 | return [todict(v, classkey) for v in obj] 155 | elif hasattr(obj, "__dict__"): 156 | data = dict([(key, todict(value, classkey)) 157 | for key, value in obj.__dict__.items() 158 | if not callable(value) and not key.startswith('_')]) 159 | if classkey is not None and hasattr(obj, "__class__"): 160 | data[classkey] = obj.__class__.__name__ 161 | return data 162 | else: 163 | return obj 164 | return todict(history) 165 | except Exception: 166 | import traceback 167 | print(traceback.format_exc()) 168 | import pickle 169 | import base64 170 | return base64.b64encode(pickle.dumps(history)).decode() 171 | 172 | def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: 173 | completion = {} 174 | try: 175 | messages = self.adapt(messages) 176 | import pdb 177 | pdb.set_trace() 178 | completion = self.agent_executor(messages, callbacks=self.callbacks) 179 | 180 | for step in agent.iter(question): 181 | if output := step.get("intermediate_step"): 182 | action, value = output[0] 183 | if action.tool == "GetPrime": 184 | print(f"Checking whether {value} is prime...") 185 | assert is_prime(int(value)) 186 | # Ask user if they want to continue 187 | _continue = input("Should the agent continue (Y/n)?:\n") 188 | if _continue != "Y": 189 | break 190 | 191 | except Exception as e: 192 | err_log = str(e) 193 | if 'maximum context length' in str(e): 194 | pass 195 | elif 'Could not parse' in str(e): 196 | pass 197 | else: 198 | import traceback 199 | err_log = traceback.format_exc() 200 | with open(f'error_log/{self.uuid}.chat_error_log', 'a') as f: 201 | print(err_log, file=f) 202 | self.callbacks[-1].history.append( 203 | ('Exception', err_log) 204 | ) 205 | self.error_logs.append(err_log) 206 | finally: 207 | self.intermediate_steps = completion.get('intermediate_steps', []) 208 | self.agent_history.append(self.intermediate_steps) 209 | # self.agent_history = self.serialize(self.callbacks[-1].history) 210 | completion = self.postprocess(completion.get('output', "")) 211 | return completion 212 | 213 | def adapt(self, messages: List[Message]): 214 | output = [] 215 | for i in messages: 216 | if i.role == "user": 217 | output.append(HumanMessage(content=i.content)) 218 | elif i.role == "assistant": 219 | output.append(AIMessage(content=i.content)) 220 | elif i.role == "system": 221 | output.append(SystemMessage(content=i.content)) 222 | return output 223 | 224 | 225 | if __name__ == "__main__": 226 | react = ReAct("Your name is react.") 227 | messages = [ 228 | HumanMessage(content="I love programming."), 229 | AIMessage(content="I love programming too."), 230 | HumanMessage(content="What is the weather in LA and SF?"), 231 | ] 232 | print(react.agent_executor.run(messages)) 233 | -------------------------------------------------------------------------------- /src/generators/agents/plan_execute.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | from typing import Dict, List, Union, Type, Any 4 | import langchain 5 | from copy import deepcopy 6 | from pydantic import BaseModel 7 | 8 | # from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser, ReActJsonSingleInputOutputParser 9 | from langchain.chat_models import ChatOpenAI, AzureChatOpenAI 10 | from langchain.agents import tool, load_tools, OpenAIFunctionsAgent, AgentExecutor, OpenAIMultiFunctionsAgent, initialize_agent, AgentType, AgentOutputParser 11 | from langchain.prompts import MessagesPlaceholder 12 | from langchain.memory import ConversationBufferMemory 13 | from langchain.callbacks import HumanApprovalCallbackHandler 14 | from langchain.schema import ( 15 | AIMessage, 16 | HumanMessage, 17 | SystemMessage, 18 | ) 19 | from uuid import uuid4 20 | from langchain_experimental.plan_and_execute import PlanAndExecute, load_agent_executor, load_chat_planner 21 | 22 | 23 | from .langchain_callback import MyCallbackHandler 24 | from .langchain_tools import VerilogToolkit 25 | from .react_json_single_input_parser import RobustReActJsonSingleInputOutputParser 26 | from ..model import GPTChat, Message 27 | langchain.debug = True 28 | 29 | 30 | def parse_markdown_code_block(text: str, ext: str = 'verilog'): 31 | try: 32 | cleaned_output = text.strip() 33 | if f"```{ext}" in cleaned_output: 34 | _, cleaned_output = cleaned_output.split(f"```{ext}") 35 | if "```" in cleaned_output: 36 | cleaned_output, _ = cleaned_output.split("```") 37 | if cleaned_output.startswith(f"```{ext}"): 38 | cleaned_output = cleaned_output[len(f"```{ext}"):] 39 | if cleaned_output.startswith("```"): 40 | cleaned_output = cleaned_output[len("```"):] 41 | if cleaned_output.endswith("```"): 42 | cleaned_output = cleaned_output[: -len("```")] 43 | return cleaned_output.strip() 44 | except Exception: 45 | return 46 | 47 | 48 | import os 49 | os.environ["GOOGLE_CSE_ID"] = "6188f8abcc17b467b" 50 | os.environ["GOOGLE_API_KEY"] = "AIzaSyDDb_jr_4NAmPxtkoVr6XLd7CrstT-WMCs" 51 | 52 | 53 | class PlanAndExecute(GPTChat): 54 | 55 | def __init__( 56 | self, 57 | model_name: str, 58 | exe, 59 | max_iters: int, 60 | toolset: list = [], 61 | system_prompt: str = None, 62 | temperature: int = 0.2, 63 | memory_key: str = None 64 | ): 65 | super().__init__(model_name) 66 | if 'azure' in model_name: 67 | import openai 68 | openai.api_type = "azure" 69 | openai.api_base = "https://testinstance1.openai.azure.com/" 70 | openai.api_version = "2023-07-01-preview" 71 | openai.api_key = "1854446716704e61a5d76c807c895d45" 72 | 73 | os.environ['OPENAI_API_BASE'] = "https://testinstance1.openai.azure.com/" 74 | os.environ['OPENAI_API_KEY'] = "1854446716704e61a5d76c807c895d45" 75 | os.environ['OPENAI_API_VERSION'] = "2023-07-01-preview" 76 | os.environ['OPENAI_API_TYPE'] = "azure" 77 | 78 | self.llm = AzureChatOpenAI( 79 | deployment_name="Morris-16k-for-sum", 80 | model_name="gpt-35-turbo-16k", 81 | temperature=temperature, 82 | max_tokens=2048, 83 | top_p=1.0 84 | ) 85 | else: 86 | self.llm = ChatOpenAI(model=model_name, temperature=temperature, max_tokens=2048, top_p=1.0) 87 | 88 | self.toolkit = VerilogToolkit(exe, toolset=toolset) 89 | 90 | # from langchain.tools.render import format_tool_to_openai_function 91 | # from langchain.agents.format_scratchpad import format_to_openai_functions 92 | # llm_with_tools = self.llm.bind( 93 | # functions=[ 94 | # # The retriever tool 95 | # format_tool_to_openai_function(self.toolkit.tools), 96 | # # Response schema 97 | # # convert_pydantic_to_openai_function(Response) 98 | # ] 99 | # ) 100 | 101 | model = self.llm 102 | planner = load_chat_planner(model) 103 | executor = load_agent_executor(model, self.toolkit.tools, verbose=True) 104 | self.agent_executor = PlanAndExecute(planner=planner, executor=executor, verbose=True) 105 | 106 | 107 | self.uuid = uuid4().hex 108 | self.callbacks = [ 109 | # HumanApprovalCallbackHandler(), 110 | MyCallbackHandler(self.agent_executor) # must be last one 111 | ] 112 | 113 | def _handle_error(self, error: str): 114 | try: 115 | with open(f'error_log/{self.uuid}.parse_error_log', 'a') as f: 116 | print(error, file=f) 117 | except Exception: 118 | return "Function arguments is not in valid json format. I should fix it and try again." 119 | 120 | def postprocess(self, text: str): 121 | tmp = parse_markdown_code_block(text, 'verilog') 122 | if tmp: 123 | return tmp 124 | return text 125 | 126 | def serialize(self, history): 127 | try: 128 | from langchain.load.serializable import Serializable 129 | from pydantic import BaseModel 130 | from uuid import UUID 131 | 132 | def todict(obj, classkey=None): 133 | if isinstance(obj, dict): 134 | data = {} 135 | for (k, v) in obj.items(): 136 | data[k] = todict(v, classkey) 137 | return data 138 | elif isinstance(obj, Serializable): 139 | return obj.to_json() 140 | elif isinstance(obj, BaseModel): 141 | return obj.json() 142 | elif isinstance(obj, UUID): 143 | return obj.hex 144 | elif hasattr(obj, "_ast"): 145 | return todict(obj._ast()) 146 | elif hasattr(obj, "__iter__") and not isinstance(obj, str): 147 | return [todict(v, classkey) for v in obj] 148 | elif hasattr(obj, "__dict__"): 149 | data = dict([(key, todict(value, classkey)) 150 | for key, value in obj.__dict__.items() 151 | if not callable(value) and not key.startswith('_')]) 152 | if classkey is not None and hasattr(obj, "__class__"): 153 | data[classkey] = obj.__class__.__name__ 154 | return data 155 | else: 156 | return obj 157 | return todict(history) 158 | except Exception: 159 | import traceback 160 | print(traceback.format_exc()) 161 | import pickle 162 | import base64 163 | return base64.b64encode(pickle.dumps(history)).decode() 164 | 165 | def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: 166 | completion = {} 167 | try: 168 | messages = self.adapt(messages) 169 | completion = self.agent_executor(messages, callbacks=self.callbacks) 170 | except Exception as e: 171 | err_log = str(e) 172 | if 'maximum context length' in str(e): 173 | pass 174 | elif 'Could not parse' in str(e): 175 | pass 176 | else: 177 | import traceback 178 | err_log = traceback.format_exc() 179 | with open(f'error_log/{self.uuid}.chat_error_log', 'a') as f: 180 | print(err_log, file=f) 181 | self.callbacks[-1].history.append( 182 | ('Exception', err_log) 183 | ) 184 | finally: 185 | self.intermediate_steps = self.serialize(completion.get('intermediate_steps', [])) 186 | self.agent_history = self.serialize(self.callbacks[-1].history) 187 | completion = self.postprocess(completion.get('output', "")) 188 | return completion 189 | 190 | def adapt(self, messages: List[Message]): 191 | output = [] 192 | for i in messages: 193 | if i.role == "user": 194 | output.append(HumanMessage(content=i.content)) 195 | elif i.role == "assistant": 196 | output.append(AIMessage(content=i.content)) 197 | elif i.role == "system": 198 | output.append(SystemMessage(content=i.content)) 199 | return output 200 | 201 | 202 | if __name__ == "__main__": 203 | react = ReAct("Your name is react.") 204 | messages = [ 205 | HumanMessage(content="I love programming."), 206 | AIMessage(content="I love programming too."), 207 | HumanMessage(content="What is the weather in LA and SF?"), 208 | ] 209 | print(react.agent_executor.run(messages)) 210 | -------------------------------------------------------------------------------- /src/generators/agents/react.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | from typing import Dict, List, Union, Type, Any 4 | import langchain 5 | from copy import deepcopy 6 | from pydantic import BaseModel 7 | 8 | # from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser, ReActJsonSingleInputOutputParser 9 | from langchain.chat_models import ChatOpenAI, AzureChatOpenAI 10 | from langchain.agents import tool, load_tools, OpenAIFunctionsAgent, AgentExecutor, OpenAIMultiFunctionsAgent, initialize_agent, AgentType, AgentOutputParser 11 | from langchain.prompts import MessagesPlaceholder 12 | from langchain.memory import ConversationBufferMemory 13 | from langchain.callbacks import HumanApprovalCallbackHandler 14 | from langchain.schema import ( 15 | AIMessage, 16 | HumanMessage, 17 | SystemMessage, 18 | ) 19 | from uuid import uuid4 20 | 21 | 22 | from .langchain_callback import MyCallbackHandler 23 | from .langchain_tools import VerilogToolkit 24 | from .react_json_single_input_parser import RobustReActJsonSingleInputOutputParser 25 | from ..model import GPTChat, Message 26 | langchain.debug = True 27 | 28 | 29 | def parse_markdown_code_block(text: str, ext: str = 'verilog'): 30 | try: 31 | cleaned_output = text.strip() 32 | if f"```{ext}" in cleaned_output: 33 | _, cleaned_output = cleaned_output.split(f"```{ext}") 34 | if "```" in cleaned_output: 35 | cleaned_output, _ = cleaned_output.split("```") 36 | if cleaned_output.startswith(f"```{ext}"): 37 | cleaned_output = cleaned_output[len(f"```{ext}"):] 38 | if cleaned_output.startswith("```"): 39 | cleaned_output = cleaned_output[len("```"):] 40 | if cleaned_output.endswith("```"): 41 | cleaned_output = cleaned_output[: -len("```")] 42 | return cleaned_output.strip() 43 | except Exception: 44 | return 45 | 46 | 47 | 48 | class ReAct(GPTChat): 49 | 50 | def __init__( 51 | self, 52 | model_name: str, 53 | exe, 54 | max_iters: int, 55 | toolset: list = [], 56 | system_prompt: str = None, 57 | temperature: int = 0.2, 58 | memory_key: str = None 59 | ): 60 | super().__init__(model_name) 61 | self.llm = ChatOpenAI(model=model_name, temperature=temperature, max_tokens=2048, top_p=1.0) 62 | self.toolkit = VerilogToolkit(exe, toolset=toolset) 63 | self.agent_executor = initialize_agent( 64 | self.toolkit.tools, 65 | self.llm, 66 | memory = ConversationBufferMemory(memory_key=memory_key, return_messages=True, input_key='input', output_key='output') if memory_key else None, 67 | agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, 68 | verbose=True, 69 | callback_manager=None, 70 | max_iterations=max_iters, 71 | # max_execution_time=25, 72 | handle_parsing_errors=self._handle_error, 73 | return_intermediate_steps=True, 74 | early_stopping_method="generate", 75 | agent_kwargs={ 76 | 'system_message': SystemMessage(content=system_prompt) if system_prompt else None, 77 | 'extra_prompt_messages': [MessagesPlaceholder(variable_name=memory_key)] if memory_key else None, 78 | 'output_parser': RobustReActJsonSingleInputOutputParser(), 79 | # 'input_variables': None, 80 | # "prefix": None, 81 | # 'format_instructions': None, 82 | # "output_parser": self.output_parser, 83 | }, 84 | ) 85 | self.uuid = uuid4().hex 86 | self.callbacks = [ 87 | # HumanApprovalCallbackHandler(), 88 | MyCallbackHandler(self.agent_executor) # must be last one 89 | ] 90 | 91 | def _handle_error(self, error: str): 92 | try: 93 | with open(f'error_log/{self.uuid}.parse_error_log', 'a') as f: 94 | print(error, file=f) 95 | except Exception: 96 | return "Function arguments is not in valid json format. I should fix it and try again." 97 | 98 | def postprocess(self, text: str): 99 | tmp = parse_markdown_code_block(text, 'verilog') 100 | if tmp: 101 | return tmp 102 | return text 103 | 104 | def serialize(self, history): 105 | try: 106 | from langchain.load.serializable import Serializable 107 | from pydantic import BaseModel 108 | from uuid import UUID 109 | 110 | def todict(obj, classkey=None): 111 | if isinstance(obj, dict): 112 | data = {} 113 | for (k, v) in obj.items(): 114 | data[k] = todict(v, classkey) 115 | return data 116 | elif isinstance(obj, Serializable): 117 | return obj.to_json() 118 | elif isinstance(obj, BaseModel): 119 | return obj.json() 120 | elif isinstance(obj, UUID): 121 | return obj.hex 122 | elif hasattr(obj, "_ast"): 123 | return todict(obj._ast()) 124 | elif hasattr(obj, "__iter__") and not isinstance(obj, str): 125 | return [todict(v, classkey) for v in obj] 126 | elif hasattr(obj, "__dict__"): 127 | data = dict([(key, todict(value, classkey)) 128 | for key, value in obj.__dict__.items() 129 | if not callable(value) and not key.startswith('_')]) 130 | if classkey is not None and hasattr(obj, "__class__"): 131 | data[classkey] = obj.__class__.__name__ 132 | return data 133 | else: 134 | return obj 135 | return todict(history) 136 | except Exception: 137 | import traceback 138 | print(traceback.format_exc()) 139 | import pickle 140 | import base64 141 | return base64.b64encode(pickle.dumps(history)).decode() 142 | 143 | def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: 144 | completion = {} 145 | try: 146 | messages = self.adapt(messages) 147 | completion = self.agent_executor(messages, callbacks=self.callbacks) 148 | except Exception as e: 149 | err_log = str(e) 150 | if 'maximum context length' in str(e): 151 | pass 152 | elif 'Could not parse' in str(e): 153 | pass 154 | else: 155 | import traceback 156 | err_log = traceback.format_exc() 157 | with open(f'error_log/{self.uuid}.chat_error_log', 'a') as f: 158 | print(err_log, file=f) 159 | self.callbacks[-1].history.append( 160 | ('Exception', err_log) 161 | ) 162 | finally: 163 | self.intermediate_steps = self.serialize(completion.get('intermediate_steps', [])) 164 | self.agent_history = self.serialize(self.callbacks[-1].history) 165 | completion = self.postprocess(completion.get('output', "")) 166 | return completion 167 | 168 | def adapt(self, messages: List[Message]): 169 | output = [] 170 | for i in messages: 171 | if i.role == "user": 172 | output.append(HumanMessage(content=i.content)) 173 | elif i.role == "assistant": 174 | output.append(AIMessage(content=i.content)) 175 | elif i.role == "system": 176 | output.append(SystemMessage(content=i.content)) 177 | return output 178 | 179 | 180 | if __name__ == "__main__": 181 | react = ReAct("Your name is react.") 182 | messages = [ 183 | HumanMessage(content="I love programming."), 184 | AIMessage(content="I love programming too."), 185 | HumanMessage(content="What is the weather in LA and SF?"), 186 | ] 187 | print(react.agent_executor.run(messages)) 188 | -------------------------------------------------------------------------------- /src/generators/agents/react_json_single_input_parser.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from uuid import uuid4 4 | from typing import Union 5 | 6 | from langchain.agents.agent import AgentOutputParser 7 | from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS 8 | from langchain.schema import AgentAction, AgentFinish, OutputParserException 9 | 10 | FINAL_ANSWER_ACTION = "Final Answer:" 11 | 12 | 13 | class RobustReActOutputParser(AgentOutputParser): 14 | """Output parser for the ReAct agent.""" 15 | 16 | def parse(self, text: str) -> Union[AgentAction, AgentFinish]: 17 | 18 | action_prefix = "Action:" 19 | if action_prefix not in text: 20 | return AgentFinish({"output": text}, text) 21 | raise OutputParserException(f"Could not parse LLM Output (no action): {text}") 22 | 23 | st = text.find(action_prefix) + len(action_prefix) 24 | action_str = text[st:] 25 | 26 | # Parse out the action and the directive. 27 | re_matches = re.search(r"(.*?)\[(.*?)\]", action_str) 28 | if re_matches is None: 29 | raise OutputParserException( 30 | f"Could not parse action directive: {action_str}" 31 | ) 32 | action, action_input = re_matches.group(1), re_matches.group(2) 33 | if action == "Finish": 34 | return AgentFinish({"output": action_input}, text) 35 | else: 36 | return AgentAction(action, action_input, text) 37 | 38 | 39 | class RobustReActJsonSingleInputOutputParser(AgentOutputParser): 40 | """Parses ReAct-style LLM calls that have a single tool input in json format. 41 | 42 | Expects output to be in one of two formats. 43 | 44 | If the output signals that an action should be taken, 45 | should be in the below format. This will result in an AgentAction 46 | being returned. 47 | 48 | ``` 49 | Thought: agent thought here 50 | Action: 51 | ``` 52 | { 53 | "action": "search", 54 | "action_input": "what is the temperature in SF" 55 | } 56 | ``` 57 | ``` 58 | 59 | If the output signals that a final answer should be given, 60 | should be in the below format. This will result in an AgentFinish 61 | being returned. 62 | 63 | ``` 64 | Thought: agent thought here 65 | Final Answer: The temperature is 100 degrees 66 | ``` 67 | 68 | """ 69 | 70 | pattern = re.compile(r"^.*?`{3}(?:json)?\n(.*?)`{3}.*?$", re.DOTALL) 71 | """Regex pattern to parse the output.""" 72 | 73 | parse_fail_record = [] 74 | 75 | def get_format_instructions(self) -> str: 76 | return FORMAT_INSTRUCTIONS 77 | 78 | def parse(self, text: str) -> Union[AgentAction, AgentFinish]: 79 | includes_answer = FINAL_ANSWER_ACTION in text 80 | try: 81 | found = self.pattern.search(text) 82 | if not found: 83 | # Fast fail to parse Final Answer. 84 | raise ValueError("action not found") 85 | action = found.group(1) 86 | response = json.loads(action.strip().replace('\n', '')) 87 | includes_action = "action" in response 88 | if includes_answer and includes_action: 89 | raise OutputParserException( 90 | "Parsing LLM output produced a final answer " 91 | f"and a parse-able action: {text}" 92 | ) 93 | return AgentAction( 94 | response["action"], response.get("action_input", {}), text 95 | ) 96 | 97 | except Exception: 98 | if not includes_answer: 99 | try: 100 | return RobustReActOutputParser().parse(text) 101 | except Exception: 102 | self.parse_fail_record.append(text) 103 | raise OutputParserException(f"Could not parse LLM output: {text}") 104 | output = text.split(FINAL_ANSWER_ACTION)[-1].strip() 105 | return AgentFinish({"output": output}, text) 106 | 107 | @property 108 | def _type(self) -> str: 109 | return "robust-react-json-single-input" 110 | 111 | -------------------------------------------------------------------------------- /src/generators/agents/rtlfixer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | from typing import Dict, List, Union, Type, Any 5 | import langchain 6 | from copy import deepcopy 7 | from pydantic import BaseModel 8 | 9 | from langchain.chat_models import ChatOpenAI, AzureChatOpenAI 10 | from langchain.agents import tool, load_tools, OpenAIFunctionsAgent, AgentExecutor, OpenAIMultiFunctionsAgent, initialize_agent, AgentType, AgentOutputParser 11 | from langchain.prompts import MessagesPlaceholder 12 | from langchain.memory import ConversationBufferMemory 13 | from langchain.callbacks import HumanApprovalCallbackHandler 14 | from langchain.schema import ( 15 | AIMessage, 16 | HumanMessage, 17 | SystemMessage, 18 | ) 19 | from uuid import uuid4 20 | 21 | 22 | from .langchain_callback import MyCallbackHandler 23 | from .langchain_tools import VerilogToolkit 24 | from .react_json_single_input_parser import RobustReActJsonSingleInputOutputParser 25 | from ..model import GPTChat, Message 26 | from .utils import parse_markdown_code_block 27 | langchain.debug = True 28 | 29 | 30 | 31 | class CodeAgent(GPTChat): 32 | 33 | def __init__( 34 | self, 35 | model_name: str, 36 | exe, 37 | max_iters: int, 38 | toolset: list = None, 39 | system_prompt: str = None, 40 | temperature: int = 0.4, 41 | memory_key: str = None, 42 | method: str = "", 43 | compiler: str = "" 44 | ): 45 | super().__init__(model_name) 46 | self.llm = ChatOpenAI(model=model_name, temperature=temperature, max_tokens=2048, top_p=1.0) 47 | # memory_key = "memory" 48 | self.toolkit = VerilogToolkit(exe, toolset=toolset, llm=self.llm, method=method, compilername=compiler) 49 | self.agent_executor = initialize_agent( 50 | self.toolkit.tools, 51 | self.llm, 52 | memory = ConversationBufferMemory(memory_key=memory_key, return_messages=True, input_key='input', output_key='output') if memory_key else None, 53 | agent=AgentType.OPENAI_FUNCTIONS, 54 | # agent=AgentType.OPENAI_MULTI_FUNCTIONS, 55 | verbose=False, 56 | callback_manager=None, 57 | max_iterations=max_iters, 58 | handle_parsing_errors=self._handle_error, 59 | return_intermediate_steps=True, 60 | # early_stopping_method="generate", 61 | agent_kwargs={ 62 | 'system_message': SystemMessage(content=system_prompt) if system_prompt else None, 63 | 'extra_prompt_messages': [MessagesPlaceholder(variable_name=memory_key)] if memory_key else None, 64 | }, 65 | ) 66 | self.uuid = uuid4().hex 67 | self.callbacks = [ 68 | # HumanApprovalCallbackHandler(), 69 | MyCallbackHandler(self.agent_executor) # must be last one 70 | ] 71 | self.agent_history = [] 72 | self.error_logs = [] 73 | self.max_iters = max_iters 74 | 75 | def reset_logs(self): 76 | self.agent_history = [] 77 | self.error_logs = [] 78 | 79 | def _handle_error(self, error: str): 80 | try: 81 | self.error_logs.append(error) 82 | except Exception: 83 | return "Function arguments is not in valid json format. I should fix it and try again." 84 | 85 | def postprocess(self, text: str): 86 | tmp = parse_markdown_code_block(text, 'verilog') 87 | if tmp: 88 | return tmp 89 | return text 90 | 91 | def serialize(self, history): 92 | try: 93 | from langchain.load.serializable import Serializable 94 | from pydantic import BaseModel 95 | from uuid import UUID 96 | 97 | def todict(obj, classkey=None): 98 | if isinstance(obj, dict): 99 | data = {} 100 | for (k, v) in obj.items(): 101 | data[k] = todict(v, classkey) 102 | return data 103 | elif isinstance(obj, Serializable): 104 | return obj.to_json() 105 | elif isinstance(obj, BaseModel): 106 | return obj.json() 107 | elif isinstance(obj, UUID): 108 | return obj.hex 109 | elif hasattr(obj, "_ast"): 110 | return todict(obj._ast()) 111 | elif hasattr(obj, "__iter__") and not isinstance(obj, str): 112 | return [todict(v, classkey) for v in obj] 113 | elif hasattr(obj, "__dict__"): 114 | data = dict([(key, todict(value, classkey)) 115 | for key, value in obj.__dict__.items() 116 | if not callable(value) and not key.startswith('_')]) 117 | if classkey is not None and hasattr(obj, "__class__"): 118 | data[classkey] = obj.__class__.__name__ 119 | return data 120 | else: 121 | return obj 122 | return todict(history) 123 | except Exception: 124 | import traceback 125 | self.error_logs.append(traceback.format_exc()) 126 | import pickle 127 | import base64 128 | return base64.b64encode(pickle.dumps(history)).decode() 129 | 130 | def self_verify(self, output: str): 131 | prompt = f""" 132 | ``` 133 | {output} 134 | ``` 135 | Is the logic implemented in the code? 136 | If yes, answer YES. 137 | If not, answer NO. 138 | """ 139 | verify_result = self.toolkit.verify._run(prompt, "The logic is not yet implemented.") 140 | intermediate_step = { 141 | 'self-verify': { 142 | 'question': prompt, 143 | 'result': verify_result, 144 | } 145 | } 146 | return verify_result, intermediate_step 147 | 148 | def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: 149 | completion = {} 150 | try: 151 | messages = self.adapt(messages) 152 | completion = {'output': "", 'intermediate_steps': []} 153 | random_trials_if_compiler_cannot_pass = 3 154 | # completion = self.agent_executor(messages, callbacks=self.callbacks) 155 | 156 | 157 | for _ in range(random_trials_if_compiler_cannot_pass): 158 | try: 159 | for step in self.agent_executor.iter(messages): 160 | 161 | if output := step.get("intermediate_steps"): 162 | # assert completion['intermediate_steps'] == step['intermediate_steps'] 163 | # completion['intermediate_steps'] = output 164 | action, value = output[-1] 165 | if action.tool == "verilog_compiler" and "Success" in value: 166 | # agent invoke compiler & compile success 167 | # completion['output'] = action.tool_input['code_completion'] 168 | break 169 | if output := step.get("intermediate_step"): 170 | completion['intermediate_steps'] += output 171 | action, value = output[-1] 172 | if action.tool == "verilog_compiler" and "give" in value: 173 | # agent invoke compiler & compile success 174 | completion['output'] = action.tool_input['code_completion'] 175 | break 176 | # if output := step.get("output"): 177 | # completion['output'] = output 178 | 179 | # if self.toolkit.num_compile > self.max_iters: 180 | # break 181 | 182 | 183 | if completion['output']: 184 | compiler_log = self.toolkit.tools[0].run(completion['output']) 185 | if f"The code has no compile error." in compiler_log: 186 | # if "Success" in compiler_log: 187 | verify_result, intermediate_step = self.self_verify(completion['output']) 188 | completion['intermediate_steps'].append(intermediate_step) 189 | if "pass" in verify_result.lower(): 190 | break 191 | else: 192 | messages.append(AIMessage(content=completion['output'])) 193 | messages.append(HumanMessage(content="The logic is not implemented yet. Complete it instead of comments.")) 194 | # break 195 | 196 | 197 | # Syntax error cannot fix. I give up. Generate a new sample. 198 | except Exception as e: 199 | import traceback 200 | self.error_logs.append(traceback.format_exc()) 201 | 202 | # 5 sample did not pass 203 | 204 | except Exception as e: 205 | err_log = str(e) 206 | if 'maximum context length' in str(e): 207 | pass 208 | elif 'Could not parse' in str(e): 209 | pass 210 | else: 211 | import traceback 212 | err_log = traceback.format_exc() 213 | with open(f'error_log/{self.uuid}.chat_error_log', 'a') as f: 214 | print(err_log, file=f) 215 | self.callbacks[-1].history.append( 216 | ('Exception', err_log) 217 | ) 218 | self.error_logs.append(err_log) 219 | finally: 220 | self.intermediate_steps = completion.get('intermediate_steps', []) 221 | self.agent_history.append(self.intermediate_steps) 222 | completion = self.postprocess(completion.get('output', "")) 223 | return completion 224 | 225 | def adapt(self, messages: List[Message]): 226 | output = [] 227 | for i in messages: 228 | if i.role == "user": 229 | output.append(HumanMessage(content=i.content)) 230 | elif i.role == "assistant": 231 | output.append(AIMessage(content=i.content)) 232 | elif i.role == "system": 233 | output.append(SystemMessage(content=i.content)) 234 | return output 235 | 236 | 237 | 238 | class FixAgent(GPTChat): 239 | 240 | def __init__( 241 | self, 242 | model_name: str, 243 | exe, 244 | max_iters: int, 245 | toolset: list = None, 246 | system_prompt: str = None, 247 | temperature: int = 0.4, 248 | memory_key: str = None, 249 | method: str = False 250 | ): 251 | super().__init__(model_name) 252 | if 'azure' in model_name: 253 | import openai 254 | openai.api_type = "azure" 255 | openai.api_base = "https://testinstance1.openai.azure.com/" 256 | openai.api_version = "2023-07-01-preview" 257 | openai.api_key = "1854446716704e61a5d76c807c895d45" 258 | 259 | os.environ['OPENAI_API_BASE'] = "https://testinstance1.openai.azure.com/" 260 | os.environ['OPENAI_API_KEY'] = "1854446716704e61a5d76c807c895d45" 261 | os.environ['OPENAI_API_VERSION'] = "2023-07-01-preview" 262 | os.environ['OPENAI_API_TYPE'] = "azure" 263 | 264 | self.llm = AzureChatOpenAI( 265 | deployment_name="Morris-16k-for-sum", 266 | model_name="gpt-35-turbo-16k", 267 | temperature=temperature, 268 | max_tokens=2048, 269 | top_p=1.0 270 | ) 271 | else: 272 | self.llm = ChatOpenAI(model=model_name, temperature=temperature, max_tokens=2048, top_p=1.0) 273 | 274 | self.toolkit = VerilogToolkit(exe, toolset=toolset, llm=self.llm, method=method) 275 | self.agent_executor = initialize_agent( 276 | self.toolkit.tools, 277 | self.llm, 278 | memory = ConversationBufferMemory(memory_key=memory_key, return_messages=True, input_key='input', output_key='output') if memory_key else None, 279 | # agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, 280 | agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, 281 | verbose=False, 282 | callback_manager=None, 283 | max_iterations=max_iters, 284 | handle_parsing_errors=self._handle_error, 285 | return_intermediate_steps=True, 286 | early_stopping_method="generate", 287 | agent_kwargs={ 288 | 'system_message': SystemMessage(content=system_prompt) if system_prompt else None, 289 | 'extra_prompt_messages': [MessagesPlaceholder(variable_name=memory_key)] if memory_key else None, 290 | # 'output_parser': RobustReActJsonSingleInputOutputParser(), 291 | }, 292 | ) 293 | self.uuid = uuid4().hex 294 | self.callbacks = [ 295 | # HumanApprovalCallbackHandler(), 296 | MyCallbackHandler(self.agent_executor) # must be last one 297 | ] 298 | self.agent_history = [] 299 | self.error_logs = [] 300 | 301 | def reset_logs(self): 302 | self.agent_history = [] 303 | self.error_logs = [] 304 | 305 | def _handle_error(self, error: str): 306 | try: 307 | self.error_logs.append(error) 308 | except Exception: 309 | return "Function arguments is not in valid json format. I should fix it and try again." 310 | 311 | def postprocess(self, text: str): 312 | tmp = parse_markdown_code_block(text, 'verilog') 313 | if tmp: 314 | return tmp 315 | return text 316 | 317 | def serialize(self, history): 318 | try: 319 | from langchain.load.serializable import Serializable 320 | from pydantic import BaseModel 321 | from uuid import UUID 322 | 323 | def todict(obj, classkey=None): 324 | if isinstance(obj, dict): 325 | data = {} 326 | for (k, v) in obj.items(): 327 | data[k] = todict(v, classkey) 328 | return data 329 | elif isinstance(obj, Serializable): 330 | return obj.to_json() 331 | elif isinstance(obj, BaseModel): 332 | return obj.json() 333 | elif isinstance(obj, UUID): 334 | return obj.hex 335 | elif hasattr(obj, "_ast"): 336 | return todict(obj._ast()) 337 | elif hasattr(obj, "__iter__") and not isinstance(obj, str): 338 | return [todict(v, classkey) for v in obj] 339 | elif hasattr(obj, "__dict__"): 340 | data = dict([(key, todict(value, classkey)) 341 | for key, value in obj.__dict__.items() 342 | if not callable(value) and not key.startswith('_')]) 343 | if classkey is not None and hasattr(obj, "__class__"): 344 | data[classkey] = obj.__class__.__name__ 345 | return data 346 | else: 347 | return obj 348 | return todict(history) 349 | except Exception: 350 | import traceback 351 | self.error_logs.append(traceback.format_exc()) 352 | import pickle 353 | import base64 354 | return base64.b64encode(pickle.dumps(history)).decode() 355 | 356 | def self_verify(self, output: str): 357 | prompt = f""" 358 | ``` 359 | {output} 360 | ``` 361 | Is the above answer complete ? 362 | If yes, answer YES. 363 | If not, answer NO. 364 | """ 365 | verify_result = self.toolkit.verify._run(prompt, "Continue to debug using tools and must follow the action format instruction.") 366 | intermediate_step = { 367 | 'self-verify': { 368 | 'question': prompt, 369 | 'result': verify_result, 370 | } 371 | } 372 | return verify_result, intermediate_step 373 | 374 | def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: 375 | completion = {} 376 | try: 377 | messages = self.adapt(messages) 378 | completions = {'output': "", 'intermediate_steps': []} 379 | 380 | for _ in range(3): 381 | try: 382 | completion = self.agent_executor(messages, callbacks=self.callbacks) 383 | completions['intermediate_steps'] = completion['intermediate_steps'] 384 | break 385 | 386 | # verify if need to continue 387 | # verify_result, intermediate_step = self.self_verify(completion['output']) 388 | # completions['intermediate_steps'].append(intermediate_step) 389 | # if "pass" in verify_result.lower(): 390 | # break 391 | # messages[-1].content += f"{completion['output']}\n{verify_result}" 392 | 393 | except Exception as e: 394 | import traceback 395 | self.error_logs.append(traceback.format_exc()) 396 | 397 | 398 | # random_trials_if_cannot_pass = 5 399 | # stop = False 400 | # for _ in range(random_trials_if_cannot_pass): 401 | # try: 402 | # for step in self.agent_executor.iter(messages): 403 | 404 | # if output := step.get("intermediate_steps"): 405 | # assert completion['intermediate_steps'] == step['intermediate_steps'] 406 | # completion['intermediate_steps'] = output 407 | # action, value = output[-1] 408 | # if action.tool == "verify" and "good" in value: 409 | # # agent invoke self verify & success 410 | # completion['output'] = action.tool_input['answer'] 411 | # if output := step.get("intermediate_step"): 412 | # completion['intermediate_steps'] += output 413 | # action, value = output[-1] 414 | # if action.tool == "verify" and "good" in value: 415 | # # agent invoke self verify & success 416 | # completion['output'] = action.tool_input['answer'] 417 | # if output := step.get("output"): 418 | # completion['output'] = output 419 | 420 | # verify_result = self.toolkit.tools[-1].run(output) 421 | # if "good" in verify_result: 422 | # break 423 | # else: 424 | # messages[-1].content += f"\n{output}" 425 | 426 | # import pdb 427 | # pdb.set_trace() 428 | 429 | # except Exception as e: 430 | # import traceback 431 | # self.error_logs.append(traceback.format_exc()) 432 | 433 | # if "good" not in verify_result: 434 | # import pdb 435 | # pdb.set_trace() 436 | 437 | 438 | except Exception as e: 439 | err_log = str(e) 440 | if 'maximum context length' in str(e): 441 | pass 442 | elif 'Could not parse' in str(e): 443 | pass 444 | else: 445 | import traceback 446 | err_log = traceback.format_exc() 447 | with open(f'error_log/{self.uuid}.chat_error_log', 'a') as f: 448 | print(err_log, file=f) 449 | self.callbacks[-1].history.append( 450 | ('Exception', err_log) 451 | ) 452 | self.error_logs.append(err_log) 453 | finally: 454 | self.intermediate_steps = completions.get('intermediate_steps', []) 455 | self.agent_history.append(self.intermediate_steps) 456 | completion = completion.get('output', "") 457 | return completion 458 | 459 | def adapt(self, messages: List[Message]): 460 | output = [] 461 | for i in messages: 462 | if i.role == "user": 463 | output.append(HumanMessage(content=i.content)) 464 | elif i.role == "assistant": 465 | output.append(AIMessage(content=i.content)) 466 | elif i.role == "system": 467 | output.append(SystemMessage(content=i.content)) 468 | return output 469 | 470 | 471 | def RTLFixer(agent_name: str, model_name: str, **kwargs): 472 | if "code" in agent_name: 473 | return CodeAgent(model_name, **kwargs) 474 | elif "fix" in agent_name: 475 | return FixAgent(model_name, **kwargs) 476 | -------------------------------------------------------------------------------- /src/generators/agents/utils.py: -------------------------------------------------------------------------------- 1 | def parse_markdown_code_block(text: str, ext: str = 'verilog'): 2 | try: 3 | cleaned_output = text.strip() 4 | if f"```{ext}" in cleaned_output: 5 | _, cleaned_output = cleaned_output.split(f"```{ext}") 6 | if "```" in cleaned_output: 7 | cleaned_output, _ = cleaned_output.split("```") 8 | if cleaned_output.startswith(f"```{ext}"): 9 | cleaned_output = cleaned_output[len(f"```{ext}"):] 10 | if cleaned_output.startswith("```"): 11 | cleaned_output = cleaned_output[len("```"):] 12 | if cleaned_output.endswith("```"): 13 | cleaned_output = cleaned_output[: -len("```")] 14 | return cleaned_output.strip() 15 | except Exception: 16 | return 17 | -------------------------------------------------------------------------------- /src/generators/factory.py: -------------------------------------------------------------------------------- 1 | from .py_generate import PyGenerator 2 | from .rs_generate import RsGenerator 3 | from .verilog_generate import VerilogGenerator 4 | from .generator_types import Generator 5 | from .model import CodeLlama, ModelBase, GPT4, GPT35, StarChat, GPTDavinci, PhindCodeLlama, CodeLlama2, CodeLlama3, CodeGen 6 | from .agents import ReAct, PlanAndExecute, OpenAIFunc, RTLFixer 7 | 8 | 9 | def generator_factory(lang: str) -> Generator: 10 | if lang == "py" or lang == "python": 11 | return PyGenerator() 12 | elif lang == "rs" or lang == "rust": 13 | return RsGenerator() 14 | elif lang == "vg" or lang == "verilog": 15 | return VerilogGenerator() 16 | else: 17 | raise ValueError(f"Invalid language for generator: {lang}") 18 | 19 | 20 | def agent_factory(agent_name: str, model_name: str, **kwargs): 21 | if agent_name == "cot": 22 | return model_factory(model_name) 23 | elif agent_name == "react": 24 | return ReAct(model_name, **kwargs) 25 | elif "rtlfixer" in agent_name: 26 | return RTLFixer(agent_name, model_name, **kwargs) 27 | elif agent_name == "openaifunc": 28 | return OpenAIFunc(model_name, **kwargs) 29 | elif agent_name == "planexec": 30 | return PlanAndExecute(model_name, **kwargs) 31 | else: 32 | raise ValueError(f"Invalid agent name: {agent_name}") 33 | 34 | 35 | def model_factory(model_name: str) -> ModelBase: 36 | if 'gpt-4' in model_name: 37 | return GPT4(model_name) 38 | elif "gpt-3.5" in model_name: 39 | return GPT35(model_name) 40 | elif "starchat" in model_name: 41 | return StarChat() 42 | elif "codegen" in model_name: 43 | return CodeGen() 44 | elif model_name.startswith("codellama"): 45 | # if it has `-` in the name, version was specified 46 | kwargs = {} 47 | if "B-" in model_name: 48 | kwargs["version"] = model_name.split("-")[-1] 49 | elif "-" in model_name: 50 | kwargs["version"] = model_name.split("-")[1] 51 | 52 | if "Phind" in model_name: 53 | MODEL = PhindCodeLlama 54 | elif model_name.startswith("codellama2"): 55 | MODEL = CodeLlama2 56 | elif model_name.startswith("codellama3"): 57 | MODEL = CodeLlama3 58 | else: 59 | MODEL = CodeLlama 60 | 61 | return MODEL(**kwargs) 62 | elif model_name.startswith("text-davinci"): 63 | return GPTDavinci(model_name) 64 | else: 65 | raise ValueError(f"Invalid model name: {model_name}") 66 | -------------------------------------------------------------------------------- /src/generators/generator_types.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | from abc import abstractmethod, ABC 3 | 4 | from generators.model import ModelBase 5 | 6 | 7 | class Generator: 8 | @abstractmethod 9 | def self_reflection(self, func: str, feedback: str, model: ModelBase) -> str: 10 | ... 11 | 12 | @abstractmethod 13 | def func_impl( 14 | self, 15 | func_sig: str, 16 | model: ModelBase, 17 | strategy: str, 18 | prev_func_impl: Optional[str] = None, 19 | feedback: Optional[str] = None, 20 | self_reflection: Optional[str] = None, 21 | num_comps: int = 1, 22 | temperature: float = 0.0, 23 | ) -> Union[str, List[str]]: 24 | ... 25 | 26 | @abstractmethod 27 | def internal_tests( 28 | self, 29 | func_sig: str, 30 | model: ModelBase, 31 | max_num_tests: int = 5 32 | ) -> List[str]: 33 | ... 34 | -------------------------------------------------------------------------------- /src/generators/generator_utils.py: -------------------------------------------------------------------------------- 1 | from generators.model import ModelBase, Message 2 | import random 3 | 4 | from typing import Union, List, Optional, Callable 5 | 6 | 7 | def func_impl_prepare_prompt( 8 | func_sig: dict, 9 | model: ModelBase, 10 | strategy: str, 11 | prev_func_impl, 12 | feedback, 13 | self_reflection, 14 | num_comps, 15 | temperature, 16 | reflexion_chat_instruction: str, 17 | reflexion_few_shot: str, 18 | simple_chat_instruction: str, 19 | reflexion_completion_instruction: str, 20 | simple_completion_instruction: str, 21 | code_block_instruction: str, 22 | parse_code_block: Callable[[str], str], 23 | add_code_block: Callable[[str], str], 24 | simple_few_shot: str = "", 25 | lang: str = None, 26 | ) -> Union[str, List[str]]: 27 | prompt = "" 28 | org_func_sig = func_sig 29 | if isinstance(func_sig, dict): 30 | prompt = func_sig['prompt'] 31 | if lang == "verilog": 32 | func_sig = f"{func_sig['detail_description']}\nImplement the above description in the following module.\n{func_sig['prompt']}" 33 | else: 34 | func_sig = func_sig['prompt'] 35 | 36 | # if strategy != "reflexion" and strategy != "simple": 37 | # raise ValueError( 38 | # f"Invalid strategy: given `{strategy}` but expected one of `reflexion` or `simple`") 39 | if strategy == "reflexion" and (prev_func_impl is None or feedback is None or self_reflection is None): 40 | raise ValueError( 41 | f"Invalid arguments: given `strategy=reflexion` but `prev_func_impl`, `feedback`, or `self_reflection` is None") 42 | 43 | if model.is_chat: 44 | if strategy == "reflexion": 45 | message = f"{reflexion_few_shot}\n[previous impl]:\n{add_code_block(prev_func_impl)}\n\n[unit test results from previous impl]:\n{feedback}\n\n[reflection on previous impl]:\n{self_reflection}\n\n[improved impl]:\n{func_sig}" 46 | prompt = f"{reflexion_chat_instruction}\n{code_block_instruction}" 47 | print_messages(prompt, message) 48 | messages = [ 49 | Message( 50 | role="system", 51 | content=prompt, 52 | ), 53 | Message( 54 | role="user", # TODO: check this 55 | content=reflexion_few_shot, 56 | ), 57 | Message( 58 | role="assistant", 59 | content=add_code_block(prev_func_impl), 60 | ), 61 | Message( 62 | role="user", 63 | content=f"[unit test results from previous impl]:\n{feedback}\n\n[reflection on previous impl]:", 64 | ), 65 | Message( 66 | role="assistant", 67 | content=self_reflection, 68 | ), 69 | Message( 70 | role="user", 71 | content=f"[improved impl]:\n{func_sig}", 72 | ), 73 | ] 74 | elif "rtlfixer" in strategy: 75 | if not prev_func_impl: 76 | system_prompt = f"{simple_chat_instruction}\n{code_block_instruction}\nI want you to run the compiler to ensure the correctness of the syntax before answering.\n" 77 | print_messages(system_prompt, func_sig) 78 | messages = [ 79 | Message( 80 | role="system", 81 | content=system_prompt, 82 | ), 83 | ] + [ 84 | Message( 85 | role=role, 86 | content=content, 87 | ) 88 | for role, content in simple_few_shot 89 | ] + [ 90 | Message( 91 | role="user", 92 | content=f"{func_sig}", 93 | ), 94 | ] 95 | else: 96 | feedback = feedback[-1]['test_output'] if feedback[-1]['test_output'] else feedback[-1]['compiler_log'] 97 | # system_prompt = f"{simple_chat_instruction}\n{code_block_instruction}\nRun the compiler to ensure the correctness of the syntax.\n" 98 | system_prompt = f"{simple_chat_instruction}\n{code_block_instruction}\n" 99 | 100 | # message = f"{reflexion_few_shot}\n[previous implementation]:\n{prev_func_impl}\n\n[simulation results from previous implementation]:\n{feedback}\n\n[reflection on previous implementation]:\n{self_reflection}\n\n[improved implementation]:\n" 101 | # print_messages(system_prompt, message) 102 | # messages = [ 103 | # Message( 104 | # role="system", 105 | # content=system_prompt, 106 | # ), 107 | # Message( 108 | # role="user", 109 | # content=reflexion_few_shot, 110 | # ), 111 | # Message( 112 | # role='user', 113 | # content=f"[problem description]:\n{org_func_sig['detail_description']}\n[previous implementation]:" 114 | # ), 115 | # Message( 116 | # role="assistant", 117 | # content=prev_func_impl, 118 | # ), 119 | # Message( 120 | # role="user", 121 | # content=f"[simulation results from previous implementation]:\n{feedback}\n\n[reflection on previous implementation]:", 122 | # ), 123 | # Message( 124 | # role="assistant", 125 | # content=self_reflection, 126 | # ), 127 | # Message( 128 | # role="user", 129 | # content=f"[improved implementation]:\n", 130 | # ), 131 | # ] 132 | 133 | 134 | reflexion = self_reflection[-1] if self_reflection else "" 135 | messages = [ 136 | Message( 137 | role="system", 138 | content=system_prompt, 139 | ), 140 | Message( 141 | role='user', 142 | content=f"{func_sig}" 143 | ), 144 | Message( 145 | role="assistant", 146 | content=prev_func_impl, 147 | ), 148 | Message( 149 | role="user", 150 | content=f"{feedback}\n{reflexion}\n{func_sig}", 151 | ), 152 | ] 153 | print_messages(system_prompt, "\n".join([i.content for i in messages])) 154 | else: 155 | if lang == "verilog": 156 | system_prompt = f"{simple_chat_instruction}\n{code_block_instruction}\n" 157 | print_messages(system_prompt, func_sig) 158 | messages = [ 159 | Message( 160 | role="system", 161 | content=system_prompt, 162 | ), 163 | ] + [ 164 | Message( 165 | role=role, 166 | content=content, 167 | ) 168 | for role, content in simple_few_shot 169 | ] + [ 170 | Message( 171 | role="user", 172 | content=f"{func_sig}", 173 | ), 174 | ] 175 | 176 | # system_prompt = f"{simple_chat_instruction}" 177 | # print_messages(system_prompt, func_sig) 178 | # messages = [ 179 | # Message( 180 | # role="system", 181 | # content=system_prompt, 182 | # ), 183 | # Message( 184 | # role="user", 185 | # content=f"{code_block_instruction}\n// {func_sig}", 186 | # ), 187 | # ] 188 | 189 | # system_prompt = f"{simple_chat_instruction}" 190 | # print_messages(system_prompt, func_sig) 191 | # messages = [ 192 | # Message( 193 | # role="system", 194 | # content=f"{simple_chat_instruction}\n", 195 | # ), 196 | # Message( 197 | # role="user", 198 | # content=code_block_instruction, 199 | # ), 200 | # Message( 201 | # role="assistant", 202 | # content=func_sig, 203 | # ), 204 | # ] 205 | else: 206 | system_prompt = f"{simple_chat_instruction}\n{code_block_instruction}" 207 | print_messages(system_prompt, func_sig) 208 | messages = [ 209 | Message( 210 | role="system", 211 | content=f"{simple_chat_instruction}\n{code_block_instruction}", 212 | ), 213 | Message( 214 | role="user", 215 | content=func_sig, 216 | ), 217 | ] 218 | 219 | return messages 220 | else: 221 | if strategy == "reflexion": 222 | prompt = f"{reflexion_completion_instruction}\n{add_code_block(prev_func_impl)}\n\nunit tests:\n{feedback}\n\nhint:\n{self_reflection}\n\n# improved implementation\n{func_sig}\n{code_block_instruction}" 223 | 224 | else: 225 | prompt = f"{simple_completion_instruction}\n{func_sig}\n{code_block_instruction}" 226 | return prompt 227 | 228 | 229 | 230 | def generic_generate_func_impl( 231 | func_sig: dict, 232 | model: ModelBase, 233 | strategy: str, 234 | prev_func_impl, 235 | feedback, 236 | self_reflection, 237 | num_comps, 238 | temperature, 239 | reflexion_chat_instruction: str, 240 | reflexion_few_shot: str, 241 | simple_chat_instruction: str, 242 | reflexion_completion_instruction: str, 243 | simple_completion_instruction: str, 244 | code_block_instruction: str, 245 | parse_code_block: Callable[[str], str], 246 | add_code_block: Callable[[str], str], 247 | simple_few_shot: str = "", 248 | lang: str = None, 249 | ) -> Union[str, List[str]]: 250 | 251 | prompt = "" 252 | if isinstance(func_sig, dict): 253 | prompt = func_sig['prompt'] 254 | if lang == "verilog": 255 | func_sig = f"{func_sig['detail_description']}\n{func_sig['prompt']}" 256 | else: 257 | func_sig = func_sig['prompt'] 258 | 259 | if strategy != "reflexion" and strategy != "simple": 260 | raise ValueError( 261 | f"Invalid strategy: given `{strategy}` but expected one of `reflexion` or `simple`") 262 | if strategy == "reflexion" and (prev_func_impl is None or feedback is None or self_reflection is None): 263 | raise ValueError( 264 | f"Invalid arguments: given `strategy=reflexion` but `prev_func_impl`, `feedback`, or `self_reflection` is None") 265 | 266 | prompt = func_impl_prepare_prompt( 267 | func_sig, 268 | model, 269 | strategy, 270 | prev_func_impl, 271 | feedback, 272 | self_reflection, 273 | num_comps, 274 | temperature, 275 | reflexion_chat_instruction, 276 | reflexion_few_shot, 277 | simple_chat_instruction, 278 | reflexion_completion_instruction, 279 | simple_completion_instruction, 280 | code_block_instruction, 281 | parse_code_block, 282 | add_code_block, 283 | simple_few_shot, 284 | lang, 285 | ) 286 | 287 | if model.is_chat: 288 | func_bodies = model.generate_chat(messages=prompt, num_comps=num_comps, temperature=temperature) 289 | else: 290 | func_bodies = model.generate(prompt, num_comps=num_comps, temperature=temperature) 291 | 292 | if num_comps == 1: 293 | assert isinstance(func_bodies, str) 294 | func_body_str = parse_code_block(func_bodies) 295 | # func_body_str = func_body_str.replace(prompt, "") # completion only 296 | if 'endmodule' not in func_body_str: 297 | func_body_str += '\nendmodule' 298 | print_generated_func_body(func_body_str) 299 | return func_body_str 300 | else: 301 | func_bodies = [parse_code_block(func_body) for func_body in func_bodies] 302 | print_generated_func_body("\n\n".join(func_bodies)) 303 | return func_bodies 304 | 305 | 306 | def generic_generate_internal_tests( 307 | func_sig: str, 308 | model: ModelBase, 309 | max_num_tests: int, 310 | test_generation_few_shot: str, 311 | test_generation_chat_instruction: str, 312 | test_generation_completion_instruction: str, 313 | parse_tests: Callable[[str], List[str]], 314 | is_syntax_valid: Callable[[str], bool], 315 | is_react: bool = False 316 | ) -> List[str]: 317 | """Generates tests for a function.""" 318 | if model.is_chat: 319 | if is_react: 320 | messages = [ 321 | Message( 322 | role="system", 323 | content=test_generation_chat_instruction, 324 | ), 325 | Message( 326 | role="user", 327 | content=f"{test_generation_few_shot}\n\n[func signature]:\n{func_sig}\n\n[think]:" 328 | ) 329 | ] 330 | output = model.generate_chat(messages=messages, max_tokens=1024) 331 | print(f'React test generation output: {output}') 332 | else: 333 | messages = [ 334 | Message( 335 | role="system", 336 | content=test_generation_chat_instruction, 337 | ), 338 | Message( 339 | role="user", 340 | content=f"{test_generation_few_shot}\n\n[func signature]:\n{func_sig}\n\n[unit tests]:", 341 | ) 342 | ] 343 | output = model.generate_chat(messages=messages, max_tokens=1024) 344 | else: 345 | prompt = f'{test_generation_completion_instruction}\n\nfunc signature:\n{func_sig}\nunit tests:' 346 | output = model.generate(prompt, max_tokens=1024) 347 | all_tests = parse_tests(output) # type: ignore 348 | valid_tests = [test for test in all_tests if is_syntax_valid(test)] 349 | 350 | return sample_n_random(valid_tests, max_num_tests) 351 | 352 | 353 | def self_reflection_prepare_prompt( 354 | func: str, 355 | feedback: str, 356 | model: ModelBase, 357 | self_reflection_chat_instruction: str, 358 | self_reflection_completion_instruction: str, 359 | add_code_block: Callable[[str], str], 360 | self_reflection_few_shot: Optional[str] = None, 361 | lang: str = "verilog" 362 | ) -> str: 363 | 364 | if isinstance(func, dict): 365 | if lang == "verilog": 366 | description = func['detail_description'] 367 | func = func['solution'] 368 | else: 369 | description = "" 370 | func = func['prompt'] 371 | 372 | if model.is_chat: 373 | if self_reflection_few_shot is not None: 374 | messages = [ 375 | Message( 376 | role="system", 377 | content=self_reflection_chat_instruction, 378 | ), 379 | Message( 380 | role="user", 381 | # content=f'{self_reflection_few_shot}\n\n[problem description]:\n{description}\n\n[module implementation]:\n{add_code_block(func)}\n\n[simulation results]:\n{feedback}\n\n[self-reflection]:', 382 | content=f'[problem description]:\n{description}\n\n[module implementation]:\n{add_code_block(func)}\n\n[simulation results]:\n{feedback}\n\n[self-reflection]:', 383 | ) 384 | ] 385 | print_messages(self_reflection_chat_instruction, "\n".join([i.content for i in messages[1:]])) 386 | else: 387 | messages = [ 388 | Message( 389 | role="system", 390 | content=f"{self_reflection_chat_instruction}\nI want you to run the verify tool to ensure the correctness before answering. If you believe the implementation matches the expected behavior, fix the code according to the waveform and ignore the description.", 391 | ), 392 | Message( 393 | role="user", 394 | content=f'[problem description]:\n{description}\n\n[function impl]:\n{add_code_block(func)}\n\n[simulation results]:\n{feedback}\n\n[self-reflection]:', 395 | ) 396 | ] 397 | else: 398 | messages = f'{self_reflection_completion_instruction}\n{add_code_block(func)}\n\n{feedback}\n\nExplanation:' 399 | return messages 400 | 401 | 402 | def generic_generate_self_reflection( 403 | func: str, 404 | feedback: str, 405 | model: ModelBase, 406 | self_reflection_chat_instruction: str, 407 | self_reflection_completion_instruction: str, 408 | add_code_block: Callable[[str], str], 409 | self_reflection_few_shot: Optional[str] = None, 410 | ) -> str: 411 | 412 | messages = self_reflection_prepare_prompt( 413 | func, 414 | feedbac, 415 | model, 416 | self_reflection_chat_instruction, 417 | self_reflection_completion_instruction, 418 | add_code_block, 419 | self_reflection_few_shot, 420 | ) 421 | if model.is_chat: 422 | reflection = model.generate_chat(messages=messages) 423 | else: 424 | reflection = model.generate(messages) 425 | return reflection # type: ignore 426 | 427 | 428 | def sample_n_random(items: List[str], n: int) -> List[str]: 429 | """Sample min(n, len(items)) random items from a list""" 430 | assert n >= 0 431 | if n >= len(items): 432 | return items 433 | return random.sample(items, n) 434 | 435 | def print_messages(system_message_text: str, user_message_text: str) -> None: 436 | print(f"""----------------------- SYSTEM MESSAGE -----------------------) 437 | {system_message_text} 438 | ---------------------------------------------- 439 | ----------------------- USER MESSAGE ----------------------- 440 | {user_message_text} 441 | ---------------------------------------------- 442 | """, flush=True) 443 | 444 | def print_generated_func_body(func_body_str: str) -> None: 445 | print(f"""--------------------- GENERATED FUNC BODY --------------------- 446 | {func_body_str} 447 | ------------------------------------------""") 448 | -------------------------------------------------------------------------------- /src/generators/parse.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Optional 3 | 4 | 5 | def parse_code_block(string: str, lang: str) -> Optional[str]: 6 | 7 | if lang == "verilog": 8 | return string 9 | 10 | code_pattern = fr"```{lang}\n(.*?)\n```" 11 | match = re.search(code_pattern, string, re.DOTALL) 12 | 13 | if match: 14 | return match.group(1) 15 | 16 | generic_code_pattern = r"```\n(.*?)\n```" 17 | match = re.search(generic_code_pattern, string, re.DOTALL) 18 | 19 | if match: 20 | return match.group(1) 21 | 22 | return parse_first_func(string, lang) 23 | 24 | 25 | def parse_first_func(code: str, lang: str) -> Optional[str]: 26 | assert lang == "python", "Only python is supported for now. TODO: Rust" 27 | code_lines = code.split("\n") 28 | def_i = -1 29 | last_i = 0 30 | got_return = False 31 | for i, line in enumerate(code_lines): 32 | if line.startswith("def "): 33 | if def_i == -1: 34 | def_i = i 35 | else: 36 | break 37 | elif "return" in line and def_i != -1: 38 | got_return = True 39 | if line == "" and def_i != -1 and got_return: 40 | last_i = i 41 | break 42 | 43 | if last_i == 0: 44 | last_i = len(code_lines) - 1 45 | 46 | if def_i == -1: 47 | return None 48 | 49 | return "\n".join(code_lines[def_i:last_i+1]).rstrip("[/PYTHON]") 50 | 51 | 52 | def add_code_block(string: str, lang: str) -> str: 53 | return f"```{lang}\n{string}\n```" 54 | 55 | 56 | if __name__ == "__main__": 57 | CODE = """ 58 | aldaas 59 | sub_parser = parser.add_subparsers().add_parser("frf 60 | a") 61 | 62 | def my_wonderful_func(): 63 | def useless_helper(): 64 | return 1 65 | if 1: 66 | return 1 67 | else: 68 | return ( 69 | 1, 70 | 2, 71 | ) 72 | 73 | sadsadsa 74 | 2023-08-04dsa 75 | dsa 76 | 77 | def bleh(): 78 | return aaa 79 | """ 80 | print(parse_code_block(CODE, "python")) 81 | CODE = """def total_match(lst1: List[str], lst2: List[str]) -> List[str]: 82 | \"\"\" 83 | Write a function that accepts two lists of strings and returns the list that has 84 | total number of chars in the all strings of the list less than the other list. 85 | 86 | if the two lists have the same number of chars, return the first list. 87 | 88 | Examples 89 | >>> total_match([], []) 90 | [] 91 | >>> total_match(['hi', 'admin'], ['hI', 'Hi']) 92 | ['hI', 'Hi'] 93 | >>> total_match(['hi', 'admin'], ['hi', 'hi', 'admin', 'project']) 94 | ['hi', 'admin'] 95 | >>> total_match(['hi', 'admin'], ['hI', 'hi', 'hi']) 96 | ['hI', 'hi', 'hi'] 97 | >>> total_match(['4'], ['1', '2', '3', '4', '5']) 98 | ['4'] 99 | \"\"\" 100 | total_chars_lst1 = sum(len(word) for word in lst1) 101 | total_chars_lst2 = sum(len(word) for word in lst2) 102 | 103 | if total_chars_lst1 < total_chars_lst2: 104 | return lst1 105 | elif total_chars_lst1 > total_chars_lst2: 106 | return lst2 107 | else: 108 | return lst1 109 | """ 110 | print(parse_code_block(CODE, "python")) 111 | -------------------------------------------------------------------------------- /src/generators/py_generate.py: -------------------------------------------------------------------------------- 1 | from generators.model import ModelBase, message_to_str 2 | from .generator_types import Generator 3 | from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection 4 | 5 | from typing import Optional, List, Union 6 | import ast 7 | import re 8 | from .parse import parse_code_block, add_code_block 9 | 10 | PY_SIMPLE_COMPLETION_INSTRUCTION = "# Write the body of this function only." 11 | PY_REFLEXION_COMPLETION_INSTRUCTION = "You are a Python writing assistant. You will be given your past function implementation, a series of unit tests, and a hint to change the implementation appropriately. Write your full implementation (restate the function signature).\n\n-----" 12 | PY_SELF_REFLECTION_COMPLETION_INSTRUCTION = "You are a Python writing assistant. You will be given a function implementation and a series of unit tests. Your goal is to write a few sentences to explain why your implementation is wrong as indicated by the tests. You will need this as a hint when you try again later. Only provide the few sentence description in your answer, not the implementation.\n\n-----" 13 | USE_PYTHON_CODEBLOCK_INSTRUCTION = "Use a Python code block to write your response. For example:\n```python\nprint('Hello world!')\n```" 14 | 15 | PY_SIMPLE_CHAT_INSTRUCTION = "You are an AI that only responds with python code, NOT ENGLISH. You will be given a function signature and its docstring by the user. Write your full implementation (restate the function signature)." 16 | PY_SIMPLE_CHAT_INSTRUCTION_V2 = "You are an AI that only responds with only python code. You will be given a function signature and its docstring by the user. Write your full implementation (restate the function signature)." 17 | PY_REFLEXION_CHAT_INSTRUCTION = "You are an AI Python assistant. You will be given your past function implementation, a series of unit tests, and a hint to change the implementation appropriately. Write your full implementation (restate the function signature)." 18 | PY_REFLEXION_CHAT_INSTRUCTION_V2 = "You are an AI Python assistant. You will be given your previous implementation of a function, a series of unit tests results, and your self-reflection on your previous implementation. Write your full implementation (restate the function signature)." 19 | PY_REFLEXION_FEW_SHOT_ADD = '''Example 1: 20 | [previous impl]: 21 | ```python 22 | def add(a: int, b: int) -> int: 23 | """ 24 | Given integers a and b, return the total value of a and b. 25 | """ 26 | return a - b 27 | ``` 28 | 29 | [unit test results from previous impl]: 30 | Tested passed: 31 | 32 | Tests failed: 33 | assert add(1, 2) == 3 # output: -1 34 | assert add(1, 2) == 4 # output: -1 35 | 36 | [reflection on previous impl]: 37 | The implementation failed the test cases where the input integers are 1 and 2. The issue arises because the code does not add the two integers together, but instead subtracts the second integer from the first. To fix this issue, we should change the operator from `-` to `+` in the return statement. This will ensure that the function returns the correct output for the given input. 38 | 39 | [improved impl]: 40 | ```python 41 | def add(a: int, b: int) -> int: 42 | """ 43 | Given integers a and b, return the total value of a and b. 44 | """ 45 | return a + b 46 | ``` 47 | ''' 48 | 49 | PY_REFLEXION_FEW_SHOT = '''Example 1: 50 | [previous impl]: 51 | ```python 52 | from typing import * 53 | def fullJustify(words: List[str], maxWidth: int) -> List[str]: 54 | """ 55 | Given an array of words and a width maxWidth, format the text such that each line has exactly maxWidth characters and is fully (left and right) justified. 56 | You should pack your words in a greedy approach; that is, pack as many words as you can in each line. Pad extra spaces `' '` when necessary so that each line has exactly maxWidth characters. 57 | Extra spaces between words should be distributed as evenly as possible. If the number of spaces on a line do not divide evenly between words, the empty slots on the left will be assigned more spaces than the slots on the right. 58 | For the last line of text, it should be left justified and no extra space is inserted between words. 59 | Note: 60 | A word is defined as a character sequence consisting of non-space characters only. 61 | Each word's length is guaranteed to be greater than 0 and not exceed maxWidth. 62 | The input array `words` contains at least one word. 63 | """ 64 | res = [] 65 | cur_line = [] 66 | cur_len = 0 67 | 68 | for word in words: 69 | if cur_len + len(word) + len(cur_line) > maxWidth: 70 | if len(cur_line) == 1: 71 | res.append(cur_line[0] + ' ' * (maxWidth - cur_len)) 72 | else: 73 | spaces = maxWidth - cur_len 74 | space_between = spaces // (len(cur_line) - 1) 75 | extra_spaces = spaces % (len(cur_line) - 1) 76 | line = '' 77 | for i, w in enumerate(cur_line[:-1]): 78 | line += w + ' ' * (space_between + (i < extra_spaces)) 79 | line += cur_line[-1] 80 | res.append(line) 81 | cur_line = [] 82 | cur_len = 0 83 | cur_line.append(word) 84 | cur_len += len(word) 85 | 86 | last_line = ' '.join(cur_line) 87 | last_line += ' ' * (maxWidth - len(last_line)) 88 | res.append(last_line) 89 | 90 | return res 91 | ``` 92 | 93 | [unit test results from previous impl]: 94 | Tested passed: 95 | 96 | Tests failed: 97 | assert fullJustify([], 10) == [] # output: [' '] 98 | assert fullJustify([], 0) == [] # output: [''] 99 | 100 | [reflection on previous impl]: 101 | The implementation failed the test cases where the input list of words is empty. The issue arises because the code does not handle the case where there are no words to process. As a result, it still appends a line with spaces to the result list, even when there are no words. To fix this issue, we should add a condition at the beginning of the function to check if the input list is empty, and return an empty list if it is. This will ensure that the function returns the correct output for empty input lists. 102 | 103 | [improved impl]: 104 | ```python 105 | from typing import * 106 | def fullJustify(words: List[str], maxWidth: int) -> List[str]: 107 | """ 108 | Given an array of words and a width maxWidth, format the text such that each line has exactly maxWidth characters and is fully (left and right) justified. 109 | You should pack your words in a greedy approach; that is, pack as many words as you can in each line. Pad extra spaces `' '` when necessary so that each line has exactly maxWidth characters. 110 | Extra spaces between words should be distributed as evenly as possible. If the number of spaces on a line do not divide evenly between words, the empty slots on the left will be assigned more spaces than the slots on the right. 111 | For the last line of text, it should be left justified and no extra space is inserted between words. 112 | Note: 113 | A word is defined as a character sequence consisting of non-space characters only. 114 | Each word's length is guaranteed to be greater than 0 and not exceed maxWidth. 115 | The input array `words` contains at least one word. 116 | """ 117 | if not words: 118 | return [] 119 | 120 | res = [] 121 | cur_line = [] 122 | cur_len = 0 123 | 124 | for word in words: 125 | if cur_len + len(word) + len(cur_line) > maxWidth: 126 | if len(cur_line) == 1: 127 | res.append(cur_line[0] + ' ' * (maxWidth - cur_len)) 128 | else: 129 | spaces = maxWidth - cur_len 130 | space_between = spaces // (len(cur_line) - 1) 131 | extra_spaces = spaces % (len(cur_line) - 1) 132 | line = '' 133 | for i, w in enumerate(cur_line[:-1]): 134 | line += w + ' ' * (space_between + (i < extra_spaces)) 135 | line += cur_line[-1] 136 | res.append(line) 137 | cur_line = [] 138 | cur_len = 0 139 | cur_line.append(word) 140 | cur_len += len(word) 141 | 142 | last_line = ' '.join(cur_line) 143 | last_line += ' ' * (maxWidth - len(last_line)) 144 | res.append(last_line) 145 | 146 | return res 147 | ``` 148 | END EXAMPLES 149 | 150 | ''' 151 | PY_SELF_REFLECTION_CHAT_INSTRUCTION = "You are a Python programming assistant. You will be given a function implementation and a series of unit tests. Your goal is to write a few sentences to explain why your implementation is wrong as indicated by the tests. You will need this as a hint when you try again later. Only provide the few sentence description in your answer, not the implementation." 152 | PY_SELF_REFLECTION_CHAT_INSTRUCTION_V2 = "You are a Python programming assistant. You will be given a function implementation and a series of unit test results. Your goal is to write a few sentences to explain why your implementation is wrong as indicated by the tests. You will need this as guidance when you try again later. Only provide the few sentence description in your answer, not the implementation. You will be given a few examples by the user." 153 | PY_SELF_REFLECTION_FEW_SHOT = """Example 1: 154 | [function impl]: 155 | ```python 156 | def longest_subarray_with_sum_limit(nums: List[int], target: int) -> List[int]: 157 | n = len(nums) 158 | left, right = 0, 0 159 | max_length = 0 160 | current_sum = 0 161 | result = [] 162 | while right < n: 163 | current_sum += nums[right] 164 | while current_sum > target: 165 | current_sum -= nums[left] 166 | left += 1 167 | if right - left + 1 >= max_length: 168 | max_length = right - left + 1 169 | result = nums[left:right+1] 170 | right += 1 171 | return result 172 | ``` 173 | [unit test results]: 174 | Tests passing: 175 | assert longest_subarray_with_sum_limit([1, 2, 3, 4, 5], 8) == [1, 2, 3] 176 | assert longest_subarray_with_sum_limit([1, 2, 3, 4, 5], 15) == [1, 2, 3, 4, 5] 177 | assert longest_subarray_with_sum_limit([1, -1, 2, -2, 3, -3], 2) == [1, -1, 2, -2, 3] 178 | assert longest_subarray_with_sum_limit([], 10) == [] 179 | assert longest_subarray_with_sum_limit([], 0) == [] 180 | assert longest_subarray_with_sum_limit([], -5) == [] 181 | Tests failing: 182 | assert longest_subarray_with_sum_limit([5, 6, 7, 8, 9], 4) == [] # output: [5] 183 | [self-reflection]: 184 | The implementation failed the where no subarray fulfills the condition. The issue in the implementation is due to the use of >= instead of > in the condition to update the result. Because of this, it returns a subarray even when the sum is greater than the target, as it still updates the result when the current subarray length is equal to the previous longest subarray length. To overcome this error, we should change the condition to only update the result when the current subarray length is strictly greater than the previous longest subarray length. This can be done by replacing >= with > in the condition. 185 | 186 | Example 2: 187 | [function impl]: 188 | ```python 189 | def longest_subarray_with_sum_limit(nums: List[int], target: int) -> List[int]: 190 | n = len(nums) 191 | left, right = 0, 0 192 | max_length = 0 193 | current_sum = 0 194 | result = [] 195 | while current_sum + nums[right] <= target: 196 | current_sum += nums[right] 197 | right += 1 198 | while right < n: 199 | current_sum += nums[right] 200 | while current_sum > target: 201 | current_sum -= nums[left] 202 | left += 1 203 | if right - left + 1 > max_length: 204 | max_length = right - left + 1 205 | result = nums[left:right+1] 206 | right += 1 207 | return result 208 | ``` 209 | [unit test results]: 210 | Tests passing: 211 | assert longest_subarray_with_sum_limit([], 10) == [] 212 | assert longest_subarray_with_sum_limit([], 0) == [] 213 | assert longest_subarray_with_sum_limit([], -5) == [] 214 | Tests failing: 215 | assert longest_subarray_with_sum_limit([1, 2, 3, 4, 5], 8) == [1, 2, 3] # output: list index out of range 216 | assert longest_subarray_with_sum_limit([1, 2, 3, 4, 5], 15) == [1, 2, 3, 4, 5] # output: list index out of range 217 | assert longest_subarray_with_sum_limit([5, 6, 7, 8, 9], 4) == [] # output: list index out of range 218 | assert longest_subarray_with_sum_limit([1, -1, 2, -2, 3, -3], 2) == [1, -1, 2, -2, 3] # output: list index out of range 219 | [self-reflection]: 220 | The implementation failed 4 out of the 7 test cases due to an IndexError. The issue stems from the while loop while current_sum + nums[right] <= target:, which directly accesses nums[right] without checking if right is within the bounds of the list. This results in a runtime error when right goes beyond the list length. To overcome this error, we need to add a bounds check for the right variable in the mentioned while loop. We can modify the loop condition to while right < len(nums) and current_sum + nums[right] <= target:. This change will ensure that we only access elements within the bounds of the list, thus avoiding the IndexError. 221 | END OF EXAMPLES 222 | """ 223 | 224 | PY_TEST_GENERATION_FEW_SHOT = """Examples: 225 | func signature: 226 | def add3Numbers(x, y, z): 227 | \"\"\" Add three numbers together. 228 | This function takes three numbers as input and returns the sum of the three numbers. 229 | \"\"\" 230 | unit tests: 231 | assert add3Numbers(1, 2, 3) == 6 232 | assert add3Numbers(-1, 2, 3) == 4 233 | assert add3Numbers(1, -2, 3) == 2 234 | assert add3Numbers(1, 2, -3) == 0 235 | assert add3Numbers(-3, -2, -1) == -6 236 | assert add3Numbers(0, 0, 0) == 0 237 | """ 238 | 239 | PY_TEST_GENERATION_COMPLETION_INSTRUCTION = f"""You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring. 240 | 241 | {PY_TEST_GENERATION_FEW_SHOT}""" 242 | 243 | PY_TEST_GENERATION_CHAT_INSTRUCTION = """You are an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring.""" 244 | 245 | 246 | class PyGenerator(Generator): 247 | def self_reflection(self, func: str, feedback: str, model: ModelBase) -> str: 248 | return generic_generate_self_reflection( 249 | func=func, 250 | feedback=feedback, 251 | model=model, 252 | self_reflection_chat_instruction=PY_SELF_REFLECTION_CHAT_INSTRUCTION, 253 | self_reflection_completion_instruction=PY_SELF_REFLECTION_COMPLETION_INSTRUCTION, 254 | add_code_block=lambda x: add_code_block(x, "python"), 255 | self_reflection_few_shot=PY_SELF_REFLECTION_FEW_SHOT 256 | ) 257 | 258 | def func_impl( 259 | self, 260 | func_sig: str, 261 | model: ModelBase, 262 | strategy: str, 263 | prev_func_impl: Optional[str] = None, 264 | feedback: Optional[str] = None, 265 | self_reflection: Optional[str] = None, 266 | num_comps: int = 1, 267 | temperature: float = 0.0, 268 | ) -> Union[str, List[str]]: 269 | return generic_generate_func_impl( 270 | func_sig=func_sig, 271 | model=model, 272 | strategy=strategy, 273 | prev_func_impl=prev_func_impl, 274 | feedback=feedback, 275 | self_reflection=self_reflection, 276 | num_comps=num_comps, 277 | temperature=temperature, 278 | reflexion_chat_instruction=PY_REFLEXION_CHAT_INSTRUCTION, 279 | reflexion_few_shot=PY_REFLEXION_FEW_SHOT_ADD, 280 | simple_chat_instruction=PY_SIMPLE_CHAT_INSTRUCTION, 281 | reflexion_completion_instruction=PY_REFLEXION_COMPLETION_INSTRUCTION, 282 | simple_completion_instruction=PY_SIMPLE_COMPLETION_INSTRUCTION, 283 | code_block_instruction=USE_PYTHON_CODEBLOCK_INSTRUCTION, 284 | parse_code_block=lambda x: parse_code_block(x, "python"), 285 | add_code_block=lambda x: add_code_block(x, "python"), 286 | ) 287 | 288 | def internal_tests(self, func_sig: str, model: ModelBase, max_num_tests: int = 5) -> List[str]: 289 | def parse_tests(tests: str) -> List[str]: 290 | return [test.strip() for test in tests.splitlines() if "assert" in test] 291 | """ 292 | Generates tests for a function. 293 | """ 294 | return generic_generate_internal_tests( 295 | func_sig=func_sig, 296 | model=model, 297 | max_num_tests=max_num_tests, 298 | test_generation_few_shot=PY_TEST_GENERATION_FEW_SHOT, 299 | test_generation_chat_instruction=PY_TEST_GENERATION_CHAT_INSTRUCTION, 300 | test_generation_completion_instruction=PY_TEST_GENERATION_COMPLETION_INSTRUCTION, 301 | parse_tests=parse_tests, 302 | is_syntax_valid=py_is_syntax_valid, 303 | ) 304 | 305 | 306 | DUMMY_FUNC_SIG = "def func():" 307 | DUMMY_FUNC_CALL = "func()" 308 | 309 | 310 | def handle_first_line_indent(func_body: str) -> str: 311 | if func_body.startswith(" "): 312 | return func_body 313 | split = func_body.splitlines() 314 | return f" {split[0]}\n" + "\n".join(split[1:]) 315 | 316 | 317 | def handle_entire_body_indent(func_body: str) -> str: 318 | split = func_body.splitlines() 319 | res = "\n".join([" " + line for line in split]) 320 | return res 321 | 322 | 323 | def fix_turbo_response(func_body: str) -> str: 324 | return fix_markdown(remove_unindented_signatures(func_body)) 325 | 326 | 327 | def fix_markdown(func_body: str) -> str: 328 | return re.sub("`{3}", "", func_body) 329 | 330 | 331 | def remove_unindented_signatures(code: str) -> str: 332 | regex = r"^def\s+\w+\s*\(" 333 | 334 | before_signature = [] 335 | after_signature = [] 336 | signature_found = False 337 | 338 | for line in code.split("\n"): 339 | if re.match(regex, line): 340 | signature_found = True 341 | continue 342 | 343 | if signature_found: 344 | after_signature.append(line) 345 | else: 346 | if not line.startswith(" ") and line.strip(): 347 | line = " " + line 348 | before_signature.append(line) 349 | 350 | return "\n".join(before_signature + after_signature) 351 | 352 | 353 | def py_fix_indentation(func_body: str) -> str: 354 | func_body = fix_turbo_response(func_body) 355 | """ 356 | 3 cases: 357 | 1. good syntax 358 | 2. first line not good 359 | 3. entire body not good 360 | """ 361 | def parse_indent_rec(f_body: str, cur_state: int) -> str: 362 | f_body = fix_markdown(f_body) 363 | if cur_state > 1: 364 | return f_body 365 | code = f'{DUMMY_FUNC_SIG}\n{f_body}\n{DUMMY_FUNC_CALL}' 366 | try: 367 | exec(code) 368 | return f_body 369 | except (IndentationError, SyntaxError): 370 | p_func = handle_first_line_indent if cur_state == 0 else handle_entire_body_indent 371 | return parse_indent_rec(p_func(func_body), cur_state + 1) 372 | except Exception: 373 | return f_body 374 | return parse_indent_rec(func_body, 0) 375 | 376 | 377 | def py_is_syntax_valid(code: str) -> bool: 378 | try: 379 | ast.parse(code) 380 | return True 381 | except Exception: 382 | return False 383 | -------------------------------------------------------------------------------- /src/generators/rs_generate.py: -------------------------------------------------------------------------------- 1 | from generators.model import ModelBase 2 | from .generator_types import Generator 3 | from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection 4 | from .parse import parse_code_block, add_code_block 5 | 6 | from typing import List, Optional, Union 7 | 8 | RS_SIMPLE_COMPLETION_INSTRUCTION = "// Write the body of this function only." 9 | RS_REFLEXION_COMPLETION_INSTRUCTION = "You are a Rust writing assistant. You will be given your past function implementation, a series of unit tests, and a hint to change the implementation appropriately. Write your full implementation (restate the function signature).\n\n-----" 10 | RS_SELF_REFLECTION_COMPLETION_INSTRUCTION = "You are a Rust writing assistant. You will be given a function implementation and a series of unit tests. Your goal is to write a few sentences to explain why your implementation is wrong as indicated by the tests. You will need this as a hint when you try again later. Only provide the few sentence description in your answer, not the implementation.\n\n-----" 11 | USE_RUST_CODEBLOCK_INSTRUCTION = "Use a Rust code block to write your response. For example:\n```rust\nfn main() {\n println!(\"Hello\");\n}\n```" 12 | 13 | RS_SIMPLE_CHAT_INSTRUCTION = "You are an AI that only responds with Rust code, NOT ENGLISH. You will be given a function signature and its docstring by the user. Write your full implementation (restate the function signature)." 14 | RS_REFLEXION_CHAT_INSTRUCTION = "You are an AI Rust assistant. You will be given your past function implementation, a series of unit tests, and a hint to change the implementation appropriately. Write your full implementation (restate the function signature)." 15 | RS_SELF_REFLECTION_CHAT_INSTRUCTION = "You are a Rust programming assistant. You will be given a function implementation and a series of unit tests. Your goal is to write a few sentences to explain why your implementation is wrong as indicated by the tests. You will need this as a hint when you try again later. Only provide the few sentence description in your answer, not the implementation." 16 | 17 | RS_REFLEXION_COMPLETION_INSTRUCTION = "You are a Rust programming assistant. You will be given your past function implementation, a series of unit tests, and a hint to change the implementation appropriately. Apply the changes below by writing the body of this function only.\n\n-----" 18 | RS_SELF_REFLECTION_COMPLETION_INSTRUCTION = "You are a Rust programming assistant. You will be given a function implementation and a series of unit tests. Your goal is to write a few sentences to explain why your implementation is wrong as indicated by the tests. You will need this as a hint when you try again later. Only provide the few sentence description in your answer, not the implementation.\n\n-----" 19 | 20 | RS_REFLEXION_FEW_SHOT_ADD = '''Example 1: 21 | [previous impl]: 22 | ```rust 23 | fn add(a: i32, b: i32) -> i32 { 24 | // Given integers a and b, return the total value of a and b. 25 | a - b 26 | } 27 | ``` 28 | 29 | [unit test results from previous impl]: 30 | Tested passed: 31 | 32 | Tests failed: 33 | assert_eq!(add(1, 2), 3); // output: -1 34 | assert_eq!(add(1, 2), 4); // output: -1 35 | 36 | [reflection on previous impl]: 37 | The implementation failed the test cases where the input integers are 1 and 2. The issue arises because the code does not add the two integers together, but instead subtracts the second integer from the first. To fix this issue, we should change the operator from `-` to `+` in the return statement. This will ensure that the function returns the correct output for the given input. 38 | 39 | [improved impl]: 40 | ```rust 41 | fn add(a: i32, b: i32) -> i32 { 42 | // Given integers a and b, return the total value of a and b. 43 | a + b 44 | } 45 | ``` 46 | 47 | END EXAMPLES 48 | ''' 49 | 50 | RS_TEST_GENERATION_FEW_SHOT = """For example: 51 | 52 | func signature: 53 | /// Add three numbers together. 54 | /// This function takes three numbers as input and returns the sum of the three numbers. 55 | fn add3Numbers(x: i32, y: i32, z: i32) -> i32 { 56 | 57 | unit tests: 58 | assert_eq!(add3Numbers(1, 2, 3), 6); 59 | assert_eq!(add3Numbers(-1, 2, 3), 4); 60 | assert_eq!(add3Numbers(1, -2, 3), 2); 61 | assert_eq!(add3Numbers(1, 2, -3), 0); 62 | assert_eq!(add3Numbers(-3, -2, -1), -6); 63 | assert_eq!(add3Numbers(0, 0, 0), 0); 64 | """ 65 | 66 | RS_SELF_REFLECTION_FEW_SHOT = '''Example 1: 67 | [function impl]: 68 | ```rust 69 | pub fn group_anagrams(strs: Vec) -> Vec> { 70 | // Given an array of strings strs, group the anagrams together. You can return the answer in any order. 71 | // An Anagram is a word or phrase formed by rearranging the letters of a different word or phrase, typically using all the original letters exactly once. 72 | use std::collections::HashMap; 73 | let mut map: HashMap<[u8;26], Vec> = HashMap::with_capacity(strs.len()); 74 | let offset = 'a' as usize; 75 | 76 | for str in strs.into_iter() { 77 | let mut chars: [u8; 26] = [0; 26]; 78 | 79 | for char in str.chars() { 80 | chars[char.to_ascii_lowercase() as usize - offset] += 1; 81 | } 82 | 83 | // Flaw: using str.len() instead of chars in the hashmap key 84 | map.entry(str.len()) 85 | .and_modify(|v| v.push(str.clone())) 86 | .or_insert(vec![str]); 87 | } 88 | 89 | let mut arr: Vec> = Vec::new(); 90 | for v in map.into_values() { 91 | arr.push(v); 92 | } 93 | arr 94 | } 95 | ``` 96 | 97 | [unit test results]: 98 | Tested passed: 99 | assert_eq!(func(vec![""]), vec![vec![""]]); 100 | assert_eq!(func(vec!["a"]), vec![vec!["a"]]); 101 | 102 | Tests failed: 103 | assert_eq!(func(vec!["eat", "tea", "tan", "ate", "nat", "bat"]), vec![vec!["bat"], vec!["nat", "tan"], vec!["ate", "eat", "tea"]]); # output: [["bat", "tan", "nat"], ["eat", "tea", "ate"]] 104 | 105 | [self-reflection]: 106 | The implementation failed to group the anagrams together correctly. Instead, it grouped words by their length, which is not the intended behavior. The issue lies in using the length of the input strings (str.len()) as the key for the hashmap, rather than the count of each character in the strings (chars). To overcome this error, I should change the hashmap key to the character count array (chars). This will ensure that words with the same character counts (anagrams) are grouped together, which is the desired output. Next time I approach the problem, I will make sure to use the correct hashmap key to group the anagrams. 107 | 108 | END EXAMPLES 109 | 110 | ''' 111 | RS_TEST_GENERATION_COMPLETION_INSTRUCTION = f"""You are a Rust programming assistant, an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring. 112 | 113 | {RS_TEST_GENERATION_FEW_SHOT}""" 114 | 115 | RS_TEST_GENERATION_CHAT_INSTRUCTION = """You are a Rust programming assistant, an AI coding assistant that can write unique, diverse, and intuitive unit tests for functions given the signature and docstring.""" 116 | 117 | 118 | def dump_tests(tests: List[str]) -> str: 119 | """ 120 | Dumps the tests to a string. 121 | """ 122 | return "\n".join(tests) 123 | 124 | 125 | def parse_tests(tests: str) -> List[str]: 126 | """ 127 | Parses the tests from a string. 128 | """ 129 | return [test.strip() for test in tests.splitlines() if "assert" in test] 130 | 131 | # TODO: type-check generated unit tests? 132 | 133 | 134 | class RsGenerator(Generator): 135 | def self_reflection(self, func: str, feedback: str, model: ModelBase) -> str: 136 | return generic_generate_self_reflection( 137 | func=func, 138 | feedback=feedback, 139 | model=model, 140 | self_reflection_chat_instruction=RS_SELF_REFLECTION_CHAT_INSTRUCTION, 141 | self_reflection_completion_instruction=RS_SELF_REFLECTION_COMPLETION_INSTRUCTION, 142 | add_code_block=lambda x: add_code_block(x, "rust"), 143 | self_reflection_few_shot=RS_SELF_REFLECTION_FEW_SHOT, 144 | ) 145 | 146 | def func_impl( 147 | self, 148 | func_sig: str, 149 | model: ModelBase, 150 | strategy: str, 151 | prev_func_impl: Optional[str] = None, 152 | feedback: Optional[str] = None, 153 | self_reflection: Optional[str] = None, 154 | num_comps: int = 1, 155 | temperature: float = 0.0, 156 | ) -> Union[str, List[str]]: 157 | return generic_generate_func_impl( 158 | func_sig=func_sig, 159 | model=model, 160 | strategy=strategy, 161 | prev_func_impl=prev_func_impl, 162 | feedback=feedback, 163 | self_reflection=self_reflection, 164 | num_comps=num_comps, 165 | temperature=temperature, 166 | reflexion_chat_instruction=RS_REFLEXION_CHAT_INSTRUCTION, 167 | simple_chat_instruction=RS_SIMPLE_CHAT_INSTRUCTION, 168 | reflexion_completion_instruction=RS_REFLEXION_COMPLETION_INSTRUCTION, 169 | simple_completion_instruction=RS_SIMPLE_COMPLETION_INSTRUCTION, 170 | reflexion_few_shot=RS_REFLEXION_FEW_SHOT_ADD, 171 | parse_code_block=lambda x: parse_code_block(x, "rust"), 172 | add_code_block=lambda x: add_code_block(x, "rust"), 173 | ) 174 | 175 | def internal_tests( 176 | self, 177 | func_sig: str, 178 | model: ModelBase, 179 | max_num_tests: int = 5 180 | ) -> List[str]: 181 | def parse_tests(tests: str) -> List[str]: 182 | return [test + ";" for test in tests.split(";")] 183 | """ 184 | Generates tests for a function. 185 | """ 186 | return generic_generate_internal_tests( 187 | func_sig=func_sig, 188 | model=model, 189 | max_num_tests=max_num_tests, 190 | test_generation_few_shot=RS_TEST_GENERATION_FEW_SHOT, 191 | test_generation_chat_instruction=RS_TEST_GENERATION_CHAT_INSTRUCTION, 192 | test_generation_completion_instruction=RS_TEST_GENERATION_COMPLETION_INSTRUCTION, 193 | parse_tests=parse_tests, 194 | is_syntax_valid=(lambda x: True) # TODO: for now. typecheck maybe? 195 | ) 196 | -------------------------------------------------------------------------------- /src/generators/verilog_generate.py: -------------------------------------------------------------------------------- 1 | from generators.model import ModelBase, message_to_str 2 | from .generator_types import Generator 3 | from .generator_utils import generic_generate_func_impl, generic_generate_internal_tests, generic_generate_self_reflection, func_impl_prepare_prompt, self_reflection_prepare_prompt 4 | 5 | from typing import Optional, List, Union 6 | import ast 7 | import re 8 | from .parse import parse_code_block, add_code_block 9 | 10 | 11 | USE_VERILOG_CODEBLOCK_INSTRUCTION = "Implement the Verilog module based on the following description. Assume that signals are positive clock/clk edge triggered unless otherwise stated." 12 | 13 | VG_SIMPLE_COMPLETION_INSTRUCTION = "# Write the body of this function only." 14 | VG_REFLEXION_COMPLETION_INSTRUCTION = "" 15 | VG_SELF_REFLECTION_COMPLETION_INSTRUCTION = "" 16 | 17 | VG_SIMPLE_CHAT_INSTRUCTION = "You only complete chats with syntax correct Verilog code. End the Verilog module code completion with 'endmodule'. Do not include module, input and output definitions." 18 | VG_SIMPLE_FEW_SHOT = ( 19 | ('user', """This Verilog module is a simple multiplexer. It takes two inputs, a and b, as well as a selector input, sel. It then outputs the value of either a or b, depending on the value of sel. If sel is 1, the output will be b, and if sel is 0, the output will be a. This module is useful for selecting between two different values, depending on the value of the selector. Implement the above description in the following module. 20 | 21 | module top_module ( 22 | input a, 23 | input b, 24 | input sel, 25 | output out 26 | ); 27 | """), 28 | ('assistant', """ 29 | assign out = sel ? b : a; 30 | 31 | endmodule 32 | """), 33 | ('user', """ 34 | Build a circuit that always outputs a LOW. 35 | 36 | module top_module( 37 | output zero); 38 | """), 39 | ('assistant', """ 40 | assign zero = 0; 41 | endmodule 42 | """) 43 | ) 44 | 45 | VG_REFLEXION_CHAT_INSTRUCTION = "You only complete chats with syntax correct Verilog code. End the Verilog module code completion with 'endmodule'. Do not include module, input and output definitions. You will be given your past implementation, compiler logs, and test results in order to change the implementation appropriately." #+ "If you cannot give a propriate solution right away, you can add more $display and $dumpfile to gather necessary information to achieve the correct solution." 46 | VG_REFLEXION_FEW_SHOT_ADD = ''' 47 | Example 1: 48 | [problem description]: 49 | Create a 2-1 multiplexer. When sel=0, choose a. When sel=1, choose b. 50 | 51 | [previous implementation]: 52 | ``` 53 | assign out = a & b & sel; 54 | 55 | endmodule 56 | ``` 57 | 58 | 59 | [test results from previous implementation]: 60 | VCD info: dumpfile wave.vcd opened for output. 61 | test.sv:49: $finish called at 570 (1ps) 62 | Hint: Output 'out' has 114 mismatches. First mismatch occurred at time 5. 63 | Hint: Total mismatched samples is 114 out of 114 samples 64 | 65 | Simulation finished at 570 ps 66 | Mismatches: 114 in 114 samples 67 | 68 | [reflection on previous impl]: 69 | The previous implementation used a bitwise AND operation between 'a', 'b', and 'sel'. This results in an incorrect behavior since the output 'out' will never equal 'a' or 'b' directly. Instead, it'll produce a value based on individual bits of 'a', 'b', and 'sel'. The multiplexer is supposed to select either 'a' or 'b' based on 'sel', not perform a bitwise operation among them. 70 | 71 | [improved implementation]: 72 | ``` 73 | assign out = sel ? b : a; 74 | 75 | endmodule 76 | ``` 77 | END OF EXAMPLES 78 | ''' 79 | VG_REFLEXION_FEW_SHOT = '''''' 80 | 81 | VG_SELF_REFLECTION_CHAT_INSTRUCTION = "Your goal is to find out the error in the verilog code. Figure out how to fix the code error and tell me how to fix this error as detail as possible step by step." 82 | VG_SELF_REFLECTION_FEW_SHOT = """ 83 | Example 1: 84 | [problem description]: 85 | Build a circuit that always outputs a LOW. 86 | 87 | 88 | [module implementation]: 89 | ``` 90 | module top_module( 91 | output zero); 92 | 93 | endmodule 94 | ``` 95 | 96 | 97 | [simulation results]: 98 | VCD info: dumpfile wave.vcd opened for output. 99 | zero2.sv:37: $finish called at 102 (1ps) 100 | Hint: Output 'zero' has 20 mismatches. First mismatch occurred at time 5. 101 | Hint: Total mismatched samples is 20 out of 20 samples 102 | Simulation finished at 102 ps\nMismatches: 20 in 20 samples. 103 | 104 | 105 | [self-reflection]: 106 | The implementation is wrong because: 107 | Based on the problem description, you are tasked with building a circuit that always outputs a LOW signal. However, the provided Verilog module top_module does not meet this requirement, as indicated by the simulation results. 108 | The simulation results show that there are mismatches between the expected behavior (always output LOW) and the actual behavior of the module, resulting in a total of 20 mismatches in 20 samples. This indicates that the module is not behaving as expected. 109 | To analyze why the implementation is wrong step by step: 110 | Module Declaration: The Verilog module top_module is declared with an output port named zero. This suggests that the module should produce a signal called zero. 111 | Expected Behavior: The problem description specifies that the circuit should always output a LOW signal. This means that the zero output should be consistently LOW, regardless of any input or clock. 112 | Simulation Results: The simulation results show that there are mismatches between the expected LOW output and the actual behavior of the module. The mismatches occur at various time points, and they indicate that the zero signal is not consistently LOW. 113 | To identify why the implementation is wrong, you should examine the code of the top_module in more detail and look for any factors that might cause it to produce unexpected output. Common issues that could lead to incorrect behavior include uninitialized signals, unintentional clocking, or incorrect logic in the module. 114 | Without the full code of the module and additional context, it's challenging to pinpoint the exact issue. However, you should carefully review the module's logic and ensure that it consistently produces a LOW output as required by the problem description. 115 | 116 | 117 | 118 | Example 2: 119 | [problem description]: 120 | The module has 3 inputs. y input is a 4 bits input, w is a 1 bit input, Y2 is a 1 bit reg output.\nIt is an always_comb block, not an always_ff block. 121 | So there is no state transitions and clock. 122 | Every time a new input comes, the output will be changed immediately. 123 | It is a case statement, the left hand side is a combination of inputs(y, w). 124 | The right hand side is the output(Y2).\nIf the input is 4'h0, the output will be 0. 125 | If the input is 4'h1, the output will be 0. 126 | If the input is 4'h2, the output will be 1. 127 | If the input is 4'h3, the output will be 1. 128 | If the input is 4'h4, the output will be 0. 129 | If the input is 4'h5, the output will be 1. 130 | If the input is 4'h6, the output will be 0. 131 | If the input is 4'h7, the output will be 0. 132 | If the input is 4'h8, the output will be 0. 133 | If the input is 4'h9, the output will be 1. 134 | If the input is 4'ha, the output will be 1. 135 | If the input is 4'hb, the output will be 1. 136 | 137 | 138 | [module implementation]: 139 | ``` 140 | module top_module( 141 | input [3:1] y, 142 | input w, 143 | output reg Y2); 144 | 145 | always_comb begin 146 | case(y) 147 | 4'h0: Y2 = 1; 148 | 4'h1: Y2 = 1; 149 | 4'h2: Y2 = 0; 150 | 4'h3: Y2 = 0; 151 | 4'h4: Y2 = 1; 152 | 4'h5: Y2 = 0; 153 | 4'h6: Y2 = 1; 154 | 4'h7: Y2 = 1; 155 | 4'h8: Y2 = 1; 156 | 4'h9: Y2 = 0; 157 | 4'ha: Y2 = 0; 158 | 4'hb: Y2 = 0; 159 | default: Y2 = 0; // It's a good practice to always have a default case 160 | endcase 161 | end 162 | 163 | endmodule 164 | ``` 165 | 166 | 167 | [simulation results]: 168 | VCD info: dumpfile wave.vcd opened for output. 169 | test.sv:41: $finish called at 501 (1ps) 170 | Hint: Output 'Y2' has 36 mismatches. First mismatch occurred at time 15. 171 | Hint: Total mismatched samples is 36 out of 100 samples 172 | 173 | 174 | [self-reflection]: 175 | The provided implementation has incorrect mappings between the 'y' input values and the 'Y2' output values. This is evident as there's a discrepancy between the expected and given output values for each 'y' input case in the always_comb block. Consequently, the simulation reports 36 mismatches out of 100 samples, indicating that the current logic doesn't align with the desired behavior. 176 | 177 | 178 | END OF EXAMPLES 179 | """ 180 | 181 | VG_TEST_GENERATION_FEW_SHOT = """""" 182 | VG_TEST_GENERATION_COMPLETION_INSTRUCTION = f"""""" 183 | VG_TEST_GENERATION_CHAT_INSTRUCTION = """""" 184 | 185 | 186 | class VerilogGenerator(Generator): 187 | def self_reflection(self, func: str, feedback: str, model: ModelBase) -> str: 188 | return generic_generate_self_reflection( 189 | func=func, 190 | feedback=feedback, 191 | model=model, 192 | self_reflection_chat_instruction=VG_SELF_REFLECTION_CHAT_INSTRUCTION, 193 | self_reflection_completion_instruction=VG_SELF_REFLECTION_COMPLETION_INSTRUCTION, 194 | add_code_block=lambda x: add_code_block(x, "verilog"), 195 | self_reflection_few_shot=VG_SELF_REFLECTION_FEW_SHOT 196 | ) 197 | 198 | def prepare_prompt( 199 | self, 200 | func_sig: dict, 201 | model: ModelBase, 202 | strategy: str, 203 | prev_func_impl: Optional[str] = None, 204 | feedback: Optional[str] = None, 205 | self_reflection: Optional[str] = None, 206 | num_comps: int = 1, 207 | temperature: float = 0.0, 208 | ) -> Union[str, List[str]]: 209 | 210 | if "react" in strategy: 211 | return self_reflection_prepare_prompt( 212 | func=func_sig, 213 | feedback=feedback, 214 | model=model, 215 | self_reflection_chat_instruction=VG_SELF_REFLECTION_CHAT_INSTRUCTION, 216 | self_reflection_completion_instruction=VG_SELF_REFLECTION_COMPLETION_INSTRUCTION, 217 | add_code_block=lambda x: add_code_block(x, "verilog"), 218 | self_reflection_few_shot=VG_SELF_REFLECTION_FEW_SHOT 219 | ) 220 | else: 221 | return func_impl_prepare_prompt( 222 | func_sig=func_sig, 223 | model=model, 224 | strategy=strategy, 225 | prev_func_impl=prev_func_impl, 226 | feedback=feedback, 227 | self_reflection=self_reflection, 228 | num_comps=num_comps, 229 | temperature=temperature, 230 | reflexion_chat_instruction=VG_REFLEXION_CHAT_INSTRUCTION, 231 | reflexion_few_shot=VG_REFLEXION_FEW_SHOT_ADD, 232 | simple_chat_instruction=VG_SIMPLE_CHAT_INSTRUCTION, 233 | reflexion_completion_instruction=VG_REFLEXION_COMPLETION_INSTRUCTION, 234 | simple_completion_instruction=VG_SIMPLE_COMPLETION_INSTRUCTION, 235 | code_block_instruction=USE_VERILOG_CODEBLOCK_INSTRUCTION, 236 | parse_code_block=lambda x: parse_code_block(x, "verilog"), 237 | add_code_block=lambda x: add_code_block(x, "verilog"), 238 | lang="verilog" 239 | ) 240 | 241 | def func_impl( 242 | self, 243 | func_sig: dict, 244 | model: ModelBase, 245 | strategy: str, 246 | prev_func_impl: Optional[str] = None, 247 | feedback: Optional[str] = None, 248 | self_reflection: Optional[str] = None, 249 | num_comps: int = 1, 250 | temperature: float = 0.0, 251 | ) -> Union[str, List[str]]: 252 | return generic_generate_func_impl( 253 | func_sig=func_sig, 254 | model=model, 255 | strategy=strategy, 256 | prev_func_impl=prev_func_impl, 257 | feedback=feedback, 258 | self_reflection=self_reflection, 259 | num_comps=num_comps, 260 | temperature=temperature, 261 | reflexion_chat_instruction=VG_REFLEXION_CHAT_INSTRUCTION, 262 | reflexion_few_shot=VG_REFLEXION_FEW_SHOT_ADD, 263 | simple_chat_instruction=VG_SIMPLE_CHAT_INSTRUCTION, 264 | reflexion_completion_instruction=VG_REFLEXION_COMPLETION_INSTRUCTION, 265 | simple_completion_instruction=VG_SIMPLE_COMPLETION_INSTRUCTION, 266 | code_block_instruction=USE_VERILOG_CODEBLOCK_INSTRUCTION, 267 | parse_code_block=lambda x: parse_code_block(x, "verilog"), 268 | add_code_block=lambda x: add_code_block(x, "verilog"), 269 | lang="verilog" 270 | ) 271 | 272 | def internal_tests(self, func_sig: dict, model: ModelBase, max_num_tests: int = 5) -> List[str]: 273 | def parse_tests(tests: str) -> List[str]: 274 | return [test.strip() for test in tests.splitlines() if "assert" in test] 275 | """ 276 | Generates tests for a function. 277 | """ 278 | return func_sig['test'] 279 | 280 | return generic_generate_internal_tests( 281 | func_sig=func_sig, 282 | model=model, 283 | max_num_tests=max_num_tests, 284 | test_generation_few_shot=VG_TEST_GENERATION_FEW_SHOT, 285 | test_generation_chat_instruction=VG_TEST_GENERATION_CHAT_INSTRUCTION, 286 | test_generation_completion_instruction=VG_TEST_GENERATION_COMPLETION_INSTRUCTION, 287 | parse_tests=parse_tests, 288 | is_syntax_valid=vg_is_syntax_valid, 289 | ) 290 | 291 | 292 | DUMMY_FUNC_SIG = "def func():" 293 | DUMMY_FUNC_CALL = "func()" 294 | 295 | 296 | def handle_first_line_indent(func_body: str) -> str: 297 | if func_body.startswith(" "): 298 | return func_body 299 | split = func_body.splitlines() 300 | return f" {split[0]}\n" + "\n".join(split[1:]) 301 | 302 | 303 | def handle_entire_body_indent(func_body: str) -> str: 304 | split = func_body.splitlines() 305 | res = "\n".join([" " + line for line in split]) 306 | return res 307 | 308 | 309 | def fix_turbo_response(func_body: str) -> str: 310 | return fix_markdown(remove_unindented_signatures(func_body)) 311 | 312 | 313 | def fix_markdown(func_body: str) -> str: 314 | return re.sub("`{3}", "", func_body) 315 | 316 | 317 | def remove_unindented_signatures(code: str) -> str: 318 | regex = r"^def\s+\w+\s*\(" 319 | 320 | before_signature = [] 321 | after_signature = [] 322 | signature_found = False 323 | 324 | for line in code.split("\n"): 325 | if re.match(regex, line): 326 | signature_found = True 327 | continue 328 | 329 | if signature_found: 330 | after_signature.append(line) 331 | else: 332 | if not line.startswith(" ") and line.strip(): 333 | line = " " + line 334 | before_signature.append(line) 335 | 336 | return "\n".join(before_signature + after_signature) 337 | 338 | 339 | def py_fix_indentation(func_body: str) -> str: 340 | func_body = fix_turbo_response(func_body) 341 | """ 342 | 3 cases: 343 | 1. good syntax 344 | 2. first line not good 345 | 3. entire body not good 346 | """ 347 | def parse_indent_rec(f_body: str, cur_state: int) -> str: 348 | f_body = fix_markdown(f_body) 349 | if cur_state > 1: 350 | return f_body 351 | code = f'{DUMMY_FUNC_SIG}\n{f_body}\n{DUMMY_FUNC_CALL}' 352 | try: 353 | exec(code) 354 | return f_body 355 | except (IndentationError, SyntaxError): 356 | p_func = handle_first_line_indent if cur_state == 0 else handle_entire_body_indent 357 | return parse_indent_rec(p_func(func_body), cur_state + 1) 358 | except Exception: 359 | return f_body 360 | return parse_indent_rec(func_body, 0) 361 | 362 | 363 | def vg_is_syntax_valid(code: str) -> bool: 364 | try: 365 | from threading import Timer 366 | cmd = "iverilog -Wall -Winfloop -Wno-timescale -g2012 -s tb -o test.vvp {}.sv" 367 | """ 368 | adding timeout options for Popen. something breaks if not using timeout. seems to be working for now. 369 | not really sure if its the best/correct way. let me know if anyone has a better solution. 370 | https://stackoverflow.com/questions/1191374/using-module-subprocess-with-timeout 371 | """ 372 | p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 373 | timer = Timer(timeout, p.kill) 374 | try: 375 | timer.start() 376 | out, err = p.communicate() 377 | finally: 378 | timer.cancel() 379 | out, err = out.decode("utf-8"), err.decode("utf-8") 380 | if len(out) > 0 or len(err) > 0: 381 | return False 382 | return True 383 | except Exception: 384 | return False 385 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from task import ( 4 | react_fix_compile, 5 | react_fix_simulate, 6 | oneshot_fix_compile, 7 | ) 8 | 9 | from utils import read_jsonl, read_jsonl_gz 10 | 11 | 12 | def get_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--run_name", type=str, help="The name of the run") 15 | parser.add_argument("--root_dir", type=str, 16 | help="The root logging directory", default="root") 17 | parser.add_argument("--prefix", type=str, 18 | help="log file prefix", default="") 19 | parser.add_argument("--dataset_path", type=str, 20 | help="The path to the benchmark dataset", default="root") 21 | parser.add_argument("--task", type=str, 22 | help="Task: `react_fix_compile`, `react_fix_simulate`, `oneshot_fix_compile`") 23 | parser.add_argument("--agent_feedback", type=str, 24 | help="Agent: `nofeedback`, `feedback`, `rag`", default="cot") 25 | parser.add_argument("--compiler", type=str, 26 | help="Compiler: `iverilog`, `modelsim`, `vcs`, `quartus`", default="iverilog") 27 | parser.add_argument("--language", type=str, help="`verilog` or `py` or `rs`") 28 | parser.add_argument( 29 | "--model", type=str, help="OpenAI models only for now. For best results, use GPT-4") 30 | parser.add_argument("--pass_at_k", type=int, 31 | help="Pass@k metric", default=1) 32 | parser.add_argument("--max_iters", type=int, 33 | help="The maximum number of self-improvement iterations", default=10) 34 | parser.add_argument("--max_budgets", type=int, 35 | help="The maximum number of simulation budgets", default=10) 36 | parser.add_argument("--num_samples", type=int, 37 | help="The maximum number of sample generation", default=10) 38 | parser.add_argument("--verbose", action='store_true', 39 | help="To print live logs") 40 | args = parser.parse_args() 41 | return args 42 | 43 | 44 | def task_factory(task: str): 45 | def kwargs_wrapper_gen(func, delete_keys=[]): 46 | def kwargs_wrapper(**kwargs): 47 | for key in delete_keys: 48 | if key in kwargs: 49 | del kwargs[key] 50 | return func(**kwargs) 51 | return kwargs_wrapper 52 | 53 | if task == "react_fix_compile": 54 | return kwargs_wrapper_gen(react_fix_compile, delete_keys=[]) 55 | elif task == "react_fix_simulate": 56 | return kwargs_wrapper_gen(react_fix_simulate, delete_keys=["compiler"]) 57 | elif task == "oneshot_fix_compile": 58 | return kwargs_wrapper_gen(oneshot_fix_compile, delete_keys=["max_budgets", 'max_iters']) 59 | else: 60 | raise ValueError(f"Task `{task}` is not supported") 61 | 62 | 63 | def main(args): 64 | # check if the root dir exists and create it if not 65 | if not os.path.exists(args.root_dir): 66 | os.makedirs(args.root_dir) 67 | 68 | # get the dataset name 69 | dataset_name = os.path.basename(args.dataset_path).replace("jsonl", "") 70 | 71 | # check if log path already exists 72 | log_dir = os.path.join(args.root_dir, args.run_name) 73 | log_path = os.path.join( 74 | log_dir, f"{args.prefix}_{dataset_name}_{args.task}_iter_{args.max_iters}_num_sample_{args.num_samples}_{args.agent_feedback}_{args.model}_pass_at_k_{args.pass_at_k}_{args.language}_{args.compiler}.jsonl") 75 | if not os.path.exists(log_dir): 76 | os.makedirs(log_dir) 77 | 78 | # check if the task is valid 79 | run_task = task_factory(args.task) 80 | 81 | # print starting message 82 | if args.verbose: 83 | print(f""" 84 | Starting run with the following parameters: 85 | task: {args.task} 86 | agent_feedback: {args.agent_feedback} 87 | num_samples: {args.num_samples} 88 | pass@k: {args.pass_at_k} 89 | """) 90 | else: 91 | print(f"Logs will be saved in `{log_dir}`") 92 | 93 | # load the dataset 94 | print(f'Loading the dataset...') 95 | if args.dataset_path.endswith(".jsonl"): 96 | dataset = read_jsonl(args.dataset_path) 97 | elif args.dataset_path.endswith(".jsonl.gz"): 98 | dataset = read_jsonl_gz(args.dataset_path) 99 | else: 100 | raise ValueError( 101 | f"Dataset path `{args.dataset_path}` is not supported") 102 | 103 | print(f"Loaded {len(dataset)} examples") 104 | # start the run 105 | # evaluate with pass@k 106 | run_task( 107 | dataset=dataset, 108 | model_name=args.model, 109 | agent_feedback=args.agent_feedback, 110 | language=args.language, 111 | max_iters=args.max_iters, 112 | max_budgets=args.max_budgets, 113 | pass_at_k=args.pass_at_k, 114 | log_path=log_path, 115 | verbose=args.verbose, 116 | num_samples=args.num_samples, 117 | compiler=args.compiler, 118 | ) 119 | 120 | print(f"Done! Check out the logs in `{log_path}`") 121 | 122 | if args.language == "verilog": 123 | from verilog_eval.execution import clean_up_simulation 124 | clean_up_simulation() 125 | 126 | 127 | if __name__ == "__main__": 128 | args = get_args() 129 | main(args) 130 | -------------------------------------------------------------------------------- /src/scripts/run_example.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --run_name "test_oneshot_compile" \ 3 | # Root directory for the exp run 4 | --root_dir "exp" \ 5 | # Path to the dataset 6 | --dataset_path ./benchmarks/verilogeval-syntax-hard.jsonl \ 7 | # Task for running [oneshot_fix_compile, react_fix_compile, react_fix_simulate] 8 | --task "oneshot_fix_compile" \ 9 | # Specifying the agent to use [nefeedback, feedback, rag] 10 | --agent_feedback "rag" \ 11 | --language "verilog" \ 12 | --model "gpt-3.5-turbo-16k-0613" \ 13 | --pass_at_k "1" \ 14 | # Number of samples for each problem instance 15 | --num_samples '1' \ 16 | # Compiler to use [iverilog, modelsim, vcs, quartus] 17 | --compiler 'quartus' \ 18 | --verbose 19 | -------------------------------------------------------------------------------- /src/scripts/run_oneshot_fix_compile.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --run_name "test_oneshot_compile" \ 3 | --root_dir "exp" \ 4 | --dataset_path ./benchmarks/verilogeval-syntax-hard.jsonl \ 5 | --task "oneshot_fix_compile" \ 6 | --agent_feedback "rag" \ 7 | --language "verilog" \ 8 | --model "gpt-3.5-turbo-16k-0613" \ 9 | --pass_at_k "1" \ 10 | --num_samples '1' \ 11 | --compiler 'quartus' \ 12 | --verbose 13 | -------------------------------------------------------------------------------- /src/scripts/run_react_fix_compile.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --run_name "test_react_fix_compile" \ 3 | --root_dir "exp" \ 4 | --dataset_path ./benchmarks/verilogeval-syntax-hard.jsonl \ 5 | --task "react_fix_compile" \ 6 | --agent_feedback "rag" \ 7 | --language "verilog" \ 8 | --model "gpt-3.5-turbo-16k-0613" \ 9 | --pass_at_k "1" \ 10 | --max_iters "15" \ 11 | --max_budgets "5" \ 12 | --num_samples '1' \ 13 | --compiler 'quartus' \ 14 | --verbose 15 | -------------------------------------------------------------------------------- /src/scripts/run_react_fix_simulate.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --run_name "test_react_simulate" \ 3 | --root_dir "exp" \ 4 | --dataset_path ./benchmarks/verilogeval-simulate-hard.jsonl \ 5 | --task "react_fix_simulate" \ 6 | --agent_feedback "rag" \ 7 | --language "verilog" \ 8 | --model "gpt-3.5-turbo-16k-0613" \ 9 | --pass_at_k "1" \ 10 | --max_iters "15" \ 11 | --max_budgets "30" \ 12 | --num_samples '5' \ 13 | --verbose 14 | -------------------------------------------------------------------------------- /src/task/__init__.py: -------------------------------------------------------------------------------- 1 | from .react_fix_compile import react_fix_compile 2 | from .react_fix_simulate import react_fix_simulate 3 | from .oneshot_fix_compile import oneshot_fix_compile 4 | 5 | -------------------------------------------------------------------------------- /src/task/oneshot_fix_compile.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils import enumerate_resume, make_printv, write_jsonl 3 | from executors import executor_factory 4 | from generators import generator_factory, model_factory, agent_factory 5 | from verilog_eval.evaluation import estimate_pass_at_k 6 | from typing import List 7 | from copy import deepcopy 8 | from generators.model import Message 9 | from generators.agents.langchain_tools import RAGTool, LocalizeTool 10 | from executors.executor_utils import simple_syntax_fixer 11 | from tqdm import tqdm 12 | 13 | 14 | def oneshot_fix_compile( 15 | dataset: List[dict], 16 | model_name: str, 17 | language: str, 18 | pass_at_k: int, 19 | log_path: str, 20 | verbose: bool, 21 | is_leetcode: bool = False, 22 | agent_feedback: str = None, 23 | num_samples: int = 10, 24 | compiler: str = "" 25 | ) -> None: 26 | exe = executor_factory(language, is_leet=is_leetcode) 27 | gen = generator_factory(language) 28 | model = model_factory(model_name) 29 | 30 | 31 | print_v = make_printv(verbose) 32 | num_correct_list = [] 33 | num_samples_list = [] 34 | 35 | for i, item in tqdm(enumerate_resume(dataset, log_path), total=len(dataset)): 36 | 37 | num_correct = 0 38 | num_budgets = 0 39 | org_item = deepcopy(item) 40 | 41 | for j in tqdm(range(num_samples)): 42 | 43 | # init 44 | is_solved = False 45 | item = org_item 46 | func_impl = simple_syntax_fixer(item['solution'], item) 47 | 48 | exec_result = exe.evaluate( 49 | item if language == "verilog" else item['entry_point'], 50 | func_impl, 51 | item["test"], 52 | timeout = 20, 53 | # compile_only=compiler 54 | ) 55 | compile_log = exec_result['feedback']['compiler_log'] 56 | feedback = exec_result['feedback'] 57 | msg = "" 58 | numbered_impl = "" 59 | rag = "" 60 | output = "" 61 | 62 | if not exec_result['passed']: 63 | 64 | code_prompt = gen.prepare_prompt( 65 | item if language == "verilog" else item['prompt'], 66 | model, 67 | "simple", 68 | func_impl, 69 | ) 70 | 71 | code_prompt.append(Message( 72 | role="assistant", 73 | content=func_impl, 74 | )) 75 | 76 | 77 | if 'rag' in agent_feedback: 78 | rag = RAGTool(compiler) 79 | rag = rag._run(compile_log) 80 | 81 | 82 | if "nofeedback" in agent_feedback or ('dangling input' in compile_log and not rag): 83 | msg = "There is an error in the code that was cause compile error." 84 | else: 85 | compile_log = compile_log.replace('I give up.', '') 86 | msg = f"{numbered_impl}compile result: {compile_log}\n{rag}\n" 87 | 88 | code_prompt.append(Message( 89 | role="user", 90 | content=f"{msg}\nFix the error and give me the correct code.", 91 | )) 92 | 93 | output = model.generate_chat(code_prompt) 94 | func_impl = simple_syntax_fixer(output, item) 95 | 96 | # simulate 97 | assert isinstance(func_impl, str) 98 | exec_result = exe.evaluate( 99 | item if language == "verilog" else item['entry_point'], 100 | func_impl, 101 | item["test"], 102 | timeout = 20, 103 | compile_only=compiler 104 | ) 105 | if exec_result['passed']: 106 | is_solved = True 107 | num_correct += 1 108 | 109 | item['output'] = output 110 | item['reflexion'] = msg 111 | item["solution"] = func_impl 112 | item["feedbacks"] = exec_result['feedback']['compiler_log'] 113 | item["test_output"] = exec_result['feedback']['test_output'] 114 | item['completion'] = exec_result['feedback']['completion'] 115 | item["is_solved"] = is_solved 116 | write_jsonl(log_path, [item], append=True) 117 | 118 | num_correct_list.append(num_correct) 119 | num_samples_list.append(num_samples) 120 | 121 | pass_k_list = estimate_pass_at_k(num_samples_list, num_correct_list, pass_at_k) 122 | print(f"pass@{pass_at_k}: {np.mean(pass_k_list)} {pass_k_list}") 123 | -------------------------------------------------------------------------------- /src/task/react_fix_compile.py: -------------------------------------------------------------------------------- 1 | from utils import enumerate_resume, make_printv, write_jsonl 2 | from executors import executor_factory 3 | from generators import generator_factory, model_factory, agent_factory 4 | from verilog_eval.evaluation import estimate_pass_at_k 5 | from typing import List 6 | from copy import deepcopy 7 | from executors.executor_utils import simple_syntax_fixer 8 | 9 | 10 | def react_fix_compile( 11 | dataset: List[dict], 12 | model_name: str, 13 | language: str, 14 | max_iters: int, 15 | max_budgets: int, 16 | pass_at_k: int, 17 | log_path: str, 18 | verbose: bool, 19 | is_leetcode: bool = False, 20 | agent_feedback: str = None, 21 | num_samples: int = 10, 22 | compiler: str = "" 23 | ) -> None: 24 | exe = executor_factory(language, is_leet=is_leetcode) 25 | gen = generator_factory(language) 26 | model = model_factory(model_name) 27 | method = agent_feedback 28 | 29 | code_agent = agent_factory('rtlfixer_code', model_name, exe=exe, max_iters=10, toolset=['compiler', 'rag'], method=method, compiler=compiler) 30 | fix_agent = agent_factory('rtlfixer_fix', model_name, exe=exe, max_iters=20, toolset=['examine']) 31 | 32 | 33 | print_v = make_printv(verbose) 34 | num_correct_list = [] 35 | num_samples_list = [] 36 | 37 | for i, item in enumerate_resume(dataset, log_path): 38 | 39 | num_correct = 0 40 | num_budgets = 0 41 | org_item = deepcopy(item) 42 | 43 | for j in range(num_samples): 44 | 45 | # init 46 | is_solved = False 47 | item = org_item 48 | func_impl = None 49 | 50 | reflections = [] 51 | implementations = [] 52 | test_feedback = [] 53 | history = [] 54 | 55 | code_agent.toolkit.initialize(item) 56 | fix_agent.toolkit.initialize(item) 57 | code_agent.reset_logs() 58 | fix_agent.reset_logs() 59 | 60 | for _ in range(1): 61 | 62 | func_impl = simple_syntax_fixer(item['solution'], item) 63 | exec_result = exe.evaluate( 64 | item if language == "verilog" else item['entry_point'], 65 | func_impl, 66 | item["test"], 67 | timeout = 20, 68 | compile_only=compiler 69 | ) 70 | compile_log = exec_result['feedback']['compiler_log'] 71 | feedback = exec_result['feedback'] 72 | is_solved = exec_result['passed'] 73 | implementations.append(func_impl) 74 | 75 | if not exec_result['passed']: 76 | 77 | # code generation 78 | code_prompt = gen.prepare_prompt( 79 | item if language == "verilog" else item['prompt'], 80 | model, 81 | "react_code", 82 | func_impl, 83 | test_feedback, 84 | reflections, 85 | ) 86 | 87 | import dataclasses 88 | from typing import List, Union, Optional, Literal 89 | MessageRole = Literal["system", "user", "assistant"] 90 | @dataclasses.dataclass() 91 | class Message(): 92 | role: MessageRole 93 | content: str 94 | 95 | code_prompt.append(Message( 96 | role="assistant", 97 | content=func_impl, 98 | )) 99 | code_prompt.append(Message( 100 | role="user", 101 | content="Run the compiler and fix the error if encountered.", 102 | )) 103 | func_impl = code_agent.generate_chat(code_prompt) 104 | func_impl = simple_syntax_fixer(func_impl, item) 105 | 106 | 107 | if func_impl not in implementations: 108 | history.append(code_agent.intermediate_steps) 109 | implementations.append(func_impl) 110 | 111 | # simulate 112 | assert isinstance(func_impl, str) 113 | exec_result = exe.evaluate( 114 | item if language == "verilog" else item['entry_point'], 115 | func_impl, 116 | item["test"], 117 | timeout = 20 if is_leetcode else 20, 118 | compile_only=compiler 119 | ) 120 | if 'wave.vcd' in exec_result: 121 | exec_result.pop('wave.vcd') 122 | if 'test_output' in exec_result: 123 | exec_result.pop('test_output') 124 | 125 | test_feedback.append(exec_result['feedback']) 126 | code_agent.toolkit.num_simulat += 1 127 | 128 | if exec_result['passed']: 129 | is_solved = True 130 | num_correct += 1 131 | 132 | item["solution"] = func_impl 133 | item["feedbacks"] = test_feedback 134 | item["is_solved"] = is_solved 135 | item["num_simulat"] = code_agent.toolkit.num_simulat 136 | item["num_compile"] = code_agent.toolkit.num_compile 137 | item['agent_history'] = code_agent.serialize(history) 138 | item['error_logs'] = code_agent.serialize({ 139 | 'code': code_agent.error_logs, 140 | 'fix': fix_agent.error_logs 141 | }) 142 | write_jsonl(log_path, [item], append=True) 143 | num_budgets += code_agent.toolkit.num_simulat 144 | 145 | num_correct_list.append(num_correct) 146 | num_samples_list.append(num_budgets) 147 | 148 | pass_k_list = estimate_pass_at_k(num_samples_list, num_correct_list, pass_at_k) 149 | print(f"pass@{pass_at_k}: {np.mean(pass_k_list)} {pass_k_list}") 150 | -------------------------------------------------------------------------------- /src/task/react_fix_compile_rtllm.py: -------------------------------------------------------------------------------- 1 | from utils import enumerate_resume, make_printv, write_jsonl 2 | from executors import executor_factory 3 | from generators import generator_factory, model_factory, agent_factory 4 | from verilog_eval.evaluation import estimate_pass_at_k 5 | from typing import List 6 | from copy import deepcopy 7 | from executors.executor_utils import simple_syntax_fixer 8 | 9 | 10 | def react_fix_compile( 11 | dataset: List[dict], 12 | model_name: str, 13 | language: str, 14 | max_iters: int, 15 | max_budgets: int, 16 | pass_at_k: int, 17 | log_path: str, 18 | verbose: bool, 19 | is_leetcode: bool = False, 20 | agent_feedback: str = None, 21 | num_samples: int = 10, 22 | compiler: str = "" 23 | ) -> None: 24 | exe = executor_factory(language, is_leet=is_leetcode) 25 | gen = generator_factory(language) 26 | model = model_factory(model_name) 27 | method = agent_feedback 28 | code_agent = agent_factory('rtlfixer_code', model_name, exe=exe, max_iters=10, toolset=['compiler'], method=method, compiler=compiler) 29 | fix_agent = agent_factory('rtlfixer_fix', model_name, exe=exe, max_iters=20, toolset=['examine']) 30 | 31 | 32 | print_v = make_printv(verbose) 33 | num_correct_list = [] 34 | num_samples_list = [] 35 | 36 | for i, item in enumerate_resume(dataset, log_path): 37 | 38 | 39 | num_correct = 0 40 | num_budgets = 0 41 | org_item = deepcopy(item) 42 | 43 | for j in range(num_samples): 44 | 45 | # init 46 | is_solved = False 47 | item = org_item 48 | func_impl = None 49 | 50 | reflections = [] 51 | implementations = [] 52 | test_feedback = [] 53 | history = [] 54 | 55 | code_agent.toolkit.initialize(item) 56 | fix_agent.toolkit.initialize(item) 57 | code_agent.reset_logs() 58 | fix_agent.reset_logs() 59 | 60 | # while code_agent.toolkit.num_simulat < max_budgets: 61 | for _ in range(1): 62 | 63 | func_impl = simple_syntax_fixer(item['solution'], item) 64 | exec_result = exe.evaluate( 65 | item if language == "verilog" else item['entry_point'], 66 | func_impl, 67 | item["test"], 68 | timeout = 20, 69 | # compile_only=compiler 70 | ) 71 | compile_log = exec_result['feedback']['compiler_log'] 72 | feedback = exec_result['feedback'] 73 | 74 | if not exec_result['passed']: 75 | 76 | # code generation 77 | code_prompt = gen.prepare_prompt( 78 | item if language == "verilog" else item['prompt'], 79 | model, 80 | "dual_agent_code", 81 | func_impl, 82 | test_feedback, 83 | reflections, 84 | ) 85 | 86 | import dataclasses 87 | from typing import List, Union, Optional, Literal 88 | MessageRole = Literal["system", "user", "assistant"] 89 | @dataclasses.dataclass() 90 | class Message(): 91 | role: MessageRole 92 | content: str 93 | 94 | code_prompt.append(Message( 95 | role="assistant", 96 | content=func_impl, 97 | )) 98 | code_prompt.append(Message( 99 | role="user", 100 | content="Run the compiler and fix the error if encountered.", 101 | )) 102 | 103 | func_impl = code_agent.generate_chat(code_prompt) 104 | func_impl = simple_syntax_fixer(func_impl, item) 105 | 106 | 107 | 108 | # simulate 109 | assert isinstance(func_impl, str) 110 | exec_result = exe.evaluate( 111 | item if language == "verilog" else item['entry_point'], 112 | func_impl, 113 | item["test"], 114 | timeout = 20 if is_leetcode else 20, 115 | # compile_only=compiler 116 | ) 117 | if 'wave.vcd' in exec_result: 118 | exec_result.pop('wave.vcd') 119 | if 'test_output' in exec_result: 120 | exec_result.pop('test_output') 121 | 122 | if exec_result['passed']: 123 | is_solved = True 124 | num_correct += 1 125 | 126 | 127 | 128 | import pdb 129 | pdb.set_trace() 130 | # item['result'] = result 131 | item["solution"] = func_impl 132 | # item["reflections"] = reflections 133 | # item["implementations"] = implementations 134 | item["feedback"] = exec_result['feedback'] 135 | item["is_solved"] = is_solved 136 | item["num_simulat"] = code_agent.toolkit.num_simulat 137 | item["num_compile"] = code_agent.toolkit.num_compile 138 | item['agent_history'] = code_agent.serialize(history) 139 | item['error_logs'] = code_agent.serialize({ 140 | 'code': code_agent.error_logs, 141 | 'fix': fix_agent.error_logs 142 | }) 143 | write_jsonl(log_path, [item], append=True) 144 | num_budgets += code_agent.toolkit.num_simulat 145 | 146 | num_correct_list.append(num_correct) 147 | num_samples_list.append(num_budgets) 148 | 149 | print(f"pass@{pass_at_k}: {estimate_pass_at_k(num_samples_list, num_correct_list, pass_at_k)}") 150 | -------------------------------------------------------------------------------- /src/task/react_fix_simulate.py: -------------------------------------------------------------------------------- 1 | from utils import enumerate_resume, make_printv, write_jsonl 2 | from executors import executor_factory 3 | from generators import generator_factory, model_factory, agent_factory 4 | from verilog_eval.evaluation import estimate_pass_at_k 5 | from typing import List, Tuple 6 | from copy import deepcopy 7 | import numpy as np 8 | from executors.executor_utils import simple_syntax_fixer 9 | 10 | 11 | def parse_mismatches(log: str) -> int: 12 | for line in log.split('\n'): 13 | if 'Mismatches' in line: 14 | st = len("Mismatches:") 15 | ed = line.find("in") 16 | return int(line[st: ed].strip()) 17 | return 999 18 | 19 | 20 | def restart_strategy(implementations: list, test_feedback: list, func_impl: str, patient: int = 3) -> Tuple[bool, str, str]: 21 | 22 | assert len(implementations) == len(test_feedback) 23 | if len(implementations) <= 1: 24 | return False, "First", func_impl 25 | 26 | if 'give up' in test_feedback[-1]['compiler_log']: 27 | return True, "compiler give up", "" 28 | 29 | miss_match_list = [parse_mismatches(i['test_output']) for i in test_feedback] 30 | 31 | # early stop 32 | # if len(implementations) >= patient: 33 | # # no improvment in rounds (equal ascending orders) 34 | # if miss_match_list[-patient:] == np.sort(miss_match_list[-patient:]).tolist(): 35 | # return True, "patient", "" 36 | 37 | last_miss_match = miss_match_list[-1] 38 | best_idx = np.argmin(miss_match_list) 39 | best_miss_match = miss_match_list[best_idx] 40 | best_impl = implementations[best_idx] 41 | 42 | if last_miss_match < best_miss_match: 43 | # has improve 44 | return False, "improved", func_impl 45 | elif last_miss_match > best_miss_match: 46 | # has worsen 47 | return False, "worsen", best_impl 48 | else: 49 | return False, "no_change", func_impl 50 | 51 | 52 | 53 | def react_fix_simulate( 54 | dataset: List[dict], 55 | model_name: str, 56 | language: str, 57 | max_iters: int, 58 | max_budgets: int, 59 | pass_at_k: int, 60 | log_path: str, 61 | verbose: bool, 62 | is_leetcode: bool = False, 63 | agent_feedback: str = None, 64 | num_samples: int = 10 65 | ) -> None: 66 | exe = executor_factory(language, is_leet=is_leetcode) 67 | gen = generator_factory(language) 68 | model = model_factory(model_name) 69 | nofeedback = True if 'nofeedback' in agent_feedback else False 70 | code_agent = agent_factory('rtlfixer_code', model_name, exe=exe, max_iters=20, toolset=['compiler']) 71 | fix_agent = agent_factory('rtlfixer_fix', model_name, exe=exe, max_iters=20, toolset=[]) 72 | 73 | 74 | print_v = make_printv(verbose) 75 | num_correct_list = [] 76 | num_samples_list = [] 77 | 78 | for i, item in enumerate_resume(dataset, log_path): 79 | 80 | num_correct = 0 81 | num_budgets = 0 82 | org_item = deepcopy(item) 83 | 84 | code_agent.toolkit.initialize(item) 85 | fix_agent.toolkit.initialize(item) 86 | 87 | # for j in range(num_samples): 88 | while code_agent.toolkit.num_simulat < max_budgets: 89 | 90 | # init 91 | is_solved = False 92 | item = org_item 93 | func_impl = None 94 | 95 | reflections = [] 96 | implementations = [] 97 | test_feedback = [] 98 | history = [] 99 | 100 | code_agent.reset_logs() 101 | fix_agent.reset_logs() 102 | 103 | func_impl = item['solution'] 104 | n_compile_error = 0 105 | n_simulate_error = 0 106 | 107 | while code_agent.toolkit.num_simulat < max_budgets: 108 | 109 | func_impl = simple_syntax_fixer(func_impl, item) 110 | 111 | if func_impl not in implementations: 112 | 113 | # simulate 114 | assert isinstance(func_impl, str) 115 | exec_result = exe.evaluate( 116 | item if language == "verilog" else item['entry_point'], 117 | func_impl, 118 | item["test"], 119 | timeout = 20 if is_leetcode else 20 120 | ) 121 | if 'wave.vcd' in exec_result: 122 | exec_result.pop('wave.vcd') 123 | if 'test_output' in exec_result: 124 | exec_result.pop('test_output') 125 | 126 | 127 | code_agent.toolkit.num_simulat += 1 128 | if hasattr(code_agent, 'intermediate_steps'): 129 | history.append(code_agent.intermediate_steps) 130 | implementations.append(func_impl) 131 | test_feedback.append(exec_result['feedback']) 132 | if exec_result['passed']: 133 | is_solved = True 134 | num_correct += 1 135 | break 136 | elif exec_result['feedback']['compiler_log']: 137 | n_compile_error += 1 138 | else: 139 | n_simulate_error += 1 140 | 141 | # track impl and performance 142 | restart, reason, func_impl = restart_strategy(implementations, test_feedback, func_impl) 143 | if restart: 144 | history.append({ 145 | 'restart': { 146 | 'reason': reason 147 | } 148 | }) 149 | func_impl = "" 150 | 151 | if func_impl: # did not restart 152 | # prepare for examine tool 153 | item['feedback'] = exec_result['feedback'] 154 | item['solution'] = func_impl 155 | 156 | # fix 157 | if nofeedback: 158 | err_msg = "There is an error during the simulation. Please check the logic and fix the code." 159 | else: 160 | err_msg = exec_result['feedback']['test_output'] + f"\nWaveform:\n{fix_agent.toolkit.examine._run('waveform')}\n" 161 | 162 | fix_prompt = gen.prepare_prompt( 163 | item if language == "verilog" else item['prompt'], 164 | model, 165 | "react_fix", 166 | func_impl, 167 | err_msg, 168 | ) 169 | reflexion = fix_agent.generate_chat(fix_prompt) 170 | history.append(fix_agent.intermediate_steps) 171 | reflections.append(reflexion) 172 | 173 | 174 | # code generation 175 | org_temperature = code_agent.llm.temperature 176 | while True: 177 | code_prompt = gen.prepare_prompt( 178 | item if language == "verilog" else item['prompt'], 179 | model, 180 | "react_code", 181 | func_impl, 182 | test_feedback, 183 | reflections, 184 | ) 185 | func_impl = code_agent.generate_chat(code_prompt) 186 | if func_impl is None: 187 | func_impl = implementations[-1] 188 | if func_impl not in implementations: 189 | code_agent.llm.temperature = org_temperature 190 | break 191 | else: 192 | if code_agent.llm.temperature > 1.0: 193 | func_impl = "" 194 | else: 195 | code_agent.llm.temperature += 0.1 196 | 197 | 198 | import pdb 199 | pdb.set_trace() 200 | item["solution"] = func_impl 201 | item["reflections"] = reflections 202 | item["implementations"] = implementations 203 | item["feedbacks"] = test_feedback 204 | item["is_solved"] = is_solved 205 | item["num_simulat"] = code_agent.toolkit.num_simulat 206 | item["num_compile"] = code_agent.toolkit.num_compile 207 | item["n_compile_error"] = n_compile_error 208 | item["n_simulate_error"] = n_simulate_error 209 | item['agent_history'] = code_agent.serialize(history) 210 | item['error_logs'] = code_agent.serialize({ 211 | 'code': code_agent.error_logs, 212 | 'fix': fix_agent.error_logs 213 | }) 214 | write_jsonl(log_path, [item], append=True) 215 | num_budgets += code_agent.toolkit.num_simulat 216 | 217 | num_correct_list.append(num_correct) 218 | num_samples_list.append(num_budgets) 219 | 220 | pass_k_list = estimate_pass_at_k(num_samples_list, num_correct_list, pass_at_k) 221 | print(f"pass@{pass_at_k}: {np.mean(pass_k_list)} {pass_k_list}") 222 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gzip 3 | import json 4 | import openai 5 | import jsonlines 6 | 7 | from typing import List 8 | 9 | openai.api_key = os.getenv("OPENAI_API_KEY") 10 | 11 | 12 | def make_printv(verbose: bool): 13 | def print_v(*args, **kwargs): 14 | if verbose: 15 | kwargs["flush"] = True 16 | print(*args, **kwargs) 17 | else: 18 | pass 19 | return print_v 20 | 21 | 22 | def read_jsonl(path: str) -> List[dict]: 23 | if not os.path.exists(path): 24 | raise FileNotFoundError(f"File `{path}` does not exist.") 25 | elif not path.endswith(".jsonl"): 26 | raise ValueError(f"File `{path}` is not a jsonl file.") 27 | items = [] 28 | with jsonlines.open(path) as reader: 29 | for item in reader: 30 | items += [item] 31 | return items 32 | 33 | 34 | def write_jsonl(path: str, data: List[dict], append: bool = False): 35 | with jsonlines.open(path, mode='a' if append else 'w') as writer: 36 | for item in data: 37 | writer.write(item) 38 | 39 | 40 | def read_jsonl_gz(path: str) -> List[dict]: 41 | if not path.endswith(".jsonl.gz"): 42 | raise ValueError(f"File `{path}` is not a jsonl.gz file.") 43 | with gzip.open(path, "rt") as f: 44 | data = [json.loads(line) for line in f] 45 | return data 46 | 47 | 48 | # generator that returns the item and the index in the dataset. 49 | # if the results_path exists, it will skip all items that have been processed 50 | # before. 51 | def enumerate_resume(dataset, results_path): 52 | if not os.path.exists(results_path): 53 | for i, item in enumerate(dataset): 54 | yield i, item 55 | else: 56 | # count = 0 57 | # with jsonlines.open(results_path) as reader: 58 | # for item in reader: 59 | # count += 1 60 | 61 | count = set() 62 | with jsonlines.open(results_path) as reader: 63 | for item in reader: 64 | count.add(item['task_id']) 65 | # count = len(count) 66 | 67 | for i, item in enumerate(dataset): 68 | # skip items that have been processed before 69 | # if i < count: 70 | # continue 71 | if item['task_id'] in count: 72 | continue 73 | yield i, item 74 | 75 | 76 | def resume_success_count(dataset) -> int: 77 | count = 0 78 | for item in dataset: 79 | if "is_solved" in item and item["is_solved"]: 80 | count += 1 81 | return count 82 | 83 | 84 | def verilog_compile_has_error(log: str): 85 | log = log.lower() 86 | if 'error' in log or 'give up' in log: 87 | return True 88 | return False -------------------------------------------------------------------------------- /src/visualize.py: -------------------------------------------------------------------------------- 1 | from vcdvcd import VCDVCD, binary_string_to_hex, StreamParserCallbacks 2 | import math 3 | import io 4 | import pandas as pd 5 | from typing import List 6 | 7 | 8 | class CustomCallback(StreamParserCallbacks): 9 | def __init__(self, printIds={}, lines=20, offset=0): 10 | self._printIdx = printIds 11 | self._references_to_widths = {} 12 | self.lines=20 13 | self.counter=0 14 | self.offset=offset 15 | 16 | def enddefinitions( 17 | self, 18 | vcd, 19 | signals, 20 | cur_sig_vals 21 | ): 22 | vcd.io = io.StringIO() 23 | self._printIdx = self._printIdx if self._printIdx else {i: i.split('.')[-1] for i in vcd.signals} 24 | 25 | if signals: 26 | self._print_dumps_refs = signals 27 | else: 28 | self._print_dumps_refs = sorted(vcd.data[i].references[0] for i in cur_sig_vals.keys()) 29 | 30 | for i, ref in enumerate(self._print_dumps_refs, 1): 31 | if i == 0: 32 | i = 1 33 | identifier_code = vcd.references_to_ids[ref] 34 | size = int(vcd.data[identifier_code].size) 35 | width = max(((size // 4)), int(math.floor(math.log10(i))) + 1) 36 | self._references_to_widths[ref] = width 37 | 38 | to_print = '// {0:<16}'.format('time') 39 | for ref in vcd.signals: 40 | string = '{0:>{1}s}'.format(self._printIdx[ref], self._references_to_widths.get(ref, 1)) 41 | to_print += '{0:<16}'.format(string) 42 | 43 | print(to_print, file=vcd.io) 44 | 45 | 46 | def time( 47 | self, 48 | vcd, 49 | time, 50 | cur_sig_vals 51 | ): 52 | self.counter += 1 53 | 54 | 55 | if self.counter > self.offset + self.lines or self.counter < self.offset: 56 | return 57 | 58 | if (vcd.signal_changed): 59 | ss = [] 60 | ss.append('// {0:<16}'.format(str(time)+'ns')) 61 | for ref in self._printIdx: 62 | identifier_code = vcd.references_to_ids[ref] 63 | value = cur_sig_vals[identifier_code] 64 | string = '{0:>{1}s}'.format( 65 | binary_string_to_hex(value), 66 | self._references_to_widths.get(ref, 1)) 67 | ss.append( '{0:<16}'.format(string) ) 68 | print(''.join(ss), file=vcd.io) 69 | 70 | 71 | 72 | 73 | 74 | 75 | def tabular_via_callback(vcd_path, offset: int, mismatch_columns: List[str], window_size: int = 5): 76 | vcd = VCDVCD(vcd_path, callbacks=CustomCallback(offset=offset, lines=window_size), store_tvs=False, only_sigs=False) 77 | tabular_text = vcd.io.getvalue() 78 | return tabular_text 79 | 80 | def tabular_via_dataframe(vcd_path, offset: int, mismatch_columns: List[str], window_size: int = 5): 81 | # from scipy.sparse import csc_matrix 82 | import numpy as np 83 | 84 | vcd = VCDVCD(vcd_path) 85 | n_row = vcd.endtime + 1 86 | n_col = len(vcd.signals) 87 | matrix = np.full((n_row, n_col), np.nan, dtype=float) 88 | for e, ref in enumerate(vcd.signals): 89 | symbol = vcd.references_to_ids[ref] 90 | for ts, signal in vcd.data[symbol].tv: 91 | try: 92 | matrix[ts, e] = int(signal) if signal.isdigit() else -999 93 | except: 94 | matrix[ts, e] = -999 95 | 96 | df = pd.DataFrame(matrix, columns=[i.split(".")[-1] for i in vcd.signals]).dropna(subset='clk') 97 | df = df.fillna(method='ffill') 98 | 99 | mismatch_columns = [i for i in df.columns if any(j in i for j in mismatch_columns)] 100 | first_row = df.loc[0: 1][mismatch_columns] 101 | tail_rows = df.loc[1: offset+1][mismatch_columns].drop_duplicates(keep='first') 102 | df = pd.concat([first_row, tail_rows])[-window_size:] 103 | df = df.astype(int).astype(str).applymap(lambda x: binary_string_to_hex(x) if x != -999 else 'x') 104 | df.index.names = ['time(ns)'] 105 | return df.to_string(header=True, index=True) 106 | --------------------------------------------------------------------------------