├── .github └── workflows │ └── main.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── action.yml ├── fig ├── .DS_Store ├── DAC-checkpoint.pdf ├── DAC-overview.pdf ├── DAC-overview.png ├── Sampling_ablation.pdf ├── Sampling_motivation.pdf ├── TBD_horizontal.pdf ├── TBD_vertical.pdf ├── fig2.pdf ├── fig3.pdf ├── fig3_1.pdf ├── fig4-1.pdf ├── fig4-2.pdf └── fig_5.pdf ├── pyproject.toml ├── src ├── mage │ ├── __init__.py │ ├── agent.py │ ├── bash_tools.py │ ├── benchmark_read_helper.py │ ├── converage │ │ ├── LLMGuidance.cpp │ │ ├── LLMGuidance.h │ │ ├── LLMGuidance4CodeCov.cpp │ │ ├── LLMGuidance4CodeCov.h │ │ └── RunGPT.py │ ├── gen_config.py │ ├── log_utils.py │ ├── prompts.py │ ├── rtl_editor.py │ ├── rtl_generator.py │ ├── sim_judge.py │ ├── sim_reviewer.py │ ├── tb_generator.py │ ├── token_counter.py │ └── utils.py └── sim │ ├── .gitignore │ ├── Makefile │ ├── Makefile_obj │ ├── input.vc │ ├── sim_golden.vvp │ └── top.sv ├── testbench_generate.ipynb └── tests ├── .gitignore ├── test_llm_chat.py ├── test_rtl_generator.py ├── test_single_agent.py └── test_top_agent.py /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v3 14 | - uses: pre-commit/action@v3.0.1 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info/ 2 | *.vcd 3 | __pycache__/ 4 | key.cfg 5 | wave.vcd 6 | output*/ 7 | log*/ 8 | data/ 9 | !requirements.txt 10 | *.o 11 | *.d 12 | *.vvp 13 | dist/ 14 | build/ 15 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "verilog-eval"] 2 | path = verilog-eval 3 | url = https://github.com/NVlabs/verilog-eval 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | - id: trailing-whitespace 6 | name: (Common) Remove trailing whitespaces 7 | - id: mixed-line-ending 8 | name: (Common) Fix mixed line ending 9 | args: [--fix=lf] 10 | - id: end-of-file-fixer 11 | name: (Common) Remove extra EOF newlines 12 | - id: check-merge-conflict 13 | name: (Common) Check for merge conflicts 14 | - id: requirements-txt-fixer 15 | name: (Common) Sort "requirements.txt" 16 | - id: check-added-large-files 17 | name: (Common) Prevent giant files from being committed 18 | - id: fix-encoding-pragma 19 | name: (Python) Remove encoding pragmas 20 | args: [--remove] 21 | - id: debug-statements 22 | name: (Python) Check for debugger imports 23 | - id: check-json 24 | name: (JSON) Check syntax 25 | - id: check-yaml 26 | name: (YAML) Check syntax 27 | - id: check-toml 28 | name: (TOML) Check syntax 29 | 30 | - repo: https://github.com/psf/black 31 | rev: 24.8.0 32 | hooks: 33 | - id: black 34 | args: [--line-length=88] 35 | 36 | - repo: https://github.com/PyCQA/isort 37 | rev: 5.13.2 38 | hooks: 39 | - id: isort 40 | args: ["--profile", "black", "--filter-files"] 41 | 42 | - repo: https://github.com/PyCQA/flake8 43 | rev: 7.1.1 44 | hooks: 45 | - id: flake8 46 | additional_dependencies: [flake8-bugbear] 47 | args: ["--max-line-length=88", "--extend-ignore=E203,W503,E501,F541"] 48 | 49 | # - repo: https://github.com/pre-commit/mirrors-mypy 50 | # rev: v1.3.0 51 | # hooks: 52 | # - id: mypy 53 | # args: [--ignore-missing-imports] 54 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Stable Lab @ UCSD 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MAGE: A Multi-Agent Engine for Automated RTL Code Generation 2 | 3 | You can learn more on our Arxiv Paper: https://arxiv.org/abs/2412.07822. 4 | MAGE is an open-source multi-agent LLM RTL code generator. 5 | 6 | ![DAC-overview](fig/DAC-overview.png) 7 | 8 | ## Environment Set Up 9 | 10 | ### 1.> To install the repo itself: 11 | ``` 12 | git clone https://github.com/stable-lab/MAGE.git 13 | # To get submodules at the same time 14 | git clone --recursive https://github.com/stable-lab/MAGE.git 15 | cd MAGE 16 | 17 | # Install conda first if it's not on your machine like "apt install conda" 18 | # To confirm successful installation of conda, run "conda --version" 19 | # Continue after successfully installed conda 20 | conda create -n mage python=3.11 21 | conda activate mage 22 | 23 | # Install the repo as a package. 24 | # If want to editable install as developer, 25 | # please check development guide below. 26 | pip install . 27 | ``` 28 | 29 | ### 2.>To set api key: 30 | You can either: 31 | 1. Set "OPENAI_API_KEY", "ANTHROPIC_API_KEY" or other keys in your env variables 32 | 2. Create key.cfg file. The file should be in format of: 33 | 34 | ``` 35 | OPENAI_API_KEY= 'xxxxxxx' 36 | ANTHROPIC_API_KEY= 'xxxxxxx' 37 | VERTEX_SERVICE_ACCOUNT_PATH= 'xxxxxxx' 38 | VERTEX_REGION= 'xxxxxxx' 39 | ``` 40 | 41 | ### To install iverilog {.tabset} 42 | You'll need to install [ICARUS verilog](https://github.com/steveicarus/iverilog) 12.0 43 | For latest installation guide, please refer to [iverilog official guide](https://steveicarus.github.io/iverilog/usage/installation.html) 44 | 45 | #### Ubuntu (Local Compilation) 46 | ``` 47 | apt install -y autoconf gperf make gcc g++ bison flex 48 | ``` 49 | and 50 | ``` 51 | $ git clone https://github.com/steveicarus/iverilog.git && cd iverilog \ 52 | && git checkout v12-branch \ 53 | && sh ./autoconf.sh && ./configure && make -j4\ 54 | $ sudo make install 55 | ``` 56 | #### MacOS 57 | ``` 58 | brew install icarus-verilog 59 | ``` 60 | 61 | #### Version confirmation of iverilog 62 | Please confirm the iverilog version is v12 by running 63 | ``` 64 | iverilog -v 65 | ``` 66 | 67 | First line of output is expected to be: 68 | ``` 69 | Icarus Verilog version 12.0 (stable) (v12_0) 70 | ``` 71 | 72 | ### 3.> Verilator Installation 73 | 74 | ``` 75 | # By apt 76 | sudo apt install verilator 77 | 78 | # By Compilation 79 | git clone https://github.com/verilator/verilator 80 | cd verilator 81 | autoconf 82 | export VERILATOR_ROOT=`pwd` 83 | ./configure 84 | make -j4 85 | ``` 86 | 87 | ### 4.> Pyverilog Installation 88 | 89 | ``` 90 | # pre require 91 | pip3 install jinja2 ply 92 | 93 | git clone https://github.com/PyHDI/Pyverilog.git 94 | cd Pyverilog 95 | # must to user dir, or error because no root 96 | python3 setup.py install --user 97 | ``` 98 | 99 | ### 5.> To get benchmarks 100 | 101 | ``` 102 | [verilog-eval](https://github.com/NVlabs/verilog-eval) 103 | ``` 104 | 105 | ``` 106 | git submodule update --init --recursive 107 | ``` 108 | 109 | ## Run Guide 110 | ``` 111 | python tests/test_top_agent.py 112 | ``` 113 | 114 | Run arguments can be set in the file like: 115 | 116 | ``` 117 | args_dict = { 118 | "provider": "anthropic", 119 | "model": "claude-3-5-sonnet-20241022", 120 | # "model": "gpt-4o-2024-08-06", 121 | # "filter_instance": "^(Prob070_ece241_2013_q2|Prob151_review2015_fsm)$", 122 | "filter_instance": "^(Prob011_norgate)$", 123 | # "filter_instance": "^(.*)$", 124 | "type_benchmark": "verilog_eval_v2", 125 | "path_benchmark": "../verilog-eval", 126 | "run_identifier": "your_run_identifier", 127 | "n": 1, 128 | "temperature": 0.85, 129 | "top_p": 0.95, 130 | "max_token": 8192, 131 | "use_golden_tb_in_mage": True, 132 | "key_cfg_path": "key.cfg", 133 | } 134 | ``` 135 | Where each argument means: 136 | 1. provider: The api provider of the LLM model used. e.g. anthropic-->claude, openai-->gpt-4o 137 | 2. model: The LLM model used. Support for gpt-4o and claude has been verified. 138 | 3. filter_instance: A RegEx style instance name filter. 139 | 4. type_benchmark: Support running verilog_eval_v1 or verilog_eval_v2 140 | 5. path_benchmark: Where the benchmark repo is cloned 141 | 6. run_identifier: Unique name to disguish different runs 142 | 7. n: Number of repeated run to execute 143 | 8. temperature: Argument for LLM generation randomness. Usually between [0, 1] 144 | 9. top_p: Argument for LLM generation randomness. Usually between [0, 1] 145 | 10. max_token: Maximum number of tokens the model is allowed to generate in its output. 146 | 11. key_cfg_path: Path to your key.cfg file. Defaulted to be under MAGE 147 | 148 | 149 | ## Development Guide 150 | 151 | Run editable install and setup pre-commit like: 152 | ``` 153 | pip install -e . --config-settings editable_mode=compat 154 | pre-commit install 155 | ``` 156 | -------------------------------------------------------------------------------- /action.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | description: run pre-commit 3 | inputs: 4 | extra_args: 5 | description: options to pass to pre-commit run 6 | required: false 7 | default: '--all-files' 8 | runs: 9 | using: composite 10 | steps: 11 | - run: python -m pip install pre-commit 12 | shell: bash 13 | - run: python -m pip freeze --local 14 | shell: bash 15 | - uses: actions/cache@v4 16 | with: 17 | path: ~/.cache/pre-commit 18 | key: pre-commit-3|${{ env.pythonLocation }}|${{ hashFiles('.pre-commit-config.yaml') }} 19 | - run: pre-commit run --show-diff-on-failure --color=always ${{ inputs.extra_args }} 20 | shell: bash 21 | -------------------------------------------------------------------------------- /fig/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stable-lab/MAGE/90c96366c518d32bd7b81703bdef3cc46eeaeaf9/fig/.DS_Store -------------------------------------------------------------------------------- /fig/DAC-checkpoint.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stable-lab/MAGE/90c96366c518d32bd7b81703bdef3cc46eeaeaf9/fig/DAC-checkpoint.pdf -------------------------------------------------------------------------------- /fig/DAC-overview.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stable-lab/MAGE/90c96366c518d32bd7b81703bdef3cc46eeaeaf9/fig/DAC-overview.pdf -------------------------------------------------------------------------------- /fig/DAC-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stable-lab/MAGE/90c96366c518d32bd7b81703bdef3cc46eeaeaf9/fig/DAC-overview.png -------------------------------------------------------------------------------- /fig/Sampling_ablation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stable-lab/MAGE/90c96366c518d32bd7b81703bdef3cc46eeaeaf9/fig/Sampling_ablation.pdf -------------------------------------------------------------------------------- /fig/Sampling_motivation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stable-lab/MAGE/90c96366c518d32bd7b81703bdef3cc46eeaeaf9/fig/Sampling_motivation.pdf -------------------------------------------------------------------------------- /fig/TBD_horizontal.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stable-lab/MAGE/90c96366c518d32bd7b81703bdef3cc46eeaeaf9/fig/TBD_horizontal.pdf -------------------------------------------------------------------------------- /fig/TBD_vertical.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stable-lab/MAGE/90c96366c518d32bd7b81703bdef3cc46eeaeaf9/fig/TBD_vertical.pdf -------------------------------------------------------------------------------- /fig/fig2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stable-lab/MAGE/90c96366c518d32bd7b81703bdef3cc46eeaeaf9/fig/fig2.pdf -------------------------------------------------------------------------------- /fig/fig3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stable-lab/MAGE/90c96366c518d32bd7b81703bdef3cc46eeaeaf9/fig/fig3.pdf -------------------------------------------------------------------------------- /fig/fig3_1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stable-lab/MAGE/90c96366c518d32bd7b81703bdef3cc46eeaeaf9/fig/fig3_1.pdf -------------------------------------------------------------------------------- /fig/fig4-1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stable-lab/MAGE/90c96366c518d32bd7b81703bdef3cc46eeaeaf9/fig/fig4-1.pdf -------------------------------------------------------------------------------- /fig/fig4-2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stable-lab/MAGE/90c96366c518d32bd7b81703bdef3cc46eeaeaf9/fig/fig4-2.pdf -------------------------------------------------------------------------------- /fig/fig_5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stable-lab/MAGE/90c96366c518d32bd7b81703bdef3cc46eeaeaf9/fig/fig_5.pdf -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=64.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "mage" 7 | version = "1.0.1" 8 | description = "MAGE: Open-source multi-agent LLM RTL code generator" 9 | readme = { file = "README.md", content-type = "text/markdown" } 10 | requires-python = ">=3.11" 11 | dependencies = [ 12 | "config", 13 | "fsspec[http]<=2024.9.0,>=2023.1.0", 14 | "httpx<1,>=0.23.0", 15 | "llama-index-core", 16 | "llama-index-llms-anthropic", 17 | "llama-index-llms-openai", 18 | "llama-index-llms-vertex", 19 | "pre-commit", 20 | "pydantic", 21 | "rich", 22 | "tiktoken" 23 | ] 24 | classifiers = [ 25 | "Development Status :: 3 - Alpha", 26 | "Intended Audience :: Developers", 27 | "License :: OSI Approved :: Apache Software License", 28 | "Operating System :: POSIX", 29 | "Programming Language :: Python :: 3" 30 | ] 31 | 32 | [tool.setuptools] 33 | include-package-data = true 34 | zip-safe = false 35 | packages = ["mage"] 36 | 37 | [tool.setuptools.package-dir] 38 | "" = "src" 39 | -------------------------------------------------------------------------------- /src/mage/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stable-lab/MAGE/90c96366c518d32bd7b81703bdef3cc46eeaeaf9/src/mage/__init__.py -------------------------------------------------------------------------------- /src/mage/agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import traceback 5 | from typing import List, Tuple 6 | 7 | from llama_index.core.llms import LLM 8 | 9 | from .log_utils import get_logger, set_log_dir, switch_log_to_file, switch_log_to_stdout 10 | from .rtl_editor import RTLEditor 11 | from .rtl_generator import RTLGenerator 12 | from .sim_judge import SimJudge 13 | from .sim_reviewer import SimReviewer 14 | from .tb_generator import TBGenerator 15 | from .token_counter import TokenCounter, TokenCounterCached 16 | 17 | logger = get_logger(__name__) 18 | 19 | 20 | class TopAgent: 21 | def __init__(self, llm: LLM): 22 | self.llm = llm 23 | self.token_counter = ( 24 | TokenCounterCached(llm) 25 | if TokenCounterCached.is_cache_enabled(llm) 26 | else TokenCounter(llm) 27 | ) 28 | self.sim_max_retry = 4 29 | self.rtl_max_candidates = 20 30 | self.rtl_selected_candidates = 2 31 | self.is_ablation = False 32 | self.redirect_log = False 33 | self.output_path = "./output" 34 | self.log_path = "./log" 35 | self.golden_tb_path: str | None = None 36 | self.golden_rtl_blackbox_path: str | None = None 37 | self.tb_gen: TBGenerator | None = None 38 | self.rtl_gen: RTLGenerator | None = None 39 | self.sim_reviewer: SimReviewer | None = None 40 | self.sim_judge: SimJudge | None = None 41 | self.rtl_edit: RTLEditor | None = None 42 | 43 | def set_output_path(self, output_path: str) -> None: 44 | self.output_path = output_path 45 | 46 | def set_log_path(self, log_path: str) -> None: 47 | self.log_path = log_path 48 | 49 | def set_ablation(self, is_ablation: bool) -> None: 50 | self.is_ablation = is_ablation 51 | 52 | def set_redirect_log(self, new_value: bool) -> None: 53 | self.redirect_log = new_value 54 | if self.redirect_log: 55 | switch_log_to_file() 56 | else: 57 | switch_log_to_stdout() 58 | 59 | def write_output(self, content: str, file_name: str) -> None: 60 | assert self.output_dir_per_run 61 | with open(f"{self.output_dir_per_run}/{file_name}", "w") as f: 62 | f.write(content) 63 | 64 | def run_instance(self, spec: str) -> Tuple[bool, str]: 65 | """ 66 | Run a single instance of the benchmark 67 | Return value: 68 | - is_pass: bool, whether the instance passes the golden testbench 69 | - rtl_code: str, the generated RTL code 70 | """ 71 | assert self.tb_gen 72 | assert self.rtl_gen 73 | assert self.sim_reviewer 74 | assert self.sim_judge 75 | assert self.rtl_edit 76 | 77 | self.tb_gen.reset() 78 | self.tb_gen.set_golden_tb_path(self.golden_tb_path) 79 | if not self.golden_tb_path: 80 | logger.info("No golden testbench provided") 81 | testbench, interface = self.tb_gen.chat(spec) 82 | logger.info("Initial tb:") 83 | logger.info(testbench) 84 | logger.info("Initial if:") 85 | logger.info(interface) 86 | self.write_output(testbench, "tb.sv") 87 | self.write_output(interface, "if.sv") 88 | self.rtl_gen.reset() 89 | logger.info(spec) 90 | 91 | is_syntax_pass, rtl_code = self.rtl_gen.chat( 92 | input_spec=spec, 93 | testbench=testbench, 94 | interface=interface, 95 | rtl_path=os.path.join(self.output_dir_per_run, "rtl.sv"), 96 | ) 97 | if not is_syntax_pass: 98 | return False, rtl_code 99 | self.write_output(rtl_code, "rtl.sv") 100 | logger.info("Initial rtl:") 101 | logger.info(rtl_code) 102 | 103 | tb_need_fix = True 104 | rtl_need_fix = True 105 | sim_log = "" 106 | for i in range(self.sim_max_retry): 107 | # run simulation judge, overwrite is_sim_pass 108 | is_sim_pass, sim_mismatch_cnt, sim_log = self.sim_reviewer.review() 109 | if is_sim_pass: 110 | tb_need_fix = False 111 | rtl_need_fix = False 112 | break 113 | self.sim_judge.reset() 114 | tb_need_fix = self.sim_judge.chat(spec, sim_log, rtl_code, testbench) 115 | if tb_need_fix: 116 | self.tb_gen.reset() 117 | if i == 0: 118 | self.tb_gen.gen_display_queue = False 119 | logger.info("Fallback from display queue to display moment") 120 | else: 121 | self.tb_gen.set_failed_trial(sim_log, rtl_code, testbench) 122 | 123 | testbench, _ = self.tb_gen.chat(spec) 124 | self.write_output(testbench, "tb.sv") 125 | logger.info("Revised tb:") 126 | logger.info(testbench) 127 | else: 128 | break 129 | 130 | assert not tb_need_fix, f"tb_need_fix should be False. sim_log: {sim_log}" 131 | 132 | candidates_info: List[Tuple[str, int, str]] = [] 133 | if rtl_need_fix: 134 | # Candidates Generation 135 | assert ( 136 | sim_mismatch_cnt > 0 137 | ), f"rtl_need_fix should be True only when sim_mismatch_cnt > 0. sim_log: {sim_log}" 138 | self.rtl_gen.reset() 139 | candidates = [ 140 | self.rtl_gen.chat( 141 | input_spec=spec, 142 | testbench=testbench, 143 | interface=interface, 144 | rtl_path=os.path.join(self.output_dir_per_run, "rtl.sv"), 145 | enable_cache=True, 146 | ) 147 | ] # Write Cache 148 | if self.rtl_max_candidates > 1: 149 | candidates += self.rtl_gen.gen_candidates( 150 | input_spec=spec, 151 | testbench=testbench, 152 | interface=interface, 153 | rtl_path=os.path.join(self.output_dir_per_run, "rtl.sv"), 154 | candidates_num=self.rtl_max_candidates - 1, 155 | enable_cache=True, 156 | ) 157 | for i in range(self.rtl_max_candidates): 158 | logger.info( 159 | f"Candidate generation: round {i + 1} / {self.rtl_max_candidates}" 160 | ) 161 | is_syntax_pass_candiate, rtl_code_candidate = candidates[i] 162 | if not is_syntax_pass_candiate: 163 | continue 164 | self.write_output(rtl_code_candidate, "rtl.sv") 165 | is_sim_pass_candidate, sim_mismatch_cnt_candidate, sim_log_candidate = ( 166 | self.sim_reviewer.review() 167 | ) 168 | if is_sim_pass_candidate: 169 | rtl_code = rtl_code_candidate 170 | sim_mismatch_cnt = sim_mismatch_cnt_candidate 171 | sim_log = sim_log_candidate 172 | rtl_need_fix = False 173 | break 174 | candidates_info.append( 175 | (rtl_code_candidate, sim_mismatch_cnt_candidate, sim_log_candidate) 176 | ) 177 | 178 | candidates_info.sort(key=lambda x: x[1]) 179 | candidates_info_unique_sign = set() 180 | candidates_info_unique = [] 181 | for candidate in candidates_info: 182 | if candidate[1] not in candidates_info_unique_sign: 183 | candidates_info_unique_sign.add(candidate[1]) 184 | candidates_info_unique.append(candidate) 185 | 186 | if rtl_need_fix: 187 | # Editor iteration 188 | for i in range(self.rtl_selected_candidates): 189 | logger.info( 190 | f"Selected candidate: round {i + 1} / {self.rtl_selected_candidates}" 191 | ) 192 | i = i % len(candidates_info_unique) 193 | rtl_code, sim_mismatch_cnt, sim_log = candidates_info_unique[i] 194 | with open(f"{self.output_dir_per_run}/rtl.sv", "w") as f: 195 | f.write(rtl_code) 196 | self.rtl_edit.reset() 197 | is_sim_pass, rtl_code = self.rtl_edit.chat( 198 | spec=spec, 199 | output_dir_per_run=self.output_dir_per_run, 200 | sim_failed_log=sim_log, 201 | sim_mismatch_cnt=sim_mismatch_cnt, 202 | ) 203 | if is_sim_pass: 204 | rtl_need_fix = False 205 | break 206 | 207 | if not is_sim_pass: # Run if keep failing before last try 208 | is_sim_pass, _, _ = self.sim_reviewer.review() 209 | 210 | return is_sim_pass, rtl_code 211 | 212 | def run_instance_ablation(self, spec: str) -> Tuple[bool, str]: 213 | """ 214 | Run a single instance of the benchmark in ablation mode 215 | Return value: 216 | - is_pass: bool, whether the instance passes the golden testbench 217 | - rtl_code: str, the generated RTL code 218 | """ 219 | assert self.rtl_gen 220 | 221 | self.rtl_gen.reset() 222 | logger.info(spec) 223 | # Current ablation: only run RTL generation with syntax check 224 | is_syntax_pass, rtl_code = self.rtl_gen.ablation_chat( 225 | input_spec=spec, rtl_path=os.path.join(self.output_dir_per_run, "rtl.sv") 226 | ) 227 | self.write_output(rtl_code, "rtl.sv") 228 | return is_syntax_pass, rtl_code 229 | 230 | def _run(self, spec: str) -> Tuple[bool, str]: 231 | try: 232 | if os.path.exists(f"{self.output_dir_per_run}/properly_finished.tag"): 233 | os.remove(f"{self.output_dir_per_run}/properly_finished.tag") 234 | self.token_counter.reset() 235 | self.sim_reviewer = SimReviewer( 236 | self.output_dir_per_run, 237 | self.golden_rtl_blackbox_path, 238 | ) 239 | self.rtl_gen = RTLGenerator(self.token_counter) 240 | self.tb_gen = TBGenerator(self.token_counter) 241 | self.sim_judge = SimJudge(self.token_counter) 242 | self.rtl_edit = RTLEditor( 243 | self.token_counter, sim_reviewer=self.sim_reviewer 244 | ) 245 | ret = ( 246 | self.run_instance(spec) 247 | if not self.is_ablation 248 | else self.run_instance_ablation(spec) 249 | ) 250 | self.token_counter.log_token_stats() 251 | with open(f"{self.output_dir_per_run}/properly_finished.tag", "w") as f: 252 | f.write("1") 253 | except Exception: 254 | exc_info = sys.exc_info() 255 | traceback.print_exception(*exc_info) 256 | ret = False, f"Exception: {exc_info[1]}" 257 | return ret 258 | 259 | def run( 260 | self, 261 | benchmark_type_name: str, 262 | task_id: str, 263 | spec: str, 264 | golden_tb_path: str | None = None, 265 | golden_rtl_blackbox_path: str | None = None, 266 | ) -> Tuple[bool, str]: 267 | self.golden_tb_path = golden_tb_path 268 | self.golden_rtl_blackbox_path = golden_rtl_blackbox_path 269 | log_dir_per_run = f"{self.log_path}/{benchmark_type_name}_{task_id}" 270 | self.output_dir_per_run = f"{self.output_path}/{benchmark_type_name}_{task_id}" 271 | os.makedirs(self.output_path, exist_ok=True) 272 | os.makedirs(self.output_dir_per_run, exist_ok=True) 273 | set_log_dir(log_dir_per_run) 274 | if self.redirect_log: 275 | with open(f"{log_dir_per_run}/mage_rtl.log", "w") as f: 276 | sys.stdout = f 277 | sys.stderr = f 278 | result = self._run(spec) 279 | sys.stdout = sys.__stdout__ 280 | sys.stderr = sys.__stderr__ 281 | else: 282 | result = self._run(spec) 283 | # Redirect log contains format with rich text. 284 | # Provide a rich-free version for log parsing or less viewing. 285 | if self.redirect_log: 286 | with open(f"{log_dir_per_run}/mage_rtl.log", "r") as f: 287 | content = f.read() 288 | content = re.sub(r"\[.*?m", "", content) 289 | with open(f"{log_dir_per_run}/mage_rtl_rich_free.log", "w") as f: 290 | f.write(content) 291 | return result 292 | -------------------------------------------------------------------------------- /src/mage/bash_tools.py: -------------------------------------------------------------------------------- 1 | import json 2 | from subprocess import PIPE, Popen, TimeoutExpired 3 | from typing import Tuple 4 | 5 | from pydantic import BaseModel 6 | 7 | from .log_utils import get_logger 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | class CommandResult(BaseModel): 13 | stdout: str 14 | stderr: str 15 | 16 | 17 | def run_bash_command(cmd: str, timeout: float | None = None) -> Tuple[bool, str]: 18 | logger.info(f"Running command: {cmd}") 19 | process = Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE, text=True) 20 | try: 21 | stdout, stderr = process.communicate( 22 | timeout=timeout 23 | ) # Set your desired timeout in seconds 24 | except TimeoutExpired: 25 | process.kill() 26 | err_msg = f"Timeout {timeout}s reached." 27 | return ( 28 | False, 29 | json.dumps(CommandResult(stdout="", stderr=err_msg).model_dump(), indent=4), 30 | ) 31 | return ( 32 | process.returncode == 0, 33 | json.dumps(CommandResult(stdout=stdout, stderr=stderr).model_dump(), indent=4), 34 | ) 35 | -------------------------------------------------------------------------------- /src/mage/benchmark_read_helper.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | from enum import Enum 5 | from typing import Dict, Tuple 6 | 7 | from .log_utils import get_logger 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | class TypeBenchmark(Enum): 13 | VERILOG_EVAL_V1 = 1 14 | VERILOG_EVAL_V2 = 2 15 | 16 | 17 | class TypeBenchmarkFile(Enum): 18 | SPEC = 0 19 | TEST_PATH = 1 20 | GOLDEN_PATH = 2 21 | 22 | 23 | def load_json(filename): 24 | des_data = [] 25 | with open(filename, "r") as f: 26 | for line in f: 27 | data = json.loads(line) 28 | des_data.append(data) 29 | return des_data 30 | 31 | 32 | def get_benchmark_contents( 33 | benchmark_type: TypeBenchmark, 34 | file_type: TypeBenchmarkFile, 35 | benchmark_repo: str, 36 | filter_instance: str, 37 | ) -> Dict[str, str]: 38 | """ 39 | Get Dict of {problem_name: problem_content/testbench_content} for given benchmark 40 | """ 41 | if ( 42 | benchmark_type == TypeBenchmark.VERILOG_EVAL_V1 43 | or benchmark_type == TypeBenchmark.VERILOG_EVAL_V2 44 | ): 45 | folder = os.path.join( 46 | benchmark_repo, 47 | ( 48 | "dataset_code-complete-iccad2023" 49 | if benchmark_type == TypeBenchmark.VERILOG_EVAL_V1 50 | else "dataset_spec-to-rtl" 51 | ), 52 | ) 53 | files = os.listdir(folder) 54 | files.sort() 55 | re_str = r"$^" # dummy 56 | if file_type == TypeBenchmarkFile.SPEC: 57 | re_str = r"(.*)_prompt.txt" 58 | elif file_type == TypeBenchmarkFile.TEST_PATH: 59 | re_str = r"(.*)_test.sv" 60 | elif file_type == TypeBenchmarkFile.GOLDEN_PATH: 61 | re_str = r"(.*)_ref.sv" 62 | else: 63 | raise ValueError(f"Invalid file_type: {file_type}") 64 | 65 | def is_target(file_name: str) -> Tuple[str, str] | None: 66 | full_path = os.path.join(folder, file_name) 67 | if not os.path.isfile(full_path): 68 | return None 69 | m = re.match(re_str, file_name) 70 | if not m: 71 | return None 72 | if not re.match(filter_instance, m[1]): 73 | return None 74 | return (m[1], full_path) 75 | 76 | ret = {} 77 | 78 | for file in files: 79 | 80 | p = is_target(file) 81 | if not p: 82 | continue 83 | 84 | if file_type == TypeBenchmarkFile.SPEC: 85 | with open(p[1], "r") as f: 86 | ret[p[0]] = f.read() 87 | elif ( 88 | file_type == TypeBenchmarkFile.TEST_PATH 89 | or file_type == TypeBenchmarkFile.GOLDEN_PATH 90 | ): 91 | ret[p[0]] = p[1] 92 | 93 | return ret 94 | raise ValueError(f"Invalid benchmark_type: {benchmark_type}") 95 | -------------------------------------------------------------------------------- /src/mage/converage/LLMGuidance.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | 11 | #include "LLMGuidance.h" 12 | 13 | using namespace std; 14 | 15 | 16 | LLMGuidance::LLMGuidance(LLMGuidanceConfig config){ 17 | 18 | dut_path_ = config.dut_path; 19 | dut_desc_path_ = config.dut_path + ".des"; 20 | dut_inst_path_ = config.dut_path + ".inst"; 21 | cov_path_ = config.cov_path; 22 | history_dir_path_ = config.history_dir_path; 23 | gpt_output_path_ = config.gpt_output_path; 24 | gpt_input_path_ = config.gpt_input_path; 25 | default_temperature_ = config.temperature; 26 | iter_cnt_max_ = config.iter_cnt_max; 27 | cov_rpst_pattern_ = config.cov_rpst_pattern; 28 | use_dut_des_ = config.use_dut_des; 29 | use_dut_inst_ = config.use_dut_inst; 30 | 31 | 32 | cov_path_total_ = cov_path_ + ".total"; 33 | 34 | // build history directory 35 | string cmd = "mkdir " + history_dir_path_; 36 | system(cmd.c_str()); 37 | 38 | // extract input signals from dut 39 | genSignalPrompt(); 40 | 41 | // launch gpt robot and connect with the python process 42 | system("cd ../llm-guidance && bash setup.sh"); 43 | string cmd1 = "python ../llm-guidance/src/RunGPT.py "; 44 | cmd1 += " -i " + gpt_input_path_; 45 | cmd1 += " -o " + gpt_output_path_; 46 | cmd1 += " &"; 47 | system(cmd1.c_str()); 48 | 49 | pipe_in.open("../llm-guidance/g2v"); 50 | pipe_out.open("../llm-guidance/v2g"); 51 | 52 | // Test GPT Python 53 | pipe_out << "hello"<iter_cnt_max_) { 74 | return 0; 75 | } 76 | 77 | return 1; 78 | } 79 | 80 | vector LLMGuidance::getBitInput() { 81 | 82 | // select strategy 83 | string input_str = covStrategy1(); 84 | // string input_str = genInput4undirectedCov(iter_cnt_==1, ""); 85 | 86 | std::vector bitstream; 87 | 88 | // bool fmt_correct = checkGPTAnswerFormat(input_str); 89 | 90 | // if(!fmt_correct) { 91 | // // throw std::runtime_error("GPT return wrong format\n"); 92 | // cout<<"iter "< save prompt=>input=>coverage of this iteration to history directory 111 | writeHistory(); 112 | 113 | // 2> keep total coverage 114 | 115 | ifstream in_file(cov_path_total_); 116 | string cmd; 117 | if (!in_file.is_open()) { 118 | // The first iteration 119 | cmd = "verilator_coverage -write " + cov_path_total_ + " " + cov_path_; 120 | } else { 121 | // cmd: verilator_coverage -write coverage.dat.total coverage.dat.total coverage.dat 122 | // coverage.dat.total = coverage.dat.total + coverage.dat 123 | cmd = "verilator_coverage -write " + cov_path_total_ + " " + cov_path_total_ + " " + cov_path_; 124 | } 125 | system(cmd.c_str()); 126 | 127 | // 3> update coverage information 128 | pair cur_cov_pair = getCoverageNum(cov_path_total_); 129 | int cur_cov_num = cur_cov_pair.first; 130 | int total_cov_num = cur_cov_pair.second; 131 | 132 | if(cur_cov_num>cur_covered_num_) { 133 | covered_stop_iter_num_ = 0; 134 | cur_covered_num_ = cur_cov_num; 135 | } else { 136 | covered_stop_iter_num_++; 137 | } 138 | 139 | // print log 140 | ofstream outfile(history_dir_path_+"/cov.log", std::ios_base::app); 141 | if (!outfile) { 142 | throw std::runtime_error("Could not open file: " + cov_path_); 143 | } 144 | outfile<<"LLM: Iter = "<< iter_cnt_<<" Clk Cycle = "< string 158 | string LLMGuidance::sendMsg2GPT(string msg, float temperature, bool forget_flag) { 159 | 160 | // Set temprature 161 | pipe_out << "temperature" < tell gpt which signals this DUT has 215 | // 2> tell gpt what is the format of its generated input 216 | void LLMGuidance::genSignalPrompt() { 217 | 218 | // extract signals from pyverilog generated files 219 | string filePath = "../input-signals.txt"; 220 | std::ifstream inFile(filePath); 221 | std::vector> data; 222 | 223 | if (!inFile.is_open()) { 224 | throw std::runtime_error("Could not open file: " + filePath); 225 | } 226 | 227 | std::string line; 228 | while (std::getline(inFile, line)) { 229 | std::istringstream iss(line); 230 | std::string inputSignal; 231 | int width; 232 | 233 | if (!(iss >> inputSignal >> width)) { 234 | throw std::runtime_error("Error parsing line: " + line); 235 | } 236 | 237 | data.emplace_back(inputSignal, width); 238 | } 239 | input_signals_ = data; 240 | if(input_signals_.size()==0) { 241 | throw std::runtime_error("No input signals parsing! Wrong DUT! "); 242 | } 243 | 244 | // generate prompt of input signals 245 | string signals_prompt = "DUT has the following input signals\n"; 246 | 247 | std::stringstream ss; 248 | for (const auto& pair : data) { 249 | ss << "Input signal: " << pair.first << "; Width: " << pair.second << "\n"; 250 | } 251 | signals_prompt += ss.str(); 252 | 253 | input_signal_prompt_ = signals_prompt; 254 | 255 | // generate gpt answer format 256 | // this answer is gpt-generatd input for DUT 257 | string answer_prompt = "Please return the answer with the following format\n"; 258 | answer_prompt += "Each value for signal should in binary format, such as 011..., the binary width should equal to signal width\n"; 259 | 260 | std::stringstream ss1; 261 | ss1 << "clk=1:"; 262 | for (const auto& pair : data) { 263 | ss1 << pair.first << "=x;"; 264 | } 265 | ss1 <<"\n"; 266 | ss1 << "clk=2:"; 267 | for (const auto& pair : data) { 268 | ss1 << pair.first << "=x;"; 269 | } 270 | ss1 <<"\n"; 271 | ss1 << "clk=3:...\n "; 272 | answer_prompt += ss1.str(); 273 | 274 | answer_prompt += "You are only allowed to response strictly in the above format and DO NOT explain any other extra information\n"; 275 | 276 | answer_format_prompt_ = answer_prompt; 277 | 278 | 279 | } 280 | 281 | 282 | // Transform the answer string to bits 283 | vector LLMGuidance::transGPTAnswer2Bits(string answer) { 284 | istringstream iss(answer); 285 | string line; 286 | vector result; 287 | while (getline(iss, line)) { 288 | istringstream lineStream(line); 289 | string token; 290 | 291 | // Skip "clk=x:" 292 | getline(lineStream, token, '='); 293 | if (token != "clk") continue; 294 | getline(lineStream, token, ':'); 295 | 296 | // Parse signal lines 297 | for (const auto& signal : input_signals_) { 298 | getline(lineStream, token, '='); 299 | getline(lineStream, token, ';'); 300 | string binStr = token; 301 | 302 | for (char c : binStr) { 303 | result.push_back(c == '1'); 304 | } 305 | } 306 | 307 | clk_cycle_num_++; 308 | } 309 | 310 | return result; 311 | } 312 | 313 | // Check GPT feedback answer obey our format rules 314 | bool LLMGuidance::checkGPTAnswerFormat(string answer) { 315 | istringstream iss(answer); 316 | string line; 317 | // int clkCount = 1; 318 | 319 | while (getline(iss, line)) { 320 | istringstream lineStream(line); 321 | string token; 322 | 323 | // Check clk 324 | getline(lineStream, token, '='); 325 | if (token != "clk") continue; 326 | 327 | getline(lineStream, token, ':'); 328 | // if (stoi(token) != clkCount) return false; 329 | // ++clkCount; 330 | 331 | // Check signal lines 332 | for (const auto& signal : input_signals_) { 333 | getline(lineStream, token, '='); 334 | if (token != signal.first) return false; 335 | 336 | getline(lineStream, token, ';'); 337 | string binStr = token; 338 | if (binStr.size() != signal.second) { 339 | return false; 340 | } 341 | for (char c : binStr) { 342 | if (c != '0' && c != '1') 343 | return false; 344 | } 345 | } 346 | } 347 | return true; 348 | } 349 | 350 | vector LLMGuidance::transJsonGPTAnswer2Bits(string answer) { 351 | 352 | // Find the start and end of the JSON substring 353 | std::string start_delimiter = R"([)"; 354 | std::string end_delimiter = R"(])"; 355 | std::size_t start_pos = answer.find(start_delimiter); 356 | std::size_t end_pos = answer.find(end_delimiter, start_pos); 357 | std::string json_str = answer.substr(start_pos, end_pos - start_pos + end_delimiter.length()); 358 | 359 | // cout< result; 363 | 364 | for (nlohmann::json::iterator it = jsonObj.begin(); it != jsonObj.end(); ++it) { 365 | // pop the input in json in clk order 366 | for(auto p: input_signals_) { 367 | bool flag = false; 368 | // pop the signals in 1 clk cycle 369 | for (nlohmann::json::iterator obj_it = it->begin(); obj_it != it->end(); ++obj_it) { 370 | string bin_str = obj_it.value(); 371 | if(obj_it.key() == p.first) { 372 | flag = true; 373 | if(bin_str[0]=='x') { 374 | // llm output is 'x' value 375 | for(int i=0;i(); // return a null vector to indicate a answer wrong format situation 382 | // throw std::runtime_error("Get GPT answer in wrong format: Width Mismatch"); 383 | } 384 | for (char c : bin_str) { 385 | result.push_back(c == '1'); 386 | } 387 | break; 388 | } 389 | } 390 | if(!flag) { 391 | return vector(); // return a null vector to indicate a answer wrong format situation 392 | // throw std::runtime_error("Get GPT answer in wrong format: Signal Miss"); 393 | } 394 | } 395 | clk_cycle_num_++; 396 | 397 | //std::cout << "------------------------\n"; 398 | } 399 | 400 | return result; 401 | } 402 | 403 | 404 | 405 | void LLMGuidance::writeHistory() { 406 | // write prompt to history 407 | string prompt_history_file = history_dir_path_ + "/prompt." + to_string(iter_cnt_); 408 | ofstream outfile(prompt_history_file); 409 | if (!outfile) { 410 | throw std::runtime_error("Could not open file: " + prompt_history_file); 411 | } 412 | for(string s:prompt_cur_iter_) { outfile << s; } 413 | outfile.close(); 414 | 415 | // write gpt answer to history 416 | string input_history_file = history_dir_path_ + "/answer." + to_string(iter_cnt_); 417 | ofstream outfile1(input_history_file); 418 | if (!outfile1) { 419 | throw std::runtime_error("Could not open file: " + input_history_file); 420 | } 421 | for(string s:answer_cur_iter_) {outfile1 << s;} 422 | outfile1.close(); 423 | 424 | // write coverage.dat to history 425 | ifstream in_file_cov(cov_path_); 426 | if (!in_file_cov.is_open()) { 427 | throw std::runtime_error("Could not open file: " + cov_path_); 428 | } else { 429 | // cmd: cp coverage.dat history/coverage.dat.x 430 | string cmd = "cp " + cov_path_ + " " + history_dir_path_ + "/coverage.dat." + to_string(iter_cnt_); 431 | system(cmd.c_str()); 432 | } 433 | } 434 | 435 | string LLMGuidance::getGptOutputJsonFormat() { 436 | string task = R"(Generate the input sequence in binary format, with the binary width matching the width of the respective signal. 437 | If a DUT has a clk signal, a REQUEST1 signal and a REQUEST2 signal, the input sequence should be presented in the following JSON format:)"; 438 | string json_string = R"([ 439 | {"clk":"1","REQUEST1":"x","REQUEST2":"x"}, 440 | {"clk":"2","REQUEST1":"x","REQUEST2":"x"}, 441 | {"clk":"3","REQUEST1":"x","REQUEST2":"x"} 442 | ] 443 | )"; 444 | 445 | return task + json_string; 446 | } 447 | 448 | 449 | string LLMGuidance::covStrategy1() { 450 | 451 | string input_str = genInput4undirectedCov(iter_cnt_==1, default_temperature_ + covered_stop_iter_num_ * 0.1); 452 | 453 | return input_str; 454 | 455 | } 456 | -------------------------------------------------------------------------------- /src/mage/converage/LLMGuidance.h: -------------------------------------------------------------------------------- 1 | 2 | #pragma once 3 | 4 | #include 5 | #include 6 | #include "../../src-basic/Guidance.h" 7 | using namespace std; 8 | 9 | typedef struct LLMGuidanceConfig { 10 | string dut_path; 11 | string cov_path; 12 | string history_dir_path; 13 | string gpt_input_path; 14 | string gpt_output_path; 15 | int iter_cnt_max; 16 | float temperature; 17 | int cov_rpst_pattern; // 0 llm rpst; 1 num rpst; 2 original dat rpst 18 | bool use_dut_des; 19 | bool use_dut_inst; 20 | } LLMGuidanceConfig; 21 | 22 | class LLMGuidance : public Guidance { 23 | public: 24 | 25 | LLMGuidance(LLMGuidanceConfig config); 26 | ~LLMGuidance(); 27 | 28 | // If Guidance prepared for a new loop, return true 29 | int waitForInput(); 30 | 31 | // Get input from Guidance 32 | std::vector getBitInput(); 33 | 34 | // Send hardware simulation coverage as feedback to Guidance 35 | int sendCovFeedback(); 36 | 37 | protected: 38 | // send msg to gpt and return answer 39 | string sendMsg2GPT(string msg, float temperature, bool forget_flag); 40 | 41 | // We use string to represent coverage 42 | // In undirectedCov, we sent current coverage to gpt and hope to cover the left coverage 43 | // If this is the first time to generate input, we try to cover at most coverage 44 | virtual string genInput4undirectedCov(bool first_flag = false, float temperature = 0.5) = 0; 45 | 46 | // In directedCov, we sent our target to gpt and hope to cover these coverpoints 47 | virtual string genInput4directedCov(string obj_cov = "") = 0; 48 | 49 | // Transform original coverage.dat generate by verilator to gpt-readable coverage representation 50 | virtual string undirectedCovRpst() = 0; 51 | 52 | // Transform target separate coverage to gpt-readable coverage representation 53 | virtual string directedCovRpst() = 0; 54 | 55 | // Generate DUT input signals with pyverilog 56 | // 1> tell gpt which signals this DUT has 57 | // 2> tell gpt what is the format of its generated input 58 | void genSignalPrompt(); 59 | 60 | string getGptOutputJsonFormat(); 61 | 62 | // old format not json 63 | // clk:x; input1: x, input2:x ... 64 | // Check whether GPT feedback answer obey our format rules 65 | vector transGPTAnswer2Bits(string answer); 66 | // Transform the answer string to bits 67 | bool checkGPTAnswerFormat(string answer); 68 | 69 | // json format translator 70 | vector transJsonGPTAnswer2Bits(string answer); 71 | 72 | 73 | // Log 74 | void writeHistory(); 75 | 76 | // Test generation strategies 77 | string covStrategy1(); 78 | 79 | 80 | 81 | protected: 82 | // DUT Information 83 | // path of dut verilog file 84 | string dut_path_; 85 | // description of DUT 86 | string dut_desc_path_; 87 | // description of DUT 88 | string dut_inst_path_; 89 | // path that verilator write back coverage result, it is the cov data of one QA iteration 90 | string cov_path_; 91 | // path that save the total coverage from test beginning 92 | string cov_path_total_; 93 | // DUT input signals names and width 94 | vector> input_signals_; 95 | 96 | 97 | // GPT Model Information 98 | // path that python read prompt and write gpt feedback answer 99 | string gpt_output_path_; 100 | string gpt_input_path_; 101 | 102 | // llm parameter 103 | float default_temperature_; 104 | 105 | // Prompt Information 106 | // signal prompt to tell gpt which signals dut has 107 | string input_signal_prompt_; 108 | // signal prompt to tell gpt the generated input format 109 | string answer_format_prompt_; 110 | // Prompt Setting 111 | // coverage report sent to llm 112 | // 0 llm-readable cov rpst; 113 | // 1 verilator annotated cov rpst; 114 | // 2 verilator-provided original dat rpst 115 | int cov_rpst_pattern_; 116 | // whether use dut description 117 | bool use_dut_des_; 118 | // whether use manual instruction 119 | bool use_dut_inst_; 120 | 121 | // History Information 122 | // pair.first: input (in our format) 123 | // pair.second: true simulation cov (in origin coverage.dat format) 124 | vector> history_cov; 125 | string history_dir_path_; 126 | // record prompt gpt-generated input for this iteration 127 | vector answer_cur_iter_; 128 | vector prompt_cur_iter_; 129 | 130 | // Statistic Information 131 | int iter_cnt_ = 0; 132 | int iter_cnt_max_; 133 | int cur_covered_num_ = 0; 134 | int covered_stop_iter_num_ = 0; //record iteration times that total coverage cannot improve 135 | int clk_cycle_num_ = 0; 136 | 137 | private: 138 | // connect with gpt python 139 | std::ifstream pipe_in; 140 | std::ofstream pipe_out; 141 | }; 142 | -------------------------------------------------------------------------------- /src/mage/converage/LLMGuidance4CodeCov.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "LLMGuidance4CodeCov.h" 7 | 8 | using namespace std; 9 | 10 | LLMGuidance4CodeCov::LLMGuidance4CodeCov(LLMGuidanceConfig config) 11 | : LLMGuidance(config) { 12 | 13 | } 14 | 15 | string LLMGuidance4CodeCov::genInput4undirectedCov(bool first_flag, float temperature) { 16 | // cout<<"==========="<> number; 159 | getline(iss, rest); 160 | 161 | // line with cover points 162 | if(number[0]=='0' || number[0]=='%') { 163 | // not line with 'if' 164 | if (rest.find("if") == string::npos) { 165 | // to be cover 166 | if(number[0]=='%') { 167 | rest += " // TO BE COVER"; 168 | } 169 | } 170 | string space = string(number.length(), ' '); 171 | oss << space << rest << '\n'; 172 | } 173 | else { 174 | oss << line << "\n"; 175 | } 176 | } 177 | } 178 | // return verilator_coverage dat 179 | return oss.str(); 180 | 181 | } 182 | 183 | // Transform target separate coverage to gpt-readable coverage representation 184 | // random select a coverage point to represent with DUT 185 | string LLMGuidance4CodeCov::directedCovRpst() { 186 | return ""; 187 | 188 | } 189 | -------------------------------------------------------------------------------- /src/mage/converage/LLMGuidance4CodeCov.h: -------------------------------------------------------------------------------- 1 | #include "LLMGuidance.h" 2 | 3 | class LLMGuidance4CodeCov : public LLMGuidance { 4 | public: 5 | 6 | LLMGuidance4CodeCov(LLMGuidanceConfig config); 7 | ~LLMGuidance4CodeCov() {}; 8 | 9 | private: 10 | 11 | string genInput4undirectedCov(bool first_flag = false, float temperature = 0.5); 12 | 13 | // In directedCov, we sent our target to gpt and hope to cover these coverpoints 14 | string genInput4directedCov(string obj_cov = ""); 15 | 16 | // Transform original coverage.dat generate by verilator to gpt-readable coverage representation 17 | string undirectedCovRpst(); 18 | 19 | // Transform target separate coverage to gpt-readable coverage representation 20 | string directedCovRpst(); 21 | 22 | }; 23 | -------------------------------------------------------------------------------- /src/mage/converage/RunGPT.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from openai import OpenAI 4 | 5 | client = OpenAI( 6 | # This is the default and can be omitted 7 | api_key="####" 8 | ) 9 | 10 | 11 | if __name__ == "__main__": 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("-i", "--input", type=str) 15 | parser.add_argument("-o", "--output", type=str) 16 | args = parser.parse_args() 17 | 18 | pipe_out = open("../llm-guidance/g2v", "w") 19 | pipe_in = open("../llm-guidance/v2g", "r") 20 | 21 | history = [] 22 | temperature = 0.0 23 | 24 | # model = "gpt-4" 25 | model = "gpt-4-0125-preview" 26 | 27 | while True: 28 | cmd = pipe_in.readline()[:-1] 29 | if cmd == "": 30 | break 31 | 32 | # connect testing 33 | if cmd == "hello": 34 | pipe_out.write("hello_end\n") 35 | pipe_out.flush() 36 | 37 | # create a new robot (forget the history) 38 | elif cmd == "new": 39 | history = [] 40 | pipe_out.write("new_end\n") 41 | pipe_out.flush() 42 | 43 | # set temperature 44 | elif cmd == "temperature": 45 | t = float(pipe_in.readline()[:-1]) 46 | temperature = t 47 | pipe_out.write("temperature_end\n") 48 | pipe_out.flush() 49 | 50 | # ask to gpt 51 | # before asking, the prompt should be prepared in 'input' history 52 | elif cmd == "prompt": 53 | # read input 54 | with open(args.input, "r") as file: 55 | prompt = file.read() 56 | 57 | history.append({"role": "user", "content": prompt}) 58 | 59 | # print(history) 60 | 61 | response = client.chat.completions.create( 62 | model=model, messages=history, temperature=temperature 63 | ) 64 | answer = response.choices[0].message.content 65 | 66 | # write answer back 67 | with open(args.output, "w") as f: 68 | f.write(answer) 69 | 70 | history.append({"role": "assistant", "content": answer}) 71 | 72 | # print(answer) 73 | 74 | pipe_out.write("prompt_end\n") 75 | pipe_out.flush() 76 | 77 | # final operation 78 | elif cmd == "exit": 79 | break 80 | -------------------------------------------------------------------------------- /src/mage/gen_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import config 4 | from google.oauth2 import service_account 5 | from llama_index.core.llms.llm import LLM 6 | from llama_index.llms.anthropic import Anthropic 7 | from llama_index.llms.openai import OpenAI 8 | from llama_index.llms.vertex import Vertex 9 | from pydantic import BaseModel 10 | 11 | from .log_utils import get_logger 12 | from .utils import VertexAnthropicWithCredentials 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | class Config: 18 | def __init__(self, file_path=None): 19 | self.file_path = file_path 20 | self.file_config = {} 21 | if self.file_path and os.path.isfile(self.file_path): 22 | self.file_config = config.Config(self.file_path) 23 | self.fallback_config = {} 24 | self.fallback_config["OPENAI_API_BASE_URL"] = "" 25 | 26 | def __getitem__(self, index): 27 | # Values in key.cfg has priority over env variables 28 | if index in self.file_config: 29 | return self.file_config[index] 30 | if index in os.environ: 31 | return os.environ[index] 32 | if index in self.fallback_config: 33 | return self.fallback_config[index] 34 | raise KeyError( 35 | f"Cannot find {index} in either cfg file '{self.file_path}' or env variables" 36 | ) 37 | 38 | 39 | def get_llm(**kwargs) -> LLM: 40 | cfg = Config(kwargs["cfg_path"]) 41 | provider: str = kwargs["provider"] 42 | provider = provider.lower() 43 | if provider == "anthropic": 44 | try: 45 | llm: LLM = Anthropic( 46 | model=kwargs["model"], 47 | api_key=cfg["ANTHROPIC_API_KEY"], 48 | max_tokens=kwargs["max_token"], 49 | ) 50 | 51 | except Exception as e: 52 | raise Exception(f"gen_config: Failed to get {provider} LLM") from e 53 | elif kwargs["provider"] == "openai": 54 | try: 55 | llm: LLM = OpenAI( 56 | model=kwargs["model"], 57 | api_key=cfg["OPENAI_API_KEY"], 58 | max_tokens=kwargs["max_token"], 59 | ) 60 | 61 | except Exception as e: 62 | raise Exception(f"gen_config: Failed to get {provider} LLM") from e 63 | elif kwargs["provider"] == "vertex": 64 | logger.warning( 65 | "Support of Vertex Gemini LLMs is still in experimental stage, use with caution" 66 | ) 67 | service_account_path = os.path.expanduser(cfg["VERTEX_SERVICE_ACCOUNT_PATH"]) 68 | if not os.path.exists(service_account_path): 69 | raise FileNotFoundError( 70 | f"Google Cloud Service Account file not found: {service_account_path}" 71 | ) 72 | try: 73 | credentials = service_account.Credentials.from_service_account_file( 74 | service_account_path 75 | ) 76 | llm: LLM = Vertex( 77 | model=kwargs["model"], 78 | project=credentials.project_id, 79 | credentials=credentials, 80 | max_tokens=kwargs["max_token"], 81 | ) 82 | 83 | except Exception as e: 84 | raise Exception(f"gen_config: Failed to get {provider} LLM") from e 85 | elif kwargs["provider"] == "vertexanthropic": 86 | service_account_path = os.path.expanduser(cfg["VERTEX_SERVICE_ACCOUNT_PATH"]) 87 | if not os.path.exists(service_account_path): 88 | raise FileNotFoundError( 89 | f"Google Cloud Service Account file not found: {service_account_path}" 90 | ) 91 | try: 92 | credentials = service_account.Credentials.from_service_account_file( 93 | service_account_path, 94 | scopes=["https://www.googleapis.com/auth/cloud-platform"], 95 | ) 96 | llm: LLM = VertexAnthropicWithCredentials( 97 | model=kwargs["model"], 98 | project_id=credentials.project_id, 99 | credentials=credentials, 100 | region=cfg["VERTEX_REGION"], 101 | max_tokens=kwargs["max_token"], 102 | ) 103 | 104 | except Exception as e: 105 | raise Exception(f"gen_config: Failed to get {provider} LLM") from e 106 | else: 107 | raise ValueError(f"gen_config: Invalid provider: {provider}") 108 | 109 | try: 110 | _ = llm.complete("Say 'Hi'") 111 | except Exception as e: 112 | raise Exception( 113 | f"gen_config: Failed to complete LLM chat for {provider}" 114 | ) from e 115 | 116 | return llm 117 | 118 | 119 | class ExperimentSetting(BaseModel): 120 | """ 121 | Global setting for experiment 122 | """ 123 | 124 | temperature: float = 0.85 # Chat temperature 125 | top_p: float = 0.95 # Chat top_p 126 | 127 | 128 | global_exp_setting = ExperimentSetting() 129 | 130 | 131 | def get_exp_setting() -> ExperimentSetting: 132 | return global_exp_setting 133 | 134 | 135 | def set_exp_setting(temperature: float | None = None, top_p: float | None = None): 136 | if temperature is not None: 137 | global_exp_setting.temperature = temperature 138 | if top_p is not None: 139 | global_exp_setting.top_p = top_p 140 | return global_exp_setting 141 | -------------------------------------------------------------------------------- /src/mage/log_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Dict 4 | 5 | from rich.logging import RichHandler 6 | 7 | logging.basicConfig(level=logging.INFO) 8 | 9 | 10 | class LoggingManager: 11 | def __init__(self): 12 | self.loggers: Dict[str, logging.Logger] = {} 13 | self.current_log_dir = "" 14 | self.use_stdout = True 15 | self.rich_handler = RichHandler( 16 | show_time=bool(os.environ.get("LLM4RTL_LOG_TIME", False)), 17 | show_path=bool(os.environ.get("LLM4RTL_LOG_PATH", False)), 18 | ) 19 | self.rich_handler.setLevel(logging.DEBUG) 20 | 21 | def get_logger(self, name: str) -> logging.Logger: 22 | if name in self.loggers: 23 | return self.loggers[name] 24 | 25 | logger = logging.getLogger(name) 26 | logger.setLevel(logging.DEBUG) 27 | 28 | # Add the handler to the logger 29 | logger.addHandler(self.rich_handler) 30 | logger.propagate = False 31 | 32 | # Store the logger in our dictionary 33 | self.loggers[name] = logger 34 | 35 | return logger 36 | 37 | def set_log_dir(self, new_dir: str) -> None: 38 | if self.current_log_dir == new_dir: 39 | return 40 | self.current_log_dir = new_dir 41 | 42 | # Ensure the new directory exists 43 | os.makedirs(self.current_log_dir, exist_ok=True) 44 | 45 | if not self.use_stdout: 46 | self._update_handlers() 47 | 48 | def switch_to_file(self) -> None: 49 | if not self.use_stdout: 50 | return 51 | self.use_stdout = False 52 | if self.current_log_dir: 53 | self._update_handlers() 54 | 55 | def switch_to_stdout(self) -> None: 56 | if self.use_stdout: 57 | return 58 | self.use_stdout = True 59 | self._update_handlers() 60 | 61 | def _update_handlers(self) -> None: 62 | assert self.current_log_dir and os.path.isdir(self.current_log_dir) 63 | 64 | formatter = logging.Formatter( 65 | "[%(asctime)s - %(name)s - %(levelname)s] %(message)s" 66 | ) 67 | 68 | if self.use_stdout: 69 | for _, logger in self.loggers.items(): 70 | # Remove existing handlers 71 | for handler in logger.handlers[:]: 72 | logger.removeHandler(handler) 73 | logger.addHandler(self.rich_handler) 74 | return 75 | 76 | unified_log_file = os.path.join(self.current_log_dir, f"mage_rtl_total.log") 77 | if os.path.exists(unified_log_file): 78 | os.remove(unified_log_file) 79 | unified_file_handler = logging.FileHandler(unified_log_file) 80 | unified_file_handler.setLevel(logging.DEBUG) 81 | unified_file_handler.setFormatter(formatter) 82 | 83 | for name, logger in self.loggers.items(): 84 | # Remove existing handlers 85 | for handler in logger.handlers[:]: 86 | logger.removeHandler(handler) 87 | 88 | # Add new handler 89 | new_log_file = os.path.join(self.current_log_dir, f"{name}.log") 90 | if os.path.exists(new_log_file): 91 | os.remove(new_log_file) 92 | new_handler = logging.FileHandler(new_log_file) 93 | new_handler.setLevel(logging.DEBUG) 94 | new_handler.setFormatter(formatter) 95 | 96 | logger.addHandler(new_handler) 97 | logger.addHandler(unified_file_handler) 98 | 99 | 100 | # Global LoggingManager instance 101 | logging_manager = LoggingManager() 102 | 103 | 104 | # Convenience functions to match the original API 105 | def get_logger(name: str) -> logging.Logger: 106 | return logging_manager.get_logger(name) 107 | 108 | 109 | def set_log_dir(new_dir: str) -> None: 110 | logging_manager.set_log_dir(new_dir) 111 | 112 | 113 | def switch_log_to_file() -> None: 114 | logging_manager.switch_to_file() 115 | 116 | 117 | def switch_log_to_stdout() -> None: 118 | logging_manager.switch_to_stdout() 119 | -------------------------------------------------------------------------------- /src/mage/prompts.py: -------------------------------------------------------------------------------- 1 | RTL_4_SHOT_EXAMPLES = """ 2 | Here are some examples of RTL SystemVerilog code: 3 | Example 1: 4 | 5 | 6 | Implement the SystemVerilog module based on the following description. 7 | Assume that sigals are positive clock/clk triggered unless otherwise stated. 8 | 9 | The module should implement a XOR gate. 10 | 11 | 12 | module TopModule( 13 | input logic in0, 14 | input logic in1, 15 | output logic out 16 | ); 17 | 18 | assign out = in0 ^ in1; 19 | 20 | endmodule 21 | 22 | 23 | Example 2: 24 | 25 | 26 | Implement the SystemVerilog module based on the following description. 27 | Assume that sigals are positive clock/clk triggered unless otherwise stated. 28 | 29 | The module should implement an 8-bit registered incrementer. 30 | The 8-bit input is first registered and then incremented by one on the next cycle. 31 | The reset input is active high synchronous and should reset the output to zero. 32 | 33 | 34 | module TopModule( 35 | input logic clk, 36 | input logic reset, 37 | input logic [7:0] in_, 38 | output logic [7:0] out 39 | ); 40 | 41 | // Sequential logic 42 | logic [7:0] reg_out; 43 | always @( posedge clk ) begin 44 | if ( reset ) 45 | reg_out <= 0; 46 | else 47 | reg_out <= in_; 48 | end 49 | 50 | // Combinational logic 51 | logic [7:0] temp_wire; 52 | always @(*) begin 53 | temp_wire = reg_out + 1; 54 | end 55 | 56 | // Structural connections 57 | assign out = temp_wire; 58 | 59 | endmodule 60 | 61 | 62 | Example 3: 63 | 64 | 65 | Implement the SystemVerilog module based on the following description. 66 | Assume that sigals are positive clock/clk triggered unless otherwise stated. 67 | 68 | The module should implement an n-bit registered incrementer where the bitwidth is specified by the parameter nbits. 69 | The n-bit input is first registered and then incremented by one on the next cycle. 70 | The reset input is active high synchronous and should reset the output to zero. 71 | 72 | 73 | module TopModule #( 74 | parameter nbits 75 | )( 76 | input logic clk, 77 | input logic reset, 78 | input logic [nbits-1:0] in_, 79 | output logic [nbits-1:0] out 80 | ); 81 | 82 | // Sequential logic 83 | logic [nbits-1:0] reg_out; 84 | always @( posedge clk ) begin 85 | if ( reset ) 86 | reg_out <= 0; 87 | else 88 | reg_out <= in_; 89 | end 90 | 91 | // Combinational logic 92 | logic [nbits-1:0] temp_wire; 93 | always @(*) begin 94 | temp_wire = reg_out + 1; 95 | end 96 | 97 | // Structural connections 98 | assign out = temp_wire; 99 | 100 | endmodule 101 | 102 | 103 | Example 4: 104 | 105 | 106 | Implement the SystemVerilog module based on the following description. 107 | Assume that sigals are positive clock/clk triggered unless otherwise stated. 108 | 109 | Build a finite-state machine that takes as input a serial bit stream, 110 | and outputs a one whenever the bit stream contains two consecutive one's. 111 | The output is one on the cycle _after_ there are two consecutive one's. 112 | The reset input is active high synchronous, 113 | and should reset the finite-state machine to an appropriate initial state. 114 | 115 | 116 | module TopModule( 117 | input logic clk, 118 | input logic reset, 119 | input logic in_, 120 | output logic out 121 | ); 122 | 123 | // State enum 124 | localparam STATE_A = 2'b00; 125 | localparam STATE_B = 2'b01; 126 | localparam STATE_C = 2'b10; 127 | 128 | // State register 129 | logic [1:0] state; 130 | logic [1:0] state_next; 131 | always @(posedge clk) begin 132 | if ( reset ) 133 | state <= STATE_A; 134 | else 135 | state <= state_next; 136 | end 137 | 138 | // Next state combinational logic 139 | always @(*) begin 140 | state_next = state; 141 | case ( state ) 142 | STATE_A: state_next = ( in_ ) ? STATE_B : STATE_A; 143 | STATE_B: state_next = ( in_ ) ? STATE_C : STATE_A; 144 | STATE_C: state_next = ( in_ ) ? STATE_C : STATE_A; 145 | endcase 146 | end 147 | 148 | // Output combinational logic 149 | always @(*) begin 150 | out = 1'b0; 151 | case ( state ) 152 | STATE_A: out = 1'b0; 153 | STATE_B: out = 1'b0; 154 | STATE_C: out = 1'b1; 155 | endcase 156 | end 157 | 158 | endmodule 159 | 160 | 161 | """ 162 | 163 | TB_4_SHOT_EXAMPLES = """ 164 | Here are some examples of SystemVerilog testbench code: 165 | Example 1: 166 | 167 | 168 | Implement the SystemVerilog module based on the following description. 169 | Assume that sigals are positive clock/clk triggered unless otherwise stated. 170 | 171 | The module should implement a XOR gate. 172 | 173 | 174 | module TopModule( 175 | input logic in0, 176 | input logic in1, 177 | output logic out 178 | ); 179 | 180 | 181 | module TopModule_tb(); 182 | // Signal declarations 183 | logic in0; 184 | logic in1; 185 | logic out; 186 | logic expected_out; 187 | int mismatch_count; 188 | 189 | // Instantiate the Device Under Test (DUT) 190 | TopModule dut ( 191 | .in0(in0), 192 | .in1(in1), 193 | .out(out) 194 | ); 195 | 196 | // Expected output calculation 197 | assign expected_out = in0 ^ in1; 198 | 199 | // Initialize signals 200 | initial begin 201 | // Initialize signals 202 | in0 = 0; 203 | in1 = 0; 204 | mismatch_count = 0; 205 | 206 | // Test all input combinations 207 | for (int i = 0; i < 4; i++) begin 208 | {in0, in1} = i; 209 | #10; // Wait for outputs to settle 210 | 211 | // Check for mismatches 212 | if (out !== expected_out) begin 213 | $display("Mismatch at time %0t:", $time); 214 | $display(" Inputs: in0=%b, in1=%b", in0, in1); 215 | $display(" Expected output: %b, Actual output: %b", expected_out, out); 216 | mismatch_count++; 217 | end else begin 218 | $display("Match at time %0t:", $time); 219 | $display(" Inputs: in0=%b, in1=%b", in0, in1); 220 | $display(" Output: %b", out); 221 | end 222 | end 223 | 224 | // Display final simulation results 225 | #10; 226 | if (mismatch_count == 0) 227 | $display("SIMULATION PASSED"); 228 | else 229 | $display("SIMULATION FAILED - %0d mismatches detected", mismatch_count); 230 | 231 | $finish; 232 | end 233 | 234 | // Optional: Generate VCD file for waveform viewing 235 | initial begin 236 | $dumpfile("xor_test.vcd"); 237 | $dumpvars(0, TopModule_tb); 238 | end 239 | endmodule 240 | 241 | 242 | Example 2: 243 | 244 | 245 | Implement the SystemVerilog module based on the following description. 246 | Assume that sigals are positive clock/clk triggered unless otherwise stated. 247 | 248 | The module should implement an 8-bit registered incrementer. 249 | The 8-bit input is first registered and then incremented by one on the next cycle. 250 | The reset input is active high synchronous and should reset the output to zero. 251 | 252 | 253 | module TopModule( 254 | input logic clk, 255 | input logic reset, 256 | input logic [7:0] in_, 257 | output logic [7:0] out 258 | ); 259 | 260 | 261 | module TopModule_tb(); 262 | // Signal declarations 263 | logic clk; 264 | logic reset; 265 | logic [7:0] in_; 266 | logic [7:0] out; 267 | logic [7:0] expected_out; 268 | 269 | // Mismatch counter 270 | int mismatch_count; 271 | 272 | // Instantiate the DUT (Design Under Test) 273 | TopModule dut ( 274 | .clk(clk), 275 | .reset(reset), 276 | .in_(in_), 277 | .out(out) 278 | ); 279 | 280 | // Clock generation 281 | always begin 282 | clk = 0; 283 | #5; 284 | clk = 1; 285 | #5; 286 | end 287 | 288 | // Test stimulus 289 | initial begin 290 | // Initialize signals 291 | reset = 0; 292 | in_ = 8'h00; 293 | mismatch_count = 0; 294 | expected_out = 8'h00; 295 | 296 | // Reset check 297 | @(posedge clk); 298 | reset = 1; 299 | @(posedge clk); 300 | @(negedge clk); 301 | check_output(); 302 | 303 | reset = 0; 304 | 305 | // Test case 1: Normal increment operation 306 | for (int i = 0; i < 10; i++) begin 307 | in_ = $urandom_range(0, 255); 308 | @(posedge clk); // Wait for input to be registered 309 | expected_out = in_; // First cycle: input gets registered 310 | @(negedge clk); 311 | check_output(); 312 | 313 | @(posedge clk); // Wait for increment 314 | expected_out = in_ + 1; // Second cycle: registered value gets incremented 315 | @(negedge clk); 316 | check_output(); 317 | end 318 | 319 | // Test case 2: Overflow condition 320 | in_ = 8'hFF; 321 | @(posedge clk); 322 | expected_out = 8'hFF; 323 | @(negedge clk); 324 | check_output(); 325 | 326 | @(posedge clk); 327 | expected_out = 8'h00; // Should overflow to 0 328 | @(negedge clk); 329 | check_output(); 330 | 331 | // Test case 3: Reset during operation 332 | in_ = 8'h55; 333 | @(posedge clk); 334 | expected_out = 8'h55; 335 | @(negedge clk); 336 | check_output(); 337 | 338 | reset = 1; 339 | @(posedge clk); 340 | expected_out = 8'h00; // Should reset to 0 341 | @(negedge clk); 342 | check_output(); 343 | 344 | // End simulation 345 | if (mismatch_count == 0) 346 | $display("SIMULATION PASSED"); 347 | else 348 | $display("SIMULATION FAILED with %0d mismatches", mismatch_count); 349 | 350 | $finish; 351 | end 352 | 353 | // Task to check output and log mismatches 354 | task check_output(); 355 | if (out !== expected_out) begin 356 | $display("Time %0t: Mismatch detected!", $time); 357 | $display("Input = %h, Expected output = %h, Actual output = %h", 358 | in_, expected_out, out); 359 | mismatch_count++; 360 | end else begin 361 | $display("Time %0t: Match detected!", $time); 362 | $display("Input = %h, Output = %h", in_, out); 363 | end 364 | endtask 365 | 366 | endmodule 367 | 368 | 369 | Example 3: 370 | 371 | 372 | Implement the SystemVerilog module based on the following description. 373 | Assume that sigals are positive clock/clk triggered unless otherwise stated. 374 | 375 | The module should implement an n-bit registered incrementer where the bitwidth is specified by the parameter nbits. 376 | The n-bit input is first registered and then incremented by one on the next cycle. 377 | The reset input is active high synchronous and should reset the output to zero. 378 | 379 | 380 | module TopModule #( 381 | parameter nbits 382 | )( 383 | input logic clk, 384 | input logic reset, 385 | input logic [nbits-1:0] in_, 386 | output logic [nbits-1:0] out 387 | ); 388 | 389 | 390 | `timescale 1ns/1ps 391 | 392 | module TopModule_tb(); 393 | 394 | // Parameters 395 | parameter nbits = 8; 396 | parameter CLK_PERIOD = 10; 397 | 398 | // Signals declaration 399 | logic clk; 400 | logic reset; 401 | logic [nbits-1:0] in_; 402 | logic [nbits-1:0] out; 403 | logic [nbits-1:0] expected_out; 404 | 405 | // Counter for mismatches 406 | int mismatch_count; 407 | 408 | // DUT instantiation 409 | TopModule #( 410 | .nbits(nbits) 411 | ) dut ( 412 | .clk(clk), 413 | .reset(reset), 414 | .in_(in_), 415 | .out(out) 416 | ); 417 | 418 | // Clock generation 419 | initial begin 420 | clk = 0; 421 | forever #(CLK_PERIOD/2) clk = ~clk; 422 | end 423 | 424 | // Test stimulus 425 | initial begin 426 | // Initialize signals 427 | reset = 1; 428 | in_ = 0; 429 | mismatch_count = 0; 430 | expected_out = 0; 431 | 432 | // Wait for 2 clock cycles in reset 433 | repeat(2) @(posedge clk); 434 | 435 | // Release reset 436 | reset = 0; 437 | 438 | // Test case 1: Regular increment 439 | for(int i = 0; i < 10; i++) begin 440 | in_ = $random; 441 | @(posedge clk); 442 | expected_out = in_; 443 | @(negedge clk); 444 | check_output(); 445 | @(posedge clk); 446 | expected_out = expected_out + 1; 447 | @(negedge clk); 448 | check_output(); 449 | end 450 | 451 | // Test case 2: Reset during operation 452 | in_ = 8'hAA; 453 | @(posedge clk); 454 | reset = 1; 455 | @(posedge clk); 456 | expected_out = 0; 457 | @(negedge clk); 458 | check_output(); 459 | 460 | // Test case 3: Overflow condition 461 | reset = 0; 462 | in_ = {nbits{1'b1}}; // All ones 463 | @(posedge clk); 464 | @(posedge clk); 465 | expected_out = 0; 466 | @(negedge clk); 467 | check_output(); 468 | 469 | // End simulation 470 | #(CLK_PERIOD); 471 | if(mismatch_count == 0) 472 | $display("SIMULATION PASSED"); 473 | else 474 | $display("SIMULATION FAILED with %0d mismatches", mismatch_count); 475 | 476 | $finish; 477 | end 478 | 479 | // Task to check output and log mismatches 480 | task check_output(); 481 | if (out !== expected_out) begin 482 | $display("Time %0t: Mismatch detected!", $time); 483 | $display("Input = %h, Expected output = %h, Actual output = %h", 484 | in_, expected_out, out); 485 | mismatch_count++; 486 | end else begin 487 | $display("Time %0t: Match detected!", $time); 488 | $display("Input = %h, Output = %h", in_, out); 489 | end 490 | endtask 491 | 492 | endmodule 493 | 494 | 495 | Example 4: 496 | 497 | 498 | Implement the SystemVerilog module based on the following description. 499 | Assume that sigals are positive clock/clk triggered unless otherwise stated. 500 | 501 | Build a finite-state machine that takes as input a serial bit stream, 502 | and outputs a one whenever the bit stream contains two consecutive one's. 503 | The output is one on the cycle _after_ there are two consecutive one's. 504 | The reset input is active high synchronous, 505 | and should reset the finite-state machine to an appropriate initial state. 506 | 507 | 508 | module TopModule( 509 | input logic clk, 510 | input logic reset, 511 | input logic in_, 512 | output logic out 513 | ); 514 | 515 | 516 | module TopModule_tb(); 517 | // Signal declarations 518 | logic clk; 519 | logic reset; 520 | logic in_; 521 | logic out; 522 | logic expected_out; 523 | int mismatch_count; 524 | 525 | // Instantiate the DUT 526 | TopModule dut( 527 | .clk(clk), 528 | .reset(reset), 529 | .in_(in_), 530 | .out(out) 531 | ); 532 | 533 | // Clock generation 534 | initial begin 535 | clk = 0; 536 | forever #5 clk = ~clk; 537 | end 538 | 539 | // Test stimulus and checking 540 | initial begin 541 | // Initialize signals 542 | reset = 1; 543 | in_ = 0; 544 | mismatch_count = 0; 545 | expected_out = 0; 546 | 547 | // Wait for 2 clock cycles and release reset 548 | @(posedge clk); 549 | @(posedge clk); 550 | reset = 0; 551 | 552 | // Test case 1: No consecutive ones 553 | @(posedge clk); in_ = 0; expected_out = 0; 554 | @(posedge clk); in_ = 1; expected_out = 0; 555 | @(posedge clk); in_ = 0; expected_out = 0; 556 | @(posedge clk); in_ = 1; expected_out = 0; 557 | 558 | // Test case 2: Two consecutive ones 559 | @(posedge clk); in_ = 1; expected_out = 0; 560 | @(posedge clk); in_ = 1; expected_out = 0; 561 | @(posedge clk); in_ = 0; expected_out = 1; 562 | @(posedge clk); in_ = 0; expected_out = 0; 563 | 564 | // Test case 3: Three consecutive ones 565 | @(posedge clk); in_ = 1; expected_out = 0; 566 | @(posedge clk); in_ = 1; expected_out = 0; 567 | @(posedge clk); in_ = 1; expected_out = 1; 568 | @(posedge clk); in_ = 0; expected_out = 1; 569 | 570 | // Test case 4: Reset during operation 571 | @(posedge clk); in_ = 1; expected_out = 0; 572 | @(posedge clk); in_ = 1; expected_out = 0; 573 | @(posedge clk); reset = 1; in_ = 0; expected_out = 0; 574 | @(posedge clk); reset = 0; in_ = 0; expected_out = 0; 575 | 576 | // End simulation 577 | #20 $finish; 578 | end 579 | 580 | // Monitor changes and check outputs 581 | always @(negedge clk) begin 582 | if (out !== expected_out) begin 583 | $display("Mismatch at time %0t: input=%b, actual_output=%b, expected_output=%b", 584 | $time, in_, out, expected_out); 585 | mismatch_count++; 586 | end else begin 587 | $display("Match at time %0t: input=%b, output=%b", 588 | $time, in_, out); 589 | end 590 | end 591 | 592 | // Final check and display results 593 | final begin 594 | if (mismatch_count == 0) 595 | $display("SIMULATION PASSED"); 596 | else 597 | $display("SIMULATION FAILED: %0d mismatches found", mismatch_count); 598 | end 599 | 600 | endmodule 601 | 602 | 603 | """ 604 | 605 | FAILED_TRIAL_PROMPT = r""" 606 | There was a generation trial that failed simulation: 607 | 608 | {failed_sim_log} 609 | 610 | 611 | {previous_code} 612 | 613 | 614 | {previous_tb} 615 | 616 | """ 617 | 618 | ORDER_PROMPT = r""" 619 | Your response will be processed by a program, not human. 620 | So, please STRICTLY FOLLOW the output format given as XML tag content below to generate a VALID JSON OBJECT: 621 | 622 | {output_format} 623 | 624 | DO NOT include any other information in your response, like 'json', 'reasoning' or ''. 625 | """ 626 | -------------------------------------------------------------------------------- /src/mage/rtl_editor.py: -------------------------------------------------------------------------------- 1 | import json 2 | from inspect import signature 3 | from typing import Any, Dict, List, Tuple 4 | 5 | from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole 6 | from pydantic import BaseModel 7 | 8 | from .log_utils import get_logger 9 | from .prompts import ORDER_PROMPT 10 | from .sim_reviewer import SimReviewer, check_syntax 11 | from .token_counter import TokenCounter, TokenCounterCached 12 | 13 | logger = get_logger(__name__) 14 | 15 | SYSTEM_PROMPT = r""" 16 | You are an expert in RTL design. 17 | Your job is to use actions to edit and simulate SystemVerilog rtl_code, 18 | To make sure the edited code has the functionality described in input_spec, and passes the simulation. 19 | The actions below are available: 20 | 21 | {actions} 22 | 23 | """ 24 | 25 | ACTION_PROMPT = r""" 26 | 27 | {command} 28 | {signature} 29 | {description} 30 | 31 | """ 32 | 33 | INIT_EDITION_PROMPT = r""" 34 | The information below is give to help your work: 35 | 1. The input_spec specifing the functionality of the RTL module; 36 | 2. The generated testbench, which is verified correct by another LLM judge; 37 | 3. The log generated by the simulation which is failed by the generated RTL and TB. 38 | 39 | {input_spec} 40 | 41 | 42 | {generated_tb} 43 | 44 | 45 | {sim_failed_log} 46 | 47 | 48 | [Hints]: 49 | For implementing kmap (Karnaugh map), you need to think and solve mismatches step by step. 50 | Find the inputs corresponding to mismatch in sim_failed_log, and set the output to correct value while maintaining other outputs. 51 | """ 52 | 53 | EXTRA_ORDER_PROMPT = r""" 54 | 1. Try to understand the input_spec, locate the suspicious range, and give reasoning steps in natural language to solve the mismatchs 1 by 1. 55 | In addition, try to give advice to avoid syntax error. 56 | 2. For sequencial logic, carefully examine whether the signal should change when "next_state" matches, or when "state" matches. 57 | 3. For combinational logic, if encountered error with complicated multiline singal: 58 | Try to comment each line of the signal to avoid missing any part. 59 | 4. Don't use state_t to define the parameter. Use `localparam` or Use 'reg' or 'logic' for signals as registers or Flip-Flops. 60 | 5. DO NOT close quotes at last line before the inline comment. It would break the json syntax. 61 | If quote is closed before the comment, just don't add the comment. 62 | 6. Do not try to modify the testbench. Only modify the RTL code. 63 | Also do not try to change or define RefModule. There is RefModule defined elsewhere. 64 | 7. Always try modify RTL code as long as simulation mismatch exists, even if you think the code is correct. 65 | SHOW RESPECT TO THE SIMULATION RESULT. 66 | 8. Declare all ports as logic; use wire or reg for signals inside the block. 67 | 9. Not all the sequential logic need to be reset to 0 when reset is asserted, 68 | but these without-reset logic should be initialized to a known value with an initial block instead of being X. 69 | 10. In sequence logic, if the expected output is asserted but the dut output is not, 70 | carefully examine whether the input signal should affect current output (with comb logic) or next-cycle output (with seq logic). 71 | 72 | The file content which is going to be edited is given below: 73 | 74 | {rtl_code} 75 | 76 | """ 77 | # The prompt above comes from: 78 | # @misc{ho2024verilogcoderautonomousverilogcoding, 79 | # title={VerilogCoder: Autonomous Verilog Coding Agents with Graph-based Planning and Abstract Syntax Tree (AST)-based Waveform Tracing Tool}, 80 | # author={Chia-Tung Ho and Haoxing Ren and Brucek Khailany}, 81 | # year={2024}, 82 | # eprint={2408.08927}, 83 | # archivePrefix={arXiv}, 84 | # primaryClass={cs.AI}, 85 | # url={https://arxiv.org/abs/2408.08927}, 86 | # } 87 | 88 | ACTION_OUTPUT_PROMPT = r""" 89 | Output after running given action: 90 | 91 | {action_output} 92 | 93 | """ 94 | 95 | EXAMPLE_OUTPUT = { 96 | "reasoning": "All reasoning steps", 97 | "action_input": { 98 | "command": "replace_content_by_matching", 99 | "args": { 100 | "old_content": "content to be replaced", 101 | "new_content": "content to replace", 102 | }, 103 | }, 104 | } 105 | 106 | 107 | class ActionInput(BaseModel): 108 | command: str 109 | args: Dict[str, Any] 110 | 111 | 112 | class RTLEditorStepOutput(BaseModel): 113 | reasoning: str 114 | action_input: ActionInput 115 | 116 | 117 | class RTLEditor: 118 | def __init__( 119 | self, 120 | token_counter: TokenCounter, 121 | sim_reviewer: SimReviewer, 122 | ): 123 | self.token_counter = token_counter 124 | self.history: List[ChatMessage] = [] 125 | self.max_trials = 15 126 | self.succeed_history_max_length = 10 127 | self.fail_history_max_length = 6 128 | self.is_done = False 129 | self.last_mismatch_cnt: int | None = None 130 | self.sim_reviewer = sim_reviewer 131 | 132 | def reset(self): 133 | self.is_done = False 134 | self.history = [] 135 | self.last_mismatch_cnt: int | None = None 136 | 137 | def write_rtl(self, content: str) -> None: 138 | with open(self.rtl_path, "w") as f: 139 | f.write(content) 140 | 141 | def read_rtl(self) -> str: 142 | with open(self.rtl_path, "r") as f: 143 | return f.read() 144 | 145 | def replace_sanity_check(self) -> Dict[str, Any]: 146 | # Run syntax check and simulation check sequentially 147 | is_syntax_pass, syntax_output = check_syntax(self.rtl_path) 148 | if is_syntax_pass: 149 | syntax_output = "Syntax check passed." 150 | if not is_syntax_pass: 151 | return { 152 | "is_syntax_pass": False, 153 | "is_sim_pass": False, 154 | "error_msg": syntax_output, 155 | "sim_mismatch_cnt": 0, 156 | } 157 | is_sim_pass, sim_mismatch_cnt, sim_output = self.sim_reviewer.review() 158 | assert isinstance(sim_mismatch_cnt, int) 159 | return { 160 | "is_syntax_pass": True, 161 | "is_sim_pass": is_sim_pass, 162 | "error_msg": "" if is_sim_pass else sim_output, 163 | "sim_mismatch_cnt": sim_mismatch_cnt, 164 | } 165 | 166 | def judge_replace_action_execution( 167 | self, 168 | old_content: str, 169 | new_content: str, 170 | action_name: str, 171 | old_file_content: str, 172 | ) -> Dict[str, Any]: 173 | sanity_check = self.replace_sanity_check() 174 | ret = { 175 | "is_action_executed": False, 176 | **sanity_check, 177 | } 178 | if not ret["is_syntax_pass"]: 179 | assert isinstance(ret["error_msg"], str) 180 | ret["error_msg"] += ( 181 | f"Syntax check failed. {action_name} not executed." 182 | f"old_content: {old_content}," 183 | f"new_content: {new_content}" 184 | ) 185 | self.write_rtl(old_file_content) 186 | return ret 187 | sim_mismatch_cnt = ret["sim_mismatch_cnt"] 188 | if ( 189 | self.last_mismatch_cnt is not None 190 | and sim_mismatch_cnt > self.last_mismatch_cnt 191 | ): 192 | logger.info( 193 | f"Mismatch_cnt {sim_mismatch_cnt} > last {self.last_mismatch_cnt}. Action not executed." 194 | ) 195 | self.write_rtl(old_file_content) 196 | assert isinstance(ret["error_msg"], str) 197 | ret["error_msg"] += ( 198 | "Mismatch_cnt increased after the replacement. " 199 | f"{action_name} not executed." 200 | ) 201 | elif sim_mismatch_cnt == 0 and ret["is_sim_pass"] is False: 202 | logger.info( 203 | f"Mismatch_cnt {sim_mismatch_cnt} == 0 but sim failed. Action not executed." 204 | ) 205 | self.write_rtl(old_file_content) 206 | assert isinstance(ret["error_msg"], str) 207 | ret["error_msg"] += ( 208 | "Mismatch_cnt is 0 but sim failed. " f"{action_name} not executed." 209 | ) 210 | else: 211 | # Accept replace 212 | logger.info( 213 | f"Mismatch_cnt {sim_mismatch_cnt} <= last {self.last_mismatch_cnt}. Action executed." 214 | ) 215 | self.last_mismatch_cnt = ret["sim_mismatch_cnt"] 216 | ret["is_action_executed"] = True 217 | if self.last_mismatch_cnt == 0: 218 | self.is_done = True 219 | 220 | return ret 221 | 222 | def replace_content_by_matching( 223 | self, old_content: str, new_content: str 224 | ) -> Dict[str, Any]: 225 | """ 226 | Replace the content of the matching line in the file with the new content. 227 | Syntax check is performed after the replacement. 228 | Please ONLY replace the content that NEEDS to be modified. Don't change the content that is correct. 229 | Please make sure old_content only occurs once in the file. 230 | Input: 231 | old_content: The old content of the file. 232 | new_content: The new content to replace the matching line. 233 | Output: 234 | A dictionary containing : 235 | 1. Whether the action is executed. 236 | 2. The error message if the action is not executed. 237 | 3. Other information like syntax check result and simulation check result. 238 | Example: 239 | Before: 240 | 241 | 1 module test; 242 | 2 reg a; 243 | 3 reg b; 244 | 4 endmodule 245 | 246 | Action: 247 | 248 | "command": "replace_content_by_matching", 249 | "args": { 250 | "old_content": " reg a;\n reg b;", 251 | "new_content": " logic a;", 252 | }, 253 | 254 | Now: 255 | 256 | 1 module test; 257 | 2 logic a; 258 | 4 endmodule 259 | 260 | """ 261 | old_file_content = self.read_rtl().expandtabs(4) 262 | old_content = old_content.expandtabs(4) 263 | new_content = new_content.expandtabs(4) 264 | 265 | # Check if old_str is unique in the file 266 | logger.info(f"Old File Content:") 267 | logger.info(old_file_content) 268 | logger.info(f"Target old Content:") 269 | logger.info(old_content) 270 | logger.info(f"Target new Content:") 271 | logger.info(new_content) 272 | occurrences = old_file_content.count(old_content) 273 | if occurrences == 0: 274 | return { 275 | "is_action_executed": False, 276 | "new_content": "", 277 | "error_msg": f"Cannot find old_content in current RTL. replace_content_by_matching not executed.", 278 | } 279 | elif occurrences > 1: 280 | return { 281 | "is_action_executed": False, 282 | "new_content": "", 283 | "error_msg": f"Find multiple old_content in current RTL. replace_content_by_matching not executed.", 284 | } 285 | 286 | # Replace old_str with new_str 287 | new_file_content = old_file_content.replace(old_content, new_content) 288 | self.write_rtl(new_file_content) 289 | ret = self.judge_replace_action_execution( 290 | old_content, new_content, "replace_content_by_matching", old_file_content 291 | ) 292 | # ret["new_file_content"] = new_file_content 293 | return ret 294 | 295 | def generate(self, messages: List[ChatMessage]) -> ChatResponse: 296 | logger.info(f"RTL editor input message: {messages}") 297 | resp, token_cnt = self.token_counter.count_chat(messages) 298 | logger.info(f"Token count: {token_cnt}") 299 | logger.info(f"{resp.message.content}") 300 | return resp 301 | 302 | def gen_action_prompt(self, function) -> str: 303 | return ACTION_PROMPT.format( 304 | command=function.__name__, 305 | signature=signature(function), 306 | description=function.__doc__, 307 | ) 308 | 309 | def get_init_prompt_messages(self) -> List[ChatMessage]: 310 | actions = [self.replace_content_by_matching] 311 | actions_prompt = SYSTEM_PROMPT.format( 312 | actions="".join([self.gen_action_prompt(action) for action in actions]) 313 | ) 314 | system_prompt = ChatMessage(content=actions_prompt, role=MessageRole.SYSTEM) 315 | with open(self.tb_path, "r") as f: 316 | generated_tb = f.read() 317 | edit_init_prompt = ChatMessage( 318 | content=INIT_EDITION_PROMPT.format( 319 | input_spec=self.spec, 320 | generated_tb=generated_tb, 321 | sim_failed_log=self.sim_failed_log, 322 | ), 323 | role=MessageRole.USER, 324 | ) 325 | ret = [system_prompt, edit_init_prompt] 326 | if ( 327 | isinstance(self.token_counter, TokenCounterCached) 328 | and self.token_counter.enable_cache 329 | ): 330 | self.token_counter.add_cache_tag(ret[-1]) 331 | return ret 332 | 333 | def get_order_prompt_messages(self) -> List[ChatMessage]: 334 | with open(self.rtl_path, "r") as f: 335 | rtl_code = f.read() 336 | return [ 337 | ChatMessage( 338 | content=ORDER_PROMPT.format( 339 | output_format="".join(json.dumps(EXAMPLE_OUTPUT, indent=4)) 340 | ) 341 | + EXTRA_ORDER_PROMPT.format(rtl_code=rtl_code), 342 | role=MessageRole.USER, 343 | ), 344 | ] 345 | 346 | def parse_output(self, response: ChatResponse) -> RTLEditorStepOutput: 347 | output_json_obj: Dict = json.loads(response.message.content, strict=False) 348 | action_input = output_json_obj["action_input"] 349 | command = action_input["command"] 350 | 351 | args = action_input["args"] 352 | return RTLEditorStepOutput( 353 | reasoning=output_json_obj["reasoning"], 354 | action_input=ActionInput(command=command, args=args), 355 | ) 356 | 357 | def run_action(self, action_input: ActionInput) -> Dict[str, Any]: 358 | logger.info(f"Action input: {action_input}") 359 | action = getattr(self, action_input.command) 360 | action_output = action(**action_input.args) 361 | logger.info(f"Action output: {action_output}") 362 | return action_output 363 | 364 | def get_action_output_message(self, output: Dict[str, Any]) -> List[ChatMessage]: 365 | return [ 366 | ChatMessage( 367 | content=ACTION_OUTPUT_PROMPT.format( 368 | action_output=json.dumps(output, indent=4) 369 | ), 370 | role=MessageRole.USER, 371 | ), 372 | ] 373 | 374 | def chat( 375 | self, 376 | spec: str, 377 | output_dir_per_run: str, 378 | sim_failed_log: str, 379 | sim_mismatch_cnt: int, 380 | ) -> Tuple[bool, str]: 381 | # 1. Initialize the history 382 | # 2. Generate the initial prompt messages (with functool information) 383 | # 3. Loop for the max trials: 384 | # - Generate the order prompt messages 385 | # - Generate & parse the response 386 | # - Generate & parse the tool call 387 | # - If called 388 | if isinstance(self.token_counter, TokenCounterCached): 389 | self.token_counter.set_enable_cache(True) 390 | self.history = [] 391 | self.token_counter.set_cur_tag(self.__class__.__name__) 392 | self.spec = spec 393 | self.output_dir_per_run = output_dir_per_run 394 | self.tb_path = f"{output_dir_per_run}/tb.sv" 395 | self.rtl_path = f"{output_dir_per_run}/rtl.sv" 396 | self.sim_failed_log = sim_failed_log 397 | self.last_mismatch_cnt = sim_mismatch_cnt 398 | 399 | self.history.extend(self.get_init_prompt_messages()) 400 | is_pass = False 401 | succeed_history: List[ChatMessage] = [] 402 | fail_history: List[ChatMessage] = [] 403 | for i in range(self.max_trials): 404 | logger.info(f"RTL Editing: round {i + 1} / {self.max_trials}") 405 | response = self.generate( 406 | self.history 407 | + succeed_history 408 | + fail_history 409 | + self.get_order_prompt_messages() 410 | ) 411 | new_contents = [response.message] 412 | action_input = self.parse_output(response).action_input 413 | action_output = self.run_action(action_input) 414 | if self.is_done: 415 | is_pass = True 416 | break 417 | new_contents.extend(self.get_action_output_message(action_output)) 418 | assert len(new_contents) == 2, f"new_contents: {new_contents}" 419 | if action_output["is_action_executed"]: 420 | fail_history = [] 421 | succeed_history.extend(new_contents) 422 | if len(succeed_history) > self.succeed_history_max_length: 423 | succeed_history = succeed_history[ 424 | -self.succeed_history_max_length : 425 | ] 426 | else: 427 | fail_history.extend(new_contents) 428 | if len(fail_history) > self.fail_history_max_length: 429 | fail_history = fail_history[-self.fail_history_max_length :] 430 | 431 | with open(self.rtl_path, "r") as f: 432 | rtl_code = f.read() 433 | return (is_pass, rtl_code) 434 | -------------------------------------------------------------------------------- /src/mage/rtl_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Dict, List, Tuple 3 | 4 | from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole 5 | from pydantic import BaseModel 6 | 7 | from .log_utils import get_logger 8 | from .prompts import FAILED_TRIAL_PROMPT, ORDER_PROMPT, RTL_4_SHOT_EXAMPLES 9 | from .sim_reviewer import check_syntax 10 | from .token_counter import TokenCounter, TokenCounterCached 11 | from .utils import add_lineno 12 | 13 | logger = get_logger(__name__) 14 | 15 | SYSTEM_PROMPT = r""" 16 | You are an expert in RTL design. You can always write SystemVerilog code with no syntax errors and always reach correct functionality. 17 | """ 18 | 19 | GENERATION_PROMPT = r""" 20 | Please write a module in SystemVerilog RTL language regarding to the given natural language specification. 21 | Try to understand the requirements above and give reasoning steps in natural language to achieve it. 22 | In addition, try to give advice to avoid syntax error. 23 | An SystemVerilog RTL module always starts with a line starting with the keyword 'module' followed by the module name. 24 | It ends with the keyword 'endmodule'. 25 | 26 | [Hints]: 27 | For implementing kmap (Karnaugh map), you need to think step by step. 28 | Carefully example how the kmap in input_spec specifies the order of the inputs. 29 | Note that x[i] in x[N:1] means x[i-1] in x[N-1:0]. 30 | Then find the inputs corresponding to output=1, 0, and don't-care for each case. 31 | 32 | Note in Verilog, for a signal "logic x[M:N]" where M > N, you CANNOT reversely select bits from it like x[1:2]; 33 | Instead, you should use concatations like {{x[1], x[2]}}. 34 | 35 | The module interface should EXACTLY MATCH module_interface if given. 36 | Otherwise, should EXACTLY MATCH with the description in input_spec. 37 | (Including the module name, input/output ports names, and their types) 38 | 39 | 40 | {examples_prompt} 41 | 42 | {input_spec} 43 | 44 | """ 45 | 46 | EXTRA_ORDER_PROMPT = r""" 47 | Other requirements: 48 | 1. Don't use state_t to define the parameter. Use `localparam` or Use 'reg' or 'logic' for signals as registers or Flip-Flops. 49 | 2. Declare all ports and signals as logic. 50 | 3. Not all the sequential logic need to be reset to 0 when reset is asserted, 51 | but these without-reset logic should be initialized to a known value with an initial block instead of being X. 52 | 4. For combinational logic with an always block do not explicitly specify the sensitivity list; instead use always @(*). 53 | 5. NEVER USE 'inside' operator in RTL code. Code like 'state inside {STATE_B, STATE_C, STATE_D}' should NOT be used. 54 | 6. Never USE 'unique' or 'unique0' keywords in RTL code. Code like 'unique case' should NOT be used. 55 | """ 56 | # Some prompts above comes from: 57 | # @misc{ho2024verilogcoderautonomousverilogcoding, 58 | # title={VerilogCoder: Autonomous Verilog Coding Agents with Graph-based Planning and Abstract Syntax Tree (AST)-based Waveform Tracing Tool}, 59 | # author={Chia-Tung Ho and Haoxing Ren and Brucek Khailany}, 60 | # year={2024}, 61 | # eprint={2408.08927}, 62 | # archivePrefix={arXiv}, 63 | # primaryClass={cs.AI}, 64 | # url={https://arxiv.org/abs/2408.08927}, 65 | # } 66 | 67 | IF_PROMPT = r""" 68 | The module interface is given below: 69 | 70 | {module_interface} 71 | 72 | """ 73 | 74 | TB_PROMPT = r""" 75 | Another agent has generated a testbench regarding the given input_spec: 76 | 77 | {testbench} 78 | 79 | """ 80 | 81 | FORMAT_ERROR_PROMPT = r""" 82 | The error below has been reported by the format tool: 83 | 84 | {format_error} 85 | 86 | To understand the error message better, we offered a version of generated module with line number: 87 | 88 | {module_with_lineno} 89 | 90 | """ 91 | 92 | EXAMPLE_OUTPUT = { 93 | "reasoning": "All reasoning steps and advices to avoid syntax error", 94 | "module": "Pure SystemVerilog code, a complete module", 95 | } 96 | 97 | 98 | class RTLOutputFormat(BaseModel): 99 | reasoning: str 100 | module: str 101 | 102 | 103 | class RTLGenerator: 104 | def __init__( 105 | self, 106 | token_counter: TokenCounter, 107 | ): 108 | self.token_counter = token_counter 109 | self.generated_tb: str | None = None 110 | self.generated_if: str | None = None 111 | self.failed_trial: List[ChatMessage] = [] 112 | self.history: List[ChatMessage] = [] 113 | self.max_trials = 5 114 | self.enable_cache = False 115 | 116 | def reset(self): 117 | self.history = [] 118 | 119 | def set_failed_trial( 120 | self, failed_sim_log: str, previous_code: str, previous_tb: str 121 | ) -> None: 122 | cur_failed_trial = FAILED_TRIAL_PROMPT.format( 123 | failed_sim_log=failed_sim_log, 124 | previous_code=add_lineno(previous_code), 125 | previous_tb=add_lineno(previous_tb), 126 | ) 127 | self.failed_trial.append( 128 | ChatMessage(content=cur_failed_trial, role=MessageRole.USER) 129 | ) 130 | 131 | def generate(self, messages: List[ChatMessage]) -> ChatResponse: 132 | logger.info(f"RTL generator input message: {messages}") 133 | resp, token_cnt = self.token_counter.count_chat(messages) 134 | logger.info(f"Token count: {token_cnt}") 135 | logger.info(f"{resp.message.content}") 136 | return resp 137 | 138 | def batch_generate( 139 | self, messages_list: List[List[ChatMessage]] 140 | ) -> List[ChatResponse]: 141 | resp_token_cnt_list = self.token_counter.count_chat_batch(messages_list) 142 | responses = [] 143 | for i, ((resp, token_cnt), _) in enumerate( 144 | zip(resp_token_cnt_list, messages_list) 145 | ): 146 | logger.info(f"Message {i+1} token count: {token_cnt}") 147 | responses.append(resp) 148 | return responses 149 | 150 | def get_init_prompt_messages(self, input_spec: str) -> List[ChatMessage]: 151 | ret = [ 152 | ChatMessage(content=SYSTEM_PROMPT, role=MessageRole.SYSTEM), 153 | ChatMessage( 154 | content=GENERATION_PROMPT.format( 155 | input_spec=input_spec, examples_prompt=RTL_4_SHOT_EXAMPLES 156 | ), 157 | role=MessageRole.USER, 158 | ), 159 | ] 160 | if self.generated_tb: 161 | ret.append( 162 | ChatMessage( 163 | content=TB_PROMPT.format(testbench=self.generated_tb), 164 | role=MessageRole.USER, 165 | ) 166 | ) 167 | if self.failed_trial: 168 | ret.extend(self.failed_trial) 169 | if self.generated_if: 170 | ret.append( 171 | ChatMessage( 172 | content=IF_PROMPT.format(module_interface=self.generated_if), 173 | role=MessageRole.USER, 174 | ) 175 | ) 176 | if ( 177 | isinstance(self.token_counter, TokenCounterCached) 178 | and self.token_counter.enable_cache 179 | ): 180 | self.token_counter.add_cache_tag(ret[-1]) 181 | return ret 182 | 183 | def get_order_prompt_messages(self) -> List[ChatMessage]: 184 | return [ 185 | ChatMessage( 186 | content=ORDER_PROMPT.format( 187 | output_format="".join(json.dumps(EXAMPLE_OUTPUT, indent=4)) 188 | ) 189 | + EXTRA_ORDER_PROMPT, 190 | role=MessageRole.USER, 191 | ), 192 | ] 193 | 194 | def get_format_error_prompt_messages( 195 | self, format_error: str, rtl_code: str 196 | ) -> List[ChatMessage]: 197 | return [ 198 | ChatMessage( 199 | content=FORMAT_ERROR_PROMPT.format( 200 | format_error=format_error, module_with_lineno=add_lineno(rtl_code) 201 | ), 202 | role=MessageRole.USER, 203 | ), 204 | ] 205 | 206 | def parse_output(self, response: ChatResponse) -> RTLOutputFormat: 207 | try: 208 | output_json_obj: Dict = json.loads(response.message.content, strict=False) 209 | ret = RTLOutputFormat( 210 | reasoning=output_json_obj["reasoning"], module=output_json_obj["module"] 211 | ) 212 | except json.decoder.JSONDecodeError as e: 213 | ret = RTLOutputFormat(reasoning=f"Json Decode Error: {str(e)}", module="") 214 | return ret 215 | 216 | def chat( 217 | self, 218 | input_spec: str, 219 | testbench: str, 220 | interface: str, 221 | rtl_path: str, 222 | enable_cache: bool = False, 223 | ) -> Tuple[bool, str]: 224 | if isinstance(self.token_counter, TokenCounterCached): 225 | self.token_counter.set_enable_cache(enable_cache) 226 | self.history = [] 227 | self.token_counter.set_cur_tag(self.__class__.__name__) 228 | self.generated_tb = testbench 229 | self.generated_if = interface 230 | self.history.extend(self.get_init_prompt_messages(input_spec)) 231 | for _ in range(self.max_trials): 232 | response = self.generate(self.history + self.get_order_prompt_messages()) 233 | resp_obj = self.parse_output(response) 234 | if resp_obj.reasoning.startswith("Json Decode Error"): 235 | logger.info( 236 | f"RTL generation Error: {resp_obj.reasoning}, drop this response" 237 | ) 238 | continue 239 | rtl_code = resp_obj.module 240 | with open(rtl_path, "w") as f: 241 | f.write(rtl_code) 242 | syntax_correct, syntax_output = check_syntax(rtl_path=rtl_path) 243 | if syntax_correct: 244 | break 245 | self.history.extend( 246 | [response.message] 247 | + self.get_format_error_prompt_messages(syntax_output, rtl_code) 248 | ) 249 | return (syntax_correct, rtl_code) 250 | 251 | def gen_candidates( 252 | self, 253 | input_spec: str, 254 | testbench: str, 255 | interface: str, 256 | rtl_path: str, 257 | candidates_num: int, 258 | enable_cache: bool = False, 259 | ) -> List[Tuple[bool, str]]: 260 | if isinstance(self.token_counter, TokenCounterCached): 261 | self.token_counter.set_enable_cache(enable_cache) 262 | self.history = [] 263 | self.token_counter.set_cur_tag(self.__class__.__name__) 264 | self.generated_tb = testbench 265 | self.generated_if = interface 266 | self.history.extend(self.get_init_prompt_messages(input_spec)) 267 | ret: List[Tuple[bool, str]] = [(False, "") for _ in range(candidates_num)] 268 | messages = [ 269 | self.history + self.get_order_prompt_messages() 270 | for _ in range(candidates_num) 271 | ] 272 | logger.info(f"gen_candidates init input message: {messages[0]}") 273 | init_responses = self.batch_generate(messages) 274 | for i, response in enumerate(init_responses): 275 | rtl_code = self.parse_output(response).module 276 | candidate_history: List[ChatMessage] = [response.message] 277 | for j in range(self.max_trials): 278 | with open(rtl_path, "w") as f: 279 | f.write(rtl_code) 280 | syntax_correct, syntax_output = check_syntax(rtl_path=rtl_path) 281 | ret[i] = (syntax_correct, rtl_code) 282 | logger.info( 283 | f"Candidate {i + 1} / {candidates_num} trial {j + 1} / {self.max_trials} syntax_correct: {syntax_correct}" 284 | ) 285 | logger.info(f"RTL code: {rtl_code}") 286 | if syntax_correct: 287 | break 288 | elif j < self.max_trials - 1: 289 | candidate_history.extend( 290 | self.get_format_error_prompt_messages(syntax_output, rtl_code) 291 | ) 292 | response = self.generate( 293 | self.history 294 | + candidate_history 295 | + self.get_order_prompt_messages() 296 | ) 297 | rtl_code = self.parse_output(response).module 298 | return ret 299 | 300 | def ablation_chat(self, input_spec: str, rtl_path: str) -> Tuple[bool, str]: 301 | if isinstance(self.token_counter, TokenCounterCached): 302 | self.token_counter.set_enable_cache(False) 303 | self.history = [] 304 | self.token_counter.set_cur_tag(self.__class__.__name__) 305 | self.generated_tb = None 306 | self.generated_if = None 307 | self.history.extend(self.get_init_prompt_messages(input_spec)) 308 | for _ in range(self.max_trials): 309 | # Don't add order message into history, to save token 310 | response = self.generate(self.history + self.get_order_prompt_messages()) 311 | self.history.append(response.message) 312 | rtl_code = self.parse_output(response).module 313 | with open(rtl_path, "w") as f: 314 | f.write(rtl_code) 315 | syntax_correct, syntax_output = check_syntax(rtl_path=rtl_path) 316 | if syntax_correct: 317 | break 318 | self.history.extend( 319 | self.get_format_error_prompt_messages(syntax_output, rtl_code) 320 | ) 321 | return (syntax_correct, rtl_code) 322 | -------------------------------------------------------------------------------- /src/mage/sim_judge.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Dict, List 3 | 4 | from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole 5 | from pydantic import BaseModel 6 | 7 | from .log_utils import get_logger 8 | from .prompts import ORDER_PROMPT 9 | from .token_counter import TokenCounter, TokenCounterCached 10 | from .utils import add_lineno 11 | 12 | logger = get_logger(__name__) 13 | 14 | SYSTEM_PROMPT = r""" 15 | You are an expert in SystemVerilog design. 16 | You can always write SystemVerilog code with no syntax errors and always reach correct functionality. 17 | """ 18 | 19 | GENERATION_PROMPT = r""" 20 | A simulation has failed for the rtl and testbench generated by other agents; 21 | The related failed_sim_log was given below. The input_spec which the rtl and testbench should follow was also given below. 22 | Please judge whether the TESTBENCH need to be modified. 23 | If you think so, set tb_needs_fix = True, otherwise set tb_needs_fix = False. 24 | Try to understand the requirements above and give reasoning steps in natural language to achieve it. 25 | 26 | 27 | {input_spec} 28 | 29 | 30 | {failed_sim_log} 31 | 32 | 33 | {failed_rtl} 34 | 35 | 36 | {failed_testbench} 37 | 38 | """ 39 | 40 | EXAMPLE_OUTPUT = { 41 | "reasoning": "All reasoning steps", 42 | "tb_needs_fix": False, 43 | } 44 | 45 | 46 | class TBOutputFormat(BaseModel): 47 | reasoning: str 48 | tb_needs_fix: bool 49 | 50 | 51 | EXTRA_ORDER_PROMPT = r""" 52 | Especially, FORCE SET tb_needs_fix to True if failed_sim_log says there is ANY syntax error in testbench(tb.sv), 53 | Even if the syntax error looks not related to the failed test case or the testbench looks correct. 54 | """ 55 | 56 | 57 | class SimJudge: 58 | def __init__( 59 | self, 60 | token_counter: TokenCounter, 61 | ): 62 | self.token_counter = token_counter 63 | self.history: List[ChatMessage] = [] 64 | 65 | def reset(self): 66 | self.history = [] 67 | 68 | def generate(self, messages: List[ChatMessage]) -> ChatResponse: 69 | logger.info(f"Sim judge input message: {messages}") 70 | resp, token_cnt = self.token_counter.count_chat(messages) 71 | logger.info(f"Token count: {token_cnt}") 72 | logger.info(f"{resp.message.content}") 73 | return resp 74 | 75 | def get_init_prompt_messages( 76 | self, 77 | input_spec: str, 78 | failed_sim_log: str, 79 | failed_rtl: str, 80 | failed_testbench: str, 81 | ) -> List[ChatMessage]: 82 | ret = [ 83 | ChatMessage(content=SYSTEM_PROMPT, role=MessageRole.SYSTEM), 84 | ChatMessage( 85 | content=GENERATION_PROMPT.format( 86 | input_spec=input_spec, 87 | failed_sim_log=failed_sim_log, 88 | failed_rtl=add_lineno(failed_rtl), 89 | failed_testbench=add_lineno(failed_testbench), 90 | ), 91 | role=MessageRole.USER, 92 | ), 93 | ] 94 | return ret 95 | 96 | def get_order_prompt_messages(self) -> List[ChatMessage]: 97 | return [ 98 | ChatMessage( 99 | content=ORDER_PROMPT.format( 100 | output_format="".join(json.dumps(EXAMPLE_OUTPUT, indent=4)) 101 | + EXTRA_ORDER_PROMPT 102 | ), 103 | role=MessageRole.USER, 104 | ), 105 | ] 106 | 107 | def parse_output(self, response: ChatResponse) -> TBOutputFormat: 108 | output_json_obj: Dict = json.loads(response.message.content, strict=False) 109 | return TBOutputFormat( 110 | reasoning=output_json_obj["reasoning"], 111 | tb_needs_fix=output_json_obj["tb_needs_fix"], 112 | ) 113 | 114 | def chat( 115 | self, 116 | input_spec: str, 117 | failed_sim_log: str, 118 | failed_rtl: str, 119 | failed_testbench: str, 120 | ) -> bool: 121 | if isinstance(self.token_counter, TokenCounterCached): 122 | self.token_counter.set_enable_cache(False) 123 | self.history = [] 124 | self.token_counter.set_cur_tag(self.__class__.__name__) 125 | self.history.extend( 126 | self.get_init_prompt_messages( 127 | input_spec, failed_sim_log, failed_rtl, failed_testbench 128 | ) 129 | ) 130 | self.history.extend(self.get_order_prompt_messages()) 131 | response = self.generate(self.history) 132 | resp_obj = self.parse_output(response) 133 | return resp_obj.tb_needs_fix 134 | -------------------------------------------------------------------------------- /src/mage/sim_reviewer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | from typing import Dict, List, Tuple 5 | 6 | from .bash_tools import CommandResult, run_bash_command 7 | from .benchmark_read_helper import TypeBenchmark 8 | from .log_utils import get_logger, set_log_dir 9 | 10 | logger = get_logger(__name__) 11 | 12 | BENIGN_STDERRS = [ 13 | r"^\S+:\d+: sorry: constant selects in always_\* processes are not currently supported \(all bits will be included\)\.$" 14 | ] 15 | 16 | 17 | def stderr_all_lines_benign(stderr: str) -> bool: 18 | return all( 19 | any(re.match(pattern, line) for pattern in BENIGN_STDERRS) 20 | for line in stderr.splitlines() 21 | ) 22 | 23 | 24 | def check_syntax(rtl_path: str) -> Tuple[bool, str]: 25 | cmd = f"iverilog -t null -Wall -Winfloop -Wno-timescale -g2012 -o /dev/null {rtl_path}" 26 | is_pass, sim_output = run_bash_command(cmd, timeout=60) 27 | sim_output_obj = CommandResult.model_validate_json(sim_output) 28 | is_pass = ( 29 | is_pass 30 | and "syntax error" not in sim_output_obj.stdout 31 | and ( 32 | sim_output_obj.stderr == "" 33 | or stderr_all_lines_benign(sim_output_obj.stderr) 34 | ) 35 | ) 36 | logger.info(f"Syntax check is_pass: {is_pass}, \noutput: {sim_output}") 37 | return is_pass, sim_output 38 | 39 | 40 | def sim_review_mismatch_cnt(stdout: str) -> int: 41 | mismatch_cnt = 0 42 | if "SIMULATION FAILED" in stdout: 43 | re_str = r"SIMULATION FAILED - (\d*) MISMATCHES DETECTED" 44 | m = re.search(re_str, stdout) 45 | assert m is not None, f"Failed to parse mismatch count from: {stdout}" 46 | mismatch_cnt = int(m.group(1)) 47 | return mismatch_cnt 48 | 49 | 50 | def sim_review( 51 | output_path_per_run: str, 52 | golden_rtl_path: str | None = None, 53 | ) -> Tuple[bool, int, str]: 54 | rtl_path = f"{output_path_per_run}/rtl.sv" 55 | vvp_name = f"{output_path_per_run}/sim_output.vvp" 56 | tb_path = f"{output_path_per_run}/tb.sv" 57 | if golden_rtl_path is None: 58 | golden_rtl_path = "" 59 | if os.path.isfile(vvp_name): 60 | os.remove(vvp_name) 61 | cmd = "iverilog -Wall -Winfloop -Wno-timescale -g2012 -o {} {} {} {}; vvp -n {}".format( 62 | vvp_name, tb_path, rtl_path, golden_rtl_path, vvp_name 63 | ) 64 | is_pass, sim_output = run_bash_command(cmd, timeout=60) 65 | sim_output_obj = CommandResult.model_validate_json(sim_output) 66 | is_pass = ( 67 | is_pass 68 | and "SIMULATION PASSED" in sim_output_obj.stdout 69 | and ( 70 | sim_output_obj.stderr == "" 71 | or stderr_all_lines_benign(sim_output_obj.stderr) 72 | ) 73 | ) 74 | mismatch_cnt = sim_review_mismatch_cnt(sim_output_obj.stdout) 75 | logger.info( 76 | f"Simulation is_pass: {is_pass}, mismatch_cnt: {mismatch_cnt}\noutput: {sim_output}" 77 | ) 78 | assert isinstance(sim_output, str) and isinstance(is_pass, bool) 79 | return is_pass, mismatch_cnt, sim_output 80 | 81 | 82 | class SimReviewer: 83 | def __init__( 84 | self, 85 | output_path_per_run: str, 86 | golden_rtl_path: str | None = None, 87 | ): 88 | self.output_path_per_run = output_path_per_run 89 | self.golden_rtl_path = golden_rtl_path 90 | 91 | def review(self) -> Tuple[bool, int, str]: 92 | return sim_review( 93 | self.output_path_per_run, 94 | self.golden_rtl_path, 95 | ) 96 | 97 | 98 | def sim_review_golden( 99 | rtl_path: str, 100 | task_id: str, 101 | benchmark_type: TypeBenchmark, 102 | benchmark_path: str, 103 | output_path_per_run: str, 104 | ) -> Tuple[bool, str]: 105 | 106 | if ( 107 | benchmark_type == TypeBenchmark.VERILOG_EVAL_V2 108 | or benchmark_type == TypeBenchmark.VERILOG_EVAL_V1 109 | ): 110 | folder = ( 111 | "dataset_code-complete-iccad2023" 112 | if benchmark_type == TypeBenchmark.VERILOG_EVAL_V1 113 | else "dataset_spec-to-rtl" 114 | ) 115 | tb_path = f"{benchmark_path}/{folder}/{task_id}_test.sv" 116 | ref_path = f"{benchmark_path}/{folder}/{task_id}_ref.sv" 117 | vvp_name = f"{output_path_per_run}/sim_golden.vvp" 118 | if os.path.isfile(vvp_name): 119 | os.remove(vvp_name) 120 | cmd = "iverilog -Wall -Winfloop -Wno-timescale -g2012 -s tb -o {} {} {} {}; vvp -n {}".format( 121 | vvp_name, tb_path, rtl_path, ref_path, vvp_name 122 | ) 123 | is_pass, sim_output = run_bash_command(cmd, timeout=60) 124 | sim_output_obj = CommandResult.model_validate_json(sim_output) 125 | is_pass = ( 126 | is_pass 127 | and "First mismatch occurred at time" not in sim_output_obj.stdout 128 | and ( 129 | sim_output_obj.stderr == "" 130 | or stderr_all_lines_benign(sim_output_obj.stderr) 131 | ) 132 | ) 133 | logger.info(f"Golden simulation is_pass: {is_pass}, \noutput: {sim_output}") 134 | return is_pass, sim_output 135 | raise NotImplementedError # Should not reach here 136 | 137 | 138 | def sim_review_golden_benchmark( 139 | task_id: str, 140 | output_path: str, 141 | benchmark_type: TypeBenchmark, 142 | benchmark_path: str, 143 | ) -> Tuple[bool, str]: 144 | output_path_per_run = f"{output_path}/{benchmark_type.name}_{task_id}" 145 | rtl_path = f"{output_path_per_run}/rtl.sv" 146 | is_pass, sim_output = sim_review_golden( 147 | rtl_path, task_id, benchmark_type, benchmark_path, output_path_per_run 148 | ) 149 | with open(f"{output_path_per_run}/sim_review_output.json", "w") as f: 150 | f.write( 151 | json.dumps( 152 | {"is_pass": is_pass, "sim_output": json.loads(sim_output)}, indent=4 153 | ) 154 | ) 155 | return (is_pass, sim_output) 156 | 157 | 158 | def sim_review_golden_benchmark_batch( 159 | task_id_list: List[str], 160 | log_path: str, 161 | output_path: str, 162 | benchmark_type: TypeBenchmark, 163 | benchmark_path: str, 164 | ) -> Dict[str, Tuple[bool, str]]: 165 | ret: Dict[str, Tuple[bool, str]] = {} 166 | for task_id in task_id_list: 167 | set_log_dir(f"{log_path}/golden_review_{benchmark_type.name}_{task_id}") 168 | ret[task_id] = sim_review_golden_benchmark( 169 | task_id, output_path, benchmark_type, benchmark_path 170 | ) 171 | return ret 172 | -------------------------------------------------------------------------------- /src/mage/tb_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Dict, List, Tuple 3 | 4 | from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole 5 | from pydantic import BaseModel 6 | 7 | from .log_utils import get_logger 8 | from .prompts import FAILED_TRIAL_PROMPT, ORDER_PROMPT, TB_4_SHOT_EXAMPLES 9 | from .token_counter import TokenCounter, TokenCounterCached 10 | from .utils import add_lineno 11 | 12 | logger = get_logger(__name__) 13 | 14 | SYSTEM_PROMPT = r""" 15 | You are an expert in SystemVerilog design. 16 | You can always write SystemVerilog code with no syntax errors and always reach correct functionality. 17 | """ 18 | 19 | NON_GOLDEN_TB_PROMPT = r""" 20 | In order to test a module generated with the given natural language specification: 21 | 1. Please write an IO interface for that module; 22 | 2. Please write a testbench to test the module. 23 | 24 | The module interface should EXACTLY MATCH the description in input_spec. 25 | (Including the module name, input/output ports names, and their types) 26 | 27 | 28 | {input_spec} 29 | 30 | 31 | The testbench should: 32 | 1. Instantiate the module according to the IO interface; 33 | 2. Generate input stimulate signals and expected output signals according to input_spec; 34 | 3. Apply the input signals to the module, count the number of mismatches between the output signals with the expected output signals; 35 | 4. Every time when a check occurs, no matter match or mismatch, display input signals, output signals and expected output signals; 36 | 5. When simulation ends, ADD DISPLAY "SIMULATION PASSED" if no mismatch occurs, otherwise display: 37 | "SIMULATION FAILED - x MISMATCHES DETECTED, FIRST AT TIME y". 38 | 6. To avoid ambiguity, please use the reverse edge to do output check. (If RTL runs at posedge, use negedge to check the output) 39 | 7. For pure combinational module (especially those without clk), 40 | the expected output should be checked at the exact moment when the input is changed; 41 | 8. Avoid using keyword "continue" 42 | 43 | Try to understand the requirements above and give reasoning steps in natural language to achieve it. 44 | In addition, try to give advice to avoid syntax error. 45 | An SystemVerilog module always starts with a line starting with the keyword 'module' followed by the module name. 46 | It ends with the keyword 'endmodule'. 47 | 48 | {examples_prompt} 49 | 50 | Please also follow the display prompt below: 51 | {display_prompt} 52 | """ 53 | 54 | GOLDEN_TB_PROMPT = r""" 55 | In order to test a module generated with the given natural language specification: 56 | 1. Please write an IO interface for that module; 57 | 2. Please improve the given golden testbench to test the module. 58 | 59 | The module interface should EXACTLY MATCH the description in input_spec. 60 | (Including the module name, input/output ports names, and their types) 61 | 62 | 63 | {input_spec} 64 | 65 | 66 | To improve the golden testbench, you should add more display to it, while keeping the original functionality. 67 | In detail, the testbench you generated should: 68 | 1. MAINTAIN the EXACT SAME functionality, interface and module instantiation as the golden testbench; 69 | 2. If the golden testbench contradicts the input_spec, ALWAYS FOLLOW the golden testbench; 70 | 3. MAINTAIN the original logic of error counting; 71 | 4. When simulation ends, ADD DISPLAY "SIMULATION PASSED" if no mismatch occurs, otherwise display: 72 | "SIMULATION FAILED - x MISMATCHES DETECTED, FIRST AT TIME y". 73 | Please also follow the display prompt below: 74 | {display_prompt} 75 | 76 | 77 | Try to understand the requirements above and give reasoning steps in natural language to achieve it. 78 | In addition, try to give advice to avoid syntax error. 79 | An SystemVerilog module always starts with a line starting with the keyword 'module' followed by the module name. 80 | It ends with the keyword 'endmodule'. 81 | 82 | Below is the golden testbench code for the module generated with the given natural language specification. 83 | 84 | {golden_testbench} 85 | 86 | """ 87 | 88 | DISPLAY_MOMENT_PROMPT = r""" 89 | 1. When the first mismatch occurs, display the input signals, output signals and expected output signals at that time. 90 | 2. For multiple-bit signals displayed in HEX format, also display the BINARY format if its width <= 64. 91 | """ 92 | 93 | DISPLAY_QUEUE_PROMPT = r""" 94 | 1. If module to test is sequential logic (like including an FSM): 95 | 1.1. Store input signals, output signals, expected output signals and reset signals in a queue with MAX_QUEUE_SIZE; 96 | When the first mismatch occurs, display the queue content after storing it. Make sure the mismatched signal can be displayed. 97 | 1.2. MAX_QUEUE_SIZE should be set according to the requirement of the module. 98 | For example, if the module has a 3-bit state, MAX_QUEUE_SIZE should be at least 2 ** 3 = 8. 99 | And if the module was to detect a pattern of 8 bits, MAX_QUEUE_SIZE should be at least (8 + 1) = 9. 100 | However, to control log size, NEVER set MAX_QUEUE_SIZE > 10. 101 | 1.3. The clocking of queue and display should be same with the clocking of tb_match detection. 102 | For example, if 'always @(posedge clk, negedge clk)' is used to detect mismatch, 103 | It should also be used to push queue and display first error. 104 | 2. If module to test is combinational logic: 105 | When the first mismatch occurs, display the input signals, output signals and expected output signals at that time. 106 | 3. For multiple-bit signals displayed in HEX format, also display the BINARY format if its width <= 64. 107 | 108 | 109 | // Queue-based simulation mismatch display 110 | 111 | reg [INPUT_WIDTH-1:0] input_queue [$]; 112 | reg [OUTPUT_WIDTH-1:0] got_output_queue [$]; 113 | reg [OUTPUT_WIDTH-1:0] golden_queue [$]; 114 | reg reset_queue [$]; 115 | 116 | localparam MAX_QUEUE_SIZE = 5; 117 | 118 | always @(posedge clk, negedge clk) begin 119 | if (input_queue.size() >= MAX_QUEUE_SIZE - 1) begin 120 | input_queue.delete(0); 121 | got_output_queue.delete(0); 122 | golden_queue.delete(0); 123 | reset_queue.delete(0); 124 | end 125 | 126 | input_queue.push_back(input_data); 127 | got_output_queue.push_back(got_output); 128 | golden_queue.push_back(golden_output); 129 | reset_queue.push_back(rst); 130 | 131 | // Check for first mismatch 132 | if (got_output !== golden_output) begin 133 | $display("Mismatch detected at time %t", $time); 134 | $display("\nLast %d cycles of simulation:", input_queue.size()); 135 | 136 | 137 | for (int i = 0; i < input_queue.size(); i++) begin 138 | if (got_output_queue[i] === golden_queue[i]) begin 139 | $display("Got Match at"); 140 | end else begin 141 | $display("Got Mismatch at"); 142 | end 143 | $display("Cycle %d, reset %b, input %h, got output %h, exp output %h", 144 | i, 145 | reset_queue[i], 146 | input_queue[i], 147 | got_output_queue[i], 148 | golden_queue[i] 149 | ); 150 | end 151 | end 152 | 153 | end 154 | 155 | """ 156 | 157 | 158 | EXAMPLE_OUTPUT = { 159 | "reasoning": "All reasoning steps and advices to avoid syntax error", 160 | "interface": "The IO part of a SystemVerilog module, not containing the module implementation", 161 | "testbench": "The testbench code to test the module", 162 | } 163 | 164 | 165 | class TBOutputFormat(BaseModel): 166 | reasoning: str 167 | interface: str 168 | testbench: str 169 | 170 | 171 | EXTRA_ORDER_GOLDEN_TB_PROMPT = r""" 172 | Remember that if the golden testbench contradicts the input_spec, ALWAYS FOLLOW the golden testbench; 173 | Especially if the input_spec say some input should not exist, but as long as the golden testbench uses it, you should use it. 174 | Remember to display "SIMULATION PASSED" when simulation ends if no mismatch occurs, otherwise display "SIMULATION FAILED - x MISMATCHES DETECTED, FIRST AT TIME y". 175 | Remember to add display for the FIRST mismatch, while maintaining the original logic of error counting; 176 | ALWAYS generate the complete testbench, no matter how long it is. 177 | Generate interface according to golden testbench, even if it contradicts the input_spec. Declare all ports as logic. 178 | """ 179 | 180 | EXTRA_ORDER_NON_GOLDEN_TB_PROMPT = r""" 181 | For pattern detecter, if no specification is found in input_spec, 182 | suppose the "detected" output will be asserted on the cycle AFTER the pattern appears in input. 183 | Like when detecting pattern "11", should be like: 184 | // Test case : Two consecutive ones 185 | @(posedge clk); in_ = 1; expected_out = 0; 186 | @(posedge clk); in_ = 1; expected_out = 0; 187 | @(posedge clk); in_ = 0; expected_out = 1; 188 | """ 189 | 190 | COVERAGE_PROMPT = r""" Your task involves a Verilog Design Under Test (DUT) that is currently in its initial phase of testing. 191 | The assignment requires you to generate a binary input sequence to maximize code coverage. 192 | To achieve this, you need to analyze the DUT, considering the logic operations and transitions within the circuit. 193 | This careful analysis will allow you to discern the relationship between the input sequence and the uncovered lines, and thus generate an effective input sequence.)"; 194 | // task_prompt += input_signal_prompt_;""" 195 | 196 | 197 | class TBGenerator: 198 | def __init__( 199 | self, 200 | token_counter: TokenCounter, 201 | ): 202 | self.token_counter = token_counter 203 | self.failed_trial: List[ChatMessage] = [] 204 | self.history: List[ChatMessage] = [] 205 | self.golden_tb_path: str | None = None 206 | self.json_decode_max_trial = 3 207 | self.gen_display_queue = True 208 | 209 | def reset(self): 210 | self.history = [] 211 | 212 | def set_golden_tb_path(self, golden_tb_path: str | None) -> None: 213 | self.golden_tb_path = golden_tb_path 214 | 215 | def set_failed_trial( 216 | self, failed_sim_log: str, previous_code: str, previous_tb: str 217 | ) -> None: 218 | cur_failed_trial = FAILED_TRIAL_PROMPT.format( 219 | failed_sim_log=failed_sim_log, 220 | previous_code=add_lineno(previous_code), 221 | previous_tb=add_lineno(previous_tb), 222 | ) 223 | self.failed_trial.append( 224 | ChatMessage(content=cur_failed_trial, role=MessageRole.USER) 225 | ) 226 | 227 | def generate(self, messages: List[ChatMessage]) -> ChatResponse: 228 | logger.info(f"TB generator input message: {messages}") 229 | resp, token_cnt = self.token_counter.count_chat(messages) 230 | logger.info(f"Token count: {token_cnt}") 231 | logger.info(f"{resp.message.content}") 232 | return resp 233 | 234 | def get_init_prompt_messages(self, input_spec: str) -> List[ChatMessage]: 235 | display_prompt = ( 236 | DISPLAY_QUEUE_PROMPT if self.gen_display_queue else DISPLAY_MOMENT_PROMPT 237 | ) 238 | if self.golden_tb_path: 239 | with open(self.golden_tb_path, "r") as f: 240 | golden_testbench = f.read() 241 | generation_content = GOLDEN_TB_PROMPT.format( 242 | input_spec=input_spec, 243 | golden_testbench=golden_testbench, 244 | display_prompt=display_prompt, 245 | ) 246 | else: 247 | generation_content = NON_GOLDEN_TB_PROMPT.format( 248 | input_spec=input_spec, 249 | examples_prompt=TB_4_SHOT_EXAMPLES, 250 | display_prompt=display_prompt, 251 | ) 252 | ret = [ 253 | ChatMessage(content=SYSTEM_PROMPT, role=MessageRole.SYSTEM), 254 | ChatMessage(content=generation_content, role=MessageRole.USER), 255 | ] 256 | if self.failed_trial: 257 | ret.extend(self.failed_trial) 258 | return ret 259 | 260 | def get_order_prompt_messages(self) -> List[ChatMessage]: 261 | if self.golden_tb_path: 262 | order_prompt_message = ChatMessage( 263 | content=ORDER_PROMPT.format( 264 | output_format="".join(json.dumps(EXAMPLE_OUTPUT, indent=4)) 265 | ) 266 | + EXTRA_ORDER_GOLDEN_TB_PROMPT, 267 | role=MessageRole.USER, 268 | ) 269 | else: 270 | order_prompt_message = ChatMessage( 271 | content=ORDER_PROMPT.format( 272 | output_format="".join(json.dumps(EXAMPLE_OUTPUT, indent=4)) 273 | ) 274 | + EXTRA_ORDER_NON_GOLDEN_TB_PROMPT, 275 | role=MessageRole.USER, 276 | ) 277 | 278 | return [order_prompt_message] 279 | 280 | def parse_output(self, response: ChatResponse) -> TBOutputFormat: 281 | try: 282 | output_json_obj: Dict = json.loads(response.message.content, strict=False) 283 | ret = TBOutputFormat( 284 | reasoning=output_json_obj["reasoning"], 285 | interface=output_json_obj["interface"], 286 | testbench=output_json_obj["testbench"], 287 | ) 288 | except json.decoder.JSONDecodeError as e: 289 | ret = TBOutputFormat( 290 | reasoning=f"Json Decode Error: {str(e)}", 291 | interface="", 292 | testbench="", 293 | ) 294 | return ret 295 | 296 | def chat(self, input_spec: str) -> Tuple[str, str]: 297 | if isinstance(self.token_counter, TokenCounterCached): 298 | self.token_counter.set_enable_cache(False) 299 | self.history = [] 300 | self.token_counter.set_cur_tag(self.__class__.__name__) 301 | self.history.extend(self.get_init_prompt_messages(input_spec)) 302 | for _ in range(self.json_decode_max_trial): 303 | response = self.generate(self.history + self.get_order_prompt_messages()) 304 | resp_obj = self.parse_output(response) 305 | if not resp_obj.reasoning.startswith("Json Decode Error"): 306 | break 307 | error_msg = ChatMessage(role=MessageRole.USER, content=resp_obj.reasoning) 308 | self.history.extend([response.message, error_msg]) 309 | if resp_obj.reasoning.startswith("Json Decode Error"): 310 | raise ValueError( 311 | f"Json Decode Error when decoding: {response.message.content}" 312 | ) 313 | return (resp_obj.testbench, resp_obj.interface) 314 | -------------------------------------------------------------------------------- /src/mage/token_counter.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | from typing import Dict, List, Tuple 4 | 5 | import tiktoken 6 | from anthropic.types import Usage 7 | from llama_index.core.base.llms.types import ChatMessage, ChatResponse 8 | from llama_index.core.llms.llm import LLM 9 | from llama_index.llms.anthropic import Anthropic 10 | from llama_index.llms.openai import OpenAI 11 | from llama_index.llms.vertex import Vertex 12 | from pydantic import BaseModel 13 | from vertexai.preview.generative_models import GenerativeModel 14 | 15 | from .gen_config import get_exp_setting 16 | from .log_utils import get_logger 17 | from .utils import reformat_json_string 18 | 19 | logger = get_logger(__name__) 20 | 21 | settings = get_exp_setting() 22 | setting_args = { 23 | "temperature": settings.temperature, 24 | "top_p": settings.top_p, 25 | } 26 | 27 | 28 | class TokenCount(BaseModel): 29 | """Token count of an LLM call""" 30 | 31 | in_token_cnt: int 32 | out_token_cnt: int 33 | 34 | class Config: 35 | frozen = True 36 | 37 | def __add__(self, other: "TokenCount"): 38 | return TokenCount( 39 | in_token_cnt=self.in_token_cnt + other.in_token_cnt, 40 | out_token_cnt=self.out_token_cnt + other.out_token_cnt, 41 | ) 42 | 43 | def __str__(self) -> str: 44 | return f"in {self.in_token_cnt:>8} tokens, out {self.out_token_cnt:>8} tokens" 45 | 46 | 47 | class TokenCountCached(TokenCount): 48 | cache_write_cnt: int = 0 49 | cache_read_cnt: int = 0 50 | 51 | class Config: 52 | frozen = True 53 | 54 | def __add__(self, other: "TokenCountCached"): 55 | return TokenCountCached( 56 | in_token_cnt=self.in_token_cnt + other.in_token_cnt, 57 | out_token_cnt=self.out_token_cnt + other.out_token_cnt, 58 | cache_read_cnt=self.cache_read_cnt + other.cache_read_cnt, 59 | cache_write_cnt=self.cache_write_cnt + other.cache_write_cnt, 60 | ) 61 | 62 | def __str__(self) -> str: 63 | if not (self.cache_read_cnt or self.cache_write_cnt): 64 | return super().__str__() 65 | return ( 66 | f"in {self.in_token_cnt:>8} tokens, " 67 | f"out {self.out_token_cnt:>8} tokens, " 68 | f"cache write {self.cache_write_cnt:>8} tokens, " 69 | f"cache read {self.cache_read_cnt:>8} tokens" 70 | ) 71 | 72 | 73 | class TokenCost(BaseModel): 74 | """Token cost of an LLM call""" 75 | 76 | in_token_cost_per_token: float = 0.0 77 | out_token_cost_per_token: float = 0.0 78 | 79 | 80 | TOKEN_COSTS = { 81 | "claude-3-5-sonnet-20241022": TokenCost( 82 | in_token_cost_per_token=3.0 / 1000000, out_token_cost_per_token=15.0 / 1000000 83 | ), 84 | "claude-3-5-sonnet@20241022": TokenCost( 85 | in_token_cost_per_token=3.0 / 1000000, out_token_cost_per_token=15.0 / 1000000 86 | ), 87 | "claude-3-7-sonnet-20250219": TokenCost( 88 | in_token_cost_per_token=3.0 / 1000000, out_token_cost_per_token=15.0 / 1000000 89 | ), 90 | "claude-3-7-sonnet@20250219": TokenCost( 91 | in_token_cost_per_token=3.0 / 1000000, out_token_cost_per_token=15.0 / 1000000 92 | ), 93 | "gpt-4o-2024-08-06": TokenCost( 94 | in_token_cost_per_token=2.5 / 1000000, out_token_cost_per_token=10.0 / 1000000 95 | ), 96 | "o1-preview-2024-09-12": TokenCost( 97 | in_token_cost_per_token=15.0 / 1000000, out_token_cost_per_token=60.0 / 1000000 98 | ), 99 | "o1-mini-2024-09-12": TokenCost( 100 | in_token_cost_per_token=3.0 / 1000000, out_token_cost_per_token=12.0 / 1000000 101 | ), 102 | "gpt-4o-2024-05-13": TokenCost( 103 | in_token_cost_per_token=5.0 / 1000000, out_token_cost_per_token=15.0 / 1000000 104 | ), 105 | "gemini-1.5-pro-002": TokenCost( 106 | in_token_cost_per_token=1.25 / 1000000, out_token_cost_per_token=5.0 / 1000000 107 | ), 108 | "gemini-2.0-flash-001": TokenCost( 109 | in_token_cost_per_token=0.1 / 1000000, out_token_cost_per_token=0.4 / 1000000 110 | ), 111 | } 112 | 113 | 114 | class TokenCounter: 115 | """Token counter based on tiktoken / Anthropic""" 116 | 117 | def __init__(self, llm: LLM) -> None: 118 | self.llm = llm 119 | self.token_cnts: Dict[str, List[TokenCount]] = {"": []} 120 | self.token_cnts_lock = asyncio.Lock() 121 | self.cur_tag = "" 122 | self.max_parallel_requests: int = 10 123 | self.enable_reformat_json = isinstance(llm, Vertex) 124 | model = llm.metadata.model_name 125 | if isinstance(llm, OpenAI): 126 | self.encoding = tiktoken.encoding_for_model(model) 127 | elif isinstance(llm, Anthropic): 128 | self.encoding = llm.tokenizer 129 | elif isinstance(llm, Vertex): 130 | assert llm.model.startswith( 131 | "gemini" 132 | ), f"Non-gemini Vertex model is not supported: {llm.model}" 133 | assert isinstance(llm._client, GenerativeModel) 134 | 135 | class VertexEncoding: 136 | def __init__(self, client: GenerativeModel): 137 | self.client = client 138 | 139 | def encode(self, text: str) -> List[str]: 140 | token_len = self.client.count_tokens(text).total_tokens 141 | return ["placeholder" for _ in range(token_len)] 142 | 143 | self.encoding = VertexEncoding(llm._client) 144 | self.activate_structure_output = True 145 | else: 146 | logger.warning( 147 | f"Cannot find tokenizer for model '{model}'. " 148 | "May need to change mage.token_counter.TokenCounter.__init__" 149 | ) 150 | self.encoding = None 151 | self.token_cost = TokenCost() 152 | return 153 | logger.info(f"Found tokenizer for model '{model}'") 154 | if model not in TOKEN_COSTS: 155 | self.token_cost = TokenCost() 156 | logger.warning( 157 | f"Cannot find token cost for model '{model}' in record. " 158 | "May need to append to mage.token_counter.TOKEN_COSTS" 159 | ) 160 | return 161 | self.token_cost = TOKEN_COSTS[model] 162 | 163 | def set_cur_tag(self, tag: str) -> None: 164 | self.cur_tag = tag 165 | if tag not in self.token_cnts: 166 | self.token_cnts[tag] = [] 167 | 168 | def count(self, string: str) -> int: 169 | if self.encoding is None: 170 | return 0 171 | return len(self.encoding.encode(string)) 172 | 173 | def reset(self) -> None: 174 | self.token_cnts = {"": []} 175 | 176 | def count_chat( 177 | self, messages: List[ChatMessage], llm: LLM | None = None 178 | ) -> Tuple[ChatResponse, TokenCount]: 179 | llm = llm or self.llm 180 | in_token_cnt = self.count(llm.messages_to_prompt(messages)) 181 | logger.info( 182 | "TokenCounter count_chat Triggered at temp: %s, top_p: %s" 183 | % (settings.temperature, settings.top_p) 184 | ) 185 | response = llm.chat( 186 | messages, top_p=settings.top_p, temperature=settings.temperature 187 | ) 188 | out_token_cnt = self.count(response.message.content) 189 | token_cnt = TokenCount(in_token_cnt=in_token_cnt, out_token_cnt=out_token_cnt) 190 | self.token_cnts[self.cur_tag].append(token_cnt) 191 | if self.enable_reformat_json: 192 | response.message.content = reformat_json_string(response.message.content) 193 | return (response, token_cnt) 194 | 195 | async def count_achat( 196 | self, messages: List[ChatMessage], llm: LLM | None = None 197 | ) -> Tuple[ChatResponse, TokenCount]: 198 | llm = llm or self.llm 199 | in_token_cnt = self.count(llm.messages_to_prompt(messages)) 200 | logger.info( 201 | "TokenCounter count_achat Triggered at temp: %s, top_p: %s" 202 | % (settings.temperature, settings.top_p) 203 | ) 204 | response = await llm.achat( 205 | messages, top_p=settings.top_p, temperature=settings.temperature 206 | ) 207 | out_token_cnt = self.count(response.message.content) 208 | token_cnt = TokenCount(in_token_cnt=in_token_cnt, out_token_cnt=out_token_cnt) 209 | async with self.token_cnts_lock: 210 | self.token_cnts[self.cur_tag].append(token_cnt) 211 | if self.enable_reformat_json: 212 | response.message.content = reformat_json_string(response.message.content) 213 | return (response, token_cnt) 214 | 215 | async def count_achat_batch( 216 | self, chat_inputs: List[List[ChatMessage]], llm: LLM | None = None 217 | ) -> List[Tuple[ChatResponse, TokenCount]]: 218 | llm = llm or self.llm 219 | results = [] 220 | for i in range(0, len(chat_inputs), self.max_parallel_requests): 221 | batch = chat_inputs[i : i + self.max_parallel_requests] 222 | tasks = [ 223 | self.count_achat(llm=llm, messages=chat_input) for chat_input in batch 224 | ] 225 | batch_results = await asyncio.gather(*tasks) 226 | results.extend(batch_results) 227 | return results 228 | 229 | def count_chat_batch( 230 | self, chat_inputs: List[List[ChatMessage]], llm: LLM | None = None 231 | ) -> List[Tuple[ChatResponse, TokenCount]]: 232 | llm = llm or self.llm 233 | try: 234 | # Get the current event loop 235 | loop = asyncio.get_event_loop() 236 | except RuntimeError: 237 | # If there is no current event loop, create a new one 238 | loop = asyncio.new_event_loop() 239 | asyncio.set_event_loop(loop) 240 | start_time = time.time() 241 | results = loop.run_until_complete( 242 | self.count_achat_batch(llm=llm, chat_inputs=chat_inputs) 243 | ) 244 | logger.info(f"Total batch chat time: {time.time() - start_time:.2f}s") 245 | return results 246 | 247 | def log_token_stats(self) -> None: 248 | total_sum_cnt = TokenCount(in_token_cnt=0, out_token_cnt=0) 249 | for tag in self.token_cnts: 250 | token_cnt = self.token_cnts[tag] 251 | if not token_cnt: 252 | continue 253 | sum_cnt = sum(token_cnt, start=TokenCount(in_token_cnt=0, out_token_cnt=0)) 254 | assert isinstance(sum_cnt, TokenCount) 255 | total_sum_cnt += sum_cnt 256 | logger.info(f"{tag + ' cnt':<25}: {sum_cnt}") 257 | logger.info((f"{'Total cnt':<25}: {total_sum_cnt}")) 258 | if self.token_cost: 259 | total_cost = ( 260 | total_sum_cnt.in_token_cnt * self.token_cost.in_token_cost_per_token 261 | + total_sum_cnt.out_token_cnt * self.token_cost.out_token_cost_per_token 262 | ) 263 | logger.info(f"{'Total cost':<25}: ${total_cost:.2f} USD") 264 | 265 | def get_sum_count(self, tag: str | None = None) -> TokenCount: 266 | # If have tag: return sum of token counts with that tag 267 | # If no tag: return sum of all token counts 268 | if tag: 269 | token_cnt = self.token_cnts[tag] 270 | sum_cnt = sum(token_cnt, start=TokenCount(in_token_cnt=0, out_token_cnt=0)) 271 | else: 272 | sum_cnt = TokenCount(in_token_cnt=0, out_token_cnt=0) 273 | for token_cnt in self.token_cnts.values(): 274 | sum_cnt += sum( 275 | token_cnt, start=TokenCount(in_token_cnt=0, out_token_cnt=0) 276 | ) 277 | assert isinstance(sum_cnt, TokenCount) 278 | return sum_cnt 279 | 280 | def get_total_token(self) -> int: 281 | """Return token number regarding to token limit""" 282 | sum_cnt = TokenCount(in_token_cnt=0, out_token_cnt=0) 283 | for token_cnt in self.token_cnts.values(): 284 | tag_cnt = sum(token_cnt, start=TokenCount(in_token_cnt=0, out_token_cnt=0)) 285 | assert isinstance(tag_cnt, TokenCount) 286 | sum_cnt += tag_cnt 287 | assert isinstance(sum_cnt, TokenCount) 288 | return sum_cnt.in_token_cnt + sum_cnt.out_token_cnt 289 | 290 | 291 | class TokenCounterCached(TokenCounter): 292 | """Token counter with cache based on Anthropic""" 293 | 294 | def __init__(self, llm: LLM) -> None: 295 | super().__init__(llm) 296 | assert isinstance(llm, Anthropic) 297 | self.write_cost_ratio: float = 1.25 298 | self.read_cost_ratio: float = 0.1 299 | self.enable_cache = True 300 | 301 | def set_enable_cache(self, enable_cache: bool) -> None: 302 | self.enable_cache = enable_cache 303 | 304 | def equivalent_cost(self, token_count_cached: TokenCountCached) -> TokenCount: 305 | equi_cost = round( 306 | token_count_cached.in_token_cnt 307 | + token_count_cached.cache_write_cnt * self.write_cost_ratio 308 | + token_count_cached.cache_read_cnt * self.read_cost_ratio 309 | ) 310 | return TokenCount( 311 | in_token_cnt=equi_cost, 312 | out_token_cnt=token_count_cached.out_token_cnt, 313 | ) 314 | 315 | @classmethod 316 | def is_cache_enabled(cls, llm: LLM) -> bool: 317 | return isinstance(llm, Anthropic) 318 | 319 | def add_cache_tag(self, target: ChatMessage) -> None: 320 | target.additional_kwargs["cache_control"] = {"type": "ephemeral"} 321 | 322 | def count_chat( 323 | self, messages: List[ChatMessage], llm: LLM | None = None 324 | ) -> Tuple[ChatResponse, TokenCountCached]: 325 | llm = llm or self.llm 326 | logger.info( 327 | "TokenCounterCached count_chat Triggered at temp: %s, top_p: %s" 328 | % (settings.temperature, settings.top_p) 329 | ) 330 | response = llm.chat( 331 | messages, 332 | top_p=settings.top_p, 333 | temperature=settings.temperature, 334 | ) 335 | usage = response.raw["usage"] 336 | assert isinstance(usage, Usage), f"Unknown usage type: {type(usage)}" 337 | token_cnt = TokenCountCached( 338 | in_token_cnt=usage.input_tokens, 339 | out_token_cnt=usage.output_tokens, 340 | cache_write_cnt=( 341 | usage.cache_creation_input_tokens 342 | if hasattr(usage, "cache_creation_input_tokens") 343 | else 0 344 | ), 345 | cache_read_cnt=( 346 | usage.cache_read_input_tokens 347 | if hasattr(usage, "cache_read_input_tokens") 348 | else 0 349 | ), 350 | ) 351 | self.token_cnts[self.cur_tag].append(token_cnt) 352 | if self.enable_reformat_json: 353 | response.message.content = reformat_json_string(response.message.content) 354 | return (response, token_cnt) 355 | 356 | async def count_achat( 357 | self, messages: List[ChatMessage], llm: LLM | None = None 358 | ) -> Tuple[ChatResponse, TokenCountCached]: 359 | llm = llm or self.llm 360 | logger.info( 361 | "TokenCounterCached count_achat Triggered at temp: %s, top_p: %s" 362 | % (settings.temperature, settings.top_p) 363 | ) 364 | response = await llm.achat( 365 | messages, 366 | top_p=settings.top_p, 367 | temperature=settings.temperature, 368 | ) 369 | usage = response.raw["usage"] 370 | assert isinstance(usage, Usage), f"Unknown usage type: {type(usage)}" 371 | token_cnt = TokenCountCached( 372 | in_token_cnt=usage.input_tokens, 373 | out_token_cnt=usage.output_tokens, 374 | cache_write_cnt=( 375 | usage.cache_creation_input_tokens 376 | if hasattr(usage, "cache_creation_input_tokens") 377 | else 0 378 | ), 379 | cache_read_cnt=( 380 | usage.cache_read_input_tokens 381 | if hasattr(usage, "cache_read_input_tokens") 382 | else 0 383 | ), 384 | ) 385 | async with self.token_cnts_lock: 386 | self.token_cnts[self.cur_tag].append(token_cnt) 387 | if self.enable_reformat_json: 388 | response.message.content = reformat_json_string(response.message.content) 389 | return (response, token_cnt) 390 | 391 | def log_token_stats(self) -> None: 392 | total_sum_cnt = TokenCountCached(in_token_cnt=0, out_token_cnt=0) 393 | for tag in self.token_cnts: 394 | token_cnt = self.token_cnts[tag] 395 | if not token_cnt: 396 | continue 397 | sum_cnt = sum( 398 | token_cnt, start=TokenCountCached(in_token_cnt=0, out_token_cnt=0) 399 | ) 400 | assert isinstance(sum_cnt, TokenCountCached) 401 | 402 | total_sum_cnt += sum_cnt 403 | sum_equal_cnt = self.equivalent_cost(sum_cnt) 404 | 405 | if sum_cnt.cache_write_cnt or sum_cnt.cache_read_cnt: 406 | logger.info(f"{tag + ' cnt':<25}: {sum_cnt}") 407 | logger.info(f"{tag + ' equal cnt':<25}: {sum_equal_cnt}") 408 | else: 409 | logger.info(f"{tag + ' cnt':<25}: {sum_equal_cnt}") 410 | 411 | total_sum_equal_cnt = self.equivalent_cost(total_sum_cnt) 412 | if total_sum_cnt.cache_write_cnt or total_sum_cnt.cache_read_cnt: 413 | saved_tokens = round( 414 | total_sum_cnt.cache_write_cnt * (1 - self.write_cost_ratio) 415 | + total_sum_cnt.cache_read_cnt * (1 - self.read_cost_ratio) 416 | ) 417 | logger.info( 418 | f"{'Total cached cnt':<25}: {total_sum_cnt}, saved {saved_tokens:>8} tokens" 419 | ) 420 | logger.info(f"{'Total equal cnt':<25}: {total_sum_equal_cnt}") 421 | else: 422 | logger.info(f"{'Total cnt':<25}: {total_sum_equal_cnt}") 423 | if self.token_cost: 424 | total_cost = ( 425 | total_sum_equal_cnt.in_token_cnt 426 | * self.token_cost.in_token_cost_per_token 427 | + total_sum_equal_cnt.out_token_cnt 428 | * self.token_cost.out_token_cost_per_token 429 | ) 430 | logger.info(f"{'Total cost':<25}: ${total_cost:.2f} USD") 431 | 432 | def get_sum_count_cached(self, tag: str | None = None) -> TokenCount: 433 | # If have tag: return sum of token counts with that tag 434 | # If no tag: return sum of all token counts 435 | if tag: 436 | token_cnt = self.token_cnts[tag] 437 | sum_cnt = sum( 438 | token_cnt, start=TokenCountCached(in_token_cnt=0, out_token_cnt=0) 439 | ) 440 | else: 441 | sum_cnt = TokenCountCached(in_token_cnt=0, out_token_cnt=0) 442 | for token_cnt in self.token_cnts.values(): 443 | sum_cnt += sum( 444 | token_cnt, start=TokenCountCached(in_token_cnt=0, out_token_cnt=0) 445 | ) 446 | assert isinstance(sum_cnt, TokenCount) 447 | return sum_cnt 448 | 449 | def get_sum_count(self, tag: str | None = None) -> TokenCount: 450 | sum_cnt_cached = self.get_sum_count_cached(tag) 451 | sum_cnt = ( 452 | self.equivalent_cost(sum_cnt_cached) 453 | if isinstance(sum_cnt_cached, TokenCountCached) 454 | else sum_cnt_cached 455 | ) 456 | return sum_cnt 457 | 458 | def get_total_token(self) -> int: 459 | """Return token number regarding to token limit""" 460 | sum_cnt_cached = self.get_sum_count_cached() 461 | sum_token_cnt = sum_cnt_cached.in_token_cnt + sum_cnt_cached.out_token_cnt 462 | if isinstance(sum_cnt_cached, TokenCountCached): 463 | sum_token_cnt += ( 464 | sum_cnt_cached.cache_write_cnt + sum_cnt_cached.cache_read_cnt 465 | ) 466 | return sum_token_cnt 467 | -------------------------------------------------------------------------------- /src/mage/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import anthropic 4 | from llama_index.llms.anthropic import Anthropic 5 | 6 | 7 | def add_lineno(file_content: str) -> str: 8 | lines = file_content.split("\n") 9 | ret = "" 10 | for i, line in enumerate(lines): 11 | ret += f"{i+1}: {line}\n" 12 | return ret 13 | 14 | 15 | def reformat_json_string(output: str) -> str: 16 | # in gemini, the output has markdown surrounding the json string 17 | # like ```json ... ``` 18 | # we need to remove the markdown 19 | # remove by using regex between ```json and ``` 20 | pattern = r"```json(.*?)```" 21 | match = re.search(pattern, output, re.DOTALL) 22 | if match: 23 | return match.group(1).strip() 24 | 25 | pattern = r"```xml(.*?)```" 26 | match = re.search(pattern, output, re.DOTALL) 27 | if match: 28 | return match.group(1).strip() 29 | 30 | return output.strip() 31 | 32 | 33 | class VertexAnthropicWithCredentials(Anthropic): 34 | def __init__(self, credentials, **kwargs): 35 | """ 36 | In addition to all parameters accepted by Anthropic, this class accepts a 37 | new parameter `credentials` that will be passed to the underlying clients. 38 | """ 39 | # Pop parameters that determine client type so we can reuse them in our branch. 40 | region = kwargs.get("region") 41 | project_id = kwargs.get("project_id") 42 | aws_region = kwargs.get("aws_region") 43 | 44 | # Call the parent initializer; this sets up a default _client and _aclient. 45 | super().__init__(**kwargs) 46 | 47 | # If using AnthropicVertex (i.e., region and project_id are provided and aws_region is None), 48 | # override the _client and _aclient with the additional credentials parameter. 49 | if region and project_id and not aws_region: 50 | self._client = anthropic.AnthropicVertex( 51 | region=region, 52 | project_id=project_id, 53 | credentials=credentials, # extra argument 54 | timeout=self.timeout, 55 | max_retries=self.max_retries, 56 | default_headers=kwargs.get("default_headers"), 57 | ) 58 | self._aclient = anthropic.AsyncAnthropicVertex( 59 | region=region, 60 | project_id=project_id, 61 | credentials=credentials, # extra argument 62 | timeout=self.timeout, 63 | max_retries=self.max_retries, 64 | default_headers=kwargs.get("default_headers"), 65 | ) 66 | # Optionally, you could add similar overrides for the aws_region branch if needed. 67 | -------------------------------------------------------------------------------- /src/sim/.gitignore: -------------------------------------------------------------------------------- 1 | obj_dir 2 | -------------------------------------------------------------------------------- /src/sim/Makefile: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # 3 | # DESCRIPTION: Verilator Example: Small Makefile 4 | # 5 | # This calls the object directory makefile. That allows the objects to 6 | # be placed in the "current directory" which simplifies the Makefile. 7 | # 8 | # This file ONLY is placed under the Creative Commons Public Domain, for 9 | # any use, without warranty, 2020 by Wilson Snyder. 10 | # SPDX-License-Identifier: CC0-1.0 11 | # 12 | ###################################################################### 13 | # Check for sanity to avoid later confusion 14 | 15 | ifneq ($(words $(CURDIR)),1) 16 | $(error Unsupported: GNU Make cannot build in directories containing spaces, build elsewhere: '$(CURDIR)') 17 | endif 18 | 19 | ###################################################################### 20 | # Set up variables 21 | 22 | # If $VERILATOR_ROOT isn't in the environment, we assume it is part of a 23 | # package install, and verilator is in your path. Otherwise find the 24 | # binary relative to $VERILATOR_ROOT (such as when inside the git sources). 25 | ifeq ($(VERILATOR_ROOT),) 26 | VERILATOR = verilator 27 | VERILATOR_COVERAGE = verilator_coverage 28 | else 29 | export VERILATOR_ROOT 30 | VERILATOR = $(VERILATOR_ROOT)/bin/verilator 31 | VERILATOR_COVERAGE = $(VERILATOR_ROOT)/bin/verilator_coverage 32 | endif 33 | 34 | # Generate C++ in executable form 35 | VERILATOR_FLAGS += -cc --exe 36 | # Generate makefile dependencies (not shown as complicates the Makefile) 37 | #VERILATOR_FLAGS += -MMD 38 | # Optimize 39 | VERILATOR_FLAGS += -x-assign fast 40 | # Warn abount lint issues; may not want this on less solid designs 41 | VERILATOR_FLAGS += -Wall 42 | # Make waveforms 43 | # VERILATOR_FLAGS += --trace 44 | # Check SystemVerilog assertions 45 | VERILATOR_FLAGS += --assert 46 | # Generate coverage analysis 47 | VERILATOR_FLAGS += --coverage-line 48 | 49 | 50 | # Run Verilator in debug mode 51 | #VERILATOR_FLAGS += --debug 52 | # Add this trace to get a backtrace in gdb 53 | #VERILATOR_FLAGS += --gdbbt 54 | 55 | # Ignore warnings 56 | VERILATOR_FLAGS += -Wno-WIDTHEXPAND -Wno-BLKSEQ -Wno-VARHIDDEN -Wno-WIDTHTRUNC -Wno-UNUSEDSIGNAL 57 | 58 | # Input files for Verilator 59 | VERILATOR_INPUT = -f input.vc top.sv ../src-basic/sim-main.cpp ../src-basic/rfuzz-harness.cpp ../llm-guidance/src/LLMGuidance.cpp ../llm-guidance/src/LLMGuidance4CodeCov.cpp 60 | 61 | ###################################################################### 62 | default: run 63 | 64 | run: 65 | @echo 66 | @echo "-- Verilator tracing example" 67 | 68 | @echo 69 | @echo "-- VERILATE ----------------" 70 | $(VERILATOR) $(VERILATOR_FLAGS) $(VERILATOR_INPUT) 71 | 72 | @echo 73 | @echo "-- BUILD -------------------" 74 | # To compile, we can either 75 | # 1. Pass --build to Verilator by editing VERILATOR_FLAGS above. 76 | # 2. Or, run the make rules Verilator does: 77 | # $(MAKE) -j -C obj_dir -f Vtop.mk 78 | # 3. Or, call a submakefile where we can override the rules ourselves: 79 | $(MAKE) -j -C obj_dir -f ../Makefile_obj 80 | 81 | @echo 82 | @echo "-- RUN ---------------------" 83 | @rm -rf logs 84 | @mkdir -p logs 85 | obj_dir/Vtop 86 | 87 | @echo 88 | @echo "-- TOTAL COVERAGE ----------------" 89 | @rm -rf logs/annotated 90 | # $(VERILATOR_COVERAGE) --annotate logs/total-annotated logs/coverage.dat.total --annotate-min 1 91 | $(VERILATOR_COVERAGE) --annotate logs/total-annotated logs/coverage.dat.total --annotate-min 1 92 | 93 | # @echo 94 | # @echo "-- DONE --------------------" 95 | # @echo "To see waveforms, open vlt_dump.vcd in a waveform viewer" 96 | # @echo 97 | 98 | 99 | ###################################################################### 100 | # Other targets 101 | 102 | show-config: 103 | $(VERILATOR) -V 104 | 105 | maintainer-copy:: 106 | clean mostlyclean distclean maintainer-clean:: 107 | -rm -rf obj_dir logs *.log *.dmp *.vpd coverage.dat core 108 | -------------------------------------------------------------------------------- /src/sim/Makefile_obj: -------------------------------------------------------------------------------- 1 | # -*- Makefile -*- 2 | ####################################################################### 3 | # 4 | # DESCRIPTION: Verilator Example: Makefile for inside object directory 5 | # 6 | # This is executed in the object directory, and called by ../Makefile 7 | # 8 | # This file ONLY is placed under the Creative Commons Public Domain, for 9 | # any use, without warranty, 2020 by Wilson Snyder. 10 | # SPDX-License-Identifier: CC0-1.0 11 | # 12 | ####################################################################### 13 | 14 | default: Vtop 15 | 16 | # Include the rules made by Verilator 17 | include Vtop.mk 18 | 19 | # Link flags 20 | # curl is needed in openai-cpp 21 | SC_LIBS += -lcurl 22 | CXX = g++ 23 | 24 | # Use OBJCACHE (ccache) if using gmake and its installed 25 | COMPILE.cc = $(OBJCACHE) $(CXX) $(CXXFLAGS) $(CPPFLAGS) $(TARGET_ARCH) -c 26 | 27 | ####################################################################### 28 | # Compile flags 29 | 30 | # Turn on creating .d make dependency files 31 | CPPFLAGS += -MMD -MP 32 | CPPFLAGS += -I ../../include/json-develop/include/ 33 | 34 | # Compile in Verilator runtime debugging, so +verilator+debug works 35 | CPPFLAGS += -DVL_DEBUG=1 36 | 37 | 38 | 39 | # Turn on some more compiler lint flags (when configured appropriately) 40 | # For testing inside Verilator, "configure --enable-ccwarn" will do this 41 | # automatically; otherwise you may want this unconditionally enabled 42 | ifeq ($(CFG_WITH_CCWARN),yes) # Local... Else don't burden users 43 | USER_CPPFLAGS_WALL += -W -Werror -Wall 44 | endif 45 | 46 | # See the benchmarking section of bin/verilator. 47 | # Support class optimizations. This includes the tracing and symbol table. 48 | # SystemC takes minutes to optimize, thus it is off by default. 49 | OPT_SLOW = 50 | 51 | # Fast path optimizations. Most time is spent in these classes. 52 | OPT_FAST = -Os -fstrict-aliasing 53 | #OPT_FAST = -O 54 | #OPT_FAST = 55 | 56 | ###################################################################### 57 | ###################################################################### 58 | # Automatically understand dependencies 59 | 60 | DEPS := $(wildcard *.d) 61 | ifneq ($(DEPS),) 62 | include $(DEPS) 63 | endif 64 | -------------------------------------------------------------------------------- /src/sim/input.vc: -------------------------------------------------------------------------------- 1 | // This file typically lists flags required by a large project, e.g. include directories 2 | +librescan +libext+.v+.sv+.vh+.svh -y . 3 | -------------------------------------------------------------------------------- /src/sim/sim_golden.vvp: -------------------------------------------------------------------------------- 1 | #! /usr/local/bin/vvp 2 | :ivl_version "12.0 (stable)" "(v12_0)"; 3 | :ivl_delay_selection "TYPICAL"; 4 | :vpi_time_precision - 12; 5 | :vpi_module "/usr/local/lib/ivl/system.vpi"; 6 | :vpi_module "/usr/local/lib/ivl/vhdl_sys.vpi"; 7 | :vpi_module "/usr/local/lib/ivl/vhdl_textio.vpi"; 8 | :vpi_module "/usr/local/lib/ivl/v2005_math.vpi"; 9 | :vpi_module "/usr/local/lib/ivl/va_math.vpi"; 10 | :vpi_module "/usr/local/lib/ivl/v2009.vpi"; 11 | S_0x650d19394860 .scope package, "$unit" "$unit" 2 1; 12 | .timescale 0 0; 13 | S_0x650d193949f0 .scope module, "tb" "tb" 3 19; 14 | .timescale -12 -12; 15 | L_0x650d19383b90 .functor NOT 1, L_0x650d193aeff0, C4<0>, C4<0>, C4<0>; 16 | L_0x7b29ff3550f0 .functor BUFT 1, C4<0>, C4<0>, C4<0>, C4<0>; 17 | L_0x7b29ff355138 .functor BUFT 1, C4<0>, C4<0>, C4<0>, C4<0>; 18 | L_0x650d1938ccb0 .functor XOR 1, L_0x7b29ff3550f0, L_0x7b29ff355138, C4<0>, C4<0>; 19 | L_0x7b29ff355180 .functor BUFT 1, C4<0>, C4<0>, C4<0>, C4<0>; 20 | L_0x650d193aef00 .functor XOR 1, L_0x650d1938ccb0, L_0x7b29ff355180, C4<0>, C4<0>; 21 | v0x650d193ae3e0_0 .net *"_ivl_10", 0 0, L_0x7b29ff355180; 1 drivers 22 | v0x650d193ae4e0_0 .net *"_ivl_12", 0 0, L_0x650d193aef00; 1 drivers 23 | L_0x7b29ff3550a8 .functor BUFT 1, C4<0>, C4<0>, C4<0>, C4<0>; 24 | v0x650d193ae5c0_0 .net *"_ivl_2", 0 0, L_0x7b29ff3550a8; 1 drivers 25 | v0x650d193ae680_0 .net *"_ivl_4", 0 0, L_0x7b29ff3550f0; 1 drivers 26 | v0x650d193ae760_0 .net *"_ivl_6", 0 0, L_0x7b29ff355138; 1 drivers 27 | v0x650d193ae890_0 .net *"_ivl_8", 0 0, L_0x650d1938ccb0; 1 drivers 28 | v0x650d193ae970_0 .var "clk", 0 0; 29 | L_0x7b29ff355060 .functor BUFT 1, C4<0>, C4<0>, C4<0>, C4<0>; 30 | v0x650d193aea10_0 .net "out_dut", 0 0, L_0x7b29ff355060; 1 drivers 31 | L_0x7b29ff355018 .functor BUFT 1, C4<0>, C4<0>, C4<0>, C4<0>; 32 | v0x650d193aeab0_0 .net "out_ref", 0 0, L_0x7b29ff355018; 1 drivers 33 | v0x650d193aeb50_0 .var/2u "stats1", 159 0; 34 | v0x650d193aebf0_0 .var/2u "strobe", 0 0; 35 | v0x650d193aec90_0 .net "tb_match", 0 0, L_0x650d193aeff0; 1 drivers 36 | v0x650d193aed50_0 .net "tb_mismatch", 0 0, L_0x650d19383b90; 1 drivers 37 | L_0x650d193aeff0 .cmp/eeq 1, L_0x7b29ff3550a8, L_0x650d193aef00; 38 | S_0x650d19352cf0 .scope module, "good1" "RefModule" 3 56, 4 2 0, S_0x650d193949f0; 39 | .timescale -12 -12; 40 | .port_info 0 /OUTPUT 1 "out"; 41 | v0x650d1937e5b0_0 .net "out", 0 0, L_0x7b29ff355018; alias, 1 drivers 42 | S_0x650d193add10 .scope module, "stim1" "stimulus_gen" 3 53, 3 5 0, S_0x650d193949f0; 43 | .timescale -12 -12; 44 | .port_info 0 /INPUT 1 "clk"; 45 | v0x650d1937e650_0 .net "clk", 0 0, v0x650d193ae970_0; 1 drivers 46 | E_0x650d19394eb0/0 .event negedge, v0x650d1937e650_0; 47 | E_0x650d19394eb0/1 .event posedge, v0x650d1937e650_0; 48 | E_0x650d19394eb0 .event/or E_0x650d19394eb0/0, E_0x650d19394eb0/1; 49 | S_0x650d193adf80 .scope module, "top_module1" "topModule" 3 59, 5 1 0, S_0x650d193949f0; 50 | .timescale -12 -12; 51 | .port_info 0 /OUTPUT 1 "out"; 52 | v0x650d193ae110_0 .net "out", 0 0, L_0x7b29ff355060; alias, 1 drivers 53 | S_0x650d193ae230 .scope task, "wait_for_end_of_timestep" "wait_for_end_of_timestep" 3 64, 3 64 0, S_0x650d193949f0; 54 | .timescale -12 -12; 55 | E_0x650d19394ef0 .event anyedge, v0x650d193aebf0_0; 56 | TD_tb.wait_for_end_of_timestep ; 57 | %pushi/vec4 5, 0, 32; 58 | T_0.0 %dup/vec4; 59 | %pushi/vec4 0, 0, 32; 60 | %cmp/s; 61 | %jmp/1xz T_0.1, 5; 62 | %jmp/1 T_0.1, 4; 63 | %pushi/vec4 1, 0, 32; 64 | %sub; 65 | %load/vec4 v0x650d193aebf0_0; 66 | %nor/r; 67 | %assign/vec4 v0x650d193aebf0_0, 0; 68 | %wait E_0x650d19394ef0; 69 | %jmp T_0.0; 70 | T_0.1 ; 71 | %pop/vec4 1; 72 | %end; 73 | .scope S_0x650d193add10; 74 | T_1 ; 75 | %pushi/vec4 100, 0, 32; 76 | T_1.0 %dup/vec4; 77 | %pushi/vec4 0, 0, 32; 78 | %cmp/s; 79 | %jmp/1xz T_1.1, 5; 80 | %jmp/1 T_1.1, 4; 81 | %pushi/vec4 1, 0, 32; 82 | %sub; 83 | %wait E_0x650d19394eb0; 84 | %jmp T_1.0; 85 | T_1.1 ; 86 | %pop/vec4 1; 87 | %delay 1, 0; 88 | %vpi_call/w 3 14 "$finish" {0 0 0}; 89 | %end; 90 | .thread T_1; 91 | .scope S_0x650d193949f0; 92 | T_2 ; 93 | %pushi/vec4 0, 0, 1; 94 | %store/vec4 v0x650d193ae970_0, 0, 1; 95 | %pushi/vec4 0, 0, 1; 96 | %store/vec4 v0x650d193aebf0_0, 0, 1; 97 | %end; 98 | .thread T_2, $init; 99 | .scope S_0x650d193949f0; 100 | T_3 ; 101 | T_3.0 ; 102 | %delay 5, 0; 103 | %load/vec4 v0x650d193ae970_0; 104 | %inv; 105 | %store/vec4 v0x650d193ae970_0, 0, 1; 106 | %jmp T_3.0; 107 | %end; 108 | .thread T_3; 109 | .scope S_0x650d193949f0; 110 | T_4 ; 111 | %vpi_call/w 3 45 "$dumpfile", "wave.vcd" {0 0 0}; 112 | %vpi_call/w 3 46 "$dumpvars", 32'sb00000000000000000000000000000001, v0x650d1937e650_0, v0x650d193aed50_0, v0x650d193aeab0_0, v0x650d193aea10_0 {0 0 0}; 113 | %end; 114 | .thread T_4; 115 | .scope S_0x650d193949f0; 116 | T_5 ; 117 | %load/vec4 v0x650d193aeb50_0; 118 | %parti/u 32, 64, 32; 119 | %cmpi/ne 0, 0, 32; 120 | %jmp/0xz T_5.0, 4; 121 | %load/vec4 v0x650d193aeb50_0; 122 | %parti/u 32, 64, 32; 123 | %load/vec4 v0x650d193aeb50_0; 124 | %parti/u 32, 32, 32; 125 | %vpi_call/w 3 73 "$display", "Hint: Output '%s' has %0d mismatches. First mismatch occurred at time %0d.", "out", S<1,vec4,s32>, S<0,vec4,s32> {2 0 0}; 126 | %jmp T_5.1; 127 | T_5.0 ; 128 | %vpi_call/w 3 74 "$display", "Hint: Output '%s' has no mismatches.", "out" {0 0 0}; 129 | T_5.1 ; 130 | %load/vec4 v0x650d193aeb50_0; 131 | %parti/u 32, 128, 32; 132 | %load/vec4 v0x650d193aeb50_0; 133 | %parti/u 32, 0, 32; 134 | %vpi_call/w 3 76 "$display", "Hint: Total mismatched samples is %1d out of %1d samples\012", S<1,vec4,s32>, S<0,vec4,s32> {2 0 0}; 135 | %vpi_call/w 3 77 "$display", "Simulation finished at %0d ps", $time {0 0 0}; 136 | %load/vec4 v0x650d193aeb50_0; 137 | %parti/u 32, 128, 32; 138 | %load/vec4 v0x650d193aeb50_0; 139 | %parti/u 32, 0, 32; 140 | %vpi_call/w 3 78 "$display", "Mismatches: %1d in %1d samples", S<1,vec4,s32>, S<0,vec4,s32> {2 0 0}; 141 | %end; 142 | .thread T_5, $final; 143 | .scope S_0x650d193949f0; 144 | T_6 ; 145 | %wait E_0x650d19394eb0; 146 | ; show_stmt_assign_vector: Get l-value for compressed += operand 147 | %load/vec4 v0x650d193aeb50_0; 148 | %pushi/vec4 0, 0, 32; 149 | %part/u 32; 150 | %pushi/vec4 1, 0, 32; 151 | %add; 152 | %cast2; 153 | %ix/load 4, 0, 0; 154 | %flag_set/imm 4, 0; 155 | %store/vec4 v0x650d193aeb50_0, 4, 32; 156 | %load/vec4 v0x650d193aec90_0; 157 | %nor/r; 158 | %flag_set/vec4 8; 159 | %jmp/0xz T_6.0, 8; 160 | %load/vec4 v0x650d193aeb50_0; 161 | %parti/u 32, 128, 32; 162 | %cmpi/e 0, 0, 32; 163 | %jmp/0xz T_6.2, 4; 164 | %vpi_func 3 89 "$time" 64 {0 0 0}; 165 | %cast2; 166 | %pad/u 32; 167 | %ix/load 4, 96, 0; 168 | %flag_set/imm 4, 0; 169 | %store/vec4 v0x650d193aeb50_0, 4, 32; 170 | T_6.2 ; 171 | ; show_stmt_assign_vector: Get l-value for compressed += operand 172 | %load/vec4 v0x650d193aeb50_0; 173 | %pushi/vec4 128, 0, 32; 174 | %part/u 32; 175 | %pushi/vec4 1, 0, 32; 176 | %add; 177 | %cast2; 178 | %ix/load 4, 128, 0; 179 | %flag_set/imm 4, 0; 180 | %store/vec4 v0x650d193aeb50_0, 4, 32; 181 | T_6.0 ; 182 | %load/vec4 v0x650d193aeab0_0; 183 | %load/vec4 v0x650d193aeab0_0; 184 | %load/vec4 v0x650d193aea10_0; 185 | %xor; 186 | %load/vec4 v0x650d193aeab0_0; 187 | %xor; 188 | %cmp/ne; 189 | %jmp/0xz T_6.4, 6; 190 | %load/vec4 v0x650d193aeb50_0; 191 | %parti/u 32, 64, 32; 192 | %cmpi/e 0, 0, 32; 193 | %jmp/0xz T_6.6, 4; 194 | %vpi_func 3 93 "$time" 64 {0 0 0}; 195 | %cast2; 196 | %pad/u 32; 197 | %ix/load 4, 32, 0; 198 | %flag_set/imm 4, 0; 199 | %store/vec4 v0x650d193aeb50_0, 4, 32; 200 | T_6.6 ; 201 | %load/vec4 v0x650d193aeb50_0; 202 | %parti/u 32, 64, 32; 203 | %addi 1, 0, 32; 204 | %cast2; 205 | %ix/load 4, 64, 0; 206 | %flag_set/imm 4, 0; 207 | %store/vec4 v0x650d193aeb50_0, 4, 32; 208 | T_6.4 ; 209 | %jmp T_6; 210 | .thread T_6; 211 | .scope S_0x650d193949f0; 212 | T_7 ; 213 | %delay 1000000, 0; 214 | %vpi_call/w 3 101 "$display", "TIMEOUT" {0 0 0}; 215 | %vpi_call/w 3 102 "$finish" {0 0 0}; 216 | %end; 217 | .thread T_7; 218 | # The file index is used to find the file name in the following table. 219 | :file_names 6; 220 | "N/A"; 221 | ""; 222 | "-"; 223 | "../verilog-eval/dataset_code-complete-iccad2023/Prob002_m2014_q4i_test.sv"; 224 | "../verilog-eval/dataset_code-complete-iccad2023/Prob002_m2014_q4i_ref.sv"; 225 | "./output_claude3.5sonnet_20241114_v1_2/VERILOG_EVAL_V1_Prob002_m2014_q4i/rtl.sv"; 226 | -------------------------------------------------------------------------------- /src/sim/top.sv: -------------------------------------------------------------------------------- 1 | 2 | module top ( 3 | input clk, 4 | input rst, 5 | input [7:0] input_data, 6 | output reg [7:0] output_data 7 | ); 8 | 9 | 10 | always @(posedge clk) begin 11 | if (rst) begin 12 | output_data <= 8'h00; 13 | end else begin 14 | if (input_data == 8'hFF) begin 15 | output_data <= input_data << 1; 16 | end else if (input_data > 8'hF8) begin 17 | output_data <= input_data + 8'h0F; 18 | end else if (input_data > 8'he0) begin 19 | output_data <= input_data - 8'h10; 20 | end else if (input_data > 8'ha0) begin 21 | output_data <= input_data - 8'h11; 22 | end else if (input_data > 8'h80) begin 23 | output_data <= input_data - 8'h22; 24 | end else if (input_data > 8'h40) begin 25 | output_data <= input_data - 8'h33; 26 | end else if (input_data > 8'h00) begin 27 | output_data <= input_data - 8'h44; 28 | end else begin 29 | output_data <= input_data; 30 | end 31 | end 32 | end 33 | 34 | 35 | 36 | endmodule 37 | -------------------------------------------------------------------------------- /testbench_generate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json\n", 10 | "import os\n", 11 | "\n", 12 | "root = \"/Users/YourName\"\n", 13 | "\n", 14 | "# Suppose data file name is data.json\n", 15 | "input_filename = f'{root}/verilog-eval-release-1.0.0/data/VerilogEval_Human.jsonl'\n", 16 | "\n", 17 | "# read and parse data\n", 18 | "with open(input_filename, 'r') as file:\n", 19 | " for line in file:\n", 20 | " # parse each line as JSON\n", 21 | " data = json.loads(line.strip())\n", 22 | " task_id = data['task_id']\n", 23 | " test_content = data['test']\n", 24 | "\n", 25 | " # Create a folder to store the testbenches (if not exist)\n", 26 | " folder_name = f\"{root}/test\"\n", 27 | " if not os.path.exists(folder_name):\n", 28 | " os.makedirs(folder_name)\n", 29 | "\n", 30 | " # Create and write the testbench file\n", 31 | " output_filename = os.path.join(folder_name, f\"{root}/test/testbench_{task_id}.sv\")\n", 32 | " with open(output_filename, 'w') as output_file:\n", 33 | " output_file.write(test_content)\n", 34 | "\n", 35 | "print(\"Data has been successfully converted to testbenches.\")\n" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [] 44 | } 45 | ], 46 | "metadata": { 47 | "kernelspec": { 48 | "display_name": "Python 3", 49 | "language": "python", 50 | "name": "python3" 51 | }, 52 | "language_info": { 53 | "codemirror_mode": { 54 | "name": "ipython", 55 | "version": 3 56 | }, 57 | "file_extension": ".py", 58 | "mimetype": "text/x-python", 59 | "name": "python", 60 | "nbconvert_exporter": "python", 61 | "pygments_lexer": "ipython3", 62 | "version": "3.9.15" 63 | } 64 | }, 65 | "nbformat": 4, 66 | "nbformat_minor": 2 67 | } 68 | -------------------------------------------------------------------------------- /tests/.gitignore: -------------------------------------------------------------------------------- 1 | output_* 2 | -------------------------------------------------------------------------------- /tests/test_llm_chat.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from mage.gen_config import get_llm 4 | from mage.log_utils import get_logger 5 | 6 | logger = get_logger(__name__) 7 | 8 | args_dict = { 9 | "provider": "vertexanthropic", 10 | "model": "claude-3-7-sonnet@20250219", 11 | "n": 1, 12 | "temperature": 0.85, 13 | "top_p": 0.95, 14 | "max_token": 8192, 15 | "key_cfg_path": "./key.cfg", 16 | } 17 | 18 | 19 | def main(): 20 | args = argparse.Namespace(**args_dict) 21 | get_llm( 22 | model=args.model, 23 | cfg_path=args.key_cfg_path, 24 | max_token=args.max_token, 25 | provider=args.provider, 26 | ) 27 | 28 | 29 | if __name__ == "__main__": 30 | main() 31 | -------------------------------------------------------------------------------- /tests/test_rtl_generator.py: -------------------------------------------------------------------------------- 1 | # Quick smoke test for RTLGenerator. 2 | # For more functionality testing please see test_top_agent.py 3 | import argparse 4 | import os 5 | 6 | from mage.benchmark_read_helper import ( 7 | TypeBenchmark, 8 | TypeBenchmarkFile, 9 | get_benchmark_contents, 10 | ) 11 | from mage.gen_config import get_llm, set_exp_setting 12 | from mage.log_utils import get_logger 13 | from mage.rtl_generator import RTLGenerator 14 | from mage.token_counter import TokenCounter, TokenCounterCached 15 | 16 | logger = get_logger(__name__) 17 | 18 | args_dict = { 19 | "provider": "vertexanthropic", 20 | "model": "claude-3-7-sonnet@20250219", 21 | "filter_instance": "^(Prob070_ece241_2013_q2|Prob151_review2015_fsm)$", 22 | "type_benchmark": "verilog_eval_v2", 23 | "path_benchmark": "./verilog-eval", 24 | "temperature": 0.85, 25 | "top_p": 0.95, 26 | "max_token": 8192, 27 | "key_cfg_path": "./key.cfg", 28 | } 29 | 30 | 31 | def main(): 32 | args = argparse.Namespace(**args_dict) 33 | llm = get_llm( 34 | model=args.model, 35 | cfg_path=args.key_cfg_path, 36 | max_token=args.max_token, 37 | provider=args.provider, 38 | ) 39 | token_counter = ( 40 | TokenCounterCached(llm) 41 | if TokenCounterCached.is_cache_enabled(llm) 42 | else TokenCounter(llm) 43 | ) 44 | set_exp_setting(temperature=args.temperature, top_p=args.top_p) 45 | type_benchmark = TypeBenchmark[args.type_benchmark.upper()] 46 | 47 | rtl_gen = RTLGenerator(token_counter) 48 | spec_dict = get_benchmark_contents( 49 | type_benchmark, 50 | TypeBenchmarkFile.SPEC, 51 | args.path_benchmark, 52 | args.filter_instance, 53 | ) 54 | golden_tb_path_dict = get_benchmark_contents( 55 | type_benchmark, 56 | TypeBenchmarkFile.TEST_PATH, 57 | args.path_benchmark, 58 | args.filter_instance, 59 | ) 60 | for key, spec in spec_dict.items(): 61 | rtl_gen.reset() 62 | logger.info(spec) 63 | testbench_path = golden_tb_path_dict.get(key) 64 | if not testbench_path: 65 | logger.error(f"Testbench path not found for {key}") 66 | continue 67 | with open(testbench_path, "r") as f: 68 | testbench = f.read() 69 | # set output path to tmp 70 | rtl_path = f"./output_{key}_rtl.sv" 71 | is_pass, code = rtl_gen.chat( 72 | input_spec=spec, 73 | testbench=testbench, 74 | interface=None, 75 | rtl_path=rtl_path, 76 | enable_cache=True, 77 | ) 78 | logger.info(is_pass) 79 | logger.info(code) 80 | # remove the generated RTL file 81 | os.remove(rtl_path) 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /tests/test_single_agent.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | 5 | import backoff 6 | import openai 7 | import pandas as pd 8 | 9 | from mage.benchmark_read_helper import ( 10 | TypeBenchmark, 11 | TypeBenchmarkFile, 12 | get_benchmark_contents, 13 | ) 14 | from mage.gen_config import Config, get_llm 15 | from mage.log_utils import get_logger 16 | from mage.rtl_generator import RTLGenerator 17 | 18 | # Configuration and Constants 19 | # model = "gpt-4-0314" 20 | model = "gpt-4" 21 | # model = "gpt-3.5-turbo-0301" 22 | temperature = 0.7 23 | n = 1 24 | 25 | 26 | description_directory = "../Dataset" 27 | output_directory = "output_gen_verilog_4_shot_self_learning" 28 | # circuit_folders = ["adder_16bit"] 29 | # ,"fsm","multi_booth","right_shifter"] # and other folders 30 | circuit_folders = [ 31 | "accu", 32 | "adder_32bit", 33 | "adder_8bit", 34 | "asyn_fifo", 35 | "counter_12", 36 | "edge_detect", 37 | "fsm", 38 | "multi_16bit", 39 | "multi_pipe_4bit", 40 | "mux", 41 | "pe", 42 | "pulse_detect", 43 | "RAM", 44 | "right_shifter", 45 | "signal_generator", 46 | "width_8to16", 47 | "adder_16bit", 48 | "adder_64bit", 49 | "alu", 50 | "calendar", 51 | "div_16bit", 52 | "freq_div", 53 | "Johnson_Counter", 54 | "multi_booth", 55 | "multi_pipe_8bit", 56 | "parallel2serial", 57 | "radix2_div", 58 | "serial2parallel", 59 | "traffic_light", 60 | ] 61 | 62 | 63 | max_generations = 20 64 | max_retries = 10 65 | 66 | 67 | results_columns = [ 68 | "Folder", 69 | "Generation", 70 | "Attempt", 71 | "Result", 72 | "Syntax Check", 73 | "Functionality Check", 74 | ] 75 | results_df = pd.DataFrame(columns=results_columns) 76 | # Initialize conversation history 77 | conversation_history = [] 78 | 79 | 80 | @backoff.on_exception(backoff.expo, openai.error.OpenAIError, max_tries=2) 81 | def completions_with_backoff(**kwargs): 82 | return openai.ChatCompletion.create(**kwargs) 83 | 84 | 85 | verilog_examples = [ 86 | """Here are Some examples of rtl verilog code: 87 | Gray to Binary Code Converter: 88 | This Verilog example demonstrates a parameterized gray to binary code converter using a Verilog generate loop 89 | module gray2bin 90 | #(parameter SIZE = 8) 91 | ( 92 | input [SIZE-1:0] gray, 93 | output [SIZE-1:0] bin 94 | ) 95 | 96 | Genvar gi; 97 | for (gi=0; gi