/Ghidra/application.properties
30 | // for the correction version of Gradle to use for the Ghidra installation you specify.
31 |
32 | plugins {
33 | id 'java'
34 | id 'eclipse'
35 | }
36 |
37 | //----------------------START "DO NOT MODIFY" SECTION------------------------------
38 | def ghidraInstallDir
39 |
40 | if (System.env.GHIDRA_INSTALL_DIR) {
41 | ghidraInstallDir = System.env.GHIDRA_INSTALL_DIR
42 | }
43 | else if (project.hasProperty("GHIDRA_INSTALL_DIR")) {
44 | ghidraInstallDir = project.getProperty("GHIDRA_INSTALL_DIR")
45 | }
46 |
47 | if (ghidraInstallDir) {
48 | apply from: new File(ghidraInstallDir).getCanonicalPath() + "/support/buildExtension.gradle"
49 | }
50 | else {
51 | throw new GradleException("GHIDRA_INSTALL_DIR is not defined!")
52 | }
53 | //----------------------END "DO NOT MODIFY" SECTION-------------------------------
54 |
55 | repositories {
56 | // Declare dependency repositories here. This is not needed if dependencies are manually
57 | // dropped into the lib/ directory.
58 | // See https://docs.gradle.org/current/userguide/declaring_repositories.html for more info.
59 | mavenCentral()
60 | }
61 |
62 | dependencies {
63 | // Any external dependencies added here will automatically be copied to the lib/ directory when
64 | // this extension is built.
65 |
66 | implementation 'com.fasterxml.jackson.core:jackson-databind:2.15.0'
67 | implementation "io.reactivex.rxjava3:rxjava:3.1.9"
68 | implementation 'com.vladsch.flexmark:flexmark:0.64.0'
69 | implementation 'com.vladsch.flexmark:flexmark-html2md-converter:0.64.0'
70 | implementation 'org.xerial:sqlite-jdbc:3.46.1.0'
71 | implementation 'org.apache.lucene:lucene-core:9.11.1'
72 | implementation 'org.apache.lucene:lucene-analysis-common:9.11.1'
73 | implementation 'org.apache.lucene:lucene-queryparser:9.11.1'
74 | implementation 'com.squareup.okio:okio:3.10.2'
75 | implementation "com.squareup.okhttp3:okhttp:4.12.0"
76 | }
77 |
78 | // Exclude additional files from the built extension
79 | // Ex: buildExtension.exclude '.idea/**'
80 |
--------------------------------------------------------------------------------
/data/README.txt:
--------------------------------------------------------------------------------
1 | The "data" directory is intended to hold data files that will be used by this module and will
2 | not end up in the .jar file, but will be present in the zip or tar file. Typically, data
3 | files are placed here rather than in the resources directory if the user may need to edit them.
4 |
5 | An optional data/languages directory can exist for the purpose of containing various Sleigh language
6 | specification files and importer opinion files.
7 |
8 | The data/buildLanguage.xml is used for building the contents of the data/languages directory.
9 |
10 | The skel language definition has been commented-out within the skel.ldefs file so that the
11 | skeleton language does not show-up within Ghidra.
12 |
13 | See the Sleigh language documentation (docs/languages/index.html) for details Sleigh language
14 | specification syntax.
15 |
--------------------------------------------------------------------------------
/extension.properties:
--------------------------------------------------------------------------------
1 | name=@extname@
2 | description=A plugin that provides LLM helpers to explain code and assist in RE.
3 | author=Jason Tang
4 | createdOn=
5 | version=@extversion@
6 |
--------------------------------------------------------------------------------
/ghidra_scripts/README.txt:
--------------------------------------------------------------------------------
1 | Java source directory to hold module-specific Ghidra scripts.
2 |
--------------------------------------------------------------------------------
/res/screenshot1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jtang613/GhidrAssist/f6bd628c273e8ff3da9b55703bcaa90e26c58f4d/res/screenshot1.png
--------------------------------------------------------------------------------
/res/screenshots_anim.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jtang613/GhidrAssist/f6bd628c273e8ff3da9b55703bcaa90e26c58f4d/res/screenshots_anim.gif
--------------------------------------------------------------------------------
/src/main/help/help/TOC_Source.xml:
--------------------------------------------------------------------------------
1 |
2 |
49 |
50 |
51 |
52 |
57 |
58 |
--------------------------------------------------------------------------------
/src/main/help/help/topics/ghidrassist/help.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
7 |
8 |
9 |
10 |
11 |
12 | Skeleton Help File for a Module
13 |
14 |
15 |
16 |
17 | Skeleton Help File for a Module
18 |
19 | This is a simple skeleton help topic. For a better description of what should and should not
20 | go in here, see the "sample" Ghidra extension in the Extensions/Ghidra directory, or see your
21 | favorite help topic. In general, language modules do not have their own help topics.
22 |
23 |
24 |
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/GAUtils.java:
--------------------------------------------------------------------------------
1 | package ghidrassist;
2 |
3 | import java.io.File;
4 |
5 | public class GAUtils {
6 | public enum OperatingSystem {
7 | WINDOWS, MAC, LINUX, UNKNOWN;
8 |
9 | public static OperatingSystem detect() {
10 | String os = System.getProperty("os.name").toLowerCase();
11 | if (os.contains("win")) {
12 | return WINDOWS;
13 | } else if (os.contains("mac")) {
14 | return MAC;
15 | } else if (os.contains("nix") || os.contains("nux") || os.contains("aix")) {
16 | return LINUX;
17 | } else {
18 | return UNKNOWN;
19 | }
20 | }
21 | }
22 |
23 | public static String getDefaultLucenePath(OperatingSystem os) {
24 | String basePath;
25 | switch (os) {
26 | case WINDOWS:
27 | basePath = System.getenv("LOCALAPPDATA");
28 | if (basePath == null) {
29 | throw new RuntimeException("Unable to access LOCALAPPDATA environment variable.");
30 | }
31 | break;
32 |
33 | case MAC:
34 | basePath = System.getProperty("user.home") + "/Library/Application Support";
35 | break;
36 |
37 | case LINUX:
38 | basePath = System.getProperty("user.home") + "/.config";
39 | break;
40 |
41 | default:
42 | throw new UnsupportedOperationException("Unsupported operating system: " + os);
43 | }
44 | return basePath + File.separator + "GhidrAssist" + File.separator + "LuceneIndex";
45 | }
46 |
47 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/GhidrAssistPlugin.java:
--------------------------------------------------------------------------------
1 | package ghidrassist;
2 |
3 | import java.lang.reflect.Type;
4 | import java.util.List;
5 |
6 | import com.google.gson.Gson;
7 | import com.google.gson.reflect.TypeToken;
8 |
9 | import docking.ActionContext;
10 | import docking.action.DockingAction;
11 | import docking.action.MenuData;
12 | import ghidra.app.decompiler.DecompilerLocation;
13 | import ghidra.app.plugin.PluginCategoryNames;
14 | import ghidra.app.plugin.ProgramPlugin;
15 | import ghidra.framework.plugintool.*;
16 | import ghidra.framework.plugintool.util.PluginStatus;
17 | import ghidra.framework.preferences.Preferences;
18 | import ghidra.program.model.address.Address;
19 | import ghidra.program.model.listing.Function;
20 | import ghidra.program.model.listing.FunctionManager;
21 | import ghidra.program.model.listing.Program;
22 | import ghidra.program.util.ProgramLocation;
23 | import ghidrassist.apiprovider.APIProviderConfig;
24 |
25 | @PluginInfo(
26 | status = PluginStatus.STABLE,
27 | packageName = "GhidrAssist",
28 | category = PluginCategoryNames.COMMON,
29 | shortDescription = "GhidrAssist LLM Plugin",
30 | description = "A plugin that provides code assistance using a language model."
31 | )
32 | public class GhidrAssistPlugin extends ProgramPlugin {
33 | public enum CodeViewType {
34 | IS_DECOMPILER,
35 | IS_DISASSEMBLER,
36 | UNKNOWN
37 | }
38 | private GhidrAssistProvider provider;
39 | private String lastActiveProvider;
40 |
41 | public GhidrAssistPlugin(PluginTool tool) {
42 | super(tool);
43 | String pluginName = getName();
44 | provider = new GhidrAssistProvider(this, pluginName);
45 | }
46 |
47 | @Override
48 | public void init() {
49 | super.init();
50 |
51 | // Add a menu action for settings
52 | DockingAction settingsAction = new DockingAction("GhidrAssist Settings", getName()) {
53 | @Override
54 | public void actionPerformed(ActionContext context) {
55 | showSettingsDialog();
56 | }
57 | };
58 | settingsAction.setMenuBarData(new MenuData(new String[] { "Tools", "GhidrAssist Settings" }, null, "GhidrAssist"));
59 | tool.addAction(settingsAction);
60 | }
61 |
62 | @Override
63 | protected void dispose() {
64 | if (provider != null) {
65 | tool.removeComponentProvider(provider);
66 | provider = null;
67 | }
68 | super.dispose();
69 | }
70 |
71 | private void showSettingsDialog() {
72 | SettingsDialog dialog = new SettingsDialog(tool.getToolFrame(), "GhidrAssist Settings", this);
73 | tool.showDialog(dialog);
74 | }
75 |
76 | @Override
77 | public void locationChanged(ProgramLocation loc) {
78 | if (provider != null) {
79 | provider.getUI().updateLocation(loc);
80 | }
81 | }
82 |
83 | public Program getCurrentProgram() {
84 | return currentProgram;
85 | }
86 |
87 | public Address getCurrentAddress() {
88 | if (currentLocation != null) {
89 | return currentLocation.getAddress();
90 | }
91 | return null;
92 | }
93 |
94 | public Function getCurrentFunction() {
95 | Program program = getCurrentProgram();
96 | Address address = getCurrentAddress();
97 |
98 | if (program != null && address != null) {
99 | FunctionManager functionManager = program.getFunctionManager();
100 | return functionManager.getFunctionContaining(address);
101 | }
102 | return null;
103 | }
104 |
105 | public String getLastActiveProvider() {
106 | return lastActiveProvider;
107 | }
108 |
109 | public CodeViewType checkLastActiveCodeView() {
110 | if (currentLocation instanceof DecompilerLocation) {
111 | return CodeViewType.IS_DECOMPILER;
112 | } else if (currentLocation != null) {
113 | return CodeViewType.IS_DISASSEMBLER;
114 | } else {
115 | return CodeViewType.UNKNOWN;
116 | }
117 | }
118 |
119 | public static APIProviderConfig getCurrentProviderConfig() {
120 | // Load the list of API providers from preferences
121 | String providersJson = Preferences.getProperty("GhidrAssist.APIProviders", "[]");
122 | Gson gson = new Gson();
123 | Type listType = new TypeToken>() {}.getType();
124 | List apiProviders = gson.fromJson(providersJson, listType);
125 |
126 | // Load the selected provider name
127 | String selectedProviderName = Preferences.getProperty("GhidrAssist.SelectedAPIProvider", "");
128 |
129 | // Load the global API timeout setting
130 | String apiTimeoutStr = Preferences.getProperty("GhidrAssist.APITimeout", "120");
131 | Integer apiTimeout = 120; // Default value
132 | try {
133 | apiTimeout = Integer.parseInt(apiTimeoutStr);
134 | } catch (NumberFormatException e) {
135 | // Use default if there's an error
136 | }
137 |
138 | for (APIProviderConfig provider : apiProviders) {
139 | if (provider.getName().equals(selectedProviderName)) {
140 | // If the provider doesn't have a timeout set, use the global setting
141 | if (provider.getTimeout() == null) {
142 | provider.setTimeout(apiTimeout);
143 | }
144 | return provider;
145 | }
146 | }
147 |
148 | return null;
149 | }
150 |
151 | public static Integer getGlobalApiTimeout() {
152 | String apiTimeoutStr = Preferences.getProperty("GhidrAssist.APITimeout", "120");
153 | try {
154 | return Integer.parseInt(apiTimeoutStr);
155 | } catch (NumberFormatException e) {
156 | return 120; // Default value
157 | }
158 | }
159 |
160 | public GhidrAssistPlugin getInstance() {
161 | return this;
162 | }
163 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/GhidrAssistProvider.java:
--------------------------------------------------------------------------------
1 | package ghidrassist;
2 |
3 | import java.awt.event.MouseEvent;
4 | import java.util.ArrayList;
5 | import java.util.List;
6 |
7 | import javax.swing.*;
8 |
9 | import docking.ActionContext;
10 | import docking.ComponentProvider;
11 | import docking.DefaultActionContext;
12 | import docking.action.DockingAction;
13 | import docking.action.ToolBarData;
14 | import ghidra.util.Msg;
15 | import ghidrassist.resources.GhidrAssistIcons;
16 | import ghidrassist.ui.GhidrAssistUI;
17 | import resources.Icons;
18 |
19 | public class GhidrAssistProvider extends ComponentProvider {
20 | private GhidrAssistPlugin plugin;
21 | private GhidrAssistUI ui;
22 | private JComponent mainPanel;
23 | private List actions;
24 |
25 | public GhidrAssistProvider(GhidrAssistPlugin plugin, String owner) {
26 | super(plugin.getTool(), owner, owner);
27 | this.plugin = plugin;
28 | this.actions = new ArrayList<>();
29 |
30 | buildPanel();
31 | createActions();
32 | setIcon(GhidrAssistIcons.ROBOT_ICON);
33 | }
34 |
35 | private void buildPanel() {
36 | ui = new GhidrAssistUI(plugin);
37 | mainPanel = ui.getComponent();
38 | setVisible(true);
39 | }
40 |
41 | private void createActions() {
42 | DockingAction refreshAction = new DockingAction("Refresh GhidrAssist", getName()) {
43 | @Override
44 | public void actionPerformed(ActionContext context) {
45 | refresh();
46 | }
47 | };
48 | refreshAction.setToolBarData(new ToolBarData(Icons.REFRESH_ICON, null));
49 | refreshAction.setEnabled(true);
50 | refreshAction.markHelpUnnecessary();
51 | actions.add(refreshAction);
52 |
53 | // Add actions to the tool
54 | for (DockingAction action : actions) {
55 | plugin.getTool().addLocalAction(this, action);
56 | }
57 | }
58 |
59 | public GhidrAssistUI getUI() {
60 | return ui;
61 | }
62 |
63 | @Override
64 | public JComponent getComponent() {
65 | return mainPanel;
66 | }
67 |
68 | @Override
69 | public ActionContext getActionContext(MouseEvent event) {
70 | return new DefaultActionContext(this, mainPanel);
71 | }
72 |
73 | public void refresh() {
74 | try {
75 | Msg.info(this, "GhidrAssist UI refreshed");
76 | }
77 | catch (Exception e) {
78 | Msg.error(this, "Error refreshing GhidrAssist UI", e);
79 | }
80 | }
81 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/LlmApi.java:
--------------------------------------------------------------------------------
1 | package ghidrassist;
2 |
3 | import ghidrassist.apiprovider.APIProviderConfig;
4 | import ghidrassist.core.ConversationalToolHandler;
5 | import ghidrassist.core.LlmApiClient;
6 | import ghidrassist.core.LlmErrorHandler;
7 | import ghidrassist.core.LlmTaskExecutor;
8 | import ghidrassist.core.ResponseProcessor;
9 |
10 | import java.util.List;
11 | import java.util.Map;
12 |
13 | /**
14 | * - LlmApiClient: Provider management and API calls
15 | * - ResponseProcessor: Text filtering and processing
16 | * - LlmTaskExecutor: Background task execution
17 | * - LlmErrorHandler: Error handling and user feedback
18 | */
19 | public class LlmApi {
20 |
21 | private final LlmApiClient apiClient;
22 | private final ResponseProcessor responseProcessor;
23 | private final LlmTaskExecutor taskExecutor;
24 | private final LlmErrorHandler errorHandler;
25 | private volatile ConversationalToolHandler activeConversationalHandler;
26 |
27 | public LlmApi(APIProviderConfig config, GhidrAssistPlugin plugin) {
28 | this.apiClient = new LlmApiClient(config, plugin);
29 | this.responseProcessor = new ResponseProcessor();
30 | this.taskExecutor = new LlmTaskExecutor();
31 | this.errorHandler = new LlmErrorHandler(plugin, this);
32 | }
33 |
34 | /**
35 | * Get the system prompt for regular queries
36 | */
37 | public String getSystemPrompt() {
38 | return apiClient.getSystemPrompt();
39 | }
40 |
41 | /**
42 | * Send a streaming request with enhanced error handling
43 | */
44 | public void sendRequestAsync(String prompt, LlmResponseHandler responseHandler) {
45 | if (!apiClient.isProviderAvailable()) {
46 | errorHandler.handleError(
47 | new IllegalStateException("LLM provider is not initialized."),
48 | "send request",
49 | null
50 | );
51 | return;
52 | }
53 |
54 | // Create enhanced response handler that includes error handling
55 | LlmTaskExecutor.LlmResponseHandler enhancedHandler = new LlmTaskExecutor.LlmResponseHandler() {
56 | @Override
57 | public void onStart() {
58 | responseHandler.onStart();
59 | }
60 |
61 | @Override
62 | public void onUpdate(String partialResponse) {
63 | responseHandler.onUpdate(partialResponse);
64 | }
65 |
66 | @Override
67 | public void onComplete(String fullResponse) {
68 | responseHandler.onComplete(fullResponse);
69 | }
70 |
71 | @Override
72 | public void onError(Throwable error) {
73 | // Handle error with enhanced error handling
74 | Runnable retryAction = () -> sendRequestAsync(prompt, responseHandler);
75 | errorHandler.handleError(error, "stream chat completion", retryAction);
76 | responseHandler.onError(error);
77 | }
78 |
79 | @Override
80 | public boolean shouldContinue() {
81 | return responseHandler.shouldContinue();
82 | }
83 | };
84 |
85 | taskExecutor.executeStreamingRequest(apiClient, prompt, responseProcessor, enhancedHandler);
86 | }
87 |
88 | /**
89 | * Send a conversational tool calling request that handles multiple turns
90 | * Monitors finish_reason to determine when to execute tools vs. complete
91 | */
92 | public void sendConversationalToolRequest(String prompt, List> functions, LlmResponseHandler responseHandler) {
93 | if (!apiClient.isProviderAvailable()) {
94 | errorHandler.handleError(
95 | new IllegalStateException("LLM provider is not initialized."),
96 | "send conversational tool request",
97 | null
98 | );
99 | return;
100 | }
101 |
102 | // Create completion callback to clear reference
103 | Runnable onCompletion = () -> {
104 | activeConversationalHandler = null;
105 | };
106 |
107 | // Create enhanced response handler for conversational tool calling
108 | ConversationalToolHandler toolHandler = new ConversationalToolHandler(
109 | apiClient, functions, responseProcessor, responseHandler, errorHandler, onCompletion);
110 |
111 | // Store reference for cancellation
112 | activeConversationalHandler = toolHandler;
113 |
114 | // Start the conversation
115 | toolHandler.startConversation(prompt);
116 | }
117 |
118 | /**
119 | * Send a function calling request with enhanced error handling (legacy method)
120 | */
121 | public void sendRequestAsyncWithFunctions(String prompt, List> functions, LlmResponseHandler responseHandler) {
122 | if (!apiClient.isProviderAvailable()) {
123 | errorHandler.handleError(
124 | new IllegalStateException("LLM provider is not initialized."),
125 | "send function request",
126 | null
127 | );
128 | return;
129 | }
130 |
131 | // Create enhanced response handler that includes error handling
132 | LlmTaskExecutor.LlmResponseHandler enhancedHandler = new LlmTaskExecutor.LlmResponseHandler() {
133 | @Override
134 | public void onStart() {
135 | responseHandler.onStart();
136 | }
137 |
138 | @Override
139 | public void onUpdate(String partialResponse) {
140 | responseHandler.onUpdate(partialResponse);
141 | }
142 |
143 | @Override
144 | public void onComplete(String fullResponse) {
145 | responseHandler.onComplete(fullResponse);
146 | }
147 |
148 | @Override
149 | public void onError(Throwable error) {
150 | // Handle error with enhanced error handling
151 | Runnable retryAction = () -> sendRequestAsyncWithFunctions(prompt, functions, responseHandler);
152 | errorHandler.handleError(error, "chat completion with functions", retryAction);
153 | responseHandler.onError(error);
154 | }
155 |
156 | @Override
157 | public boolean shouldContinue() {
158 | return responseHandler.shouldContinue();
159 | }
160 | };
161 |
162 | taskExecutor.executeFunctionRequest(apiClient, prompt, functions, responseProcessor, enhancedHandler);
163 | }
164 |
165 | /**
166 | * Cancel the current request
167 | */
168 | public void cancelCurrentRequest() {
169 | // Cancel conversational tool handler if active
170 | if (activeConversationalHandler != null) {
171 | activeConversationalHandler.cancel();
172 | activeConversationalHandler = null;
173 | }
174 |
175 | // Cancel regular task executor
176 | taskExecutor.cancelCurrentRequest();
177 | }
178 |
179 | /**
180 | * Check if currently processing a request
181 | */
182 | public boolean isStreaming() {
183 | return taskExecutor.isStreaming();
184 | }
185 |
186 | /**
187 | * Get provider information for debugging/logging
188 | */
189 | public String getProviderInfo() {
190 | return String.format("Provider: %s, Model: %s",
191 | apiClient.getProviderName(),
192 | apiClient.getProviderModel());
193 | }
194 |
195 | /**
196 | * Interface for handling LLM responses - maintains compatibility with existing code
197 | */
198 | public interface LlmResponseHandler {
199 | void onStart();
200 | void onUpdate(String partialResponse);
201 | void onComplete(String fullResponse);
202 | void onError(Throwable error);
203 | default boolean shouldContinue() {
204 | return true;
205 | }
206 | }
207 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/RLHFDatabase.java:
--------------------------------------------------------------------------------
1 | package ghidrassist;
2 |
3 | import ghidra.framework.preferences.Preferences;
4 | import ghidra.util.Msg;
5 |
6 | import java.sql.*;
7 |
8 | public class RLHFDatabase {
9 |
10 | private static final String DB_PATH_PROPERTY = "GhidrAssist.RLHFDatabasePath";
11 | private static final String DEFAULT_DB_PATH = "ghidrassist_rlhf.db";
12 | private Connection connection;
13 |
14 | public RLHFDatabase() {
15 | String dbPath = Preferences.getProperty(DB_PATH_PROPERTY, DEFAULT_DB_PATH);
16 | initializeDatabase(dbPath);
17 | }
18 |
19 | private void initializeDatabase(String dbPath) {
20 | try {
21 | connection = DriverManager.getConnection("jdbc:sqlite:" + dbPath);
22 | createFeedbackTable();
23 | } catch (SQLException e) {
24 | Msg.showError(this, null, "Database Error", "Failed to initialize RLHF database: " + e.getMessage());
25 | }
26 | }
27 |
28 | private void createFeedbackTable() throws SQLException {
29 | String createTableSQL = "CREATE TABLE IF NOT EXISTS feedback ("
30 | + "id INTEGER PRIMARY KEY AUTOINCREMENT,"
31 | + "model_name TEXT NOT NULL,"
32 | + "prompt_context TEXT NOT NULL,"
33 | + "system_context TEXT NOT NULL,"
34 | + "response TEXT NOT NULL,"
35 | + "feedback INTEGER NOT NULL" // 1 for thumbs up, 0 for thumbs down
36 | + ")";
37 | Statement stmt = connection.createStatement();
38 | stmt.execute(createTableSQL);
39 | stmt.close();
40 | }
41 |
42 | public void storeFeedback(String modelName, String promptContext, String systemContext, String response, int feedback) {
43 | String insertSQL = "INSERT INTO feedback (model_name, prompt_context, system_context, response, feedback) "
44 | + "VALUES (?, ?, ?, ?, ?)";
45 | try (PreparedStatement pstmt = connection.prepareStatement(insertSQL)) {
46 | pstmt.setString(1, modelName);
47 | pstmt.setString(2, promptContext);
48 | pstmt.setString(3, systemContext);
49 | pstmt.setString(4, response);
50 | pstmt.setInt(5, feedback);
51 | pstmt.executeUpdate();
52 | } catch (SQLException e) {
53 | Msg.showError(this, null, "Database Error", "Failed to store feedback: " + e.getMessage());
54 | }
55 | }
56 |
57 | public void close() {
58 | try {
59 | connection.close();
60 | } catch (SQLException e) {
61 | Msg.showError(this, null, "Database Error", "Failed to close RLHF database connection: " + e.getMessage());
62 | }
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/APIProviderConfig.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider;
2 |
3 | import ghidrassist.GhidrAssistPlugin;
4 | import ghidrassist.apiprovider.factory.ProviderRegistry;
5 | import ghidrassist.apiprovider.factory.UnsupportedProviderException;
6 |
7 | public class APIProviderConfig {
8 | private String name;
9 | private String model;
10 | private Integer maxTokens;
11 | private String url;
12 | private String key;
13 | private boolean disableTlsVerification;
14 | private APIProvider.ProviderType type;
15 | private Integer timeout;
16 |
17 | public APIProviderConfig(
18 | String name,
19 | APIProvider.ProviderType type,
20 | String model,
21 | Integer maxTokens,
22 | String url,
23 | String key,
24 | boolean disableTlsVerification) {
25 | this(name, type, model, maxTokens, url, key, disableTlsVerification, 120); // Default timeout of 120 seconds
26 | }
27 |
28 | public APIProviderConfig(
29 | String name,
30 | APIProvider.ProviderType type,
31 | String model,
32 | Integer maxTokens,
33 | String url,
34 | String key,
35 | boolean disableTlsVerification,
36 | Integer timeout) {
37 | this.name = name;
38 | this.type = type;
39 | this.model = model;
40 | this.maxTokens = maxTokens;
41 | this.url = url;
42 | this.key = key;
43 | this.disableTlsVerification = disableTlsVerification;
44 | this.timeout = timeout;
45 | }
46 |
47 | // Getters
48 | public String getName() { return name; }
49 | public APIProvider.ProviderType getType() { return type; }
50 | public String getModel() { return model; }
51 | public Integer getMaxTokens() { return maxTokens; }
52 | public String getUrl() { return url; }
53 | public String getKey() { return key; }
54 | public boolean isDisableTlsVerification() { return disableTlsVerification; }
55 | public Integer getTimeout() { return timeout; }
56 |
57 | // Setters
58 | public void setName(String name) { this.name = name; }
59 | public void setType(APIProvider.ProviderType type) { this.type = type; }
60 | public void setModel(String model) { this.model = model; }
61 | public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; }
62 | public void setUrl(String url) { this.url = url; }
63 | public void setKey(String key) { this.key = key; }
64 | public void setDisableTlsVerification(boolean disableTlsVerification) { this.disableTlsVerification = disableTlsVerification; }
65 | public void setTimeout(Integer timeout) { this.timeout = timeout; }
66 |
67 | /**
68 | * Create a provider using the factory pattern
69 | * @return Configured API provider instance
70 | * @throws RuntimeException if provider creation fails
71 | */
72 | public APIProvider createProvider() {
73 | this.timeout = GhidrAssistPlugin.getGlobalApiTimeout();
74 |
75 | try {
76 | return ProviderRegistry.getInstance().createProvider(this);
77 | } catch (UnsupportedProviderException e) {
78 | throw new IllegalArgumentException("Failed to create provider: " + e.getMessage(), e);
79 | }
80 | }
81 |
82 | /**
83 | * Check if this provider type is supported
84 | * @return true if the provider type is supported
85 | */
86 | public boolean isSupported() {
87 | return ProviderRegistry.getInstance().isSupported(type);
88 | }
89 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/ChatMessage.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider;
2 |
3 | import com.fasterxml.jackson.databind.JsonNode;
4 | import com.google.gson.JsonArray;
5 |
6 | public class ChatMessage {
7 | private String role;
8 | private String content;
9 | private FunctionCall functionCall;
10 | private JsonArray toolCalls; // For assistant messages with tool calls
11 | private String toolCallId; // For tool response messages
12 |
13 | public ChatMessage(String role, String content) {
14 | this.role = role;
15 | this.content = content;
16 | }
17 |
18 | public String getRole() {
19 | return role;
20 | }
21 |
22 | public String getContent() {
23 | return content;
24 | }
25 |
26 | public void setContent(String content) {
27 | this.content = content;
28 | }
29 |
30 | public FunctionCall getFunctionCall() {
31 | return functionCall;
32 | }
33 |
34 | public void setFunctionCall(FunctionCall functionCall) {
35 | this.functionCall = functionCall;
36 | }
37 |
38 | public JsonArray getToolCalls() {
39 | return toolCalls;
40 | }
41 |
42 | public void setToolCalls(JsonArray toolCalls) {
43 | this.toolCalls = toolCalls;
44 | }
45 |
46 | public String getToolCallId() {
47 | return toolCallId;
48 | }
49 |
50 | public void setToolCallId(String toolCallId) {
51 | this.toolCallId = toolCallId;
52 | }
53 |
54 | public static class FunctionCall {
55 | private String name;
56 | private JsonNode arguments;
57 |
58 | public String getName() {
59 | return name;
60 | }
61 |
62 | public JsonNode getArguments() {
63 | return arguments;
64 | }
65 | }
66 |
67 | public static class ChatMessageRole {
68 | public static final String SYSTEM = "system";
69 | public static final String USER = "user";
70 | public static final String ASSISTANT = "assistant";
71 | public static final String FUNCTION = "function";
72 | public static final String TOOL = "tool";
73 | }
74 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/ErrorAction.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider;
2 |
3 | import javax.swing.*;
4 | import java.awt.event.ActionEvent;
5 | import java.awt.event.ActionListener;
6 |
7 | /**
8 | * Represents an action that can be taken in response to an error
9 | */
10 | public class ErrorAction {
11 | private final String actionText;
12 | private final String description;
13 | private final Runnable action;
14 | private final boolean isPrimary;
15 |
16 | public ErrorAction(String actionText, String description, Runnable action, boolean isPrimary) {
17 | this.actionText = actionText;
18 | this.description = description;
19 | this.action = action;
20 | this.isPrimary = isPrimary;
21 | }
22 |
23 | public ErrorAction(String actionText, Runnable action) {
24 | this(actionText, null, action, false);
25 | }
26 |
27 | // Getters
28 | public String getActionText() { return actionText; }
29 | public String getDescription() { return description; }
30 | public Runnable getAction() { return action; }
31 | public boolean isPrimary() { return isPrimary; }
32 |
33 | /**
34 | * Create a button for this action
35 | */
36 | public JButton createButton() {
37 | JButton button = new JButton(actionText);
38 | if (description != null) {
39 | button.setToolTipText(description);
40 | }
41 | button.addActionListener(new ActionListener() {
42 | @Override
43 | public void actionPerformed(ActionEvent e) {
44 | if (action != null) {
45 | try {
46 | action.run();
47 | } catch (Exception ex) {
48 | // Log error but don't propagate to avoid cascading errors
49 | System.err.println("Error executing action: " + ex.getMessage());
50 | }
51 | }
52 | }
53 | });
54 | return button;
55 | }
56 |
57 | // Common action factory methods
58 | public static ErrorAction createSettingsAction(Runnable openSettingsAction) {
59 | return new ErrorAction(
60 | "Open Settings",
61 | "Open the settings dialog to configure API providers",
62 | openSettingsAction,
63 | true
64 | );
65 | }
66 |
67 | public static ErrorAction createRetryAction(Runnable retryAction) {
68 | return new ErrorAction(
69 | "Retry",
70 | "Try the operation again",
71 | retryAction,
72 | true
73 | );
74 | }
75 |
76 | public static ErrorAction createCopyErrorAction(String errorDetails) {
77 | return new ErrorAction(
78 | "Copy Details",
79 | "Copy error details to clipboard",
80 | () -> copyToClipboard(errorDetails),
81 | false
82 | );
83 | }
84 |
85 | public static ErrorAction createSwitchProviderAction(Runnable switchAction) {
86 | return new ErrorAction(
87 | "Switch Provider",
88 | "Try using a different API provider",
89 | switchAction,
90 | false
91 | );
92 | }
93 |
94 | public static ErrorAction createDismissAction() {
95 | return new ErrorAction(
96 | "Dismiss",
97 | "Close this error dialog",
98 | () -> {}, // No-op, dialog will handle dismissal
99 | false
100 | );
101 | }
102 |
103 | private static void copyToClipboard(String text) {
104 | try {
105 | java.awt.datatransfer.StringSelection stringSelection =
106 | new java.awt.datatransfer.StringSelection(text);
107 | java.awt.datatransfer.Clipboard clipboard =
108 | java.awt.Toolkit.getDefaultToolkit().getSystemClipboard();
109 | clipboard.setContents(stringSelection, null);
110 | } catch (Exception e) {
111 | // Silently fail if clipboard is not available
112 | }
113 | }
114 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/RetryHandler.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider;
2 |
3 | import ghidrassist.apiprovider.exceptions.*;
4 | import ghidra.util.Msg;
5 |
6 | import java.util.concurrent.Callable;
7 | import java.util.function.Supplier;
8 |
9 | /**
10 | * Handles retry logic for API provider operations
11 | */
12 | public class RetryHandler {
13 | private static final int DEFAULT_MAX_RETRIES = 3;
14 | private static final int BASE_BACKOFF_MS = 1000; // 1 second
15 | private static final int MAX_BACKOFF_MS = 90000; // 90 seconds
16 |
17 | private final int maxRetries;
18 | private final Object source; // For logging
19 |
20 | public RetryHandler() {
21 | this(DEFAULT_MAX_RETRIES, null);
22 | }
23 |
24 | public RetryHandler(int maxRetries, Object source) {
25 | this.maxRetries = maxRetries;
26 | this.source = source;
27 | }
28 |
29 | /**
30 | * Execute an operation with retry logic
31 | */
32 | public T executeWithRetry(Supplier operation, String operationName) throws APIProviderException {
33 | return executeWithRetryCallable(() -> operation.get(), operationName);
34 | }
35 |
36 | /**
37 | * Execute a callable operation with retry logic
38 | */
39 | public T executeWithRetryCallable(Callable operation, String operationName) throws APIProviderException {
40 | APIProviderException lastException = null;
41 |
42 | for (int attempt = 1; attempt <= maxRetries; attempt++) {
43 | try {
44 | return operation.call();
45 | } catch (APIProviderException e) {
46 | lastException = e;
47 |
48 | if (!shouldRetry(e, attempt)) {
49 | throw e;
50 | }
51 |
52 | logRetryAttempt(operationName, attempt, e);
53 |
54 | if (attempt < maxRetries) {
55 | waitForRetry(e, attempt);
56 | }
57 | } catch (Exception e) {
58 | // Convert non-API exceptions to APIProviderException
59 | throw new APIProviderException(
60 | APIProviderException.ErrorCategory.SERVICE_ERROR,
61 | "Unknown",
62 | operationName,
63 | "Unexpected error: " + e.getMessage()
64 | );
65 | }
66 | }
67 |
68 | // If we get here, all retries failed
69 | throw lastException;
70 | }
71 |
72 | /**
73 | * Execute an operation with retry logic that doesn't return a value
74 | */
75 | public void executeWithRetryRunnable(Runnable operation, String operationName) throws APIProviderException {
76 | executeWithRetryCallable(() -> {
77 | operation.run();
78 | return null;
79 | }, operationName);
80 | }
81 |
82 | private boolean shouldRetry(APIProviderException e, int attempt) {
83 | // Don't retry if we've exceeded max attempts
84 | if (attempt >= maxRetries) {
85 | return false;
86 | }
87 |
88 | // Check if the error is retryable based on category
89 | switch (e.getCategory()) {
90 | case RATE_LIMIT:
91 | case NETWORK:
92 | case TIMEOUT:
93 | case SERVICE_ERROR:
94 | return true;
95 |
96 | case AUTHENTICATION:
97 | case MODEL_ERROR:
98 | case CONFIGURATION:
99 | case RESPONSE_ERROR:
100 | case CANCELLED:
101 | return false;
102 |
103 | default:
104 | // For unknown errors, check the isRetryable flag
105 | return e.isRetryable();
106 | }
107 | }
108 |
109 | private void waitForRetry(APIProviderException e, int attempt) {
110 | int waitTimeMs = calculateWaitTime(e, attempt);
111 |
112 | if (source != null) {
113 | Msg.info(source, String.format("Waiting %d seconds before retry...", waitTimeMs / 1000 ));
114 | }
115 |
116 | try {
117 | Thread.sleep(waitTimeMs);
118 | } catch (InterruptedException ie) {
119 | Thread.currentThread().interrupt();
120 | throw new RuntimeException("Retry interrupted", ie);
121 | }
122 | }
123 |
124 | private int calculateWaitTime(APIProviderException e, int attempt) {
125 | // For rate limit errors, use the provided retry-after if available
126 | if (e.getCategory() == APIProviderException.ErrorCategory.RATE_LIMIT &&
127 | e.getRetryAfterSeconds() != null) {
128 | return e.getRetryAfterSeconds() * 1000;
129 | }
130 |
131 | // For other errors, use exponential backoff with jitter
132 | int backoffMs = BASE_BACKOFF_MS * (int) Math.pow(2, attempt - 1);
133 |
134 | // Add jitter (±25%)
135 | int jitter = (int) (backoffMs * 0.25 * (Math.random() - 0.5));
136 | backoffMs += jitter;
137 |
138 | // Cap at maximum backoff
139 | return Math.min(backoffMs, MAX_BACKOFF_MS);
140 | }
141 |
142 | private void logRetryAttempt(String operationName, int attempt, APIProviderException e) {
143 | if (source != null) {
144 | String message = String.format(
145 | "Retry attempt %d/%d for %s: %s (%s)",
146 | attempt, maxRetries, operationName,
147 | e.getCategory().getDisplayName(), e.getProviderName()
148 | );
149 | Msg.warn(source, message);
150 | }
151 | }
152 |
153 | /**
154 | * Check if an exception indicates a transient error that might succeed on retry
155 | */
156 | public static boolean isTransientError(Throwable error) {
157 | if (error instanceof APIProviderException) {
158 | APIProviderException ape = (APIProviderException) error;
159 | return ape.isRetryable() || isTransientCategory(ape.getCategory());
160 | }
161 |
162 | // Check for common transient error indicators in message
163 | String message = error.getMessage();
164 | if (message != null) {
165 | message = message.toLowerCase();
166 | return message.contains("timeout") ||
167 | message.contains("connection reset") ||
168 | message.contains("temporary") ||
169 | message.contains("service unavailable");
170 | }
171 |
172 | return false;
173 | }
174 |
175 | private static boolean isTransientCategory(APIProviderException.ErrorCategory category) {
176 | switch (category) {
177 | case RATE_LIMIT:
178 | case NETWORK:
179 | case TIMEOUT:
180 | case SERVICE_ERROR:
181 | return true;
182 | default:
183 | return false;
184 | }
185 | }
186 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/capabilities/ChatProvider.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.capabilities;
2 |
3 | import ghidrassist.LlmApi;
4 | import ghidrassist.apiprovider.ChatMessage;
5 | import ghidrassist.apiprovider.exceptions.APIProviderException;
6 |
7 | import java.util.List;
8 |
9 | /**
10 | * Interface for providers that support basic chat completion.
11 | * This is the core capability that all LLM providers should support.
12 | */
13 | public interface ChatProvider {
14 |
15 | /**
16 | * Create a chat completion (blocking/synchronous)
17 | * @param messages The conversation messages
18 | * @return The completion response
19 | * @throws APIProviderException if the request fails
20 | */
21 | String createChatCompletion(List messages) throws APIProviderException;
22 |
23 | /**
24 | * Stream a chat completion (non-blocking/asynchronous)
25 | * @param messages The conversation messages
26 | * @param handler Handler for streaming response chunks
27 | * @throws APIProviderException if the request fails
28 | */
29 | void streamChatCompletion(List messages, LlmApi.LlmResponseHandler handler) throws APIProviderException;
30 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/capabilities/EmbeddingProvider.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.capabilities;
2 |
3 | import ghidrassist.apiprovider.APIProvider;
4 |
5 | /**
6 | * Interface for providers that support text embeddings.
7 | * Not all providers support this capability.
8 | */
9 | public interface EmbeddingProvider {
10 |
11 | /**
12 | * Generate embeddings for text asynchronously
13 | * @param text The text to embed
14 | * @param callback Callback to handle the embedding result
15 | */
16 | void getEmbeddingsAsync(String text, APIProvider.EmbeddingCallback callback);
17 |
18 | /**
19 | * Check if this provider supports embeddings
20 | * @return true if embeddings are supported
21 | */
22 | default boolean supportsEmbeddings() {
23 | return true;
24 | }
25 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/capabilities/FunctionCallingProvider.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.capabilities;
2 |
3 | import ghidrassist.apiprovider.ChatMessage;
4 | import ghidrassist.apiprovider.exceptions.APIProviderException;
5 |
6 | import java.util.List;
7 | import java.util.Map;
8 |
9 | /**
10 | * Interface for providers that support function calling / tool calling.
11 | * Not all providers support this capability.
12 | */
13 | public interface FunctionCallingProvider {
14 |
15 | /**
16 | * Create a chat completion with function calling support
17 | * @param messages The conversation messages
18 | * @param functions Available functions/tools that the model can call
19 | * @return The completion response, potentially containing function calls
20 | * @throws APIProviderException if the request fails
21 | */
22 | String createChatCompletionWithFunctions(List messages, List> functions) throws APIProviderException;
23 |
24 | /**
25 | * Check if this provider supports function calling
26 | * @return true if function calling is supported
27 | */
28 | default boolean supportsFunctionCalling() {
29 | return true;
30 | }
31 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/capabilities/ModelListProvider.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.capabilities;
2 |
3 | import ghidrassist.apiprovider.exceptions.APIProviderException;
4 |
5 | import java.util.List;
6 |
7 | /**
8 | * Interface for providers that can list available models.
9 | * Not all providers support this capability.
10 | */
11 | public interface ModelListProvider {
12 |
13 | /**
14 | * Get list of available models from this provider
15 | * @return List of model identifiers
16 | * @throws APIProviderException if the request fails
17 | */
18 | List getAvailableModels() throws APIProviderException;
19 |
20 | /**
21 | * Check if this provider supports model listing
22 | * @return true if model listing is supported
23 | */
24 | default boolean supportsModelListing() {
25 | return true;
26 | }
27 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/exceptions/APIProviderException.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.exceptions;
2 |
3 | /**
4 | * Base exception for all API provider errors with structured error information
5 | */
6 | public class APIProviderException extends Exception {
7 | private final ErrorCategory category;
8 | private final String providerName;
9 | private final String operation;
10 | private final int httpStatusCode;
11 | private final String apiErrorCode;
12 | private final boolean isRetryable;
13 | private final Integer retryAfterSeconds;
14 |
15 | public enum ErrorCategory {
16 | AUTHENTICATION("Authentication Error", "Check your API key and credentials"),
17 | NETWORK("Network Error", "Check your internet connection and API URL"),
18 | RATE_LIMIT("Rate Limit Exceeded", "Too many requests - please wait before retrying"),
19 | MODEL_ERROR("Model Error", "Issue with the specified model or unsupported feature"),
20 | CONFIGURATION("Configuration Error", "Invalid settings or configuration"),
21 | RESPONSE_ERROR("Response Error", "Invalid or unexpected response from API"),
22 | SERVICE_ERROR("Service Error", "API service is experiencing issues"),
23 | TIMEOUT("Timeout Error", "Request took too long to complete"),
24 | CANCELLED("Request Cancelled", "Operation was cancelled");
25 |
26 | private final String displayName;
27 | private final String description;
28 |
29 | ErrorCategory(String displayName, String description) {
30 | this.displayName = displayName;
31 | this.description = description;
32 | }
33 |
34 | public String getDisplayName() { return displayName; }
35 | public String getDescription() { return description; }
36 | }
37 |
38 | public APIProviderException(ErrorCategory category, String providerName, String operation,
39 | String message) {
40 | this(category, providerName, operation, -1, null, message, false, null, null);
41 | }
42 |
43 | public APIProviderException(ErrorCategory category, String providerName, String operation,
44 | int httpStatusCode, String apiErrorCode, String message) {
45 | this(category, providerName, operation, httpStatusCode, apiErrorCode, message, false, null, null);
46 | }
47 |
48 | public APIProviderException(ErrorCategory category, String providerName, String operation,
49 | int httpStatusCode, String apiErrorCode, String message,
50 | boolean isRetryable, Integer retryAfterSeconds, Throwable cause) {
51 | super(message, cause);
52 | this.category = category;
53 | this.providerName = providerName;
54 | this.operation = operation;
55 | this.httpStatusCode = httpStatusCode;
56 | this.apiErrorCode = apiErrorCode;
57 | this.isRetryable = isRetryable;
58 | this.retryAfterSeconds = retryAfterSeconds;
59 | }
60 |
61 | // Getters
62 | public ErrorCategory getCategory() { return category; }
63 | public String getProviderName() { return providerName; }
64 | public String getOperation() { return operation; }
65 | public int getHttpStatusCode() { return httpStatusCode; }
66 | public String getApiErrorCode() { return apiErrorCode; }
67 | public boolean isRetryable() { return isRetryable; }
68 | public Integer getRetryAfterSeconds() { return retryAfterSeconds; }
69 |
70 | /**
71 | * Get technical details for debugging
72 | */
73 | public String getTechnicalDetails() {
74 | StringBuilder details = new StringBuilder();
75 | details.append("Provider: ").append(providerName).append("\n");
76 | details.append("Operation: ").append(operation).append("\n");
77 | details.append("Category: ").append(category.getDisplayName()).append("\n");
78 |
79 | if (httpStatusCode > 0) {
80 | details.append("HTTP Status: ").append(httpStatusCode).append("\n");
81 | }
82 |
83 | if (apiErrorCode != null && !apiErrorCode.isEmpty()) {
84 | details.append("API Error Code: ").append(apiErrorCode).append("\n");
85 | }
86 |
87 | if (getMessage() != null) {
88 | details.append("Message: ").append(getMessage()).append("\n");
89 | }
90 |
91 | if (getCause() != null) {
92 | details.append("Cause: ").append(getCause().getClass().getSimpleName()).append("\n");
93 | }
94 |
95 | return details.toString();
96 | }
97 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/exceptions/AuthenticationException.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.exceptions;
2 |
3 | /**
4 | * Exception for authentication and authorization failures
5 | */
6 | public class AuthenticationException extends APIProviderException {
7 |
8 | public AuthenticationException(String providerName, String operation, String message) {
9 | super(ErrorCategory.AUTHENTICATION, providerName, operation, message);
10 | }
11 |
12 | public AuthenticationException(String providerName, String operation, int httpStatusCode,
13 | String apiErrorCode, String message) {
14 | super(ErrorCategory.AUTHENTICATION, providerName, operation, httpStatusCode, apiErrorCode,
15 | message, false, null, null);
16 | }
17 |
18 | public AuthenticationException(String providerName, String operation, String message, Throwable cause) {
19 | super(ErrorCategory.AUTHENTICATION, providerName, operation, -1, null, message, false, null, cause);
20 | }
21 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/exceptions/ModelException.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.exceptions;
2 |
3 | /**
4 | * Exception for model-related errors
5 | */
6 | public class ModelException extends APIProviderException {
7 |
8 | public enum ModelErrorType {
9 | MODEL_NOT_FOUND("The specified model was not found or is not available"),
10 | UNSUPPORTED_FEATURE("The model does not support this feature"),
11 | CONTEXT_LENGTH_EXCEEDED("Input exceeds the model's maximum context length"),
12 | TOKEN_LIMIT_EXCEEDED("Response would exceed the maximum token limit"),
13 | MODEL_OVERLOADED("The model is currently overloaded");
14 |
15 | private final String description;
16 |
17 | ModelErrorType(String description) {
18 | this.description = description;
19 | }
20 |
21 | public String getDescription() { return description; }
22 | }
23 |
24 | private final ModelErrorType modelErrorType;
25 |
26 | public ModelException(String providerName, String operation, ModelErrorType errorType) {
27 | super(ErrorCategory.MODEL_ERROR, providerName, operation, errorType.getDescription());
28 | this.modelErrorType = errorType;
29 | }
30 |
31 | public ModelException(String providerName, String operation, ModelErrorType errorType,
32 | int httpStatusCode, String apiErrorCode) {
33 | super(ErrorCategory.MODEL_ERROR, providerName, operation, httpStatusCode, apiErrorCode,
34 | errorType.getDescription(), false, null, null);
35 | this.modelErrorType = errorType;
36 | }
37 |
38 | public ModelException(String providerName, String operation, String message) {
39 | super(ErrorCategory.MODEL_ERROR, providerName, operation, message);
40 | this.modelErrorType = null;
41 | }
42 |
43 | public ModelErrorType getModelErrorType() {
44 | return modelErrorType;
45 | }
46 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/exceptions/NetworkException.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.exceptions;
2 |
3 | /**
4 | * Exception for network-related failures
5 | */
6 | public class NetworkException extends APIProviderException {
7 |
8 | public enum NetworkErrorType {
9 | CONNECTION_FAILED("Cannot connect to server"),
10 | TIMEOUT("Request timed out"),
11 | SSL_ERROR("SSL/TLS connection failed"),
12 | DNS_ERROR("Cannot resolve hostname"),
13 | CONNECTION_LOST("Connection was lost during request");
14 |
15 | private final String description;
16 |
17 | NetworkErrorType(String description) {
18 | this.description = description;
19 | }
20 |
21 | public String getDescription() { return description; }
22 | }
23 |
24 | private final NetworkErrorType networkErrorType;
25 |
26 | public NetworkException(String providerName, String operation, NetworkErrorType errorType) {
27 | super(ErrorCategory.NETWORK, providerName, operation, errorType.getDescription());
28 | this.networkErrorType = errorType;
29 | }
30 |
31 | public NetworkException(String providerName, String operation, NetworkErrorType errorType,
32 | Throwable cause) {
33 | super(ErrorCategory.NETWORK, providerName, operation, -1, null, errorType.getDescription(),
34 | true, null, cause);
35 | this.networkErrorType = errorType;
36 | }
37 |
38 | public NetworkException(String providerName, String operation, String message) {
39 | super(ErrorCategory.NETWORK, providerName, operation, message);
40 | this.networkErrorType = null;
41 | }
42 |
43 | public NetworkErrorType getNetworkErrorType() {
44 | return networkErrorType;
45 | }
46 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/exceptions/RateLimitException.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.exceptions;
2 |
3 | /**
4 | * Exception for rate limiting errors
5 | */
6 | public class RateLimitException extends APIProviderException {
7 |
8 | public RateLimitException(String providerName, String operation, Integer retryAfterSeconds) {
9 | super(ErrorCategory.RATE_LIMIT, providerName, operation, 429, "rate_limit_exceeded",
10 | "Rate limit exceeded. Please wait before retrying.", true, retryAfterSeconds, null);
11 | }
12 |
13 | public RateLimitException(String providerName, String operation, String message,
14 | Integer retryAfterSeconds) {
15 | super(ErrorCategory.RATE_LIMIT, providerName, operation, 429, "rate_limit_exceeded",
16 | message, true, retryAfterSeconds, null);
17 | }
18 |
19 | public RateLimitException(String providerName, String operation, int httpStatusCode,
20 | String apiErrorCode, String message, Integer retryAfterSeconds) {
21 | super(ErrorCategory.RATE_LIMIT, providerName, operation, httpStatusCode, apiErrorCode,
22 | message, true, retryAfterSeconds, null);
23 | }
24 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/exceptions/ResponseException.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.exceptions;
2 |
3 | /**
4 | * Exception for response parsing and format errors
5 | */
6 | public class ResponseException extends APIProviderException {
7 |
8 | public enum ResponseErrorType {
9 | MALFORMED_JSON("Response contains invalid JSON"),
10 | MISSING_REQUIRED_FIELD("Required field missing from response"),
11 | UNEXPECTED_FORMAT("Response format is not as expected"),
12 | EMPTY_RESPONSE("Received empty response"),
13 | STREAM_INTERRUPTED("Response stream was interrupted");
14 |
15 | private final String description;
16 |
17 | ResponseErrorType(String description) {
18 | this.description = description;
19 | }
20 |
21 | public String getDescription() { return description; }
22 | }
23 |
24 | private final ResponseErrorType responseErrorType;
25 |
26 | public ResponseException(String providerName, String operation, ResponseErrorType errorType) {
27 | super(ErrorCategory.RESPONSE_ERROR, providerName, operation, errorType.getDescription());
28 | this.responseErrorType = errorType;
29 | }
30 |
31 | public ResponseException(String providerName, String operation, ResponseErrorType errorType,
32 | Throwable cause) {
33 | super(ErrorCategory.RESPONSE_ERROR, providerName, operation, -1, null,
34 | errorType.getDescription(), false, null, cause);
35 | this.responseErrorType = errorType;
36 | }
37 |
38 | public ResponseException(String providerName, String operation, String message) {
39 | super(ErrorCategory.RESPONSE_ERROR, providerName, operation, message);
40 | this.responseErrorType = null;
41 | }
42 |
43 | public ResponseErrorType getResponseErrorType() {
44 | return responseErrorType;
45 | }
46 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/exceptions/StreamCancelledException.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.exceptions;
2 |
3 | /**
4 | * Exception for stream cancellation scenarios
5 | */
6 | public class StreamCancelledException extends APIProviderException {
7 |
8 | public enum CancellationReason {
9 | USER_REQUESTED("User cancelled the request"),
10 | TIMEOUT("Request timed out"),
11 | CONNECTION_LOST("Network connection was lost"),
12 | PROVIDER_ERROR("API provider terminated the stream"),
13 | SHUTDOWN("Application is shutting down");
14 |
15 | private final String description;
16 |
17 | CancellationReason(String description) {
18 | this.description = description;
19 | }
20 |
21 | public String getDescription() { return description; }
22 | }
23 |
24 | private final CancellationReason cancellationReason;
25 |
26 | public StreamCancelledException(String providerName, String operation, CancellationReason reason) {
27 | super(ErrorCategory.CANCELLED, providerName, operation, reason.getDescription());
28 | this.cancellationReason = reason;
29 | }
30 |
31 | public StreamCancelledException(String providerName, String operation, CancellationReason reason,
32 | Throwable cause) {
33 | super(ErrorCategory.CANCELLED, providerName, operation, -1, null, reason.getDescription(),
34 | false, null, cause);
35 | this.cancellationReason = reason;
36 | }
37 |
38 | public CancellationReason getCancellationReason() {
39 | return cancellationReason;
40 | }
41 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/factory/APIProviderFactory.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.factory;
2 |
3 | import ghidrassist.apiprovider.APIProvider;
4 | import ghidrassist.apiprovider.APIProviderConfig;
5 |
6 | /**
7 | * Factory interface for creating API providers.
8 | * Follows the Factory Method pattern to allow extensibility.
9 | */
10 | public interface APIProviderFactory {
11 |
12 | /**
13 | * Create an API provider instance from configuration
14 | * @param config The provider configuration
15 | * @return A configured API provider instance
16 | * @throws UnsupportedProviderException if this factory cannot create the requested provider type
17 | */
18 | APIProvider createProvider(APIProviderConfig config) throws UnsupportedProviderException;
19 |
20 | /**
21 | * Check if this factory supports creating the given provider type
22 | * @param type The provider type to check
23 | * @return true if this factory can create providers of the given type
24 | */
25 | boolean supports(APIProvider.ProviderType type);
26 |
27 | /**
28 | * Get the provider type this factory creates
29 | * @return The provider type
30 | */
31 | APIProvider.ProviderType getProviderType();
32 |
33 | /**
34 | * Get a human-readable name for this factory
35 | * @return Factory name
36 | */
37 | String getFactoryName();
38 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/factory/AnthropicProviderFactory.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.factory;
2 |
3 | import ghidrassist.apiprovider.APIProvider;
4 | import ghidrassist.apiprovider.APIProviderConfig;
5 | import ghidrassist.apiprovider.AnthropicProvider;
6 |
7 | /**
8 | * Factory for creating Anthropic API providers.
9 | */
10 | public class AnthropicProviderFactory implements APIProviderFactory {
11 |
12 | @Override
13 | public APIProvider createProvider(APIProviderConfig config) throws UnsupportedProviderException {
14 | if (!supports(config.getType())) {
15 | throw new UnsupportedProviderException(config.getType(), getFactoryName());
16 | }
17 |
18 | return new AnthropicProvider(
19 | config.getName(),
20 | config.getModel(),
21 | config.getMaxTokens(),
22 | config.getUrl(),
23 | config.getKey(),
24 | config.isDisableTlsVerification(),
25 | config.getTimeout()
26 | );
27 | }
28 |
29 | @Override
30 | public boolean supports(APIProvider.ProviderType type) {
31 | return type == APIProvider.ProviderType.ANTHROPIC;
32 | }
33 |
34 | @Override
35 | public APIProvider.ProviderType getProviderType() {
36 | return APIProvider.ProviderType.ANTHROPIC;
37 | }
38 |
39 | @Override
40 | public String getFactoryName() {
41 | return "AnthropicProviderFactory";
42 | }
43 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/factory/LMStudioProviderFactory.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.factory;
2 |
3 | import ghidrassist.apiprovider.APIProvider;
4 | import ghidrassist.apiprovider.APIProviderConfig;
5 | import ghidrassist.apiprovider.LMStudioProvider;
6 |
7 | /**
8 | * Factory for creating LM Studio API providers.
9 | */
10 | public class LMStudioProviderFactory implements APIProviderFactory {
11 |
12 | @Override
13 | public APIProvider createProvider(APIProviderConfig config) throws UnsupportedProviderException {
14 | if (!supports(config.getType())) {
15 | throw new UnsupportedProviderException(config.getType(), getFactoryName());
16 | }
17 |
18 | return new LMStudioProvider(
19 | config.getName(),
20 | config.getModel(),
21 | config.getMaxTokens(),
22 | config.getUrl(),
23 | config.getKey(),
24 | config.isDisableTlsVerification(),
25 | config.getTimeout()
26 | );
27 | }
28 |
29 | @Override
30 | public boolean supports(APIProvider.ProviderType type) {
31 | return type == APIProvider.ProviderType.LMSTUDIO;
32 | }
33 |
34 | @Override
35 | public APIProvider.ProviderType getProviderType() {
36 | return APIProvider.ProviderType.LMSTUDIO;
37 | }
38 |
39 | @Override
40 | public String getFactoryName() {
41 | return "LMStudioProviderFactory";
42 | }
43 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/factory/OllamaProviderFactory.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.factory;
2 |
3 | import ghidrassist.apiprovider.APIProvider;
4 | import ghidrassist.apiprovider.APIProviderConfig;
5 | import ghidrassist.apiprovider.OllamaProvider;
6 |
7 | /**
8 | * Factory for creating Ollama API providers.
9 | */
10 | public class OllamaProviderFactory implements APIProviderFactory {
11 |
12 | @Override
13 | public APIProvider createProvider(APIProviderConfig config) throws UnsupportedProviderException {
14 | if (!supports(config.getType())) {
15 | throw new UnsupportedProviderException(config.getType(), getFactoryName());
16 | }
17 |
18 | return new OllamaProvider(
19 | config.getName(),
20 | config.getModel(),
21 | config.getMaxTokens(),
22 | config.getUrl(),
23 | config.getKey(),
24 | config.isDisableTlsVerification(),
25 | config.getTimeout()
26 | );
27 | }
28 |
29 | @Override
30 | public boolean supports(APIProvider.ProviderType type) {
31 | return type == APIProvider.ProviderType.OLLAMA;
32 | }
33 |
34 | @Override
35 | public APIProvider.ProviderType getProviderType() {
36 | return APIProvider.ProviderType.OLLAMA;
37 | }
38 |
39 | @Override
40 | public String getFactoryName() {
41 | return "OllamaProviderFactory";
42 | }
43 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/factory/OpenAIProviderFactory.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.factory;
2 |
3 | import ghidrassist.apiprovider.APIProvider;
4 | import ghidrassist.apiprovider.APIProviderConfig;
5 | import ghidrassist.apiprovider.OpenAIProvider;
6 |
7 | /**
8 | * Factory for creating OpenAI API providers.
9 | */
10 | public class OpenAIProviderFactory implements APIProviderFactory {
11 |
12 | @Override
13 | public APIProvider createProvider(APIProviderConfig config) throws UnsupportedProviderException {
14 | if (!supports(config.getType())) {
15 | throw new UnsupportedProviderException(config.getType(), getFactoryName());
16 | }
17 |
18 | return new OpenAIProvider(
19 | config.getName(),
20 | config.getModel(),
21 | config.getMaxTokens(),
22 | config.getUrl(),
23 | config.getKey(),
24 | config.isDisableTlsVerification(),
25 | config.getTimeout()
26 | );
27 | }
28 |
29 | @Override
30 | public boolean supports(APIProvider.ProviderType type) {
31 | return type == APIProvider.ProviderType.OPENAI;
32 | }
33 |
34 | @Override
35 | public APIProvider.ProviderType getProviderType() {
36 | return APIProvider.ProviderType.OPENAI;
37 | }
38 |
39 | @Override
40 | public String getFactoryName() {
41 | return "OpenAIProviderFactory";
42 | }
43 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/factory/OpenWebUiProviderFactory.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.factory;
2 |
3 | import ghidrassist.apiprovider.APIProvider;
4 | import ghidrassist.apiprovider.APIProviderConfig;
5 | import ghidrassist.apiprovider.OpenWebUiProvider;
6 |
7 | /**
8 | * Factory for creating OpenWebUI API providers.
9 | */
10 | public class OpenWebUiProviderFactory implements APIProviderFactory {
11 |
12 | @Override
13 | public APIProvider createProvider(APIProviderConfig config) throws UnsupportedProviderException {
14 | if (!supports(config.getType())) {
15 | throw new UnsupportedProviderException(config.getType(), getFactoryName());
16 | }
17 |
18 | return new OpenWebUiProvider(
19 | config.getName(),
20 | config.getModel(),
21 | config.getMaxTokens(),
22 | config.getUrl(),
23 | config.getKey(),
24 | config.isDisableTlsVerification(),
25 | config.getTimeout()
26 | );
27 | }
28 |
29 | @Override
30 | public boolean supports(APIProvider.ProviderType type) {
31 | return type == APIProvider.ProviderType.OPENWEBUI;
32 | }
33 |
34 | @Override
35 | public APIProvider.ProviderType getProviderType() {
36 | return APIProvider.ProviderType.OPENWEBUI;
37 | }
38 |
39 | @Override
40 | public String getFactoryName() {
41 | return "OpenWebUiProviderFactory";
42 | }
43 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/factory/ProviderRegistry.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.factory;
2 |
3 | import ghidrassist.apiprovider.APIProvider;
4 | import ghidrassist.apiprovider.APIProviderConfig;
5 |
6 | import java.util.*;
7 | import java.util.concurrent.ConcurrentHashMap;
8 |
9 | /**
10 | * Registry for API provider factories.
11 | * Manages the creation of providers through registered factories.
12 | * Thread-safe and follows the Registry pattern.
13 | */
14 | public class ProviderRegistry {
15 |
16 | private final Map factories = new ConcurrentHashMap<>();
17 | private static final ProviderRegistry INSTANCE = new ProviderRegistry();
18 |
19 | /**
20 | * Get the singleton instance of the provider registry
21 | */
22 | public static ProviderRegistry getInstance() {
23 | return INSTANCE;
24 | }
25 |
26 | /**
27 | * Private constructor for singleton
28 | */
29 | private ProviderRegistry() {
30 | // Register default factories
31 | registerDefaultFactories();
32 | }
33 |
34 | /**
35 | * Register a factory for a specific provider type
36 | * @param factory The factory to register
37 | */
38 | public void registerFactory(APIProviderFactory factory) {
39 | if (factory == null) {
40 | throw new IllegalArgumentException("Factory cannot be null");
41 | }
42 |
43 | APIProvider.ProviderType type = factory.getProviderType();
44 | if (type == null) {
45 | throw new IllegalArgumentException("Factory must specify a provider type");
46 | }
47 |
48 | factories.put(type, factory);
49 | }
50 |
51 | /**
52 | * Unregister a factory for a specific provider type
53 | * @param type The provider type to unregister
54 | * @return The previously registered factory, or null if none was registered
55 | */
56 | public APIProviderFactory unregisterFactory(APIProvider.ProviderType type) {
57 | return factories.remove(type);
58 | }
59 |
60 | /**
61 | * Create a provider using the appropriate factory
62 | * @param config The provider configuration
63 | * @return A configured provider instance
64 | * @throws UnsupportedProviderException if no factory is registered for the provider type
65 | */
66 | public APIProvider createProvider(APIProviderConfig config) throws UnsupportedProviderException {
67 | if (config == null) {
68 | throw new IllegalArgumentException("Provider config cannot be null");
69 | }
70 |
71 | APIProvider.ProviderType type = config.getType();
72 | APIProviderFactory factory = factories.get(type);
73 |
74 | if (factory == null) {
75 | throw new UnsupportedProviderException(type, "ProviderRegistry",
76 | "No factory registered for provider type: " + type);
77 | }
78 |
79 | return factory.createProvider(config);
80 | }
81 |
82 | /**
83 | * Check if a provider type is supported
84 | * @param type The provider type to check
85 | * @return true if a factory is registered for this type
86 | */
87 | public boolean isSupported(APIProvider.ProviderType type) {
88 | return factories.containsKey(type);
89 | }
90 |
91 | /**
92 | * Get all supported provider types
93 | * @return Set of supported provider types
94 | */
95 | public Set getSupportedTypes() {
96 | return new HashSet<>(factories.keySet());
97 | }
98 |
99 | /**
100 | * Get all registered factories
101 | * @return Map of provider types to their factories
102 | */
103 | public Map getRegisteredFactories() {
104 | return new HashMap<>(factories);
105 | }
106 |
107 | /**
108 | * Get the factory for a specific provider type
109 | * @param type The provider type
110 | * @return The factory, or null if none is registered
111 | */
112 | public APIProviderFactory getFactory(APIProvider.ProviderType type) {
113 | return factories.get(type);
114 | }
115 |
116 | /**
117 | * Clear all registered factories (mainly for testing)
118 | */
119 | public void clearFactories() {
120 | factories.clear();
121 | }
122 |
123 | /**
124 | * Register the default built-in factories
125 | */
126 | private void registerDefaultFactories() {
127 | registerFactory(new AnthropicProviderFactory());
128 | registerFactory(new OpenAIProviderFactory());
129 | registerFactory(new OllamaProviderFactory());
130 | registerFactory(new LMStudioProviderFactory());
131 | registerFactory(new OpenWebUiProviderFactory());
132 | }
133 |
134 | /**
135 | * Get information about all registered factories
136 | * @return Human-readable string describing registered factories
137 | */
138 | public String getRegistryInfo() {
139 | StringBuilder sb = new StringBuilder();
140 | sb.append("Registered Provider Factories:\n");
141 |
142 | for (Map.Entry entry : factories.entrySet()) {
143 | sb.append(String.format(" %s -> %s\n",
144 | entry.getKey(),
145 | entry.getValue().getFactoryName()));
146 | }
147 |
148 | return sb.toString();
149 | }
150 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/apiprovider/factory/UnsupportedProviderException.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.apiprovider.factory;
2 |
3 | import ghidrassist.apiprovider.APIProvider;
4 |
5 | /**
6 | * Exception thrown when a factory cannot create a requested provider type.
7 | */
8 | public class UnsupportedProviderException extends Exception {
9 |
10 | private final APIProvider.ProviderType requestedType;
11 | private final String factoryName;
12 |
13 | public UnsupportedProviderException(APIProvider.ProviderType requestedType, String factoryName) {
14 | super(String.format("Factory '%s' does not support provider type '%s'", factoryName, requestedType));
15 | this.requestedType = requestedType;
16 | this.factoryName = factoryName;
17 | }
18 |
19 | public UnsupportedProviderException(APIProvider.ProviderType requestedType, String factoryName, String message) {
20 | super(message);
21 | this.requestedType = requestedType;
22 | this.factoryName = factoryName;
23 | }
24 |
25 | public UnsupportedProviderException(APIProvider.ProviderType requestedType, String factoryName, String message, Throwable cause) {
26 | super(message, cause);
27 | this.requestedType = requestedType;
28 | this.factoryName = factoryName;
29 | }
30 |
31 | public APIProvider.ProviderType getRequestedType() {
32 | return requestedType;
33 | }
34 |
35 | public String getFactoryName() {
36 | return factoryName;
37 | }
38 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/core/ActionConstants.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.core;
2 |
3 | import java.util.*;
4 |
5 | public class ActionConstants {
6 |
7 | public static final List> FN_TEMPLATES = Arrays.asList(
8 | createFunctionTemplate(
9 | "rename_function",
10 | "Rename a function",
11 | createParameters(
12 | createParameter("new_name", "string", "The new name for the function. (e.g., recv_data)")
13 | )
14 | ),
15 | createFunctionTemplate(
16 | "rename_variable",
17 | "Rename a variable within a function",
18 | createParameters(
19 | createParameter("func_name", "string", "The name of the function containing the variable. (e.g., sub_40001234)"),
20 | createParameter("var_name", "string", "The current name of the variable. (e.g., var_20)"),
21 | createParameter("new_name", "string", "The new name for the variable. (e.g., recv_buf)")
22 | )
23 | ),
24 | createFunctionTemplate(
25 | "retype_variable",
26 | "Set a variable data type within a function",
27 | createParameters(
28 | createParameter("func_name", "string", "The name of the function containing the variable. (e.g., sub_40001234)"),
29 | createParameter("var_name", "string", "The current name of the variable. (e.g., rax_12)"),
30 | createParameter("new_type", "string", "The new type for the variable. (e.g., int32_t)")
31 | )
32 | ),
33 | createFunctionTemplate(
34 | "auto_create_struct",
35 | "Automatically create a structure datatype from a variable given its offset uses in a given function.",
36 | createParameters(
37 | createParameter("func_name", "string", "The name of the function containing the variable. (e.g., sub_40001234)"),
38 | createParameter("var_name", "string", "The current name of the variable. (e.g., rax_12)")
39 | )
40 | )
41 | );
42 |
43 | public static final Map ACTION_PROMPTS = new HashMap<>();
44 |
45 | static {
46 | ACTION_PROMPTS.put("rename_function",
47 | "Use the 'rename_function' tool:\n```\n{code}\n```\n" +
48 | "Examine the code functionality, strings and log parameters.\n" +
49 | "If you detect C++ Super::Derived::Method or Class::Method style class names, recommend that name first, OTHERWISE USE PROCEDURAL NAMING.\n" +
50 | "CREATE A JSON TOOL_CALL LIST WITH SUGGESTIONS FOR THREE POSSIBLE FUNCTION NAMES " +
51 | "THAT ALIGN AS CLOSELY AS POSSIBLE TO WHAT THE CODE ABOVE DOES.\n" +
52 | "RESPOND ONLY WITH THE RENAME_FUNCTION PARAMETER (new_name). DO NOT INCLUDE ANY OTHER TEXT.\n" +
53 | "ALL JSON MUST BE PROPERLY FORMATTED WITH NO EMBEDDED COMMENTS.\n"
54 | );
55 | ACTION_PROMPTS.put("rename_variable",
56 | "Use the 'rename_variable' tool:\n```\n{code}\n```\n" +
57 | "Examine the code functionality, strings, and log parameters.\n" +
58 | "SUGGEST VARIABLE NAMES THAT BETTER ALIGN WITH THE CODE FUNCTIONALITY.\n" +
59 | "RESPOND ONLY WITH THE RENAME_VARIABLE PARAMETERS (func_name, var_name, new_name). DO NOT INCLUDE ANY OTHER TEXT.\n" +
60 | "ALL JSON VALUES MUST BE TEXT STRINGS, INCLUDING NUMBERS AND ADDRESSES, e.g., \"0x1234abcd\".\n" +
61 | "ALL JSON MUST BE PROPERLY FORMATTED WITH NO EMBEDDED COMMENTS.\n"
62 | );
63 | ACTION_PROMPTS.put("retype_variable",
64 | "Use the 'retype_variable' tool:\n```\n{code}\n```\n" +
65 | "Examine the code functionality, strings, and log parameters.\n" +
66 | "SUGGEST VARIABLE TYPES THAT BETTER ALIGN WITH THE CODE FUNCTIONALITY.\n" +
67 | "RESPOND ONLY WITH THE RETYPE_VARIABLE PARAMETERS (func_name, var_name, new_type). DO NOT INCLUDE ANY OTHER TEXT.\n" +
68 | "ALL JSON VALUES MUST BE TEXT STRINGS, INCLUDING NUMBERS AND ADDRESSES, e.g., \"0x1234abcd\".\n" +
69 | "ALL JSON MUST BE PROPERLY FORMATTED WITH NO EMBEDDED COMMENTS.\n"
70 | );
71 | ACTION_PROMPTS.put("auto_create_struct",
72 | "Use the 'auto_create_struct' tool:\n```\n{code}\n```\n" +
73 | "Examine the code functionality, parameters, and variables being used.\n" +
74 | "IF YOU DETECT A VARIABLE THAT USES OFFSET ACCESS SUCH AS `*(arg1 + 0xc)` OR VARIABLES LIKELY TO BE STRUCTURES OR CLASSES,\n" +
75 | "RESPOND ONLY WITH THE AUTO_CREATE_STRUCT PARAMETERS (func_name, var_name). DO NOT INCLUDE ANY OTHER TEXT.\n" +
76 | "ALL JSON VALUES MUST BE TEXT STRINGS, INCLUDING NUMBERS AND ADDRESSES, e.g., \"0x1234abcd\".\n" +
77 | "ALL JSON MUST BE PROPERLY FORMATTED WITH NO EMBEDDED COMMENTS.\n"
78 | );
79 | }
80 |
81 | // Helper methods for creating function templates
82 | private static Map createFunctionTemplate(String name, String description, Map parameters) {
83 | Map functionMap = new HashMap<>();
84 | functionMap.put("name", name);
85 | functionMap.put("description", description);
86 | functionMap.put("parameters", parameters);
87 |
88 | Map template = new HashMap<>();
89 | template.put("type", "function");
90 | template.put("function", functionMap);
91 | return template;
92 | }
93 |
94 | private static Map createParameters(Map... parameters) {
95 | Map parametersMap = new HashMap<>();
96 | parametersMap.put("type", "object");
97 |
98 | Map properties = new HashMap<>();
99 | List required = new ArrayList<>();
100 |
101 | for (Map param : parameters) {
102 | String name = (String) param.get("name");
103 | properties.put(name, param);
104 | required.add(name);
105 | }
106 |
107 | parametersMap.put("properties", properties);
108 | parametersMap.put("required", required);
109 | return parametersMap;
110 | }
111 |
112 | private static Map createParameter(String name, String type, String description) {
113 | Map param = new HashMap<>();
114 | param.put("name", name);
115 | param.put("type", type);
116 | param.put("description", description);
117 | return param;
118 | }
119 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/core/CodeUtils.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.core;
2 |
3 | import ghidra.app.decompiler.ClangNode;
4 | import ghidra.app.decompiler.ClangToken;
5 | import ghidra.app.decompiler.ClangTokenGroup;
6 | import ghidra.app.decompiler.DecompInterface;
7 | import ghidra.app.decompiler.DecompileResults;
8 | import ghidra.program.model.address.Address;
9 | import ghidra.program.model.listing.Function;
10 | import ghidra.program.model.listing.Instruction;
11 | import ghidra.program.model.listing.InstructionIterator;
12 | import ghidra.program.model.listing.Listing;
13 | import ghidra.program.model.listing.Program;
14 | import ghidra.util.task.TaskMonitor;
15 |
16 | public class CodeUtils {
17 |
18 | /**
19 | * Gets the decompiled code for a function.
20 | * @param function The function to decompile
21 | * @param monitor Task monitor for tracking progress
22 | * @return The decompiled code as a string
23 | */
24 | public static String getFunctionCode(Function function, TaskMonitor monitor) {
25 | DecompInterface decompiler = new DecompInterface();
26 | decompiler.openProgram(function.getProgram());
27 |
28 | try {
29 | DecompileResults results = decompiler.decompileFunction(function, 60, monitor);
30 | if (results != null && results.decompileCompleted()) {
31 | return results.getDecompiledFunction().getC();
32 | } else {
33 | return "Failed to decompile function.";
34 | }
35 | } catch (Exception e) {
36 | return "Failed to decompile function: " + e.getMessage();
37 | } finally {
38 | decompiler.dispose();
39 | }
40 | }
41 |
42 | /**
43 | * Gets the disassembly for a function.
44 | * @param function The function to disassemble
45 | * @return The disassembled code as a string
46 | */
47 | public static String getFunctionDisassembly(Function function) {
48 | StringBuilder sb = new StringBuilder();
49 | Listing listing = function.getProgram().getListing();
50 | InstructionIterator instructions = listing.getInstructions(function.getBody(), true);
51 |
52 | while (instructions.hasNext()) {
53 | Instruction instr = instructions.next();
54 | sb.append(formatInstruction(instr)).append("\n");
55 | }
56 |
57 | return sb.toString();
58 | }
59 |
60 | /**
61 | * Gets the decompiled code for a specific line at an address.
62 | * @param address The address to get code for
63 | * @param monitor Task monitor for tracking progress
64 | * @param program The current program
65 | * @return The decompiled line as a string
66 | */
67 | public static String getLineCode(Address address, TaskMonitor monitor, Program program) {
68 | DecompInterface decompiler = new DecompInterface();
69 | decompiler.openProgram(program);
70 |
71 | try {
72 | Function function = program.getFunctionManager().getFunctionContaining(address);
73 | if (function == null) {
74 | return "No function containing the address.";
75 | }
76 |
77 | DecompileResults results = decompiler.decompileFunction(function, 60, monitor);
78 | if (results != null && results.decompileCompleted()) {
79 | ClangTokenGroup tokens = results.getCCodeMarkup();
80 | if (tokens != null) {
81 | StringBuilder codeLineBuilder = new StringBuilder();
82 | boolean found = collectCodeLine(tokens, address, codeLineBuilder);
83 | if (found && codeLineBuilder.length() > 0) {
84 | return codeLineBuilder.toString();
85 | } else {
86 | return "No code line found at the address.";
87 | }
88 | } else {
89 | return "Failed to get code tokens.";
90 | }
91 | } else {
92 | return "Failed to decompile function.";
93 | }
94 | } catch (Exception e) {
95 | return "Failed to decompile line: " + e.getMessage();
96 | } finally {
97 | decompiler.dispose();
98 | }
99 | }
100 |
101 | /**
102 | * Gets the disassembly for a specific address.
103 | * @param address The address to get disassembly for
104 | * @param program The current program
105 | * @return The disassembled instruction as a string
106 | */
107 | public static String getLineDisassembly(Address address, Program program) {
108 | Instruction instruction = program.getListing().getInstructionAt(address);
109 | if (instruction != null) {
110 | return formatInstruction(instruction);
111 | }
112 | return null;
113 | }
114 |
115 | /**
116 | * Collects all code tokens for a specific line containing an address.
117 | * @param node The current ClangNode
118 | * @param address The target address
119 | * @param codeLineBuilder StringBuilder to collect the code
120 | * @return true if the address was found and code collected
121 | */
122 | private static boolean collectCodeLine(ClangNode node, Address address, StringBuilder codeLineBuilder) {
123 | if (node instanceof ClangToken) {
124 | ClangToken token = (ClangToken) node;
125 | if (token.getMinAddress() != null && token.getMaxAddress() != null) {
126 | if (token.getMinAddress().compareTo(address) <= 0 &&
127 | token.getMaxAddress().compareTo(address) >= 0) {
128 | // Found the token corresponding to the address
129 | ClangNode parent = token.Parent();
130 | if (parent != null) {
131 | for (int i = 0; i < parent.numChildren(); i++) {
132 | ClangNode sibling = parent.Child(i);
133 | if (sibling instanceof ClangToken) {
134 | codeLineBuilder.append(((ClangToken) sibling).getText());
135 | }
136 | }
137 | } else {
138 | codeLineBuilder.append(token.getText());
139 | }
140 | return true;
141 | }
142 | }
143 | } else if (node instanceof ClangTokenGroup) {
144 | ClangTokenGroup group = (ClangTokenGroup) node;
145 | for (int i = 0; i < group.numChildren(); i++) {
146 | ClangNode child = group.Child(i);
147 | boolean found = collectCodeLine(child, address, codeLineBuilder);
148 | if (found) {
149 | return true;
150 | }
151 | }
152 | }
153 | return false;
154 | }
155 |
156 | /**
157 | * Formats an instruction with its address and representation.
158 | * @param instruction The instruction to format
159 | * @return A formatted string representation of the instruction
160 | */
161 | private static String formatInstruction(Instruction instruction) {
162 | return String.format("%s %s",
163 | instruction.getAddressString(true, true),
164 | instruction.toString());
165 | }
166 | }
167 |
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/core/LlmApiClient.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.core;
2 |
3 | import ghidrassist.AnalysisDB;
4 | import ghidrassist.GhidrAssistPlugin;
5 | import ghidrassist.LlmApi;
6 | import ghidrassist.apiprovider.APIProvider;
7 | import ghidrassist.apiprovider.APIProviderConfig;
8 | import ghidrassist.apiprovider.ChatMessage;
9 | import ghidrassist.apiprovider.exceptions.APIProviderException;
10 |
11 | import java.util.ArrayList;
12 | import java.util.List;
13 | import java.util.Map;
14 |
15 | /**
16 | * Handles API provider management and low-level API calls.
17 | * Focused solely on provider configuration and basic API interactions.
18 | */
19 | public class LlmApiClient {
20 | private APIProvider provider;
21 | private final AnalysisDB analysisDB;
22 | private final GhidrAssistPlugin plugin;
23 |
24 | private final String DEFAULT_SYSTEM_PROMPT =
25 | "You are a professional software reverse engineer specializing in cybersecurity. You are intimately \n"
26 | + "familiar with x86_64, ARM, PPC and MIPS architectures. You are an expert C and C++ developer.\n"
27 | + "You are an expert Python and Rust developer. You are familiar with common frameworks and libraries \n"
28 | + "such as WinSock, OpenSSL, MFC, etc. You are an expert in TCP/IP network programming and packet analysis.\n"
29 | + "You always respond to queries in a structured format using Markdown styling for headings and lists. \n"
30 | + "You format code blocks using back-tick code-fencing.\n";
31 |
32 | private final String FUNCTION_PROMPT = "USE THE PROVIDED TOOLS WHEN NECESSARY. YOU ALWAYS RESPOND WITH TOOL CALLS WHEN POSSIBLE.";
33 |
34 | public LlmApiClient(APIProviderConfig config, GhidrAssistPlugin plugin) {
35 | this.provider = config.createProvider();
36 | this.analysisDB = new AnalysisDB();
37 | this.plugin = plugin;
38 |
39 | // Get the global API timeout and set it if the provider doesn't have one
40 | if (provider != null && provider.getTimeout() == null) {
41 | Integer timeout = GhidrAssistPlugin.getGlobalApiTimeout();
42 | provider.setTimeout(timeout);
43 | }
44 | }
45 |
46 | public String getSystemPrompt() {
47 | return this.DEFAULT_SYSTEM_PROMPT;
48 | }
49 |
50 | public GhidrAssistPlugin getPlugin() {
51 | return plugin;
52 | }
53 |
54 | public String getCurrentContext() {
55 | if (plugin.getCurrentProgram() != null) {
56 | String programHash = plugin.getCurrentProgram().getExecutableSHA256();
57 | String context = analysisDB.getContext(programHash);
58 | if (context != null) {
59 | return context;
60 | }
61 | }
62 | return DEFAULT_SYSTEM_PROMPT;
63 | }
64 |
65 | /**
66 | * Create messages for regular chat completion
67 | */
68 | public List createChatMessages(String prompt) {
69 | String systemUser = ChatMessage.ChatMessageRole.SYSTEM;
70 | if (isO1OrO3Model()) {
71 | systemUser = ChatMessage.ChatMessageRole.USER;
72 | }
73 |
74 | List messages = new ArrayList<>();
75 | messages.add(new ChatMessage(systemUser, getCurrentContext()));
76 | messages.add(new ChatMessage(ChatMessage.ChatMessageRole.USER, prompt));
77 | return messages;
78 | }
79 |
80 | /**
81 | * Create messages for function calling
82 | */
83 | public List createFunctionMessages(String prompt) {
84 | String systemUser = ChatMessage.ChatMessageRole.SYSTEM;
85 | if (isO1OrO3Model()) {
86 | systemUser = ChatMessage.ChatMessageRole.USER;
87 | }
88 |
89 | List messages = new ArrayList<>();
90 | messages.add(new ChatMessage(ChatMessage.ChatMessageRole.USER, prompt));
91 | return messages;
92 | }
93 |
94 | /**
95 | * Check if the current model is O1 or O3 series (which handle system prompts differently)
96 | */
97 | private boolean isO1OrO3Model() {
98 | return provider != null && (
99 | provider.getModel().startsWith("o1-") ||
100 | provider.getModel().startsWith("o3-") ||
101 | provider.getModel().startsWith("o4-")
102 | );
103 | }
104 |
105 | /**
106 | * Stream chat completion
107 | */
108 | public void streamChatCompletion(List messages, LlmApi.LlmResponseHandler handler)
109 | throws APIProviderException {
110 | if (provider == null) {
111 | throw new IllegalStateException("LLM provider is not initialized.");
112 | }
113 | provider.streamChatCompletion(messages, handler);
114 | }
115 |
116 | /**
117 | * Create chat completion with functions
118 | */
119 | public String createChatCompletionWithFunctions(List messages, List> functions)
120 | throws APIProviderException {
121 | if (provider == null) {
122 | throw new IllegalStateException("LLM provider is not initialized.");
123 | }
124 | return provider.createChatCompletionWithFunctions(messages, functions);
125 | }
126 |
127 | /**
128 | * Create chat completion with functions - returns full response including finish_reason
129 | */
130 | public String createChatCompletionWithFunctionsFullResponse(List messages, List> functions)
131 | throws APIProviderException {
132 | if (provider == null) {
133 | throw new IllegalStateException("LLM provider is not initialized.");
134 | }
135 | return provider.createChatCompletionWithFunctionsFullResponse(messages, functions);
136 | }
137 |
138 | /**
139 | * Check if provider is available
140 | */
141 | public boolean isProviderAvailable() {
142 | return provider != null;
143 | }
144 |
145 | /**
146 | * Get provider name for logging/error handling
147 | */
148 | public String getProviderName() {
149 | return provider != null ? provider.getName() : "Unknown";
150 | }
151 |
152 | /**
153 | * Get provider model for logging/error handling
154 | */
155 | public String getProviderModel() {
156 | return provider != null ? provider.getModel() : "Unknown";
157 | }
158 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/core/LlmErrorHandler.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.core;
2 |
3 | import ghidra.util.Msg;
4 | import ghidrassist.GhidrAssistPlugin;
5 | import ghidrassist.apiprovider.APIProviderLogger;
6 | import ghidrassist.apiprovider.ErrorAction;
7 | import ghidrassist.apiprovider.exceptions.APIProviderException;
8 | import ghidrassist.apiprovider.exceptions.StreamCancelledException;
9 | import ghidrassist.ui.EnhancedErrorDialog;
10 |
11 | import java.util.ArrayList;
12 | import java.util.List;
13 |
14 | /**
15 | * Handles error processing, logging, and user interaction for LLM operations.
16 | * Focused solely on error handling logic and user feedback.
17 | */
18 | public class LlmErrorHandler {
19 |
20 | private final GhidrAssistPlugin plugin;
21 | private final Object source;
22 |
23 | public LlmErrorHandler(GhidrAssistPlugin plugin, Object source) {
24 | this.plugin = plugin;
25 | this.source = source;
26 | }
27 |
28 | /**
29 | * Handle an error with enhanced error dialogs and logging
30 | */
31 | public void handleError(Throwable error, String operation, Runnable retryAction) {
32 | if (error instanceof APIProviderException) {
33 | APIProviderException ape = (APIProviderException) error;
34 |
35 | // Log the error with structured information
36 | APIProviderLogger.logError(source, ape);
37 |
38 | // Skip showing error dialog for cancellations unless it's unexpected
39 | if (shouldSkipErrorDialog(ape)) {
40 | return;
41 | }
42 |
43 | // Create appropriate error actions
44 | List actions = createErrorActions(ape, retryAction);
45 |
46 | // Show enhanced error dialog
47 | java.awt.Window parentWindow = getParentWindow();
48 | EnhancedErrorDialog.showError(parentWindow, ape, actions);
49 |
50 | } else {
51 | // Handle non-API provider exceptions (fallback)
52 | handleGenericError(error, operation);
53 | }
54 | }
55 |
56 | /**
57 | * Handle generic (non-API provider) errors
58 | */
59 | private void handleGenericError(Throwable error, String operation) {
60 | String message = error.getMessage() != null ? error.getMessage() : error.getClass().getSimpleName();
61 | Msg.showError(source, null, "Unexpected Error",
62 | "An unexpected error occurred during " + operation + ": " + message);
63 |
64 | // Log the error
65 | Msg.error(source, "Unexpected error during " + operation, error);
66 | }
67 |
68 | /**
69 | * Determine if error dialog should be skipped for certain cancellation types
70 | */
71 | private boolean shouldSkipErrorDialog(APIProviderException ape) {
72 | if (ape.getCategory() == APIProviderException.ErrorCategory.CANCELLED) {
73 | if (ape instanceof StreamCancelledException) {
74 | StreamCancelledException sce = (StreamCancelledException) ape;
75 | if (sce.getCancellationReason() == StreamCancelledException.CancellationReason.USER_REQUESTED) {
76 | return true; // Don't show dialog for user-requested cancellations
77 | }
78 | }
79 | }
80 | return false;
81 | }
82 |
83 | /**
84 | * Create appropriate error actions based on the exception type
85 | */
86 | private List createErrorActions(APIProviderException ape, Runnable retryAction) {
87 | List actions = new ArrayList<>();
88 |
89 | // Add retry action for retryable errors
90 | if (ape.isRetryable() && retryAction != null) {
91 | actions.add(ErrorAction.createRetryAction(retryAction));
92 | }
93 |
94 | // Add settings action for configuration-related errors
95 | if (isConfigurationError(ape)) {
96 | actions.add(ErrorAction.createSettingsAction(() -> openSettings()));
97 | }
98 |
99 | // Add provider switching action for persistent errors
100 | APIProviderLogger.ErrorStats stats = APIProviderLogger.getErrorStats(ape.getProviderName());
101 | if (stats != null && stats.isFrequentErrorsDetected()) {
102 | actions.add(ErrorAction.createSwitchProviderAction(() -> suggestProviderSwitch()));
103 | }
104 |
105 | // Add copy error details action
106 | actions.add(ErrorAction.createCopyErrorAction(ape.getTechnicalDetails()));
107 |
108 | // Add dismiss action
109 | actions.add(ErrorAction.createDismissAction());
110 |
111 | return actions;
112 | }
113 |
114 | /**
115 | * Check if error is configuration-related
116 | */
117 | private boolean isConfigurationError(APIProviderException ape) {
118 | return ape.getCategory() == APIProviderException.ErrorCategory.AUTHENTICATION ||
119 | ape.getCategory() == APIProviderException.ErrorCategory.CONFIGURATION ||
120 | ape.getCategory() == APIProviderException.ErrorCategory.MODEL_ERROR;
121 | }
122 |
123 | /**
124 | * Get the parent window for error dialogs
125 | */
126 | private java.awt.Window getParentWindow() {
127 | try {
128 | // Try to get the main Ghidra window
129 | if (plugin != null && plugin.getTool() != null) {
130 | return plugin.getTool().getToolFrame();
131 | }
132 | } catch (Exception e) {
133 | // Ignore errors getting parent window
134 | }
135 | return null;
136 | }
137 |
138 | /**
139 | * Open the settings dialog
140 | */
141 | private void openSettings() {
142 | try {
143 | if (plugin != null) {
144 | // This would typically call the plugin's settings dialog
145 | // The actual implementation depends on how settings are accessed
146 | Msg.showInfo(source, null, "Settings",
147 | "Please go to Tools -> GhidrAssist Settings to configure API providers.");
148 | }
149 | } catch (Exception e) {
150 | Msg.showError(source, null, "Error", "Could not open settings: " + e.getMessage());
151 | }
152 | }
153 |
154 | /**
155 | * Suggest switching to a different provider
156 | */
157 | private void suggestProviderSwitch() {
158 | try {
159 | // Generate a simple suggestion message
160 | StringBuilder suggestion = new StringBuilder();
161 | suggestion.append("Current provider is experiencing frequent errors.\n\n");
162 | suggestion.append("Consider switching to a different provider in Settings.\n\n");
163 | suggestion.append("Provider Error Statistics:\n");
164 | suggestion.append(APIProviderLogger.generateDiagnosticsReport());
165 |
166 | Msg.showInfo(source, null, "Provider Reliability", suggestion.toString());
167 | } catch (Exception e) {
168 | Msg.showError(source, null, "Error", "Could not generate provider statistics: " + e.getMessage());
169 | }
170 | }
171 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/core/LlmTaskExecutor.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.core;
2 |
3 | import ghidrassist.LlmApi;
4 | import ghidrassist.apiprovider.exceptions.APIProviderException;
5 |
6 | import javax.swing.SwingWorker;
7 | import java.util.List;
8 | import java.util.Map;
9 | import java.util.concurrent.atomic.AtomicBoolean;
10 |
11 | /**
12 | * Handles background task execution for LLM operations.
13 | * Focused on managing SwingWorker tasks and request lifecycle.
14 | */
15 | public class LlmTaskExecutor {
16 |
17 | private final Object streamLock = new Object();
18 | private volatile boolean isStreaming = false;
19 | private final AtomicBoolean shouldCancel = new AtomicBoolean(false);
20 |
21 | /**
22 | * Execute a streaming chat request in the background
23 | */
24 | public void executeStreamingRequest(
25 | LlmApiClient client,
26 | String prompt,
27 | ResponseProcessor responseProcessor,
28 | LlmResponseHandler responseHandler) {
29 |
30 | if (!client.isProviderAvailable()) {
31 | responseHandler.onError(new IllegalStateException("LLM provider is not initialized."));
32 | return;
33 | }
34 |
35 | // Cancel any existing stream
36 | cancelCurrentRequest();
37 | shouldCancel.set(false);
38 |
39 | try {
40 | synchronized (streamLock) {
41 | isStreaming = true;
42 | ResponseProcessor.StreamingResponseFilter filter = responseProcessor.createStreamingFilter();
43 |
44 | client.streamChatCompletion(client.createChatMessages(prompt), new LlmApi.LlmResponseHandler() {
45 | private boolean isFirst = true;
46 |
47 | @Override
48 | public void onStart() {
49 | if (isFirst && shouldCancel.get() == false) {
50 | responseHandler.onStart();
51 | isFirst = false;
52 | }
53 | }
54 |
55 | @Override
56 | public void onUpdate(String partialResponse) {
57 | if (shouldCancel.get()) {
58 | return;
59 | }
60 | String filteredContent = filter.processChunk(partialResponse);
61 | if (filteredContent != null && !filteredContent.isEmpty()) {
62 | responseHandler.onUpdate(filteredContent);
63 | }
64 | }
65 |
66 | @Override
67 | public void onComplete(String fullResponse) {
68 | synchronized (streamLock) {
69 | isStreaming = false;
70 | }
71 | if (!shouldCancel.get()) {
72 | responseHandler.onComplete(filter.getFilteredContent());
73 | }
74 | }
75 |
76 | @Override
77 | public void onError(Throwable error) {
78 | synchronized (streamLock) {
79 | isStreaming = false;
80 | }
81 | if (!shouldCancel.get()) {
82 | responseHandler.onError(error);
83 | }
84 | }
85 |
86 | @Override
87 | public boolean shouldContinue() {
88 | return !shouldCancel.get() && responseHandler.shouldContinue();
89 | }
90 | });
91 | }
92 | } catch (Exception e) {
93 | synchronized (streamLock) {
94 | isStreaming = false;
95 | }
96 | if (!shouldCancel.get()) {
97 | responseHandler.onError(e);
98 | }
99 | }
100 | }
101 |
102 | /**
103 | * Execute a function calling request in the background
104 | */
105 | public void executeFunctionRequest(
106 | LlmApiClient client,
107 | String prompt,
108 | List> functions,
109 | ResponseProcessor responseProcessor,
110 | LlmResponseHandler responseHandler) {
111 |
112 | if (!client.isProviderAvailable()) {
113 | responseHandler.onError(new IllegalStateException("LLM provider is not initialized."));
114 | return;
115 | }
116 |
117 | shouldCancel.set(false);
118 |
119 | // Create a background task
120 | SwingWorker worker = new SwingWorker<>() {
121 | @Override
122 | protected Void doInBackground() {
123 | try {
124 | synchronized (streamLock) {
125 | isStreaming = true;
126 | }
127 |
128 | if (!shouldCancel.get()) {
129 | responseHandler.onStart();
130 | String response = client.createChatCompletionWithFunctions(
131 | client.createFunctionMessages(prompt), functions);
132 |
133 | if (!shouldCancel.get() && responseHandler.shouldContinue()) {
134 | String filteredResponse = responseProcessor.filterThinkBlocks(response);
135 | responseHandler.onComplete(filteredResponse);
136 | }
137 | }
138 | } catch (APIProviderException e) {
139 | if (!shouldCancel.get() && responseHandler.shouldContinue()) {
140 | responseHandler.onError(e);
141 | }
142 | } finally {
143 | synchronized (streamLock) {
144 | isStreaming = false;
145 | }
146 | }
147 | return null;
148 | }
149 |
150 | @Override
151 | protected void done() {
152 | try {
153 | get(); // Check for exceptions
154 | } catch (Exception e) {
155 | if (!shouldCancel.get() && responseHandler.shouldContinue()) {
156 | responseHandler.onError(e);
157 | }
158 | }
159 | }
160 | };
161 |
162 | worker.execute();
163 | }
164 |
165 | /**
166 | * Cancel the current request
167 | */
168 | public void cancelCurrentRequest() {
169 | shouldCancel.set(true);
170 | synchronized (streamLock) {
171 | isStreaming = false;
172 | }
173 | }
174 |
175 | /**
176 | * Check if currently streaming
177 | */
178 | public boolean isStreaming() {
179 | synchronized (streamLock) {
180 | return isStreaming;
181 | }
182 | }
183 |
184 | /**
185 | * Interface for handling LLM responses
186 | */
187 | public interface LlmResponseHandler {
188 | void onStart();
189 | void onUpdate(String partialResponse);
190 | void onComplete(String fullResponse);
191 | void onError(Throwable error);
192 | default boolean shouldContinue() {
193 | return true;
194 | }
195 | }
196 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/core/MarkdownHelper.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.core;
2 |
3 | import com.vladsch.flexmark.html.HtmlRenderer;
4 | import com.vladsch.flexmark.html2md.converter.FlexmarkHtmlConverter;
5 | import com.vladsch.flexmark.parser.Parser;
6 | import com.vladsch.flexmark.util.ast.Document;
7 | import com.vladsch.flexmark.util.data.MutableDataSet;
8 |
9 | import java.util.regex.Matcher;
10 | import java.util.regex.Pattern;
11 |
12 | public class MarkdownHelper {
13 | private final Parser parser;
14 | private final HtmlRenderer renderer;
15 | private final FlexmarkHtmlConverter htmlToMdConverter;
16 |
17 | public MarkdownHelper() {
18 | MutableDataSet options = new MutableDataSet();
19 | options.set(HtmlRenderer.SOFT_BREAK, " \n");
20 | this.parser = Parser.builder(options).build();
21 | this.renderer = HtmlRenderer.builder(options).build();
22 | this.htmlToMdConverter = FlexmarkHtmlConverter.builder().build();
23 | }
24 |
25 | /**
26 | * Convert Markdown text to HTML for display
27 | * Includes feedback buttons in the HTML output
28 | *
29 | * @param markdown The markdown text to convert
30 | * @return HTML representation of the markdown
31 | */
32 | public String markdownToHtml(String markdown) {
33 | if (markdown == null) {
34 | return "";
35 | }
36 |
37 | Document document = parser.parse(markdown);
38 | String html = renderer.render(document);
39 |
40 | // Add feedback buttons
41 | String feedbackLinks = " ";
43 |
44 | return "" + html + feedbackLinks + "";
45 | }
46 |
47 | /**
48 | * Convert Markdown text to HTML without adding feedback buttons
49 | * Used for preview or when feedback isn't needed
50 | *
51 | * @param markdown The markdown text to convert
52 | * @return HTML representation of the markdown
53 | */
54 | public String markdownToHtmlSimple(String markdown) {
55 | if (markdown == null) {
56 | return "";
57 | }
58 |
59 | Document document = parser.parse(markdown);
60 | String html = renderer.render(document);
61 |
62 | return "" + html + "";
63 | }
64 |
65 | /**
66 | * Convert HTML to Markdown
67 | *
68 | * @param html The HTML to convert
69 | * @return Markdown representation of the HTML
70 | */
71 | public String htmlToMarkdown(String html) {
72 | if (html == null || html.isEmpty()) {
73 | return "";
74 | }
75 |
76 | // Remove feedback buttons if present
77 | html = removeFeedbackButtons(html);
78 |
79 | // Remove html wrapper tags if present
80 | html = removeHtmlWrapperTags(html);
81 |
82 | // Use flexmark converter for the HTML to Markdown conversion
83 | return htmlToMdConverter.convert(html);
84 | }
85 |
86 | /**
87 | * Extract markdown from a response that might be in various formats
88 | *
89 | * @param response The response to extract markdown from
90 | * @return Extracted markdown content
91 | */
92 | public String extractMarkdownFromLlmResponse(String response) {
93 | if (response == null || response.isEmpty()) {
94 | return "";
95 | }
96 |
97 | // Check if it's HTML
98 | if (response.toLowerCase().contains("") || response.toLowerCase().contains("")) {
99 | return htmlToMarkdown(response);
100 | }
101 |
102 | // Otherwise, assume it's already markdown or plain text
103 | return response;
104 | }
105 |
106 | /**
107 | * Remove feedback buttons from HTML string
108 | */
109 | private String removeFeedbackButtons(String html) {
110 | // Pattern to match the feedback buttons div
111 | Pattern feedbackPattern = Pattern.compile(".*?
");
112 | Matcher matcher = feedbackPattern.matcher(html);
113 | return matcher.replaceAll("");
114 | }
115 |
116 | /**
117 | * Remove HTML and BODY wrapper tags
118 | */
119 | private String removeHtmlWrapperTags(String html) {
120 | return html.replaceAll("(?i)|||", "");
121 | }
122 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/core/QueryProcessor.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.core;
2 |
3 | import java.util.List;
4 | import java.util.regex.Pattern;
5 | import java.util.regex.Matcher;
6 |
7 | import ghidra.program.model.address.Address;
8 | import ghidra.program.model.address.AddressFactory;
9 | import ghidra.program.model.listing.Function;
10 | import ghidra.program.model.listing.Program;
11 | import ghidra.util.task.TaskMonitor;
12 | import ghidrassist.GhidrAssistPlugin;
13 | import ghidrassist.GhidrAssistPlugin.CodeViewType;
14 |
15 | public class QueryProcessor {
16 |
17 | private static final Pattern RANGE_PATTERN = Pattern.compile("#range\\(([^,]+),\\s*([^\\)]+)\\)");
18 | private static final int MAX_SEARCH_RESULTS = 5;
19 |
20 | /**
21 | * Process all macros in the query and replace them with actual content.
22 | * @param query The original query containing macros
23 | * @param plugin The GhidrAssist plugin instance
24 | * @return Processed query with macros replaced
25 | */
26 | public static String processMacrosInQuery(String query, GhidrAssistPlugin plugin) {
27 | String processedQuery = query;
28 |
29 | try {
30 | CodeViewType viewType = plugin.checkLastActiveCodeView();
31 | TaskMonitor monitor = TaskMonitor.DUMMY;
32 |
33 | // Process #line macro
34 | if (processedQuery.contains("#line")) {
35 | String codeLine = getCurrentLine(plugin, viewType, monitor);
36 | if (codeLine != null) {
37 | processedQuery = processedQuery.replace("#line", codeLine);
38 | }
39 | }
40 |
41 | // Process #func macro
42 | if (processedQuery.contains("#func")) {
43 | String functionCode = getCurrentFunction(plugin, viewType, monitor);
44 | if (functionCode != null) {
45 | processedQuery = processedQuery.replace("#func", functionCode);
46 | }
47 | }
48 |
49 | // Process #addr macro
50 | if (processedQuery.contains("#addr")) {
51 | String addressString = getCurrentAddress(plugin);
52 | processedQuery = processedQuery.replace("#addr", addressString);
53 | }
54 |
55 | // Process #range macros
56 | processedQuery = processRangeMacros(processedQuery, plugin);
57 |
58 | } catch (Exception e) {
59 | throw new RuntimeException("Failed to process macros: " + e.getMessage(), e);
60 | }
61 |
62 | return processedQuery;
63 | }
64 |
65 | /**
66 | * Append RAG context to the query based on similarity search.
67 | * @param query The original query
68 | * @return Query with RAG context prepended
69 | * @throws Exception if RAG search fails
70 | */
71 | public static String appendRAGContext(String query) throws Exception {
72 | List results = RAGEngine.hybridSearch(query, MAX_SEARCH_RESULTS);
73 | if (results.isEmpty()) {
74 | return query;
75 | }
76 |
77 | StringBuilder contextBuilder = new StringBuilder();
78 | contextBuilder.append("\n");
79 |
80 | for (SearchResult result : results) {
81 | contextBuilder.append("\n");
82 | contextBuilder.append("").append(result.getFilename()).append(" \n");
83 | contextBuilder.append("").append(result.getChunkId()).append(" \n");
84 | contextBuilder.append("").append(result.getScore()).append(" \n");
85 | contextBuilder.append("\n").append(result.getSnippet()).append("\n \n");
86 | contextBuilder.append("\n \n\n");
87 | }
88 |
89 | contextBuilder.append("\n \n");
90 | return contextBuilder.toString() + query;
91 | }
92 |
93 | /**
94 | * Get the current line based on view type.
95 | */
96 | private static String getCurrentLine(GhidrAssistPlugin plugin, CodeViewType viewType, TaskMonitor monitor) {
97 | Address currentAddress = plugin.getCurrentAddress();
98 | if (currentAddress == null) {
99 | return "No current address available.";
100 | }
101 |
102 | if (viewType == CodeViewType.IS_DECOMPILER) {
103 | return CodeUtils.getLineCode(currentAddress, monitor, plugin.getCurrentProgram());
104 | } else if (viewType == CodeViewType.IS_DISASSEMBLER) {
105 | return CodeUtils.getLineDisassembly(currentAddress, plugin.getCurrentProgram());
106 | }
107 |
108 | return "Unknown code view type.";
109 | }
110 |
111 | /**
112 | * Get the current function based on view type.
113 | */
114 | private static String getCurrentFunction(GhidrAssistPlugin plugin, CodeViewType viewType, TaskMonitor monitor) {
115 | Function currentFunction = plugin.getCurrentFunction();
116 | if (currentFunction == null) {
117 | return "No function at current location.";
118 | }
119 |
120 | if (viewType == CodeViewType.IS_DECOMPILER) {
121 | return CodeUtils.getFunctionCode(currentFunction, monitor);
122 | } else if (viewType == CodeViewType.IS_DISASSEMBLER) {
123 | return CodeUtils.getFunctionDisassembly(currentFunction);
124 | }
125 |
126 | return "Unknown code view type.";
127 | }
128 |
129 | /**
130 | * Get the current address as a string.
131 | */
132 | private static String getCurrentAddress(GhidrAssistPlugin plugin) {
133 | Address currentAddress = plugin.getCurrentAddress();
134 | return (currentAddress != null) ? currentAddress.toString() : "No address available.";
135 | }
136 |
137 | /**
138 | * Process all #range macros in the query.
139 | */
140 | private static String processRangeMacros(String query, GhidrAssistPlugin plugin) {
141 | Matcher matcher = RANGE_PATTERN.matcher(query);
142 | while (matcher.find()) {
143 | String startStr = matcher.group(1);
144 | String endStr = matcher.group(2);
145 | String rangeData = getRangeData(startStr.trim(), endStr.trim(), plugin);
146 | query = query.replace(matcher.group(0), rangeData);
147 | matcher = RANGE_PATTERN.matcher(query);
148 | }
149 | return query;
150 | }
151 |
152 | /**
153 | * Get the data for a specific address range.
154 | */
155 | private static String getRangeData(String startStr, String endStr, GhidrAssistPlugin plugin) {
156 | try {
157 | Program program = plugin.getCurrentProgram();
158 | if (program == null) {
159 | return "No program loaded.";
160 | }
161 |
162 | AddressFactory addressFactory = program.getAddressFactory();
163 | Address startAddr = addressFactory.getAddress(startStr);
164 | Address endAddr = addressFactory.getAddress(endStr);
165 |
166 | if (startAddr == null || endAddr == null) {
167 | return "Invalid addresses.";
168 | }
169 |
170 | // Get the bytes in the range
171 | long size = endAddr.getOffset() - startAddr.getOffset() + 1;
172 | if (size <= 0 || size > 1024) { // Limit to reasonable size
173 | return "Invalid range size.";
174 | }
175 |
176 | byte[] bytes = new byte[(int) size];
177 | program.getMemory().getBytes(startAddr, bytes);
178 |
179 | // Convert bytes to hex string
180 | StringBuilder sb = new StringBuilder();
181 | for (byte b : bytes) {
182 | sb.append(String.format("%02X ", b));
183 | }
184 | return sb.toString().trim();
185 |
186 | } catch (Exception e) {
187 | return "Failed to get range data: " + e.getMessage();
188 | }
189 | }
190 | }
191 |
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/core/ResponseProcessor.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.core;
2 |
3 | import java.util.regex.Pattern;
4 |
5 | /**
6 | * Handles response processing including streaming filters and thinking block removal.
7 | * Focused solely on text processing and filtering logic.
8 | */
9 | public class ResponseProcessor {
10 |
11 | // Pattern for matching complete blocks and opening/closing tags
12 | private static final Pattern COMPLETE_THINK_PATTERN = Pattern.compile(".*? ", Pattern.DOTALL);
13 |
14 | /**
15 | * Create a new streaming filter for processing chunks
16 | */
17 | public StreamingResponseFilter createStreamingFilter() {
18 | return new StreamingResponseFilter();
19 | }
20 |
21 | /**
22 | * Filter thinking blocks from a complete response
23 | */
24 | public String filterThinkBlocks(String response) {
25 | if (response == null) {
26 | return null;
27 | }
28 | return COMPLETE_THINK_PATTERN.matcher(response).replaceAll("").trim();
29 | }
30 |
31 | /**
32 | * Streaming filter that processes chunks of text and removes thinking blocks in real-time
33 | */
34 | public static class StreamingResponseFilter {
35 | private StringBuilder buffer = new StringBuilder();
36 | private StringBuilder visibleBuffer = new StringBuilder();
37 | private boolean insideThinkBlock = false;
38 |
39 | /**
40 | * Process a chunk of streaming text, filtering out thinking blocks
41 | * @param chunk The text chunk to process
42 | * @return The filtered content that should be displayed, or null if nothing to display
43 | */
44 | public String processChunk(String chunk) {
45 | if (chunk == null) {
46 | return null;
47 | }
48 |
49 | buffer.append(chunk);
50 |
51 | // Process the buffer until we can't anymore
52 | String currentBuffer = buffer.toString();
53 | int lastSafeIndex = 0;
54 |
55 | for (int i = 0; i < currentBuffer.length(); i++) {
56 | // Look for start tag
57 | if (!insideThinkBlock && currentBuffer.startsWith("", i)) {
58 | // Append everything up to this point to visible buffer
59 | visibleBuffer.append(currentBuffer.substring(lastSafeIndex, i));
60 | insideThinkBlock = true;
61 | lastSafeIndex = i + 7; // Skip ""
62 | i += 6; // Move past ""
63 | }
64 | // Look for end tag
65 | else if (insideThinkBlock && currentBuffer.startsWith(" ", i)) {
66 | insideThinkBlock = false;
67 | lastSafeIndex = i + 8; // Skip " "
68 | i += 7; // Move past " "
69 | }
70 | }
71 |
72 | // If we're not in a think block, append any remaining safe content
73 | if (!insideThinkBlock) {
74 | visibleBuffer.append(currentBuffer.substring(lastSafeIndex));
75 | // Clear processed content from buffer
76 | buffer.setLength(0);
77 | } else {
78 | // Keep everything from lastSafeIndex in buffer
79 | buffer = new StringBuilder(currentBuffer.substring(lastSafeIndex));
80 | }
81 |
82 | return visibleBuffer.toString();
83 | }
84 |
85 | /**
86 | * Get the complete filtered content processed so far
87 | */
88 | public String getFilteredContent() {
89 | return visibleBuffer.toString();
90 | }
91 |
92 | /**
93 | * Reset the filter state for reuse
94 | */
95 | public void reset() {
96 | buffer.setLength(0);
97 | visibleBuffer.setLength(0);
98 | insideThinkBlock = false;
99 | }
100 |
101 | /**
102 | * Check if currently inside a thinking block
103 | */
104 | public boolean isInsideThinkBlock() {
105 | return insideThinkBlock;
106 | }
107 | }
108 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/core/UIState.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.core;
2 |
3 | public class UIState {
4 | private volatile boolean isQueryRunning;
5 | private int activeRunners;
6 |
7 | public UIState() {
8 | this.isQueryRunning = false;
9 | this.activeRunners = 0;
10 | }
11 |
12 | public synchronized boolean isQueryRunning() {
13 | return isQueryRunning;
14 | }
15 |
16 | public synchronized void setQueryRunning(boolean running) {
17 | this.isQueryRunning = running;
18 | }
19 |
20 | public synchronized void incrementRunners() {
21 | activeRunners++;
22 | }
23 |
24 | public synchronized void decrementRunners() {
25 | activeRunners--;
26 | if (activeRunners <= 0) {
27 | activeRunners = 0;
28 | isQueryRunning = false;
29 | }
30 | }
31 |
32 | public synchronized int getActiveRunners() {
33 | return activeRunners;
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/mcp2/protocol/MCPMessage.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.mcp2.protocol;
2 |
3 | import com.google.gson.JsonObject;
4 | import com.google.gson.JsonElement;
5 |
6 | /**
7 | * Base class for all MCP JSON-RPC 2.0 messages.
8 | * Implements the core JSON-RPC 2.0 message structure.
9 | */
10 | public abstract class MCPMessage {
11 |
12 | public static final String JSONRPC_VERSION = "2.0";
13 |
14 | protected String jsonrpc = JSONRPC_VERSION;
15 | protected String method;
16 | protected JsonObject params;
17 |
18 | public MCPMessage(String method) {
19 | this.method = method;
20 | this.params = new JsonObject();
21 | }
22 |
23 | public MCPMessage(String method, JsonObject params) {
24 | this.method = method;
25 | this.params = params != null ? params : new JsonObject();
26 | }
27 |
28 | public String getJsonrpc() {
29 | return jsonrpc;
30 | }
31 |
32 | public String getMethod() {
33 | return method;
34 | }
35 |
36 | public JsonObject getParams() {
37 | return params;
38 | }
39 |
40 | public void setParam(String key, String value) {
41 | params.addProperty(key, value);
42 | }
43 |
44 | public void setParam(String key, JsonElement value) {
45 | params.add(key, value);
46 | }
47 |
48 | public void setParam(String key, Number value) {
49 | params.addProperty(key, value);
50 | }
51 |
52 | public void setParam(String key, Boolean value) {
53 | params.addProperty(key, value);
54 | }
55 |
56 | /**
57 | * Convert message to JSON string for transmission
58 | */
59 | public abstract String toJson();
60 |
61 | /**
62 | * Validate message format
63 | */
64 | public boolean isValid() {
65 | return JSONRPC_VERSION.equals(jsonrpc) && method != null && !method.isEmpty();
66 | }
67 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/mcp2/protocol/MCPRequest.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.mcp2.protocol;
2 |
3 | import com.google.gson.Gson;
4 | import com.google.gson.JsonObject;
5 | import com.google.gson.JsonElement;
6 |
7 | /**
8 | * Represents an MCP JSON-RPC 2.0 request message.
9 | * Used for tool discovery (tools/list) and tool execution (tools/call).
10 | */
11 | public class MCPRequest extends MCPMessage {
12 |
13 | private final Object id;
14 |
15 | public MCPRequest(Object id, String method) {
16 | super(method);
17 | this.id = id;
18 | }
19 |
20 | public MCPRequest(Object id, String method, JsonObject params) {
21 | super(method, params);
22 | this.id = id;
23 | }
24 |
25 | public Object getId() {
26 | return id;
27 | }
28 |
29 | @Override
30 | public String toJson() {
31 | JsonObject json = new JsonObject();
32 | json.addProperty("jsonrpc", jsonrpc);
33 |
34 | // Only add ID if this is not a notification
35 | if (!"notification".equals(id)) {
36 | json.addProperty("id", id.toString());
37 | }
38 |
39 | json.addProperty("method", method);
40 |
41 | if (params != null && params.size() > 0) {
42 | json.add("params", params);
43 | }
44 |
45 | return new Gson().toJson(json);
46 | }
47 |
48 | @Override
49 | public boolean isValid() {
50 | return super.isValid() && id != null;
51 | }
52 |
53 | /**
54 | * Create a tools/list request
55 | */
56 | public static MCPRequest createToolsListRequest(Object id) {
57 | return createToolsListRequest(id, null);
58 | }
59 |
60 | /**
61 | * Create a tools/list request with cursor for pagination
62 | */
63 | public static MCPRequest createToolsListRequest(Object id, String cursor) {
64 | MCPRequest request = new MCPRequest(id, "tools/list");
65 | if (cursor != null) {
66 | request.setParam("cursor", cursor);
67 | }
68 | return request;
69 | }
70 |
71 | /**
72 | * Create a tools/call request
73 | */
74 | public static MCPRequest createToolsCallRequest(Object id, String toolName, JsonObject arguments) {
75 | MCPRequest request = new MCPRequest(id, "tools/call");
76 | request.setParam("name", toolName);
77 | if (arguments != null) {
78 | request.setParam("arguments", arguments);
79 | }
80 | return request;
81 | }
82 |
83 | /**
84 | * Create an initialize request for protocol handshake
85 | */
86 | public static MCPRequest createInitializeRequest(Object id, String protocolVersion, String clientInfo) {
87 | MCPRequest request = new MCPRequest(id, "initialize");
88 | request.setParam("protocolVersion", protocolVersion);
89 |
90 | JsonObject clientInfoObj = new JsonObject();
91 | clientInfoObj.addProperty("name", "GhidrAssist");
92 | clientInfoObj.addProperty("version", "1.0.0");
93 | if (clientInfo != null) {
94 | clientInfoObj.addProperty("description", clientInfo);
95 | }
96 | request.setParam("clientInfo", clientInfoObj);
97 |
98 | JsonObject capabilities = new JsonObject();
99 | capabilities.addProperty("tools", true);
100 | request.setParam("capabilities", capabilities);
101 |
102 | return request;
103 | }
104 |
105 | /**
106 | * Create an initialized notification (sent after initialize response)
107 | */
108 | public static MCPRequest createInitializedNotification() {
109 | // Notifications don't have IDs in JSON-RPC 2.0
110 | // The bridge expects "notifications/initialized" method
111 | return new MCPRequest("notification", "notifications/initialized");
112 | }
113 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/mcp2/protocol/MCPResponse.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.mcp2.protocol;
2 |
3 | import com.google.gson.Gson;
4 | import com.google.gson.JsonObject;
5 | import com.google.gson.JsonElement;
6 | import com.google.gson.JsonArray;
7 |
8 | /**
9 | * Represents an MCP JSON-RPC 2.0 response message.
10 | * Can contain either a result (success) or an error (failure).
11 | */
12 | public class MCPResponse {
13 |
14 | private String jsonrpc = MCPMessage.JSONRPC_VERSION;
15 | private Object id;
16 | private JsonElement result;
17 | private MCPError error;
18 |
19 | public MCPResponse(Object id) {
20 | this.id = id;
21 | }
22 |
23 | public String getJsonrpc() {
24 | return jsonrpc;
25 | }
26 |
27 | public Object getId() {
28 | return id;
29 | }
30 |
31 | public JsonElement getResult() {
32 | return result;
33 | }
34 |
35 | public void setResult(JsonElement result) {
36 | this.result = result;
37 | this.error = null; // Clear error if setting result
38 | }
39 |
40 | public MCPError getError() {
41 | return error;
42 | }
43 |
44 | public void setError(MCPError error) {
45 | this.error = error;
46 | this.result = null; // Clear result if setting error
47 | }
48 |
49 | public void setError(int code, String message, JsonElement data) {
50 | setError(new MCPError(code, message, data));
51 | }
52 |
53 | public boolean isSuccess() {
54 | return result != null && error == null;
55 | }
56 |
57 | public boolean isError() {
58 | return error != null;
59 | }
60 |
61 | public String toJson() {
62 | JsonObject json = new JsonObject();
63 | json.addProperty("jsonrpc", jsonrpc);
64 | json.addProperty("id", id.toString());
65 |
66 | if (result != null) {
67 | json.add("result", result);
68 | } else if (error != null) {
69 | json.add("error", error.toJson());
70 | }
71 |
72 | return new Gson().toJson(json);
73 | }
74 |
75 | /**
76 | * Parse response from JSON string
77 | */
78 | public static MCPResponse fromJson(String jsonStr) {
79 | Gson gson = new Gson();
80 | JsonObject json = gson.fromJson(jsonStr, JsonObject.class);
81 |
82 | Object id = json.has("id") ? json.get("id").getAsString() : null;
83 | MCPResponse response = new MCPResponse(id);
84 |
85 | if (json.has("result")) {
86 | response.setResult(json.get("result"));
87 | } else if (json.has("error")) {
88 | JsonObject errorObj = json.getAsJsonObject("error");
89 | int code = errorObj.get("code").getAsInt();
90 | String message = errorObj.get("message").getAsString();
91 | JsonElement data = errorObj.has("data") ? errorObj.get("data") : null;
92 | response.setError(code, message, data);
93 | }
94 |
95 | return response;
96 | }
97 |
98 | /**
99 | * Extract tools array from tools/list response
100 | */
101 | public JsonArray getToolsArray() {
102 | if (isSuccess() && result.isJsonObject()) {
103 | JsonObject resultObj = result.getAsJsonObject();
104 | if (resultObj.has("tools") && resultObj.get("tools").isJsonArray()) {
105 | return resultObj.getAsJsonArray("tools");
106 | }
107 | }
108 | return new JsonArray();
109 | }
110 |
111 | /**
112 | * Extract tool call result content
113 | */
114 | public JsonElement getToolCallResult() {
115 | if (isSuccess() && result.isJsonObject()) {
116 | JsonObject resultObj = result.getAsJsonObject();
117 | if (resultObj.has("content")) {
118 | return resultObj.get("content");
119 | }
120 | }
121 | return result;
122 | }
123 |
124 | /**
125 | * Get cursor for pagination in tools/list response
126 | */
127 | public String getNextCursor() {
128 | if (isSuccess() && result.isJsonObject()) {
129 | JsonObject resultObj = result.getAsJsonObject();
130 | if (resultObj.has("nextCursor")) {
131 | return resultObj.get("nextCursor").getAsString();
132 | }
133 | }
134 | return null;
135 | }
136 |
137 | /**
138 | * Inner class representing JSON-RPC 2.0 error object
139 | */
140 | public static class MCPError {
141 | private int code;
142 | private String message;
143 | private JsonElement data;
144 |
145 | public MCPError(int code, String message, JsonElement data) {
146 | this.code = code;
147 | this.message = message;
148 | this.data = data;
149 | }
150 |
151 | public int getCode() {
152 | return code;
153 | }
154 |
155 | public String getMessage() {
156 | return message;
157 | }
158 |
159 | public JsonElement getData() {
160 | return data;
161 | }
162 |
163 | public JsonObject toJson() {
164 | JsonObject json = new JsonObject();
165 | json.addProperty("code", code);
166 | json.addProperty("message", message);
167 | if (data != null) {
168 | json.add("data", data);
169 | }
170 | return json;
171 | }
172 |
173 | @Override
174 | public String toString() {
175 | return String.format("MCPError{code=%d, message='%s'}", code, message);
176 | }
177 | }
178 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/mcp2/server/MCPServerConfig.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.mcp2.server;
2 |
3 | import com.google.gson.Gson;
4 | import com.google.gson.JsonObject;
5 |
6 | /**
7 | * Configuration for an MCP server connection.
8 | * Stores all necessary information to connect to and manage an MCP server.
9 | */
10 | public class MCPServerConfig {
11 |
12 | public enum TransportType {
13 | SSE("Server-Sent Events"),
14 | STDIO("Standard I/O");
15 |
16 | private final String displayName;
17 |
18 | TransportType(String displayName) {
19 | this.displayName = displayName;
20 | }
21 |
22 | public String getDisplayName() {
23 | return displayName;
24 | }
25 | }
26 |
27 | private String name; // Display name (e.g., "GhidraMCP Local")
28 | private String url; // Server URL (e.g., "http://localhost:8081")
29 | private TransportType transport; // Transport mechanism
30 | private int connectionTimeout; // Connection timeout in seconds
31 | private int requestTimeout; // Request timeout in seconds
32 | private boolean enabled; // Whether this server is active
33 | private String description; // Optional description
34 |
35 | // Default constructor for JSON deserialization
36 | public MCPServerConfig() {
37 | this.transport = TransportType.SSE;
38 | this.connectionTimeout = 5; // Reduced from 10 seconds
39 | this.requestTimeout = 15; // Reduced from 30 seconds
40 | this.enabled = true;
41 | }
42 |
43 | public MCPServerConfig(String name, String url) {
44 | this();
45 | this.name = name;
46 | this.url = url;
47 | }
48 |
49 | public MCPServerConfig(String name, String url, TransportType transport) {
50 | this(name, url);
51 | this.transport = transport;
52 | }
53 |
54 | public MCPServerConfig(String name, String url, TransportType transport, boolean enabled) {
55 | this(name, url, transport);
56 | this.enabled = enabled;
57 | }
58 |
59 | // Getters and setters
60 |
61 | public String getName() {
62 | return name;
63 | }
64 |
65 | public void setName(String name) {
66 | this.name = name;
67 | }
68 |
69 | public String getUrl() {
70 | return url;
71 | }
72 |
73 | public void setUrl(String url) {
74 | this.url = url;
75 | }
76 |
77 | public TransportType getTransport() {
78 | return transport;
79 | }
80 |
81 | public void setTransport(TransportType transport) {
82 | this.transport = transport;
83 | }
84 |
85 | public int getConnectionTimeout() {
86 | return connectionTimeout;
87 | }
88 |
89 | public void setConnectionTimeout(int connectionTimeout) {
90 | this.connectionTimeout = connectionTimeout;
91 | }
92 |
93 | public int getRequestTimeout() {
94 | return requestTimeout;
95 | }
96 |
97 | public void setRequestTimeout(int requestTimeout) {
98 | this.requestTimeout = requestTimeout;
99 | }
100 |
101 | public boolean isEnabled() {
102 | return enabled;
103 | }
104 |
105 | public void setEnabled(boolean enabled) {
106 | this.enabled = enabled;
107 | }
108 |
109 | public String getDescription() {
110 | return description;
111 | }
112 |
113 | public void setDescription(String description) {
114 | this.description = description;
115 | }
116 |
117 | /**
118 | * Get the base URL for HTTP connections
119 | */
120 | public String getBaseUrl() {
121 | if (url == null) return null;
122 |
123 | // Ensure URL has protocol
124 | if (!url.startsWith("http://") && !url.startsWith("https://")) {
125 | return "http://" + url;
126 | }
127 | return url;
128 | }
129 |
130 | /**
131 | * Get the host from the URL
132 | */
133 | public String getHost() {
134 | try {
135 | java.net.URL urlObj = new java.net.URL(getBaseUrl());
136 | return urlObj.getHost();
137 | } catch (Exception e) {
138 | return "localhost";
139 | }
140 | }
141 |
142 | /**
143 | * Get the port from the URL
144 | */
145 | public int getPort() {
146 | try {
147 | java.net.URL urlObj = new java.net.URL(getBaseUrl());
148 | int port = urlObj.getPort();
149 | return port != -1 ? port : (urlObj.getProtocol().equals("https") ? 443 : 80);
150 | } catch (Exception e) {
151 | return 8081; // Default MCP port
152 | }
153 | }
154 |
155 | /**
156 | * Validate configuration
157 | */
158 | public boolean isValid() {
159 | return name != null && !name.trim().isEmpty() &&
160 | url != null && !url.trim().isEmpty() &&
161 | transport != null &&
162 | connectionTimeout > 0 &&
163 | requestTimeout > 0;
164 | }
165 |
166 | /**
167 | * Create a copy of this configuration
168 | */
169 | public MCPServerConfig copy() {
170 | MCPServerConfig copy = new MCPServerConfig(name, url, transport);
171 | copy.setConnectionTimeout(connectionTimeout);
172 | copy.setRequestTimeout(requestTimeout);
173 | copy.setEnabled(enabled);
174 | copy.setDescription(description);
175 | return copy;
176 | }
177 |
178 | /**
179 | * Serialize to JSON
180 | */
181 | public String toJson() {
182 | return new Gson().toJson(this);
183 | }
184 |
185 | /**
186 | * Deserialize from JSON
187 | */
188 | public static MCPServerConfig fromJson(String json) {
189 | return new Gson().fromJson(json, MCPServerConfig.class);
190 | }
191 |
192 | @Override
193 | public String toString() {
194 | return String.format("%s (%s) - %s", name, transport.getDisplayName(),
195 | enabled ? "Enabled" : "Disabled");
196 | }
197 |
198 | @Override
199 | public boolean equals(Object obj) {
200 | if (this == obj) return true;
201 | if (obj == null || getClass() != obj.getClass()) return false;
202 |
203 | MCPServerConfig that = (MCPServerConfig) obj;
204 | return name != null ? name.equals(that.name) : that.name == null;
205 | }
206 |
207 | @Override
208 | public int hashCode() {
209 | return name != null ? name.hashCode() : 0;
210 | }
211 |
212 | /**
213 | * Create default GhidraMCP configuration
214 | */
215 | public static MCPServerConfig createGhidraMCPDefault() {
216 | MCPServerConfig config = new MCPServerConfig("GhidraMCP Local", "http://localhost:8081");
217 | config.setDescription("Local GhidraMCP server instance");
218 | config.setTransport(TransportType.SSE);
219 | return config;
220 | }
221 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/mcp2/tools/MCPTool.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.mcp2.tools;
2 |
3 | import com.google.gson.JsonObject;
4 | import com.google.gson.JsonElement;
5 | import java.util.Map;
6 | import java.util.HashMap;
7 |
8 | /**
9 | * Represents an MCP tool discovered from a server.
10 | * This is server-agnostic and follows the MCP specification.
11 | */
12 | public class MCPTool {
13 |
14 | private final String name;
15 | private final String description;
16 | private final JsonObject inputSchema;
17 | private final String serverName; // Which server provides this tool
18 |
19 | public MCPTool(String name, String description, JsonObject inputSchema, String serverName) {
20 | this.name = name;
21 | this.description = description;
22 | this.inputSchema = inputSchema;
23 | this.serverName = serverName;
24 | }
25 |
26 | public String getName() {
27 | return name;
28 | }
29 |
30 | public String getDescription() {
31 | return description;
32 | }
33 |
34 | public JsonObject getInputSchema() {
35 | return inputSchema;
36 | }
37 |
38 | public String getServerName() {
39 | return serverName;
40 | }
41 |
42 | /**
43 | * Check if this tool has input parameters
44 | */
45 | public boolean hasInputSchema() {
46 | return inputSchema != null && inputSchema.size() > 0;
47 | }
48 |
49 | /**
50 | * Convert to function schema for LLM function calling
51 | * This follows the OpenAI function calling format
52 | */
53 | public Map toFunctionSchema() {
54 | Map function = new HashMap<>();
55 | function.put("type", "function");
56 |
57 | Map functionDef = new HashMap<>();
58 | functionDef.put("name", name);
59 | functionDef.put("description", description);
60 |
61 | if (inputSchema != null) {
62 | // Convert JsonObject to Map for compatibility
63 | functionDef.put("parameters", jsonObjectToMap(inputSchema));
64 | } else {
65 | // Empty parameters schema
66 | Map emptyParams = new HashMap<>();
67 | emptyParams.put("type", "object");
68 | emptyParams.put("properties", new HashMap<>());
69 | functionDef.put("parameters", emptyParams);
70 | }
71 |
72 | function.put("function", functionDef);
73 | return function;
74 | }
75 |
76 | /**
77 | * Create MCPTool from MCP tools/list response
78 | */
79 | public static MCPTool fromToolsListEntry(JsonObject toolEntry, String serverName) {
80 | String name = toolEntry.has("name") ? toolEntry.get("name").getAsString() : null;
81 | String description = toolEntry.has("description") ? toolEntry.get("description").getAsString() : "";
82 | JsonObject inputSchema = toolEntry.has("inputSchema") ?
83 | toolEntry.getAsJsonObject("inputSchema") : null;
84 |
85 | return new MCPTool(name, description, inputSchema, serverName);
86 | }
87 |
88 | /**
89 | * Helper method to convert JsonObject to Map recursively
90 | */
91 | private Map jsonObjectToMap(JsonObject jsonObject) {
92 | Map map = new HashMap<>();
93 |
94 | for (String key : jsonObject.keySet()) {
95 | Object value = jsonElementToObject(jsonObject.get(key));
96 | map.put(key, value);
97 | }
98 |
99 | return map;
100 | }
101 |
102 | /**
103 | * Helper method to convert JsonElement to Java object
104 | */
105 | private Object jsonElementToObject(JsonElement element) {
106 | if (element.isJsonPrimitive()) {
107 | if (element.getAsJsonPrimitive().isString()) {
108 | return element.getAsString();
109 | } else if (element.getAsJsonPrimitive().isNumber()) {
110 | return element.getAsNumber();
111 | } else if (element.getAsJsonPrimitive().isBoolean()) {
112 | return element.getAsBoolean();
113 | }
114 | } else if (element.isJsonObject()) {
115 | return jsonObjectToMap(element.getAsJsonObject());
116 | } else if (element.isJsonArray()) {
117 | com.google.gson.JsonArray array = element.getAsJsonArray();
118 | java.util.List list = new java.util.ArrayList<>();
119 | for (int i = 0; i < array.size(); i++) {
120 | list.add(jsonElementToObject(array.get(i)));
121 | }
122 | return list;
123 | }
124 | return null;
125 | }
126 |
127 | @Override
128 | public String toString() {
129 | return String.format("MCPTool{name='%s', server='%s', description='%s'}",
130 | name, serverName, description);
131 | }
132 |
133 | @Override
134 | public boolean equals(Object obj) {
135 | if (this == obj) return true;
136 | if (obj == null || getClass() != obj.getClass()) return false;
137 |
138 | MCPTool mcpTool = (MCPTool) obj;
139 | return name.equals(mcpTool.name) && serverName.equals(mcpTool.serverName);
140 | }
141 |
142 | @Override
143 | public int hashCode() {
144 | return java.util.Objects.hash(name, serverName);
145 | }
146 |
147 | /**
148 | * Get a display name that includes the server
149 | */
150 | public String getDisplayName() {
151 | return String.format("%s (%s)", name, serverName);
152 | }
153 |
154 | /**
155 | * Check if tool name matches (case-insensitive)
156 | */
157 | public boolean matchesName(String toolName) {
158 | return name != null && name.equalsIgnoreCase(toolName);
159 | }
160 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/mcp2/tools/MCPToolResult.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.mcp2.tools;
2 |
3 | /**
4 | * Result of an MCP tool execution.
5 | * Similar to the original MCPToolResult but for MCP 2.0.
6 | */
7 | public class MCPToolResult {
8 |
9 | private final boolean success;
10 | private final String content;
11 | private final String error;
12 |
13 | public MCPToolResult(boolean success, String content, String error) {
14 | this.success = success;
15 | this.content = content;
16 | this.error = error;
17 | }
18 |
19 | /**
20 | * Create successful result
21 | */
22 | public static MCPToolResult success(String content) {
23 | return new MCPToolResult(true, content, null);
24 | }
25 |
26 | /**
27 | * Create error result
28 | */
29 | public static MCPToolResult error(String error) {
30 | return new MCPToolResult(false, null, error);
31 | }
32 |
33 | public boolean isSuccess() {
34 | return success;
35 | }
36 |
37 | public String getContent() {
38 | return content;
39 | }
40 |
41 | public String getError() {
42 | return error;
43 | }
44 |
45 | /**
46 | * Get result as string for display
47 | */
48 | public String getResultText() {
49 | if (success) {
50 | return content != null ? content : "";
51 | } else {
52 | return "Error: " + (error != null ? error : "Unknown error");
53 | }
54 | }
55 |
56 | @Override
57 | public String toString() {
58 | return String.format("MCPToolResult{success=%s, content='%s', error='%s'}",
59 | success, content, error);
60 | }
61 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/mcp2/transport/MCPTransport.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.mcp2.transport;
2 |
3 | import ghidrassist.mcp2.protocol.MCPRequest;
4 | import ghidrassist.mcp2.protocol.MCPResponse;
5 | import java.util.concurrent.CompletableFuture;
6 |
7 | /**
8 | * Abstract transport layer for MCP communication.
9 | * Supports different transport mechanisms (SSE, stdio, etc.).
10 | */
11 | public abstract class MCPTransport {
12 |
13 | protected boolean connected = false;
14 | protected MCPTransportHandler handler;
15 |
16 | /**
17 | * Interface for handling transport events
18 | */
19 | public interface MCPTransportHandler {
20 | void onConnected();
21 | void onDisconnected();
22 | void onResponse(MCPResponse response);
23 | void onError(Throwable error);
24 | }
25 |
26 | /**
27 | * Set the transport event handler
28 | */
29 | public void setHandler(MCPTransportHandler handler) {
30 | this.handler = handler;
31 | }
32 |
33 | /**
34 | * Connect to the MCP server
35 | */
36 | public abstract CompletableFuture connect();
37 |
38 | /**
39 | * Disconnect from the MCP server
40 | */
41 | public abstract CompletableFuture disconnect();
42 |
43 | /**
44 | * Send a request to the server
45 | */
46 | public abstract CompletableFuture sendRequest(MCPRequest request);
47 |
48 | /**
49 | * Send a notification to the server (no response expected)
50 | */
51 | public abstract CompletableFuture sendNotification(MCPRequest notification);
52 |
53 | /**
54 | * Check if transport is connected
55 | */
56 | public boolean isConnected() {
57 | return connected;
58 | }
59 |
60 | /**
61 | * Get transport type name
62 | */
63 | public abstract String getTransportType();
64 |
65 | /**
66 | * Get connection info for debugging
67 | */
68 | public abstract String getConnectionInfo();
69 |
70 | /**
71 | * Notify handler of connection
72 | */
73 | protected void notifyConnected() {
74 | connected = true;
75 | if (handler != null) {
76 | handler.onConnected();
77 | }
78 | }
79 |
80 | /**
81 | * Notify handler of disconnection
82 | */
83 | protected void notifyDisconnected() {
84 | connected = false;
85 | if (handler != null) {
86 | handler.onDisconnected();
87 | }
88 | }
89 |
90 | /**
91 | * Notify handler of response
92 | */
93 | protected void notifyResponse(MCPResponse response) {
94 | if (handler != null) {
95 | handler.onResponse(response);
96 | }
97 | }
98 |
99 | /**
100 | * Notify handler of error
101 | */
102 | protected void notifyError(Throwable error) {
103 | if (handler != null) {
104 | handler.onError(error);
105 | }
106 | }
107 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/resources/GhidrAssistIcons.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.resources;
2 |
3 | import javax.swing.ImageIcon;
4 | import resources.ResourceManager;
5 |
6 | public class GhidrAssistIcons {
7 | public static final ImageIcon ROBOT_ICON = ResourceManager.loadImage("images/robot32.png");
8 | }
9 |
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/services/AnalysisDataService.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.services;
2 |
3 | import ghidrassist.AnalysisDB;
4 | import ghidrassist.GhidrAssistPlugin;
5 | import ghidrassist.LlmApi;
6 | import ghidrassist.apiprovider.APIProviderConfig;
7 |
8 | /**
9 | * Service for managing analysis context and program-specific data.
10 | * Responsible for context storage, retrieval, and management operations.
11 | */
12 | public class AnalysisDataService {
13 |
14 | private final GhidrAssistPlugin plugin;
15 | private final AnalysisDB analysisDB;
16 |
17 | public AnalysisDataService(GhidrAssistPlugin plugin) {
18 | this.plugin = plugin;
19 | this.analysisDB = new AnalysisDB();
20 | }
21 |
22 | /**
23 | * Save context for the current program
24 | */
25 | public void saveContext(String context) {
26 | if (plugin.getCurrentProgram() == null) {
27 | throw new IllegalStateException("No active program to save context for.");
28 | }
29 |
30 | String programHash = plugin.getCurrentProgram().getExecutableSHA256();
31 | analysisDB.upsertContext(programHash, context);
32 | }
33 |
34 | /**
35 | * Get context for the current program
36 | */
37 | public String getContext() {
38 | if (plugin.getCurrentProgram() == null) {
39 | return getDefaultContext();
40 | }
41 |
42 | String programHash = plugin.getCurrentProgram().getExecutableSHA256();
43 | String context = analysisDB.getContext(programHash);
44 |
45 | if (context == null) {
46 | return getDefaultContext();
47 | }
48 |
49 | return context;
50 | }
51 |
52 | /**
53 | * Revert context to default for the current program
54 | */
55 | public String revertToDefaultContext() {
56 | String defaultContext = getDefaultContext();
57 |
58 | if (plugin.getCurrentProgram() != null) {
59 | // Clear custom context, will fall back to default
60 | String programHash = plugin.getCurrentProgram().getExecutableSHA256();
61 | analysisDB.upsertContext(programHash, null);
62 | }
63 |
64 | return defaultContext;
65 | }
66 |
67 | /**
68 | * Check if current program has custom context
69 | */
70 | public boolean hasCustomContext() {
71 | if (plugin.getCurrentProgram() == null) {
72 | return false;
73 | }
74 |
75 | String programHash = plugin.getCurrentProgram().getExecutableSHA256();
76 | String context = analysisDB.getContext(programHash);
77 | return context != null && !context.equals(getDefaultContext());
78 | }
79 |
80 | /**
81 | * Get default system context
82 | */
83 | private String getDefaultContext() {
84 | APIProviderConfig config = GhidrAssistPlugin.getCurrentProviderConfig();
85 | if (config == null) {
86 | return "You are a professional software reverse engineer."; // Fallback
87 | }
88 |
89 | LlmApi llmApi = new LlmApi(config, plugin);
90 | return llmApi.getSystemPrompt();
91 | }
92 |
93 | /**
94 | * Get context statistics
95 | */
96 | public ContextStats getContextStats() {
97 | String currentContext = getContext();
98 | boolean isCustom = hasCustomContext();
99 | String programName = plugin.getCurrentProgram() != null ?
100 | plugin.getCurrentProgram().getName() : "No Program";
101 |
102 | return new ContextStats(programName, currentContext.length(), isCustom);
103 | }
104 |
105 | /**
106 | * Close database resources
107 | */
108 | public void close() {
109 | if (analysisDB != null) {
110 | analysisDB.close();
111 | }
112 | }
113 |
114 | /**
115 | * Statistics about the current context
116 | */
117 | public static class ContextStats {
118 | private final String programName;
119 | private final int contextLength;
120 | private final boolean isCustom;
121 |
122 | public ContextStats(String programName, int contextLength, boolean isCustom) {
123 | this.programName = programName;
124 | this.contextLength = contextLength;
125 | this.isCustom = isCustom;
126 | }
127 |
128 | public String getProgramName() { return programName; }
129 | public int getContextLength() { return contextLength; }
130 | public boolean isCustom() { return isCustom; }
131 |
132 | @Override
133 | public String toString() {
134 | return String.format("Program: %s, Context: %d chars (%s)",
135 | programName, contextLength, isCustom ? "Custom" : "Default");
136 | }
137 | }
138 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/services/CodeAnalysisService.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.services;
2 |
3 | import ghidra.program.model.address.Address;
4 | import ghidra.program.model.listing.Function;
5 | import ghidra.program.model.listing.Program;
6 | import ghidra.util.task.TaskMonitor;
7 | import ghidrassist.AnalysisDB;
8 | import ghidrassist.GhidrAssistPlugin;
9 | import ghidrassist.LlmApi;
10 | import ghidrassist.apiprovider.APIProviderConfig;
11 | import ghidrassist.core.CodeUtils;
12 |
13 | /**
14 | * Service for handling code analysis operations.
15 | * Responsible for explaining functions and lines of code.
16 | */
17 | public class CodeAnalysisService {
18 |
19 | private final GhidrAssistPlugin plugin;
20 | private final AnalysisDB analysisDB;
21 |
22 | public CodeAnalysisService(GhidrAssistPlugin plugin) {
23 | this.plugin = plugin;
24 | this.analysisDB = new AnalysisDB();
25 | }
26 |
27 | /**
28 | * Analyze and explain a function
29 | */
30 | public AnalysisRequest createFunctionAnalysisRequest(Function function) throws Exception {
31 | if (function == null) {
32 | throw new IllegalArgumentException("No function at current location.");
33 | }
34 |
35 | String functionCode = null;
36 | String codeType = null;
37 |
38 | GhidrAssistPlugin.CodeViewType viewType = plugin.checkLastActiveCodeView();
39 | if (viewType == GhidrAssistPlugin.CodeViewType.IS_DECOMPILER) {
40 | functionCode = CodeUtils.getFunctionCode(function, TaskMonitor.DUMMY);
41 | codeType = "pseudo-C";
42 | } else if (viewType == GhidrAssistPlugin.CodeViewType.IS_DISASSEMBLER) {
43 | functionCode = CodeUtils.getFunctionDisassembly(function);
44 | codeType = "assembly";
45 | } else {
46 | throw new Exception("Unknown code view type.");
47 | }
48 |
49 | String prompt = "Explain the following " + codeType + " code:\n```\n" + functionCode + "\n```";
50 | return new AnalysisRequest(AnalysisRequest.Type.FUNCTION, prompt, function);
51 | }
52 |
53 | /**
54 | * Analyze and explain a line of code
55 | */
56 | public AnalysisRequest createLineAnalysisRequest(Address address) throws Exception {
57 | if (address == null) {
58 | throw new IllegalArgumentException("No address at current location.");
59 | }
60 |
61 | String codeLine = null;
62 | String codeType = null;
63 |
64 | GhidrAssistPlugin.CodeViewType viewType = plugin.checkLastActiveCodeView();
65 | if (viewType == GhidrAssistPlugin.CodeViewType.IS_DECOMPILER) {
66 | codeLine = CodeUtils.getLineCode(address, TaskMonitor.DUMMY, plugin.getCurrentProgram());
67 | codeType = "pseudo-C";
68 | } else if (viewType == GhidrAssistPlugin.CodeViewType.IS_DISASSEMBLER) {
69 | codeLine = CodeUtils.getLineDisassembly(address, plugin.getCurrentProgram());
70 | codeType = "assembly";
71 | } else {
72 | throw new Exception("Unknown code view type.");
73 | }
74 |
75 | String prompt = "Explain the following " + codeType + " line:\n```\n" + codeLine + "\n```";
76 | return new AnalysisRequest(AnalysisRequest.Type.LINE, prompt, address);
77 | }
78 |
79 | /**
80 | * Execute an analysis request
81 | */
82 | public void executeAnalysis(AnalysisRequest request, LlmApi.LlmResponseHandler handler) throws Exception {
83 | APIProviderConfig config = GhidrAssistPlugin.getCurrentProviderConfig();
84 | if (config == null) {
85 | throw new Exception("No API provider configured.");
86 | }
87 |
88 | LlmApi llmApi = new LlmApi(config, plugin);
89 | llmApi.sendRequestAsync(request.getPrompt(), handler);
90 | }
91 |
92 | /**
93 | * Store analysis result in database
94 | */
95 | public void storeAnalysisResult(Function function, String prompt, String response) {
96 | if (function != null && plugin.getCurrentProgram() != null) {
97 | analysisDB.upsertAnalysis(
98 | plugin.getCurrentProgram().getExecutableSHA256(),
99 | function.getEntryPoint(),
100 | prompt,
101 | response
102 | );
103 | }
104 | }
105 |
106 | /**
107 | * Get existing analysis for a function
108 | */
109 | public AnalysisDB.Analysis getExistingAnalysis(Function function) {
110 | if (function == null || plugin.getCurrentProgram() == null) {
111 | return null;
112 | }
113 |
114 | return analysisDB.getAnalysis(
115 | plugin.getCurrentProgram().getExecutableSHA256(),
116 | function.getEntryPoint()
117 | );
118 | }
119 |
120 | /**
121 | * Update existing analysis
122 | */
123 | public void updateAnalysis(Function function, String updatedContent) {
124 | if (function == null || plugin.getCurrentProgram() == null) {
125 | throw new IllegalArgumentException("No active program or function");
126 | }
127 |
128 | String programHash = plugin.getCurrentProgram().getExecutableSHA256();
129 | Address functionAddress = function.getEntryPoint();
130 |
131 | // Get existing analysis to preserve the query
132 | AnalysisDB.Analysis existingAnalysis = analysisDB.getAnalysis(programHash, functionAddress);
133 |
134 | if (existingAnalysis == null) {
135 | // Create new entry with generic query
136 | analysisDB.upsertAnalysis(programHash, functionAddress, "Edited explanation", updatedContent);
137 | } else {
138 | // Update existing entry, preserving original query
139 | analysisDB.upsertAnalysis(programHash, functionAddress, existingAnalysis.getQuery(), updatedContent);
140 | }
141 | }
142 |
143 | /**
144 | * Clear analysis data for a function
145 | */
146 | public boolean clearAnalysis(Function function) {
147 | if (function == null || plugin.getCurrentProgram() == null) {
148 | return false;
149 | }
150 |
151 | String programHash = plugin.getCurrentProgram().getExecutableSHA256();
152 | Address functionAddress = function.getEntryPoint();
153 |
154 | return analysisDB.deleteAnalysis(programHash, functionAddress);
155 | }
156 |
157 | /**
158 | * Close database resources
159 | */
160 | public void close() {
161 | if (analysisDB != null) {
162 | analysisDB.close();
163 | }
164 | }
165 |
166 | /**
167 | * Request object for analysis operations
168 | */
169 | public static class AnalysisRequest {
170 | public enum Type {
171 | FUNCTION, LINE
172 | }
173 |
174 | private final Type type;
175 | private final String prompt;
176 | private final Object context; // Function or Address
177 |
178 | public AnalysisRequest(Type type, String prompt, Object context) {
179 | this.type = type;
180 | this.prompt = prompt;
181 | this.context = context;
182 | }
183 |
184 | public Type getType() { return type; }
185 | public String getPrompt() { return prompt; }
186 | public Object getContext() { return context; }
187 |
188 | public Function getFunction() {
189 | return context instanceof Function ? (Function) context : null;
190 | }
191 |
192 | public Address getAddress() {
193 | return context instanceof Address ? (Address) context : null;
194 | }
195 | }
196 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/services/FeedbackService.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.services;
2 |
3 | import ghidrassist.GhidrAssistPlugin;
4 | import ghidrassist.LlmApi;
5 | import ghidrassist.RLHFDatabase;
6 | import ghidrassist.apiprovider.APIProviderConfig;
7 |
8 | /**
9 | * Service for handling RLHF (Reinforcement Learning from Human Feedback) operations.
10 | * Responsible for storing user feedback and managing feedback data.
11 | */
12 | public class FeedbackService {
13 |
14 | private final GhidrAssistPlugin plugin;
15 | private final RLHFDatabase rlhfDB;
16 |
17 | // Cache for last interaction
18 | private String lastPrompt;
19 | private String lastResponse;
20 |
21 | public FeedbackService(GhidrAssistPlugin plugin) {
22 | this.plugin = plugin;
23 | this.rlhfDB = new RLHFDatabase();
24 | }
25 |
26 | /**
27 | * Cache the last prompt and response for feedback
28 | */
29 | public void cacheLastInteraction(String prompt, String response) {
30 | this.lastPrompt = prompt;
31 | this.lastResponse = response;
32 | }
33 |
34 | /**
35 | * Store positive feedback (thumbs up)
36 | */
37 | public void storePositiveFeedback() {
38 | storeFeedback(1);
39 | }
40 |
41 | /**
42 | * Store negative feedback (thumbs down)
43 | */
44 | public void storeNegativeFeedback() {
45 | storeFeedback(0);
46 | }
47 |
48 | /**
49 | * Store feedback with specified rating
50 | */
51 | public void storeFeedback(int feedback) {
52 | if (lastPrompt == null || lastResponse == null) {
53 | throw new IllegalStateException("No recent interaction to provide feedback for.");
54 | }
55 |
56 | APIProviderConfig config = GhidrAssistPlugin.getCurrentProviderConfig();
57 | if (config == null) {
58 | throw new IllegalStateException("No API provider configured.");
59 | }
60 |
61 | LlmApi llmApi = new LlmApi(config, plugin);
62 | String modelName = config.getModel();
63 | String systemContext = llmApi.getSystemPrompt();
64 |
65 | rlhfDB.storeFeedback(modelName, lastPrompt, systemContext, lastResponse, feedback);
66 | }
67 |
68 | /**
69 | * Check if there's a recent interaction available for feedback
70 | */
71 | public boolean hasPendingFeedback() {
72 | return lastPrompt != null && lastResponse != null;
73 | }
74 |
75 | /**
76 | * Clear cached interaction (e.g., after feedback is provided)
77 | */
78 | public void clearCachedInteraction() {
79 | lastPrompt = null;
80 | lastResponse = null;
81 | }
82 |
83 | /**
84 | * Get feedback statistics
85 | */
86 | public FeedbackStats getFeedbackStats() {
87 | // Note: This would require extending RLHFDatabase with stats methods
88 | // For now, return basic info
89 | return new FeedbackStats(hasPendingFeedback());
90 | }
91 |
92 | /**
93 | * Get the last cached prompt (for debugging/info)
94 | */
95 | public String getLastPrompt() {
96 | return lastPrompt;
97 | }
98 |
99 | /**
100 | * Get the last cached response (for debugging/info)
101 | */
102 | public String getLastResponse() {
103 | return lastResponse;
104 | }
105 |
106 | /**
107 | * Close database resources
108 | */
109 | public void close() {
110 | if (rlhfDB != null) {
111 | rlhfDB.close();
112 | }
113 | }
114 |
115 | /**
116 | * Statistics about feedback state
117 | */
118 | public static class FeedbackStats {
119 | private final boolean hasPendingFeedback;
120 |
121 | public FeedbackStats(boolean hasPendingFeedback) {
122 | this.hasPendingFeedback = hasPendingFeedback;
123 | }
124 |
125 | public boolean hasPendingFeedback() { return hasPendingFeedback; }
126 |
127 | @Override
128 | public String toString() {
129 | return String.format("Feedback: %s",
130 | hasPendingFeedback ? "Available" : "No pending feedback");
131 | }
132 | }
133 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/services/RAGManagementService.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.services;
2 |
3 | import ghidrassist.core.RAGEngine;
4 |
5 | import java.io.File;
6 | import java.io.IOException;
7 | import java.util.Arrays;
8 | import java.util.List;
9 |
10 | /**
11 | * Service for managing RAG (Retrieval Augmented Generation) documents.
12 | * Responsible for document ingestion, deletion, and listing operations.
13 | */
14 | public class RAGManagementService {
15 |
16 | /**
17 | * Add documents to the RAG index
18 | */
19 | public void addDocuments(File[] files) throws IOException {
20 | if (files == null || files.length == 0) {
21 | throw new IllegalArgumentException("No files provided for ingestion.");
22 | }
23 |
24 | RAGEngine.ingestDocuments(Arrays.asList(files));
25 | }
26 |
27 | /**
28 | * Delete selected documents from the RAG index
29 | */
30 | public void deleteDocuments(List fileNames) throws IOException {
31 | if (fileNames == null || fileNames.isEmpty()) {
32 | throw new IllegalArgumentException("No documents selected for deletion.");
33 | }
34 |
35 | for (String fileName : fileNames) {
36 | RAGEngine.deleteDocument(fileName);
37 | }
38 | }
39 |
40 | /**
41 | * Get list of indexed files
42 | */
43 | public List getIndexedFiles() throws IOException {
44 | return RAGEngine.listIndexedFiles();
45 | }
46 |
47 | /**
48 | * Check if the RAG index is available and working
49 | */
50 | public boolean isRAGAvailable() {
51 | try {
52 | RAGEngine.listIndexedFiles();
53 | return true;
54 | } catch (IOException e) {
55 | return false;
56 | }
57 | }
58 |
59 | /**
60 | * Get RAG index statistics
61 | */
62 | public RAGIndexStats getIndexStats() throws IOException {
63 | List files = getIndexedFiles();
64 | return new RAGIndexStats(files.size(), files);
65 | }
66 |
67 | /**
68 | * Clear all documents from the RAG index
69 | */
70 | public void clearAllDocuments() throws IOException {
71 | List allFiles = getIndexedFiles();
72 | deleteDocuments(allFiles);
73 | }
74 |
75 | /**
76 | * Statistics about the RAG index
77 | */
78 | public static class RAGIndexStats {
79 | private final int totalFiles;
80 | private final List fileNames;
81 |
82 | public RAGIndexStats(int totalFiles, List fileNames) {
83 | this.totalFiles = totalFiles;
84 | this.fileNames = fileNames;
85 | }
86 |
87 | public int getTotalFiles() { return totalFiles; }
88 | public List getFileNames() { return fileNames; }
89 |
90 | @Override
91 | public String toString() {
92 | return String.format("RAG Index: %d files indexed", totalFiles);
93 | }
94 | }
95 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/ui/GhidrAssistUI.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.ui;
2 |
3 | import javax.swing.*;
4 |
5 | import ghidra.program.util.ProgramLocation;
6 |
7 | import java.awt.*;
8 | import ghidrassist.GhidrAssistPlugin;
9 | import ghidrassist.ui.tabs.*;
10 | import ghidrassist.core.TabController;
11 | import ghidrassist.ui.common.UIConstants;
12 |
13 | public class GhidrAssistUI extends JPanel {
14 | private static final long serialVersionUID = 1L;
15 | private final GhidrAssistPlugin plugin;
16 | private final TabController controller;
17 | private final JTabbedPane tabbedPane;
18 | private final ExplainTab explainTab;
19 | private final QueryTab queryTab;
20 | private final ActionsTab actionsTab;
21 | private final RAGManagementTab ragManagementTab;
22 | private final AnalysisOptionsTab analysisOptionsTab;
23 |
24 | public GhidrAssistUI(GhidrAssistPlugin plugin) {
25 | super(new BorderLayout());
26 | this.plugin = plugin;
27 | this.controller = new TabController(plugin);
28 |
29 | // Initialize components
30 | this.tabbedPane = new JTabbedPane();
31 |
32 | // Create tabs
33 | this.explainTab = new ExplainTab(controller);
34 | this.queryTab = new QueryTab(controller);
35 | this.actionsTab = new ActionsTab(controller);
36 | this.ragManagementTab = new RAGManagementTab(controller);
37 | this.analysisOptionsTab = new AnalysisOptionsTab(controller);
38 |
39 | // Set tab references in controller
40 | controller.setExplainTab(explainTab);
41 | controller.setQueryTab(queryTab);
42 | controller.setActionsTab(actionsTab);
43 | controller.setRAGManagementTab(ragManagementTab);
44 | controller.setAnalysisOptionsTab(analysisOptionsTab);
45 |
46 | initializeUI();
47 | }
48 |
49 | private void initializeUI() {
50 | setBorder(UIConstants.PANEL_BORDER);
51 |
52 | // Add tabs
53 | tabbedPane.addTab("Explain", explainTab);
54 | tabbedPane.addTab("Custom Query", queryTab);
55 | tabbedPane.addTab("Actions", actionsTab);
56 | tabbedPane.addTab("RAG Management", ragManagementTab);
57 | tabbedPane.addTab("Analysis Options", analysisOptionsTab);
58 |
59 | add(tabbedPane, BorderLayout.CENTER);
60 |
61 | // Initialize tabs that need startup data
62 | SwingUtilities.invokeLater(() -> {
63 | // Load initial context
64 | controller.handleContextRevert();
65 |
66 | // Load RAG file list
67 | controller.loadIndexedFiles(ragManagementTab.getDocumentList());
68 | });
69 |
70 | tabbedPane.addChangeListener(e -> {
71 | if (tabbedPane.getSelectedComponent() == analysisOptionsTab) {
72 | // Refresh context when Analysis Options tab is selected
73 | controller.handleContextRevert();
74 | }
75 | });
76 | }
77 |
78 | public void updateLocation(ProgramLocation loc) {
79 | if (loc != null && loc.getAddress() != null) {
80 | explainTab.updateOffset(loc.getAddress().toString());
81 | controller.updateAnalysis(loc);
82 | }
83 | }
84 |
85 | public JComponent getComponent() {
86 | return this;
87 | }
88 |
89 | public GhidrAssistPlugin getPlugin() {
90 | return plugin;
91 | }
92 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/ui/common/PlaceholderTextField.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.ui.common;
2 |
3 | import javax.swing.JTextField;
4 | import java.awt.Color;
5 | import java.awt.event.FocusEvent;
6 | import java.awt.event.FocusListener;
7 |
8 | public class PlaceholderTextField extends JTextField {
9 | private boolean showingPlaceholder;
10 | private final String placeholder;
11 | private final Color placeholderColor;
12 | private final Color textColor;
13 |
14 | public PlaceholderTextField(String placeholder, int columns) {
15 | super(columns);
16 | this.placeholder = placeholder;
17 | this.showingPlaceholder = true;
18 | this.placeholderColor = UIConstants.PLACEHOLDER_COLOR;
19 | this.textColor = getForeground();
20 |
21 | setupPlaceholder();
22 | }
23 |
24 | private void setupPlaceholder() {
25 | setText(placeholder);
26 | setForeground(placeholderColor);
27 |
28 | addFocusListener(new FocusListener() {
29 | @Override
30 | public void focusGained(FocusEvent e) {
31 | if (showingPlaceholder) {
32 | showingPlaceholder = false;
33 | setText("");
34 | setForeground(textColor);
35 | }
36 | }
37 |
38 | @Override
39 | public void focusLost(FocusEvent e) {
40 | if (getText().isEmpty()) {
41 | showingPlaceholder = true;
42 | setText(placeholder);
43 | setForeground(placeholderColor);
44 | }
45 | }
46 | });
47 | }
48 |
49 | @Override
50 | public String getText() {
51 | return showingPlaceholder ? "" : super.getText();
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/ui/common/UIConstants.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.ui.common;
2 |
3 | import java.awt.Color;
4 | import java.awt.Dimension;
5 | import javax.swing.BorderFactory;
6 | import javax.swing.border.Border;
7 |
8 | public class UIConstants {
9 | public static final int PADDING = 5;
10 | public static final int BUTTON_SPACING = 5;
11 | public static final int TEXT_FIELD_COLUMNS = 20;
12 | public static final Dimension PREFERRED_BUTTON_SIZE = new Dimension(120, 30);
13 | public static final Border PANEL_BORDER = BorderFactory.createEmptyBorder(PADDING, PADDING, PADDING, PADDING);
14 | public static final Color PLACEHOLDER_COLOR = Color.GRAY;
15 | }
16 |
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/ui/tabs/ActionsTab.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.ui.tabs;
2 |
3 | import javax.swing.*;
4 | import javax.swing.table.DefaultTableModel;
5 | import java.awt.*;
6 | import java.util.Map;
7 | import java.util.HashMap;
8 | import ghidrassist.core.TabController;
9 | import ghidrassist.core.ActionConstants;
10 |
11 | public class ActionsTab extends JPanel {
12 | private static final long serialVersionUID = 1L;
13 | private final TabController controller;
14 | private JTable actionsTable;
15 | private Map filterCheckBoxes;
16 | private JButton analyzeFunctionButton;
17 | private JButton analyzeClearButton;
18 | private JButton applyActionsButton;
19 |
20 | public ActionsTab(TabController controller) {
21 | super(new BorderLayout());
22 | this.controller = controller;
23 | initializeComponents();
24 | layoutComponents();
25 | setupListeners();
26 | }
27 |
28 | private void initializeComponents() {
29 | // Initialize table
30 | actionsTable = createActionsTable();
31 |
32 | // Initialize filter checkboxes
33 | filterCheckBoxes = createFilterCheckboxes();
34 |
35 | // Initialize buttons
36 | analyzeFunctionButton = new JButton("Analyze Function");
37 | analyzeClearButton = new JButton("Clear");
38 | applyActionsButton = new JButton("Apply Actions");
39 | }
40 |
41 | private JTable createActionsTable() {
42 | DefaultTableModel model = new DefaultTableModel(
43 | new Object[]{"Select", "Action", "Description", "Status", "Arguments"}, 0) {
44 | private static final long serialVersionUID = 1L;
45 |
46 | @Override
47 | public Class> getColumnClass(int column) {
48 | return column == 0 ? Boolean.class : String.class;
49 | }
50 | };
51 |
52 | JTable table = new JTable(model);
53 | int w = table.getColumnModel().getColumn(0).getWidth();
54 | table.getColumnModel().getColumn(0).setMaxWidth((int)((double) (w*0.8)));
55 | return table;
56 | }
57 |
58 | private Map createFilterCheckboxes() {
59 | Map checkboxes = new HashMap<>();
60 | for (Map fnTemplate : ActionConstants.FN_TEMPLATES) {
61 | if (fnTemplate.get("type").equals("function")) {
62 | @SuppressWarnings("unchecked")
63 | Map functionMap = (Map) fnTemplate.get("function");
64 | String fnName = functionMap.get("name").toString();
65 | String fnDescription = functionMap.get("description").toString();
66 | String checkboxLabel = fnName.replace("_", " ") + ": " + fnDescription;
67 | checkboxes.put(fnName, new JCheckBox(checkboxLabel, true));
68 | }
69 | }
70 | return checkboxes;
71 | }
72 |
73 | private void layoutComponents() {
74 | // Filter panel
75 | JPanel filterPanel = new JPanel();
76 | filterPanel.setLayout(new BoxLayout(filterPanel, BoxLayout.Y_AXIS));
77 | filterPanel.setBorder(BorderFactory.createTitledBorder("Filters"));
78 | filterCheckBoxes.values().forEach(filterPanel::add);
79 |
80 | JScrollPane filterScrollPane = new JScrollPane(filterPanel);
81 | filterScrollPane.setPreferredSize(new Dimension(200, 150));
82 | add(filterScrollPane, BorderLayout.NORTH);
83 |
84 | // Table
85 | add(new JScrollPane(actionsTable), BorderLayout.CENTER);
86 |
87 | // Buttons
88 | JPanel buttonPanel = new JPanel();
89 | buttonPanel.add(analyzeFunctionButton);
90 | buttonPanel.add(analyzeClearButton);
91 | buttonPanel.add(applyActionsButton);
92 | add(buttonPanel, BorderLayout.SOUTH);
93 | }
94 |
95 | private void setupListeners() {
96 | analyzeFunctionButton.addActionListener(e ->
97 | controller.handleAnalyzeFunction(filterCheckBoxes));
98 | analyzeClearButton.addActionListener(e ->
99 | ((DefaultTableModel)actionsTable.getModel()).setRowCount(0));
100 | applyActionsButton.addActionListener(e ->
101 | controller.handleApplyActions(actionsTable));
102 | }
103 |
104 | public DefaultTableModel getTableModel() {
105 | return (DefaultTableModel)actionsTable.getModel();
106 | }
107 |
108 | public void setAnalyzeFunctionButtonText(String text) {
109 | analyzeFunctionButton.setText(text);
110 | }
111 | }
112 |
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/ui/tabs/AnalysisOptionsTab.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.ui.tabs;
2 |
3 | import javax.swing.*;
4 | import java.awt.*;
5 | import ghidrassist.core.TabController;
6 |
7 | public class AnalysisOptionsTab extends JPanel {
8 | private final TabController controller;
9 | private JTextArea contextArea;
10 | private JButton saveButton;
11 | private JButton revertButton;
12 |
13 | public AnalysisOptionsTab(TabController controller) {
14 | super(new BorderLayout());
15 | this.controller = controller;
16 | initializeComponents();
17 | layoutComponents();
18 | setupListeners();
19 | }
20 |
21 | private void initializeComponents() {
22 | contextArea = new JTextArea();
23 | contextArea.setFont(new Font("Monospaced", Font.PLAIN, 12));
24 | contextArea.setLineWrap(true);
25 | contextArea.setWrapStyleWord(true);
26 |
27 | saveButton = new JButton("Save");
28 | revertButton = new JButton("Revert");
29 | }
30 |
31 | private void layoutComponents() {
32 | // Add header label
33 | JLabel headerLabel = new JLabel("System Context");
34 | headerLabel.setBorder(BorderFactory.createEmptyBorder(5, 5, 5, 5));
35 | add(headerLabel, BorderLayout.NORTH);
36 |
37 | // Add text area with scroll pane
38 | JScrollPane scrollPane = new JScrollPane(contextArea);
39 | add(scrollPane, BorderLayout.CENTER);
40 |
41 | // Add button panel
42 | JPanel buttonPanel = new JPanel(new FlowLayout(FlowLayout.RIGHT));
43 | buttonPanel.add(revertButton);
44 | buttonPanel.add(saveButton);
45 | add(buttonPanel, BorderLayout.SOUTH);
46 | }
47 |
48 | private void setupListeners() {
49 | saveButton.addActionListener(e -> controller.handleContextSave(contextArea.getText()));
50 | revertButton.addActionListener(e -> controller.handleContextRevert());
51 | }
52 |
53 | public void setContextText(String text) {
54 | contextArea.setText(text);
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/ui/tabs/ExplainTab.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.ui.tabs;
2 |
3 | import javax.swing.*;
4 | import java.awt.*;
5 | import ghidrassist.core.MarkdownHelper;
6 | import ghidrassist.core.TabController;
7 |
8 | public class ExplainTab extends JPanel {
9 | private static final long serialVersionUID = 1L;
10 | private final TabController controller;
11 | private final MarkdownHelper markdownHelper;
12 | private JLabel offsetLabel;
13 | private JTextField offsetField;
14 | private JEditorPane explainTextPane;
15 | private JTextArea markdownTextArea;
16 | private JButton explainFunctionButton;
17 | private JButton explainLineButton;
18 | private JButton clearExplainButton;
19 | private JButton editSaveButton;
20 | private JPanel contentPanel;
21 | private CardLayout contentLayout;
22 | private boolean isEditMode = false;
23 | private String currentMarkdown = "";
24 |
25 | public ExplainTab(TabController controller) {
26 | super(new BorderLayout());
27 | this.controller = controller;
28 | this.markdownHelper = new MarkdownHelper();
29 | initializeComponents();
30 | layoutComponents();
31 | setupListeners();
32 | }
33 |
34 | private void initializeComponents() {
35 | // Initialize offset field
36 | offsetLabel = new JLabel("Offset: ");
37 | offsetField = new JTextField(16);
38 | offsetField.setEditable(false);
39 |
40 | // Initialize text pane for HTML viewing
41 | explainTextPane = new JEditorPane();
42 | explainTextPane.setEditable(false);
43 | explainTextPane.setContentType("text/html");
44 | explainTextPane.addHyperlinkListener(controller::handleHyperlinkEvent);
45 |
46 | // Initialize text area for Markdown editing
47 | markdownTextArea = new JTextArea();
48 | markdownTextArea.setFont(new Font("Monospaced", Font.PLAIN, 12));
49 | markdownTextArea.setLineWrap(true);
50 | markdownTextArea.setWrapStyleWord(true);
51 |
52 | // Initialize buttons
53 | explainFunctionButton = new JButton("Explain Function");
54 | explainLineButton = new JButton("Explain Line");
55 | clearExplainButton = new JButton("Clear");
56 | editSaveButton = new JButton("Edit");
57 |
58 | // Setup card layout for switching between view and edit modes
59 | contentLayout = new CardLayout();
60 | contentPanel = new JPanel(contentLayout);
61 | contentPanel.add(new JScrollPane(explainTextPane), "view");
62 | contentPanel.add(new JScrollPane(markdownTextArea), "edit");
63 | }
64 |
65 | private void layoutComponents() {
66 | // Offset and Edit/Save panel
67 | JPanel topPanel = new JPanel(new BorderLayout());
68 |
69 | JPanel offsetPanel = new JPanel();
70 | offsetPanel.add(offsetLabel);
71 | offsetPanel.add(offsetField);
72 | topPanel.add(offsetPanel, BorderLayout.WEST);
73 |
74 | JPanel editPanel = new JPanel();
75 | editPanel.add(editSaveButton);
76 | topPanel.add(editPanel, BorderLayout.EAST);
77 |
78 | add(topPanel, BorderLayout.NORTH);
79 |
80 | // Text content panel with card layout
81 | add(contentPanel, BorderLayout.CENTER);
82 |
83 | // Button panel
84 | JPanel buttonPanel = new JPanel();
85 | buttonPanel.add(explainFunctionButton);
86 | buttonPanel.add(explainLineButton);
87 | buttonPanel.add(clearExplainButton);
88 | add(buttonPanel, BorderLayout.SOUTH);
89 | }
90 |
91 | private void setupListeners() {
92 | explainFunctionButton.addActionListener(e -> controller.handleExplainFunction());
93 | explainLineButton.addActionListener(e -> controller.handleExplainLine());
94 | clearExplainButton.addActionListener(e -> {
95 | // Clear the UI
96 | explainTextPane.setText("");
97 | markdownTextArea.setText("");
98 | currentMarkdown = "";
99 |
100 | // Also clear from database
101 | controller.handleClearAnalysisData();
102 | });
103 |
104 | editSaveButton.addActionListener(e -> {
105 | if (isEditMode) {
106 | // Save mode - save the markdown and switch to view mode
107 | currentMarkdown = markdownTextArea.getText();
108 | String html = markdownHelper.markdownToHtml(currentMarkdown);
109 | explainTextPane.setText(html);
110 |
111 | // Save to database
112 | controller.handleUpdateAnalysis(currentMarkdown);
113 |
114 | // Switch to view mode
115 | contentLayout.show(contentPanel, "view");
116 | editSaveButton.setText("Edit");
117 | isEditMode = false;
118 | } else {
119 | // Edit mode - switch to the markdown editor
120 | markdownTextArea.setText(currentMarkdown);
121 |
122 | // Switch to edit mode
123 | contentLayout.show(contentPanel, "edit");
124 | editSaveButton.setText("Save");
125 | isEditMode = true;
126 | }
127 | });
128 | }
129 |
130 | public void updateOffset(String offset) {
131 | offsetField.setText(offset);
132 | }
133 |
134 | public void setExplanationText(String text) {
135 | explainTextPane.setText(text);
136 | explainTextPane.setCaretPosition(0);
137 |
138 | // Store the markdown equivalent
139 | currentMarkdown = markdownHelper.extractMarkdownFromLlmResponse(text);
140 |
141 | // If we're in edit mode, update the markdown text area too
142 | if (isEditMode) {
143 | markdownTextArea.setText(currentMarkdown);
144 | }
145 |
146 | // Switch to view mode if we're setting new content
147 | if (isEditMode) {
148 | contentLayout.show(contentPanel, "view");
149 | editSaveButton.setText("Edit");
150 | isEditMode = false;
151 | }
152 | }
153 |
154 | public void setFunctionButtonText(String text) {
155 | explainFunctionButton.setText(text);
156 | }
157 |
158 | public void setLineButtonText(String text) {
159 | explainLineButton.setText(text);
160 | }
161 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/ui/tabs/MCPServerDialog.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.ui.tabs;
2 |
3 | import java.awt.*;
4 | import java.awt.event.ActionEvent;
5 | import java.awt.event.ActionListener;
6 | import javax.swing.*;
7 |
8 | import ghidrassist.mcp2.server.MCPServerConfig;
9 |
10 | public class MCPServerDialog extends JDialog {
11 | private static final long serialVersionUID = 1L;
12 |
13 | private JTextField nameField;
14 | private JTextField urlField;
15 | private JComboBox transportCombo;
16 | private JCheckBox enabledCheckBox;
17 | private JButton okButton;
18 | private JButton cancelButton;
19 | private boolean confirmed = false;
20 |
21 | public MCPServerDialog(Window parent, MCPServerConfig existingServer) {
22 | super(parent, existingServer == null ? "Add MCP Server" : "Edit MCP Server",
23 | ModalityType.APPLICATION_MODAL);
24 |
25 | initializeComponents();
26 | layoutComponents();
27 | setupEventHandlers();
28 |
29 | if (existingServer != null) {
30 | populateFields(existingServer);
31 | } else {
32 | setDefaults();
33 | }
34 |
35 | pack();
36 | setLocationRelativeTo(parent);
37 | nameField.requestFocusInWindow();
38 | }
39 |
40 | private void initializeComponents() {
41 | nameField = new JTextField(20);
42 | urlField = new JTextField(30);
43 | transportCombo = new JComboBox<>(MCPServerConfig.TransportType.values());
44 | enabledCheckBox = new JCheckBox("Enabled", true);
45 |
46 | okButton = new JButton("OK");
47 | cancelButton = new JButton("Cancel");
48 |
49 | getRootPane().setDefaultButton(okButton);
50 | }
51 |
52 | private void layoutComponents() {
53 | setLayout(new BorderLayout());
54 |
55 | // Form panel
56 | JPanel formPanel = new JPanel(new GridBagLayout());
57 | GridBagConstraints gbc = new GridBagConstraints();
58 | gbc.insets = new Insets(5, 5, 5, 5);
59 |
60 | // Name
61 | gbc.gridx = 0; gbc.gridy = 0; gbc.anchor = GridBagConstraints.EAST;
62 | formPanel.add(new JLabel("Name:"), gbc);
63 | gbc.gridx = 1; gbc.anchor = GridBagConstraints.WEST; gbc.fill = GridBagConstraints.HORIZONTAL;
64 | formPanel.add(nameField, gbc);
65 |
66 | // URL
67 | gbc.gridx = 0; gbc.gridy = 1; gbc.anchor = GridBagConstraints.EAST; gbc.fill = GridBagConstraints.NONE;
68 | formPanel.add(new JLabel("URL:"), gbc);
69 | gbc.gridx = 1; gbc.anchor = GridBagConstraints.WEST; gbc.fill = GridBagConstraints.HORIZONTAL;
70 | formPanel.add(urlField, gbc);
71 |
72 | // Transport
73 | gbc.gridx = 0; gbc.gridy = 2; gbc.anchor = GridBagConstraints.EAST; gbc.fill = GridBagConstraints.NONE;
74 | formPanel.add(new JLabel("Transport:"), gbc);
75 | gbc.gridx = 1; gbc.anchor = GridBagConstraints.WEST; gbc.fill = GridBagConstraints.HORIZONTAL;
76 | formPanel.add(transportCombo, gbc);
77 |
78 | // Enabled
79 | gbc.gridx = 1; gbc.gridy = 3; gbc.anchor = GridBagConstraints.WEST;
80 | formPanel.add(enabledCheckBox, gbc);
81 |
82 | // Help text
83 | JPanel helpPanel = new JPanel(new BorderLayout());
84 | JTextArea helpText = new JTextArea(
85 | "Examples:\n" +
86 | "• Name: GhidraMCP, URL: http://localhost:8080\n" +
87 | "• Name: Local Tools, URL: http://127.0.0.1:3000\n\n" +
88 | "The server must implement the Model Context Protocol (MCP) specification."
89 | );
90 | helpText.setEditable(false);
91 | helpText.setOpaque(false);
92 | helpText.setFont(helpText.getFont().deriveFont(Font.ITALIC, 11f));
93 | helpPanel.add(helpText, BorderLayout.CENTER);
94 | helpPanel.setBorder(BorderFactory.createTitledBorder("Help"));
95 |
96 | // Button panel
97 | JPanel buttonPanel = new JPanel(new FlowLayout(FlowLayout.RIGHT));
98 | buttonPanel.add(okButton);
99 | buttonPanel.add(cancelButton);
100 |
101 | // Layout
102 | add(formPanel, BorderLayout.CENTER);
103 | add(helpPanel, BorderLayout.NORTH);
104 | add(buttonPanel, BorderLayout.SOUTH);
105 |
106 | // Add border to content pane instead
107 | getRootPane().setBorder(BorderFactory.createEmptyBorder(10, 10, 10, 10));
108 | }
109 |
110 | private void setupEventHandlers() {
111 | okButton.addActionListener(e -> {
112 | if (validateInput()) {
113 | confirmed = true;
114 | dispose();
115 | }
116 | });
117 |
118 | cancelButton.addActionListener(e -> {
119 | confirmed = false;
120 | dispose();
121 | });
122 |
123 | // Transport selection updates URL placeholder
124 | transportCombo.addActionListener(e -> updateUrlPlaceholder());
125 | }
126 |
127 | private void setDefaults() {
128 | transportCombo.setSelectedItem(MCPServerConfig.TransportType.SSE);
129 | updateUrlPlaceholder();
130 | }
131 |
132 | private void populateFields(MCPServerConfig server) {
133 | nameField.setText(server.getName());
134 | urlField.setText(server.getBaseUrl());
135 | transportCombo.setSelectedItem(server.getTransport());
136 | enabledCheckBox.setSelected(server.isEnabled());
137 | }
138 |
139 | private void updateUrlPlaceholder() {
140 | MCPServerConfig.TransportType transport =
141 | (MCPServerConfig.TransportType) transportCombo.getSelectedItem();
142 |
143 | if (transport == MCPServerConfig.TransportType.SSE) {
144 | urlField.setToolTipText("HTTP(S) URL for Server-Sent Events transport (e.g., http://localhost:8080)");
145 | } else {
146 | urlField.setToolTipText("Command or path for stdio transport");
147 | }
148 | }
149 |
150 | private boolean validateInput() {
151 | String name = nameField.getText().trim();
152 | String url = urlField.getText().trim();
153 |
154 | if (name.isEmpty()) {
155 | showError("Name cannot be empty.");
156 | nameField.requestFocus();
157 | return false;
158 | }
159 |
160 | if (url.isEmpty()) {
161 | showError("URL cannot be empty.");
162 | urlField.requestFocus();
163 | return false;
164 | }
165 |
166 | // Basic URL validation for SSE transport
167 | MCPServerConfig.TransportType transport =
168 | (MCPServerConfig.TransportType) transportCombo.getSelectedItem();
169 |
170 | if (transport == MCPServerConfig.TransportType.SSE) {
171 | if (!url.startsWith("http://") && !url.startsWith("https://")) {
172 | showError("URL must start with http:// or https:// for SSE transport.");
173 | urlField.requestFocus();
174 | return false;
175 | }
176 | }
177 |
178 | return true;
179 | }
180 |
181 | private void showError(String message) {
182 | JOptionPane.showMessageDialog(this, message, "Validation Error", JOptionPane.ERROR_MESSAGE);
183 | }
184 |
185 | public boolean isConfirmed() {
186 | return confirmed;
187 | }
188 |
189 | public MCPServerConfig getServerConfig() {
190 | if (!confirmed) return null;
191 |
192 | return new MCPServerConfig(
193 | nameField.getText().trim(),
194 | urlField.getText().trim(),
195 | (MCPServerConfig.TransportType) transportCombo.getSelectedItem(),
196 | enabledCheckBox.isSelected()
197 | );
198 | }
199 | }
--------------------------------------------------------------------------------
/src/main/java/ghidrassist/ui/tabs/RAGManagementTab.java:
--------------------------------------------------------------------------------
1 | package ghidrassist.ui.tabs;
2 |
3 | import javax.swing.*;
4 | import java.awt.*;
5 | import ghidrassist.core.TabController;
6 |
7 | public class RAGManagementTab extends JPanel {
8 | private static final long serialVersionUID = 1L;
9 | private final TabController controller;
10 | private JList documentList;
11 | private JButton addDocumentsButton;
12 | private JButton deleteSelectedButton;
13 | private JButton refreshListButton;
14 |
15 | public RAGManagementTab(TabController controller) {
16 | super(new BorderLayout());
17 | this.controller = controller;
18 | initializeComponents();
19 | layoutComponents();
20 | setupListeners();
21 | }
22 |
23 | private void initializeComponents() {
24 | addDocumentsButton = new JButton("Add Documents to RAG");
25 | documentList = new JList<>();
26 | deleteSelectedButton = new JButton("Delete Selected");
27 | refreshListButton = new JButton("Refresh List");
28 | }
29 |
30 | private void layoutComponents() {
31 | add(addDocumentsButton, BorderLayout.NORTH);
32 | add(new JScrollPane(documentList), BorderLayout.CENTER);
33 |
34 | JPanel buttonPanel = new JPanel();
35 | buttonPanel.add(deleteSelectedButton);
36 | buttonPanel.add(refreshListButton);
37 | add(buttonPanel, BorderLayout.SOUTH);
38 | }
39 |
40 | private void setupListeners() {
41 | addDocumentsButton.addActionListener(e ->
42 | controller.handleAddDocuments(documentList));
43 | deleteSelectedButton.addActionListener(e ->
44 | controller.handleDeleteSelected(documentList));
45 | refreshListButton.addActionListener(e ->
46 | controller.loadIndexedFiles(documentList));
47 | }
48 |
49 | public void updateDocumentList(String[] files) {
50 | documentList.setListData(files);
51 | }
52 |
53 | public JList getDocumentList() {
54 | return documentList;
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/src/main/resources/images/README.txt:
--------------------------------------------------------------------------------
1 | The "src/resources/images" directory is intended to hold all image/icon files used by
2 | this module.
3 |
--------------------------------------------------------------------------------
/src/main/resources/images/robot.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/src/main/resources/images/robot16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jtang613/GhidrAssist/f6bd628c273e8ff3da9b55703bcaa90e26c58f4d/src/main/resources/images/robot16.png
--------------------------------------------------------------------------------
/src/main/resources/images/robot32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jtang613/GhidrAssist/f6bd628c273e8ff3da9b55703bcaa90e26c58f4d/src/main/resources/images/robot32.png
--------------------------------------------------------------------------------