├── .github └── FUNDING.yml ├── .gitignore ├── .project ├── LICENSE ├── Module.manifest ├── README.md ├── build.gradle ├── data └── README.txt ├── extension.properties ├── ghidra_scripts └── README.txt ├── res ├── screenshot1.png └── screenshots_anim.gif └── src └── main ├── help └── help │ ├── TOC_Source.xml │ └── topics │ └── ghidrassist │ └── help.html ├── java └── ghidrassist │ ├── AnalysisDB.java │ ├── GAUtils.java │ ├── GhidrAssistPlugin.java │ ├── GhidrAssistProvider.java │ ├── LlmApi.java │ ├── RLHFDatabase.java │ ├── SettingsDialog.java │ ├── apiprovider │ ├── APIProvider.java │ ├── APIProviderConfig.java │ ├── APIProviderLogger.java │ ├── AnthropicProvider.java │ ├── ChatMessage.java │ ├── ErrorAction.java │ ├── ErrorMessageBuilder.java │ ├── LMStudioProvider.java │ ├── OllamaProvider.java │ ├── OpenAIProvider.java │ ├── OpenWebUiProvider.java │ ├── RetryHandler.java │ ├── capabilities │ │ ├── ChatProvider.java │ │ ├── EmbeddingProvider.java │ │ ├── FunctionCallingProvider.java │ │ └── ModelListProvider.java │ ├── exceptions │ │ ├── APIProviderException.java │ │ ├── AuthenticationException.java │ │ ├── ModelException.java │ │ ├── NetworkException.java │ │ ├── RateLimitException.java │ │ ├── ResponseException.java │ │ └── StreamCancelledException.java │ └── factory │ │ ├── APIProviderFactory.java │ │ ├── AnthropicProviderFactory.java │ │ ├── LMStudioProviderFactory.java │ │ ├── OllamaProviderFactory.java │ │ ├── OpenAIProviderFactory.java │ │ ├── OpenWebUiProviderFactory.java │ │ ├── ProviderRegistry.java │ │ └── UnsupportedProviderException.java │ ├── core │ ├── ActionConstants.java │ ├── ActionExecutor.java │ ├── ActionParser.java │ ├── CodeUtils.java │ ├── ConversationalToolHandler.java │ ├── LlmApiClient.java │ ├── LlmErrorHandler.java │ ├── LlmTaskExecutor.java │ ├── MarkdownHelper.java │ ├── QueryProcessor.java │ ├── RAGEngine.java │ ├── ResponseProcessor.java │ ├── TabController.java │ └── UIState.java │ ├── mcp2 │ ├── protocol │ │ ├── MCPMessage.java │ │ ├── MCPProtocolClient.java │ │ ├── MCPRequest.java │ │ └── MCPResponse.java │ ├── server │ │ ├── MCPServerConfig.java │ │ └── MCPServerRegistry.java │ ├── tools │ │ ├── MCPTool.java │ │ ├── MCPToolManager.java │ │ └── MCPToolResult.java │ └── transport │ │ ├── MCPTransport.java │ │ └── SSETransport.java │ ├── resources │ └── GhidrAssistIcons.java │ ├── services │ ├── ActionAnalysisService.java │ ├── AnalysisDataService.java │ ├── CodeAnalysisService.java │ ├── FeedbackService.java │ ├── QueryService.java │ └── RAGManagementService.java │ └── ui │ ├── EnhancedErrorDialog.java │ ├── GhidrAssistUI.java │ ├── common │ ├── PlaceholderTextField.java │ └── UIConstants.java │ └── tabs │ ├── ActionsTab.java │ ├── AnalysisOptionsTab.java │ ├── ExplainTab.java │ ├── MCPServerDialog.java │ ├── MCPServersTab.java │ ├── QueryTab.java │ └── RAGManagementTab.java └── resources └── images ├── README.txt ├── robot.svg ├── robot16.png └── robot32.png /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: jtang613 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | polar: # Replace with a single Polar username 13 | buy_me_a_coffee: jtang613 14 | thanks_dev: # Replace with a single thanks.dev username 15 | custom: 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bin/ 2 | lib/ 3 | dist/ 4 | os/ 5 | .gradle/ 6 | .settings/ 7 | .classpath 8 | .pydevproject 9 | ghidrassist_rlhf.db 10 | ghidrassist_analysis.db 11 | lucene/ 12 | node_modules/ 13 | package.json 14 | package-lock.json 15 | CLAUDE.md 16 | *_PLAN.md 17 | -------------------------------------------------------------------------------- /.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | GhidrAssist 4 | 5 | 6 | 7 | 8 | 9 | org.python.pydev.PyDevBuilder 10 | 11 | 12 | 13 | 14 | org.eclipse.jdt.core.javabuilder 15 | 16 | 17 | 18 | 19 | 20 | org.eclipse.jdt.core.javanature 21 | org.python.pydev.pythonNature 22 | 23 | 24 | 25 | Ghidra 26 | 2 27 | /home/jtang613/tools/ghidra_11.3.2_PUBLIC 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024 Jason Tang 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /Module.manifest: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jtang613/GhidrAssist/f6bd628c273e8ff3da9b55703bcaa90e26c58f4d/Module.manifest -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GhidrAssist 2 | Author: **Jason Tang** 3 | 4 | _A plugin that provides LLM helpers to explain code and assist in RE._ 5 | 6 | ## Description: 7 | 8 | This is a LLM plugin aimed at enabling the use of local LLM's (Ollama, Open-WebUI, LM-Studio, etc) for assisting with binary exploration and reverse engineering. It supports any OpenAI v1-compatible API. Recommended models are LLaMA-based models such as llama3.1:8b, but others such as DeepSeek and ChatGPT work as well. 9 | 10 | Current features include: 11 | * Explain the current function - Works for disassembly and pseudo-C. 12 | * Explain the current instruction - Works for disassembly and pseudo-C. 13 | * General query - Query the LLM directly from the UI. 14 | * MCP client - Leverage MCP tools like [GhidraMCP](https://github.com/LaurieWired/GhidraMCP) from the interactive LLM chat. 15 | * Agentic RE using the MCP Client and GhidraMCP. 16 | * Propose actions - Provide a list of proposed actions to apply. 17 | * Function calling - Allow agent to call functions to navigate the binary, rename functions and variables. 18 | * Retrieval Augmented Generation - Supports adding contextual documents to refine query effectiveness. 19 | * RLHF dataset generation - To enable model fine tuning. 20 | * Settings to modify API host, key, model name and max tokens. 21 | 22 | Future Roadmap: 23 | * Model fine tuning - Leverage the RLHF dataset to fine tune the model. 24 | 25 | ## Screenshots 26 | 27 | ![Screenshot](https://github.com/user-attachments/assets/29fcaa14-277c-4eb2-816a-dd1b8ef52259) 28 | 29 | 30 | https://github.com/user-attachments/assets/bd79474a-c82f-4083-b432-96625fef1387 31 | 32 | 33 | ## Quickstart 34 | 35 | * If necessary, copy the binary release ZIP archive to the Ghidra_Install/Extensions/Ghidra directory. 36 | * Launch Ghidra -> File -> Install Extension -> Enable GhidrAssist. 37 | * Load a binary and launch the CodeBrowser. 38 | * CodeBrowser -> File -> Configure -> Miscellaneous -> Enable GhidrAssist. 39 | * CodeBrowser -> Tools -> GhidraAssist Settings. 40 | * Ensure the RLHF and RAG database paths are appropriate for your environment. 41 | * Point the API host to your preferred API provider and set the API key. 42 | * Open GhidrAssist with the GhidrAssist option in the Windows menu and start exploring. 43 | 44 | ## LLMs 45 | 46 | General LLM setup is a bit outside the scope of this project since there's so many different options and there are plenty of sources that cover the topic much better than I could. It assumes one already has access to an OpenAI-compatible API provider. 47 | Here's a few resources that might get you started: 48 | 49 | - https://lmstudio.ai/docs/basics 50 | - https://github.com/ollama/ollama#running-local-builds 51 | - https://help.openai.com/en/articles/4936850-where-do-i-find-my-openai-api-key 52 | - https://docs.anthropic.com/en/docs/initial-setup 53 | 54 | For local LLM's, I've found that the Llama3.3:70b, Llama3.1:8b and DeepSeek-r1 produce good results. 55 | From OpenAI, the o4-mini produces good results. Anthropic's Claude Sonnet also produces good results. 56 | 57 | ## GhidraMCP 58 | 59 | To use with GhidraMCP, launch the bridge in SSE mode from a terminal: 60 | 61 | `python bridge_mcp_ghidra.py --transport sse --mcp-host 127.0.0.1 --mcp-port 8081 --ghidra-server http://127.0.0.1:8080/` 62 | 63 | Then open Tools -> GhidrAssist and add `http://127.0.0.1:8081` as `GhidraMCP` with `SSE` as the type. 64 | 65 | Enable "Use MCP" in the Custom Query tab. Try a simple query like "What does the current function do?" 66 | 67 | ## Homepage 68 | https://github.com/jtang613/GhidrAssist 69 | 70 | 71 | ## Minimum Version 72 | 73 | This plugin requires the following minimum version of Ghidra: 74 | 75 | * 11.0 76 | 77 | ## License 78 | 79 | This plugin is released under a MIT license. 80 | -------------------------------------------------------------------------------- /build.gradle: -------------------------------------------------------------------------------- 1 | /* ### 2 | * IP: GHIDRA 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | // Builds a Ghidra Extension for a given Ghidra installation. 17 | // 18 | // An absolute path to the Ghidra installation directory must be supplied either by setting the 19 | // GHIDRA_INSTALL_DIR environment variable or Gradle project property: 20 | // 21 | // > export GHIDRA_INSTALL_DIR= 22 | // > gradle 23 | // 24 | // or 25 | // 26 | // > gradle -PGHIDRA_INSTALL_DIR= 27 | // 28 | // Gradle should be invoked from the directory of the project to build. Please see the 29 | // application.gradle.version property in /Ghidra/application.properties 30 | // for the correction version of Gradle to use for the Ghidra installation you specify. 31 | 32 | plugins { 33 | id 'java' 34 | id 'eclipse' 35 | } 36 | 37 | //----------------------START "DO NOT MODIFY" SECTION------------------------------ 38 | def ghidraInstallDir 39 | 40 | if (System.env.GHIDRA_INSTALL_DIR) { 41 | ghidraInstallDir = System.env.GHIDRA_INSTALL_DIR 42 | } 43 | else if (project.hasProperty("GHIDRA_INSTALL_DIR")) { 44 | ghidraInstallDir = project.getProperty("GHIDRA_INSTALL_DIR") 45 | } 46 | 47 | if (ghidraInstallDir) { 48 | apply from: new File(ghidraInstallDir).getCanonicalPath() + "/support/buildExtension.gradle" 49 | } 50 | else { 51 | throw new GradleException("GHIDRA_INSTALL_DIR is not defined!") 52 | } 53 | //----------------------END "DO NOT MODIFY" SECTION------------------------------- 54 | 55 | repositories { 56 | // Declare dependency repositories here. This is not needed if dependencies are manually 57 | // dropped into the lib/ directory. 58 | // See https://docs.gradle.org/current/userguide/declaring_repositories.html for more info. 59 | mavenCentral() 60 | } 61 | 62 | dependencies { 63 | // Any external dependencies added here will automatically be copied to the lib/ directory when 64 | // this extension is built. 65 | 66 | implementation 'com.fasterxml.jackson.core:jackson-databind:2.15.0' 67 | implementation "io.reactivex.rxjava3:rxjava:3.1.9" 68 | implementation 'com.vladsch.flexmark:flexmark:0.64.0' 69 | implementation 'com.vladsch.flexmark:flexmark-html2md-converter:0.64.0' 70 | implementation 'org.xerial:sqlite-jdbc:3.46.1.0' 71 | implementation 'org.apache.lucene:lucene-core:9.11.1' 72 | implementation 'org.apache.lucene:lucene-analysis-common:9.11.1' 73 | implementation 'org.apache.lucene:lucene-queryparser:9.11.1' 74 | implementation 'com.squareup.okio:okio:3.10.2' 75 | implementation "com.squareup.okhttp3:okhttp:4.12.0" 76 | } 77 | 78 | // Exclude additional files from the built extension 79 | // Ex: buildExtension.exclude '.idea/**' 80 | -------------------------------------------------------------------------------- /data/README.txt: -------------------------------------------------------------------------------- 1 | The "data" directory is intended to hold data files that will be used by this module and will 2 | not end up in the .jar file, but will be present in the zip or tar file. Typically, data 3 | files are placed here rather than in the resources directory if the user may need to edit them. 4 | 5 | An optional data/languages directory can exist for the purpose of containing various Sleigh language 6 | specification files and importer opinion files. 7 | 8 | The data/buildLanguage.xml is used for building the contents of the data/languages directory. 9 | 10 | The skel language definition has been commented-out within the skel.ldefs file so that the 11 | skeleton language does not show-up within Ghidra. 12 | 13 | See the Sleigh language documentation (docs/languages/index.html) for details Sleigh language 14 | specification syntax. 15 | -------------------------------------------------------------------------------- /extension.properties: -------------------------------------------------------------------------------- 1 | name=@extname@ 2 | description=A plugin that provides LLM helpers to explain code and assist in RE. 3 | author=Jason Tang 4 | createdOn= 5 | version=@extversion@ 6 | -------------------------------------------------------------------------------- /ghidra_scripts/README.txt: -------------------------------------------------------------------------------- 1 | Java source directory to hold module-specific Ghidra scripts. 2 | -------------------------------------------------------------------------------- /res/screenshot1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jtang613/GhidrAssist/f6bd628c273e8ff3da9b55703bcaa90e26c58f4d/res/screenshot1.png -------------------------------------------------------------------------------- /res/screenshots_anim.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jtang613/GhidrAssist/f6bd628c273e8ff3da9b55703bcaa90e26c58f4d/res/screenshots_anim.gif -------------------------------------------------------------------------------- /src/main/help/help/TOC_Source.xml: -------------------------------------------------------------------------------- 1 | 2 | 49 | 50 | 51 | 52 | 57 | 58 | -------------------------------------------------------------------------------- /src/main/help/help/topics/ghidrassist/help.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | 8 | 9 | 10 | 11 | 12 | Skeleton Help File for a Module 13 | 14 | 15 | 16 | 17 |

Skeleton Help File for a Module

18 | 19 |

This is a simple skeleton help topic. For a better description of what should and should not 20 | go in here, see the "sample" Ghidra extension in the Extensions/Ghidra directory, or see your 21 | favorite help topic. In general, language modules do not have their own help topics.

22 | 23 | 24 | -------------------------------------------------------------------------------- /src/main/java/ghidrassist/GAUtils.java: -------------------------------------------------------------------------------- 1 | package ghidrassist; 2 | 3 | import java.io.File; 4 | 5 | public class GAUtils { 6 | public enum OperatingSystem { 7 | WINDOWS, MAC, LINUX, UNKNOWN; 8 | 9 | public static OperatingSystem detect() { 10 | String os = System.getProperty("os.name").toLowerCase(); 11 | if (os.contains("win")) { 12 | return WINDOWS; 13 | } else if (os.contains("mac")) { 14 | return MAC; 15 | } else if (os.contains("nix") || os.contains("nux") || os.contains("aix")) { 16 | return LINUX; 17 | } else { 18 | return UNKNOWN; 19 | } 20 | } 21 | } 22 | 23 | public static String getDefaultLucenePath(OperatingSystem os) { 24 | String basePath; 25 | switch (os) { 26 | case WINDOWS: 27 | basePath = System.getenv("LOCALAPPDATA"); 28 | if (basePath == null) { 29 | throw new RuntimeException("Unable to access LOCALAPPDATA environment variable."); 30 | } 31 | break; 32 | 33 | case MAC: 34 | basePath = System.getProperty("user.home") + "/Library/Application Support"; 35 | break; 36 | 37 | case LINUX: 38 | basePath = System.getProperty("user.home") + "/.config"; 39 | break; 40 | 41 | default: 42 | throw new UnsupportedOperationException("Unsupported operating system: " + os); 43 | } 44 | return basePath + File.separator + "GhidrAssist" + File.separator + "LuceneIndex"; 45 | } 46 | 47 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/GhidrAssistPlugin.java: -------------------------------------------------------------------------------- 1 | package ghidrassist; 2 | 3 | import java.lang.reflect.Type; 4 | import java.util.List; 5 | 6 | import com.google.gson.Gson; 7 | import com.google.gson.reflect.TypeToken; 8 | 9 | import docking.ActionContext; 10 | import docking.action.DockingAction; 11 | import docking.action.MenuData; 12 | import ghidra.app.decompiler.DecompilerLocation; 13 | import ghidra.app.plugin.PluginCategoryNames; 14 | import ghidra.app.plugin.ProgramPlugin; 15 | import ghidra.framework.plugintool.*; 16 | import ghidra.framework.plugintool.util.PluginStatus; 17 | import ghidra.framework.preferences.Preferences; 18 | import ghidra.program.model.address.Address; 19 | import ghidra.program.model.listing.Function; 20 | import ghidra.program.model.listing.FunctionManager; 21 | import ghidra.program.model.listing.Program; 22 | import ghidra.program.util.ProgramLocation; 23 | import ghidrassist.apiprovider.APIProviderConfig; 24 | 25 | @PluginInfo( 26 | status = PluginStatus.STABLE, 27 | packageName = "GhidrAssist", 28 | category = PluginCategoryNames.COMMON, 29 | shortDescription = "GhidrAssist LLM Plugin", 30 | description = "A plugin that provides code assistance using a language model." 31 | ) 32 | public class GhidrAssistPlugin extends ProgramPlugin { 33 | public enum CodeViewType { 34 | IS_DECOMPILER, 35 | IS_DISASSEMBLER, 36 | UNKNOWN 37 | } 38 | private GhidrAssistProvider provider; 39 | private String lastActiveProvider; 40 | 41 | public GhidrAssistPlugin(PluginTool tool) { 42 | super(tool); 43 | String pluginName = getName(); 44 | provider = new GhidrAssistProvider(this, pluginName); 45 | } 46 | 47 | @Override 48 | public void init() { 49 | super.init(); 50 | 51 | // Add a menu action for settings 52 | DockingAction settingsAction = new DockingAction("GhidrAssist Settings", getName()) { 53 | @Override 54 | public void actionPerformed(ActionContext context) { 55 | showSettingsDialog(); 56 | } 57 | }; 58 | settingsAction.setMenuBarData(new MenuData(new String[] { "Tools", "GhidrAssist Settings" }, null, "GhidrAssist")); 59 | tool.addAction(settingsAction); 60 | } 61 | 62 | @Override 63 | protected void dispose() { 64 | if (provider != null) { 65 | tool.removeComponentProvider(provider); 66 | provider = null; 67 | } 68 | super.dispose(); 69 | } 70 | 71 | private void showSettingsDialog() { 72 | SettingsDialog dialog = new SettingsDialog(tool.getToolFrame(), "GhidrAssist Settings", this); 73 | tool.showDialog(dialog); 74 | } 75 | 76 | @Override 77 | public void locationChanged(ProgramLocation loc) { 78 | if (provider != null) { 79 | provider.getUI().updateLocation(loc); 80 | } 81 | } 82 | 83 | public Program getCurrentProgram() { 84 | return currentProgram; 85 | } 86 | 87 | public Address getCurrentAddress() { 88 | if (currentLocation != null) { 89 | return currentLocation.getAddress(); 90 | } 91 | return null; 92 | } 93 | 94 | public Function getCurrentFunction() { 95 | Program program = getCurrentProgram(); 96 | Address address = getCurrentAddress(); 97 | 98 | if (program != null && address != null) { 99 | FunctionManager functionManager = program.getFunctionManager(); 100 | return functionManager.getFunctionContaining(address); 101 | } 102 | return null; 103 | } 104 | 105 | public String getLastActiveProvider() { 106 | return lastActiveProvider; 107 | } 108 | 109 | public CodeViewType checkLastActiveCodeView() { 110 | if (currentLocation instanceof DecompilerLocation) { 111 | return CodeViewType.IS_DECOMPILER; 112 | } else if (currentLocation != null) { 113 | return CodeViewType.IS_DISASSEMBLER; 114 | } else { 115 | return CodeViewType.UNKNOWN; 116 | } 117 | } 118 | 119 | public static APIProviderConfig getCurrentProviderConfig() { 120 | // Load the list of API providers from preferences 121 | String providersJson = Preferences.getProperty("GhidrAssist.APIProviders", "[]"); 122 | Gson gson = new Gson(); 123 | Type listType = new TypeToken>() {}.getType(); 124 | List apiProviders = gson.fromJson(providersJson, listType); 125 | 126 | // Load the selected provider name 127 | String selectedProviderName = Preferences.getProperty("GhidrAssist.SelectedAPIProvider", ""); 128 | 129 | // Load the global API timeout setting 130 | String apiTimeoutStr = Preferences.getProperty("GhidrAssist.APITimeout", "120"); 131 | Integer apiTimeout = 120; // Default value 132 | try { 133 | apiTimeout = Integer.parseInt(apiTimeoutStr); 134 | } catch (NumberFormatException e) { 135 | // Use default if there's an error 136 | } 137 | 138 | for (APIProviderConfig provider : apiProviders) { 139 | if (provider.getName().equals(selectedProviderName)) { 140 | // If the provider doesn't have a timeout set, use the global setting 141 | if (provider.getTimeout() == null) { 142 | provider.setTimeout(apiTimeout); 143 | } 144 | return provider; 145 | } 146 | } 147 | 148 | return null; 149 | } 150 | 151 | public static Integer getGlobalApiTimeout() { 152 | String apiTimeoutStr = Preferences.getProperty("GhidrAssist.APITimeout", "120"); 153 | try { 154 | return Integer.parseInt(apiTimeoutStr); 155 | } catch (NumberFormatException e) { 156 | return 120; // Default value 157 | } 158 | } 159 | 160 | public GhidrAssistPlugin getInstance() { 161 | return this; 162 | } 163 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/GhidrAssistProvider.java: -------------------------------------------------------------------------------- 1 | package ghidrassist; 2 | 3 | import java.awt.event.MouseEvent; 4 | import java.util.ArrayList; 5 | import java.util.List; 6 | 7 | import javax.swing.*; 8 | 9 | import docking.ActionContext; 10 | import docking.ComponentProvider; 11 | import docking.DefaultActionContext; 12 | import docking.action.DockingAction; 13 | import docking.action.ToolBarData; 14 | import ghidra.util.Msg; 15 | import ghidrassist.resources.GhidrAssistIcons; 16 | import ghidrassist.ui.GhidrAssistUI; 17 | import resources.Icons; 18 | 19 | public class GhidrAssistProvider extends ComponentProvider { 20 | private GhidrAssistPlugin plugin; 21 | private GhidrAssistUI ui; 22 | private JComponent mainPanel; 23 | private List actions; 24 | 25 | public GhidrAssistProvider(GhidrAssistPlugin plugin, String owner) { 26 | super(plugin.getTool(), owner, owner); 27 | this.plugin = plugin; 28 | this.actions = new ArrayList<>(); 29 | 30 | buildPanel(); 31 | createActions(); 32 | setIcon(GhidrAssistIcons.ROBOT_ICON); 33 | } 34 | 35 | private void buildPanel() { 36 | ui = new GhidrAssistUI(plugin); 37 | mainPanel = ui.getComponent(); 38 | setVisible(true); 39 | } 40 | 41 | private void createActions() { 42 | DockingAction refreshAction = new DockingAction("Refresh GhidrAssist", getName()) { 43 | @Override 44 | public void actionPerformed(ActionContext context) { 45 | refresh(); 46 | } 47 | }; 48 | refreshAction.setToolBarData(new ToolBarData(Icons.REFRESH_ICON, null)); 49 | refreshAction.setEnabled(true); 50 | refreshAction.markHelpUnnecessary(); 51 | actions.add(refreshAction); 52 | 53 | // Add actions to the tool 54 | for (DockingAction action : actions) { 55 | plugin.getTool().addLocalAction(this, action); 56 | } 57 | } 58 | 59 | public GhidrAssistUI getUI() { 60 | return ui; 61 | } 62 | 63 | @Override 64 | public JComponent getComponent() { 65 | return mainPanel; 66 | } 67 | 68 | @Override 69 | public ActionContext getActionContext(MouseEvent event) { 70 | return new DefaultActionContext(this, mainPanel); 71 | } 72 | 73 | public void refresh() { 74 | try { 75 | Msg.info(this, "GhidrAssist UI refreshed"); 76 | } 77 | catch (Exception e) { 78 | Msg.error(this, "Error refreshing GhidrAssist UI", e); 79 | } 80 | } 81 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/LlmApi.java: -------------------------------------------------------------------------------- 1 | package ghidrassist; 2 | 3 | import ghidrassist.apiprovider.APIProviderConfig; 4 | import ghidrassist.core.ConversationalToolHandler; 5 | import ghidrassist.core.LlmApiClient; 6 | import ghidrassist.core.LlmErrorHandler; 7 | import ghidrassist.core.LlmTaskExecutor; 8 | import ghidrassist.core.ResponseProcessor; 9 | 10 | import java.util.List; 11 | import java.util.Map; 12 | 13 | /** 14 | * - LlmApiClient: Provider management and API calls 15 | * - ResponseProcessor: Text filtering and processing 16 | * - LlmTaskExecutor: Background task execution 17 | * - LlmErrorHandler: Error handling and user feedback 18 | */ 19 | public class LlmApi { 20 | 21 | private final LlmApiClient apiClient; 22 | private final ResponseProcessor responseProcessor; 23 | private final LlmTaskExecutor taskExecutor; 24 | private final LlmErrorHandler errorHandler; 25 | private volatile ConversationalToolHandler activeConversationalHandler; 26 | 27 | public LlmApi(APIProviderConfig config, GhidrAssistPlugin plugin) { 28 | this.apiClient = new LlmApiClient(config, plugin); 29 | this.responseProcessor = new ResponseProcessor(); 30 | this.taskExecutor = new LlmTaskExecutor(); 31 | this.errorHandler = new LlmErrorHandler(plugin, this); 32 | } 33 | 34 | /** 35 | * Get the system prompt for regular queries 36 | */ 37 | public String getSystemPrompt() { 38 | return apiClient.getSystemPrompt(); 39 | } 40 | 41 | /** 42 | * Send a streaming request with enhanced error handling 43 | */ 44 | public void sendRequestAsync(String prompt, LlmResponseHandler responseHandler) { 45 | if (!apiClient.isProviderAvailable()) { 46 | errorHandler.handleError( 47 | new IllegalStateException("LLM provider is not initialized."), 48 | "send request", 49 | null 50 | ); 51 | return; 52 | } 53 | 54 | // Create enhanced response handler that includes error handling 55 | LlmTaskExecutor.LlmResponseHandler enhancedHandler = new LlmTaskExecutor.LlmResponseHandler() { 56 | @Override 57 | public void onStart() { 58 | responseHandler.onStart(); 59 | } 60 | 61 | @Override 62 | public void onUpdate(String partialResponse) { 63 | responseHandler.onUpdate(partialResponse); 64 | } 65 | 66 | @Override 67 | public void onComplete(String fullResponse) { 68 | responseHandler.onComplete(fullResponse); 69 | } 70 | 71 | @Override 72 | public void onError(Throwable error) { 73 | // Handle error with enhanced error handling 74 | Runnable retryAction = () -> sendRequestAsync(prompt, responseHandler); 75 | errorHandler.handleError(error, "stream chat completion", retryAction); 76 | responseHandler.onError(error); 77 | } 78 | 79 | @Override 80 | public boolean shouldContinue() { 81 | return responseHandler.shouldContinue(); 82 | } 83 | }; 84 | 85 | taskExecutor.executeStreamingRequest(apiClient, prompt, responseProcessor, enhancedHandler); 86 | } 87 | 88 | /** 89 | * Send a conversational tool calling request that handles multiple turns 90 | * Monitors finish_reason to determine when to execute tools vs. complete 91 | */ 92 | public void sendConversationalToolRequest(String prompt, List> functions, LlmResponseHandler responseHandler) { 93 | if (!apiClient.isProviderAvailable()) { 94 | errorHandler.handleError( 95 | new IllegalStateException("LLM provider is not initialized."), 96 | "send conversational tool request", 97 | null 98 | ); 99 | return; 100 | } 101 | 102 | // Create completion callback to clear reference 103 | Runnable onCompletion = () -> { 104 | activeConversationalHandler = null; 105 | }; 106 | 107 | // Create enhanced response handler for conversational tool calling 108 | ConversationalToolHandler toolHandler = new ConversationalToolHandler( 109 | apiClient, functions, responseProcessor, responseHandler, errorHandler, onCompletion); 110 | 111 | // Store reference for cancellation 112 | activeConversationalHandler = toolHandler; 113 | 114 | // Start the conversation 115 | toolHandler.startConversation(prompt); 116 | } 117 | 118 | /** 119 | * Send a function calling request with enhanced error handling (legacy method) 120 | */ 121 | public void sendRequestAsyncWithFunctions(String prompt, List> functions, LlmResponseHandler responseHandler) { 122 | if (!apiClient.isProviderAvailable()) { 123 | errorHandler.handleError( 124 | new IllegalStateException("LLM provider is not initialized."), 125 | "send function request", 126 | null 127 | ); 128 | return; 129 | } 130 | 131 | // Create enhanced response handler that includes error handling 132 | LlmTaskExecutor.LlmResponseHandler enhancedHandler = new LlmTaskExecutor.LlmResponseHandler() { 133 | @Override 134 | public void onStart() { 135 | responseHandler.onStart(); 136 | } 137 | 138 | @Override 139 | public void onUpdate(String partialResponse) { 140 | responseHandler.onUpdate(partialResponse); 141 | } 142 | 143 | @Override 144 | public void onComplete(String fullResponse) { 145 | responseHandler.onComplete(fullResponse); 146 | } 147 | 148 | @Override 149 | public void onError(Throwable error) { 150 | // Handle error with enhanced error handling 151 | Runnable retryAction = () -> sendRequestAsyncWithFunctions(prompt, functions, responseHandler); 152 | errorHandler.handleError(error, "chat completion with functions", retryAction); 153 | responseHandler.onError(error); 154 | } 155 | 156 | @Override 157 | public boolean shouldContinue() { 158 | return responseHandler.shouldContinue(); 159 | } 160 | }; 161 | 162 | taskExecutor.executeFunctionRequest(apiClient, prompt, functions, responseProcessor, enhancedHandler); 163 | } 164 | 165 | /** 166 | * Cancel the current request 167 | */ 168 | public void cancelCurrentRequest() { 169 | // Cancel conversational tool handler if active 170 | if (activeConversationalHandler != null) { 171 | activeConversationalHandler.cancel(); 172 | activeConversationalHandler = null; 173 | } 174 | 175 | // Cancel regular task executor 176 | taskExecutor.cancelCurrentRequest(); 177 | } 178 | 179 | /** 180 | * Check if currently processing a request 181 | */ 182 | public boolean isStreaming() { 183 | return taskExecutor.isStreaming(); 184 | } 185 | 186 | /** 187 | * Get provider information for debugging/logging 188 | */ 189 | public String getProviderInfo() { 190 | return String.format("Provider: %s, Model: %s", 191 | apiClient.getProviderName(), 192 | apiClient.getProviderModel()); 193 | } 194 | 195 | /** 196 | * Interface for handling LLM responses - maintains compatibility with existing code 197 | */ 198 | public interface LlmResponseHandler { 199 | void onStart(); 200 | void onUpdate(String partialResponse); 201 | void onComplete(String fullResponse); 202 | void onError(Throwable error); 203 | default boolean shouldContinue() { 204 | return true; 205 | } 206 | } 207 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/RLHFDatabase.java: -------------------------------------------------------------------------------- 1 | package ghidrassist; 2 | 3 | import ghidra.framework.preferences.Preferences; 4 | import ghidra.util.Msg; 5 | 6 | import java.sql.*; 7 | 8 | public class RLHFDatabase { 9 | 10 | private static final String DB_PATH_PROPERTY = "GhidrAssist.RLHFDatabasePath"; 11 | private static final String DEFAULT_DB_PATH = "ghidrassist_rlhf.db"; 12 | private Connection connection; 13 | 14 | public RLHFDatabase() { 15 | String dbPath = Preferences.getProperty(DB_PATH_PROPERTY, DEFAULT_DB_PATH); 16 | initializeDatabase(dbPath); 17 | } 18 | 19 | private void initializeDatabase(String dbPath) { 20 | try { 21 | connection = DriverManager.getConnection("jdbc:sqlite:" + dbPath); 22 | createFeedbackTable(); 23 | } catch (SQLException e) { 24 | Msg.showError(this, null, "Database Error", "Failed to initialize RLHF database: " + e.getMessage()); 25 | } 26 | } 27 | 28 | private void createFeedbackTable() throws SQLException { 29 | String createTableSQL = "CREATE TABLE IF NOT EXISTS feedback (" 30 | + "id INTEGER PRIMARY KEY AUTOINCREMENT," 31 | + "model_name TEXT NOT NULL," 32 | + "prompt_context TEXT NOT NULL," 33 | + "system_context TEXT NOT NULL," 34 | + "response TEXT NOT NULL," 35 | + "feedback INTEGER NOT NULL" // 1 for thumbs up, 0 for thumbs down 36 | + ")"; 37 | Statement stmt = connection.createStatement(); 38 | stmt.execute(createTableSQL); 39 | stmt.close(); 40 | } 41 | 42 | public void storeFeedback(String modelName, String promptContext, String systemContext, String response, int feedback) { 43 | String insertSQL = "INSERT INTO feedback (model_name, prompt_context, system_context, response, feedback) " 44 | + "VALUES (?, ?, ?, ?, ?)"; 45 | try (PreparedStatement pstmt = connection.prepareStatement(insertSQL)) { 46 | pstmt.setString(1, modelName); 47 | pstmt.setString(2, promptContext); 48 | pstmt.setString(3, systemContext); 49 | pstmt.setString(4, response); 50 | pstmt.setInt(5, feedback); 51 | pstmt.executeUpdate(); 52 | } catch (SQLException e) { 53 | Msg.showError(this, null, "Database Error", "Failed to store feedback: " + e.getMessage()); 54 | } 55 | } 56 | 57 | public void close() { 58 | try { 59 | connection.close(); 60 | } catch (SQLException e) { 61 | Msg.showError(this, null, "Database Error", "Failed to close RLHF database connection: " + e.getMessage()); 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/APIProviderConfig.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider; 2 | 3 | import ghidrassist.GhidrAssistPlugin; 4 | import ghidrassist.apiprovider.factory.ProviderRegistry; 5 | import ghidrassist.apiprovider.factory.UnsupportedProviderException; 6 | 7 | public class APIProviderConfig { 8 | private String name; 9 | private String model; 10 | private Integer maxTokens; 11 | private String url; 12 | private String key; 13 | private boolean disableTlsVerification; 14 | private APIProvider.ProviderType type; 15 | private Integer timeout; 16 | 17 | public APIProviderConfig( 18 | String name, 19 | APIProvider.ProviderType type, 20 | String model, 21 | Integer maxTokens, 22 | String url, 23 | String key, 24 | boolean disableTlsVerification) { 25 | this(name, type, model, maxTokens, url, key, disableTlsVerification, 120); // Default timeout of 120 seconds 26 | } 27 | 28 | public APIProviderConfig( 29 | String name, 30 | APIProvider.ProviderType type, 31 | String model, 32 | Integer maxTokens, 33 | String url, 34 | String key, 35 | boolean disableTlsVerification, 36 | Integer timeout) { 37 | this.name = name; 38 | this.type = type; 39 | this.model = model; 40 | this.maxTokens = maxTokens; 41 | this.url = url; 42 | this.key = key; 43 | this.disableTlsVerification = disableTlsVerification; 44 | this.timeout = timeout; 45 | } 46 | 47 | // Getters 48 | public String getName() { return name; } 49 | public APIProvider.ProviderType getType() { return type; } 50 | public String getModel() { return model; } 51 | public Integer getMaxTokens() { return maxTokens; } 52 | public String getUrl() { return url; } 53 | public String getKey() { return key; } 54 | public boolean isDisableTlsVerification() { return disableTlsVerification; } 55 | public Integer getTimeout() { return timeout; } 56 | 57 | // Setters 58 | public void setName(String name) { this.name = name; } 59 | public void setType(APIProvider.ProviderType type) { this.type = type; } 60 | public void setModel(String model) { this.model = model; } 61 | public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } 62 | public void setUrl(String url) { this.url = url; } 63 | public void setKey(String key) { this.key = key; } 64 | public void setDisableTlsVerification(boolean disableTlsVerification) { this.disableTlsVerification = disableTlsVerification; } 65 | public void setTimeout(Integer timeout) { this.timeout = timeout; } 66 | 67 | /** 68 | * Create a provider using the factory pattern 69 | * @return Configured API provider instance 70 | * @throws RuntimeException if provider creation fails 71 | */ 72 | public APIProvider createProvider() { 73 | this.timeout = GhidrAssistPlugin.getGlobalApiTimeout(); 74 | 75 | try { 76 | return ProviderRegistry.getInstance().createProvider(this); 77 | } catch (UnsupportedProviderException e) { 78 | throw new IllegalArgumentException("Failed to create provider: " + e.getMessage(), e); 79 | } 80 | } 81 | 82 | /** 83 | * Check if this provider type is supported 84 | * @return true if the provider type is supported 85 | */ 86 | public boolean isSupported() { 87 | return ProviderRegistry.getInstance().isSupported(type); 88 | } 89 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/ChatMessage.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider; 2 | 3 | import com.fasterxml.jackson.databind.JsonNode; 4 | import com.google.gson.JsonArray; 5 | 6 | public class ChatMessage { 7 | private String role; 8 | private String content; 9 | private FunctionCall functionCall; 10 | private JsonArray toolCalls; // For assistant messages with tool calls 11 | private String toolCallId; // For tool response messages 12 | 13 | public ChatMessage(String role, String content) { 14 | this.role = role; 15 | this.content = content; 16 | } 17 | 18 | public String getRole() { 19 | return role; 20 | } 21 | 22 | public String getContent() { 23 | return content; 24 | } 25 | 26 | public void setContent(String content) { 27 | this.content = content; 28 | } 29 | 30 | public FunctionCall getFunctionCall() { 31 | return functionCall; 32 | } 33 | 34 | public void setFunctionCall(FunctionCall functionCall) { 35 | this.functionCall = functionCall; 36 | } 37 | 38 | public JsonArray getToolCalls() { 39 | return toolCalls; 40 | } 41 | 42 | public void setToolCalls(JsonArray toolCalls) { 43 | this.toolCalls = toolCalls; 44 | } 45 | 46 | public String getToolCallId() { 47 | return toolCallId; 48 | } 49 | 50 | public void setToolCallId(String toolCallId) { 51 | this.toolCallId = toolCallId; 52 | } 53 | 54 | public static class FunctionCall { 55 | private String name; 56 | private JsonNode arguments; 57 | 58 | public String getName() { 59 | return name; 60 | } 61 | 62 | public JsonNode getArguments() { 63 | return arguments; 64 | } 65 | } 66 | 67 | public static class ChatMessageRole { 68 | public static final String SYSTEM = "system"; 69 | public static final String USER = "user"; 70 | public static final String ASSISTANT = "assistant"; 71 | public static final String FUNCTION = "function"; 72 | public static final String TOOL = "tool"; 73 | } 74 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/ErrorAction.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider; 2 | 3 | import javax.swing.*; 4 | import java.awt.event.ActionEvent; 5 | import java.awt.event.ActionListener; 6 | 7 | /** 8 | * Represents an action that can be taken in response to an error 9 | */ 10 | public class ErrorAction { 11 | private final String actionText; 12 | private final String description; 13 | private final Runnable action; 14 | private final boolean isPrimary; 15 | 16 | public ErrorAction(String actionText, String description, Runnable action, boolean isPrimary) { 17 | this.actionText = actionText; 18 | this.description = description; 19 | this.action = action; 20 | this.isPrimary = isPrimary; 21 | } 22 | 23 | public ErrorAction(String actionText, Runnable action) { 24 | this(actionText, null, action, false); 25 | } 26 | 27 | // Getters 28 | public String getActionText() { return actionText; } 29 | public String getDescription() { return description; } 30 | public Runnable getAction() { return action; } 31 | public boolean isPrimary() { return isPrimary; } 32 | 33 | /** 34 | * Create a button for this action 35 | */ 36 | public JButton createButton() { 37 | JButton button = new JButton(actionText); 38 | if (description != null) { 39 | button.setToolTipText(description); 40 | } 41 | button.addActionListener(new ActionListener() { 42 | @Override 43 | public void actionPerformed(ActionEvent e) { 44 | if (action != null) { 45 | try { 46 | action.run(); 47 | } catch (Exception ex) { 48 | // Log error but don't propagate to avoid cascading errors 49 | System.err.println("Error executing action: " + ex.getMessage()); 50 | } 51 | } 52 | } 53 | }); 54 | return button; 55 | } 56 | 57 | // Common action factory methods 58 | public static ErrorAction createSettingsAction(Runnable openSettingsAction) { 59 | return new ErrorAction( 60 | "Open Settings", 61 | "Open the settings dialog to configure API providers", 62 | openSettingsAction, 63 | true 64 | ); 65 | } 66 | 67 | public static ErrorAction createRetryAction(Runnable retryAction) { 68 | return new ErrorAction( 69 | "Retry", 70 | "Try the operation again", 71 | retryAction, 72 | true 73 | ); 74 | } 75 | 76 | public static ErrorAction createCopyErrorAction(String errorDetails) { 77 | return new ErrorAction( 78 | "Copy Details", 79 | "Copy error details to clipboard", 80 | () -> copyToClipboard(errorDetails), 81 | false 82 | ); 83 | } 84 | 85 | public static ErrorAction createSwitchProviderAction(Runnable switchAction) { 86 | return new ErrorAction( 87 | "Switch Provider", 88 | "Try using a different API provider", 89 | switchAction, 90 | false 91 | ); 92 | } 93 | 94 | public static ErrorAction createDismissAction() { 95 | return new ErrorAction( 96 | "Dismiss", 97 | "Close this error dialog", 98 | () -> {}, // No-op, dialog will handle dismissal 99 | false 100 | ); 101 | } 102 | 103 | private static void copyToClipboard(String text) { 104 | try { 105 | java.awt.datatransfer.StringSelection stringSelection = 106 | new java.awt.datatransfer.StringSelection(text); 107 | java.awt.datatransfer.Clipboard clipboard = 108 | java.awt.Toolkit.getDefaultToolkit().getSystemClipboard(); 109 | clipboard.setContents(stringSelection, null); 110 | } catch (Exception e) { 111 | // Silently fail if clipboard is not available 112 | } 113 | } 114 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/RetryHandler.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider; 2 | 3 | import ghidrassist.apiprovider.exceptions.*; 4 | import ghidra.util.Msg; 5 | 6 | import java.util.concurrent.Callable; 7 | import java.util.function.Supplier; 8 | 9 | /** 10 | * Handles retry logic for API provider operations 11 | */ 12 | public class RetryHandler { 13 | private static final int DEFAULT_MAX_RETRIES = 3; 14 | private static final int BASE_BACKOFF_MS = 1000; // 1 second 15 | private static final int MAX_BACKOFF_MS = 90000; // 90 seconds 16 | 17 | private final int maxRetries; 18 | private final Object source; // For logging 19 | 20 | public RetryHandler() { 21 | this(DEFAULT_MAX_RETRIES, null); 22 | } 23 | 24 | public RetryHandler(int maxRetries, Object source) { 25 | this.maxRetries = maxRetries; 26 | this.source = source; 27 | } 28 | 29 | /** 30 | * Execute an operation with retry logic 31 | */ 32 | public T executeWithRetry(Supplier operation, String operationName) throws APIProviderException { 33 | return executeWithRetryCallable(() -> operation.get(), operationName); 34 | } 35 | 36 | /** 37 | * Execute a callable operation with retry logic 38 | */ 39 | public T executeWithRetryCallable(Callable operation, String operationName) throws APIProviderException { 40 | APIProviderException lastException = null; 41 | 42 | for (int attempt = 1; attempt <= maxRetries; attempt++) { 43 | try { 44 | return operation.call(); 45 | } catch (APIProviderException e) { 46 | lastException = e; 47 | 48 | if (!shouldRetry(e, attempt)) { 49 | throw e; 50 | } 51 | 52 | logRetryAttempt(operationName, attempt, e); 53 | 54 | if (attempt < maxRetries) { 55 | waitForRetry(e, attempt); 56 | } 57 | } catch (Exception e) { 58 | // Convert non-API exceptions to APIProviderException 59 | throw new APIProviderException( 60 | APIProviderException.ErrorCategory.SERVICE_ERROR, 61 | "Unknown", 62 | operationName, 63 | "Unexpected error: " + e.getMessage() 64 | ); 65 | } 66 | } 67 | 68 | // If we get here, all retries failed 69 | throw lastException; 70 | } 71 | 72 | /** 73 | * Execute an operation with retry logic that doesn't return a value 74 | */ 75 | public void executeWithRetryRunnable(Runnable operation, String operationName) throws APIProviderException { 76 | executeWithRetryCallable(() -> { 77 | operation.run(); 78 | return null; 79 | }, operationName); 80 | } 81 | 82 | private boolean shouldRetry(APIProviderException e, int attempt) { 83 | // Don't retry if we've exceeded max attempts 84 | if (attempt >= maxRetries) { 85 | return false; 86 | } 87 | 88 | // Check if the error is retryable based on category 89 | switch (e.getCategory()) { 90 | case RATE_LIMIT: 91 | case NETWORK: 92 | case TIMEOUT: 93 | case SERVICE_ERROR: 94 | return true; 95 | 96 | case AUTHENTICATION: 97 | case MODEL_ERROR: 98 | case CONFIGURATION: 99 | case RESPONSE_ERROR: 100 | case CANCELLED: 101 | return false; 102 | 103 | default: 104 | // For unknown errors, check the isRetryable flag 105 | return e.isRetryable(); 106 | } 107 | } 108 | 109 | private void waitForRetry(APIProviderException e, int attempt) { 110 | int waitTimeMs = calculateWaitTime(e, attempt); 111 | 112 | if (source != null) { 113 | Msg.info(source, String.format("Waiting %d seconds before retry...", waitTimeMs / 1000 )); 114 | } 115 | 116 | try { 117 | Thread.sleep(waitTimeMs); 118 | } catch (InterruptedException ie) { 119 | Thread.currentThread().interrupt(); 120 | throw new RuntimeException("Retry interrupted", ie); 121 | } 122 | } 123 | 124 | private int calculateWaitTime(APIProviderException e, int attempt) { 125 | // For rate limit errors, use the provided retry-after if available 126 | if (e.getCategory() == APIProviderException.ErrorCategory.RATE_LIMIT && 127 | e.getRetryAfterSeconds() != null) { 128 | return e.getRetryAfterSeconds() * 1000; 129 | } 130 | 131 | // For other errors, use exponential backoff with jitter 132 | int backoffMs = BASE_BACKOFF_MS * (int) Math.pow(2, attempt - 1); 133 | 134 | // Add jitter (±25%) 135 | int jitter = (int) (backoffMs * 0.25 * (Math.random() - 0.5)); 136 | backoffMs += jitter; 137 | 138 | // Cap at maximum backoff 139 | return Math.min(backoffMs, MAX_BACKOFF_MS); 140 | } 141 | 142 | private void logRetryAttempt(String operationName, int attempt, APIProviderException e) { 143 | if (source != null) { 144 | String message = String.format( 145 | "Retry attempt %d/%d for %s: %s (%s)", 146 | attempt, maxRetries, operationName, 147 | e.getCategory().getDisplayName(), e.getProviderName() 148 | ); 149 | Msg.warn(source, message); 150 | } 151 | } 152 | 153 | /** 154 | * Check if an exception indicates a transient error that might succeed on retry 155 | */ 156 | public static boolean isTransientError(Throwable error) { 157 | if (error instanceof APIProviderException) { 158 | APIProviderException ape = (APIProviderException) error; 159 | return ape.isRetryable() || isTransientCategory(ape.getCategory()); 160 | } 161 | 162 | // Check for common transient error indicators in message 163 | String message = error.getMessage(); 164 | if (message != null) { 165 | message = message.toLowerCase(); 166 | return message.contains("timeout") || 167 | message.contains("connection reset") || 168 | message.contains("temporary") || 169 | message.contains("service unavailable"); 170 | } 171 | 172 | return false; 173 | } 174 | 175 | private static boolean isTransientCategory(APIProviderException.ErrorCategory category) { 176 | switch (category) { 177 | case RATE_LIMIT: 178 | case NETWORK: 179 | case TIMEOUT: 180 | case SERVICE_ERROR: 181 | return true; 182 | default: 183 | return false; 184 | } 185 | } 186 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/capabilities/ChatProvider.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.capabilities; 2 | 3 | import ghidrassist.LlmApi; 4 | import ghidrassist.apiprovider.ChatMessage; 5 | import ghidrassist.apiprovider.exceptions.APIProviderException; 6 | 7 | import java.util.List; 8 | 9 | /** 10 | * Interface for providers that support basic chat completion. 11 | * This is the core capability that all LLM providers should support. 12 | */ 13 | public interface ChatProvider { 14 | 15 | /** 16 | * Create a chat completion (blocking/synchronous) 17 | * @param messages The conversation messages 18 | * @return The completion response 19 | * @throws APIProviderException if the request fails 20 | */ 21 | String createChatCompletion(List messages) throws APIProviderException; 22 | 23 | /** 24 | * Stream a chat completion (non-blocking/asynchronous) 25 | * @param messages The conversation messages 26 | * @param handler Handler for streaming response chunks 27 | * @throws APIProviderException if the request fails 28 | */ 29 | void streamChatCompletion(List messages, LlmApi.LlmResponseHandler handler) throws APIProviderException; 30 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/capabilities/EmbeddingProvider.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.capabilities; 2 | 3 | import ghidrassist.apiprovider.APIProvider; 4 | 5 | /** 6 | * Interface for providers that support text embeddings. 7 | * Not all providers support this capability. 8 | */ 9 | public interface EmbeddingProvider { 10 | 11 | /** 12 | * Generate embeddings for text asynchronously 13 | * @param text The text to embed 14 | * @param callback Callback to handle the embedding result 15 | */ 16 | void getEmbeddingsAsync(String text, APIProvider.EmbeddingCallback callback); 17 | 18 | /** 19 | * Check if this provider supports embeddings 20 | * @return true if embeddings are supported 21 | */ 22 | default boolean supportsEmbeddings() { 23 | return true; 24 | } 25 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/capabilities/FunctionCallingProvider.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.capabilities; 2 | 3 | import ghidrassist.apiprovider.ChatMessage; 4 | import ghidrassist.apiprovider.exceptions.APIProviderException; 5 | 6 | import java.util.List; 7 | import java.util.Map; 8 | 9 | /** 10 | * Interface for providers that support function calling / tool calling. 11 | * Not all providers support this capability. 12 | */ 13 | public interface FunctionCallingProvider { 14 | 15 | /** 16 | * Create a chat completion with function calling support 17 | * @param messages The conversation messages 18 | * @param functions Available functions/tools that the model can call 19 | * @return The completion response, potentially containing function calls 20 | * @throws APIProviderException if the request fails 21 | */ 22 | String createChatCompletionWithFunctions(List messages, List> functions) throws APIProviderException; 23 | 24 | /** 25 | * Check if this provider supports function calling 26 | * @return true if function calling is supported 27 | */ 28 | default boolean supportsFunctionCalling() { 29 | return true; 30 | } 31 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/capabilities/ModelListProvider.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.capabilities; 2 | 3 | import ghidrassist.apiprovider.exceptions.APIProviderException; 4 | 5 | import java.util.List; 6 | 7 | /** 8 | * Interface for providers that can list available models. 9 | * Not all providers support this capability. 10 | */ 11 | public interface ModelListProvider { 12 | 13 | /** 14 | * Get list of available models from this provider 15 | * @return List of model identifiers 16 | * @throws APIProviderException if the request fails 17 | */ 18 | List getAvailableModels() throws APIProviderException; 19 | 20 | /** 21 | * Check if this provider supports model listing 22 | * @return true if model listing is supported 23 | */ 24 | default boolean supportsModelListing() { 25 | return true; 26 | } 27 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/exceptions/APIProviderException.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.exceptions; 2 | 3 | /** 4 | * Base exception for all API provider errors with structured error information 5 | */ 6 | public class APIProviderException extends Exception { 7 | private final ErrorCategory category; 8 | private final String providerName; 9 | private final String operation; 10 | private final int httpStatusCode; 11 | private final String apiErrorCode; 12 | private final boolean isRetryable; 13 | private final Integer retryAfterSeconds; 14 | 15 | public enum ErrorCategory { 16 | AUTHENTICATION("Authentication Error", "Check your API key and credentials"), 17 | NETWORK("Network Error", "Check your internet connection and API URL"), 18 | RATE_LIMIT("Rate Limit Exceeded", "Too many requests - please wait before retrying"), 19 | MODEL_ERROR("Model Error", "Issue with the specified model or unsupported feature"), 20 | CONFIGURATION("Configuration Error", "Invalid settings or configuration"), 21 | RESPONSE_ERROR("Response Error", "Invalid or unexpected response from API"), 22 | SERVICE_ERROR("Service Error", "API service is experiencing issues"), 23 | TIMEOUT("Timeout Error", "Request took too long to complete"), 24 | CANCELLED("Request Cancelled", "Operation was cancelled"); 25 | 26 | private final String displayName; 27 | private final String description; 28 | 29 | ErrorCategory(String displayName, String description) { 30 | this.displayName = displayName; 31 | this.description = description; 32 | } 33 | 34 | public String getDisplayName() { return displayName; } 35 | public String getDescription() { return description; } 36 | } 37 | 38 | public APIProviderException(ErrorCategory category, String providerName, String operation, 39 | String message) { 40 | this(category, providerName, operation, -1, null, message, false, null, null); 41 | } 42 | 43 | public APIProviderException(ErrorCategory category, String providerName, String operation, 44 | int httpStatusCode, String apiErrorCode, String message) { 45 | this(category, providerName, operation, httpStatusCode, apiErrorCode, message, false, null, null); 46 | } 47 | 48 | public APIProviderException(ErrorCategory category, String providerName, String operation, 49 | int httpStatusCode, String apiErrorCode, String message, 50 | boolean isRetryable, Integer retryAfterSeconds, Throwable cause) { 51 | super(message, cause); 52 | this.category = category; 53 | this.providerName = providerName; 54 | this.operation = operation; 55 | this.httpStatusCode = httpStatusCode; 56 | this.apiErrorCode = apiErrorCode; 57 | this.isRetryable = isRetryable; 58 | this.retryAfterSeconds = retryAfterSeconds; 59 | } 60 | 61 | // Getters 62 | public ErrorCategory getCategory() { return category; } 63 | public String getProviderName() { return providerName; } 64 | public String getOperation() { return operation; } 65 | public int getHttpStatusCode() { return httpStatusCode; } 66 | public String getApiErrorCode() { return apiErrorCode; } 67 | public boolean isRetryable() { return isRetryable; } 68 | public Integer getRetryAfterSeconds() { return retryAfterSeconds; } 69 | 70 | /** 71 | * Get technical details for debugging 72 | */ 73 | public String getTechnicalDetails() { 74 | StringBuilder details = new StringBuilder(); 75 | details.append("Provider: ").append(providerName).append("\n"); 76 | details.append("Operation: ").append(operation).append("\n"); 77 | details.append("Category: ").append(category.getDisplayName()).append("\n"); 78 | 79 | if (httpStatusCode > 0) { 80 | details.append("HTTP Status: ").append(httpStatusCode).append("\n"); 81 | } 82 | 83 | if (apiErrorCode != null && !apiErrorCode.isEmpty()) { 84 | details.append("API Error Code: ").append(apiErrorCode).append("\n"); 85 | } 86 | 87 | if (getMessage() != null) { 88 | details.append("Message: ").append(getMessage()).append("\n"); 89 | } 90 | 91 | if (getCause() != null) { 92 | details.append("Cause: ").append(getCause().getClass().getSimpleName()).append("\n"); 93 | } 94 | 95 | return details.toString(); 96 | } 97 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/exceptions/AuthenticationException.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.exceptions; 2 | 3 | /** 4 | * Exception for authentication and authorization failures 5 | */ 6 | public class AuthenticationException extends APIProviderException { 7 | 8 | public AuthenticationException(String providerName, String operation, String message) { 9 | super(ErrorCategory.AUTHENTICATION, providerName, operation, message); 10 | } 11 | 12 | public AuthenticationException(String providerName, String operation, int httpStatusCode, 13 | String apiErrorCode, String message) { 14 | super(ErrorCategory.AUTHENTICATION, providerName, operation, httpStatusCode, apiErrorCode, 15 | message, false, null, null); 16 | } 17 | 18 | public AuthenticationException(String providerName, String operation, String message, Throwable cause) { 19 | super(ErrorCategory.AUTHENTICATION, providerName, operation, -1, null, message, false, null, cause); 20 | } 21 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/exceptions/ModelException.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.exceptions; 2 | 3 | /** 4 | * Exception for model-related errors 5 | */ 6 | public class ModelException extends APIProviderException { 7 | 8 | public enum ModelErrorType { 9 | MODEL_NOT_FOUND("The specified model was not found or is not available"), 10 | UNSUPPORTED_FEATURE("The model does not support this feature"), 11 | CONTEXT_LENGTH_EXCEEDED("Input exceeds the model's maximum context length"), 12 | TOKEN_LIMIT_EXCEEDED("Response would exceed the maximum token limit"), 13 | MODEL_OVERLOADED("The model is currently overloaded"); 14 | 15 | private final String description; 16 | 17 | ModelErrorType(String description) { 18 | this.description = description; 19 | } 20 | 21 | public String getDescription() { return description; } 22 | } 23 | 24 | private final ModelErrorType modelErrorType; 25 | 26 | public ModelException(String providerName, String operation, ModelErrorType errorType) { 27 | super(ErrorCategory.MODEL_ERROR, providerName, operation, errorType.getDescription()); 28 | this.modelErrorType = errorType; 29 | } 30 | 31 | public ModelException(String providerName, String operation, ModelErrorType errorType, 32 | int httpStatusCode, String apiErrorCode) { 33 | super(ErrorCategory.MODEL_ERROR, providerName, operation, httpStatusCode, apiErrorCode, 34 | errorType.getDescription(), false, null, null); 35 | this.modelErrorType = errorType; 36 | } 37 | 38 | public ModelException(String providerName, String operation, String message) { 39 | super(ErrorCategory.MODEL_ERROR, providerName, operation, message); 40 | this.modelErrorType = null; 41 | } 42 | 43 | public ModelErrorType getModelErrorType() { 44 | return modelErrorType; 45 | } 46 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/exceptions/NetworkException.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.exceptions; 2 | 3 | /** 4 | * Exception for network-related failures 5 | */ 6 | public class NetworkException extends APIProviderException { 7 | 8 | public enum NetworkErrorType { 9 | CONNECTION_FAILED("Cannot connect to server"), 10 | TIMEOUT("Request timed out"), 11 | SSL_ERROR("SSL/TLS connection failed"), 12 | DNS_ERROR("Cannot resolve hostname"), 13 | CONNECTION_LOST("Connection was lost during request"); 14 | 15 | private final String description; 16 | 17 | NetworkErrorType(String description) { 18 | this.description = description; 19 | } 20 | 21 | public String getDescription() { return description; } 22 | } 23 | 24 | private final NetworkErrorType networkErrorType; 25 | 26 | public NetworkException(String providerName, String operation, NetworkErrorType errorType) { 27 | super(ErrorCategory.NETWORK, providerName, operation, errorType.getDescription()); 28 | this.networkErrorType = errorType; 29 | } 30 | 31 | public NetworkException(String providerName, String operation, NetworkErrorType errorType, 32 | Throwable cause) { 33 | super(ErrorCategory.NETWORK, providerName, operation, -1, null, errorType.getDescription(), 34 | true, null, cause); 35 | this.networkErrorType = errorType; 36 | } 37 | 38 | public NetworkException(String providerName, String operation, String message) { 39 | super(ErrorCategory.NETWORK, providerName, operation, message); 40 | this.networkErrorType = null; 41 | } 42 | 43 | public NetworkErrorType getNetworkErrorType() { 44 | return networkErrorType; 45 | } 46 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/exceptions/RateLimitException.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.exceptions; 2 | 3 | /** 4 | * Exception for rate limiting errors 5 | */ 6 | public class RateLimitException extends APIProviderException { 7 | 8 | public RateLimitException(String providerName, String operation, Integer retryAfterSeconds) { 9 | super(ErrorCategory.RATE_LIMIT, providerName, operation, 429, "rate_limit_exceeded", 10 | "Rate limit exceeded. Please wait before retrying.", true, retryAfterSeconds, null); 11 | } 12 | 13 | public RateLimitException(String providerName, String operation, String message, 14 | Integer retryAfterSeconds) { 15 | super(ErrorCategory.RATE_LIMIT, providerName, operation, 429, "rate_limit_exceeded", 16 | message, true, retryAfterSeconds, null); 17 | } 18 | 19 | public RateLimitException(String providerName, String operation, int httpStatusCode, 20 | String apiErrorCode, String message, Integer retryAfterSeconds) { 21 | super(ErrorCategory.RATE_LIMIT, providerName, operation, httpStatusCode, apiErrorCode, 22 | message, true, retryAfterSeconds, null); 23 | } 24 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/exceptions/ResponseException.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.exceptions; 2 | 3 | /** 4 | * Exception for response parsing and format errors 5 | */ 6 | public class ResponseException extends APIProviderException { 7 | 8 | public enum ResponseErrorType { 9 | MALFORMED_JSON("Response contains invalid JSON"), 10 | MISSING_REQUIRED_FIELD("Required field missing from response"), 11 | UNEXPECTED_FORMAT("Response format is not as expected"), 12 | EMPTY_RESPONSE("Received empty response"), 13 | STREAM_INTERRUPTED("Response stream was interrupted"); 14 | 15 | private final String description; 16 | 17 | ResponseErrorType(String description) { 18 | this.description = description; 19 | } 20 | 21 | public String getDescription() { return description; } 22 | } 23 | 24 | private final ResponseErrorType responseErrorType; 25 | 26 | public ResponseException(String providerName, String operation, ResponseErrorType errorType) { 27 | super(ErrorCategory.RESPONSE_ERROR, providerName, operation, errorType.getDescription()); 28 | this.responseErrorType = errorType; 29 | } 30 | 31 | public ResponseException(String providerName, String operation, ResponseErrorType errorType, 32 | Throwable cause) { 33 | super(ErrorCategory.RESPONSE_ERROR, providerName, operation, -1, null, 34 | errorType.getDescription(), false, null, cause); 35 | this.responseErrorType = errorType; 36 | } 37 | 38 | public ResponseException(String providerName, String operation, String message) { 39 | super(ErrorCategory.RESPONSE_ERROR, providerName, operation, message); 40 | this.responseErrorType = null; 41 | } 42 | 43 | public ResponseErrorType getResponseErrorType() { 44 | return responseErrorType; 45 | } 46 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/exceptions/StreamCancelledException.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.exceptions; 2 | 3 | /** 4 | * Exception for stream cancellation scenarios 5 | */ 6 | public class StreamCancelledException extends APIProviderException { 7 | 8 | public enum CancellationReason { 9 | USER_REQUESTED("User cancelled the request"), 10 | TIMEOUT("Request timed out"), 11 | CONNECTION_LOST("Network connection was lost"), 12 | PROVIDER_ERROR("API provider terminated the stream"), 13 | SHUTDOWN("Application is shutting down"); 14 | 15 | private final String description; 16 | 17 | CancellationReason(String description) { 18 | this.description = description; 19 | } 20 | 21 | public String getDescription() { return description; } 22 | } 23 | 24 | private final CancellationReason cancellationReason; 25 | 26 | public StreamCancelledException(String providerName, String operation, CancellationReason reason) { 27 | super(ErrorCategory.CANCELLED, providerName, operation, reason.getDescription()); 28 | this.cancellationReason = reason; 29 | } 30 | 31 | public StreamCancelledException(String providerName, String operation, CancellationReason reason, 32 | Throwable cause) { 33 | super(ErrorCategory.CANCELLED, providerName, operation, -1, null, reason.getDescription(), 34 | false, null, cause); 35 | this.cancellationReason = reason; 36 | } 37 | 38 | public CancellationReason getCancellationReason() { 39 | return cancellationReason; 40 | } 41 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/factory/APIProviderFactory.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.factory; 2 | 3 | import ghidrassist.apiprovider.APIProvider; 4 | import ghidrassist.apiprovider.APIProviderConfig; 5 | 6 | /** 7 | * Factory interface for creating API providers. 8 | * Follows the Factory Method pattern to allow extensibility. 9 | */ 10 | public interface APIProviderFactory { 11 | 12 | /** 13 | * Create an API provider instance from configuration 14 | * @param config The provider configuration 15 | * @return A configured API provider instance 16 | * @throws UnsupportedProviderException if this factory cannot create the requested provider type 17 | */ 18 | APIProvider createProvider(APIProviderConfig config) throws UnsupportedProviderException; 19 | 20 | /** 21 | * Check if this factory supports creating the given provider type 22 | * @param type The provider type to check 23 | * @return true if this factory can create providers of the given type 24 | */ 25 | boolean supports(APIProvider.ProviderType type); 26 | 27 | /** 28 | * Get the provider type this factory creates 29 | * @return The provider type 30 | */ 31 | APIProvider.ProviderType getProviderType(); 32 | 33 | /** 34 | * Get a human-readable name for this factory 35 | * @return Factory name 36 | */ 37 | String getFactoryName(); 38 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/factory/AnthropicProviderFactory.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.factory; 2 | 3 | import ghidrassist.apiprovider.APIProvider; 4 | import ghidrassist.apiprovider.APIProviderConfig; 5 | import ghidrassist.apiprovider.AnthropicProvider; 6 | 7 | /** 8 | * Factory for creating Anthropic API providers. 9 | */ 10 | public class AnthropicProviderFactory implements APIProviderFactory { 11 | 12 | @Override 13 | public APIProvider createProvider(APIProviderConfig config) throws UnsupportedProviderException { 14 | if (!supports(config.getType())) { 15 | throw new UnsupportedProviderException(config.getType(), getFactoryName()); 16 | } 17 | 18 | return new AnthropicProvider( 19 | config.getName(), 20 | config.getModel(), 21 | config.getMaxTokens(), 22 | config.getUrl(), 23 | config.getKey(), 24 | config.isDisableTlsVerification(), 25 | config.getTimeout() 26 | ); 27 | } 28 | 29 | @Override 30 | public boolean supports(APIProvider.ProviderType type) { 31 | return type == APIProvider.ProviderType.ANTHROPIC; 32 | } 33 | 34 | @Override 35 | public APIProvider.ProviderType getProviderType() { 36 | return APIProvider.ProviderType.ANTHROPIC; 37 | } 38 | 39 | @Override 40 | public String getFactoryName() { 41 | return "AnthropicProviderFactory"; 42 | } 43 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/factory/LMStudioProviderFactory.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.factory; 2 | 3 | import ghidrassist.apiprovider.APIProvider; 4 | import ghidrassist.apiprovider.APIProviderConfig; 5 | import ghidrassist.apiprovider.LMStudioProvider; 6 | 7 | /** 8 | * Factory for creating LM Studio API providers. 9 | */ 10 | public class LMStudioProviderFactory implements APIProviderFactory { 11 | 12 | @Override 13 | public APIProvider createProvider(APIProviderConfig config) throws UnsupportedProviderException { 14 | if (!supports(config.getType())) { 15 | throw new UnsupportedProviderException(config.getType(), getFactoryName()); 16 | } 17 | 18 | return new LMStudioProvider( 19 | config.getName(), 20 | config.getModel(), 21 | config.getMaxTokens(), 22 | config.getUrl(), 23 | config.getKey(), 24 | config.isDisableTlsVerification(), 25 | config.getTimeout() 26 | ); 27 | } 28 | 29 | @Override 30 | public boolean supports(APIProvider.ProviderType type) { 31 | return type == APIProvider.ProviderType.LMSTUDIO; 32 | } 33 | 34 | @Override 35 | public APIProvider.ProviderType getProviderType() { 36 | return APIProvider.ProviderType.LMSTUDIO; 37 | } 38 | 39 | @Override 40 | public String getFactoryName() { 41 | return "LMStudioProviderFactory"; 42 | } 43 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/factory/OllamaProviderFactory.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.factory; 2 | 3 | import ghidrassist.apiprovider.APIProvider; 4 | import ghidrassist.apiprovider.APIProviderConfig; 5 | import ghidrassist.apiprovider.OllamaProvider; 6 | 7 | /** 8 | * Factory for creating Ollama API providers. 9 | */ 10 | public class OllamaProviderFactory implements APIProviderFactory { 11 | 12 | @Override 13 | public APIProvider createProvider(APIProviderConfig config) throws UnsupportedProviderException { 14 | if (!supports(config.getType())) { 15 | throw new UnsupportedProviderException(config.getType(), getFactoryName()); 16 | } 17 | 18 | return new OllamaProvider( 19 | config.getName(), 20 | config.getModel(), 21 | config.getMaxTokens(), 22 | config.getUrl(), 23 | config.getKey(), 24 | config.isDisableTlsVerification(), 25 | config.getTimeout() 26 | ); 27 | } 28 | 29 | @Override 30 | public boolean supports(APIProvider.ProviderType type) { 31 | return type == APIProvider.ProviderType.OLLAMA; 32 | } 33 | 34 | @Override 35 | public APIProvider.ProviderType getProviderType() { 36 | return APIProvider.ProviderType.OLLAMA; 37 | } 38 | 39 | @Override 40 | public String getFactoryName() { 41 | return "OllamaProviderFactory"; 42 | } 43 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/factory/OpenAIProviderFactory.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.factory; 2 | 3 | import ghidrassist.apiprovider.APIProvider; 4 | import ghidrassist.apiprovider.APIProviderConfig; 5 | import ghidrassist.apiprovider.OpenAIProvider; 6 | 7 | /** 8 | * Factory for creating OpenAI API providers. 9 | */ 10 | public class OpenAIProviderFactory implements APIProviderFactory { 11 | 12 | @Override 13 | public APIProvider createProvider(APIProviderConfig config) throws UnsupportedProviderException { 14 | if (!supports(config.getType())) { 15 | throw new UnsupportedProviderException(config.getType(), getFactoryName()); 16 | } 17 | 18 | return new OpenAIProvider( 19 | config.getName(), 20 | config.getModel(), 21 | config.getMaxTokens(), 22 | config.getUrl(), 23 | config.getKey(), 24 | config.isDisableTlsVerification(), 25 | config.getTimeout() 26 | ); 27 | } 28 | 29 | @Override 30 | public boolean supports(APIProvider.ProviderType type) { 31 | return type == APIProvider.ProviderType.OPENAI; 32 | } 33 | 34 | @Override 35 | public APIProvider.ProviderType getProviderType() { 36 | return APIProvider.ProviderType.OPENAI; 37 | } 38 | 39 | @Override 40 | public String getFactoryName() { 41 | return "OpenAIProviderFactory"; 42 | } 43 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/factory/OpenWebUiProviderFactory.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.factory; 2 | 3 | import ghidrassist.apiprovider.APIProvider; 4 | import ghidrassist.apiprovider.APIProviderConfig; 5 | import ghidrassist.apiprovider.OpenWebUiProvider; 6 | 7 | /** 8 | * Factory for creating OpenWebUI API providers. 9 | */ 10 | public class OpenWebUiProviderFactory implements APIProviderFactory { 11 | 12 | @Override 13 | public APIProvider createProvider(APIProviderConfig config) throws UnsupportedProviderException { 14 | if (!supports(config.getType())) { 15 | throw new UnsupportedProviderException(config.getType(), getFactoryName()); 16 | } 17 | 18 | return new OpenWebUiProvider( 19 | config.getName(), 20 | config.getModel(), 21 | config.getMaxTokens(), 22 | config.getUrl(), 23 | config.getKey(), 24 | config.isDisableTlsVerification(), 25 | config.getTimeout() 26 | ); 27 | } 28 | 29 | @Override 30 | public boolean supports(APIProvider.ProviderType type) { 31 | return type == APIProvider.ProviderType.OPENWEBUI; 32 | } 33 | 34 | @Override 35 | public APIProvider.ProviderType getProviderType() { 36 | return APIProvider.ProviderType.OPENWEBUI; 37 | } 38 | 39 | @Override 40 | public String getFactoryName() { 41 | return "OpenWebUiProviderFactory"; 42 | } 43 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/factory/ProviderRegistry.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.factory; 2 | 3 | import ghidrassist.apiprovider.APIProvider; 4 | import ghidrassist.apiprovider.APIProviderConfig; 5 | 6 | import java.util.*; 7 | import java.util.concurrent.ConcurrentHashMap; 8 | 9 | /** 10 | * Registry for API provider factories. 11 | * Manages the creation of providers through registered factories. 12 | * Thread-safe and follows the Registry pattern. 13 | */ 14 | public class ProviderRegistry { 15 | 16 | private final Map factories = new ConcurrentHashMap<>(); 17 | private static final ProviderRegistry INSTANCE = new ProviderRegistry(); 18 | 19 | /** 20 | * Get the singleton instance of the provider registry 21 | */ 22 | public static ProviderRegistry getInstance() { 23 | return INSTANCE; 24 | } 25 | 26 | /** 27 | * Private constructor for singleton 28 | */ 29 | private ProviderRegistry() { 30 | // Register default factories 31 | registerDefaultFactories(); 32 | } 33 | 34 | /** 35 | * Register a factory for a specific provider type 36 | * @param factory The factory to register 37 | */ 38 | public void registerFactory(APIProviderFactory factory) { 39 | if (factory == null) { 40 | throw new IllegalArgumentException("Factory cannot be null"); 41 | } 42 | 43 | APIProvider.ProviderType type = factory.getProviderType(); 44 | if (type == null) { 45 | throw new IllegalArgumentException("Factory must specify a provider type"); 46 | } 47 | 48 | factories.put(type, factory); 49 | } 50 | 51 | /** 52 | * Unregister a factory for a specific provider type 53 | * @param type The provider type to unregister 54 | * @return The previously registered factory, or null if none was registered 55 | */ 56 | public APIProviderFactory unregisterFactory(APIProvider.ProviderType type) { 57 | return factories.remove(type); 58 | } 59 | 60 | /** 61 | * Create a provider using the appropriate factory 62 | * @param config The provider configuration 63 | * @return A configured provider instance 64 | * @throws UnsupportedProviderException if no factory is registered for the provider type 65 | */ 66 | public APIProvider createProvider(APIProviderConfig config) throws UnsupportedProviderException { 67 | if (config == null) { 68 | throw new IllegalArgumentException("Provider config cannot be null"); 69 | } 70 | 71 | APIProvider.ProviderType type = config.getType(); 72 | APIProviderFactory factory = factories.get(type); 73 | 74 | if (factory == null) { 75 | throw new UnsupportedProviderException(type, "ProviderRegistry", 76 | "No factory registered for provider type: " + type); 77 | } 78 | 79 | return factory.createProvider(config); 80 | } 81 | 82 | /** 83 | * Check if a provider type is supported 84 | * @param type The provider type to check 85 | * @return true if a factory is registered for this type 86 | */ 87 | public boolean isSupported(APIProvider.ProviderType type) { 88 | return factories.containsKey(type); 89 | } 90 | 91 | /** 92 | * Get all supported provider types 93 | * @return Set of supported provider types 94 | */ 95 | public Set getSupportedTypes() { 96 | return new HashSet<>(factories.keySet()); 97 | } 98 | 99 | /** 100 | * Get all registered factories 101 | * @return Map of provider types to their factories 102 | */ 103 | public Map getRegisteredFactories() { 104 | return new HashMap<>(factories); 105 | } 106 | 107 | /** 108 | * Get the factory for a specific provider type 109 | * @param type The provider type 110 | * @return The factory, or null if none is registered 111 | */ 112 | public APIProviderFactory getFactory(APIProvider.ProviderType type) { 113 | return factories.get(type); 114 | } 115 | 116 | /** 117 | * Clear all registered factories (mainly for testing) 118 | */ 119 | public void clearFactories() { 120 | factories.clear(); 121 | } 122 | 123 | /** 124 | * Register the default built-in factories 125 | */ 126 | private void registerDefaultFactories() { 127 | registerFactory(new AnthropicProviderFactory()); 128 | registerFactory(new OpenAIProviderFactory()); 129 | registerFactory(new OllamaProviderFactory()); 130 | registerFactory(new LMStudioProviderFactory()); 131 | registerFactory(new OpenWebUiProviderFactory()); 132 | } 133 | 134 | /** 135 | * Get information about all registered factories 136 | * @return Human-readable string describing registered factories 137 | */ 138 | public String getRegistryInfo() { 139 | StringBuilder sb = new StringBuilder(); 140 | sb.append("Registered Provider Factories:\n"); 141 | 142 | for (Map.Entry entry : factories.entrySet()) { 143 | sb.append(String.format(" %s -> %s\n", 144 | entry.getKey(), 145 | entry.getValue().getFactoryName())); 146 | } 147 | 148 | return sb.toString(); 149 | } 150 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/apiprovider/factory/UnsupportedProviderException.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.apiprovider.factory; 2 | 3 | import ghidrassist.apiprovider.APIProvider; 4 | 5 | /** 6 | * Exception thrown when a factory cannot create a requested provider type. 7 | */ 8 | public class UnsupportedProviderException extends Exception { 9 | 10 | private final APIProvider.ProviderType requestedType; 11 | private final String factoryName; 12 | 13 | public UnsupportedProviderException(APIProvider.ProviderType requestedType, String factoryName) { 14 | super(String.format("Factory '%s' does not support provider type '%s'", factoryName, requestedType)); 15 | this.requestedType = requestedType; 16 | this.factoryName = factoryName; 17 | } 18 | 19 | public UnsupportedProviderException(APIProvider.ProviderType requestedType, String factoryName, String message) { 20 | super(message); 21 | this.requestedType = requestedType; 22 | this.factoryName = factoryName; 23 | } 24 | 25 | public UnsupportedProviderException(APIProvider.ProviderType requestedType, String factoryName, String message, Throwable cause) { 26 | super(message, cause); 27 | this.requestedType = requestedType; 28 | this.factoryName = factoryName; 29 | } 30 | 31 | public APIProvider.ProviderType getRequestedType() { 32 | return requestedType; 33 | } 34 | 35 | public String getFactoryName() { 36 | return factoryName; 37 | } 38 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/core/ActionConstants.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.core; 2 | 3 | import java.util.*; 4 | 5 | public class ActionConstants { 6 | 7 | public static final List> FN_TEMPLATES = Arrays.asList( 8 | createFunctionTemplate( 9 | "rename_function", 10 | "Rename a function", 11 | createParameters( 12 | createParameter("new_name", "string", "The new name for the function. (e.g., recv_data)") 13 | ) 14 | ), 15 | createFunctionTemplate( 16 | "rename_variable", 17 | "Rename a variable within a function", 18 | createParameters( 19 | createParameter("func_name", "string", "The name of the function containing the variable. (e.g., sub_40001234)"), 20 | createParameter("var_name", "string", "The current name of the variable. (e.g., var_20)"), 21 | createParameter("new_name", "string", "The new name for the variable. (e.g., recv_buf)") 22 | ) 23 | ), 24 | createFunctionTemplate( 25 | "retype_variable", 26 | "Set a variable data type within a function", 27 | createParameters( 28 | createParameter("func_name", "string", "The name of the function containing the variable. (e.g., sub_40001234)"), 29 | createParameter("var_name", "string", "The current name of the variable. (e.g., rax_12)"), 30 | createParameter("new_type", "string", "The new type for the variable. (e.g., int32_t)") 31 | ) 32 | ), 33 | createFunctionTemplate( 34 | "auto_create_struct", 35 | "Automatically create a structure datatype from a variable given its offset uses in a given function.", 36 | createParameters( 37 | createParameter("func_name", "string", "The name of the function containing the variable. (e.g., sub_40001234)"), 38 | createParameter("var_name", "string", "The current name of the variable. (e.g., rax_12)") 39 | ) 40 | ) 41 | ); 42 | 43 | public static final Map ACTION_PROMPTS = new HashMap<>(); 44 | 45 | static { 46 | ACTION_PROMPTS.put("rename_function", 47 | "Use the 'rename_function' tool:\n```\n{code}\n```\n" + 48 | "Examine the code functionality, strings and log parameters.\n" + 49 | "If you detect C++ Super::Derived::Method or Class::Method style class names, recommend that name first, OTHERWISE USE PROCEDURAL NAMING.\n" + 50 | "CREATE A JSON TOOL_CALL LIST WITH SUGGESTIONS FOR THREE POSSIBLE FUNCTION NAMES " + 51 | "THAT ALIGN AS CLOSELY AS POSSIBLE TO WHAT THE CODE ABOVE DOES.\n" + 52 | "RESPOND ONLY WITH THE RENAME_FUNCTION PARAMETER (new_name). DO NOT INCLUDE ANY OTHER TEXT.\n" + 53 | "ALL JSON MUST BE PROPERLY FORMATTED WITH NO EMBEDDED COMMENTS.\n" 54 | ); 55 | ACTION_PROMPTS.put("rename_variable", 56 | "Use the 'rename_variable' tool:\n```\n{code}\n```\n" + 57 | "Examine the code functionality, strings, and log parameters.\n" + 58 | "SUGGEST VARIABLE NAMES THAT BETTER ALIGN WITH THE CODE FUNCTIONALITY.\n" + 59 | "RESPOND ONLY WITH THE RENAME_VARIABLE PARAMETERS (func_name, var_name, new_name). DO NOT INCLUDE ANY OTHER TEXT.\n" + 60 | "ALL JSON VALUES MUST BE TEXT STRINGS, INCLUDING NUMBERS AND ADDRESSES, e.g., \"0x1234abcd\".\n" + 61 | "ALL JSON MUST BE PROPERLY FORMATTED WITH NO EMBEDDED COMMENTS.\n" 62 | ); 63 | ACTION_PROMPTS.put("retype_variable", 64 | "Use the 'retype_variable' tool:\n```\n{code}\n```\n" + 65 | "Examine the code functionality, strings, and log parameters.\n" + 66 | "SUGGEST VARIABLE TYPES THAT BETTER ALIGN WITH THE CODE FUNCTIONALITY.\n" + 67 | "RESPOND ONLY WITH THE RETYPE_VARIABLE PARAMETERS (func_name, var_name, new_type). DO NOT INCLUDE ANY OTHER TEXT.\n" + 68 | "ALL JSON VALUES MUST BE TEXT STRINGS, INCLUDING NUMBERS AND ADDRESSES, e.g., \"0x1234abcd\".\n" + 69 | "ALL JSON MUST BE PROPERLY FORMATTED WITH NO EMBEDDED COMMENTS.\n" 70 | ); 71 | ACTION_PROMPTS.put("auto_create_struct", 72 | "Use the 'auto_create_struct' tool:\n```\n{code}\n```\n" + 73 | "Examine the code functionality, parameters, and variables being used.\n" + 74 | "IF YOU DETECT A VARIABLE THAT USES OFFSET ACCESS SUCH AS `*(arg1 + 0xc)` OR VARIABLES LIKELY TO BE STRUCTURES OR CLASSES,\n" + 75 | "RESPOND ONLY WITH THE AUTO_CREATE_STRUCT PARAMETERS (func_name, var_name). DO NOT INCLUDE ANY OTHER TEXT.\n" + 76 | "ALL JSON VALUES MUST BE TEXT STRINGS, INCLUDING NUMBERS AND ADDRESSES, e.g., \"0x1234abcd\".\n" + 77 | "ALL JSON MUST BE PROPERLY FORMATTED WITH NO EMBEDDED COMMENTS.\n" 78 | ); 79 | } 80 | 81 | // Helper methods for creating function templates 82 | private static Map createFunctionTemplate(String name, String description, Map parameters) { 83 | Map functionMap = new HashMap<>(); 84 | functionMap.put("name", name); 85 | functionMap.put("description", description); 86 | functionMap.put("parameters", parameters); 87 | 88 | Map template = new HashMap<>(); 89 | template.put("type", "function"); 90 | template.put("function", functionMap); 91 | return template; 92 | } 93 | 94 | private static Map createParameters(Map... parameters) { 95 | Map parametersMap = new HashMap<>(); 96 | parametersMap.put("type", "object"); 97 | 98 | Map properties = new HashMap<>(); 99 | List required = new ArrayList<>(); 100 | 101 | for (Map param : parameters) { 102 | String name = (String) param.get("name"); 103 | properties.put(name, param); 104 | required.add(name); 105 | } 106 | 107 | parametersMap.put("properties", properties); 108 | parametersMap.put("required", required); 109 | return parametersMap; 110 | } 111 | 112 | private static Map createParameter(String name, String type, String description) { 113 | Map param = new HashMap<>(); 114 | param.put("name", name); 115 | param.put("type", type); 116 | param.put("description", description); 117 | return param; 118 | } 119 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/core/CodeUtils.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.core; 2 | 3 | import ghidra.app.decompiler.ClangNode; 4 | import ghidra.app.decompiler.ClangToken; 5 | import ghidra.app.decompiler.ClangTokenGroup; 6 | import ghidra.app.decompiler.DecompInterface; 7 | import ghidra.app.decompiler.DecompileResults; 8 | import ghidra.program.model.address.Address; 9 | import ghidra.program.model.listing.Function; 10 | import ghidra.program.model.listing.Instruction; 11 | import ghidra.program.model.listing.InstructionIterator; 12 | import ghidra.program.model.listing.Listing; 13 | import ghidra.program.model.listing.Program; 14 | import ghidra.util.task.TaskMonitor; 15 | 16 | public class CodeUtils { 17 | 18 | /** 19 | * Gets the decompiled code for a function. 20 | * @param function The function to decompile 21 | * @param monitor Task monitor for tracking progress 22 | * @return The decompiled code as a string 23 | */ 24 | public static String getFunctionCode(Function function, TaskMonitor monitor) { 25 | DecompInterface decompiler = new DecompInterface(); 26 | decompiler.openProgram(function.getProgram()); 27 | 28 | try { 29 | DecompileResults results = decompiler.decompileFunction(function, 60, monitor); 30 | if (results != null && results.decompileCompleted()) { 31 | return results.getDecompiledFunction().getC(); 32 | } else { 33 | return "Failed to decompile function."; 34 | } 35 | } catch (Exception e) { 36 | return "Failed to decompile function: " + e.getMessage(); 37 | } finally { 38 | decompiler.dispose(); 39 | } 40 | } 41 | 42 | /** 43 | * Gets the disassembly for a function. 44 | * @param function The function to disassemble 45 | * @return The disassembled code as a string 46 | */ 47 | public static String getFunctionDisassembly(Function function) { 48 | StringBuilder sb = new StringBuilder(); 49 | Listing listing = function.getProgram().getListing(); 50 | InstructionIterator instructions = listing.getInstructions(function.getBody(), true); 51 | 52 | while (instructions.hasNext()) { 53 | Instruction instr = instructions.next(); 54 | sb.append(formatInstruction(instr)).append("\n"); 55 | } 56 | 57 | return sb.toString(); 58 | } 59 | 60 | /** 61 | * Gets the decompiled code for a specific line at an address. 62 | * @param address The address to get code for 63 | * @param monitor Task monitor for tracking progress 64 | * @param program The current program 65 | * @return The decompiled line as a string 66 | */ 67 | public static String getLineCode(Address address, TaskMonitor monitor, Program program) { 68 | DecompInterface decompiler = new DecompInterface(); 69 | decompiler.openProgram(program); 70 | 71 | try { 72 | Function function = program.getFunctionManager().getFunctionContaining(address); 73 | if (function == null) { 74 | return "No function containing the address."; 75 | } 76 | 77 | DecompileResults results = decompiler.decompileFunction(function, 60, monitor); 78 | if (results != null && results.decompileCompleted()) { 79 | ClangTokenGroup tokens = results.getCCodeMarkup(); 80 | if (tokens != null) { 81 | StringBuilder codeLineBuilder = new StringBuilder(); 82 | boolean found = collectCodeLine(tokens, address, codeLineBuilder); 83 | if (found && codeLineBuilder.length() > 0) { 84 | return codeLineBuilder.toString(); 85 | } else { 86 | return "No code line found at the address."; 87 | } 88 | } else { 89 | return "Failed to get code tokens."; 90 | } 91 | } else { 92 | return "Failed to decompile function."; 93 | } 94 | } catch (Exception e) { 95 | return "Failed to decompile line: " + e.getMessage(); 96 | } finally { 97 | decompiler.dispose(); 98 | } 99 | } 100 | 101 | /** 102 | * Gets the disassembly for a specific address. 103 | * @param address The address to get disassembly for 104 | * @param program The current program 105 | * @return The disassembled instruction as a string 106 | */ 107 | public static String getLineDisassembly(Address address, Program program) { 108 | Instruction instruction = program.getListing().getInstructionAt(address); 109 | if (instruction != null) { 110 | return formatInstruction(instruction); 111 | } 112 | return null; 113 | } 114 | 115 | /** 116 | * Collects all code tokens for a specific line containing an address. 117 | * @param node The current ClangNode 118 | * @param address The target address 119 | * @param codeLineBuilder StringBuilder to collect the code 120 | * @return true if the address was found and code collected 121 | */ 122 | private static boolean collectCodeLine(ClangNode node, Address address, StringBuilder codeLineBuilder) { 123 | if (node instanceof ClangToken) { 124 | ClangToken token = (ClangToken) node; 125 | if (token.getMinAddress() != null && token.getMaxAddress() != null) { 126 | if (token.getMinAddress().compareTo(address) <= 0 && 127 | token.getMaxAddress().compareTo(address) >= 0) { 128 | // Found the token corresponding to the address 129 | ClangNode parent = token.Parent(); 130 | if (parent != null) { 131 | for (int i = 0; i < parent.numChildren(); i++) { 132 | ClangNode sibling = parent.Child(i); 133 | if (sibling instanceof ClangToken) { 134 | codeLineBuilder.append(((ClangToken) sibling).getText()); 135 | } 136 | } 137 | } else { 138 | codeLineBuilder.append(token.getText()); 139 | } 140 | return true; 141 | } 142 | } 143 | } else if (node instanceof ClangTokenGroup) { 144 | ClangTokenGroup group = (ClangTokenGroup) node; 145 | for (int i = 0; i < group.numChildren(); i++) { 146 | ClangNode child = group.Child(i); 147 | boolean found = collectCodeLine(child, address, codeLineBuilder); 148 | if (found) { 149 | return true; 150 | } 151 | } 152 | } 153 | return false; 154 | } 155 | 156 | /** 157 | * Formats an instruction with its address and representation. 158 | * @param instruction The instruction to format 159 | * @return A formatted string representation of the instruction 160 | */ 161 | private static String formatInstruction(Instruction instruction) { 162 | return String.format("%s %s", 163 | instruction.getAddressString(true, true), 164 | instruction.toString()); 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /src/main/java/ghidrassist/core/LlmApiClient.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.core; 2 | 3 | import ghidrassist.AnalysisDB; 4 | import ghidrassist.GhidrAssistPlugin; 5 | import ghidrassist.LlmApi; 6 | import ghidrassist.apiprovider.APIProvider; 7 | import ghidrassist.apiprovider.APIProviderConfig; 8 | import ghidrassist.apiprovider.ChatMessage; 9 | import ghidrassist.apiprovider.exceptions.APIProviderException; 10 | 11 | import java.util.ArrayList; 12 | import java.util.List; 13 | import java.util.Map; 14 | 15 | /** 16 | * Handles API provider management and low-level API calls. 17 | * Focused solely on provider configuration and basic API interactions. 18 | */ 19 | public class LlmApiClient { 20 | private APIProvider provider; 21 | private final AnalysisDB analysisDB; 22 | private final GhidrAssistPlugin plugin; 23 | 24 | private final String DEFAULT_SYSTEM_PROMPT = 25 | "You are a professional software reverse engineer specializing in cybersecurity. You are intimately \n" 26 | + "familiar with x86_64, ARM, PPC and MIPS architectures. You are an expert C and C++ developer.\n" 27 | + "You are an expert Python and Rust developer. You are familiar with common frameworks and libraries \n" 28 | + "such as WinSock, OpenSSL, MFC, etc. You are an expert in TCP/IP network programming and packet analysis.\n" 29 | + "You always respond to queries in a structured format using Markdown styling for headings and lists. \n" 30 | + "You format code blocks using back-tick code-fencing.\n"; 31 | 32 | private final String FUNCTION_PROMPT = "USE THE PROVIDED TOOLS WHEN NECESSARY. YOU ALWAYS RESPOND WITH TOOL CALLS WHEN POSSIBLE."; 33 | 34 | public LlmApiClient(APIProviderConfig config, GhidrAssistPlugin plugin) { 35 | this.provider = config.createProvider(); 36 | this.analysisDB = new AnalysisDB(); 37 | this.plugin = plugin; 38 | 39 | // Get the global API timeout and set it if the provider doesn't have one 40 | if (provider != null && provider.getTimeout() == null) { 41 | Integer timeout = GhidrAssistPlugin.getGlobalApiTimeout(); 42 | provider.setTimeout(timeout); 43 | } 44 | } 45 | 46 | public String getSystemPrompt() { 47 | return this.DEFAULT_SYSTEM_PROMPT; 48 | } 49 | 50 | public GhidrAssistPlugin getPlugin() { 51 | return plugin; 52 | } 53 | 54 | public String getCurrentContext() { 55 | if (plugin.getCurrentProgram() != null) { 56 | String programHash = plugin.getCurrentProgram().getExecutableSHA256(); 57 | String context = analysisDB.getContext(programHash); 58 | if (context != null) { 59 | return context; 60 | } 61 | } 62 | return DEFAULT_SYSTEM_PROMPT; 63 | } 64 | 65 | /** 66 | * Create messages for regular chat completion 67 | */ 68 | public List createChatMessages(String prompt) { 69 | String systemUser = ChatMessage.ChatMessageRole.SYSTEM; 70 | if (isO1OrO3Model()) { 71 | systemUser = ChatMessage.ChatMessageRole.USER; 72 | } 73 | 74 | List messages = new ArrayList<>(); 75 | messages.add(new ChatMessage(systemUser, getCurrentContext())); 76 | messages.add(new ChatMessage(ChatMessage.ChatMessageRole.USER, prompt)); 77 | return messages; 78 | } 79 | 80 | /** 81 | * Create messages for function calling 82 | */ 83 | public List createFunctionMessages(String prompt) { 84 | String systemUser = ChatMessage.ChatMessageRole.SYSTEM; 85 | if (isO1OrO3Model()) { 86 | systemUser = ChatMessage.ChatMessageRole.USER; 87 | } 88 | 89 | List messages = new ArrayList<>(); 90 | messages.add(new ChatMessage(ChatMessage.ChatMessageRole.USER, prompt)); 91 | return messages; 92 | } 93 | 94 | /** 95 | * Check if the current model is O1 or O3 series (which handle system prompts differently) 96 | */ 97 | private boolean isO1OrO3Model() { 98 | return provider != null && ( 99 | provider.getModel().startsWith("o1-") || 100 | provider.getModel().startsWith("o3-") || 101 | provider.getModel().startsWith("o4-") 102 | ); 103 | } 104 | 105 | /** 106 | * Stream chat completion 107 | */ 108 | public void streamChatCompletion(List messages, LlmApi.LlmResponseHandler handler) 109 | throws APIProviderException { 110 | if (provider == null) { 111 | throw new IllegalStateException("LLM provider is not initialized."); 112 | } 113 | provider.streamChatCompletion(messages, handler); 114 | } 115 | 116 | /** 117 | * Create chat completion with functions 118 | */ 119 | public String createChatCompletionWithFunctions(List messages, List> functions) 120 | throws APIProviderException { 121 | if (provider == null) { 122 | throw new IllegalStateException("LLM provider is not initialized."); 123 | } 124 | return provider.createChatCompletionWithFunctions(messages, functions); 125 | } 126 | 127 | /** 128 | * Create chat completion with functions - returns full response including finish_reason 129 | */ 130 | public String createChatCompletionWithFunctionsFullResponse(List messages, List> functions) 131 | throws APIProviderException { 132 | if (provider == null) { 133 | throw new IllegalStateException("LLM provider is not initialized."); 134 | } 135 | return provider.createChatCompletionWithFunctionsFullResponse(messages, functions); 136 | } 137 | 138 | /** 139 | * Check if provider is available 140 | */ 141 | public boolean isProviderAvailable() { 142 | return provider != null; 143 | } 144 | 145 | /** 146 | * Get provider name for logging/error handling 147 | */ 148 | public String getProviderName() { 149 | return provider != null ? provider.getName() : "Unknown"; 150 | } 151 | 152 | /** 153 | * Get provider model for logging/error handling 154 | */ 155 | public String getProviderModel() { 156 | return provider != null ? provider.getModel() : "Unknown"; 157 | } 158 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/core/LlmErrorHandler.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.core; 2 | 3 | import ghidra.util.Msg; 4 | import ghidrassist.GhidrAssistPlugin; 5 | import ghidrassist.apiprovider.APIProviderLogger; 6 | import ghidrassist.apiprovider.ErrorAction; 7 | import ghidrassist.apiprovider.exceptions.APIProviderException; 8 | import ghidrassist.apiprovider.exceptions.StreamCancelledException; 9 | import ghidrassist.ui.EnhancedErrorDialog; 10 | 11 | import java.util.ArrayList; 12 | import java.util.List; 13 | 14 | /** 15 | * Handles error processing, logging, and user interaction for LLM operations. 16 | * Focused solely on error handling logic and user feedback. 17 | */ 18 | public class LlmErrorHandler { 19 | 20 | private final GhidrAssistPlugin plugin; 21 | private final Object source; 22 | 23 | public LlmErrorHandler(GhidrAssistPlugin plugin, Object source) { 24 | this.plugin = plugin; 25 | this.source = source; 26 | } 27 | 28 | /** 29 | * Handle an error with enhanced error dialogs and logging 30 | */ 31 | public void handleError(Throwable error, String operation, Runnable retryAction) { 32 | if (error instanceof APIProviderException) { 33 | APIProviderException ape = (APIProviderException) error; 34 | 35 | // Log the error with structured information 36 | APIProviderLogger.logError(source, ape); 37 | 38 | // Skip showing error dialog for cancellations unless it's unexpected 39 | if (shouldSkipErrorDialog(ape)) { 40 | return; 41 | } 42 | 43 | // Create appropriate error actions 44 | List actions = createErrorActions(ape, retryAction); 45 | 46 | // Show enhanced error dialog 47 | java.awt.Window parentWindow = getParentWindow(); 48 | EnhancedErrorDialog.showError(parentWindow, ape, actions); 49 | 50 | } else { 51 | // Handle non-API provider exceptions (fallback) 52 | handleGenericError(error, operation); 53 | } 54 | } 55 | 56 | /** 57 | * Handle generic (non-API provider) errors 58 | */ 59 | private void handleGenericError(Throwable error, String operation) { 60 | String message = error.getMessage() != null ? error.getMessage() : error.getClass().getSimpleName(); 61 | Msg.showError(source, null, "Unexpected Error", 62 | "An unexpected error occurred during " + operation + ": " + message); 63 | 64 | // Log the error 65 | Msg.error(source, "Unexpected error during " + operation, error); 66 | } 67 | 68 | /** 69 | * Determine if error dialog should be skipped for certain cancellation types 70 | */ 71 | private boolean shouldSkipErrorDialog(APIProviderException ape) { 72 | if (ape.getCategory() == APIProviderException.ErrorCategory.CANCELLED) { 73 | if (ape instanceof StreamCancelledException) { 74 | StreamCancelledException sce = (StreamCancelledException) ape; 75 | if (sce.getCancellationReason() == StreamCancelledException.CancellationReason.USER_REQUESTED) { 76 | return true; // Don't show dialog for user-requested cancellations 77 | } 78 | } 79 | } 80 | return false; 81 | } 82 | 83 | /** 84 | * Create appropriate error actions based on the exception type 85 | */ 86 | private List createErrorActions(APIProviderException ape, Runnable retryAction) { 87 | List actions = new ArrayList<>(); 88 | 89 | // Add retry action for retryable errors 90 | if (ape.isRetryable() && retryAction != null) { 91 | actions.add(ErrorAction.createRetryAction(retryAction)); 92 | } 93 | 94 | // Add settings action for configuration-related errors 95 | if (isConfigurationError(ape)) { 96 | actions.add(ErrorAction.createSettingsAction(() -> openSettings())); 97 | } 98 | 99 | // Add provider switching action for persistent errors 100 | APIProviderLogger.ErrorStats stats = APIProviderLogger.getErrorStats(ape.getProviderName()); 101 | if (stats != null && stats.isFrequentErrorsDetected()) { 102 | actions.add(ErrorAction.createSwitchProviderAction(() -> suggestProviderSwitch())); 103 | } 104 | 105 | // Add copy error details action 106 | actions.add(ErrorAction.createCopyErrorAction(ape.getTechnicalDetails())); 107 | 108 | // Add dismiss action 109 | actions.add(ErrorAction.createDismissAction()); 110 | 111 | return actions; 112 | } 113 | 114 | /** 115 | * Check if error is configuration-related 116 | */ 117 | private boolean isConfigurationError(APIProviderException ape) { 118 | return ape.getCategory() == APIProviderException.ErrorCategory.AUTHENTICATION || 119 | ape.getCategory() == APIProviderException.ErrorCategory.CONFIGURATION || 120 | ape.getCategory() == APIProviderException.ErrorCategory.MODEL_ERROR; 121 | } 122 | 123 | /** 124 | * Get the parent window for error dialogs 125 | */ 126 | private java.awt.Window getParentWindow() { 127 | try { 128 | // Try to get the main Ghidra window 129 | if (plugin != null && plugin.getTool() != null) { 130 | return plugin.getTool().getToolFrame(); 131 | } 132 | } catch (Exception e) { 133 | // Ignore errors getting parent window 134 | } 135 | return null; 136 | } 137 | 138 | /** 139 | * Open the settings dialog 140 | */ 141 | private void openSettings() { 142 | try { 143 | if (plugin != null) { 144 | // This would typically call the plugin's settings dialog 145 | // The actual implementation depends on how settings are accessed 146 | Msg.showInfo(source, null, "Settings", 147 | "Please go to Tools -> GhidrAssist Settings to configure API providers."); 148 | } 149 | } catch (Exception e) { 150 | Msg.showError(source, null, "Error", "Could not open settings: " + e.getMessage()); 151 | } 152 | } 153 | 154 | /** 155 | * Suggest switching to a different provider 156 | */ 157 | private void suggestProviderSwitch() { 158 | try { 159 | // Generate a simple suggestion message 160 | StringBuilder suggestion = new StringBuilder(); 161 | suggestion.append("Current provider is experiencing frequent errors.\n\n"); 162 | suggestion.append("Consider switching to a different provider in Settings.\n\n"); 163 | suggestion.append("Provider Error Statistics:\n"); 164 | suggestion.append(APIProviderLogger.generateDiagnosticsReport()); 165 | 166 | Msg.showInfo(source, null, "Provider Reliability", suggestion.toString()); 167 | } catch (Exception e) { 168 | Msg.showError(source, null, "Error", "Could not generate provider statistics: " + e.getMessage()); 169 | } 170 | } 171 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/core/LlmTaskExecutor.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.core; 2 | 3 | import ghidrassist.LlmApi; 4 | import ghidrassist.apiprovider.exceptions.APIProviderException; 5 | 6 | import javax.swing.SwingWorker; 7 | import java.util.List; 8 | import java.util.Map; 9 | import java.util.concurrent.atomic.AtomicBoolean; 10 | 11 | /** 12 | * Handles background task execution for LLM operations. 13 | * Focused on managing SwingWorker tasks and request lifecycle. 14 | */ 15 | public class LlmTaskExecutor { 16 | 17 | private final Object streamLock = new Object(); 18 | private volatile boolean isStreaming = false; 19 | private final AtomicBoolean shouldCancel = new AtomicBoolean(false); 20 | 21 | /** 22 | * Execute a streaming chat request in the background 23 | */ 24 | public void executeStreamingRequest( 25 | LlmApiClient client, 26 | String prompt, 27 | ResponseProcessor responseProcessor, 28 | LlmResponseHandler responseHandler) { 29 | 30 | if (!client.isProviderAvailable()) { 31 | responseHandler.onError(new IllegalStateException("LLM provider is not initialized.")); 32 | return; 33 | } 34 | 35 | // Cancel any existing stream 36 | cancelCurrentRequest(); 37 | shouldCancel.set(false); 38 | 39 | try { 40 | synchronized (streamLock) { 41 | isStreaming = true; 42 | ResponseProcessor.StreamingResponseFilter filter = responseProcessor.createStreamingFilter(); 43 | 44 | client.streamChatCompletion(client.createChatMessages(prompt), new LlmApi.LlmResponseHandler() { 45 | private boolean isFirst = true; 46 | 47 | @Override 48 | public void onStart() { 49 | if (isFirst && shouldCancel.get() == false) { 50 | responseHandler.onStart(); 51 | isFirst = false; 52 | } 53 | } 54 | 55 | @Override 56 | public void onUpdate(String partialResponse) { 57 | if (shouldCancel.get()) { 58 | return; 59 | } 60 | String filteredContent = filter.processChunk(partialResponse); 61 | if (filteredContent != null && !filteredContent.isEmpty()) { 62 | responseHandler.onUpdate(filteredContent); 63 | } 64 | } 65 | 66 | @Override 67 | public void onComplete(String fullResponse) { 68 | synchronized (streamLock) { 69 | isStreaming = false; 70 | } 71 | if (!shouldCancel.get()) { 72 | responseHandler.onComplete(filter.getFilteredContent()); 73 | } 74 | } 75 | 76 | @Override 77 | public void onError(Throwable error) { 78 | synchronized (streamLock) { 79 | isStreaming = false; 80 | } 81 | if (!shouldCancel.get()) { 82 | responseHandler.onError(error); 83 | } 84 | } 85 | 86 | @Override 87 | public boolean shouldContinue() { 88 | return !shouldCancel.get() && responseHandler.shouldContinue(); 89 | } 90 | }); 91 | } 92 | } catch (Exception e) { 93 | synchronized (streamLock) { 94 | isStreaming = false; 95 | } 96 | if (!shouldCancel.get()) { 97 | responseHandler.onError(e); 98 | } 99 | } 100 | } 101 | 102 | /** 103 | * Execute a function calling request in the background 104 | */ 105 | public void executeFunctionRequest( 106 | LlmApiClient client, 107 | String prompt, 108 | List> functions, 109 | ResponseProcessor responseProcessor, 110 | LlmResponseHandler responseHandler) { 111 | 112 | if (!client.isProviderAvailable()) { 113 | responseHandler.onError(new IllegalStateException("LLM provider is not initialized.")); 114 | return; 115 | } 116 | 117 | shouldCancel.set(false); 118 | 119 | // Create a background task 120 | SwingWorker worker = new SwingWorker<>() { 121 | @Override 122 | protected Void doInBackground() { 123 | try { 124 | synchronized (streamLock) { 125 | isStreaming = true; 126 | } 127 | 128 | if (!shouldCancel.get()) { 129 | responseHandler.onStart(); 130 | String response = client.createChatCompletionWithFunctions( 131 | client.createFunctionMessages(prompt), functions); 132 | 133 | if (!shouldCancel.get() && responseHandler.shouldContinue()) { 134 | String filteredResponse = responseProcessor.filterThinkBlocks(response); 135 | responseHandler.onComplete(filteredResponse); 136 | } 137 | } 138 | } catch (APIProviderException e) { 139 | if (!shouldCancel.get() && responseHandler.shouldContinue()) { 140 | responseHandler.onError(e); 141 | } 142 | } finally { 143 | synchronized (streamLock) { 144 | isStreaming = false; 145 | } 146 | } 147 | return null; 148 | } 149 | 150 | @Override 151 | protected void done() { 152 | try { 153 | get(); // Check for exceptions 154 | } catch (Exception e) { 155 | if (!shouldCancel.get() && responseHandler.shouldContinue()) { 156 | responseHandler.onError(e); 157 | } 158 | } 159 | } 160 | }; 161 | 162 | worker.execute(); 163 | } 164 | 165 | /** 166 | * Cancel the current request 167 | */ 168 | public void cancelCurrentRequest() { 169 | shouldCancel.set(true); 170 | synchronized (streamLock) { 171 | isStreaming = false; 172 | } 173 | } 174 | 175 | /** 176 | * Check if currently streaming 177 | */ 178 | public boolean isStreaming() { 179 | synchronized (streamLock) { 180 | return isStreaming; 181 | } 182 | } 183 | 184 | /** 185 | * Interface for handling LLM responses 186 | */ 187 | public interface LlmResponseHandler { 188 | void onStart(); 189 | void onUpdate(String partialResponse); 190 | void onComplete(String fullResponse); 191 | void onError(Throwable error); 192 | default boolean shouldContinue() { 193 | return true; 194 | } 195 | } 196 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/core/MarkdownHelper.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.core; 2 | 3 | import com.vladsch.flexmark.html.HtmlRenderer; 4 | import com.vladsch.flexmark.html2md.converter.FlexmarkHtmlConverter; 5 | import com.vladsch.flexmark.parser.Parser; 6 | import com.vladsch.flexmark.util.ast.Document; 7 | import com.vladsch.flexmark.util.data.MutableDataSet; 8 | 9 | import java.util.regex.Matcher; 10 | import java.util.regex.Pattern; 11 | 12 | public class MarkdownHelper { 13 | private final Parser parser; 14 | private final HtmlRenderer renderer; 15 | private final FlexmarkHtmlConverter htmlToMdConverter; 16 | 17 | public MarkdownHelper() { 18 | MutableDataSet options = new MutableDataSet(); 19 | options.set(HtmlRenderer.SOFT_BREAK, "
\n"); 20 | this.parser = Parser.builder(options).build(); 21 | this.renderer = HtmlRenderer.builder(options).build(); 22 | this.htmlToMdConverter = FlexmarkHtmlConverter.builder().build(); 23 | } 24 | 25 | /** 26 | * Convert Markdown text to HTML for display 27 | * Includes feedback buttons in the HTML output 28 | * 29 | * @param markdown The markdown text to convert 30 | * @return HTML representation of the markdown 31 | */ 32 | public String markdownToHtml(String markdown) { 33 | if (markdown == null) { 34 | return ""; 35 | } 36 | 37 | Document document = parser.parse(markdown); 38 | String html = renderer.render(document); 39 | 40 | // Add feedback buttons 41 | String feedbackLinks = "
" + 42 | "👍 | 👎
"; 43 | 44 | return "" + html + feedbackLinks + ""; 45 | } 46 | 47 | /** 48 | * Convert Markdown text to HTML without adding feedback buttons 49 | * Used for preview or when feedback isn't needed 50 | * 51 | * @param markdown The markdown text to convert 52 | * @return HTML representation of the markdown 53 | */ 54 | public String markdownToHtmlSimple(String markdown) { 55 | if (markdown == null) { 56 | return ""; 57 | } 58 | 59 | Document document = parser.parse(markdown); 60 | String html = renderer.render(document); 61 | 62 | return "" + html + ""; 63 | } 64 | 65 | /** 66 | * Convert HTML to Markdown 67 | * 68 | * @param html The HTML to convert 69 | * @return Markdown representation of the HTML 70 | */ 71 | public String htmlToMarkdown(String html) { 72 | if (html == null || html.isEmpty()) { 73 | return ""; 74 | } 75 | 76 | // Remove feedback buttons if present 77 | html = removeFeedbackButtons(html); 78 | 79 | // Remove html wrapper tags if present 80 | html = removeHtmlWrapperTags(html); 81 | 82 | // Use flexmark converter for the HTML to Markdown conversion 83 | return htmlToMdConverter.convert(html); 84 | } 85 | 86 | /** 87 | * Extract markdown from a response that might be in various formats 88 | * 89 | * @param response The response to extract markdown from 90 | * @return Extracted markdown content 91 | */ 92 | public String extractMarkdownFromLlmResponse(String response) { 93 | if (response == null || response.isEmpty()) { 94 | return ""; 95 | } 96 | 97 | // Check if it's HTML 98 | if (response.toLowerCase().contains("") || response.toLowerCase().contains("")) { 99 | return htmlToMarkdown(response); 100 | } 101 | 102 | // Otherwise, assume it's already markdown or plain text 103 | return response; 104 | } 105 | 106 | /** 107 | * Remove feedback buttons from HTML string 108 | */ 109 | private String removeFeedbackButtons(String html) { 110 | // Pattern to match the feedback buttons div 111 | Pattern feedbackPattern = Pattern.compile("
.*?
"); 112 | Matcher matcher = feedbackPattern.matcher(html); 113 | return matcher.replaceAll(""); 114 | } 115 | 116 | /** 117 | * Remove HTML and BODY wrapper tags 118 | */ 119 | private String removeHtmlWrapperTags(String html) { 120 | return html.replaceAll("(?i)|||", ""); 121 | } 122 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/core/QueryProcessor.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.core; 2 | 3 | import java.util.List; 4 | import java.util.regex.Pattern; 5 | import java.util.regex.Matcher; 6 | 7 | import ghidra.program.model.address.Address; 8 | import ghidra.program.model.address.AddressFactory; 9 | import ghidra.program.model.listing.Function; 10 | import ghidra.program.model.listing.Program; 11 | import ghidra.util.task.TaskMonitor; 12 | import ghidrassist.GhidrAssistPlugin; 13 | import ghidrassist.GhidrAssistPlugin.CodeViewType; 14 | 15 | public class QueryProcessor { 16 | 17 | private static final Pattern RANGE_PATTERN = Pattern.compile("#range\\(([^,]+),\\s*([^\\)]+)\\)"); 18 | private static final int MAX_SEARCH_RESULTS = 5; 19 | 20 | /** 21 | * Process all macros in the query and replace them with actual content. 22 | * @param query The original query containing macros 23 | * @param plugin The GhidrAssist plugin instance 24 | * @return Processed query with macros replaced 25 | */ 26 | public static String processMacrosInQuery(String query, GhidrAssistPlugin plugin) { 27 | String processedQuery = query; 28 | 29 | try { 30 | CodeViewType viewType = plugin.checkLastActiveCodeView(); 31 | TaskMonitor monitor = TaskMonitor.DUMMY; 32 | 33 | // Process #line macro 34 | if (processedQuery.contains("#line")) { 35 | String codeLine = getCurrentLine(plugin, viewType, monitor); 36 | if (codeLine != null) { 37 | processedQuery = processedQuery.replace("#line", codeLine); 38 | } 39 | } 40 | 41 | // Process #func macro 42 | if (processedQuery.contains("#func")) { 43 | String functionCode = getCurrentFunction(plugin, viewType, monitor); 44 | if (functionCode != null) { 45 | processedQuery = processedQuery.replace("#func", functionCode); 46 | } 47 | } 48 | 49 | // Process #addr macro 50 | if (processedQuery.contains("#addr")) { 51 | String addressString = getCurrentAddress(plugin); 52 | processedQuery = processedQuery.replace("#addr", addressString); 53 | } 54 | 55 | // Process #range macros 56 | processedQuery = processRangeMacros(processedQuery, plugin); 57 | 58 | } catch (Exception e) { 59 | throw new RuntimeException("Failed to process macros: " + e.getMessage(), e); 60 | } 61 | 62 | return processedQuery; 63 | } 64 | 65 | /** 66 | * Append RAG context to the query based on similarity search. 67 | * @param query The original query 68 | * @return Query with RAG context prepended 69 | * @throws Exception if RAG search fails 70 | */ 71 | public static String appendRAGContext(String query) throws Exception { 72 | List results = RAGEngine.hybridSearch(query, MAX_SEARCH_RESULTS); 73 | if (results.isEmpty()) { 74 | return query; 75 | } 76 | 77 | StringBuilder contextBuilder = new StringBuilder(); 78 | contextBuilder.append("\n"); 79 | 80 | for (SearchResult result : results) { 81 | contextBuilder.append("\n"); 82 | contextBuilder.append("
").append(result.getFilename()).append("\n"); 83 | contextBuilder.append("
").append(result.getChunkId()).append("\n"); 84 | contextBuilder.append("
").append(result.getScore()).append("\n"); 85 | contextBuilder.append("
\n").append(result.getSnippet()).append("\n\n"); 86 | contextBuilder.append("\n
\n\n"); 87 | } 88 | 89 | contextBuilder.append("\n
\n"); 90 | return contextBuilder.toString() + query; 91 | } 92 | 93 | /** 94 | * Get the current line based on view type. 95 | */ 96 | private static String getCurrentLine(GhidrAssistPlugin plugin, CodeViewType viewType, TaskMonitor monitor) { 97 | Address currentAddress = plugin.getCurrentAddress(); 98 | if (currentAddress == null) { 99 | return "No current address available."; 100 | } 101 | 102 | if (viewType == CodeViewType.IS_DECOMPILER) { 103 | return CodeUtils.getLineCode(currentAddress, monitor, plugin.getCurrentProgram()); 104 | } else if (viewType == CodeViewType.IS_DISASSEMBLER) { 105 | return CodeUtils.getLineDisassembly(currentAddress, plugin.getCurrentProgram()); 106 | } 107 | 108 | return "Unknown code view type."; 109 | } 110 | 111 | /** 112 | * Get the current function based on view type. 113 | */ 114 | private static String getCurrentFunction(GhidrAssistPlugin plugin, CodeViewType viewType, TaskMonitor monitor) { 115 | Function currentFunction = plugin.getCurrentFunction(); 116 | if (currentFunction == null) { 117 | return "No function at current location."; 118 | } 119 | 120 | if (viewType == CodeViewType.IS_DECOMPILER) { 121 | return CodeUtils.getFunctionCode(currentFunction, monitor); 122 | } else if (viewType == CodeViewType.IS_DISASSEMBLER) { 123 | return CodeUtils.getFunctionDisassembly(currentFunction); 124 | } 125 | 126 | return "Unknown code view type."; 127 | } 128 | 129 | /** 130 | * Get the current address as a string. 131 | */ 132 | private static String getCurrentAddress(GhidrAssistPlugin plugin) { 133 | Address currentAddress = plugin.getCurrentAddress(); 134 | return (currentAddress != null) ? currentAddress.toString() : "No address available."; 135 | } 136 | 137 | /** 138 | * Process all #range macros in the query. 139 | */ 140 | private static String processRangeMacros(String query, GhidrAssistPlugin plugin) { 141 | Matcher matcher = RANGE_PATTERN.matcher(query); 142 | while (matcher.find()) { 143 | String startStr = matcher.group(1); 144 | String endStr = matcher.group(2); 145 | String rangeData = getRangeData(startStr.trim(), endStr.trim(), plugin); 146 | query = query.replace(matcher.group(0), rangeData); 147 | matcher = RANGE_PATTERN.matcher(query); 148 | } 149 | return query; 150 | } 151 | 152 | /** 153 | * Get the data for a specific address range. 154 | */ 155 | private static String getRangeData(String startStr, String endStr, GhidrAssistPlugin plugin) { 156 | try { 157 | Program program = plugin.getCurrentProgram(); 158 | if (program == null) { 159 | return "No program loaded."; 160 | } 161 | 162 | AddressFactory addressFactory = program.getAddressFactory(); 163 | Address startAddr = addressFactory.getAddress(startStr); 164 | Address endAddr = addressFactory.getAddress(endStr); 165 | 166 | if (startAddr == null || endAddr == null) { 167 | return "Invalid addresses."; 168 | } 169 | 170 | // Get the bytes in the range 171 | long size = endAddr.getOffset() - startAddr.getOffset() + 1; 172 | if (size <= 0 || size > 1024) { // Limit to reasonable size 173 | return "Invalid range size."; 174 | } 175 | 176 | byte[] bytes = new byte[(int) size]; 177 | program.getMemory().getBytes(startAddr, bytes); 178 | 179 | // Convert bytes to hex string 180 | StringBuilder sb = new StringBuilder(); 181 | for (byte b : bytes) { 182 | sb.append(String.format("%02X ", b)); 183 | } 184 | return sb.toString().trim(); 185 | 186 | } catch (Exception e) { 187 | return "Failed to get range data: " + e.getMessage(); 188 | } 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /src/main/java/ghidrassist/core/ResponseProcessor.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.core; 2 | 3 | import java.util.regex.Pattern; 4 | 5 | /** 6 | * Handles response processing including streaming filters and thinking block removal. 7 | * Focused solely on text processing and filtering logic. 8 | */ 9 | public class ResponseProcessor { 10 | 11 | // Pattern for matching complete blocks and opening/closing tags 12 | private static final Pattern COMPLETE_THINK_PATTERN = Pattern.compile(".*?", Pattern.DOTALL); 13 | 14 | /** 15 | * Create a new streaming filter for processing chunks 16 | */ 17 | public StreamingResponseFilter createStreamingFilter() { 18 | return new StreamingResponseFilter(); 19 | } 20 | 21 | /** 22 | * Filter thinking blocks from a complete response 23 | */ 24 | public String filterThinkBlocks(String response) { 25 | if (response == null) { 26 | return null; 27 | } 28 | return COMPLETE_THINK_PATTERN.matcher(response).replaceAll("").trim(); 29 | } 30 | 31 | /** 32 | * Streaming filter that processes chunks of text and removes thinking blocks in real-time 33 | */ 34 | public static class StreamingResponseFilter { 35 | private StringBuilder buffer = new StringBuilder(); 36 | private StringBuilder visibleBuffer = new StringBuilder(); 37 | private boolean insideThinkBlock = false; 38 | 39 | /** 40 | * Process a chunk of streaming text, filtering out thinking blocks 41 | * @param chunk The text chunk to process 42 | * @return The filtered content that should be displayed, or null if nothing to display 43 | */ 44 | public String processChunk(String chunk) { 45 | if (chunk == null) { 46 | return null; 47 | } 48 | 49 | buffer.append(chunk); 50 | 51 | // Process the buffer until we can't anymore 52 | String currentBuffer = buffer.toString(); 53 | int lastSafeIndex = 0; 54 | 55 | for (int i = 0; i < currentBuffer.length(); i++) { 56 | // Look for start tag 57 | if (!insideThinkBlock && currentBuffer.startsWith("", i)) { 58 | // Append everything up to this point to visible buffer 59 | visibleBuffer.append(currentBuffer.substring(lastSafeIndex, i)); 60 | insideThinkBlock = true; 61 | lastSafeIndex = i + 7; // Skip "" 62 | i += 6; // Move past "" 63 | } 64 | // Look for end tag 65 | else if (insideThinkBlock && currentBuffer.startsWith("", i)) { 66 | insideThinkBlock = false; 67 | lastSafeIndex = i + 8; // Skip "" 68 | i += 7; // Move past "" 69 | } 70 | } 71 | 72 | // If we're not in a think block, append any remaining safe content 73 | if (!insideThinkBlock) { 74 | visibleBuffer.append(currentBuffer.substring(lastSafeIndex)); 75 | // Clear processed content from buffer 76 | buffer.setLength(0); 77 | } else { 78 | // Keep everything from lastSafeIndex in buffer 79 | buffer = new StringBuilder(currentBuffer.substring(lastSafeIndex)); 80 | } 81 | 82 | return visibleBuffer.toString(); 83 | } 84 | 85 | /** 86 | * Get the complete filtered content processed so far 87 | */ 88 | public String getFilteredContent() { 89 | return visibleBuffer.toString(); 90 | } 91 | 92 | /** 93 | * Reset the filter state for reuse 94 | */ 95 | public void reset() { 96 | buffer.setLength(0); 97 | visibleBuffer.setLength(0); 98 | insideThinkBlock = false; 99 | } 100 | 101 | /** 102 | * Check if currently inside a thinking block 103 | */ 104 | public boolean isInsideThinkBlock() { 105 | return insideThinkBlock; 106 | } 107 | } 108 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/core/UIState.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.core; 2 | 3 | public class UIState { 4 | private volatile boolean isQueryRunning; 5 | private int activeRunners; 6 | 7 | public UIState() { 8 | this.isQueryRunning = false; 9 | this.activeRunners = 0; 10 | } 11 | 12 | public synchronized boolean isQueryRunning() { 13 | return isQueryRunning; 14 | } 15 | 16 | public synchronized void setQueryRunning(boolean running) { 17 | this.isQueryRunning = running; 18 | } 19 | 20 | public synchronized void incrementRunners() { 21 | activeRunners++; 22 | } 23 | 24 | public synchronized void decrementRunners() { 25 | activeRunners--; 26 | if (activeRunners <= 0) { 27 | activeRunners = 0; 28 | isQueryRunning = false; 29 | } 30 | } 31 | 32 | public synchronized int getActiveRunners() { 33 | return activeRunners; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/ghidrassist/mcp2/protocol/MCPMessage.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.mcp2.protocol; 2 | 3 | import com.google.gson.JsonObject; 4 | import com.google.gson.JsonElement; 5 | 6 | /** 7 | * Base class for all MCP JSON-RPC 2.0 messages. 8 | * Implements the core JSON-RPC 2.0 message structure. 9 | */ 10 | public abstract class MCPMessage { 11 | 12 | public static final String JSONRPC_VERSION = "2.0"; 13 | 14 | protected String jsonrpc = JSONRPC_VERSION; 15 | protected String method; 16 | protected JsonObject params; 17 | 18 | public MCPMessage(String method) { 19 | this.method = method; 20 | this.params = new JsonObject(); 21 | } 22 | 23 | public MCPMessage(String method, JsonObject params) { 24 | this.method = method; 25 | this.params = params != null ? params : new JsonObject(); 26 | } 27 | 28 | public String getJsonrpc() { 29 | return jsonrpc; 30 | } 31 | 32 | public String getMethod() { 33 | return method; 34 | } 35 | 36 | public JsonObject getParams() { 37 | return params; 38 | } 39 | 40 | public void setParam(String key, String value) { 41 | params.addProperty(key, value); 42 | } 43 | 44 | public void setParam(String key, JsonElement value) { 45 | params.add(key, value); 46 | } 47 | 48 | public void setParam(String key, Number value) { 49 | params.addProperty(key, value); 50 | } 51 | 52 | public void setParam(String key, Boolean value) { 53 | params.addProperty(key, value); 54 | } 55 | 56 | /** 57 | * Convert message to JSON string for transmission 58 | */ 59 | public abstract String toJson(); 60 | 61 | /** 62 | * Validate message format 63 | */ 64 | public boolean isValid() { 65 | return JSONRPC_VERSION.equals(jsonrpc) && method != null && !method.isEmpty(); 66 | } 67 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/mcp2/protocol/MCPRequest.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.mcp2.protocol; 2 | 3 | import com.google.gson.Gson; 4 | import com.google.gson.JsonObject; 5 | import com.google.gson.JsonElement; 6 | 7 | /** 8 | * Represents an MCP JSON-RPC 2.0 request message. 9 | * Used for tool discovery (tools/list) and tool execution (tools/call). 10 | */ 11 | public class MCPRequest extends MCPMessage { 12 | 13 | private final Object id; 14 | 15 | public MCPRequest(Object id, String method) { 16 | super(method); 17 | this.id = id; 18 | } 19 | 20 | public MCPRequest(Object id, String method, JsonObject params) { 21 | super(method, params); 22 | this.id = id; 23 | } 24 | 25 | public Object getId() { 26 | return id; 27 | } 28 | 29 | @Override 30 | public String toJson() { 31 | JsonObject json = new JsonObject(); 32 | json.addProperty("jsonrpc", jsonrpc); 33 | 34 | // Only add ID if this is not a notification 35 | if (!"notification".equals(id)) { 36 | json.addProperty("id", id.toString()); 37 | } 38 | 39 | json.addProperty("method", method); 40 | 41 | if (params != null && params.size() > 0) { 42 | json.add("params", params); 43 | } 44 | 45 | return new Gson().toJson(json); 46 | } 47 | 48 | @Override 49 | public boolean isValid() { 50 | return super.isValid() && id != null; 51 | } 52 | 53 | /** 54 | * Create a tools/list request 55 | */ 56 | public static MCPRequest createToolsListRequest(Object id) { 57 | return createToolsListRequest(id, null); 58 | } 59 | 60 | /** 61 | * Create a tools/list request with cursor for pagination 62 | */ 63 | public static MCPRequest createToolsListRequest(Object id, String cursor) { 64 | MCPRequest request = new MCPRequest(id, "tools/list"); 65 | if (cursor != null) { 66 | request.setParam("cursor", cursor); 67 | } 68 | return request; 69 | } 70 | 71 | /** 72 | * Create a tools/call request 73 | */ 74 | public static MCPRequest createToolsCallRequest(Object id, String toolName, JsonObject arguments) { 75 | MCPRequest request = new MCPRequest(id, "tools/call"); 76 | request.setParam("name", toolName); 77 | if (arguments != null) { 78 | request.setParam("arguments", arguments); 79 | } 80 | return request; 81 | } 82 | 83 | /** 84 | * Create an initialize request for protocol handshake 85 | */ 86 | public static MCPRequest createInitializeRequest(Object id, String protocolVersion, String clientInfo) { 87 | MCPRequest request = new MCPRequest(id, "initialize"); 88 | request.setParam("protocolVersion", protocolVersion); 89 | 90 | JsonObject clientInfoObj = new JsonObject(); 91 | clientInfoObj.addProperty("name", "GhidrAssist"); 92 | clientInfoObj.addProperty("version", "1.0.0"); 93 | if (clientInfo != null) { 94 | clientInfoObj.addProperty("description", clientInfo); 95 | } 96 | request.setParam("clientInfo", clientInfoObj); 97 | 98 | JsonObject capabilities = new JsonObject(); 99 | capabilities.addProperty("tools", true); 100 | request.setParam("capabilities", capabilities); 101 | 102 | return request; 103 | } 104 | 105 | /** 106 | * Create an initialized notification (sent after initialize response) 107 | */ 108 | public static MCPRequest createInitializedNotification() { 109 | // Notifications don't have IDs in JSON-RPC 2.0 110 | // The bridge expects "notifications/initialized" method 111 | return new MCPRequest("notification", "notifications/initialized"); 112 | } 113 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/mcp2/protocol/MCPResponse.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.mcp2.protocol; 2 | 3 | import com.google.gson.Gson; 4 | import com.google.gson.JsonObject; 5 | import com.google.gson.JsonElement; 6 | import com.google.gson.JsonArray; 7 | 8 | /** 9 | * Represents an MCP JSON-RPC 2.0 response message. 10 | * Can contain either a result (success) or an error (failure). 11 | */ 12 | public class MCPResponse { 13 | 14 | private String jsonrpc = MCPMessage.JSONRPC_VERSION; 15 | private Object id; 16 | private JsonElement result; 17 | private MCPError error; 18 | 19 | public MCPResponse(Object id) { 20 | this.id = id; 21 | } 22 | 23 | public String getJsonrpc() { 24 | return jsonrpc; 25 | } 26 | 27 | public Object getId() { 28 | return id; 29 | } 30 | 31 | public JsonElement getResult() { 32 | return result; 33 | } 34 | 35 | public void setResult(JsonElement result) { 36 | this.result = result; 37 | this.error = null; // Clear error if setting result 38 | } 39 | 40 | public MCPError getError() { 41 | return error; 42 | } 43 | 44 | public void setError(MCPError error) { 45 | this.error = error; 46 | this.result = null; // Clear result if setting error 47 | } 48 | 49 | public void setError(int code, String message, JsonElement data) { 50 | setError(new MCPError(code, message, data)); 51 | } 52 | 53 | public boolean isSuccess() { 54 | return result != null && error == null; 55 | } 56 | 57 | public boolean isError() { 58 | return error != null; 59 | } 60 | 61 | public String toJson() { 62 | JsonObject json = new JsonObject(); 63 | json.addProperty("jsonrpc", jsonrpc); 64 | json.addProperty("id", id.toString()); 65 | 66 | if (result != null) { 67 | json.add("result", result); 68 | } else if (error != null) { 69 | json.add("error", error.toJson()); 70 | } 71 | 72 | return new Gson().toJson(json); 73 | } 74 | 75 | /** 76 | * Parse response from JSON string 77 | */ 78 | public static MCPResponse fromJson(String jsonStr) { 79 | Gson gson = new Gson(); 80 | JsonObject json = gson.fromJson(jsonStr, JsonObject.class); 81 | 82 | Object id = json.has("id") ? json.get("id").getAsString() : null; 83 | MCPResponse response = new MCPResponse(id); 84 | 85 | if (json.has("result")) { 86 | response.setResult(json.get("result")); 87 | } else if (json.has("error")) { 88 | JsonObject errorObj = json.getAsJsonObject("error"); 89 | int code = errorObj.get("code").getAsInt(); 90 | String message = errorObj.get("message").getAsString(); 91 | JsonElement data = errorObj.has("data") ? errorObj.get("data") : null; 92 | response.setError(code, message, data); 93 | } 94 | 95 | return response; 96 | } 97 | 98 | /** 99 | * Extract tools array from tools/list response 100 | */ 101 | public JsonArray getToolsArray() { 102 | if (isSuccess() && result.isJsonObject()) { 103 | JsonObject resultObj = result.getAsJsonObject(); 104 | if (resultObj.has("tools") && resultObj.get("tools").isJsonArray()) { 105 | return resultObj.getAsJsonArray("tools"); 106 | } 107 | } 108 | return new JsonArray(); 109 | } 110 | 111 | /** 112 | * Extract tool call result content 113 | */ 114 | public JsonElement getToolCallResult() { 115 | if (isSuccess() && result.isJsonObject()) { 116 | JsonObject resultObj = result.getAsJsonObject(); 117 | if (resultObj.has("content")) { 118 | return resultObj.get("content"); 119 | } 120 | } 121 | return result; 122 | } 123 | 124 | /** 125 | * Get cursor for pagination in tools/list response 126 | */ 127 | public String getNextCursor() { 128 | if (isSuccess() && result.isJsonObject()) { 129 | JsonObject resultObj = result.getAsJsonObject(); 130 | if (resultObj.has("nextCursor")) { 131 | return resultObj.get("nextCursor").getAsString(); 132 | } 133 | } 134 | return null; 135 | } 136 | 137 | /** 138 | * Inner class representing JSON-RPC 2.0 error object 139 | */ 140 | public static class MCPError { 141 | private int code; 142 | private String message; 143 | private JsonElement data; 144 | 145 | public MCPError(int code, String message, JsonElement data) { 146 | this.code = code; 147 | this.message = message; 148 | this.data = data; 149 | } 150 | 151 | public int getCode() { 152 | return code; 153 | } 154 | 155 | public String getMessage() { 156 | return message; 157 | } 158 | 159 | public JsonElement getData() { 160 | return data; 161 | } 162 | 163 | public JsonObject toJson() { 164 | JsonObject json = new JsonObject(); 165 | json.addProperty("code", code); 166 | json.addProperty("message", message); 167 | if (data != null) { 168 | json.add("data", data); 169 | } 170 | return json; 171 | } 172 | 173 | @Override 174 | public String toString() { 175 | return String.format("MCPError{code=%d, message='%s'}", code, message); 176 | } 177 | } 178 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/mcp2/server/MCPServerConfig.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.mcp2.server; 2 | 3 | import com.google.gson.Gson; 4 | import com.google.gson.JsonObject; 5 | 6 | /** 7 | * Configuration for an MCP server connection. 8 | * Stores all necessary information to connect to and manage an MCP server. 9 | */ 10 | public class MCPServerConfig { 11 | 12 | public enum TransportType { 13 | SSE("Server-Sent Events"), 14 | STDIO("Standard I/O"); 15 | 16 | private final String displayName; 17 | 18 | TransportType(String displayName) { 19 | this.displayName = displayName; 20 | } 21 | 22 | public String getDisplayName() { 23 | return displayName; 24 | } 25 | } 26 | 27 | private String name; // Display name (e.g., "GhidraMCP Local") 28 | private String url; // Server URL (e.g., "http://localhost:8081") 29 | private TransportType transport; // Transport mechanism 30 | private int connectionTimeout; // Connection timeout in seconds 31 | private int requestTimeout; // Request timeout in seconds 32 | private boolean enabled; // Whether this server is active 33 | private String description; // Optional description 34 | 35 | // Default constructor for JSON deserialization 36 | public MCPServerConfig() { 37 | this.transport = TransportType.SSE; 38 | this.connectionTimeout = 5; // Reduced from 10 seconds 39 | this.requestTimeout = 15; // Reduced from 30 seconds 40 | this.enabled = true; 41 | } 42 | 43 | public MCPServerConfig(String name, String url) { 44 | this(); 45 | this.name = name; 46 | this.url = url; 47 | } 48 | 49 | public MCPServerConfig(String name, String url, TransportType transport) { 50 | this(name, url); 51 | this.transport = transport; 52 | } 53 | 54 | public MCPServerConfig(String name, String url, TransportType transport, boolean enabled) { 55 | this(name, url, transport); 56 | this.enabled = enabled; 57 | } 58 | 59 | // Getters and setters 60 | 61 | public String getName() { 62 | return name; 63 | } 64 | 65 | public void setName(String name) { 66 | this.name = name; 67 | } 68 | 69 | public String getUrl() { 70 | return url; 71 | } 72 | 73 | public void setUrl(String url) { 74 | this.url = url; 75 | } 76 | 77 | public TransportType getTransport() { 78 | return transport; 79 | } 80 | 81 | public void setTransport(TransportType transport) { 82 | this.transport = transport; 83 | } 84 | 85 | public int getConnectionTimeout() { 86 | return connectionTimeout; 87 | } 88 | 89 | public void setConnectionTimeout(int connectionTimeout) { 90 | this.connectionTimeout = connectionTimeout; 91 | } 92 | 93 | public int getRequestTimeout() { 94 | return requestTimeout; 95 | } 96 | 97 | public void setRequestTimeout(int requestTimeout) { 98 | this.requestTimeout = requestTimeout; 99 | } 100 | 101 | public boolean isEnabled() { 102 | return enabled; 103 | } 104 | 105 | public void setEnabled(boolean enabled) { 106 | this.enabled = enabled; 107 | } 108 | 109 | public String getDescription() { 110 | return description; 111 | } 112 | 113 | public void setDescription(String description) { 114 | this.description = description; 115 | } 116 | 117 | /** 118 | * Get the base URL for HTTP connections 119 | */ 120 | public String getBaseUrl() { 121 | if (url == null) return null; 122 | 123 | // Ensure URL has protocol 124 | if (!url.startsWith("http://") && !url.startsWith("https://")) { 125 | return "http://" + url; 126 | } 127 | return url; 128 | } 129 | 130 | /** 131 | * Get the host from the URL 132 | */ 133 | public String getHost() { 134 | try { 135 | java.net.URL urlObj = new java.net.URL(getBaseUrl()); 136 | return urlObj.getHost(); 137 | } catch (Exception e) { 138 | return "localhost"; 139 | } 140 | } 141 | 142 | /** 143 | * Get the port from the URL 144 | */ 145 | public int getPort() { 146 | try { 147 | java.net.URL urlObj = new java.net.URL(getBaseUrl()); 148 | int port = urlObj.getPort(); 149 | return port != -1 ? port : (urlObj.getProtocol().equals("https") ? 443 : 80); 150 | } catch (Exception e) { 151 | return 8081; // Default MCP port 152 | } 153 | } 154 | 155 | /** 156 | * Validate configuration 157 | */ 158 | public boolean isValid() { 159 | return name != null && !name.trim().isEmpty() && 160 | url != null && !url.trim().isEmpty() && 161 | transport != null && 162 | connectionTimeout > 0 && 163 | requestTimeout > 0; 164 | } 165 | 166 | /** 167 | * Create a copy of this configuration 168 | */ 169 | public MCPServerConfig copy() { 170 | MCPServerConfig copy = new MCPServerConfig(name, url, transport); 171 | copy.setConnectionTimeout(connectionTimeout); 172 | copy.setRequestTimeout(requestTimeout); 173 | copy.setEnabled(enabled); 174 | copy.setDescription(description); 175 | return copy; 176 | } 177 | 178 | /** 179 | * Serialize to JSON 180 | */ 181 | public String toJson() { 182 | return new Gson().toJson(this); 183 | } 184 | 185 | /** 186 | * Deserialize from JSON 187 | */ 188 | public static MCPServerConfig fromJson(String json) { 189 | return new Gson().fromJson(json, MCPServerConfig.class); 190 | } 191 | 192 | @Override 193 | public String toString() { 194 | return String.format("%s (%s) - %s", name, transport.getDisplayName(), 195 | enabled ? "Enabled" : "Disabled"); 196 | } 197 | 198 | @Override 199 | public boolean equals(Object obj) { 200 | if (this == obj) return true; 201 | if (obj == null || getClass() != obj.getClass()) return false; 202 | 203 | MCPServerConfig that = (MCPServerConfig) obj; 204 | return name != null ? name.equals(that.name) : that.name == null; 205 | } 206 | 207 | @Override 208 | public int hashCode() { 209 | return name != null ? name.hashCode() : 0; 210 | } 211 | 212 | /** 213 | * Create default GhidraMCP configuration 214 | */ 215 | public static MCPServerConfig createGhidraMCPDefault() { 216 | MCPServerConfig config = new MCPServerConfig("GhidraMCP Local", "http://localhost:8081"); 217 | config.setDescription("Local GhidraMCP server instance"); 218 | config.setTransport(TransportType.SSE); 219 | return config; 220 | } 221 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/mcp2/tools/MCPTool.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.mcp2.tools; 2 | 3 | import com.google.gson.JsonObject; 4 | import com.google.gson.JsonElement; 5 | import java.util.Map; 6 | import java.util.HashMap; 7 | 8 | /** 9 | * Represents an MCP tool discovered from a server. 10 | * This is server-agnostic and follows the MCP specification. 11 | */ 12 | public class MCPTool { 13 | 14 | private final String name; 15 | private final String description; 16 | private final JsonObject inputSchema; 17 | private final String serverName; // Which server provides this tool 18 | 19 | public MCPTool(String name, String description, JsonObject inputSchema, String serverName) { 20 | this.name = name; 21 | this.description = description; 22 | this.inputSchema = inputSchema; 23 | this.serverName = serverName; 24 | } 25 | 26 | public String getName() { 27 | return name; 28 | } 29 | 30 | public String getDescription() { 31 | return description; 32 | } 33 | 34 | public JsonObject getInputSchema() { 35 | return inputSchema; 36 | } 37 | 38 | public String getServerName() { 39 | return serverName; 40 | } 41 | 42 | /** 43 | * Check if this tool has input parameters 44 | */ 45 | public boolean hasInputSchema() { 46 | return inputSchema != null && inputSchema.size() > 0; 47 | } 48 | 49 | /** 50 | * Convert to function schema for LLM function calling 51 | * This follows the OpenAI function calling format 52 | */ 53 | public Map toFunctionSchema() { 54 | Map function = new HashMap<>(); 55 | function.put("type", "function"); 56 | 57 | Map functionDef = new HashMap<>(); 58 | functionDef.put("name", name); 59 | functionDef.put("description", description); 60 | 61 | if (inputSchema != null) { 62 | // Convert JsonObject to Map for compatibility 63 | functionDef.put("parameters", jsonObjectToMap(inputSchema)); 64 | } else { 65 | // Empty parameters schema 66 | Map emptyParams = new HashMap<>(); 67 | emptyParams.put("type", "object"); 68 | emptyParams.put("properties", new HashMap<>()); 69 | functionDef.put("parameters", emptyParams); 70 | } 71 | 72 | function.put("function", functionDef); 73 | return function; 74 | } 75 | 76 | /** 77 | * Create MCPTool from MCP tools/list response 78 | */ 79 | public static MCPTool fromToolsListEntry(JsonObject toolEntry, String serverName) { 80 | String name = toolEntry.has("name") ? toolEntry.get("name").getAsString() : null; 81 | String description = toolEntry.has("description") ? toolEntry.get("description").getAsString() : ""; 82 | JsonObject inputSchema = toolEntry.has("inputSchema") ? 83 | toolEntry.getAsJsonObject("inputSchema") : null; 84 | 85 | return new MCPTool(name, description, inputSchema, serverName); 86 | } 87 | 88 | /** 89 | * Helper method to convert JsonObject to Map recursively 90 | */ 91 | private Map jsonObjectToMap(JsonObject jsonObject) { 92 | Map map = new HashMap<>(); 93 | 94 | for (String key : jsonObject.keySet()) { 95 | Object value = jsonElementToObject(jsonObject.get(key)); 96 | map.put(key, value); 97 | } 98 | 99 | return map; 100 | } 101 | 102 | /** 103 | * Helper method to convert JsonElement to Java object 104 | */ 105 | private Object jsonElementToObject(JsonElement element) { 106 | if (element.isJsonPrimitive()) { 107 | if (element.getAsJsonPrimitive().isString()) { 108 | return element.getAsString(); 109 | } else if (element.getAsJsonPrimitive().isNumber()) { 110 | return element.getAsNumber(); 111 | } else if (element.getAsJsonPrimitive().isBoolean()) { 112 | return element.getAsBoolean(); 113 | } 114 | } else if (element.isJsonObject()) { 115 | return jsonObjectToMap(element.getAsJsonObject()); 116 | } else if (element.isJsonArray()) { 117 | com.google.gson.JsonArray array = element.getAsJsonArray(); 118 | java.util.List list = new java.util.ArrayList<>(); 119 | for (int i = 0; i < array.size(); i++) { 120 | list.add(jsonElementToObject(array.get(i))); 121 | } 122 | return list; 123 | } 124 | return null; 125 | } 126 | 127 | @Override 128 | public String toString() { 129 | return String.format("MCPTool{name='%s', server='%s', description='%s'}", 130 | name, serverName, description); 131 | } 132 | 133 | @Override 134 | public boolean equals(Object obj) { 135 | if (this == obj) return true; 136 | if (obj == null || getClass() != obj.getClass()) return false; 137 | 138 | MCPTool mcpTool = (MCPTool) obj; 139 | return name.equals(mcpTool.name) && serverName.equals(mcpTool.serverName); 140 | } 141 | 142 | @Override 143 | public int hashCode() { 144 | return java.util.Objects.hash(name, serverName); 145 | } 146 | 147 | /** 148 | * Get a display name that includes the server 149 | */ 150 | public String getDisplayName() { 151 | return String.format("%s (%s)", name, serverName); 152 | } 153 | 154 | /** 155 | * Check if tool name matches (case-insensitive) 156 | */ 157 | public boolean matchesName(String toolName) { 158 | return name != null && name.equalsIgnoreCase(toolName); 159 | } 160 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/mcp2/tools/MCPToolResult.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.mcp2.tools; 2 | 3 | /** 4 | * Result of an MCP tool execution. 5 | * Similar to the original MCPToolResult but for MCP 2.0. 6 | */ 7 | public class MCPToolResult { 8 | 9 | private final boolean success; 10 | private final String content; 11 | private final String error; 12 | 13 | public MCPToolResult(boolean success, String content, String error) { 14 | this.success = success; 15 | this.content = content; 16 | this.error = error; 17 | } 18 | 19 | /** 20 | * Create successful result 21 | */ 22 | public static MCPToolResult success(String content) { 23 | return new MCPToolResult(true, content, null); 24 | } 25 | 26 | /** 27 | * Create error result 28 | */ 29 | public static MCPToolResult error(String error) { 30 | return new MCPToolResult(false, null, error); 31 | } 32 | 33 | public boolean isSuccess() { 34 | return success; 35 | } 36 | 37 | public String getContent() { 38 | return content; 39 | } 40 | 41 | public String getError() { 42 | return error; 43 | } 44 | 45 | /** 46 | * Get result as string for display 47 | */ 48 | public String getResultText() { 49 | if (success) { 50 | return content != null ? content : ""; 51 | } else { 52 | return "Error: " + (error != null ? error : "Unknown error"); 53 | } 54 | } 55 | 56 | @Override 57 | public String toString() { 58 | return String.format("MCPToolResult{success=%s, content='%s', error='%s'}", 59 | success, content, error); 60 | } 61 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/mcp2/transport/MCPTransport.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.mcp2.transport; 2 | 3 | import ghidrassist.mcp2.protocol.MCPRequest; 4 | import ghidrassist.mcp2.protocol.MCPResponse; 5 | import java.util.concurrent.CompletableFuture; 6 | 7 | /** 8 | * Abstract transport layer for MCP communication. 9 | * Supports different transport mechanisms (SSE, stdio, etc.). 10 | */ 11 | public abstract class MCPTransport { 12 | 13 | protected boolean connected = false; 14 | protected MCPTransportHandler handler; 15 | 16 | /** 17 | * Interface for handling transport events 18 | */ 19 | public interface MCPTransportHandler { 20 | void onConnected(); 21 | void onDisconnected(); 22 | void onResponse(MCPResponse response); 23 | void onError(Throwable error); 24 | } 25 | 26 | /** 27 | * Set the transport event handler 28 | */ 29 | public void setHandler(MCPTransportHandler handler) { 30 | this.handler = handler; 31 | } 32 | 33 | /** 34 | * Connect to the MCP server 35 | */ 36 | public abstract CompletableFuture connect(); 37 | 38 | /** 39 | * Disconnect from the MCP server 40 | */ 41 | public abstract CompletableFuture disconnect(); 42 | 43 | /** 44 | * Send a request to the server 45 | */ 46 | public abstract CompletableFuture sendRequest(MCPRequest request); 47 | 48 | /** 49 | * Send a notification to the server (no response expected) 50 | */ 51 | public abstract CompletableFuture sendNotification(MCPRequest notification); 52 | 53 | /** 54 | * Check if transport is connected 55 | */ 56 | public boolean isConnected() { 57 | return connected; 58 | } 59 | 60 | /** 61 | * Get transport type name 62 | */ 63 | public abstract String getTransportType(); 64 | 65 | /** 66 | * Get connection info for debugging 67 | */ 68 | public abstract String getConnectionInfo(); 69 | 70 | /** 71 | * Notify handler of connection 72 | */ 73 | protected void notifyConnected() { 74 | connected = true; 75 | if (handler != null) { 76 | handler.onConnected(); 77 | } 78 | } 79 | 80 | /** 81 | * Notify handler of disconnection 82 | */ 83 | protected void notifyDisconnected() { 84 | connected = false; 85 | if (handler != null) { 86 | handler.onDisconnected(); 87 | } 88 | } 89 | 90 | /** 91 | * Notify handler of response 92 | */ 93 | protected void notifyResponse(MCPResponse response) { 94 | if (handler != null) { 95 | handler.onResponse(response); 96 | } 97 | } 98 | 99 | /** 100 | * Notify handler of error 101 | */ 102 | protected void notifyError(Throwable error) { 103 | if (handler != null) { 104 | handler.onError(error); 105 | } 106 | } 107 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/resources/GhidrAssistIcons.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.resources; 2 | 3 | import javax.swing.ImageIcon; 4 | import resources.ResourceManager; 5 | 6 | public class GhidrAssistIcons { 7 | public static final ImageIcon ROBOT_ICON = ResourceManager.loadImage("images/robot32.png"); 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/ghidrassist/services/AnalysisDataService.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.services; 2 | 3 | import ghidrassist.AnalysisDB; 4 | import ghidrassist.GhidrAssistPlugin; 5 | import ghidrassist.LlmApi; 6 | import ghidrassist.apiprovider.APIProviderConfig; 7 | 8 | /** 9 | * Service for managing analysis context and program-specific data. 10 | * Responsible for context storage, retrieval, and management operations. 11 | */ 12 | public class AnalysisDataService { 13 | 14 | private final GhidrAssistPlugin plugin; 15 | private final AnalysisDB analysisDB; 16 | 17 | public AnalysisDataService(GhidrAssistPlugin plugin) { 18 | this.plugin = plugin; 19 | this.analysisDB = new AnalysisDB(); 20 | } 21 | 22 | /** 23 | * Save context for the current program 24 | */ 25 | public void saveContext(String context) { 26 | if (plugin.getCurrentProgram() == null) { 27 | throw new IllegalStateException("No active program to save context for."); 28 | } 29 | 30 | String programHash = plugin.getCurrentProgram().getExecutableSHA256(); 31 | analysisDB.upsertContext(programHash, context); 32 | } 33 | 34 | /** 35 | * Get context for the current program 36 | */ 37 | public String getContext() { 38 | if (plugin.getCurrentProgram() == null) { 39 | return getDefaultContext(); 40 | } 41 | 42 | String programHash = plugin.getCurrentProgram().getExecutableSHA256(); 43 | String context = analysisDB.getContext(programHash); 44 | 45 | if (context == null) { 46 | return getDefaultContext(); 47 | } 48 | 49 | return context; 50 | } 51 | 52 | /** 53 | * Revert context to default for the current program 54 | */ 55 | public String revertToDefaultContext() { 56 | String defaultContext = getDefaultContext(); 57 | 58 | if (plugin.getCurrentProgram() != null) { 59 | // Clear custom context, will fall back to default 60 | String programHash = plugin.getCurrentProgram().getExecutableSHA256(); 61 | analysisDB.upsertContext(programHash, null); 62 | } 63 | 64 | return defaultContext; 65 | } 66 | 67 | /** 68 | * Check if current program has custom context 69 | */ 70 | public boolean hasCustomContext() { 71 | if (plugin.getCurrentProgram() == null) { 72 | return false; 73 | } 74 | 75 | String programHash = plugin.getCurrentProgram().getExecutableSHA256(); 76 | String context = analysisDB.getContext(programHash); 77 | return context != null && !context.equals(getDefaultContext()); 78 | } 79 | 80 | /** 81 | * Get default system context 82 | */ 83 | private String getDefaultContext() { 84 | APIProviderConfig config = GhidrAssistPlugin.getCurrentProviderConfig(); 85 | if (config == null) { 86 | return "You are a professional software reverse engineer."; // Fallback 87 | } 88 | 89 | LlmApi llmApi = new LlmApi(config, plugin); 90 | return llmApi.getSystemPrompt(); 91 | } 92 | 93 | /** 94 | * Get context statistics 95 | */ 96 | public ContextStats getContextStats() { 97 | String currentContext = getContext(); 98 | boolean isCustom = hasCustomContext(); 99 | String programName = plugin.getCurrentProgram() != null ? 100 | plugin.getCurrentProgram().getName() : "No Program"; 101 | 102 | return new ContextStats(programName, currentContext.length(), isCustom); 103 | } 104 | 105 | /** 106 | * Close database resources 107 | */ 108 | public void close() { 109 | if (analysisDB != null) { 110 | analysisDB.close(); 111 | } 112 | } 113 | 114 | /** 115 | * Statistics about the current context 116 | */ 117 | public static class ContextStats { 118 | private final String programName; 119 | private final int contextLength; 120 | private final boolean isCustom; 121 | 122 | public ContextStats(String programName, int contextLength, boolean isCustom) { 123 | this.programName = programName; 124 | this.contextLength = contextLength; 125 | this.isCustom = isCustom; 126 | } 127 | 128 | public String getProgramName() { return programName; } 129 | public int getContextLength() { return contextLength; } 130 | public boolean isCustom() { return isCustom; } 131 | 132 | @Override 133 | public String toString() { 134 | return String.format("Program: %s, Context: %d chars (%s)", 135 | programName, contextLength, isCustom ? "Custom" : "Default"); 136 | } 137 | } 138 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/services/CodeAnalysisService.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.services; 2 | 3 | import ghidra.program.model.address.Address; 4 | import ghidra.program.model.listing.Function; 5 | import ghidra.program.model.listing.Program; 6 | import ghidra.util.task.TaskMonitor; 7 | import ghidrassist.AnalysisDB; 8 | import ghidrassist.GhidrAssistPlugin; 9 | import ghidrassist.LlmApi; 10 | import ghidrassist.apiprovider.APIProviderConfig; 11 | import ghidrassist.core.CodeUtils; 12 | 13 | /** 14 | * Service for handling code analysis operations. 15 | * Responsible for explaining functions and lines of code. 16 | */ 17 | public class CodeAnalysisService { 18 | 19 | private final GhidrAssistPlugin plugin; 20 | private final AnalysisDB analysisDB; 21 | 22 | public CodeAnalysisService(GhidrAssistPlugin plugin) { 23 | this.plugin = plugin; 24 | this.analysisDB = new AnalysisDB(); 25 | } 26 | 27 | /** 28 | * Analyze and explain a function 29 | */ 30 | public AnalysisRequest createFunctionAnalysisRequest(Function function) throws Exception { 31 | if (function == null) { 32 | throw new IllegalArgumentException("No function at current location."); 33 | } 34 | 35 | String functionCode = null; 36 | String codeType = null; 37 | 38 | GhidrAssistPlugin.CodeViewType viewType = plugin.checkLastActiveCodeView(); 39 | if (viewType == GhidrAssistPlugin.CodeViewType.IS_DECOMPILER) { 40 | functionCode = CodeUtils.getFunctionCode(function, TaskMonitor.DUMMY); 41 | codeType = "pseudo-C"; 42 | } else if (viewType == GhidrAssistPlugin.CodeViewType.IS_DISASSEMBLER) { 43 | functionCode = CodeUtils.getFunctionDisassembly(function); 44 | codeType = "assembly"; 45 | } else { 46 | throw new Exception("Unknown code view type."); 47 | } 48 | 49 | String prompt = "Explain the following " + codeType + " code:\n```\n" + functionCode + "\n```"; 50 | return new AnalysisRequest(AnalysisRequest.Type.FUNCTION, prompt, function); 51 | } 52 | 53 | /** 54 | * Analyze and explain a line of code 55 | */ 56 | public AnalysisRequest createLineAnalysisRequest(Address address) throws Exception { 57 | if (address == null) { 58 | throw new IllegalArgumentException("No address at current location."); 59 | } 60 | 61 | String codeLine = null; 62 | String codeType = null; 63 | 64 | GhidrAssistPlugin.CodeViewType viewType = plugin.checkLastActiveCodeView(); 65 | if (viewType == GhidrAssistPlugin.CodeViewType.IS_DECOMPILER) { 66 | codeLine = CodeUtils.getLineCode(address, TaskMonitor.DUMMY, plugin.getCurrentProgram()); 67 | codeType = "pseudo-C"; 68 | } else if (viewType == GhidrAssistPlugin.CodeViewType.IS_DISASSEMBLER) { 69 | codeLine = CodeUtils.getLineDisassembly(address, plugin.getCurrentProgram()); 70 | codeType = "assembly"; 71 | } else { 72 | throw new Exception("Unknown code view type."); 73 | } 74 | 75 | String prompt = "Explain the following " + codeType + " line:\n```\n" + codeLine + "\n```"; 76 | return new AnalysisRequest(AnalysisRequest.Type.LINE, prompt, address); 77 | } 78 | 79 | /** 80 | * Execute an analysis request 81 | */ 82 | public void executeAnalysis(AnalysisRequest request, LlmApi.LlmResponseHandler handler) throws Exception { 83 | APIProviderConfig config = GhidrAssistPlugin.getCurrentProviderConfig(); 84 | if (config == null) { 85 | throw new Exception("No API provider configured."); 86 | } 87 | 88 | LlmApi llmApi = new LlmApi(config, plugin); 89 | llmApi.sendRequestAsync(request.getPrompt(), handler); 90 | } 91 | 92 | /** 93 | * Store analysis result in database 94 | */ 95 | public void storeAnalysisResult(Function function, String prompt, String response) { 96 | if (function != null && plugin.getCurrentProgram() != null) { 97 | analysisDB.upsertAnalysis( 98 | plugin.getCurrentProgram().getExecutableSHA256(), 99 | function.getEntryPoint(), 100 | prompt, 101 | response 102 | ); 103 | } 104 | } 105 | 106 | /** 107 | * Get existing analysis for a function 108 | */ 109 | public AnalysisDB.Analysis getExistingAnalysis(Function function) { 110 | if (function == null || plugin.getCurrentProgram() == null) { 111 | return null; 112 | } 113 | 114 | return analysisDB.getAnalysis( 115 | plugin.getCurrentProgram().getExecutableSHA256(), 116 | function.getEntryPoint() 117 | ); 118 | } 119 | 120 | /** 121 | * Update existing analysis 122 | */ 123 | public void updateAnalysis(Function function, String updatedContent) { 124 | if (function == null || plugin.getCurrentProgram() == null) { 125 | throw new IllegalArgumentException("No active program or function"); 126 | } 127 | 128 | String programHash = plugin.getCurrentProgram().getExecutableSHA256(); 129 | Address functionAddress = function.getEntryPoint(); 130 | 131 | // Get existing analysis to preserve the query 132 | AnalysisDB.Analysis existingAnalysis = analysisDB.getAnalysis(programHash, functionAddress); 133 | 134 | if (existingAnalysis == null) { 135 | // Create new entry with generic query 136 | analysisDB.upsertAnalysis(programHash, functionAddress, "Edited explanation", updatedContent); 137 | } else { 138 | // Update existing entry, preserving original query 139 | analysisDB.upsertAnalysis(programHash, functionAddress, existingAnalysis.getQuery(), updatedContent); 140 | } 141 | } 142 | 143 | /** 144 | * Clear analysis data for a function 145 | */ 146 | public boolean clearAnalysis(Function function) { 147 | if (function == null || plugin.getCurrentProgram() == null) { 148 | return false; 149 | } 150 | 151 | String programHash = plugin.getCurrentProgram().getExecutableSHA256(); 152 | Address functionAddress = function.getEntryPoint(); 153 | 154 | return analysisDB.deleteAnalysis(programHash, functionAddress); 155 | } 156 | 157 | /** 158 | * Close database resources 159 | */ 160 | public void close() { 161 | if (analysisDB != null) { 162 | analysisDB.close(); 163 | } 164 | } 165 | 166 | /** 167 | * Request object for analysis operations 168 | */ 169 | public static class AnalysisRequest { 170 | public enum Type { 171 | FUNCTION, LINE 172 | } 173 | 174 | private final Type type; 175 | private final String prompt; 176 | private final Object context; // Function or Address 177 | 178 | public AnalysisRequest(Type type, String prompt, Object context) { 179 | this.type = type; 180 | this.prompt = prompt; 181 | this.context = context; 182 | } 183 | 184 | public Type getType() { return type; } 185 | public String getPrompt() { return prompt; } 186 | public Object getContext() { return context; } 187 | 188 | public Function getFunction() { 189 | return context instanceof Function ? (Function) context : null; 190 | } 191 | 192 | public Address getAddress() { 193 | return context instanceof Address ? (Address) context : null; 194 | } 195 | } 196 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/services/FeedbackService.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.services; 2 | 3 | import ghidrassist.GhidrAssistPlugin; 4 | import ghidrassist.LlmApi; 5 | import ghidrassist.RLHFDatabase; 6 | import ghidrassist.apiprovider.APIProviderConfig; 7 | 8 | /** 9 | * Service for handling RLHF (Reinforcement Learning from Human Feedback) operations. 10 | * Responsible for storing user feedback and managing feedback data. 11 | */ 12 | public class FeedbackService { 13 | 14 | private final GhidrAssistPlugin plugin; 15 | private final RLHFDatabase rlhfDB; 16 | 17 | // Cache for last interaction 18 | private String lastPrompt; 19 | private String lastResponse; 20 | 21 | public FeedbackService(GhidrAssistPlugin plugin) { 22 | this.plugin = plugin; 23 | this.rlhfDB = new RLHFDatabase(); 24 | } 25 | 26 | /** 27 | * Cache the last prompt and response for feedback 28 | */ 29 | public void cacheLastInteraction(String prompt, String response) { 30 | this.lastPrompt = prompt; 31 | this.lastResponse = response; 32 | } 33 | 34 | /** 35 | * Store positive feedback (thumbs up) 36 | */ 37 | public void storePositiveFeedback() { 38 | storeFeedback(1); 39 | } 40 | 41 | /** 42 | * Store negative feedback (thumbs down) 43 | */ 44 | public void storeNegativeFeedback() { 45 | storeFeedback(0); 46 | } 47 | 48 | /** 49 | * Store feedback with specified rating 50 | */ 51 | public void storeFeedback(int feedback) { 52 | if (lastPrompt == null || lastResponse == null) { 53 | throw new IllegalStateException("No recent interaction to provide feedback for."); 54 | } 55 | 56 | APIProviderConfig config = GhidrAssistPlugin.getCurrentProviderConfig(); 57 | if (config == null) { 58 | throw new IllegalStateException("No API provider configured."); 59 | } 60 | 61 | LlmApi llmApi = new LlmApi(config, plugin); 62 | String modelName = config.getModel(); 63 | String systemContext = llmApi.getSystemPrompt(); 64 | 65 | rlhfDB.storeFeedback(modelName, lastPrompt, systemContext, lastResponse, feedback); 66 | } 67 | 68 | /** 69 | * Check if there's a recent interaction available for feedback 70 | */ 71 | public boolean hasPendingFeedback() { 72 | return lastPrompt != null && lastResponse != null; 73 | } 74 | 75 | /** 76 | * Clear cached interaction (e.g., after feedback is provided) 77 | */ 78 | public void clearCachedInteraction() { 79 | lastPrompt = null; 80 | lastResponse = null; 81 | } 82 | 83 | /** 84 | * Get feedback statistics 85 | */ 86 | public FeedbackStats getFeedbackStats() { 87 | // Note: This would require extending RLHFDatabase with stats methods 88 | // For now, return basic info 89 | return new FeedbackStats(hasPendingFeedback()); 90 | } 91 | 92 | /** 93 | * Get the last cached prompt (for debugging/info) 94 | */ 95 | public String getLastPrompt() { 96 | return lastPrompt; 97 | } 98 | 99 | /** 100 | * Get the last cached response (for debugging/info) 101 | */ 102 | public String getLastResponse() { 103 | return lastResponse; 104 | } 105 | 106 | /** 107 | * Close database resources 108 | */ 109 | public void close() { 110 | if (rlhfDB != null) { 111 | rlhfDB.close(); 112 | } 113 | } 114 | 115 | /** 116 | * Statistics about feedback state 117 | */ 118 | public static class FeedbackStats { 119 | private final boolean hasPendingFeedback; 120 | 121 | public FeedbackStats(boolean hasPendingFeedback) { 122 | this.hasPendingFeedback = hasPendingFeedback; 123 | } 124 | 125 | public boolean hasPendingFeedback() { return hasPendingFeedback; } 126 | 127 | @Override 128 | public String toString() { 129 | return String.format("Feedback: %s", 130 | hasPendingFeedback ? "Available" : "No pending feedback"); 131 | } 132 | } 133 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/services/RAGManagementService.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.services; 2 | 3 | import ghidrassist.core.RAGEngine; 4 | 5 | import java.io.File; 6 | import java.io.IOException; 7 | import java.util.Arrays; 8 | import java.util.List; 9 | 10 | /** 11 | * Service for managing RAG (Retrieval Augmented Generation) documents. 12 | * Responsible for document ingestion, deletion, and listing operations. 13 | */ 14 | public class RAGManagementService { 15 | 16 | /** 17 | * Add documents to the RAG index 18 | */ 19 | public void addDocuments(File[] files) throws IOException { 20 | if (files == null || files.length == 0) { 21 | throw new IllegalArgumentException("No files provided for ingestion."); 22 | } 23 | 24 | RAGEngine.ingestDocuments(Arrays.asList(files)); 25 | } 26 | 27 | /** 28 | * Delete selected documents from the RAG index 29 | */ 30 | public void deleteDocuments(List fileNames) throws IOException { 31 | if (fileNames == null || fileNames.isEmpty()) { 32 | throw new IllegalArgumentException("No documents selected for deletion."); 33 | } 34 | 35 | for (String fileName : fileNames) { 36 | RAGEngine.deleteDocument(fileName); 37 | } 38 | } 39 | 40 | /** 41 | * Get list of indexed files 42 | */ 43 | public List getIndexedFiles() throws IOException { 44 | return RAGEngine.listIndexedFiles(); 45 | } 46 | 47 | /** 48 | * Check if the RAG index is available and working 49 | */ 50 | public boolean isRAGAvailable() { 51 | try { 52 | RAGEngine.listIndexedFiles(); 53 | return true; 54 | } catch (IOException e) { 55 | return false; 56 | } 57 | } 58 | 59 | /** 60 | * Get RAG index statistics 61 | */ 62 | public RAGIndexStats getIndexStats() throws IOException { 63 | List files = getIndexedFiles(); 64 | return new RAGIndexStats(files.size(), files); 65 | } 66 | 67 | /** 68 | * Clear all documents from the RAG index 69 | */ 70 | public void clearAllDocuments() throws IOException { 71 | List allFiles = getIndexedFiles(); 72 | deleteDocuments(allFiles); 73 | } 74 | 75 | /** 76 | * Statistics about the RAG index 77 | */ 78 | public static class RAGIndexStats { 79 | private final int totalFiles; 80 | private final List fileNames; 81 | 82 | public RAGIndexStats(int totalFiles, List fileNames) { 83 | this.totalFiles = totalFiles; 84 | this.fileNames = fileNames; 85 | } 86 | 87 | public int getTotalFiles() { return totalFiles; } 88 | public List getFileNames() { return fileNames; } 89 | 90 | @Override 91 | public String toString() { 92 | return String.format("RAG Index: %d files indexed", totalFiles); 93 | } 94 | } 95 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/ui/GhidrAssistUI.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.ui; 2 | 3 | import javax.swing.*; 4 | 5 | import ghidra.program.util.ProgramLocation; 6 | 7 | import java.awt.*; 8 | import ghidrassist.GhidrAssistPlugin; 9 | import ghidrassist.ui.tabs.*; 10 | import ghidrassist.core.TabController; 11 | import ghidrassist.ui.common.UIConstants; 12 | 13 | public class GhidrAssistUI extends JPanel { 14 | private static final long serialVersionUID = 1L; 15 | private final GhidrAssistPlugin plugin; 16 | private final TabController controller; 17 | private final JTabbedPane tabbedPane; 18 | private final ExplainTab explainTab; 19 | private final QueryTab queryTab; 20 | private final ActionsTab actionsTab; 21 | private final RAGManagementTab ragManagementTab; 22 | private final AnalysisOptionsTab analysisOptionsTab; 23 | 24 | public GhidrAssistUI(GhidrAssistPlugin plugin) { 25 | super(new BorderLayout()); 26 | this.plugin = plugin; 27 | this.controller = new TabController(plugin); 28 | 29 | // Initialize components 30 | this.tabbedPane = new JTabbedPane(); 31 | 32 | // Create tabs 33 | this.explainTab = new ExplainTab(controller); 34 | this.queryTab = new QueryTab(controller); 35 | this.actionsTab = new ActionsTab(controller); 36 | this.ragManagementTab = new RAGManagementTab(controller); 37 | this.analysisOptionsTab = new AnalysisOptionsTab(controller); 38 | 39 | // Set tab references in controller 40 | controller.setExplainTab(explainTab); 41 | controller.setQueryTab(queryTab); 42 | controller.setActionsTab(actionsTab); 43 | controller.setRAGManagementTab(ragManagementTab); 44 | controller.setAnalysisOptionsTab(analysisOptionsTab); 45 | 46 | initializeUI(); 47 | } 48 | 49 | private void initializeUI() { 50 | setBorder(UIConstants.PANEL_BORDER); 51 | 52 | // Add tabs 53 | tabbedPane.addTab("Explain", explainTab); 54 | tabbedPane.addTab("Custom Query", queryTab); 55 | tabbedPane.addTab("Actions", actionsTab); 56 | tabbedPane.addTab("RAG Management", ragManagementTab); 57 | tabbedPane.addTab("Analysis Options", analysisOptionsTab); 58 | 59 | add(tabbedPane, BorderLayout.CENTER); 60 | 61 | // Initialize tabs that need startup data 62 | SwingUtilities.invokeLater(() -> { 63 | // Load initial context 64 | controller.handleContextRevert(); 65 | 66 | // Load RAG file list 67 | controller.loadIndexedFiles(ragManagementTab.getDocumentList()); 68 | }); 69 | 70 | tabbedPane.addChangeListener(e -> { 71 | if (tabbedPane.getSelectedComponent() == analysisOptionsTab) { 72 | // Refresh context when Analysis Options tab is selected 73 | controller.handleContextRevert(); 74 | } 75 | }); 76 | } 77 | 78 | public void updateLocation(ProgramLocation loc) { 79 | if (loc != null && loc.getAddress() != null) { 80 | explainTab.updateOffset(loc.getAddress().toString()); 81 | controller.updateAnalysis(loc); 82 | } 83 | } 84 | 85 | public JComponent getComponent() { 86 | return this; 87 | } 88 | 89 | public GhidrAssistPlugin getPlugin() { 90 | return plugin; 91 | } 92 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/ui/common/PlaceholderTextField.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.ui.common; 2 | 3 | import javax.swing.JTextField; 4 | import java.awt.Color; 5 | import java.awt.event.FocusEvent; 6 | import java.awt.event.FocusListener; 7 | 8 | public class PlaceholderTextField extends JTextField { 9 | private boolean showingPlaceholder; 10 | private final String placeholder; 11 | private final Color placeholderColor; 12 | private final Color textColor; 13 | 14 | public PlaceholderTextField(String placeholder, int columns) { 15 | super(columns); 16 | this.placeholder = placeholder; 17 | this.showingPlaceholder = true; 18 | this.placeholderColor = UIConstants.PLACEHOLDER_COLOR; 19 | this.textColor = getForeground(); 20 | 21 | setupPlaceholder(); 22 | } 23 | 24 | private void setupPlaceholder() { 25 | setText(placeholder); 26 | setForeground(placeholderColor); 27 | 28 | addFocusListener(new FocusListener() { 29 | @Override 30 | public void focusGained(FocusEvent e) { 31 | if (showingPlaceholder) { 32 | showingPlaceholder = false; 33 | setText(""); 34 | setForeground(textColor); 35 | } 36 | } 37 | 38 | @Override 39 | public void focusLost(FocusEvent e) { 40 | if (getText().isEmpty()) { 41 | showingPlaceholder = true; 42 | setText(placeholder); 43 | setForeground(placeholderColor); 44 | } 45 | } 46 | }); 47 | } 48 | 49 | @Override 50 | public String getText() { 51 | return showingPlaceholder ? "" : super.getText(); 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/main/java/ghidrassist/ui/common/UIConstants.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.ui.common; 2 | 3 | import java.awt.Color; 4 | import java.awt.Dimension; 5 | import javax.swing.BorderFactory; 6 | import javax.swing.border.Border; 7 | 8 | public class UIConstants { 9 | public static final int PADDING = 5; 10 | public static final int BUTTON_SPACING = 5; 11 | public static final int TEXT_FIELD_COLUMNS = 20; 12 | public static final Dimension PREFERRED_BUTTON_SIZE = new Dimension(120, 30); 13 | public static final Border PANEL_BORDER = BorderFactory.createEmptyBorder(PADDING, PADDING, PADDING, PADDING); 14 | public static final Color PLACEHOLDER_COLOR = Color.GRAY; 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/ghidrassist/ui/tabs/ActionsTab.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.ui.tabs; 2 | 3 | import javax.swing.*; 4 | import javax.swing.table.DefaultTableModel; 5 | import java.awt.*; 6 | import java.util.Map; 7 | import java.util.HashMap; 8 | import ghidrassist.core.TabController; 9 | import ghidrassist.core.ActionConstants; 10 | 11 | public class ActionsTab extends JPanel { 12 | private static final long serialVersionUID = 1L; 13 | private final TabController controller; 14 | private JTable actionsTable; 15 | private Map filterCheckBoxes; 16 | private JButton analyzeFunctionButton; 17 | private JButton analyzeClearButton; 18 | private JButton applyActionsButton; 19 | 20 | public ActionsTab(TabController controller) { 21 | super(new BorderLayout()); 22 | this.controller = controller; 23 | initializeComponents(); 24 | layoutComponents(); 25 | setupListeners(); 26 | } 27 | 28 | private void initializeComponents() { 29 | // Initialize table 30 | actionsTable = createActionsTable(); 31 | 32 | // Initialize filter checkboxes 33 | filterCheckBoxes = createFilterCheckboxes(); 34 | 35 | // Initialize buttons 36 | analyzeFunctionButton = new JButton("Analyze Function"); 37 | analyzeClearButton = new JButton("Clear"); 38 | applyActionsButton = new JButton("Apply Actions"); 39 | } 40 | 41 | private JTable createActionsTable() { 42 | DefaultTableModel model = new DefaultTableModel( 43 | new Object[]{"Select", "Action", "Description", "Status", "Arguments"}, 0) { 44 | private static final long serialVersionUID = 1L; 45 | 46 | @Override 47 | public Class getColumnClass(int column) { 48 | return column == 0 ? Boolean.class : String.class; 49 | } 50 | }; 51 | 52 | JTable table = new JTable(model); 53 | int w = table.getColumnModel().getColumn(0).getWidth(); 54 | table.getColumnModel().getColumn(0).setMaxWidth((int)((double) (w*0.8))); 55 | return table; 56 | } 57 | 58 | private Map createFilterCheckboxes() { 59 | Map checkboxes = new HashMap<>(); 60 | for (Map fnTemplate : ActionConstants.FN_TEMPLATES) { 61 | if (fnTemplate.get("type").equals("function")) { 62 | @SuppressWarnings("unchecked") 63 | Map functionMap = (Map) fnTemplate.get("function"); 64 | String fnName = functionMap.get("name").toString(); 65 | String fnDescription = functionMap.get("description").toString(); 66 | String checkboxLabel = fnName.replace("_", " ") + ": " + fnDescription; 67 | checkboxes.put(fnName, new JCheckBox(checkboxLabel, true)); 68 | } 69 | } 70 | return checkboxes; 71 | } 72 | 73 | private void layoutComponents() { 74 | // Filter panel 75 | JPanel filterPanel = new JPanel(); 76 | filterPanel.setLayout(new BoxLayout(filterPanel, BoxLayout.Y_AXIS)); 77 | filterPanel.setBorder(BorderFactory.createTitledBorder("Filters")); 78 | filterCheckBoxes.values().forEach(filterPanel::add); 79 | 80 | JScrollPane filterScrollPane = new JScrollPane(filterPanel); 81 | filterScrollPane.setPreferredSize(new Dimension(200, 150)); 82 | add(filterScrollPane, BorderLayout.NORTH); 83 | 84 | // Table 85 | add(new JScrollPane(actionsTable), BorderLayout.CENTER); 86 | 87 | // Buttons 88 | JPanel buttonPanel = new JPanel(); 89 | buttonPanel.add(analyzeFunctionButton); 90 | buttonPanel.add(analyzeClearButton); 91 | buttonPanel.add(applyActionsButton); 92 | add(buttonPanel, BorderLayout.SOUTH); 93 | } 94 | 95 | private void setupListeners() { 96 | analyzeFunctionButton.addActionListener(e -> 97 | controller.handleAnalyzeFunction(filterCheckBoxes)); 98 | analyzeClearButton.addActionListener(e -> 99 | ((DefaultTableModel)actionsTable.getModel()).setRowCount(0)); 100 | applyActionsButton.addActionListener(e -> 101 | controller.handleApplyActions(actionsTable)); 102 | } 103 | 104 | public DefaultTableModel getTableModel() { 105 | return (DefaultTableModel)actionsTable.getModel(); 106 | } 107 | 108 | public void setAnalyzeFunctionButtonText(String text) { 109 | analyzeFunctionButton.setText(text); 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/main/java/ghidrassist/ui/tabs/AnalysisOptionsTab.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.ui.tabs; 2 | 3 | import javax.swing.*; 4 | import java.awt.*; 5 | import ghidrassist.core.TabController; 6 | 7 | public class AnalysisOptionsTab extends JPanel { 8 | private final TabController controller; 9 | private JTextArea contextArea; 10 | private JButton saveButton; 11 | private JButton revertButton; 12 | 13 | public AnalysisOptionsTab(TabController controller) { 14 | super(new BorderLayout()); 15 | this.controller = controller; 16 | initializeComponents(); 17 | layoutComponents(); 18 | setupListeners(); 19 | } 20 | 21 | private void initializeComponents() { 22 | contextArea = new JTextArea(); 23 | contextArea.setFont(new Font("Monospaced", Font.PLAIN, 12)); 24 | contextArea.setLineWrap(true); 25 | contextArea.setWrapStyleWord(true); 26 | 27 | saveButton = new JButton("Save"); 28 | revertButton = new JButton("Revert"); 29 | } 30 | 31 | private void layoutComponents() { 32 | // Add header label 33 | JLabel headerLabel = new JLabel("System Context"); 34 | headerLabel.setBorder(BorderFactory.createEmptyBorder(5, 5, 5, 5)); 35 | add(headerLabel, BorderLayout.NORTH); 36 | 37 | // Add text area with scroll pane 38 | JScrollPane scrollPane = new JScrollPane(contextArea); 39 | add(scrollPane, BorderLayout.CENTER); 40 | 41 | // Add button panel 42 | JPanel buttonPanel = new JPanel(new FlowLayout(FlowLayout.RIGHT)); 43 | buttonPanel.add(revertButton); 44 | buttonPanel.add(saveButton); 45 | add(buttonPanel, BorderLayout.SOUTH); 46 | } 47 | 48 | private void setupListeners() { 49 | saveButton.addActionListener(e -> controller.handleContextSave(contextArea.getText())); 50 | revertButton.addActionListener(e -> controller.handleContextRevert()); 51 | } 52 | 53 | public void setContextText(String text) { 54 | contextArea.setText(text); 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/main/java/ghidrassist/ui/tabs/ExplainTab.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.ui.tabs; 2 | 3 | import javax.swing.*; 4 | import java.awt.*; 5 | import ghidrassist.core.MarkdownHelper; 6 | import ghidrassist.core.TabController; 7 | 8 | public class ExplainTab extends JPanel { 9 | private static final long serialVersionUID = 1L; 10 | private final TabController controller; 11 | private final MarkdownHelper markdownHelper; 12 | private JLabel offsetLabel; 13 | private JTextField offsetField; 14 | private JEditorPane explainTextPane; 15 | private JTextArea markdownTextArea; 16 | private JButton explainFunctionButton; 17 | private JButton explainLineButton; 18 | private JButton clearExplainButton; 19 | private JButton editSaveButton; 20 | private JPanel contentPanel; 21 | private CardLayout contentLayout; 22 | private boolean isEditMode = false; 23 | private String currentMarkdown = ""; 24 | 25 | public ExplainTab(TabController controller) { 26 | super(new BorderLayout()); 27 | this.controller = controller; 28 | this.markdownHelper = new MarkdownHelper(); 29 | initializeComponents(); 30 | layoutComponents(); 31 | setupListeners(); 32 | } 33 | 34 | private void initializeComponents() { 35 | // Initialize offset field 36 | offsetLabel = new JLabel("Offset: "); 37 | offsetField = new JTextField(16); 38 | offsetField.setEditable(false); 39 | 40 | // Initialize text pane for HTML viewing 41 | explainTextPane = new JEditorPane(); 42 | explainTextPane.setEditable(false); 43 | explainTextPane.setContentType("text/html"); 44 | explainTextPane.addHyperlinkListener(controller::handleHyperlinkEvent); 45 | 46 | // Initialize text area for Markdown editing 47 | markdownTextArea = new JTextArea(); 48 | markdownTextArea.setFont(new Font("Monospaced", Font.PLAIN, 12)); 49 | markdownTextArea.setLineWrap(true); 50 | markdownTextArea.setWrapStyleWord(true); 51 | 52 | // Initialize buttons 53 | explainFunctionButton = new JButton("Explain Function"); 54 | explainLineButton = new JButton("Explain Line"); 55 | clearExplainButton = new JButton("Clear"); 56 | editSaveButton = new JButton("Edit"); 57 | 58 | // Setup card layout for switching between view and edit modes 59 | contentLayout = new CardLayout(); 60 | contentPanel = new JPanel(contentLayout); 61 | contentPanel.add(new JScrollPane(explainTextPane), "view"); 62 | contentPanel.add(new JScrollPane(markdownTextArea), "edit"); 63 | } 64 | 65 | private void layoutComponents() { 66 | // Offset and Edit/Save panel 67 | JPanel topPanel = new JPanel(new BorderLayout()); 68 | 69 | JPanel offsetPanel = new JPanel(); 70 | offsetPanel.add(offsetLabel); 71 | offsetPanel.add(offsetField); 72 | topPanel.add(offsetPanel, BorderLayout.WEST); 73 | 74 | JPanel editPanel = new JPanel(); 75 | editPanel.add(editSaveButton); 76 | topPanel.add(editPanel, BorderLayout.EAST); 77 | 78 | add(topPanel, BorderLayout.NORTH); 79 | 80 | // Text content panel with card layout 81 | add(contentPanel, BorderLayout.CENTER); 82 | 83 | // Button panel 84 | JPanel buttonPanel = new JPanel(); 85 | buttonPanel.add(explainFunctionButton); 86 | buttonPanel.add(explainLineButton); 87 | buttonPanel.add(clearExplainButton); 88 | add(buttonPanel, BorderLayout.SOUTH); 89 | } 90 | 91 | private void setupListeners() { 92 | explainFunctionButton.addActionListener(e -> controller.handleExplainFunction()); 93 | explainLineButton.addActionListener(e -> controller.handleExplainLine()); 94 | clearExplainButton.addActionListener(e -> { 95 | // Clear the UI 96 | explainTextPane.setText(""); 97 | markdownTextArea.setText(""); 98 | currentMarkdown = ""; 99 | 100 | // Also clear from database 101 | controller.handleClearAnalysisData(); 102 | }); 103 | 104 | editSaveButton.addActionListener(e -> { 105 | if (isEditMode) { 106 | // Save mode - save the markdown and switch to view mode 107 | currentMarkdown = markdownTextArea.getText(); 108 | String html = markdownHelper.markdownToHtml(currentMarkdown); 109 | explainTextPane.setText(html); 110 | 111 | // Save to database 112 | controller.handleUpdateAnalysis(currentMarkdown); 113 | 114 | // Switch to view mode 115 | contentLayout.show(contentPanel, "view"); 116 | editSaveButton.setText("Edit"); 117 | isEditMode = false; 118 | } else { 119 | // Edit mode - switch to the markdown editor 120 | markdownTextArea.setText(currentMarkdown); 121 | 122 | // Switch to edit mode 123 | contentLayout.show(contentPanel, "edit"); 124 | editSaveButton.setText("Save"); 125 | isEditMode = true; 126 | } 127 | }); 128 | } 129 | 130 | public void updateOffset(String offset) { 131 | offsetField.setText(offset); 132 | } 133 | 134 | public void setExplanationText(String text) { 135 | explainTextPane.setText(text); 136 | explainTextPane.setCaretPosition(0); 137 | 138 | // Store the markdown equivalent 139 | currentMarkdown = markdownHelper.extractMarkdownFromLlmResponse(text); 140 | 141 | // If we're in edit mode, update the markdown text area too 142 | if (isEditMode) { 143 | markdownTextArea.setText(currentMarkdown); 144 | } 145 | 146 | // Switch to view mode if we're setting new content 147 | if (isEditMode) { 148 | contentLayout.show(contentPanel, "view"); 149 | editSaveButton.setText("Edit"); 150 | isEditMode = false; 151 | } 152 | } 153 | 154 | public void setFunctionButtonText(String text) { 155 | explainFunctionButton.setText(text); 156 | } 157 | 158 | public void setLineButtonText(String text) { 159 | explainLineButton.setText(text); 160 | } 161 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/ui/tabs/MCPServerDialog.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.ui.tabs; 2 | 3 | import java.awt.*; 4 | import java.awt.event.ActionEvent; 5 | import java.awt.event.ActionListener; 6 | import javax.swing.*; 7 | 8 | import ghidrassist.mcp2.server.MCPServerConfig; 9 | 10 | public class MCPServerDialog extends JDialog { 11 | private static final long serialVersionUID = 1L; 12 | 13 | private JTextField nameField; 14 | private JTextField urlField; 15 | private JComboBox transportCombo; 16 | private JCheckBox enabledCheckBox; 17 | private JButton okButton; 18 | private JButton cancelButton; 19 | private boolean confirmed = false; 20 | 21 | public MCPServerDialog(Window parent, MCPServerConfig existingServer) { 22 | super(parent, existingServer == null ? "Add MCP Server" : "Edit MCP Server", 23 | ModalityType.APPLICATION_MODAL); 24 | 25 | initializeComponents(); 26 | layoutComponents(); 27 | setupEventHandlers(); 28 | 29 | if (existingServer != null) { 30 | populateFields(existingServer); 31 | } else { 32 | setDefaults(); 33 | } 34 | 35 | pack(); 36 | setLocationRelativeTo(parent); 37 | nameField.requestFocusInWindow(); 38 | } 39 | 40 | private void initializeComponents() { 41 | nameField = new JTextField(20); 42 | urlField = new JTextField(30); 43 | transportCombo = new JComboBox<>(MCPServerConfig.TransportType.values()); 44 | enabledCheckBox = new JCheckBox("Enabled", true); 45 | 46 | okButton = new JButton("OK"); 47 | cancelButton = new JButton("Cancel"); 48 | 49 | getRootPane().setDefaultButton(okButton); 50 | } 51 | 52 | private void layoutComponents() { 53 | setLayout(new BorderLayout()); 54 | 55 | // Form panel 56 | JPanel formPanel = new JPanel(new GridBagLayout()); 57 | GridBagConstraints gbc = new GridBagConstraints(); 58 | gbc.insets = new Insets(5, 5, 5, 5); 59 | 60 | // Name 61 | gbc.gridx = 0; gbc.gridy = 0; gbc.anchor = GridBagConstraints.EAST; 62 | formPanel.add(new JLabel("Name:"), gbc); 63 | gbc.gridx = 1; gbc.anchor = GridBagConstraints.WEST; gbc.fill = GridBagConstraints.HORIZONTAL; 64 | formPanel.add(nameField, gbc); 65 | 66 | // URL 67 | gbc.gridx = 0; gbc.gridy = 1; gbc.anchor = GridBagConstraints.EAST; gbc.fill = GridBagConstraints.NONE; 68 | formPanel.add(new JLabel("URL:"), gbc); 69 | gbc.gridx = 1; gbc.anchor = GridBagConstraints.WEST; gbc.fill = GridBagConstraints.HORIZONTAL; 70 | formPanel.add(urlField, gbc); 71 | 72 | // Transport 73 | gbc.gridx = 0; gbc.gridy = 2; gbc.anchor = GridBagConstraints.EAST; gbc.fill = GridBagConstraints.NONE; 74 | formPanel.add(new JLabel("Transport:"), gbc); 75 | gbc.gridx = 1; gbc.anchor = GridBagConstraints.WEST; gbc.fill = GridBagConstraints.HORIZONTAL; 76 | formPanel.add(transportCombo, gbc); 77 | 78 | // Enabled 79 | gbc.gridx = 1; gbc.gridy = 3; gbc.anchor = GridBagConstraints.WEST; 80 | formPanel.add(enabledCheckBox, gbc); 81 | 82 | // Help text 83 | JPanel helpPanel = new JPanel(new BorderLayout()); 84 | JTextArea helpText = new JTextArea( 85 | "Examples:\n" + 86 | "• Name: GhidraMCP, URL: http://localhost:8080\n" + 87 | "• Name: Local Tools, URL: http://127.0.0.1:3000\n\n" + 88 | "The server must implement the Model Context Protocol (MCP) specification." 89 | ); 90 | helpText.setEditable(false); 91 | helpText.setOpaque(false); 92 | helpText.setFont(helpText.getFont().deriveFont(Font.ITALIC, 11f)); 93 | helpPanel.add(helpText, BorderLayout.CENTER); 94 | helpPanel.setBorder(BorderFactory.createTitledBorder("Help")); 95 | 96 | // Button panel 97 | JPanel buttonPanel = new JPanel(new FlowLayout(FlowLayout.RIGHT)); 98 | buttonPanel.add(okButton); 99 | buttonPanel.add(cancelButton); 100 | 101 | // Layout 102 | add(formPanel, BorderLayout.CENTER); 103 | add(helpPanel, BorderLayout.NORTH); 104 | add(buttonPanel, BorderLayout.SOUTH); 105 | 106 | // Add border to content pane instead 107 | getRootPane().setBorder(BorderFactory.createEmptyBorder(10, 10, 10, 10)); 108 | } 109 | 110 | private void setupEventHandlers() { 111 | okButton.addActionListener(e -> { 112 | if (validateInput()) { 113 | confirmed = true; 114 | dispose(); 115 | } 116 | }); 117 | 118 | cancelButton.addActionListener(e -> { 119 | confirmed = false; 120 | dispose(); 121 | }); 122 | 123 | // Transport selection updates URL placeholder 124 | transportCombo.addActionListener(e -> updateUrlPlaceholder()); 125 | } 126 | 127 | private void setDefaults() { 128 | transportCombo.setSelectedItem(MCPServerConfig.TransportType.SSE); 129 | updateUrlPlaceholder(); 130 | } 131 | 132 | private void populateFields(MCPServerConfig server) { 133 | nameField.setText(server.getName()); 134 | urlField.setText(server.getBaseUrl()); 135 | transportCombo.setSelectedItem(server.getTransport()); 136 | enabledCheckBox.setSelected(server.isEnabled()); 137 | } 138 | 139 | private void updateUrlPlaceholder() { 140 | MCPServerConfig.TransportType transport = 141 | (MCPServerConfig.TransportType) transportCombo.getSelectedItem(); 142 | 143 | if (transport == MCPServerConfig.TransportType.SSE) { 144 | urlField.setToolTipText("HTTP(S) URL for Server-Sent Events transport (e.g., http://localhost:8080)"); 145 | } else { 146 | urlField.setToolTipText("Command or path for stdio transport"); 147 | } 148 | } 149 | 150 | private boolean validateInput() { 151 | String name = nameField.getText().trim(); 152 | String url = urlField.getText().trim(); 153 | 154 | if (name.isEmpty()) { 155 | showError("Name cannot be empty."); 156 | nameField.requestFocus(); 157 | return false; 158 | } 159 | 160 | if (url.isEmpty()) { 161 | showError("URL cannot be empty."); 162 | urlField.requestFocus(); 163 | return false; 164 | } 165 | 166 | // Basic URL validation for SSE transport 167 | MCPServerConfig.TransportType transport = 168 | (MCPServerConfig.TransportType) transportCombo.getSelectedItem(); 169 | 170 | if (transport == MCPServerConfig.TransportType.SSE) { 171 | if (!url.startsWith("http://") && !url.startsWith("https://")) { 172 | showError("URL must start with http:// or https:// for SSE transport."); 173 | urlField.requestFocus(); 174 | return false; 175 | } 176 | } 177 | 178 | return true; 179 | } 180 | 181 | private void showError(String message) { 182 | JOptionPane.showMessageDialog(this, message, "Validation Error", JOptionPane.ERROR_MESSAGE); 183 | } 184 | 185 | public boolean isConfirmed() { 186 | return confirmed; 187 | } 188 | 189 | public MCPServerConfig getServerConfig() { 190 | if (!confirmed) return null; 191 | 192 | return new MCPServerConfig( 193 | nameField.getText().trim(), 194 | urlField.getText().trim(), 195 | (MCPServerConfig.TransportType) transportCombo.getSelectedItem(), 196 | enabledCheckBox.isSelected() 197 | ); 198 | } 199 | } -------------------------------------------------------------------------------- /src/main/java/ghidrassist/ui/tabs/RAGManagementTab.java: -------------------------------------------------------------------------------- 1 | package ghidrassist.ui.tabs; 2 | 3 | import javax.swing.*; 4 | import java.awt.*; 5 | import ghidrassist.core.TabController; 6 | 7 | public class RAGManagementTab extends JPanel { 8 | private static final long serialVersionUID = 1L; 9 | private final TabController controller; 10 | private JList documentList; 11 | private JButton addDocumentsButton; 12 | private JButton deleteSelectedButton; 13 | private JButton refreshListButton; 14 | 15 | public RAGManagementTab(TabController controller) { 16 | super(new BorderLayout()); 17 | this.controller = controller; 18 | initializeComponents(); 19 | layoutComponents(); 20 | setupListeners(); 21 | } 22 | 23 | private void initializeComponents() { 24 | addDocumentsButton = new JButton("Add Documents to RAG"); 25 | documentList = new JList<>(); 26 | deleteSelectedButton = new JButton("Delete Selected"); 27 | refreshListButton = new JButton("Refresh List"); 28 | } 29 | 30 | private void layoutComponents() { 31 | add(addDocumentsButton, BorderLayout.NORTH); 32 | add(new JScrollPane(documentList), BorderLayout.CENTER); 33 | 34 | JPanel buttonPanel = new JPanel(); 35 | buttonPanel.add(deleteSelectedButton); 36 | buttonPanel.add(refreshListButton); 37 | add(buttonPanel, BorderLayout.SOUTH); 38 | } 39 | 40 | private void setupListeners() { 41 | addDocumentsButton.addActionListener(e -> 42 | controller.handleAddDocuments(documentList)); 43 | deleteSelectedButton.addActionListener(e -> 44 | controller.handleDeleteSelected(documentList)); 45 | refreshListButton.addActionListener(e -> 46 | controller.loadIndexedFiles(documentList)); 47 | } 48 | 49 | public void updateDocumentList(String[] files) { 50 | documentList.setListData(files); 51 | } 52 | 53 | public JList getDocumentList() { 54 | return documentList; 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/main/resources/images/README.txt: -------------------------------------------------------------------------------- 1 | The "src/resources/images" directory is intended to hold all image/icon files used by 2 | this module. 3 | -------------------------------------------------------------------------------- /src/main/resources/images/robot.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /src/main/resources/images/robot16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jtang613/GhidrAssist/f6bd628c273e8ff3da9b55703bcaa90e26c58f4d/src/main/resources/images/robot16.png -------------------------------------------------------------------------------- /src/main/resources/images/robot32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jtang613/GhidrAssist/f6bd628c273e8ff3da9b55703bcaa90e26c58f4d/src/main/resources/images/robot32.png --------------------------------------------------------------------------------