├── .config ├── __init__.py ├── config.json └── model_config.json ├── .dockerignore ├── .gitignore ├── BaseMachine ├── __init__.py ├── action_utils.py ├── agent_action_utils.py ├── code_filling │ ├── __init__.py │ ├── code_filling_config.py │ ├── code_filling_context.py │ └── code_filling_tools.py ├── config_loader.py ├── llm_helpers.py ├── logger.py ├── model_manager.py └── state_machine.py ├── Dockerfile ├── QLWorkflow ├── _01_ql_query_modification │ ├── __init__.py │ ├── modification_config.py │ ├── modification_context.py │ └── modification_tools.py ├── _02_run_ql_query │ ├── __init__.py │ ├── query_config.py │ ├── query_context.py │ └── query_tools.py ├── _03_output_validation │ ├── __init__.py │ ├── validation_config.py │ ├── validation_context.py │ └── validation_tools.py ├── _04_iteration_control │ ├── __init__.py │ ├── iteration_config.py │ ├── iteration_context.py │ └── iteration_tools.py ├── __init__.py ├── pipeline_config.py └── util │ ├── __init__.py │ ├── evaluation_utils.py │ ├── function_dump.ql │ └── logging_utils.py ├── README.md ├── docs └── images │ └── architecture.png ├── draw ├── initial_vs_final_comparison.png └── plot_initial_vs_final.py ├── requirements.txt ├── run_juliet.py ├── run_ql_workflow.py └── start_docker.sh /.config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/P1umer/QL-Relax/941ffa7fb0b0b196ba9e4e49e99f5b7fe6e47138/.config/__init__.py -------------------------------------------------------------------------------- /.config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "selected_model": "azure-gpt4o", 3 | "api_key": "dummy_key_for_openai", 4 | "azure_key": "dummy_key_for_xxxx", 5 | "azure_endpoint": "https://gpt4-func-sweden.openai.azure.com/openai/deployments/caozong-exp/chat/completions?api-version=2024-02-15-preview", 6 | "siliconflow_key": "dummy_key_for_siliconflow", 7 | "ds_azure_key": "dummy_key_for_xxx", 8 | "ds_azure_endpoint": "https://ai-test324936320153.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview", 9 | "openrouter_api_key": "dummy_key_for_openrouter", 10 | "max_tokens": 8192, 11 | "temperature": 0.1, 12 | "top_p": 1.0, 13 | "stop_sequences": ["\n"] 14 | } -------------------------------------------------------------------------------- /.config/model_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "models": { 3 | "gpt-4o": { 4 | "provider": "openai", 5 | "model_name": "gpt-4o", 6 | "max_tokens": 8192, 7 | "description": "OpenAI GPT-4o model" 8 | }, 9 | "gpt-3.5-turbo": { 10 | "provider": "openai", 11 | "model_name": "gpt-3.5-turbo", 12 | "max_tokens": 4096, 13 | "description": "OpenAI GPT-3.5 Turbo model" 14 | }, 15 | "azure-gpt4o": { 16 | "provider": "azure", 17 | "model_name": "gpt-4o", 18 | "deployment_name": "caozong-exp", 19 | "api_version": "2024-08-01-preview", 20 | "max_tokens": 8192, 21 | "description": "Azure OpenAI GPT-4o deployment" 22 | }, 23 | "deepseek-v2.5": { 24 | "provider": "siliconflow", 25 | "model_name": "deepseek-ai/DeepSeek-V2.5", 26 | "max_tokens": 4096, 27 | "description": "DeepSeek V2.5 model via SiliconFlow" 28 | }, 29 | "deepseek-v3": { 30 | "provider": "siliconflow", 31 | "model_name": "Pro/deepseek-ai/DeepSeek-V3", 32 | "max_tokens": 4096, 33 | "description": "DeepSeek V3 model via SiliconFlow" 34 | }, 35 | "deepseek-r1": { 36 | "provider": "siliconflow", 37 | "model_name": "Pro/deepseek-ai/DeepSeek-R1", 38 | "max_tokens": 8192, 39 | "description": "DeepSeek R1 model via SiliconFlow" 40 | }, 41 | "azure-deepseek-r1": { 42 | "provider": "azure-deepseek", 43 | "model_name": "DeepSeek-R1", 44 | "deployment_name": "deepseek-r1", 45 | "api_version": "2024-05-01-preview", 46 | "max_tokens": 8192, 47 | "description": "Azure DeepSeek R1 deployment" 48 | }, 49 | "openrouter-deepseek-r1": { 50 | "description": "OpenRouter Model", 51 | "provider": "openrouter", 52 | "model_name": "deepseek/deepseek-r1", 53 | "max_tokens": 163840, 54 | "api_version": "v1", 55 | "openrouter_provider": { 56 | "order": ["Anthropic", "Fireworks"], 57 | "allow_fallbacks": true, 58 | "sort": "throughput", 59 | "ignore": ["Meta"], 60 | "require_parameters": false 61 | } 62 | }, 63 | "qwq-32b": { 64 | "provider": "openrouter", 65 | "model_name": "qwen/qwq-32b", 66 | "max_tokens": 100000, 67 | "description": "QwQ 32B model", 68 | "api_version": "v1", 69 | "openrouter_provider": { 70 | "order": ["Fireworks"], 71 | "allow_fallbacks": false 72 | }, 73 | "reasoning": { 74 | "effort": "high", 75 | "exclude": true 76 | } 77 | }, 78 | "gemini-2.0-flash-lite-preview": { 79 | "provider": "openrouter", 80 | "model_name": "google/gemini-2.0-flash-lite-preview-02-05:free", 81 | "max_tokens": 1000000, 82 | "description": "Gemini 2.0 Flash Lite Preview model", 83 | "openrouter_provider": { 84 | "order": ["Google"], 85 | "allow_fallbacks": false, 86 | "data_collection": "disallow" 87 | } 88 | } 89 | }, 90 | "provider_configs": { 91 | "openai": { 92 | "base_url": "https://api.openai.com/v1", 93 | "requires": ["api_key"] 94 | }, 95 | "azure": { 96 | "requires": ["azure_key", "azure_endpoint"] 97 | }, 98 | "siliconflow": { 99 | "base_url": "https://api.siliconflow.cn/v1", 100 | "requires": ["siliconflow_key"] 101 | }, 102 | "azure-deepseek": { 103 | "requires": ["ds_azure_key", "ds_azure_endpoint"] 104 | }, 105 | "openrouter": { 106 | "requires": ["openrouter_api_key"] 107 | } 108 | } 109 | } -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | juliet-test-suite-c/ 2 | qlresult-origin/ 3 | *.log 4 | **/*.log 5 | **/log/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Log files 2 | logs/ 3 | *.log 4 | 5 | # Python cache 6 | __pycache__/ 7 | *.pyc 8 | *.pyo 9 | *.pyd 10 | .Python 11 | *.so 12 | 13 | # IDE files 14 | .vscode/ 15 | .idea/ 16 | *.swp 17 | *.swo 18 | 19 | # OS files 20 | .DS_Store 21 | Thumbs.db 22 | 23 | # Virtual environments 24 | venv/ 25 | env/ 26 | .env 27 | 28 | # Temporary files 29 | *.tmp 30 | *.temp 31 | 32 | # Personal/scratch files 33 | run.sh 34 | claude-sdk-readme.md 35 | 36 | juliet-test-suite-c/* 37 | qlworkspace/* 38 | 39 | juliet-test-suite-c/ -------------------------------------------------------------------------------- /BaseMachine/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | BaseMachine package for LLMFrontend2 3 | This package contains the core functionality for model management and state machine implementation. 4 | """ 5 | 6 | # Core components 7 | from .state_machine import StateMachine, BaseState, ExitState 8 | from .model_manager import ModelManager 9 | from .config_loader import load_config 10 | 11 | # Action utilities 12 | from .action_utils import ( 13 | create_chat_action, 14 | create_new_chat_action, 15 | create_context_filling_new_chat_action, 16 | create_context_filling_new_chat_json_action, 17 | call_sub_state_machine_action 18 | ) 19 | 20 | # Agent action utilities 21 | from .agent_action_utils import create_agent_action 22 | 23 | __all__ = [ 24 | # Core 25 | 'StateMachine', 26 | 'BaseState', 27 | 'ExitState', 28 | 'ModelManager', 29 | 'load_config', 30 | # Chat actions 31 | 'create_chat_action', 32 | 'create_new_chat_action', 33 | 'create_context_filling_new_chat_action', 34 | 'create_context_filling_new_chat_json_action', 35 | 'call_sub_state_machine_action', 36 | # Agent actions 37 | 'create_agent_action', 38 | ] 39 | -------------------------------------------------------------------------------- /BaseMachine/action_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | LLM Action Functions Module 3 | Contains various action functions for interacting with LLMs 4 | """ 5 | 6 | from typing import Any, List 7 | from pydantic import BaseModel, Field 8 | import logging 9 | from colorama import Fore 10 | import json 11 | import requests 12 | # Fix import errors by adapting to different OpenAI library versions 13 | try: 14 | # Try importing from newer version 15 | from openai.types.chat import ChatCompletion, ChatCompletionMessage 16 | from openai.types.chat.chat_completion import Choice 17 | except ImportError: 18 | try: 19 | # Try importing from another possible location 20 | from openai.types.chat import ChatCompletion, ChatCompletionMessage 21 | from openai.types import Choice 22 | except ImportError: 23 | # If both fail, create a simple implementation 24 | class ChatCompletion: 25 | def __init__(self, id, choices, created, model, object, system_fingerprint, usage): 26 | self.id = id 27 | self.choices = choices 28 | self.created = created 29 | self.model = model 30 | self.object = object 31 | self.system_fingerprint = system_fingerprint 32 | self.usage = usage 33 | 34 | class ChatCompletionMessage: 35 | def __init__(self, content, role, function_call=None, tool_calls=None): 36 | self.content = content 37 | self.role = role 38 | self.function_call = function_call 39 | self.tool_calls = tool_calls 40 | 41 | class Choice: 42 | def __init__(self, finish_reason, index, message, logprobs=None): 43 | self.finish_reason = finish_reason 44 | self.index = index 45 | self.message = message 46 | self.logprobs = logprobs 47 | 48 | 49 | class ContextCode(BaseModel): 50 | name: str = Field(description="Function or Class name") 51 | reason: str = Field(description="Brief reason why this function's code is needed for analysis") 52 | code_line: str = Field(description="The single line of code where where this context object is referenced.") 53 | file_path: str = Field(description="The file path of the code line.") 54 | 55 | class Response(BaseModel): 56 | analysis: str = Field(description="The analysis result of the question.") 57 | context_code: List[str] = Field(description="If you need additional context code to analyze the question, please provide the context code's names you need additionally to analyze the question.") 58 | 59 | # Import helper functions 60 | from BaseMachine.llm_helpers import ( 61 | reliable_parse, 62 | safe_format, 63 | extract_code_snippets, 64 | parse_and_validate_json_response 65 | ) 66 | 67 | def create_chat_action(prompt_template, response_parser=None, save_option='both', model_name='azure-gpt4o', debug=False): 68 | """ 69 | Create a chat action function for sending prompts and handling responses. 70 | Maintains the complete chat history. 71 | 72 | :param prompt_template: The prompt template 73 | :param response_parser: Optional response parser 74 | :param save_option: Save option, can be 'both', 'prompt', 'result', or 'none' 75 | :param model_name: The model name to use 76 | :param debug: Whether to enable debugging 77 | :return: The action function 78 | """ 79 | def chat_action(machine, **kwargs): 80 | from BaseMachine.state_machine import StateMachine # Move import here 81 | prompt = prompt_template.format(**kwargs) 82 | 83 | if debug: 84 | logging.info(Fore.BLUE + f'Chat Action Prompt: {prompt}') 85 | 86 | machine.messages.append({"role": "user", "content": prompt}) 87 | 88 | # Select the appropriate client based on model_name 89 | client_to_use, info = next(((client, info) for client, info in machine.clients if info['name'] == model_name), (None, None)) 90 | if client_to_use is None: 91 | raise ValueError(f"Model '{model_name}' not found in initialized clients.") 92 | 93 | # Ensure info is a dictionary 94 | if isinstance(info, tuple): 95 | info = dict(info) 96 | 97 | # Build request parameters based on whether response_parser is None or not 98 | request_params = { 99 | 'model': info['model_name'], 100 | 'messages': machine.messages, 101 | **( 102 | {"temperature": 0.01, "top_p": machine.config.top_p} 103 | if info['model_name'] not in ["o1-mini", "o1-preview"] 104 | else {} 105 | ), 106 | } 107 | if response_parser is not None: 108 | request_params['response_format'] = response_parser 109 | 110 | # Change to use the reliable_parse function to make the request 111 | # Use the selected client to make the request 112 | logging.info(Fore.YELLOW + f'Waiting for the model {info["model_name"]} to process the request...') 113 | response = reliable_parse(client_to_use, request_params, max_retries=3, debug=debug, model_info=info) 114 | logging.info(Fore.GREEN + f'Model {info["model_name"]} processed the request successfully.') 115 | # machine.total_input_tokens += response.usage.prompt_tokens 116 | # machine.total_output_tokens += response.usage.completion_tokens 117 | 118 | # Add the assistant's reply to the message list 119 | machine.messages.append( 120 | {"role": "assistant", "content": response.choices[0].message.content} 121 | ) 122 | 123 | # Parse the assistant's reply 124 | message = response.choices[0].message 125 | # parsed_result = getattr(message, "parsed", message.content) 126 | parsed_result = message.content if getattr(message, "parsed", None) is None else message.parsed 127 | 128 | # Save content based on the save_option parameter 129 | if save_option == 'prompt': 130 | machine.analysis_result.append(prompt) 131 | elif save_option == 'result': 132 | machine.analysis_result.append(parsed_result) 133 | elif save_option == 'both': 134 | machine.analysis_result.append({'prompt': prompt, 'result': parsed_result}) 135 | elif save_option == 'none': 136 | pass 137 | else: 138 | # If an invalid save_option is provided, throw an exception or perform default handling 139 | raise ValueError("Invalid save_option value. Choose from 'prompt', 'result', or 'both'.") 140 | 141 | return parsed_result 142 | 143 | return chat_action 144 | 145 | 146 | def create_new_chat_action(prompt_template, response_parser=None, save_option='both', model_name='azure-gpt4o', debug=False): 147 | """ 148 | Create a new chat action function that ignores previous messages but updates the machine's message history. 149 | 150 | :param prompt_template: The prompt template 151 | :param response_parser: Optional response parser 152 | :param save_option: Save option, can be 'both', 'prompt', 'result', or 'none' 153 | :param model_name: The model name to use 154 | :param debug: Whether to enable debugging 155 | :return: The action function 156 | """ 157 | pass 158 | 159 | 160 | def create_context_filling_new_chat_action(prompt_template, response_parser=None, save_option='both', model_name='azure-gpt4o'): 161 | """ 162 | Create a context-filling chat action function. 163 | The first response includes a general chat result and a context filling field. 164 | Then, include the filled context code at the end of the prompt and re-ask. 165 | 166 | :param prompt_template: The prompt template 167 | :param response_parser: Optional response parser 168 | :param save_option: Save option, can be 'both', 'prompt', 'result', or 'none' 169 | :param model_name: The model name to use 170 | :return: The action function 171 | """ 172 | pass 173 | 174 | 175 | def create_context_filling_new_chat_json_action(prompt_template, response_parser=None, save_option='both', model_name='azure-gpt4o', debug=False, use_hardcoded_json=False, accelerated_mode=True): 176 | """ 177 | Create a context-filling chat action function with JSON response format. 178 | Provides accelerated mode and debugging features. 179 | 180 | :param prompt_template: The prompt template 181 | :param response_parser: Optional response parser 182 | :param save_option: Save option, can be 'both', 'prompt', 'result', or 'none' 183 | :param model_name: The model name to use 184 | :param debug: Whether to enable debugging 185 | :param use_hardcoded_json: Whether to use hardcoded JSON (for debugging) 186 | :param accelerated_mode: Whether to enable accelerated mode 187 | :return: The action function 188 | """ 189 | pass 190 | 191 | 192 | def call_sub_state_machine_action(sub_state_definitions, sub_initial_state, sub_context_cls, save_option='both'): 193 | """ 194 | Create an action function that calls a sub-state machine 195 | 196 | :param sub_state_definitions: The sub-state machine's state definitions 197 | :param sub_initial_state: The sub-state machine's initial state 198 | :param sub_context_cls: The sub-state machine's context class 199 | :param save_option: Save option, can be 'both', 'prompt', 'result', or 'none' 200 | :return: The action function 201 | """ 202 | def sub_state_machine_action(machine, **kwargs): 203 | from BaseMachine.state_machine import StateMachine # Move import here 204 | # Create the sub-state machine's context 205 | sub_context = sub_context_cls(**kwargs) 206 | 207 | # Create and run the sub-state machine 208 | sub_machine = StateMachine( 209 | context=sub_context, 210 | state_definitions=sub_state_definitions, 211 | initial_state=sub_initial_state, 212 | config_path=machine.config.config_path 213 | ) 214 | sub_result = sub_machine.process() 215 | 216 | # Merge the sub-state machine's results and resource consumption 217 | machine.total_input_tokens += sub_machine.total_input_tokens 218 | machine.total_output_tokens += sub_machine.total_output_tokens 219 | machine.messages.extend(sub_machine.messages) 220 | 221 | # Save content based on the save_option parameter 222 | if save_option == 'prompt': 223 | machine.analysis_result.append(sub_context) 224 | elif save_option == 'result': 225 | machine.analysis_result.append(sub_result) 226 | elif save_option == 'both': 227 | machine.analysis_result.append({'context': sub_context, 'result': sub_result}) 228 | else: 229 | # If an invalid save_option is provided, throw an exception or perform default handling 230 | raise ValueError("Invalid save_option value. Choose from 'prompt', 'result', or 'both'.") 231 | 232 | return sub_result 233 | return sub_state_machine_action 234 | -------------------------------------------------------------------------------- /BaseMachine/agent_action_utils.py: -------------------------------------------------------------------------------- 1 | # agent_action_utils.py 2 | 3 | import anyio 4 | import logging 5 | import json 6 | import os 7 | from datetime import datetime 8 | from pathlib import Path 9 | from typing import Any, Callable, Dict, Optional, List, AsyncIterator, Union, TYPE_CHECKING 10 | 11 | # Get the directory of the script for relative paths 12 | SCRIPT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 13 | 14 | if TYPE_CHECKING: 15 | from BaseMachine.state_machine import StateMachine 16 | from claude_code_sdk import ( 17 | query, 18 | ClaudeCodeOptions, 19 | AssistantMessage, 20 | TextBlock, 21 | ToolUseBlock, 22 | ToolResultBlock, 23 | UserMessage, 24 | SystemMessage, 25 | ResultMessage, 26 | CLINotFoundError, 27 | ProcessError, 28 | CLIJSONDecodeError 29 | ) 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class StreamingJSONLogger: 35 | """ 36 | Logger for saving streaming JSON messages from Claude Code SDK. 37 | """ 38 | def __init__(self, base_log_dir: str = None): 39 | if base_log_dir is None: 40 | base_log_dir = os.path.join(SCRIPT_DIR, 'qlworkspace') 41 | self.base_log_dir = Path(base_log_dir) 42 | self.base_log_dir.mkdir(parents=True, exist_ok=True) 43 | self.current_session_id = None 44 | self.session_messages = [] 45 | 46 | def create_session(self, context: Dict[str, Any]) -> str: 47 | """ 48 | Create a new logging session. 49 | 50 | Args: 51 | context: Context containing session metadata 52 | 53 | Returns: 54 | Session ID 55 | """ 56 | timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') 57 | # Include action type in session ID if available 58 | action_type = context.get('action_type', 'general') 59 | self.current_session_id = f"{action_type}_{timestamp}" 60 | 61 | # Build log directory path 62 | # Use custom log path if provided, otherwise use default 63 | if 'log_path' in context: 64 | self.session_log_dir = Path(context['log_path']) 65 | else: 66 | # Create a general session directory 67 | self.session_log_dir = self.base_log_dir / 'sessions' / self.current_session_id 68 | 69 | self.session_log_dir.mkdir(parents=True, exist_ok=True) 70 | self.session_messages = [] 71 | 72 | # Log session start 73 | self.log_message({ 74 | 'type': 'session_start', 75 | 'session_id': self.current_session_id, 76 | 'timestamp': timestamp, 77 | 'context': context 78 | }) 79 | 80 | return self.current_session_id 81 | 82 | def log_message(self, message: Dict[str, Any]): 83 | """ 84 | Log a streaming JSON message. 85 | 86 | Args: 87 | message: The message to log 88 | """ 89 | if not self.current_session_id: 90 | logger.warning("No active session. Message not logged.") 91 | return 92 | 93 | # Add timestamp if not present 94 | if 'timestamp' not in message: 95 | message['timestamp'] = datetime.now().isoformat() 96 | 97 | self.session_messages.append(message) 98 | 99 | # Write to single log file (append mode) - no longer separate streaming file 100 | log_path = self.session_log_dir / f"{self.current_session_id}.jsonl" 101 | with open(log_path, 'a') as f: 102 | f.write(json.dumps(message) + '\n') 103 | 104 | def finalize_session(self, result: Any = None): 105 | """ 106 | Finalize the current session and save complete log. 107 | 108 | Args: 109 | result: Final result of the session 110 | """ 111 | if not self.current_session_id: 112 | return 113 | 114 | # Log session end 115 | self.log_message({ 116 | 'type': 'session_end', 117 | 'session_id': self.current_session_id, 118 | 'timestamp': datetime.now().isoformat(), 119 | 'result_summary': str(result) if result else None 120 | }) 121 | 122 | # No need to save a separate complete log file - the JSONL file has everything 123 | # Just log the final statistics 124 | logger.info(f"Session {self.current_session_id} completed with {len(self.session_messages)} messages") 125 | 126 | logger.info(f"Session {self.current_session_id} finalized. Logs saved to {self.session_log_dir}") 127 | 128 | # Reset session 129 | self.current_session_id = None 130 | self.session_messages = [] 131 | 132 | 133 | # Global logger instance 134 | streaming_logger = StreamingJSONLogger() 135 | 136 | 137 | def parse_streaming_json_message(message: Any) -> Dict[str, Any]: 138 | """ 139 | Parse a streaming JSON message from Claude Code SDK. 140 | 141 | Args: 142 | message: The message object from the SDK 143 | 144 | Returns: 145 | A dictionary representing the JSON message 146 | """ 147 | message_dict = { 148 | 'type': type(message).__name__.lower().replace('message', '') 149 | } 150 | 151 | if isinstance(message, (SystemMessage, UserMessage)): 152 | message_dict['content'] = str(message.content) if hasattr(message, 'content') else '' 153 | elif isinstance(message, AssistantMessage): 154 | content_list = [] 155 | for block in message.content: 156 | if isinstance(block, TextBlock): 157 | content_list.append({ 158 | 'type': 'text', 159 | 'text': block.text 160 | }) 161 | elif isinstance(block, ToolUseBlock): 162 | content_list.append({ 163 | 'type': 'tool_use', 164 | 'name': block.name, 165 | 'input': block.input 166 | }) 167 | message_dict['content'] = content_list 168 | elif isinstance(message, ResultMessage): 169 | message_dict.update({ 170 | 'session_id': getattr(message, 'session_id', None), 171 | 'duration': getattr(message, 'duration', None), 172 | 'total_cost': getattr(message, 'total_cost', None), 173 | 'turn_count': getattr(message, 'turn_count', None) 174 | }) 175 | 176 | return message_dict 177 | 178 | 179 | def create_agent_action( 180 | prompt_template: str, 181 | response_parser: Optional[Callable[[str], Any]] = None, 182 | save_option: str = 'both', 183 | allowed_tools: Optional[List[str]] = None, 184 | permission_mode: str = 'default', 185 | system_prompt: Optional[str] = None, 186 | max_turns: Optional[int] = None, 187 | output_format: str = 'default', 188 | stream_callback: Optional[Callable[[Dict[str, Any]], None]] = None, 189 | enable_stream_logging: bool = False, 190 | debug: bool = False 191 | ) -> Callable[['StateMachine', Dict[str, Any]], Any]: 192 | """ 193 | Create an agent action function using Claude Code SDK. 194 | 195 | Args: 196 | prompt_template: Template string for the prompt 197 | response_parser: Optional function to parse response 198 | save_option: How to save the interaction (unused in agent mode) 199 | allowed_tools: List of tools the agent can use 200 | permission_mode: 'acceptEdits', 'bypassPermissions', 'default', or 'plan' 201 | system_prompt: System prompt for the agent 202 | max_turns: Maximum conversation turns 203 | output_format: Output format - 'default' or 'stream-json' 204 | stream_callback: Optional callback for streaming JSON messages 205 | enable_stream_logging: Enable automatic streaming JSON logging 206 | debug: Enable debug logging 207 | 208 | Returns: 209 | An action function that can be used in state definitions 210 | """ 211 | def action(machine: 'StateMachine', **kwargs) -> Any: 212 | # Format the prompt using template 213 | formatted_prompt = prompt_template.format(**kwargs) 214 | 215 | if debug: 216 | logger.debug(f"Agent action prompt: {formatted_prompt}") 217 | 218 | # Initialize streaming logger session if enabled 219 | session_id = None 220 | if enable_stream_logging or output_format == 'stream-json': 221 | # Extract context from machine for logging 222 | log_context = { 223 | 'action_type': machine.context.get('action_type', 'general'), 224 | 'working_directory': machine.context.get('working_directory'), 225 | 'prompt': formatted_prompt[:200] + '...' if len(formatted_prompt) > 200 else formatted_prompt 226 | } 227 | # Allow machine context to override log path 228 | if hasattr(machine.context, 'session_log_path'): 229 | log_context['log_path'] = machine.context.session_log_path 230 | session_id = streaming_logger.create_session(log_context) 231 | 232 | # Configure Claude Code options 233 | options = ClaudeCodeOptions( 234 | cwd=machine.context.get('working_directory', None), 235 | allowed_tools=allowed_tools or [], 236 | permission_mode=permission_mode, 237 | system_prompt=system_prompt, 238 | max_turns=max_turns 239 | ) 240 | 241 | # Collect all responses 242 | responses = [] 243 | tool_uses = [] 244 | streaming_messages = [] 245 | 246 | try: 247 | # Run the async query synchronously 248 | async def run_query(): 249 | async for message in query(prompt=formatted_prompt, options=options): 250 | if debug: 251 | logger.debug(f"Received message: {type(message).__name__}") 252 | 253 | # Handle streaming JSON output 254 | if output_format == 'stream-json' or enable_stream_logging: 255 | json_message = parse_streaming_json_message(message) 256 | streaming_messages.append(json_message) 257 | 258 | # Log to file if enabled 259 | if enable_stream_logging: 260 | streaming_logger.log_message(json_message) 261 | 262 | # Call stream callback if provided 263 | if stream_callback: 264 | stream_callback(json_message) 265 | 266 | if isinstance(message, AssistantMessage): 267 | for block in message.content: 268 | if isinstance(block, TextBlock): 269 | responses.append(block.text) 270 | # Check for Claude AI usage limit error 271 | if "Claude AI usage limit reached" in block.text: 272 | logger.error("Claude AI usage limit reached - stopping pipeline") 273 | raise RuntimeError("Claude AI usage limit reached") 274 | elif isinstance(block, ToolUseBlock): 275 | tool_uses.append({ 276 | 'tool': block.name, 277 | 'input': block.input 278 | }) 279 | elif isinstance(message, ResultMessage): 280 | # Handle result message with metadata 281 | if debug: 282 | logger.debug(f"Result message: {message}") 283 | 284 | return responses, tool_uses, streaming_messages 285 | 286 | responses, tool_uses, streaming_messages = anyio.run(run_query) 287 | 288 | # Join all text responses 289 | full_response = '\n'.join(responses) 290 | 291 | # Store in machine's context for hybrid mode 292 | if hasattr(machine, 'agent_results'): 293 | machine.agent_results.append({ 294 | 'prompt': formatted_prompt, 295 | 'response': full_response, 296 | 'tool_uses': tool_uses 297 | }) 298 | 299 | # Finalize streaming logger session 300 | if session_id: 301 | streaming_logger.finalize_session({ 302 | 'response_length': len(full_response), 303 | 'tool_use_count': len(tool_uses), 304 | 'message_count': len(streaming_messages) 305 | }) 306 | 307 | # Parse response if parser provided 308 | if response_parser: 309 | parsed_result = response_parser(full_response) 310 | if output_format == 'stream-json': 311 | return { 312 | 'parsed': parsed_result, 313 | 'streaming_messages': streaming_messages 314 | } 315 | return parsed_result 316 | 317 | result = { 318 | 'response': full_response, 319 | 'tool_uses': tool_uses 320 | } 321 | 322 | if output_format == 'stream-json': 323 | result['streaming_messages'] = streaming_messages 324 | 325 | return result 326 | 327 | except CLINotFoundError: 328 | logger.error("Claude Code CLI not found. Please install with: npm install -g @anthropic-ai/claude-code") 329 | raise 330 | except ProcessError as e: 331 | logger.error(f"Process failed with exit code: {e.exit_code}") 332 | raise 333 | except RuntimeError as e: 334 | # Re-raise RuntimeError (including usage limit) without wrapping 335 | if "Claude AI usage limit reached" in str(e): 336 | logger.error("Claude AI usage limit reached - stopping execution") 337 | raise 338 | except Exception as e: 339 | logger.error(f"Agent action failed: {e}") 340 | raise 341 | 342 | return action 343 | 344 | 345 | -------------------------------------------------------------------------------- /BaseMachine/code_filling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/P1umer/QL-Relax/941ffa7fb0b0b196ba9e4e49e99f5b7fe6e47138/BaseMachine/code_filling/__init__.py -------------------------------------------------------------------------------- /BaseMachine/code_filling/code_filling_config.py: -------------------------------------------------------------------------------- 1 | from BaseMachine.code_filling.code_filling_tools import query_symbol_definition 2 | from BaseMachine.action_utils import create_chat_action 3 | 4 | # Control whether to guess the code definition or simply return "missing" message 5 | # Set to True to enable guessing, False to return missing definition message 6 | DEF_GUESS = False 7 | 8 | def initialize_system_prompt_action(machine) -> None: 9 | machine.messages = [ 10 | { 11 | "role": "system", 12 | "content": f'You are an expert in code search.', 13 | } 14 | ] 15 | symbol = machine.context.name 16 | machine.search_results = query_symbol_definition(symbol) 17 | machine.code_snippet = machine.context.context_code 18 | return None 19 | 20 | def guess_the_code_action(machine): 21 | machine.definition = create_chat_action( 22 | prompt_template=''' 23 | You have to guess the code function definition with its possiable implementation of the function {name} based on the following code snippet: 24 | {code_snippet} 25 | ''' 26 | )( 27 | machine, 28 | name=machine.context.name, 29 | code_snippet=machine.code_snippet 30 | ) 31 | 32 | return None 33 | 34 | def return_missing_definition_action(machine): 35 | # Simply set the definition to indicate it's missing 36 | machine.definition = f"ERROR!:{machine.context.name} definition missing because of the function definition tools is broken when search for the symbol name {machine.context.name} definitions." 37 | return None 38 | 39 | def use_single_result_action(machine): 40 | """ 41 | Directly use the single search result without LLM selection. 42 | """ 43 | if len(machine.search_results) == 1: 44 | machine.definition = machine.search_results[0] 45 | else: 46 | # This should not happen if state transitions are correct 47 | machine.definition = f"ERROR!: Expected single result but got {len(machine.search_results)} results." 48 | return None 49 | 50 | def choose_most_related_result_action(machine): 51 | machine.definition = create_chat_action( 52 | prompt_template=''' 53 | Below are some definitions return by openGrok code search platform when I search for the symbol name {name} definitions. 54 | The use of this symbol is in the following code snippet: 55 | {code_snippet} 56 | 57 | Please choose the most related one based on the provided symbol using context, and return the full definition of the chosen one. 58 | the search results are as follows, stored in an array: 59 | {search_results} 60 | ''', 61 | save_option='both' 62 | )( 63 | machine, 64 | name=machine.context.name, 65 | code_snippet=machine.code_snippet, 66 | search_results=machine.search_results 67 | ) 68 | # print(machine.definition) 69 | return None 70 | 71 | def exit_action(machine): 72 | return None 73 | 74 | 75 | # 1. init the system prompt (input is a symbol and context) 76 | # 2. choose the most related result (only if more than 1 result) 77 | # 3. use single result directly if only 1 result 78 | state_definitions = { 79 | 'InitializeSystemPrompt': { 80 | 'action': initialize_system_prompt_action, 81 | 'next_state_func': lambda result, machine: ( 82 | 'SelectAndChooseMostRelatedResult' if len(machine.search_results) > 1 83 | else 'UseSingleResult' if len(machine.search_results) == 1 84 | else ('GuessTheCode' if DEF_GUESS else 'ReturnMissingDefinition') 85 | ), 86 | }, 87 | 'UseSingleResult': { 88 | 'action': use_single_result_action, 89 | 'next_state_func': lambda result, machine: 'Exit', 90 | }, 91 | 'SelectAndChooseMostRelatedResult': { 92 | 'action': choose_most_related_result_action, 93 | 'next_state_func': lambda result, machine: 'Exit', 94 | }, 95 | 'GuessTheCode': { 96 | 'action': guess_the_code_action, 97 | 'next_state_func': lambda result, machine: 'Exit', 98 | }, 99 | 'ReturnMissingDefinition': { 100 | 'action': return_missing_definition_action, 101 | 'next_state_func': lambda result, machine: 'Exit', 102 | }, 103 | 'Exit': { 104 | 'action': exit_action, 105 | 'next_state_func': None, 106 | }, 107 | } -------------------------------------------------------------------------------- /BaseMachine/code_filling/code_filling_context.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import List 3 | 4 | # Context data model 5 | class CFContext(BaseModel): 6 | name: str = Field(description="Function or Class name") 7 | context_code: str = Field(description="The callsite code of the function") 8 | 9 | # class Definition(BaseModel): 10 | 11 | 12 | -------------------------------------------------------------------------------- /BaseMachine/code_filling/code_filling_tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | from colorama import Fore 5 | # from flow_analysis.flow_context import execStep 6 | 7 | current_dir = os.path.dirname(os.path.abspath(__file__)) 8 | parent_dir = os.path.abspath(os.path.join(current_dir, '../../AdvancedTools/CodeSearch')) 9 | sys.path.append(parent_dir) 10 | 11 | from AdvancedTools.CodeSearch.opengrok_search import CodeQueryManager # Corrected import path 12 | 13 | def query_symbol_definition(symbol, port=8080): 14 | """ 15 | Query the definition of a symbol using CodeQueryManager. 16 | """ 17 | os.environ['OPENGROK_STATUS'] = 'ready' 18 | query_manager = CodeQueryManager(port) 19 | results = query_manager.query_definition(symbol) 20 | return results 21 | 22 | # write a main 23 | if __name__ == '__main__': 24 | project_name = 'VBox' 25 | symbol = 'fetch_raw_setting_copy' 26 | # tmpe env variable to avoid the error 27 | os.environ['OPENGROK_STATUS'] = 'ready' 28 | results = query_symbol_definition(symbol) 29 | print(results) -------------------------------------------------------------------------------- /BaseMachine/config_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | class chatGPTConfig: 5 | def __init__(self, api_key, model, max_tokens, temperature, top_p, stop_sequences, azure_key, azure_endpoint, use_provider="openai", siliconflow_key=None, siliconflow_base_url=None, siliconflow_model=None): 6 | self.api_key = api_key 7 | self.model = model 8 | self.max_tokens = max_tokens 9 | self.temperature = temperature 10 | self.top_p = top_p 11 | self.stop_sequences = stop_sequences 12 | self.azure_key = azure_key 13 | self.azure_endpoint = azure_endpoint 14 | self.use_provider = use_provider 15 | self.siliconflow_key = siliconflow_key 16 | self.siliconflow_base_url = siliconflow_base_url 17 | self.siliconflow_model = siliconflow_model 18 | 19 | def __str__(self): 20 | return f"api_key: {self.api_key}, model: {self.model}, max_tokens: {self.max_tokens}, temperature: {self.temperature}, top_p: {self.top_p}, stop_sequences: {self.stop_sequences}" 21 | 22 | def __repr__(self): 23 | return f"api_key: {self.api_key}, model: {self.model}, max_tokens: {self.max_tokens}, temperature: {self.temperature}, top_p: {self.top_p}, stop_sequences: {self.stop_sequences}" 24 | 25 | def __eq__(self, other): 26 | if not isinstance(other, chatGPTConfig): 27 | return False 28 | return self.api_key == other.api_key and self.model == other.model and self.max_tokens == other.max_tokens and self.temperature == other.temperature and self.top_p == other.top_p and self.stop_sequences == other.stop_sequences 29 | 30 | def load_config(file_path) -> chatGPTConfig: 31 | log_file_path = os.path.join(os.path.dirname(__file__), file_path) 32 | # print("[I] Loading ChatGPT configure from", log_file_path) 33 | with open(log_file_path, 'r') as file: 34 | config = json.load(file) 35 | 36 | default_model = 'gpt-4o' 37 | default_max_tokens = 8192 38 | default_temperature = 0.7 39 | default_top_p = 1.0 40 | default_stop_sequences = '\n' 41 | default_provider = 'openai' 42 | 43 | model = config.get('model') 44 | max_tokens = config.get('max_tokens') 45 | temperature = config.get('temperature') 46 | top_p = config.get('top_p') 47 | stop_sequences = config.get('stop_sequences') 48 | azure_key = config.get('azure_key') 49 | azure_endpoint = config.get('azure_endpoint') 50 | use_provider = config.get('use_provider', default_provider) 51 | siliconflow_key = config.get('siliconflow_key') 52 | siliconflow_base_url = config.get('siliconflow_base_url') 53 | siliconflow_model = config.get('siliconflow_model') 54 | 55 | # if (not model): 56 | # print("[W] Model is not specified. Use", default_model, "by default.") 57 | # model = default_model 58 | if (not max_tokens): 59 | print("[W] Max token length is not specified. Use", default_max_tokens, "by default.") 60 | max_tokens = default_max_tokens 61 | if (not temperature): 62 | print("[W] Temperature is not specified. Using", default_temperature, "by default.") 63 | temperature = default_temperature 64 | if (not top_p): 65 | print("[W] Top_p is not specified. Using", default_top_p, "by default.") 66 | top_p = default_top_p 67 | if (not stop_sequences): 68 | print("[W] Stop sequence char are not specified. Using", repr(default_stop_sequences), "by default.") 69 | stop_sequences = default_stop_sequences 70 | 71 | config_obj = chatGPTConfig( 72 | config['api_key'], 73 | model, 74 | max_tokens, 75 | temperature, 76 | top_p, 77 | stop_sequences, 78 | azure_key, 79 | azure_endpoint, 80 | use_provider, 81 | siliconflow_key, 82 | siliconflow_base_url, 83 | siliconflow_model 84 | ) 85 | return config_obj 86 | 87 | 88 | if __name__ == "__main__": 89 | from openai import OpenAI 90 | config = load_config('../config/config.json') 91 | client = OpenAI(api_key=config.api_key) 92 | 93 | print("ChatGPT CLI. Type 'exit' to quit.") 94 | while True: 95 | user_input = input("You: ") 96 | if user_input.lower() == 'exit': 97 | break 98 | try: 99 | completion = client.chat.completions.create( 100 | model=config.model, 101 | messages=[ 102 | {"role": "user", "content": user_input} 103 | ], 104 | max_tokens=config.max_tokens, 105 | temperature=config.temperature, 106 | top_p=config.top_p, 107 | stop=config.stop_sequences 108 | ) 109 | message = completion.choices[0].message.content 110 | print(f"ChatGPT: {message}") 111 | except Exception as e: 112 | print(f"[E]: {e}") 113 | -------------------------------------------------------------------------------- /BaseMachine/llm_helpers.py: -------------------------------------------------------------------------------- 1 | """ 2 | LLM Response Processing and Tool Functions Module 3 | Contains helper functions for LLM API interactions, response parsing, and formatting 4 | """ 5 | 6 | import logging 7 | import json 8 | import re 9 | import requests 10 | from colorama import Fore 11 | 12 | # Fix import errors by adapting to different OpenAI library versions 13 | try: 14 | # Try importing from newer version 15 | from openai.types.chat import ChatCompletion, ChatCompletionMessage 16 | from openai.types.chat.chat_completion import Choice 17 | except ImportError: 18 | try: 19 | # Try importing from another possible location 20 | from openai.types.chat import ChatCompletion, ChatCompletionMessage 21 | from openai.types import Choice 22 | except ImportError: 23 | # If both fail, create a simple implementation 24 | class ChatCompletion: 25 | def __init__(self, id, choices, created, model, object, system_fingerprint, usage): 26 | self.id = id 27 | self.choices = choices 28 | self.created = created 29 | self.model = model 30 | self.object = object 31 | self.system_fingerprint = system_fingerprint 32 | self.usage = usage 33 | 34 | class ChatCompletionMessage: 35 | def __init__(self, content, role, function_call=None, tool_calls=None): 36 | self.content = content 37 | self.role = role 38 | self.function_call = function_call 39 | self.tool_calls = tool_calls 40 | 41 | class Choice: 42 | def __init__(self, finish_reason, index, message, logprobs=None): 43 | self.finish_reason = finish_reason 44 | self.index = index 45 | self.message = message 46 | self.logprobs = logprobs 47 | 48 | 49 | def reliable_parse(client, request_params, max_retries=3, debug=False, model_info=None): 50 | """ 51 | Reliably parse LLM completion responses with retry logic. 52 | 53 | :param client: The client to use for parsing (OpenAI client) 54 | :param request_params: The parameters for the parsing request 55 | :param max_retries: Maximum number of retries 56 | :param debug: Enable debug logging 57 | :param model_info: Optional model information containing additional parameters 58 | :return: The message content if successful, None otherwise 59 | """ 60 | # Initialize retry counter 61 | retries = 0 62 | 63 | # Create a copy of request parameters 64 | merged_params = {**request_params} 65 | 66 | # Handle OpenRouter provider 67 | if model_info and model_info.get('provider') == 'openrouter': 68 | # Get OpenRouter API key 69 | openrouter_api_key = model_info.get('openrouter_api_key', None) 70 | if not openrouter_api_key and 'additional_kwargs' in model_info and 'openrouter_api_key' in model_info['additional_kwargs']: 71 | openrouter_api_key = model_info['additional_kwargs']['openrouter_api_key'] 72 | 73 | if not openrouter_api_key: 74 | raise ValueError("OpenRouter API key not found in model_info") 75 | 76 | # Set up extra headers for OpenRouter 77 | extra_headers = {} 78 | if 'additional_kwargs' in model_info: 79 | if 'http_referer' in model_info['additional_kwargs']: 80 | extra_headers["HTTP-Referer"] = model_info['additional_kwargs']['http_referer'] 81 | if 'x_title' in model_info['additional_kwargs']: 82 | extra_headers["X-Title"] = model_info['additional_kwargs']['x_title'] 83 | 84 | # Set up extra body parameters for OpenRouter 85 | extra_body = {} 86 | if 'additional_kwargs' in model_info: 87 | # Add reasoning parameter support 88 | if 'reasoning' in model_info['additional_kwargs']: 89 | extra_body['reasoning'] = model_info['additional_kwargs']['reasoning'] 90 | logging.info(Fore.CYAN + f"Using OpenRouter reasoning parameter: {extra_body['reasoning']}") 91 | 92 | # Add provider if specified 93 | if 'provider' in model_info['additional_kwargs']: 94 | extra_body['provider'] = model_info['additional_kwargs']['provider'] 95 | 96 | # Add models array if specified 97 | if 'models' in model_info['additional_kwargs']: 98 | extra_body['models'] = model_info['additional_kwargs']['models'] 99 | 100 | # Add any other OpenRouter-specific parameters 101 | for param in ['routes', 'transforms', 'stream_options']: 102 | if param in model_info['additional_kwargs']: 103 | extra_body[param] = model_info['additional_kwargs'][param] 104 | 105 | # Add extra parameters to merged_params 106 | if extra_headers: 107 | merged_params['extra_headers'] = extra_headers 108 | if extra_body: 109 | merged_params['extra_body'] = extra_body 110 | 111 | # Use the client directly with OpenRouter base URL 112 | while retries < max_retries: 113 | try: 114 | if debug: 115 | logging.info(Fore.BLUE + f"OpenRouter request params: {json.dumps(merged_params, default=str, ensure_ascii=False)}") 116 | 117 | # Use the OpenAI client to make the request 118 | response = client.beta.chat.completions.parse(**merged_params) 119 | 120 | # Check if we have a valid message content 121 | if not response.choices or not response.choices[0].message.content: 122 | logging.info(Fore.YELLOW + "Message content is empty, resending request...") 123 | retries += 1 124 | continue 125 | 126 | return response 127 | 128 | except Exception as e: 129 | logging.error(Fore.RED + f"OpenRouter request exception: {str(e)}") 130 | retries += 1 131 | 132 | logging.error(Fore.RED + "Unable to get a valid OpenRouter response after maximum retries.") 133 | return None 134 | 135 | # Other providers use standard OpenAI client 136 | else: 137 | while retries < max_retries: 138 | response = client.beta.chat.completions.parse(**request_params) 139 | message = response.choices[0].message 140 | 141 | if message.content: 142 | return response 143 | else: 144 | logging.info(Fore.YELLOW + "The message content is null or empty, re-running the request...") 145 | retries += 1 146 | 147 | logging.error(Fore.RED + "Failed to get a valid response after maximum retries.") 148 | return None 149 | 150 | 151 | def safe_format(template_str, **kwargs): 152 | """ 153 | Safely format a string, preserving original placeholders if parameters are missing 154 | """ 155 | class SafeDict(dict): 156 | def __missing__(self, key): 157 | return '{' + key + '}' 158 | 159 | try: 160 | return template_str.format_map(SafeDict(kwargs)) 161 | except Exception: 162 | return template_str 163 | 164 | 165 | def extract_code_snippets(prompt): 166 | """Extract all code snippets from the prompt""" 167 | pass 168 | 169 | 170 | def parse_and_validate_json_response(message, machine, debug=False): 171 | """Process and validate JSON responses, automatically fixing format issues""" 172 | pass -------------------------------------------------------------------------------- /BaseMachine/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unified Logging System for WorkFlow Framework 3 | Provides consistent logging configuration across all modules. 4 | """ 5 | 6 | import logging 7 | import sys 8 | from datetime import datetime 9 | from pathlib import Path 10 | from colorama import Fore, Style, init 11 | 12 | # Initialize colorama for cross-platform colored output 13 | init(autoreset=True) 14 | 15 | 16 | class ColoredFormatter(logging.Formatter): 17 | """Custom formatter for colored console output.""" 18 | 19 | COLORS = { 20 | 'DEBUG': Fore.BLUE, 21 | 'INFO': Fore.GREEN, 22 | 'WARNING': Fore.YELLOW, 23 | 'ERROR': Fore.RED, 24 | 'CRITICAL': Fore.MAGENTA + Style.BRIGHT, 25 | } 26 | 27 | def format(self, record): 28 | """Format log record with color coding.""" 29 | # Add color to level name 30 | level_color = self.COLORS.get(record.levelname, '') 31 | record.levelname = f"{level_color}{record.levelname}{Style.RESET_ALL}" 32 | 33 | # Add color to module name for better readability 34 | if hasattr(record, 'module'): 35 | record.module = f"{Fore.CYAN}{record.module}{Style.RESET_ALL}" 36 | 37 | return super().format(record) 38 | 39 | 40 | class WorkflowLogger: 41 | """ 42 | Unified logger for the workflow framework. 43 | Provides consistent logging across all modules with proper formatting and colors. 44 | """ 45 | 46 | _instance = None 47 | _initialized = False 48 | 49 | def __new__(cls): 50 | """Singleton pattern to ensure only one logger configuration.""" 51 | if cls._instance is None: 52 | cls._instance = super().__new__(cls) 53 | return cls._instance 54 | 55 | def __init__(self): 56 | """Initialize the logger configuration.""" 57 | if self._initialized: 58 | return 59 | 60 | self._setup_logging() 61 | self._initialized = True 62 | 63 | def _setup_logging(self): 64 | """Setup the unified logging configuration.""" 65 | # Get root logger 66 | root_logger = logging.getLogger() 67 | 68 | # Clear any existing handlers to avoid duplicates 69 | for handler in root_logger.handlers[:]: 70 | root_logger.removeHandler(handler) 71 | 72 | # Set logging level 73 | root_logger.setLevel(logging.INFO) 74 | 75 | # Create console handler 76 | console_handler = logging.StreamHandler(sys.stdout) 77 | console_handler.setLevel(logging.INFO) 78 | 79 | # Create colored formatter for console 80 | console_format = ( 81 | f"{Fore.WHITE}[%(asctime)s]{Style.RESET_ALL} " 82 | f"{Fore.CYAN}[%(name)s]{Style.RESET_ALL} " 83 | f"[%(levelname)s] " 84 | f"%(message)s" 85 | ) 86 | console_formatter = ColoredFormatter( 87 | console_format, 88 | datefmt='%H:%M:%S' 89 | ) 90 | console_handler.setFormatter(console_formatter) 91 | 92 | # Add handler to root logger 93 | root_logger.addHandler(console_handler) 94 | 95 | # Optionally create file handler for persistent logging 96 | self._setup_file_logging() 97 | 98 | def _setup_file_logging(self): 99 | """Setup file logging for persistent logs.""" 100 | try: 101 | # Create logs directory if it doesn't exist 102 | log_dir = Path('logs') 103 | log_dir.mkdir(exist_ok=True) 104 | 105 | # Create file handler with timestamp 106 | timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') 107 | log_file = log_dir / f'workflow_{timestamp}.log' 108 | 109 | file_handler = logging.FileHandler(log_file, encoding='utf-8') 110 | file_handler.setLevel(logging.DEBUG) # File logs can be more verbose 111 | 112 | # Create plain formatter for file (no colors) 113 | file_format = ( 114 | '[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s' 115 | ) 116 | file_formatter = logging.Formatter( 117 | file_format, 118 | datefmt='%Y-%m-%d %H:%M:%S' 119 | ) 120 | file_handler.setFormatter(file_formatter) 121 | 122 | # Add file handler to root logger 123 | logging.getLogger().addHandler(file_handler) 124 | 125 | except Exception as e: 126 | # If file logging fails, continue with console logging only 127 | print(f"Warning: Could not setup file logging: {e}") 128 | 129 | @staticmethod 130 | def get_logger(name=None): 131 | """ 132 | Get a logger instance for a specific module. 133 | 134 | Args: 135 | name: Logger name (usually __name__ or module name) 136 | 137 | Returns: 138 | Configured logger instance 139 | """ 140 | # Ensure the unified logger is initialized 141 | WorkflowLogger() 142 | 143 | if name is None: 144 | name = 'workflow' 145 | 146 | # Clean up the name for better readability 147 | if name.startswith('WorkflowTemplate.'): 148 | name = name.replace('WorkflowTemplate.', 'WF.') 149 | elif name.startswith('BaseMachine.'): 150 | name = name.replace('BaseMachine.', 'BM.') 151 | 152 | return logging.getLogger(name) 153 | 154 | @staticmethod 155 | def log_step_start(step_name, description=""): 156 | """Log the start of a workflow step with consistent formatting.""" 157 | logger = WorkflowLogger.get_logger('workflow') 158 | separator = "=" * 60 159 | logger.info(f"\n{Fore.CYAN}{separator}") 160 | logger.info(f"{Fore.CYAN}🚀 Starting: {step_name}") 161 | if description: 162 | logger.info(f"{Fore.CYAN}📝 Description: {description}") 163 | logger.info(f"{Fore.CYAN}{separator}{Style.RESET_ALL}") 164 | 165 | @staticmethod 166 | def log_step_complete(step_name, result=None): 167 | """Log the completion of a workflow step.""" 168 | logger = WorkflowLogger.get_logger('workflow') 169 | logger.info(f"{Fore.GREEN}✅ Completed: {step_name}{Style.RESET_ALL}") 170 | if result: 171 | logger.info(f"{Fore.WHITE}📊 Result: {result}{Style.RESET_ALL}") 172 | 173 | @staticmethod 174 | def log_step_error(step_name, error): 175 | """Log a workflow step error with proper formatting.""" 176 | logger = WorkflowLogger.get_logger('workflow') 177 | logger.error(f"{Fore.RED}❌ Failed: {step_name}") 178 | logger.error(f"{Fore.RED}💥 Error: {error}{Style.RESET_ALL}") 179 | 180 | @staticmethod 181 | def log_workflow_summary(total_steps, completed_steps, errors=None): 182 | """Log a workflow execution summary.""" 183 | logger = WorkflowLogger.get_logger('workflow') 184 | separator = "=" * 60 185 | logger.info(f"\n{Fore.MAGENTA}{separator}") 186 | logger.info(f"{Fore.MAGENTA}📋 Workflow Summary") 187 | logger.info(f"{Fore.MAGENTA}{separator}") 188 | logger.info(f"{Fore.WHITE}📊 Total Steps: {total_steps}") 189 | logger.info(f"{Fore.GREEN}✅ Completed: {completed_steps}") 190 | if errors: 191 | logger.info(f"{Fore.RED}❌ Errors: {len(errors)}") 192 | for error in errors: 193 | logger.error(f"{Fore.RED} • {error}") 194 | success_rate = (completed_steps / total_steps * 100) if total_steps > 0 else 0 195 | logger.info(f"{Fore.CYAN}📈 Success Rate: {success_rate:.1f}%") 196 | logger.info(f"{Fore.MAGENTA}{separator}{Style.RESET_ALL}") 197 | 198 | 199 | # Convenience functions for easy access 200 | def get_logger(name=None): 201 | """Get a logger instance - convenience function.""" 202 | return WorkflowLogger.get_logger(name) 203 | 204 | 205 | def setup_logging(): 206 | """Initialize the unified logging system - convenience function.""" 207 | WorkflowLogger() 208 | 209 | 210 | # Initialize logging when module is imported 211 | setup_logging() -------------------------------------------------------------------------------- /BaseMachine/model_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import Dict, Any 4 | from openai import OpenAI, AzureOpenAI 5 | 6 | class ModelManager: 7 | def __init__(self, config_dir: str): 8 | self.config_dir = config_dir 9 | self.model_config = self._load_model_config() 10 | self.config = self._load_main_config() 11 | 12 | def _load_model_config(self) -> Dict[str, Any]: 13 | model_config_path = os.path.join(self.config_dir, 'model_config.json') 14 | with open(model_config_path, 'r') as f: 15 | return json.load(f) 16 | 17 | def _load_main_config(self) -> Dict[str, Any]: 18 | config_path = os.path.join(self.config_dir, 'config.json') 19 | with open(config_path, 'r') as f: 20 | return json.load(f) 21 | 22 | def get_available_models(self): 23 | """Return list of available models with their descriptions""" 24 | return { 25 | model_id: { 26 | "description": info["description"], 27 | "provider": info["provider"] 28 | } 29 | for model_id, info in self.model_config["models"].items() 30 | } 31 | 32 | def initialize_client(self): 33 | """Initialize and return a list of all model clients based on model_config.""" 34 | clients = [] 35 | for model_id, model_info in self.model_config['models'].items(): 36 | provider = model_info['provider'] 37 | provider_config = self.model_config['provider_configs'][provider] 38 | 39 | # Verify required credentials are present 40 | for required_key in provider_config.get('requires', []): 41 | if not self.config.get(required_key): 42 | raise ValueError(f"Missing required credential: {required_key}") 43 | 44 | # Add 'name' to model_info for easy access 45 | model_info['name'] = model_id 46 | 47 | # Initialize appropriate client 48 | if provider == 'azure': 49 | client = AzureOpenAI( 50 | api_key=self.config['azure_key'], 51 | api_version=model_info['api_version'], 52 | azure_endpoint=self.config['azure_endpoint'] 53 | ) 54 | elif provider == 'siliconflow': 55 | client = OpenAI( 56 | api_key=self.config['siliconflow_key'], 57 | base_url=provider_config['base_url'] 58 | ) 59 | elif provider == 'azure-deepseek': 60 | client = AzureOpenAI( 61 | api_key=self.config['ds_azure_key'], 62 | api_version=model_info['api_version'], 63 | azure_endpoint=self.config['ds_azure_endpoint'] 64 | ) 65 | elif provider == 'openrouter': 66 | # OpenRouter integration with API version 67 | client = OpenAI( 68 | api_key=self.config['openrouter_api_key'], 69 | base_url='https://openrouter.ai/api/v1' 70 | ) 71 | # Only add API key to model_info, without site information 72 | model_info['openrouter_api_key'] = self.config['openrouter_api_key'] 73 | 74 | else: # openai 75 | client = OpenAI( 76 | api_key=self.config['api_key'], 77 | base_url=provider_config.get('base_url') 78 | ) 79 | 80 | # Store completion kwargs in model_info 81 | model_info = self.get_completion_kwargs(model_info) 82 | 83 | clients.append((client, model_info)) 84 | 85 | return clients 86 | 87 | def get_completion_kwargs(self, model_info): 88 | """Get kwargs for completion API call based on selected model""" 89 | 90 | provider = model_info.get("provider") 91 | 92 | kwargs = { 93 | "model": model_info["model_name"], 94 | "max_tokens": model_info["max_tokens"], 95 | "temperature": self.config["temperature"], 96 | "top_p": self.config["top_p"], 97 | "stop": self.config["stop_sequences"] 98 | } 99 | 100 | # Handle provider-specific parameters 101 | if provider == "azure": 102 | if "deployment_name" in model_info: 103 | kwargs["deployment_id"] = model_info["deployment_name"] 104 | # For Azure, deployment_id is used instead of model 105 | del kwargs["model"] 106 | elif provider == "azure-deepseek": 107 | if "deployment_name" in model_info: 108 | kwargs["deployment_id"] = model_info["deployment_name"] 109 | del kwargs["model"] 110 | elif provider == "openrouter": 111 | # Add OpenRouter-specific parameters 112 | # Set up headers with minimum required information 113 | 114 | # Handle OpenRouter's API version if specified 115 | if "api_version" in model_info: 116 | kwargs["openrouter_version"] = model_info["api_version"] 117 | 118 | # Add OpenRouter provider routing parameters if specified 119 | if "openrouter_provider" in model_info: 120 | # Add provider preferences to request 121 | routing_params = {} 122 | provider_config = model_info["openrouter_provider"] 123 | 124 | # Map provider preferences to OpenRouter parameters 125 | if "order" in provider_config: 126 | routing_params["order"] = provider_config["order"] 127 | 128 | if "allow_fallbacks" in provider_config: 129 | routing_params["allow_fallbacks"] = provider_config["allow_fallbacks"] 130 | 131 | if "sort" in provider_config: 132 | routing_params["sort"] = provider_config["sort"] 133 | 134 | if "ignore" in provider_config: 135 | routing_params["skip"] = provider_config["ignore"] 136 | 137 | if "require_parameters" in provider_config: 138 | routing_params["filterParams"] = provider_config["require_parameters"] 139 | 140 | if "data_collection" in provider_config: 141 | routing_params["data_collection"] = provider_config["data_collection"] 142 | 143 | if "quantizations" in provider_config: 144 | routing_params["quantizations"] = provider_config["quantizations"] 145 | 146 | # Add provider parameters directly to the request body 147 | if routing_params: 148 | kwargs["provider"] = routing_params 149 | 150 | # Add OpenRouter reasoning parameter support 151 | if "reasoning" in model_info: 152 | # Store reasoning as a separate parameter, not nested in kwargs 153 | # This way it can be used directly in reliable_parse 154 | kwargs["reasoning"] = model_info["reasoning"] 155 | 156 | # Store the kwargs in model_info's additional_kwargs parameter instead of returning them 157 | model_info["additional_kwargs"] = kwargs 158 | return model_info -------------------------------------------------------------------------------- /BaseMachine/state_machine.py: -------------------------------------------------------------------------------- 1 | # statemachine.py 2 | 3 | import os 4 | import sys 5 | from typing import Any, Callable, Dict, Tuple 6 | import logging 7 | 8 | # Import configuration loading function 9 | from BaseMachine.config_loader import load_config 10 | from BaseMachine.model_manager import ModelManager 11 | 12 | from openai import OpenAI 13 | from openai import AzureOpenAI 14 | 15 | # Add utils directory to system path (as needed) 16 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))) 17 | 18 | class BaseState: 19 | def __init__( 20 | self, 21 | name: str, 22 | action: Callable[..., Any], 23 | next_state_func: Callable[[Any, 'StateMachine'], Tuple[str, Dict[str, Any]]] = None, 24 | ): 25 | self.name = name 26 | self.action = action 27 | self.next_state_func = next_state_func 28 | 29 | def process(self, machine: 'StateMachine', **kwargs): 30 | # Execute action, pass in machine and optional parameters, get result 31 | result = self.action(machine, **kwargs) 32 | # Call next_state_func, may return next state name or (state name, parameters) 33 | next_state_info = self.next_state_func(result, machine) 34 | return next_state_info, result # Return next state info and result 35 | 36 | class ExitState(BaseState): 37 | def __init__(self): 38 | super().__init__(name="Exit", action=lambda machine: None) 39 | 40 | def process(self, machine, **kwargs): 41 | pass # Exit state, no processing needed 42 | 43 | class StateMachine: 44 | def __init__(self, context, state_definitions: Dict[str, Dict], initial_state: str, config_path='', unified_config=None, mode='chat', require_models=True, cwd=None): 45 | self.state = None 46 | self.context = context 47 | self.unified_config = unified_config 48 | self.mode = mode # 'chat', 'agent', 'hybrid', or 'action' 49 | self.require_models = require_models 50 | self.cwd = cwd or os.getcwd() # Use provided cwd or current working directory 51 | 52 | # Change to the specified working directory if provided 53 | if cwd and os.path.isdir(cwd): 54 | self.original_cwd = os.getcwd() 55 | os.chdir(self.cwd) 56 | else: 57 | self.original_cwd = None 58 | 59 | # Initialize model manager only if needed 60 | if self.require_models and mode != 'action': 61 | config_dir = os.path.dirname(config_path) if config_path else os.path.join(os.path.dirname(__file__), '../.config') 62 | self.model_manager = ModelManager(config_dir) 63 | else: 64 | self.model_manager = None 65 | 66 | # Initialize agent-specific attributes 67 | self.agent_results = [] # Store agent action results 68 | self.hybrid_history = [] # Store hybrid mode history 69 | 70 | # # Test model 'create_success' functionality 71 | # self._test_model_create_success() 72 | 73 | # Load configuration 74 | self.config = self._load_config(config_path) 75 | 76 | self.analysis_result = [] 77 | self.messages = getattr(self.context, 'messages', []) 78 | self.total_input_tokens = 0 79 | self.total_output_tokens = 0 80 | 81 | # Initialize all model clients only if model manager exists 82 | if self.model_manager: 83 | self.clients = self.model_manager.initialize_client() 84 | else: 85 | self.clients = {} 86 | 87 | # Note: Client selection is handled by create_chat_action, not here. 88 | 89 | # Create state instances 90 | self.states = self._create_states(state_definitions) 91 | self.state = self.states.get(initial_state, None) 92 | if self.state is None: 93 | raise ValueError(f"Initial state '{initial_state}' is not defined in state_definitions.") 94 | 95 | def _load_config(self, config_path): 96 | if not config_path: 97 | default_config_path = '../.config/config.json' 98 | config_path = os.path.join(os.path.dirname(__file__), default_config_path) 99 | config = load_config(config_path) 100 | config.config_path = config_path # Set config_path attribute 101 | return config 102 | 103 | def _create_states(self, state_definitions): 104 | states = {} 105 | for name, config in state_definitions.items(): 106 | if name == "Exit": 107 | states[name] = ExitState() 108 | else: 109 | states[name] = BaseState( 110 | name=name, 111 | action=config['action'], 112 | next_state_func=config.get("next_state_func", None), 113 | ) 114 | return states 115 | 116 | def process(self): 117 | previous_result = None # Save the result of the previous action 118 | extra_args = {} # Store parameters to pass to the next action 119 | while True: 120 | try: 121 | if isinstance(self.state, ExitState) or self.state is None: 122 | return previous_result # or self.analysis_result 123 | else: 124 | # Call action function, pass in machine and optional parameters 125 | action_func = self.state.action 126 | 127 | # Get the parameter list of action_func 128 | args_spec = action_func.__code__.co_varnames 129 | if len(args_spec) > 1: 130 | # There are other parameters besides 'machine' 131 | # Prepare parameters 132 | kwargs = extra_args if extra_args else {} 133 | result = action_func(self, **kwargs) 134 | extra_args = {} # Clear extra_args 135 | else: 136 | result = action_func(self) 137 | 138 | # Call next_state_func, may return next state name or (state name, parameter dict) 139 | next_state_info = self.state.next_state_func(result, self) 140 | if isinstance(next_state_info, tuple): 141 | next_state_name = next_state_info[0] 142 | extra_args = next_state_info[1] if len(next_state_info) > 1 else {} 143 | self.state = self.states.get(next_state_name, ExitState()) 144 | elif isinstance(next_state_info, str): 145 | next_state_name = next_state_info 146 | self.state = self.states.get(next_state_name, ExitState()) 147 | extra_args = {} 148 | else: 149 | raise ValueError("next_state_func must return a string or a tuple (state_name, args_dict)") 150 | previous_result = result # Update previous_result 151 | except RuntimeError as e: 152 | # Check for Claude AI usage limit 153 | if "Claude AI usage limit reached" in str(e): 154 | logging.error(f"\033[91mClaude AI usage limit reached in state '{self.state.name}'\033[0m") 155 | # Re-raise to let the caller handle it 156 | raise 157 | else: 158 | logging.error(f"\033[91mRuntime error in state '{self.state.name}': {e}\033[0m") 159 | import traceback 160 | tb_str = ''.join(traceback.format_exception(None, e, e.__traceback__)) 161 | logging.error(f"\033[90m{tb_str}\033[0m") 162 | break 163 | except Exception as e: 164 | logging.error(f"\033[91mError in state '{self.state.name}': {e}\033[0m") 165 | import traceback 166 | tb_str = ''.join(traceback.format_exception(None, e, e.__traceback__)) 167 | logging.error(f"\033[90m{tb_str}\033[0m") 168 | break 169 | 170 | def results(self): 171 | return self.analysis_result 172 | 173 | def get_completion_kwargs(self): 174 | """Get the kwargs for completion API call""" 175 | return self.model_manager.get_completion_kwargs() 176 | 177 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | # Install system dependencies 4 | RUN apt-get update && apt-get -y upgrade && \ 5 | apt-get -y install git build-essential cmake wget curl netcat socat net-tools \ 6 | python3 python3-pip sudo xz-utils 7 | 8 | # Create non-root user 'user' 9 | RUN useradd -ms /bin/bash user && \ 10 | usermod -aG sudo user && \ 11 | echo 'user ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers 12 | 13 | # Install CodeQL 14 | RUN wget https://github.com/github/codeql-action/releases/download/codeql-bundle-v2.20.1/codeql-bundle-linux64.tar.gz && \ 15 | tar -xzf codeql-bundle-linux64.tar.gz && \ 16 | mv codeql /opt/codeql && \ 17 | rm codeql-bundle-linux64.tar.gz 18 | 19 | # Download and install Node.js 18 manually 20 | RUN curl -fsSL https://nodejs.org/dist/v18.20.8/node-v18.20.8-linux-x64.tar.xz -o node.tar.xz && \ 21 | mkdir -p /opt/node && \ 22 | tar -xf node.tar.xz -C /opt/node --strip-components=1 && \ 23 | rm node.tar.xz && \ 24 | ln -s /opt/node/bin/node /usr/local/bin/node && \ 25 | ln -s /opt/node/bin/npm /usr/local/bin/npm 26 | 27 | # Switch to non-root user 28 | USER user 29 | WORKDIR /home/user 30 | 31 | # Install Claude Code and SDK 32 | RUN npm install --prefix ~/.local @anthropic-ai/claude-code && \ 33 | python3 -m pip install --user claude-code-sdk 34 | 35 | # Set PATH for pip and npm user-level installs 36 | ENV PATH="/home/user/.local/bin:/home/user/.local/node_modules/.bin:$PATH" 37 | 38 | # Set working directory 39 | WORKDIR /workspace 40 | 41 | # Default command 42 | CMD ["bash"] -------------------------------------------------------------------------------- /QLWorkflow/_01_ql_query_modification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/P1umer/QL-Relax/941ffa7fb0b0b196ba9e4e49e99f5b7fe6e47138/QLWorkflow/_01_ql_query_modification/__init__.py -------------------------------------------------------------------------------- /QLWorkflow/_01_ql_query_modification/modification_context.py: -------------------------------------------------------------------------------- 1 | """ 2 | Context for QL Query Modification 3 | """ 4 | 5 | import json 6 | import os 7 | from datetime import datetime 8 | 9 | # Get the directory of the script for relative paths 10 | SCRIPT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | 12 | 13 | class ModificationContext: 14 | """ 15 | Context for the QL query modification step. 16 | """ 17 | 18 | def __init__(self, cwe_number=None, ql_file_path=None, current_iteration=1, query_name=None, **kwargs): 19 | """ 20 | Initialize the modification context. 21 | 22 | Args: 23 | cwe_number: The CWE number being processed 24 | ql_file_path: Path to the original QL file 25 | current_iteration: Current iteration number 26 | **kwargs: Additional parameters 27 | """ 28 | # Core parameters 29 | self.cwe_number = cwe_number 30 | self.ql_file_path = ql_file_path 31 | self.current_iteration = current_iteration 32 | self.query_name = query_name or (os.path.splitext(os.path.basename(ql_file_path))[0] if ql_file_path else None) 33 | 34 | # Query content 35 | self.current_ql_content = None 36 | self.modified_ql_path = None 37 | 38 | # Results from previous iterations 39 | self.previous_results = kwargs.get('previous_results', {}) 40 | 41 | # Original QL path (for path resolution in modification_config.py) 42 | self.original_ql_path = kwargs.get('original_ql_path', ql_file_path) 43 | 44 | # Output directory 45 | default_output_dir = os.path.join(SCRIPT_DIR, 'qlworkspace', f'CWE-{cwe_number:03d}_{query_name}' if query_name else f'CWE-{cwe_number:03d}') 46 | self.output_dir = kwargs.get('output_dir', default_output_dir) 47 | 48 | # Working directory for agent 49 | self.working_directory = kwargs.get('working_directory', self.output_dir) 50 | 51 | # For LLM interactions 52 | self.messages = [] 53 | 54 | # Logging 55 | self.interactions_log = [] 56 | 57 | # Load the original QL content 58 | if ql_file_path and os.path.exists(ql_file_path): 59 | with open(ql_file_path, 'r') as f: 60 | self.current_ql_content = f.read() 61 | 62 | def log_interaction(self, action_type, request, response): 63 | """Log request and response for tracking.""" 64 | interaction = { 65 | 'timestamp': datetime.now().isoformat(), 66 | 'iteration': self.current_iteration, 67 | 'action': action_type, 68 | 'request': request, 69 | 'response': response 70 | } 71 | self.interactions_log.append(interaction) 72 | 73 | # Save to main interactions log file 74 | log_file = os.path.join(self.output_dir, 'interactions_log.json') 75 | os.makedirs(os.path.dirname(log_file), exist_ok=True) 76 | 77 | # Load existing log if exists 78 | existing_log = [] 79 | if os.path.exists(log_file): 80 | with open(log_file, 'r') as f: 81 | existing_log = json.load(f) 82 | 83 | # Append new interaction 84 | existing_log.append(interaction) 85 | 86 | # Save updated log 87 | with open(log_file, 'w') as f: 88 | json.dump(existing_log, f, indent=2) 89 | 90 | # Also save to iteration-specific directory 91 | iteration_dir = os.path.join(self.output_dir, f"iteration_{self.current_iteration}") 92 | os.makedirs(iteration_dir, exist_ok=True) 93 | 94 | # Save this specific interaction 95 | interaction_file = os.path.join(iteration_dir, f"{action_type}_interaction_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json") 96 | with open(interaction_file, 'w') as f: 97 | json.dump(interaction, f, indent=2) 98 | 99 | def __str__(self): 100 | return f"ModificationContext(cwe={self.cwe_number}, iteration={self.current_iteration})" 101 | 102 | def __repr__(self): 103 | return self.__str__() 104 | 105 | def get(self, key, default=None): 106 | """Get attribute value with default fallback.""" 107 | # Handle key mapping for compatibility 108 | if key == 'iteration': 109 | return self.current_iteration 110 | return getattr(self, key, default) -------------------------------------------------------------------------------- /QLWorkflow/_01_ql_query_modification/modification_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools for QL Query Modification 3 | """ 4 | 5 | import os 6 | import re 7 | 8 | 9 | def extract_ql_metadata(ql_content): 10 | """Extract metadata from QL query content.""" 11 | metadata = { 12 | 'tags': [], 13 | 'kind': None, 14 | 'description': None, 15 | 'cwe_numbers': [] 16 | } 17 | 18 | # Extract @tags 19 | tag_matches = re.findall(r'\* @tags?\s+(.+)', ql_content) 20 | for match in tag_matches: 21 | tags = [tag.strip() for tag in match.split()] 22 | metadata['tags'].extend(tags) 23 | 24 | # Extract @kind 25 | kind_match = re.search(r'\* @kind\s+(.+)', ql_content) 26 | if kind_match: 27 | metadata['kind'] = kind_match.group(1).strip() 28 | 29 | # Extract @description 30 | desc_match = re.search(r'\* @description\s+(.+)', ql_content) 31 | if desc_match: 32 | metadata['description'] = desc_match.group(1).strip() 33 | 34 | # Extract CWE numbers from tags 35 | for tag in metadata['tags']: 36 | cwe_match = re.match(r'external/cwe/cwe-(\d+)', tag, re.IGNORECASE) 37 | if cwe_match: 38 | metadata['cwe_numbers'].append(int(cwe_match.group(1))) 39 | 40 | return metadata 41 | 42 | 43 | def validate_ql_syntax(ql_content): 44 | """Basic validation of QL syntax (placeholder for actual validation).""" 45 | # Check for basic QL structure 46 | required_patterns = [ 47 | r'import\s+\w+', # Import statements 48 | r'from\s+.+\s+where\s+.+\s+select', # Basic query structure 49 | ] 50 | 51 | for pattern in required_patterns: 52 | if not re.search(pattern, ql_content, re.IGNORECASE | re.DOTALL): 53 | return False, f"Missing required pattern: {pattern}" 54 | 55 | return True, "Basic syntax validation passed" 56 | 57 | 58 | def compare_query_versions(original_content, modified_content): 59 | """Compare original and modified queries to identify changes.""" 60 | changes = { 61 | 'lines_added': 0, 62 | 'lines_removed': 0, 63 | 'structural_changes': [] 64 | } 65 | 66 | original_lines = original_content.split('\n') 67 | modified_lines = modified_content.split('\n') 68 | 69 | # Simple line count comparison 70 | changes['lines_added'] = max(0, len(modified_lines) - len(original_lines)) 71 | changes['lines_removed'] = max(0, len(original_lines) - len(modified_lines)) 72 | 73 | # Check for structural changes (simplified) 74 | original_imports = len(re.findall(r'^import\s+', original_content, re.MULTILINE)) 75 | modified_imports = len(re.findall(r'^import\s+', modified_content, re.MULTILINE)) 76 | 77 | if modified_imports > original_imports: 78 | changes['structural_changes'].append('Added new imports') 79 | 80 | # Check for predicate additions 81 | original_predicates = len(re.findall(r'^predicate\s+\w+', original_content, re.MULTILINE)) 82 | modified_predicates = len(re.findall(r'^predicate\s+\w+', modified_content, re.MULTILINE)) 83 | 84 | if modified_predicates > original_predicates: 85 | changes['structural_changes'].append('Added new predicates') 86 | 87 | return changes -------------------------------------------------------------------------------- /QLWorkflow/_02_run_ql_query/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/P1umer/QL-Relax/941ffa7fb0b0b196ba9e4e49e99f5b7fe6e47138/QLWorkflow/_02_run_ql_query/__init__.py -------------------------------------------------------------------------------- /QLWorkflow/_02_run_ql_query/query_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | QL Query Execution Configuration 3 | Defines the state machine for running QL queries using CodeQL. 4 | """ 5 | 6 | import subprocess 7 | import os 8 | import csv 9 | import json 10 | 11 | # Get the directory of the script for relative paths 12 | SCRIPT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | 14 | 15 | def run_ql_query_action(machine): 16 | """ 17 | Action to execute the QL query using run_juliet.py. 18 | """ 19 | # Get the QL file path to run 20 | ql_path = machine.context.ql_file_path 21 | cwe_number = machine.context.cwe_number 22 | 23 | print(f"\n[Run QL Query] Executing query for CWE-{cwe_number} iteration {machine.context.current_iteration}") 24 | print(f"[Run QL Query] Input QL file: {ql_path}") 25 | 26 | # Create output directory for this iteration 27 | iteration_dir = os.path.join(machine.context.output_dir, f"iteration_{machine.context.current_iteration}") 28 | os.makedirs(iteration_dir, exist_ok=True) 29 | 30 | # Use query_results directory for both input and output 31 | query_output_dir = os.path.join(iteration_dir, 'query_results') 32 | os.makedirs(query_output_dir, exist_ok=True) 33 | 34 | # The modified query should already be in the passed ql_path 35 | print(f"[Run QL Query] Using query from: {ql_path}") 36 | 37 | # For proper module resolution, we need to run from the project's codeql directory 38 | # If the query is not already in the codeql directory, copy it there with a different name 39 | if hasattr(machine.context, 'original_ql_path') and not ql_path.startswith(os.path.dirname(machine.context.original_ql_path)): 40 | import shutil 41 | original_dir = os.path.dirname(machine.context.original_ql_path) 42 | original_name = os.path.basename(machine.context.original_ql_path) 43 | 44 | # Create a temporary file with a unique name to avoid conflicts 45 | temp_name = f"{os.path.splitext(original_name)[0]}_modified_{machine.context.current_iteration}.ql" 46 | temp_ql_path = os.path.join(original_dir, temp_name) 47 | 48 | # Copy the modified QL file to the codeql directory with temp name 49 | shutil.copy2(ql_path, temp_ql_path) 50 | print(f"[Run QL Query] Copied modified QL to codeql directory as: {temp_ql_path}") 51 | 52 | # Use the temp path in codeql directory for execution 53 | ql_path = temp_ql_path 54 | machine.context.temp_ql_path = temp_ql_path # Store for cleanup later 55 | else: 56 | print(f"[Run QL Query] Query already in codeql directory: {ql_path}") 57 | 58 | # Construct the command with custom output directory 59 | command = [ 60 | 'python3', 61 | os.path.join(SCRIPT_DIR, 'run_juliet.py'), 62 | '--run-queries', 63 | '--cwe', f'{cwe_number:03d}', 64 | '--ql', ql_path, 65 | '--output', query_output_dir 66 | ] 67 | 68 | # Run the command 69 | try: 70 | print(f"[Run QL Query] Running command: {' '.join(command)}") 71 | # record running time 72 | import time 73 | start_time = time.time() 74 | result = subprocess.run(command, capture_output=True, text=True) 75 | end_time = time.time() 76 | running_time = end_time - start_time 77 | print(f"[Run QL Query] Running time: {running_time:.2f} seconds") 78 | 79 | # Save command output 80 | output_log = { 81 | 'command': ' '.join(command), 82 | 'stdout': result.stdout, 83 | 'stderr': result.stderr, 84 | 'returncode': result.returncode, 85 | 'running_time': running_time 86 | } 87 | 88 | # Save log in the query output directory 89 | log_file = os.path.join(query_output_dir, 'query_execution_log.json') 90 | with open(log_file, 'w') as f: 91 | json.dump(output_log, f, indent=2) 92 | 93 | if result.returncode != 0: 94 | return f"Query execution failed: {result.stderr}" 95 | 96 | # Look for CSV file in the output directory 97 | # Both CSV and SARIF files are generated in the same directory 98 | csv_file = None 99 | 100 | # Find CSV file in the output directory 101 | for file in os.listdir(query_output_dir): 102 | if file.endswith('.csv'): 103 | csv_file = os.path.join(query_output_dir, file) 104 | print(f"[Run QL Query] Found CSV file: {csv_file}") 105 | break 106 | 107 | if not csv_file: 108 | print(f"[Run QL Query] No CSV file found in {query_output_dir}") 109 | 110 | machine.context.query_result_file = csv_file 111 | 112 | if not csv_file: 113 | print(f"[Run QL Query] WARNING: No CSV file found after query execution") 114 | 115 | # Clean up temporary file if created 116 | if hasattr(machine.context, 'temp_ql_path') and os.path.exists(machine.context.temp_ql_path): 117 | os.remove(machine.context.temp_ql_path) 118 | print(f"[Run QL Query] Cleaned up temporary QL file: {machine.context.temp_ql_path}") 119 | 120 | return "Query executed successfully" 121 | 122 | except Exception as e: 123 | # Clean up temporary file in case of error 124 | if hasattr(machine.context, 'temp_ql_path') and os.path.exists(machine.context.temp_ql_path): 125 | os.remove(machine.context.temp_ql_path) 126 | print(f"[Run QL Query] Cleaned up temporary QL file after error: {machine.context.temp_ql_path}") 127 | 128 | return f"Error executing query: {str(e)}" 129 | 130 | 131 | def parse_query_results_action(machine): 132 | """ 133 | Parse the SARIF results from the query execution and count threadFlows. 134 | """ 135 | print(f"[Run QL Query] Parsing query results...") 136 | print(f"[Run QL Query] CSV file path: {machine.context.query_result_file}") 137 | 138 | # Try to find the corresponding SARIF file 139 | if machine.context.query_result_file: 140 | # SARIF file should be in the same directory as CSV file 141 | sarif_path = machine.context.query_result_file.replace('.csv', '.sarif') 142 | else: 143 | sarif_path = None 144 | 145 | # First try to parse SARIF for threadFlow count 146 | threadflow_count = 0 147 | if sarif_path and os.path.exists(sarif_path): 148 | try: 149 | with open(sarif_path, 'r', encoding='utf-8') as f: 150 | sarif_data = json.load(f) 151 | 152 | # Count all threadFlows 153 | for run in sarif_data.get('runs', []): 154 | for result in run.get('results', []): 155 | for code_flow in result.get('codeFlows', []): 156 | threadflow_count += len(code_flow.get('threadFlows', [])) 157 | 158 | print(f"[Run QL Query] Found SARIF file with {threadflow_count} threadFlows") 159 | except Exception as e: 160 | print(f"[Run QL Query] Error parsing SARIF: {str(e)}") 161 | 162 | # Fall back to CSV parsing if needed 163 | if not machine.context.query_result_file or not os.path.exists(machine.context.query_result_file): 164 | machine.context.query_results = [] 165 | machine.context.result_count = threadflow_count if threadflow_count > 0 else 0 166 | print(f"[Run QL Query] No results file found at: {machine.context.query_result_file}") 167 | return "No results file found" 168 | 169 | try: 170 | results = [] 171 | with open(machine.context.query_result_file, 'r') as f: 172 | csv_reader = csv.DictReader(f) 173 | for row in csv_reader: 174 | results.append(row) 175 | 176 | machine.context.query_results = results 177 | # Use threadFlow count if available, otherwise use CSV row count 178 | machine.context.result_count = threadflow_count if threadflow_count > 0 else len(results) 179 | 180 | # Calculate result distribution 181 | from QLWorkflow._02_run_ql_query.query_tools import analyze_result_distribution 182 | machine.context.result_distribution = analyze_result_distribution(results) 183 | 184 | # Determine output directory based on context 185 | if machine.context.current_iteration == 1 and hasattr(machine.context, 'is_origin_run') and machine.context.is_origin_run: 186 | # For origin run in first iteration, save to initial/query_results/ 187 | output_dir = os.path.join(machine.context.output_dir, 'initial', 'query_results') 188 | else: 189 | # For all modified queries, save to iteration_X/query_results/ 190 | iteration_dir = os.path.join(machine.context.output_dir, f"iteration_{machine.context.current_iteration}") 191 | output_dir = os.path.join(iteration_dir, 'query_results') 192 | 193 | # Perform evaluation if SARIF exists 194 | evaluation_metrics = {} 195 | if sarif_path and os.path.exists(sarif_path): 196 | from QLWorkflow.util.evaluation_utils import evaluate_sarif_results 197 | # Pass output_dir to save good/bad results 198 | # Find the actual CWE directory in Juliet test suite 199 | testcases_base = os.path.join(SCRIPT_DIR, 'juliet-test-suite-c', 'testcases') 200 | source_base_dir = None 201 | if os.path.exists(testcases_base): 202 | for dirname in os.listdir(testcases_base): 203 | if dirname.startswith(f'CWE{machine.context.cwe_number}_'): 204 | source_base_dir = os.path.join(testcases_base, dirname) 205 | break 206 | 207 | evaluation_metrics = evaluate_sarif_results(sarif_path, output_dir, source_base_dir) 208 | print(f"[Run QL Query] Evaluation: TP={evaluation_metrics['true_positive_count']}, FP={evaluation_metrics['false_positive_count']}") 209 | print(f"[Run QL Query] Saved good_results.json and bad_results.json to {output_dir}") 210 | 211 | # Save complete results with evaluation metrics 212 | complete_results = { 213 | 'ql_file': machine.context.ql_file_path, 214 | 'result_count': machine.context.result_count, 215 | 'csv_file': machine.context.query_result_file, 216 | 'sarif_file': sarif_path if sarif_path and os.path.exists(sarif_path) else None 217 | } 218 | 219 | # Add evaluation metrics if available 220 | if evaluation_metrics: 221 | complete_results.update(evaluation_metrics) 222 | # Store in context for later use 223 | machine.context.evaluation_metrics = evaluation_metrics 224 | 225 | complete_results_file = os.path.join(output_dir, 'results_log.json') 226 | with open(complete_results_file, 'w') as f: 227 | json.dump(complete_results, f, indent=2) 228 | 229 | print(f"[Run QL Query] Parsed {machine.context.result_count} results") 230 | return f"Parsed {machine.context.result_count} results" 231 | 232 | except Exception as e: 233 | machine.context.query_results = [] 234 | machine.context.result_count = 0 235 | return f"Error parsing results: {str(e)}" 236 | 237 | 238 | def exit_action(machine): 239 | """Exit action - cleanup temp files and return the result count.""" 240 | # Clean up temporary QL file if it was created 241 | if hasattr(machine.context, 'temp_ql_path') and machine.context.temp_ql_path: 242 | if os.path.exists(machine.context.temp_ql_path): 243 | try: 244 | os.remove(machine.context.temp_ql_path) 245 | print(f"[Run QL Query] Cleaned up temporary file: {machine.context.temp_ql_path}") 246 | except Exception as e: 247 | print(f"[Run QL Query] Warning: Failed to clean up temp file: {e}") 248 | 249 | return machine.context.result_count 250 | 251 | 252 | # State machine configuration for query execution 253 | state_definitions = { 254 | 'RunQLQuery': { 255 | 'action': run_ql_query_action, 256 | 'next_state_func': lambda result, machine: 'ParseResults' if 'successfully' in result.lower() else 'Exit', 257 | }, 258 | 'ParseResults': { 259 | 'action': parse_query_results_action, 260 | 'next_state_func': lambda result, machine: 'Exit', 261 | }, 262 | 'Exit': { 263 | 'action': exit_action, 264 | 'next_state_func': None, 265 | }, 266 | } -------------------------------------------------------------------------------- /QLWorkflow/_02_run_ql_query/query_context.py: -------------------------------------------------------------------------------- 1 | """ 2 | Context for QL Query Execution 3 | """ 4 | 5 | import os 6 | 7 | # Get the directory of the script for relative paths 8 | SCRIPT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 9 | 10 | 11 | class QueryContext: 12 | """ 13 | Context for the QL query execution step. 14 | """ 15 | 16 | def __init__(self, cwe_number=None, ql_file_path=None, current_iteration=1, query_name=None, **kwargs): 17 | """ 18 | Initialize the query execution context. 19 | 20 | Args: 21 | cwe_number: The CWE number being processed 22 | ql_file_path: Path to the QL file to execute 23 | current_iteration: Current iteration number 24 | query_name: Name of the query 25 | **kwargs: Additional parameters 26 | """ 27 | # Core parameters 28 | self.cwe_number = cwe_number 29 | self.ql_file_path = ql_file_path 30 | self.current_iteration = current_iteration 31 | self.query_name = query_name or (os.path.basename(ql_file_path).split('.')[0] if ql_file_path else None) 32 | 33 | # Query execution results 34 | self.query_result_file = None 35 | self.query_results = [] 36 | self.result_count = 0 37 | self.result_distribution = {} 38 | 39 | # Output directory 40 | default_output_dir = os.path.join(SCRIPT_DIR, 'qlworkspace', f'CWE-{cwe_number:03d}_{self.query_name}' if self.query_name else f'CWE-{cwe_number:03d}') 41 | self.output_dir = kwargs.get('output_dir', default_output_dir) 42 | 43 | # For LLM interactions (if needed) 44 | self.messages = [] 45 | 46 | # Previous iteration data 47 | self.previous_result_count = kwargs.get('previous_result_count', 0) 48 | 49 | # Store original QL path for module resolution 50 | self.original_ql_path = kwargs.get('original_ql_path', None) 51 | 52 | # Working directory for agent compatibility 53 | self.working_directory = kwargs.get('working_directory', self.output_dir) 54 | 55 | def get(self, key, default=None): 56 | """ 57 | Get attribute value with dictionary-style access for compatibility with agent action utils. 58 | """ 59 | # Handle key mapping for compatibility 60 | if key == 'iteration': 61 | return self.current_iteration 62 | return getattr(self, key, default) 63 | 64 | def __str__(self): 65 | return f"QueryContext(cwe={self.cwe_number}, iteration={self.current_iteration})" 66 | 67 | def __repr__(self): 68 | return self.__str__() -------------------------------------------------------------------------------- /QLWorkflow/_02_run_ql_query/query_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools for QL Query Execution 3 | """ 4 | 5 | import os 6 | import subprocess 7 | import json 8 | 9 | # Get the directory of the script for relative paths 10 | SCRIPT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | 12 | 13 | def check_database_exists(cwe_number): 14 | """Check if the CodeQL database exists for the given CWE.""" 15 | db_path = os.path.join(SCRIPT_DIR, 'juliet-test-suite-c', 'testcasesdb', f'CWE{cwe_number}_cpp-db') 16 | return os.path.exists(db_path) 17 | 18 | 19 | def create_database_if_needed(cwe_number): 20 | """Create the CodeQL database if it doesn't exist.""" 21 | if not check_database_exists(cwe_number): 22 | command = [ 23 | 'python3', 24 | os.path.join(SCRIPT_DIR, 'run_juliet.py'), 25 | '--create-db', 26 | '--cwe', f'{cwe_number:03d}' 27 | ] 28 | 29 | result = subprocess.run(command, capture_output=True, text=True) 30 | if result.returncode != 0: 31 | return False, f"Failed to create database: {result.stderr}" 32 | 33 | return True, "Database created successfully" 34 | 35 | return True, "Database already exists" 36 | 37 | 38 | def get_query_metadata(ql_file_path): 39 | """Extract metadata from the QL file.""" 40 | metadata = { 41 | 'filename': os.path.basename(ql_file_path), 42 | 'directory': os.path.dirname(ql_file_path), 43 | 'size': os.path.getsize(ql_file_path) if os.path.exists(ql_file_path) else 0 44 | } 45 | 46 | if os.path.exists(ql_file_path): 47 | with open(ql_file_path, 'r') as f: 48 | content = f.read() 49 | # Count lines 50 | metadata['lines'] = len(content.split('\n')) 51 | # Check for specific patterns 52 | metadata['has_dataflow'] = 'DataFlow' in content or 'TaintTracking' in content 53 | metadata['has_guards'] = 'isBarrier' in content or 'isSanitizer' in content 54 | 55 | return metadata 56 | 57 | 58 | def analyze_result_distribution(results): 59 | """Analyze the distribution of query results.""" 60 | distribution = {} 61 | 62 | for result in results: 63 | # Group by file or location 64 | if 'File' in result: 65 | file_path = result['File'] 66 | file_name = os.path.basename(file_path) 67 | distribution[file_name] = distribution.get(file_name, 0) + 1 68 | 69 | return distribution -------------------------------------------------------------------------------- /QLWorkflow/_03_output_validation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/P1umer/QL-Relax/941ffa7fb0b0b196ba9e4e49e99f5b7fe6e47138/QLWorkflow/_03_output_validation/__init__.py -------------------------------------------------------------------------------- /QLWorkflow/_03_output_validation/validation_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Output Validation Configuration 3 | Defines the state machine for validating query results and determining if broadening was successful. 4 | """ 5 | 6 | from BaseMachine.agent_action_utils import create_agent_action 7 | import json 8 | import os 9 | import subprocess 10 | from QLWorkflow.util.logging_utils import get_ql_workflow_log_path, get_action_type_from_prompt 11 | 12 | 13 | def check_query_results_action(machine): 14 | """ 15 | Check the query execution results and gather information. 16 | """ 17 | print(f"\n[Output Validation] Checking query execution results for CWE-{machine.context.cwe_number} iteration {machine.context.current_iteration}") 18 | 19 | iteration_dir = os.path.join(machine.context.output_dir, f"iteration_{machine.context.current_iteration}") 20 | # Query execution log is now in query_results directory 21 | execution_log_file = os.path.join(iteration_dir, 'query_results', 'query_execution_log.json') 22 | 23 | if not os.path.exists(execution_log_file): 24 | machine.context.has_compilation_errors = False 25 | machine.context.compilation_errors = [] 26 | return "no_execution_log" 27 | 28 | with open(execution_log_file, 'r') as f: 29 | execution_log = json.load(f) 30 | 31 | # Check for compilation errors in stderr 32 | stderr = execution_log.get('stderr', '') 33 | if 'ERROR:' in stderr: 34 | # Extract error messages 35 | error_lines = [line for line in stderr.split('\n') if 'ERROR:' in line] 36 | machine.context.compilation_errors = error_lines 37 | machine.context.has_compilation_errors = True 38 | print(f"[Output Validation] Found {len(error_lines)} compilation errors") 39 | else: 40 | machine.context.has_compilation_errors = False 41 | machine.context.compilation_errors = [] 42 | print(f"[Output Validation] No compilation errors found") 43 | 44 | return "continue" 45 | 46 | 47 | def analyze_results_action(machine): 48 | """ 49 | Analyze the query results to determine if broadening was successful. 50 | """ 51 | current_count = machine.context.current_result_count 52 | previous_count = machine.context.previous_result_count 53 | iteration = machine.context.current_iteration 54 | 55 | # Calculate improvement 56 | if previous_count > 0: 57 | improvement_percentage = ((current_count - previous_count) / previous_count) * 100 58 | else: 59 | improvement_percentage = 100 if current_count > 0 else 0 60 | 61 | # Check if query had compilation errors 62 | had_compilation_errors = getattr(machine.context, 'has_compilation_errors', False) 63 | 64 | # Prepare analysis data 65 | analysis = { 66 | 'iteration': iteration, 67 | 'current_result_count': current_count, 68 | 'previous_result_count': previous_count, 69 | 'improvement_percentage': improvement_percentage, 70 | 'improved': current_count > previous_count, 71 | 'result_distribution': machine.context.result_distribution, 72 | 'had_compilation_errors': had_compilation_errors 73 | } 74 | 75 | machine.context.analysis_result = analysis 76 | 77 | # Save analysis 78 | iteration_dir = os.path.join(machine.context.output_dir, f"iteration_{iteration}") 79 | reports_dir = os.path.join(iteration_dir, 'reports') 80 | os.makedirs(reports_dir, exist_ok=True) 81 | analysis_file = os.path.join(reports_dir, 'validation_analysis.json') 82 | with open(analysis_file, 'w') as f: 83 | json.dump(analysis, f, indent=2) 84 | 85 | return "Analysis completed" 86 | 87 | 88 | def generate_validation_report_action(machine): 89 | """ 90 | Generate a validation report using chat mode. 91 | """ 92 | analysis = machine.context.analysis_result 93 | 94 | # Use appropriate model based on mode 95 | model_name = 'azure-gpt4o' 96 | if hasattr(machine, 'mode') and machine.mode == 'agent': 97 | # Agent mode - potentially use different settings 98 | model_name = 'azure-gpt4o' 99 | 100 | prompt_template = """Analyze the CodeQL query results for CWE-{cwe_number} iteration {current_iteration}. 101 | 102 | File Paths: 103 | - Original QL File: {original_ql_file} 104 | - Modified QL File: {modified_ql_file} 105 | - Current Results CSV(Results of modified ql): {current_csv_file} 106 | - Previous Results CSV(Results of original ql): {previous_csv_file} 107 | 108 | Summary Data: 109 | - Current Result Count: {current_result_count} 110 | - Previous Result Count: {previous_result_count} 111 | - Improvement: {improvement_percentage:.1f}% 112 | 113 | Please provide: 114 | 1. A summary of whether the query broadening was successful 115 | 2. Analysis of the result distribution (are we getting meaningful results or just noise?) 116 | 3. Recommendations for the next iteration: 117 | - Should we continue broadening? 118 | - What specific aspects should be modified? 119 | - Are there any concerning patterns in the results? 120 | 4. Risk assessment: Are we maintaining the integrity of the security check while broadening? 121 | 122 | Keep your analysis concise and actionable.""" 123 | 124 | # Set up logging context for QLWorkflow 125 | log_context = { 126 | 'cwe_number': machine.context.cwe_number, 127 | 'query_name': machine.context.query_name if hasattr(machine.context, 'query_name') else f"CWE-{machine.context.cwe_number:03d}", 128 | 'iteration': machine.context.current_iteration, 129 | 'output_dir': machine.context.output_dir 130 | } 131 | 132 | # Get the log path and set action type 133 | log_path = get_ql_workflow_log_path(log_context) 134 | if log_path: 135 | machine.context.session_log_path = str(log_path) # Convert Path to string 136 | machine.context.action_type = 'validation' 137 | 138 | # Use agent action for agent mode with streaming JSON logging enabled 139 | action = create_agent_action( 140 | prompt_template=prompt_template, 141 | save_option='both', 142 | system_prompt="You are a CodeQL validation expert. Analyze query results and provide recommendations. You have access to Read tool to examine the CSV files for detailed analysis.", 143 | allowed_tools=["Read", "Grep"], 144 | enable_stream_logging=True 145 | ) 146 | 147 | # Find CSV file paths 148 | iteration_dir = os.path.join(machine.context.output_dir, f"iteration_{machine.context.current_iteration}") 149 | current_csv_file = None 150 | previous_csv_file = None 151 | 152 | # Find current CSV file in query_results 153 | current_results_dir = os.path.join(iteration_dir, "query_results") 154 | if os.path.exists(current_results_dir): 155 | for file in os.listdir(current_results_dir): 156 | if file.endswith('.csv'): 157 | current_csv_file = os.path.join(current_results_dir, file) 158 | break 159 | 160 | # Find previous CSV file - from initial/ for iteration 1, from previous iteration for others 161 | if machine.context.current_iteration == 1: 162 | previous_results_dir = os.path.join(machine.context.output_dir, "initial") 163 | else: 164 | prev_iteration = machine.context.current_iteration - 1 165 | previous_results_dir = os.path.join(machine.context.output_dir, f"iteration_{prev_iteration}", "query_results") 166 | 167 | if os.path.exists(previous_results_dir): 168 | for file in os.listdir(previous_results_dir): 169 | if file.endswith('.csv'): 170 | previous_csv_file = os.path.join(previous_results_dir, file) 171 | break 172 | 173 | # Format the prompt for saving 174 | formatted_prompt = prompt_template.format( 175 | cwe_number=machine.context.cwe_number, 176 | current_iteration=machine.context.current_iteration, 177 | current_result_count=analysis['current_result_count'], 178 | previous_result_count=analysis['previous_result_count'], 179 | improvement_percentage=analysis['improvement_percentage'], 180 | original_ql_file=machine.context.original_ql_file, 181 | modified_ql_file=machine.context.modified_ql_file, 182 | current_csv_file=current_csv_file or "No CSV file found", 183 | previous_csv_file=previous_csv_file or "No CSV file found" 184 | ) 185 | 186 | # Save the prompt to iteration/reports directory 187 | iteration_dir = os.path.join(machine.context.output_dir, f"iteration_{machine.context.current_iteration}") 188 | reports_dir = os.path.join(iteration_dir, "reports") 189 | os.makedirs(reports_dir, exist_ok=True) 190 | prompt_file = os.path.join(reports_dir, "03_validation_prompt.txt") 191 | with open(prompt_file, 'w') as f: 192 | f.write(formatted_prompt) 193 | 194 | result = action(machine, 195 | cwe_number=machine.context.cwe_number, 196 | current_iteration=machine.context.current_iteration, 197 | current_result_count=analysis['current_result_count'], 198 | previous_result_count=analysis['previous_result_count'], 199 | improvement_percentage=analysis['improvement_percentage'], 200 | original_ql_file=machine.context.original_ql_file, 201 | modified_ql_file=machine.context.modified_ql_file, 202 | current_csv_file=current_csv_file or "No CSV file found", 203 | previous_csv_file=previous_csv_file or "No CSV file found") 204 | 205 | # Save the response - agent mode returns a dict with 'response' key 206 | response_file = os.path.join(reports_dir, "03_validation_response.txt") 207 | if isinstance(result, dict) and 'response' in result: 208 | with open(response_file, 'w') as f: 209 | f.write(result['response']) 210 | # Store response for later use 211 | machine.context.validation_response = result['response'] 212 | elif isinstance(result, str): 213 | with open(response_file, 'w') as f: 214 | f.write(result) 215 | machine.context.validation_response = result 216 | 217 | return result 218 | 219 | 220 | def save_validation_conclusion_action(machine): 221 | """ 222 | Save the validation conclusion and recommendations. 223 | """ 224 | # Extract conclusion from agent response 225 | response = getattr(machine.context, 'validation_response', '') 226 | 227 | # Categorize the result based on requirements 228 | current_count = machine.context.current_result_count 229 | previous_count = machine.context.previous_result_count 230 | has_compile_error = machine.context.analysis_result.get('had_compilation_errors', False) 231 | query_failed = machine.context.analysis_result.get('query_failed', False) 232 | 233 | # Determine result category 234 | if has_compile_error and query_failed: 235 | result_category = "compile_error" 236 | continue_iteration = True # Continue to fix errors 237 | stop_reason = "Compilation errors need to be fixed" 238 | elif current_count > machine.context.initial_result_count: 239 | result_category = "success_increase" 240 | continue_iteration = False # Success, stop iteration 241 | stop_reason = "Successfully increased result count" 242 | elif current_count < machine.context.initial_result_count: 243 | result_category = "result_decrease" 244 | continue_iteration = True # Continue with warning about decrease 245 | stop_reason = "Result count decreased, need different approach" 246 | else: 247 | result_category = "no_change" 248 | continue_iteration = True # Continue trying 249 | stop_reason = "No change in results, continue iteration" 250 | 251 | conclusion = { 252 | 'iteration': machine.context.current_iteration, 253 | 'result_category': result_category, 254 | 'success': result_category == "success_increase", 255 | 'current_count': current_count, 256 | 'previous_count': previous_count, 257 | 'has_compile_error': has_compile_error, 258 | 'error_message': '\n'.join(machine.context.compilation_errors) if has_compile_error else '', 259 | 'continue_iteration': continue_iteration, 260 | 'stop_reason': stop_reason 261 | } 262 | 263 | machine.context.validation_conclusion = conclusion 264 | 265 | # Save conclusion 266 | iteration_dir = os.path.join(machine.context.output_dir, f"iteration_{machine.context.current_iteration}") 267 | conclusion_file = os.path.join(iteration_dir, 'validation_conclusion.json') 268 | with open(conclusion_file, 'w') as f: 269 | json.dump(conclusion, f, indent=2) 270 | 271 | # Log the interaction 272 | log_prompt = f"Validation report for CWE-{machine.context.cwe_number} iteration {machine.context.current_iteration}" 273 | if machine.context.messages: 274 | response = machine.context.messages[-1]['content'] 275 | machine.context.log_interaction('validation_report', log_prompt, response) 276 | 277 | return "Validation completed" 278 | 279 | 280 | def exit_action(machine): 281 | """Exit action - returns the validation conclusion.""" 282 | return machine.context.validation_conclusion 283 | 284 | 285 | # State machine configuration for output validation 286 | state_definitions = { 287 | 'CheckQueryResults': { 288 | 'action': check_query_results_action, 289 | 'next_state_func': lambda result, machine: 'AnalyzeResults', 290 | }, 291 | 'AnalyzeResults': { 292 | 'action': analyze_results_action, 293 | 'next_state_func': lambda result, machine: 'GenerateValidationReport', 294 | }, 295 | 'GenerateValidationReport': { 296 | 'action': generate_validation_report_action, 297 | 'next_state_func': lambda result, machine: 'SaveConclusion', 298 | }, 299 | 'SaveConclusion': { 300 | 'action': save_validation_conclusion_action, 301 | 'next_state_func': lambda result, machine: 'Exit', 302 | }, 303 | 'Exit': { 304 | 'action': exit_action, 305 | 'next_state_func': None, 306 | }, 307 | } -------------------------------------------------------------------------------- /QLWorkflow/_03_output_validation/validation_context.py: -------------------------------------------------------------------------------- 1 | """ 2 | Context for Output Validation 3 | """ 4 | 5 | import json 6 | import os 7 | from datetime import datetime 8 | 9 | # Get the directory of the script for relative paths 10 | SCRIPT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | 12 | 13 | class ValidationContext: 14 | """ 15 | Context for the output validation step. 16 | """ 17 | 18 | def __init__(self, cwe_number=None, current_iteration=1, query_name=None, **kwargs): 19 | """ 20 | Initialize the validation context. 21 | 22 | Args: 23 | cwe_number: The CWE number being processed 24 | current_iteration: Current iteration number 25 | **kwargs: Additional parameters including result counts and files 26 | """ 27 | # Core parameters 28 | self.cwe_number = cwe_number 29 | self.current_iteration = current_iteration 30 | self.query_name = query_name 31 | 32 | # Result counts 33 | self.current_result_count = kwargs.get('current_result_count', 0) 34 | self.previous_result_count = kwargs.get('previous_result_count', 0) 35 | self.initial_result_count = kwargs.get('initial_result_count', 0) 36 | 37 | # Result analysis 38 | self.result_distribution = kwargs.get('result_distribution', {}) 39 | self.analysis_result = {} 40 | self.validation_conclusion = {} 41 | 42 | # File paths 43 | self.original_ql_file = kwargs.get('original_ql_file', '') 44 | self.modified_ql_file = kwargs.get('modified_ql_file', '') 45 | 46 | # Output directory 47 | default_output_dir = os.path.join(SCRIPT_DIR, 'qlworkspace', f'CWE-{cwe_number:03d}_{query_name}' if query_name else f'CWE-{cwe_number:03d}') 48 | self.output_dir = kwargs.get('output_dir', default_output_dir) 49 | 50 | # Working directory for agent 51 | self.working_directory = kwargs.get('working_directory', self.output_dir) 52 | 53 | # For LLM interactions 54 | self.messages = [] 55 | 56 | # Logging 57 | self.interactions_log = [] 58 | 59 | def log_interaction(self, action_type, request, response): 60 | """Log request and response for tracking.""" 61 | interaction = { 62 | 'timestamp': datetime.now().isoformat(), 63 | 'iteration': self.current_iteration, 64 | 'action': action_type, 65 | 'request': request, 66 | 'response': response 67 | } 68 | self.interactions_log.append(interaction) 69 | 70 | # Save to main interactions log file 71 | log_file = os.path.join(self.output_dir, 'interactions_log.json') 72 | os.makedirs(os.path.dirname(log_file), exist_ok=True) 73 | 74 | # Load existing log if exists 75 | existing_log = [] 76 | if os.path.exists(log_file): 77 | with open(log_file, 'r') as f: 78 | existing_log = json.load(f) 79 | 80 | # Append new interaction 81 | existing_log.append(interaction) 82 | 83 | # Save updated log 84 | with open(log_file, 'w') as f: 85 | json.dump(existing_log, f, indent=2) 86 | 87 | # Also save to iteration-specific directory 88 | iteration_dir = os.path.join(self.output_dir, f"iteration_{self.current_iteration}") 89 | os.makedirs(iteration_dir, exist_ok=True) 90 | 91 | # Save this specific interaction 92 | interaction_file = os.path.join(iteration_dir, f"{action_type}_interaction_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json") 93 | with open(interaction_file, 'w') as f: 94 | json.dump(interaction, f, indent=2) 95 | 96 | def __str__(self): 97 | return f"ValidationContext(cwe={self.cwe_number}, iteration={self.current_iteration}, current={self.current_result_count}, previous={self.previous_result_count})" 98 | 99 | def __repr__(self): 100 | return self.__str__() 101 | 102 | def get(self, key, default=None): 103 | """Get attribute value with default fallback.""" 104 | # Handle key mapping for compatibility 105 | if key == 'iteration': 106 | return self.current_iteration 107 | return getattr(self, key, default) -------------------------------------------------------------------------------- /QLWorkflow/_03_output_validation/validation_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools for Output Validation 3 | """ 4 | 5 | import json 6 | import os 7 | import csv 8 | 9 | 10 | def calculate_metrics(current_count, previous_count): 11 | """Calculate improvement metrics.""" 12 | metrics = { 13 | 'absolute_change': current_count - previous_count, 14 | 'percentage_change': 0, 15 | 'multiplier': 1 16 | } 17 | 18 | if previous_count > 0: 19 | metrics['percentage_change'] = ((current_count - previous_count) / previous_count) * 100 20 | metrics['multiplier'] = current_count / previous_count 21 | elif current_count > 0: 22 | metrics['percentage_change'] = 100 23 | metrics['multiplier'] = float('inf') 24 | 25 | return metrics 26 | 27 | 28 | def assess_result_quality(result_distribution, total_count): 29 | """Assess the quality of results based on distribution.""" 30 | quality_indicators = { 31 | 'concentration_score': 0, # How concentrated are the results 32 | 'diversity_score': 0, # How diverse are the results 33 | 'likely_noise': False 34 | } 35 | 36 | if not result_distribution or total_count == 0: 37 | return quality_indicators 38 | 39 | # Calculate concentration (Gini coefficient approximation) 40 | sorted_counts = sorted(result_distribution.values(), reverse=True) 41 | cumulative_sum = 0 42 | for i, count in enumerate(sorted_counts): 43 | cumulative_sum += count * (i + 1) 44 | 45 | if sum(sorted_counts) > 0: 46 | quality_indicators['concentration_score'] = (2 * cumulative_sum) / (len(sorted_counts) * sum(sorted_counts)) - 1 47 | 48 | # Calculate diversity 49 | quality_indicators['diversity_score'] = len(result_distribution) / total_count 50 | 51 | # Check for likely noise (too many unique results) 52 | if quality_indicators['diversity_score'] > 0.8: 53 | quality_indicators['likely_noise'] = True 54 | 55 | return quality_indicators 56 | 57 | 58 | def generate_iteration_summary(iteration_data): 59 | """Generate a summary for the iteration.""" 60 | summary = { 61 | 'iteration': iteration_data.get('iteration', 0), 62 | 'timestamp': iteration_data.get('timestamp', ''), 63 | 'results': { 64 | 'count': iteration_data.get('current_count', 0), 65 | 'improvement': iteration_data.get('improvement_percentage', 0), 66 | 'quality': iteration_data.get('quality_assessment', {}) 67 | }, 68 | 'recommendation': iteration_data.get('recommendation', 'Continue iteration'), 69 | 'key_findings': [] 70 | } 71 | 72 | # Add key findings based on data 73 | if summary['results']['improvement'] > 50: 74 | summary['key_findings'].append('Significant improvement in result count') 75 | 76 | if iteration_data.get('quality_assessment', {}).get('likely_noise', False): 77 | summary['key_findings'].append('Results may contain noise - consider refining constraints') 78 | 79 | return summary 80 | 81 | 82 | def should_continue_iteration(validation_conclusion, current_iteration, max_iterations=5): 83 | """Determine if iteration should continue.""" 84 | # Check max iterations 85 | if current_iteration >= max_iterations: 86 | return False, "Maximum iterations reached" 87 | 88 | # Check improvement 89 | if not validation_conclusion.get('success', False): 90 | return False, "No improvement in results" 91 | 92 | # Check for diminishing returns 93 | analysis = validation_conclusion.get('agent_analysis', '') 94 | if 'noise' in analysis.lower() or 'too broad' in analysis.lower(): 95 | return False, "Query may be too broad" 96 | 97 | # Check explicit recommendation 98 | if not validation_conclusion.get('continue_iteration', True): 99 | return False, "Validation recommends stopping" 100 | 101 | return True, "Continue with next iteration" -------------------------------------------------------------------------------- /QLWorkflow/_04_iteration_control/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/P1umer/QL-Relax/941ffa7fb0b0b196ba9e4e49e99f5b7fe6e47138/QLWorkflow/_04_iteration_control/__init__.py -------------------------------------------------------------------------------- /QLWorkflow/_04_iteration_control/iteration_context.py: -------------------------------------------------------------------------------- 1 | """ 2 | Context for Iteration Control 3 | """ 4 | 5 | import os 6 | 7 | # Get the directory of the script for relative paths 8 | SCRIPT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 9 | 10 | 11 | class IterationContext: 12 | """ 13 | Context for the iteration control workflow. 14 | """ 15 | 16 | def __init__(self, cwe_number=None, ql_file_path=None, max_iterations=5, query_name=None, **kwargs): 17 | """ 18 | Initialize the iteration control context. 19 | 20 | Args: 21 | cwe_number: The CWE number being processed 22 | ql_file_path: Original QL file path 23 | max_iterations: Maximum number of iterations 24 | **kwargs: Additional parameters 25 | """ 26 | # Core parameters 27 | self.cwe_number = cwe_number 28 | self.original_ql_path = ql_file_path 29 | self.current_ql_path = ql_file_path # Will be updated each iteration 30 | self.max_iterations = max_iterations 31 | self.query_name = query_name or (os.path.splitext(os.path.basename(ql_file_path))[0] if ql_file_path else None) 32 | 33 | # Iteration tracking 34 | self.current_iteration = 1 35 | self.iteration_history = [] 36 | 37 | # Result tracking 38 | self.initial_result_count = kwargs.get('initial_result_count', 0) 39 | self.previous_result_count = self.initial_result_count 40 | self.current_result_count = self.initial_result_count 41 | 42 | # Validation tracking 43 | self.last_validation = None 44 | self.stop_reason = None 45 | 46 | # Output directory 47 | default_output_dir = os.path.join(SCRIPT_DIR, 'qlworkspace', f'CWE-{cwe_number:03d}_{self.query_name}' if self.query_name else f'CWE-{cwe_number:03d}') 48 | self.output_dir = kwargs.get('output_dir', default_output_dir) 49 | 50 | # Final report 51 | self.final_report = None 52 | 53 | # For LLM interactions 54 | self.messages = [] 55 | 56 | # Working directory for agent compatibility 57 | self.working_directory = self.output_dir 58 | 59 | def get(self, key, default=None): 60 | """ 61 | Get attribute value with dictionary-style access for compatibility with agent action utils. 62 | """ 63 | # Handle key mapping for compatibility 64 | if key == 'iteration': 65 | return self.current_iteration 66 | return getattr(self, key, default) 67 | 68 | def __str__(self): 69 | return f"IterationContext(cwe={self.cwe_number}, iteration={self.current_iteration}/{self.max_iterations})" 70 | 71 | def __repr__(self): 72 | return self.__str__() -------------------------------------------------------------------------------- /QLWorkflow/_04_iteration_control/iteration_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools for Iteration Control 3 | """ 4 | 5 | import os 6 | import json 7 | from datetime import datetime 8 | 9 | 10 | def create_iteration_directory(output_dir, iteration_number): 11 | """Create a directory for the current iteration.""" 12 | iteration_dir = os.path.join(output_dir, f"iteration_{iteration_number}") 13 | os.makedirs(iteration_dir, exist_ok=True) 14 | return iteration_dir 15 | 16 | 17 | def load_iteration_history(output_dir): 18 | """Load iteration history from previous runs.""" 19 | history_file = os.path.join(output_dir, 'iteration_history.json') 20 | if os.path.exists(history_file): 21 | with open(history_file, 'r') as f: 22 | return json.load(f) 23 | return [] 24 | 25 | 26 | def save_iteration_history(output_dir, history): 27 | """Save iteration history to file.""" 28 | history_file = os.path.join(output_dir, 'iteration_history.json') 29 | with open(history_file, 'w') as f: 30 | json.dump(history, f, indent=2) 31 | 32 | 33 | def calculate_convergence_metrics(iteration_history): 34 | """Calculate metrics to determine if the iterations are converging.""" 35 | if len(iteration_history) < 2: 36 | return None 37 | 38 | metrics = { 39 | 'improvement_trend': [], 40 | 'is_converging': False, 41 | 'convergence_rate': 0 42 | } 43 | 44 | # Calculate improvement between consecutive iterations 45 | for i in range(1, len(iteration_history)): 46 | prev_count = iteration_history[i-1].get('result_count', 0) 47 | curr_count = iteration_history[i].get('result_count', 0) 48 | 49 | if prev_count > 0: 50 | improvement = ((curr_count - prev_count) / prev_count) * 100 51 | else: 52 | improvement = 100 if curr_count > 0 else 0 53 | 54 | metrics['improvement_trend'].append(improvement) 55 | 56 | # Check if improvements are decreasing (converging) 57 | if len(metrics['improvement_trend']) >= 2: 58 | recent_improvements = metrics['improvement_trend'][-3:] 59 | if all(imp < 20 for imp in recent_improvements): 60 | metrics['is_converging'] = True 61 | 62 | # Calculate average improvement rate 63 | metrics['convergence_rate'] = sum(recent_improvements) / len(recent_improvements) 64 | 65 | return metrics 66 | 67 | 68 | def generate_iteration_summary(iteration_data): 69 | """Generate a summary for a single iteration.""" 70 | summary = { 71 | 'iteration_number': iteration_data.get('iteration', 0), 72 | 'timestamp': datetime.now().isoformat(), 73 | 'ql_file': os.path.basename(iteration_data.get('ql_path', '')), 74 | 'results': { 75 | 'count': iteration_data.get('result_count', 0), 76 | 'validation_passed': iteration_data.get('validation', {}).get('success', False) 77 | }, 78 | 'next_action': 'continue' if iteration_data.get('validation', {}).get('continue_iteration', False) else 'stop' 79 | } 80 | 81 | return summary 82 | 83 | 84 | def should_early_stop(iteration_history, current_iteration): 85 | """Determine if we should stop early based on convergence or other factors.""" 86 | # Check for convergence 87 | convergence = calculate_convergence_metrics(iteration_history) 88 | if convergence and convergence['is_converging']: 89 | return True, "Iterations are converging with minimal improvement" 90 | 91 | # Check for oscillation (results going up and down) 92 | if len(iteration_history) >= 3: 93 | recent_counts = [h.get('result_count', 0) for h in iteration_history[-3:]] 94 | if recent_counts[0] < recent_counts[1] > recent_counts[2]: 95 | return True, "Results are oscillating" 96 | 97 | # Check for explosion (too many results) 98 | if iteration_history and iteration_history[-1].get('result_count', 0) > 1000: 99 | return True, "Result count exceeds reasonable threshold" 100 | 101 | return False, None -------------------------------------------------------------------------------- /QLWorkflow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/P1umer/QL-Relax/941ffa7fb0b0b196ba9e4e49e99f5b7fe6e47138/QLWorkflow/__init__.py -------------------------------------------------------------------------------- /QLWorkflow/pipeline_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | QL Workflow Pipeline Configuration 3 | Main pipeline that orchestrates the QL query broadening workflow. 4 | """ 5 | 6 | from BaseMachine.action_utils import call_sub_state_machine_action 7 | import subprocess 8 | import json 9 | import os 10 | 11 | # Get the absolute path of the script directory 12 | SCRIPT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 13 | 14 | # Import iteration control configuration 15 | from QLWorkflow._04_iteration_control.iteration_config import state_definitions as iteration_states 16 | from QLWorkflow._04_iteration_control.iteration_context import IterationContext 17 | 18 | # Import query execution for initial baseline 19 | from QLWorkflow._02_run_ql_query.query_config import state_definitions as query_states 20 | from QLWorkflow._02_run_ql_query.query_context import QueryContext 21 | 22 | 23 | class QLWorkflowContext: 24 | """ 25 | Root context for the QL workflow pipeline. 26 | """ 27 | 28 | def __init__(self, **kwargs): 29 | """ 30 | Initialize the QL workflow context. 31 | 32 | Args: 33 | **kwargs: Configuration parameters 34 | """ 35 | # Workflow configuration 36 | self.max_iterations = kwargs.get('max_iterations', 5) 37 | self.output_base_dir = kwargs.get('output_dir', os.path.join(SCRIPT_DIR, 'qlworkspace')) 38 | self.cwe_limit = kwargs.get('cwe_limit', None) 39 | self.specific_cwe = kwargs.get('specific_cwe', None) 40 | self.process_all_cwes = kwargs.get('process_all_cwes', False) 41 | self.specific_query = kwargs.get('specific_query', None) 42 | 43 | # CWE and QL data (will be populated by pipeline) 44 | self.common_cwes = [] 45 | self.cwe_ql_mapping = {} 46 | self.current_cwe = None 47 | self.current_ql_files = [] 48 | self.processed_cwes = set() # Track processed CWEs for --all mode 49 | 50 | # Results tracking 51 | self.workflow_results = {} 52 | 53 | # For LLM interactions 54 | self.messages = [] 55 | 56 | def __str__(self): 57 | return f"QLWorkflowContext(cwes={len(self.common_cwes)})" 58 | 59 | def __repr__(self): 60 | return self.__str__() 61 | 62 | 63 | def get_common_cwes_action(machine): 64 | """ 65 | Get the list of common CWEs using run_juliet.py --list-common-cwes. 66 | """ 67 | command = ['python3', os.path.join(SCRIPT_DIR, 'run_juliet.py'), '--list-common-cwes'] 68 | 69 | try: 70 | result = subprocess.run(command, capture_output=True, text=True) 71 | 72 | if result.returncode != 0: 73 | return f"Failed to get common CWEs: {result.stderr}" 74 | 75 | # Parse the output to extract CWE numbers and QL files 76 | output_lines = result.stdout.split('\n') 77 | current_cwe = None 78 | cwe_ql_mapping = {} 79 | 80 | for line in output_lines: 81 | if line.startswith('CWE-'): 82 | # Extract CWE number 83 | cwe_num = int(line.split('-')[1].split(':')[0]) 84 | current_cwe = cwe_num 85 | cwe_ql_mapping[current_cwe] = [] 86 | elif line.strip().startswith('- /') and current_cwe: 87 | # Extract QL file path 88 | ql_path = line.strip().lstrip('- ') 89 | # Skip _after.ql and _temp files created by the workflow 90 | if not ql_path.endswith('_after.ql') and '_temp_' not in ql_path: 91 | cwe_ql_mapping[current_cwe].append(ql_path) 92 | 93 | machine.context.common_cwes = sorted(cwe_ql_mapping.keys()) 94 | machine.context.cwe_ql_mapping = cwe_ql_mapping 95 | 96 | # Save the mapping 97 | os.makedirs(machine.context.output_base_dir, exist_ok=True) 98 | mapping_file = os.path.join(machine.context.output_base_dir, 'cwe_ql_mapping.json') 99 | with open(mapping_file, 'w') as f: 100 | json.dump(cwe_ql_mapping, f, indent=2) 101 | 102 | return f"Found {len(machine.context.common_cwes)} common CWEs" 103 | 104 | except Exception as e: 105 | return f"Error getting common CWEs: {str(e)}" 106 | 107 | 108 | def select_next_cwe_action(machine): 109 | """ 110 | Select the next CWE to process. 111 | """ 112 | if machine.context.process_all_cwes: 113 | # Process all CWEs mode 114 | for cwe in machine.context.common_cwes: 115 | if cwe not in machine.context.processed_cwes: 116 | machine.context.current_cwe = cwe 117 | machine.context.current_ql_files = machine.context.cwe_ql_mapping[cwe] 118 | 119 | # Filter by specific query if provided 120 | if machine.context.specific_query: 121 | filtered_files = [f for f in machine.context.current_ql_files 122 | if machine.context.specific_query in os.path.basename(f)] 123 | if filtered_files: 124 | machine.context.current_ql_files = filtered_files 125 | else: 126 | print(f"Warning: Query '{machine.context.specific_query}' not found for CWE-{cwe}") 127 | continue 128 | 129 | machine.context.processed_cwes.add(cwe) 130 | return f"Selected CWE-{cwe} with {len(machine.context.current_ql_files)} QL files" 131 | 132 | # All CWEs processed 133 | return "all_processed" 134 | 135 | else: 136 | # Process specific CWE (original behavior) 137 | if machine.context.specific_cwe: 138 | if machine.context.specific_cwe in machine.context.common_cwes: 139 | if machine.context.specific_cwe not in machine.context.workflow_results: 140 | machine.context.current_cwe = machine.context.specific_cwe 141 | machine.context.current_ql_files = machine.context.cwe_ql_mapping[machine.context.specific_cwe] 142 | 143 | # Filter by specific query if provided 144 | if machine.context.specific_query: 145 | filtered_files = [f for f in machine.context.current_ql_files 146 | if machine.context.specific_query in os.path.basename(f)] 147 | if filtered_files: 148 | machine.context.current_ql_files = filtered_files 149 | else: 150 | return f"Error: Query '{machine.context.specific_query}' not found for CWE-{machine.context.specific_cwe}" 151 | 152 | return f"Selected CWE-{machine.context.specific_cwe} with {len(machine.context.current_ql_files)} QL files" 153 | else: 154 | return "all_processed" # Already processed 155 | else: 156 | return f"CWE-{machine.context.specific_cwe} not found in common CWEs" 157 | 158 | # Since we're only processing one specific CWE, we're done 159 | return "all_processed" 160 | 161 | 162 | def process_cwe_ql_files_action(machine): 163 | """ 164 | Process all QL files for the current CWE. 165 | """ 166 | cwe = machine.context.current_cwe 167 | ql_files = machine.context.current_ql_files 168 | cwe_results = [] 169 | 170 | print(f"\n{'='*80}") 171 | print(f"[Pipeline] Processing CWE-{cwe} with {len(ql_files)} QL file(s)") 172 | print(f"{'='*80}") 173 | 174 | for ql_file in ql_files: 175 | print(f"\n[Pipeline] Processing {ql_file} for CWE-{cwe}") 176 | 177 | # Create output directory for this CWE and QL file 178 | ql_name = os.path.splitext(os.path.basename(ql_file))[0] 179 | output_dir = os.path.join(machine.context.output_base_dir, f"CWE-{cwe:03d}_{ql_name}") 180 | os.makedirs(output_dir, exist_ok=True) 181 | 182 | # Initialize - the initial count will be determined in the first iteration 183 | initial_count = 0 # Will be set by save_origin_query_action in iteration 1 184 | print(f"[Pipeline] Starting iterations for {ql_file}") 185 | 186 | # Run the iteration workflow starting from iteration 1 187 | iteration_context = IterationContext( 188 | cwe_number=cwe, 189 | ql_file_path=ql_file, 190 | query_name=ql_name, 191 | max_iterations=machine.context.max_iterations, 192 | initial_result_count=initial_count, 193 | output_dir=output_dir 194 | ) 195 | iteration_context.current_iteration = 1 # Start from iteration 1 196 | 197 | # Convert origin path to project codeql path 198 | if 'qlworkspace/origin/codeql/' in ql_file: 199 | # Extract the relative path after origin/codeql/ 200 | relative_path = ql_file.split('qlworkspace/origin/codeql/')[-1] 201 | 202 | # Construct the project codeql path 203 | project_codeql_path = os.path.join(output_dir, 'codeql', relative_path) 204 | project_codeql_path = os.path.abspath(project_codeql_path) 205 | 206 | # Check if the file exists in project codeql directory 207 | if os.path.exists(project_codeql_path): 208 | iteration_context.original_ql_path = project_codeql_path 209 | print(f"[Pipeline] Using project CodeQL file: {project_codeql_path}") 210 | else: 211 | # Fallback to original if project copy doesn't exist 212 | iteration_context.original_ql_path = ql_file 213 | print(f"[Pipeline] WARNING: Project CodeQL file not found at {project_codeql_path}") 214 | print(f"[Pipeline] Using original: {ql_file}") 215 | else: 216 | iteration_context.original_ql_path = ql_file 217 | 218 | iteration_result = call_sub_state_machine_action( 219 | sub_state_definitions=iteration_states, 220 | sub_initial_state='CheckIterationLimit', 221 | sub_context_cls=lambda: iteration_context, 222 | save_option='result' 223 | )(machine) 224 | 225 | # Store results 226 | ql_result = { 227 | 'ql_file': ql_file, 228 | 'initial_count': initial_count, 229 | 'final_report': iteration_context.final_report or {} 230 | } 231 | cwe_results.append(ql_result) 232 | 233 | # Store CWE results 234 | machine.context.workflow_results[cwe] = cwe_results 235 | 236 | # Save intermediate results 237 | results_file = os.path.join(machine.context.output_base_dir, 'workflow_results.json') 238 | with open(results_file, 'w') as f: 239 | json.dump(machine.context.workflow_results, f, indent=2) 240 | 241 | return f"Processed {len(ql_files)} QL files for CWE-{cwe}" 242 | 243 | 244 | def generate_summary_report_action(machine): 245 | """ 246 | Generate a summary report for the processed CWE. 247 | """ 248 | if not machine.context.workflow_results: 249 | print("\nNo results to summarize") 250 | return "No results" 251 | 252 | summary = { 253 | 'total_cwes': len(machine.context.workflow_results), 254 | 'total_ql_files': sum(len(results) for results in machine.context.workflow_results.values()), 255 | 'cwe_summaries': {} 256 | } 257 | 258 | for cwe, ql_results in machine.context.workflow_results.items(): 259 | cwe_summary = { 260 | 'ql_files_processed': len(ql_results), 261 | 'total_improvement': 0, 262 | 'successful_modifications': 0, 263 | 'compilation_failures': 0, 264 | 'result_decreases': 0 265 | } 266 | 267 | for ql_result in ql_results: 268 | final_report = ql_result.get('final_report', {}) 269 | improvement = final_report.get('overall_improvement', {}) 270 | 271 | # Check the final iteration's result category 272 | iterations = final_report.get('iterations', []) 273 | if iterations: 274 | last_iteration = iterations[-1] 275 | validation = last_iteration.get('validation', {}) 276 | result_category = validation.get('result_category', '') 277 | 278 | if result_category == 'success_increase': 279 | cwe_summary['successful_modifications'] += 1 280 | cwe_summary['total_improvement'] += improvement.get('percentage', 0) 281 | elif result_category == 'compile_error': 282 | cwe_summary['compilation_failures'] += 1 283 | elif result_category == 'result_decrease': 284 | cwe_summary['result_decreases'] += 1 285 | 286 | summary['cwe_summaries'][f'CWE-{cwe}'] = cwe_summary 287 | 288 | # Save summary 289 | summary_file = os.path.join(machine.context.output_base_dir, 'workflow_summary.json') 290 | with open(summary_file, 'w') as f: 291 | json.dump(summary, f, indent=2) 292 | 293 | print(f"\nWorkflow Summary saved to: {summary_file}") 294 | return "Summary generated" 295 | 296 | 297 | def exit_action(machine): 298 | """Exit action.""" 299 | return None 300 | 301 | 302 | # Root workflow state machine configuration 303 | state_definitions = { 304 | 'GetCommonCWEs': { 305 | 'action': get_common_cwes_action, 306 | 'next_state_func': lambda result, machine: 'SelectNextCWE', 307 | }, 308 | 'SelectNextCWE': { 309 | 'action': select_next_cwe_action, 310 | 'next_state_func': lambda result, machine: 'ProcessCWEQLFiles' if result != 'all_processed' else 'GenerateSummary', 311 | }, 312 | 'ProcessCWEQLFiles': { 313 | 'action': process_cwe_ql_files_action, 314 | 'next_state_func': lambda result, machine: 'SelectNextCWE', 315 | }, 316 | 'GenerateSummary': { 317 | 'action': generate_summary_report_action, 318 | 'next_state_func': lambda result, machine: 'Exit', 319 | }, 320 | 'Exit': { 321 | 'action': exit_action, 322 | 'next_state_func': None, 323 | }, 324 | } -------------------------------------------------------------------------------- /QLWorkflow/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Utility modules for QLWorkflow -------------------------------------------------------------------------------- /QLWorkflow/util/evaluation_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Enhanced evaluation utilities using CodeQL-extracted function boundaries. 3 | """ 4 | 5 | import json 6 | import os 7 | import csv 8 | import subprocess 9 | import time 10 | 11 | # Get the directory of the script for relative paths 12 | SCRIPT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | 14 | 15 | def extract_functions_for_cwe(cwe_number): 16 | """ 17 | Extract function boundaries for a specific CWE using CodeQL. 18 | Returns a dictionary mapping (file_path, line_number) to function info. 19 | """ 20 | # Paths 21 | query_file = os.path.join(SCRIPT_DIR, 'QLWorkflow', 'util', 'function_dump.ql') 22 | util_dir = os.path.join(SCRIPT_DIR, 'qlworkspace', 'util') 23 | os.makedirs(util_dir, exist_ok=True) 24 | output_csv = os.path.join(util_dir, f'cwe{cwe_number}_functions.csv') 25 | 26 | # Check if cached CSV exists and is recent (within 1 hour) 27 | if os.path.exists(output_csv): 28 | file_age = time.time() - os.path.getmtime(output_csv) 29 | if file_age < 3600: # 1 hour cache 30 | print(f"Using cached function boundaries from {output_csv}") 31 | # Parse and return cached data 32 | function_map = {} 33 | with open(output_csv, 'r') as f: 34 | reader = csv.DictReader(f) 35 | for row in reader: 36 | func_name = row['col0'] 37 | file_path = row['col1'] 38 | start_line = int(row['col2']) 39 | end_line = int(row['col3']) 40 | 41 | # Normalize file path - remove /workspace prefix 42 | if file_path.startswith('/workspace/'): 43 | file_path = file_path[11:] 44 | 45 | # Store function info for all lines in its range 46 | for line_num in range(start_line, end_line + 1): 47 | key = (file_path, line_num) 48 | function_map[key] = { 49 | 'name': func_name, 50 | 'start_line': start_line, 51 | 'end_line': end_line, 52 | 'type': classify_function_name(func_name) 53 | } 54 | return function_map 55 | 56 | # Run query using run_juliet.py 57 | command = [ 58 | 'python3', 59 | os.path.join(SCRIPT_DIR, 'run_juliet.py'), 60 | '--run-queries', 61 | '--cwe', f'{cwe_number:03d}', 62 | '--ql', query_file, 63 | '--output', util_dir 64 | ] 65 | 66 | print(f"Extracting function boundaries for CWE-{cwe_number}...") 67 | result = subprocess.run(command, capture_output=True, text=True) 68 | 69 | if result.returncode != 0: 70 | print(f"Error running query: {result.stderr}") 71 | return {} 72 | 73 | # The output file should be named based on the query file 74 | # run_juliet.py outputs files as {cwe_name}_{ql_name}.csv 75 | # We need to find the correct one for this CWE 76 | expected_output = None 77 | 78 | # Look for CSV files that match our CWE number and contain 'function_dump' 79 | for file in os.listdir(util_dir): 80 | if file.endswith('.csv') and 'function_dump' in file: 81 | # Check if this file contains data for our CWE 82 | file_path = os.path.join(util_dir, file) 83 | with open(file_path, 'r') as f: 84 | # Read first few lines to check CWE number 85 | for _ in range(10): 86 | line = f.readline() 87 | if f'CWE{cwe_number}_' in line or f'CWE{cwe_number:03d}_' in line: 88 | expected_output = file_path 89 | break 90 | if expected_output: 91 | break 92 | 93 | # If still not found, try the most recent function_dump CSV 94 | if not expected_output: 95 | csv_files = [f for f in os.listdir(util_dir) if f.endswith('.csv') and 'function_dump' in f] 96 | if csv_files: 97 | # Get the most recently modified one 98 | csv_files_with_time = [(f, os.path.getmtime(os.path.join(util_dir, f))) for f in csv_files] 99 | csv_files_with_time.sort(key=lambda x: x[1], reverse=True) 100 | expected_output = os.path.join(util_dir, csv_files_with_time[0][0]) 101 | 102 | if expected_output and os.path.exists(expected_output) and expected_output != output_csv: 103 | # Move to our expected location 104 | os.rename(expected_output, output_csv) 105 | 106 | # Parse CSV and build lookup structure 107 | function_map = {} 108 | 109 | # Check if the output CSV exists before trying to open it 110 | if not os.path.exists(output_csv): 111 | print(f"Warning: Function boundary CSV not found at {output_csv}") 112 | return {} 113 | 114 | with open(output_csv, 'r') as f: 115 | reader = csv.DictReader(f) 116 | for row in reader: 117 | func_name = row['col0'] 118 | file_path = row['col1'] 119 | start_line = int(row['col2']) 120 | end_line = int(row['col3']) 121 | 122 | # Normalize file path - remove /workspace prefix 123 | if file_path.startswith('/workspace/'): 124 | file_path = file_path[11:] # Remove '/workspace/' 125 | 126 | # Store function info for all lines in its range 127 | for line_num in range(start_line, end_line + 1): 128 | key = (file_path, line_num) 129 | function_map[key] = { 130 | 'name': func_name, 131 | 'start_line': start_line, 132 | 'end_line': end_line, 133 | 'type': classify_function_name(func_name) 134 | } 135 | 136 | # Keep the CSV file for debugging/caching purposes 137 | # if os.path.exists(output_csv): 138 | # os.remove(output_csv) 139 | 140 | return function_map 141 | 142 | 143 | def classify_function_name(func_name): 144 | """ 145 | Classify function based on its name following Juliet conventions. 146 | """ 147 | if not func_name: 148 | return 'unknown' 149 | 150 | # Pattern matches 151 | func_lower = func_name.lower() 152 | 153 | # Bad function patterns 154 | if 'bad' in func_lower: 155 | return 'bad' 156 | 157 | # Good function patterns 158 | if 'good' in func_lower: 159 | return 'good' 160 | 161 | return 'unknown' 162 | 163 | 164 | def get_function_from_line(file_path, line_number, function_map): 165 | """ 166 | Get function type for a specific line using pre-extracted function data. 167 | 168 | Args: 169 | file_path: Path to the source file 170 | line_number: Line number to check 171 | function_map: Pre-extracted function boundaries from CodeQL 172 | 173 | Returns: 174 | str: 'bad', 'good', or 'unknown' 175 | """ 176 | # Try multiple path variations to find a match 177 | path_variations = [] 178 | 179 | # Original path 180 | path_variations.append(file_path) 181 | 182 | # If it's a short path, try to expand it 183 | if not file_path.startswith('/') and not file_path.startswith('juliet-test-suite-c'): 184 | # Get CWE number from filename 185 | import re 186 | cwe_match = re.search(r'CWE(\d+)_', file_path) 187 | 188 | if cwe_match: 189 | cwe_num = cwe_match.group(1).lstrip('0') 190 | 191 | # Check if the file has a subdirectory prefix like s01/ 192 | if file_path.startswith('s'): 193 | # Look for entries in function_map that contain our file path 194 | for key in function_map: 195 | key_path = key[0] 196 | # Check if this key contains our filename 197 | if file_path in key_path and f'CWE{cwe_num}_' in key_path: 198 | # Extract the CWE directory name 199 | if '/testcases/' in key_path: 200 | # Find where testcases/ ends and extract the full path after it 201 | idx = key_path.find('/testcases/') + len('/testcases/') 202 | relative_path = key_path[idx:] 203 | path_variations.append(relative_path) 204 | 205 | # Also try the full path from function map 206 | path_variations.append(key_path) 207 | 208 | # If the key path starts with /workspace/, also try without it 209 | if key_path.startswith('/workspace/'): 210 | path_variations.append(key_path[11:]) # Remove '/workspace/' 211 | else: 212 | # File without subdirectory (like CWE476 files) 213 | # Look for matching files in function_map 214 | filename = os.path.basename(file_path) 215 | for key in function_map: 216 | key_path = key[0] 217 | if filename in key_path and f'CWE{cwe_num}_' in key_path: 218 | # Extract various path patterns 219 | if '/testcases/' in key_path: 220 | idx = key_path.find('/testcases/') + len('/testcases/') 221 | relative_path = key_path[idx:] 222 | path_variations.append(relative_path) 223 | 224 | path_variations.append(key_path) 225 | 226 | if key_path.startswith('/workspace/'): 227 | path_variations.append(key_path[11:]) 228 | 229 | # Standard patterns - try to find the CWE directory dynamically 230 | # Look for any directory starting with CWE{num}_ 231 | for key in function_map: 232 | if f'/CWE{cwe_num}_' in key[0]: 233 | # Extract the CWE directory name 234 | path_parts = key[0].split('/') 235 | for i, part in enumerate(path_parts): 236 | if part.startswith(f'CWE{cwe_num}_'): 237 | cwe_dir = part 238 | if file_path.startswith('s'): 239 | # With subdirectory 240 | path_variations.append(f'juliet-test-suite-c/testcases/{cwe_dir}/{file_path}') 241 | path_variations.append(f'testcases/{cwe_dir}/{file_path}') 242 | path_variations.append(f'{cwe_dir}/{file_path}') 243 | else: 244 | # Without subdirectory 245 | path_variations.append(f'juliet-test-suite-c/testcases/{cwe_dir}/{file_path}') 246 | path_variations.append(f'testcases/{cwe_dir}/{file_path}') 247 | path_variations.append(f'{cwe_dir}/{file_path}') 248 | break 249 | break 250 | 251 | # Try to find in function map 252 | for normalized_path in path_variations: 253 | key = (normalized_path, line_number) 254 | if key in function_map: 255 | func_info = function_map[key] 256 | func_type = func_info['type'] 257 | if func_type in ['bad', 'good']: 258 | return func_type 259 | 260 | # If function map is empty or doesn't contain the file, use filename-based classification 261 | # This is important for cases where function boundaries weren't extracted properly 262 | if len(function_map) == 0 or not any(file_path in key[0] for key in function_map): 263 | # Read the file and check function name at the line 264 | try: 265 | # Try to find the file 266 | possible_paths = [ 267 | file_path, 268 | os.path.join('/hdd2/QL-Relax/juliet-test-suite-c/testcases', file_path) 269 | ] 270 | 271 | # Add CWE-specific paths 272 | import re 273 | cwe_match = re.search(r'CWE(\d+)_', file_path) 274 | if cwe_match: 275 | # Find CWE directory 276 | import glob 277 | cwe_num = cwe_match.group(1).lstrip('0') 278 | cwe_dirs = glob.glob(f'/hdd2/QL-Relax/juliet-test-suite-c/testcases/CWE{cwe_num}_*') 279 | for cwe_dir in cwe_dirs: 280 | possible_paths.append(os.path.join(cwe_dir, os.path.basename(file_path))) 281 | 282 | # Try to read the file 283 | file_content = None 284 | for path in possible_paths: 285 | if os.path.exists(path): 286 | with open(path, 'r') as f: 287 | file_content = f.readlines() 288 | break 289 | 290 | if file_content and line_number <= len(file_content): 291 | # Look backwards from the line to find the function declaration 292 | for i in range(line_number - 1, -1, -1): 293 | line = file_content[i] 294 | # Check for function declarations 295 | if ('void ' in line or 'int ' in line or 'char ' in line) and '(' in line and '{' in file_content[i:i+3]: 296 | # Extract function name 297 | import re 298 | func_match = re.search(r'(\w+)\s*\(', line) 299 | if func_match: 300 | func_name = func_match.group(1) 301 | return classify_function_name(func_name) 302 | except: 303 | pass 304 | 305 | # Final fallback to file name patterns if function not found 306 | if '_bad' in file_path or 'bad_' in file_path: 307 | return 'bad' 308 | elif '_good' in file_path or 'good_' in file_path: 309 | return 'good' 310 | 311 | return 'unknown' 312 | 313 | 314 | def classify_result(thread_flow, sarif_result, function_map): 315 | """ 316 | Enhanced classification using CodeQL-extracted function boundaries. 317 | """ 318 | # Check both thread flow locations and result locations 319 | all_locations = [] 320 | 321 | # Add thread flow locations 322 | thread_locations = thread_flow.get('locations', []) 323 | for location in thread_locations: 324 | loc = location.get('location', {}) 325 | all_locations.append(loc) 326 | 327 | # Add result locations 328 | result_locations = sarif_result.get('locations', []) 329 | all_locations.extend(result_locations) 330 | 331 | # Check each location using function map 332 | for loc in all_locations: 333 | phys_loc = loc.get('physicalLocation', {}) 334 | file_uri = phys_loc.get('artifactLocation', {}).get('uri', '') 335 | line_num = phys_loc.get('region', {}).get('startLine', 0) 336 | 337 | if file_uri and line_num > 0: 338 | # get_function_from_line already handles file name patterns as fallback 339 | func_type = get_function_from_line(file_uri, line_num, function_map) 340 | if func_type != 'unknown': 341 | return func_type 342 | 343 | return 'unknown' 344 | 345 | 346 | def evaluate_sarif_results(sarif_path, output_dir=None, source_base_dir=None): 347 | """ 348 | Enhanced evaluation using CodeQL-extracted function boundaries. 349 | 350 | Args: 351 | sarif_path: Path to the SARIF file 352 | output_dir: Optional directory to save good_results.json and bad_results.json 353 | cwe_number: CWE number for extracting function boundaries 354 | 355 | Returns: 356 | dict: Evaluation metrics including TP/FP counts and rates 357 | """ 358 | if not os.path.exists(sarif_path): 359 | return { 360 | 'good_result_count': 0, 361 | 'bad_result_count': 0, 362 | 'unknown_result_count': 0, 363 | 'true_positive_count': 0, 364 | 'false_positive_count': 0, 365 | 'true_positive_rate': 0.0, 366 | 'false_positive_rate': 0.0, 367 | 'total_threadflows': 0 368 | } 369 | 370 | # Extract CWE number from path or source_base_dir 371 | cwe_number = None 372 | import re 373 | 374 | # Try from sarif path first 375 | cwe_match = re.search(r'CWE-?(\d+)', sarif_path) 376 | if cwe_match: 377 | cwe_number = int(cwe_match.group(1).lstrip('0')) 378 | 379 | # Try from source_base_dir if not found 380 | if not cwe_number and source_base_dir: 381 | cwe_match = re.search(r'CWE(\d+)_', source_base_dir) 382 | if cwe_match: 383 | cwe_number = int(cwe_match.group(1).lstrip('0')) 384 | 385 | if not cwe_number: 386 | print("Warning: Could not determine CWE number, using text-based function detection") 387 | function_map = {} 388 | else: 389 | print(f"Extracting function boundaries for CWE-{cwe_number}...") 390 | function_map = extract_functions_for_cwe(cwe_number) 391 | print(f"Extracted {len(function_map)} function-line mappings") 392 | 393 | try: 394 | with open(sarif_path, 'r', encoding='utf-8') as f: 395 | sarif_data = json.load(f) 396 | 397 | good_count = 0 398 | bad_count = 0 399 | unknown_count = 0 400 | total_threadflows = 0 401 | good_results = [] 402 | bad_results = [] 403 | unknown_results = [] 404 | 405 | for run in sarif_data.get('runs', []): 406 | for result in run.get('results', []): 407 | # Get result location info 408 | result_loc = result.get('locations', [{}])[0].get('physicalLocation', {}) 409 | result_file = result_loc.get('artifactLocation', {}).get('uri', '') 410 | result_line = result_loc.get('region', {}).get('startLine', 0) 411 | result_message = result.get('message', {}).get('text', '') 412 | 413 | # Check if this is a path-problem query (has codeFlows) 414 | code_flows = result.get('codeFlows', []) 415 | if code_flows: 416 | # Handle path-problem queries 417 | for code_flow in code_flows: 418 | for thread_flow in code_flow.get('threadFlows', []): 419 | total_threadflows += 1 420 | 421 | # Create a summary of this threadFlow 422 | thread_flow_summary = { 423 | 'result_location': { 424 | 'file': result_file, 425 | 'line': result_line, 426 | 'message': result_message[:200] # Truncate long messages 427 | }, 428 | 'thread_flow_locations': [] 429 | } 430 | 431 | # Add key locations from the threadFlow 432 | locations = thread_flow.get('locations', []) 433 | for i, location in enumerate(locations): 434 | if i == 0 or i == len(locations) - 1: # First and last locations 435 | loc = location.get('location', {}) 436 | phys_loc = loc.get('physicalLocation', {}) 437 | thread_flow_summary['thread_flow_locations'].append({ 438 | 'step': 'source' if i == 0 else 'sink', 439 | 'file': phys_loc.get('artifactLocation', {}).get('uri', ''), 440 | 'line': phys_loc.get('region', {}).get('startLine', 0), 441 | 'message': loc.get('message', {}).get('text', '')[:100] 442 | }) 443 | 444 | classification = classify_result(thread_flow, result, function_map) 445 | if classification == 'bad': 446 | bad_count += 1 447 | bad_results.append(thread_flow_summary) 448 | elif classification == 'good': 449 | good_count += 1 450 | good_results.append(thread_flow_summary) 451 | else: # unknown 452 | unknown_count += 1 453 | unknown_results.append(thread_flow_summary) 454 | else: 455 | # Handle regular problem queries (no codeFlows) 456 | total_threadflows += 1 457 | 458 | # Create a summary for this result 459 | result_summary = { 460 | 'result_location': { 461 | 'file': result_file, 462 | 'line': result_line, 463 | 'message': result_message[:200] # Truncate long messages 464 | } 465 | } 466 | 467 | # Classify based on result location instead of threadFlow 468 | classification = get_function_from_line(result_file, result_line, function_map) 469 | if classification == 'bad': 470 | bad_count += 1 471 | bad_results.append(result_summary) 472 | elif classification == 'good': 473 | good_count += 1 474 | good_results.append(result_summary) 475 | else: # unknown 476 | unknown_count += 1 477 | unknown_results.append(result_summary) 478 | 479 | # In Juliet test suite: 480 | # - True Positive (TP): Finding a vulnerability in a "bad" function 481 | # - False Positive (FP): Finding a vulnerability in a "good" function or unmarked function 482 | true_positive_count = bad_count 483 | false_positive_count = good_count + unknown_count 484 | 485 | # Calculate rates based on all results 486 | total = true_positive_count + false_positive_count 487 | true_positive_rate = (true_positive_count / total * 100) if total > 0 else 0.0 488 | false_positive_rate = (false_positive_count / total * 100) if total > 0 else 0.0 489 | 490 | # Save results by category if output directory is provided 491 | if output_dir and os.path.exists(output_dir): 492 | # Save good results (false positives) 493 | good_results_file = os.path.join(output_dir, 'good_results.json') 494 | with open(good_results_file, 'w', encoding='utf-8') as f: 495 | json.dump({ 496 | 'count': good_count, 497 | 'description': 'ThreadFlows in code marked as "good" (known false positives)', 498 | 'results': good_results 499 | }, f, indent=2) 500 | 501 | # Save bad results (true positives) 502 | bad_results_file = os.path.join(output_dir, 'bad_results.json') 503 | with open(bad_results_file, 'w', encoding='utf-8') as f: 504 | json.dump({ 505 | 'count': bad_count, 506 | 'description': 'ThreadFlows in code marked as "bad" (true positives)', 507 | 'results': bad_results 508 | }, f, indent=2) 509 | 510 | # Save unknown results 511 | unknown_results_file = os.path.join(output_dir, 'unknown_results.json') 512 | with open(unknown_results_file, 'w', encoding='utf-8') as f: 513 | json.dump({ 514 | 'count': unknown_count, 515 | 'description': 'ThreadFlows in unmarked code (classification unknown)', 516 | 'results': unknown_results 517 | }, f, indent=2) 518 | 519 | return { 520 | 'good_result_count': good_count, 521 | 'bad_result_count': bad_count, 522 | 'unknown_result_count': unknown_count, 523 | 'true_positive_count': true_positive_count, 524 | 'false_positive_count': false_positive_count, 525 | 'true_positive_rate': round(true_positive_rate, 2), 526 | 'false_positive_rate': round(false_positive_rate, 2), 527 | 'total_threadflows': total_threadflows 528 | } 529 | 530 | except Exception as e: 531 | print(f"[Evaluation] Error evaluating SARIF: {str(e)}") 532 | import traceback 533 | traceback.print_exc() 534 | return { 535 | 'good_result_count': 0, 536 | 'bad_result_count': 0, 537 | 'unknown_result_count': 0, 538 | 'true_positive_count': 0, 539 | 'false_positive_count': 0, 540 | 'true_positive_rate': 0.0, 541 | 'false_positive_rate': 0.0, 542 | 'total_threadflows': 0, 543 | 'error': str(e) 544 | } 545 | 546 | -------------------------------------------------------------------------------- /QLWorkflow/util/function_dump.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @kind problem 3 | * @name Function Blocks 4 | * @id cpp/function-block-location 5 | * @description Lists functions with their block statement locations. 6 | */ 7 | 8 | import cpp 9 | 10 | from Function f, BlockStmt body 11 | where f.hasDefinition() and 12 | body = f.getBlock() and 13 | f.getLocation().getFile() = body.getLocation().getFile() 14 | select f.getName().toString(), body.getLocation().getFile().getAbsolutePath().toString(), 15 | f.getLocation().getStartLine().toString(), 16 | body.getLocation().getEndLine().toString() -------------------------------------------------------------------------------- /QLWorkflow/util/logging_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logging utilities for QLWorkflow 3 | """ 4 | 5 | from pathlib import Path 6 | from typing import Dict, Any, Optional 7 | 8 | 9 | def get_ql_workflow_log_path(context: Dict[str, Any]) -> Optional[Path]: 10 | """ 11 | Get the log path for QLWorkflow sessions based on CWE and iteration. 12 | 13 | Args: 14 | context: Context containing CWE number, query name, iteration, etc. 15 | 16 | Returns: 17 | Path object for the log directory, or None if insufficient context 18 | """ 19 | # The output_dir already contains the CWE-XXX_QueryName format 20 | output_dir = context.get('output_dir') 21 | if not output_dir: 22 | return None 23 | 24 | base_log_dir = Path(output_dir) 25 | iteration = context.get('iteration', 1) 26 | 27 | # Special handling for initial evaluation (iteration 0) 28 | if iteration == 0: 29 | return base_log_dir / 'initial' / 'session_log' 30 | else: 31 | return base_log_dir / f"iteration_{iteration}" / 'session_log' 32 | 33 | 34 | def get_action_type_from_prompt(prompt: str) -> str: 35 | """ 36 | Determine action type based on prompt content. 37 | 38 | Args: 39 | prompt: The prompt text 40 | 41 | Returns: 42 | Action type string ('modification', 'validation', or 'general') 43 | """ 44 | prompt_lower = prompt.lower() 45 | 46 | if any(keyword in prompt_lower for keyword in ['modifying ql', 'modify', 'modification', 'broaden', 'compile error']): 47 | return 'modification' 48 | elif any(keyword in prompt_lower for keyword in ['validation', 'analyze', 'result distribution', 'query results']): 49 | return 'validation' 50 | else: 51 | return 'general' -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QL-Relax 2 | 3 | **An Experimental Approach to Relaxing CodeQL Constraints Using LLMs** 4 | 5 | This project explores an alternative to generating CodeQL queries from scratch (which often produces syntax errors). Instead, we experiment with using LLMs to systematically relax conservative constraints in existing official CodeQL queries. Our hypothesis is that many official queries prioritize low false positive rates through aggressive pruning, potentially missing edge cases that could be valuable in certain security contexts. 6 | 7 | ## Architecture 8 | 9 | QL-Relax Architecture 10 | 11 | ## Setup 12 | 13 | ### Requirements 14 | - Claude Code 15 | - Docker 16 | - Git 17 | - ~50GB disk space 18 | 19 | ### First Time Setup 20 | 0. Clone & CD the project 21 | 1. Clone required repositories: 22 | ```bash 23 | mkdir -p qlworkspace/origin 24 | cd qlworkspace/origin 25 | git clone https://github.com/github/codeql.git 26 | cd ../.. 27 | 28 | # Clone Juliet test suite 29 | git clone https://github.com/arichardson/juliet-test-suite-c.git 30 | ``` 31 | 32 | 2. Build and start Docker: 33 | ```bash 34 | ./start_docker.sh 35 | ``` 36 | 37 | 3. Create CodeQL databases: 38 | ```bash 39 | python3 run_juliet.py --create-db --all 40 | ``` 41 | 42 | 4. Create initial workspaces: 43 | ```bash 44 | python3 run_juliet.py --create-workspace --all 45 | ``` 46 | 47 | ### Running 48 | 49 | ```bash 50 | # Single CWE 51 | python3 run_ql_workflow.py --cwe 190 52 | 53 | # Multiple CWEs 54 | python3 run_ql_workflow.py --cwe 190 134 78 55 | 56 | # All supported CWEs 57 | python3 run_ql_workflow.py --all 58 | ``` 59 | 60 | ## How It Works 61 | 62 | This experimental workflow attempts to improve vulnerability detection through constraint relaxation: 63 | 64 | 1. **Start with Official Queries**: Use production-tested CodeQL queries as a reliable foundation 65 | 2. **LLM-Guided Relaxation**: Experiment with removing conservative filters and constraints 66 | 3. **Test on Juliet Suite**: Validate whether relaxed queries catch more known vulnerabilities 67 | 4. **Iterate Carefully**: Balance between finding more issues and maintaining query validity 68 | 69 | ## Docker Setup 70 | 71 | The system uses a fixed container name `ql-relax-container`. 72 | 73 | ```bash 74 | docker build -t ql-relax:latest . 75 | docker run -d --name ql-relax-container -v $(pwd):/workspace ql-relax:latest 76 | ``` 77 | 78 | ## Project Structure 79 | 80 | ``` 81 | QL-Relax/ 82 | ├── BaseMachine/ # LLM state machine framework 83 | ├── QLWorkflow/ # Query optimization workflow 84 | ├── juliet-test-suite-c/ # Juliet test cases (mount or symlink) 85 | ├── run_juliet.py # Single CWE runner 86 | └── run_ql_workflow.py # Multi-CWE pipeline 87 | ``` 88 | 89 | ## Configuration 90 | 91 | Environment variables (optional): 92 | - `JULIET_PATH`: Path to Juliet test suite 93 | - `CODEQL_DB_PATH`: Path to CodeQL databases 94 | 95 | ## Supported CWEs 96 | 97 | We select CWEs that have both path-problem queries and Juliet testsuites to build a closed loop. 98 | -------------------------------------------------------------------------------- /docs/images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/P1umer/QL-Relax/941ffa7fb0b0b196ba9e4e49e99f5b7fe6e47138/docs/images/architecture.png -------------------------------------------------------------------------------- /draw/initial_vs_final_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/P1umer/QL-Relax/941ffa7fb0b0b196ba9e4e49e99f5b7fe6e47138/draw/initial_vs_final_comparison.png -------------------------------------------------------------------------------- /draw/plot_initial_vs_final.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Plot initial vs final results comparison charts. 4 | """ 5 | 6 | import os 7 | import json 8 | import glob 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from collections import defaultdict 12 | import argparse 13 | 14 | def collect_initial_and_final_results(qlworkspace_dir): 15 | """Collect both initial and final results from all CWE directories.""" 16 | initial_results = [] 17 | final_results = [] 18 | 19 | # Find all CWE directories 20 | cwe_pattern = os.path.join(qlworkspace_dir, "CWE-*") 21 | cwe_dirs = glob.glob(cwe_pattern) 22 | 23 | for cwe_dir in cwe_dirs: 24 | cwe_name = os.path.basename(cwe_dir) 25 | if '_' in cwe_name: 26 | cwe_part, query_part = cwe_name.split('_', 1) 27 | cwe_number = cwe_part.replace('CWE-', '') 28 | query_name = query_part 29 | else: 30 | continue 31 | 32 | # Collect initial results 33 | initial_pattern = os.path.join(cwe_dir, "initial/query_results/results_log.json") 34 | initial_files = glob.glob(initial_pattern) 35 | 36 | initial_data = None 37 | for log_file in initial_files: 38 | try: 39 | with open(log_file, 'r') as f: 40 | data = json.load(f) 41 | 42 | if data.get('total_threadflows', 0) > 0 and 'error' not in data: 43 | initial_data = { 44 | 'cwe': cwe_number, 45 | 'query': query_name, 46 | 'tp': data.get('true_positive_count', 0), 47 | 'fp': data.get('false_positive_count', 0), 48 | 'total': data.get('total_threadflows', 0), 49 | 'tp_rate': data.get('true_positive_rate', 0.0), 50 | 'fp_rate': data.get('false_positive_rate', 0.0) 51 | } 52 | break 53 | except: 54 | continue 55 | 56 | # Collect final results (prefer latest iteration) 57 | final_pattern = os.path.join(cwe_dir, "iteration_*/query_results/results_log.json") 58 | final_files = sorted(glob.glob(final_pattern), reverse=True) # Latest first 59 | 60 | final_data = None 61 | for log_file in final_files: 62 | try: 63 | with open(log_file, 'r') as f: 64 | data = json.load(f) 65 | 66 | if data.get('total_threadflows', 0) > 0 and 'error' not in data: 67 | final_data = { 68 | 'cwe': cwe_number, 69 | 'query': query_name, 70 | 'tp': data.get('true_positive_count', 0), 71 | 'fp': data.get('false_positive_count', 0), 72 | 'total': data.get('total_threadflows', 0), 73 | 'tp_rate': data.get('true_positive_rate', 0.0), 74 | 'fp_rate': data.get('false_positive_rate', 0.0) 75 | } 76 | break 77 | except: 78 | continue 79 | 80 | # Only include if we have both initial and final data 81 | if initial_data and final_data: 82 | initial_results.append(initial_data) 83 | final_results.append(final_data) 84 | elif initial_data: # Only initial data available 85 | initial_results.append(initial_data) 86 | # Create empty final data for comparison 87 | final_results.append({ 88 | 'cwe': cwe_number, 89 | 'query': query_name, 90 | 'tp': 0, 'fp': 0, 'total': 0, 91 | 'tp_rate': 0.0, 'fp_rate': 0.0 92 | }) 93 | elif final_data: # Only final data available 94 | final_results.append(final_data) 95 | # Create empty initial data for comparison 96 | initial_results.append({ 97 | 'cwe': cwe_number, 98 | 'query': query_name, 99 | 'tp': 0, 'fp': 0, 'total': 0, 100 | 'tp_rate': 0.0, 'fp_rate': 0.0 101 | }) 102 | 103 | return initial_results, final_results 104 | 105 | def create_comparison_chart(initial_results, final_results, output_path): 106 | """Create side-by-side comparison charts.""" 107 | if not initial_results and not final_results: 108 | print("No data to plot") 109 | return 110 | 111 | # Ensure both lists have the same length and order 112 | all_queries = set() 113 | initial_dict = {f"{r['cwe']}_{r['query']}": r for r in initial_results} 114 | final_dict = {f"{r['cwe']}_{r['query']}": r for r in final_results} 115 | all_queries = set(initial_dict.keys()) | set(final_dict.keys()) 116 | 117 | # Prepare aligned data 118 | aligned_initial = [] 119 | aligned_final = [] 120 | 121 | for query_key in sorted(all_queries): 122 | initial_data = initial_dict.get(query_key, { 123 | 'cwe': query_key.split('_')[0], 'query': '_'.join(query_key.split('_')[1:]), 124 | 'tp': 0, 'fp': 0, 'total': 0, 'tp_rate': 0.0 125 | }) 126 | final_data = final_dict.get(query_key, { 127 | 'cwe': query_key.split('_')[0], 'query': '_'.join(query_key.split('_')[1:]), 128 | 'tp': 0, 'fp': 0, 'total': 0, 'tp_rate': 0.0 129 | }) 130 | 131 | aligned_initial.append(initial_data) 132 | aligned_final.append(final_data) 133 | 134 | # Sort by total increase (final total - initial total) for better visualization 135 | total_increases = [f['total'] - i['total'] for i, f in zip(aligned_initial, aligned_final)] 136 | sorted_indices = sorted(range(len(total_increases)), 137 | key=lambda i: total_increases[i], reverse=True) 138 | 139 | aligned_initial = [aligned_initial[i] for i in sorted_indices] 140 | aligned_final = [aligned_final[i] for i in sorted_indices] 141 | 142 | # Create figure with grouped bar chart 143 | fig, ax = plt.subplots(figsize=(20, 10)) 144 | 145 | # Prepare data 146 | labels = [f"CWE-{r['cwe']}\n{r['query']}" for r in aligned_final] 147 | 148 | initial_tp = [r['tp'] for r in aligned_initial] 149 | initial_fp = [r['fp'] for r in aligned_initial] 150 | initial_total = [r['total'] for r in aligned_initial] 151 | 152 | final_tp = [r['tp'] for r in aligned_final] 153 | final_fp = [r['fp'] for r in aligned_final] 154 | final_total = [r['total'] for r in aligned_final] 155 | 156 | x = np.arange(len(labels)) 157 | width = 0.35 158 | 159 | # Calculate max height for scaling 160 | max_height = max(max(initial_total + final_total), 1) 161 | 162 | # Create stacked bars - Initial on left, Final on right 163 | # Initial results (with diagonal hatching) 164 | for i in range(len(labels)): 165 | # Make bars visible even for 0 values 166 | display_total = max(initial_total[i], max_height * 0.005) if initial_total[i] == 0 else initial_total[i] 167 | 168 | if initial_total[i] > 0: 169 | # Stack: True Positives (red) at bottom, False Positives (yellow) on top 170 | # True Positives 171 | ax.bar(x[i] - width/2, initial_tp[i], width, 172 | color='#DC143C', alpha=0.8, edgecolor='black', linewidth=1.5, 173 | hatch='///', label='True Positives' if i == 0 else "") 174 | # False Positives 175 | ax.bar(x[i] - width/2, initial_fp[i], width, bottom=initial_tp[i], 176 | color='#4169E1', alpha=0.8, edgecolor='black', linewidth=1.5, 177 | hatch='///', label='False Positives' if i == 0 else "") 178 | else: 179 | # Empty bar with initial hatching 180 | ax.bar(x[i] - width/2, display_total, width, 181 | color='lightgray', alpha=0.3, edgecolor='black', linewidth=1.5, 182 | hatch='///') 183 | 184 | # Final results (with dot hatching) 185 | for i in range(len(labels)): 186 | # Make bars visible even for 0 values 187 | display_total = max(final_total[i], max_height * 0.005) if final_total[i] == 0 else final_total[i] 188 | 189 | if final_total[i] > 0: 190 | # Stack: True Positives (red) at bottom, False Positives (yellow) on top 191 | # True Positives 192 | ax.bar(x[i] + width/2, final_tp[i], width, 193 | color='#DC143C', alpha=0.8, edgecolor='black', linewidth=1.5, 194 | hatch='...') 195 | # False Positives 196 | ax.bar(x[i] + width/2, final_fp[i], width, bottom=final_tp[i], 197 | color='#4169E1', alpha=0.8, edgecolor='black', linewidth=1.5, 198 | hatch='...') 199 | else: 200 | # Empty bar with final hatching 201 | ax.bar(x[i] + width/2, display_total, width, 202 | color='lightgray', alpha=0.3, edgecolor='black', linewidth=1.5, 203 | hatch='...') 204 | 205 | # Add custom legend 206 | from matplotlib.patches import Patch, Rectangle 207 | from matplotlib.lines import Line2D 208 | 209 | # Create custom legend elements 210 | legend_elements = [ 211 | # Result types 212 | Rectangle((0,0), 1, 1, facecolor='#DC143C', alpha=0.8, edgecolor='black', linewidth=1, label='True Positives'), 213 | Rectangle((0,0), 1, 1, facecolor='#4169E1', alpha=0.8, edgecolor='black', linewidth=1, label='False Positives'), 214 | # Separator 215 | Line2D([0], [0], color='none', label=''), 216 | # Pattern types 217 | Rectangle((0,0), 1, 1, facecolor='gray', alpha=0.5, edgecolor='black', linewidth=1.5, 218 | hatch='///', label='Initial Results'), 219 | Rectangle((0,0), 1, 1, facecolor='gray', alpha=0.5, edgecolor='black', linewidth=1.5, 220 | hatch='...', label='Final Results'), 221 | ] 222 | ax.legend(handles=legend_elements, loc='upper right', fontsize=11, framealpha=0.9) 223 | 224 | ax.set_xlabel('CWE and Query', fontsize=14, fontweight='bold') 225 | ax.set_ylabel('Number of Results', fontsize=14, fontweight='bold') 226 | ax.set_title('CodeQL Query Results: Before vs After Optimization\n(Sorted by Total Result Increase)', 227 | fontsize=16, fontweight='bold') 228 | ax.set_xticks(x) 229 | ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=10) 230 | ax.grid(True, alpha=0.3, axis='y') 231 | 232 | # Add value labels and increase indicators 233 | for i in range(len(labels)): 234 | # Initial bar labels 235 | if initial_total[i] > 0: 236 | # Show total on top 237 | ax.text(x[i] - width/2, initial_total[i] + max_height * 0.01, 238 | str(initial_total[i]), ha='center', va='bottom', fontsize=7, fontweight='bold') 239 | # Show TP count in red section if significant 240 | if initial_tp[i] > initial_total[i] * 0.15: 241 | ax.text(x[i] - width/2, initial_tp[i]/2, str(initial_tp[i]), 242 | ha='center', va='center', color='white', fontsize=6, fontweight='bold') 243 | # Show FP count in blue section if significant 244 | if initial_fp[i] > initial_total[i] * 0.15: 245 | ax.text(x[i] - width/2, initial_tp[i] + initial_fp[i]/2, str(initial_fp[i]), 246 | ha='center', va='center', color='white', fontsize=6, fontweight='bold') 247 | else: 248 | ax.text(x[i] - width/2, max_height * 0.01, '0', 249 | ha='center', va='bottom', fontsize=7, fontweight='bold') 250 | 251 | # Final bar labels 252 | if final_total[i] > 0: 253 | # Show total on top 254 | ax.text(x[i] + width/2, final_total[i] + max_height * 0.01, 255 | str(final_total[i]), ha='center', va='bottom', fontsize=7, fontweight='bold') 256 | # Show TP count in red section if significant 257 | if final_tp[i] > final_total[i] * 0.15: 258 | ax.text(x[i] + width/2, final_tp[i]/2, str(final_tp[i]), 259 | ha='center', va='center', color='white', fontsize=6, fontweight='bold') 260 | # Show FP count in blue section if significant 261 | if final_fp[i] > final_total[i] * 0.15: 262 | ax.text(x[i] + width/2, final_tp[i] + final_fp[i]/2, str(final_fp[i]), 263 | ha='center', va='center', color='white', fontsize=6, fontweight='bold') 264 | else: 265 | ax.text(x[i] + width/2, max_height * 0.01, '0', 266 | ha='center', va='bottom', fontsize=7, fontweight='bold') 267 | 268 | # Add increase indicator 269 | increase = final_total[i] - initial_total[i] 270 | if increase > 0: 271 | # Draw arrow and show increase 272 | arrow_y = max(initial_total[i], final_total[i]) + max_height * 0.05 273 | ax.annotate(f'+{increase}', xy=(x[i], arrow_y), 274 | ha='center', va='bottom', fontsize=8, 275 | color='darkgreen', fontweight='bold', 276 | bbox=dict(boxstyle="round,pad=0.3", facecolor='lightgreen', alpha=0.7)) 277 | 278 | 279 | # Calculate and display summary statistics 280 | initial_total_tp = sum(initial_tp) 281 | initial_total_fp = sum(initial_fp) 282 | initial_total_all = sum(initial_total) 283 | 284 | final_total_tp = sum(final_tp) 285 | final_total_fp = sum(final_fp) 286 | final_total_all = sum(final_total) 287 | 288 | # Calculate average TP rates 289 | initial_tp_rates = [r['tp_rate'] for r in aligned_initial] 290 | final_tp_rates = [r['tp_rate'] for r in aligned_final] 291 | initial_avg_tp_rate = np.mean([r for r in initial_tp_rates if r > 0]) if any(r > 0 for r in initial_tp_rates) else 0 292 | final_avg_tp_rate = np.mean([r for r in final_tp_rates if r > 0]) if any(r > 0 for r in final_tp_rates) else 0 293 | 294 | # Add statistics text box 295 | stats_text = f"""Summary Statistics: 296 | Initial: {initial_total_all:,} results ({initial_total_tp:,} TP, {initial_total_fp:,} FP) - Avg TP Rate: {initial_avg_tp_rate:.1f}% 297 | Final: {final_total_all:,} results ({final_total_tp:,} TP, {final_total_fp:,} FP) - Avg TP Rate: {final_avg_tp_rate:.1f}% 298 | Improvement: {final_total_all - initial_total_all:+,} results ({final_total_tp - initial_total_tp:+,} TP)""" 299 | 300 | ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, fontsize=11, 301 | verticalalignment='top', 302 | bbox=dict(boxstyle='round,pad=0.5', facecolor='lightyellow', alpha=0.8)) 303 | 304 | plt.tight_layout() 305 | plt.savefig(output_path, dpi=300, bbox_inches='tight') 306 | print(f"Initial vs Final comparison chart saved to: {output_path}") 307 | plt.show() 308 | 309 | def create_tp_rate_comparison(initial_results, final_results, output_path): 310 | """Create TP rate comparison chart.""" 311 | # Align data similar to above 312 | all_queries = set() 313 | initial_dict = {f"{r['cwe']}_{r['query']}": r for r in initial_results} 314 | final_dict = {f"{r['cwe']}_{r['query']}": r for r in final_results} 315 | all_queries = set(initial_dict.keys()) | set(final_dict.keys()) 316 | 317 | aligned_initial = [] 318 | aligned_final = [] 319 | 320 | for query_key in sorted(all_queries): 321 | initial_data = initial_dict.get(query_key, { 322 | 'cwe': query_key.split('_')[0], 'query': '_'.join(query_key.split('_')[1:]), 323 | 'tp_rate': 0.0 324 | }) 325 | final_data = final_dict.get(query_key, { 326 | 'cwe': query_key.split('_')[0], 'query': '_'.join(query_key.split('_')[1:]), 327 | 'tp_rate': 0.0 328 | }) 329 | 330 | aligned_initial.append(initial_data) 331 | aligned_final.append(final_data) 332 | 333 | # Sort by improvement (final - initial) 334 | improvements = [f['tp_rate'] - i['tp_rate'] for i, f in zip(aligned_initial, aligned_final)] 335 | sorted_indices = sorted(range(len(improvements)), key=lambda i: improvements[i], reverse=True) 336 | 337 | aligned_initial = [aligned_initial[i] for i in sorted_indices] 338 | aligned_final = [aligned_final[i] for i in sorted_indices] 339 | improvements = [improvements[i] for i in sorted_indices] 340 | 341 | # Create the TP rate comparison chart 342 | fig, ax = plt.subplots(figsize=(20, 8)) 343 | 344 | labels = [f"CWE-{r['cwe']}\n{r['query']}" for r in aligned_final] 345 | initial_rates = [r['tp_rate'] for r in aligned_initial] 346 | final_rates = [r['tp_rate'] for r in aligned_final] 347 | 348 | x = np.arange(len(labels)) 349 | width = 0.35 350 | 351 | bars1 = ax.bar(x - width/2, initial_rates, width, label='Initial TP Rate', 352 | color='#FF6B6B', alpha=0.8) 353 | bars2 = ax.bar(x + width/2, final_rates, width, label='Final TP Rate', 354 | color='#4ECDC4', alpha=0.8) 355 | 356 | # Add improvement arrows 357 | for i, (init_rate, final_rate, improvement) in enumerate(zip(initial_rates, final_rates, improvements)): 358 | if abs(improvement) > 1: # Only show significant improvements 359 | arrow_color = 'green' if improvement > 0 else 'red' 360 | arrow_style = '↑' if improvement > 0 else '↓' 361 | ax.annotate(f'{arrow_style}{abs(improvement):.1f}%', 362 | xy=(i, max(init_rate, final_rate) + 2), 363 | ha='center', va='bottom', color=arrow_color, fontweight='bold', fontsize=8) 364 | 365 | ax.set_xlabel('CWE and Query', fontsize=12, fontweight='bold') 366 | ax.set_ylabel('True Positive Rate (%)', fontsize=12, fontweight='bold') 367 | ax.set_title('True Positive Rate Comparison: Initial vs Final\n(Sorted by Improvement)', 368 | fontsize=14, fontweight='bold') 369 | ax.set_xticks(x) 370 | ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=9) 371 | ax.legend() 372 | ax.grid(True, alpha=0.3, axis='y') 373 | ax.set_ylim(0, 105) 374 | 375 | # Add value labels on bars 376 | for bar in bars1: 377 | height = bar.get_height() 378 | if height > 0: 379 | ax.text(bar.get_x() + bar.get_width()/2., height + 0.5, 380 | f'{height:.1f}%', ha='center', va='bottom', fontsize=8) 381 | 382 | for bar in bars2: 383 | height = bar.get_height() 384 | if height > 0: 385 | ax.text(bar.get_x() + bar.get_width()/2., height + 0.5, 386 | f'{height:.1f}%', ha='center', va='bottom', fontsize=8) 387 | 388 | plt.tight_layout() 389 | plt.savefig(output_path, dpi=300, bbox_inches='tight') 390 | print(f"TP Rate comparison chart saved to: {output_path}") 391 | plt.show() 392 | 393 | def print_comparison_table(initial_results, final_results): 394 | """Print comparison table.""" 395 | # Align data 396 | all_queries = set() 397 | initial_dict = {f"{r['cwe']}_{r['query']}": r for r in initial_results} 398 | final_dict = {f"{r['cwe']}_{r['query']}": r for r in final_results} 399 | all_queries = set(initial_dict.keys()) | set(final_dict.keys()) 400 | 401 | comparisons = [] 402 | for query_key in sorted(all_queries): 403 | initial_data = initial_dict.get(query_key, { 404 | 'cwe': query_key.split('_')[0], 'query': '_'.join(query_key.split('_')[1:]), 405 | 'tp': 0, 'fp': 0, 'total': 0, 'tp_rate': 0.0 406 | }) 407 | final_data = final_dict.get(query_key, { 408 | 'cwe': query_key.split('_')[0], 'query': '_'.join(query_key.split('_')[1:]), 409 | 'tp': 0, 'fp': 0, 'total': 0, 'tp_rate': 0.0 410 | }) 411 | 412 | comparisons.append({ 413 | 'cwe': final_data['cwe'], 414 | 'query': final_data['query'], 415 | 'initial_tp': initial_data['tp'], 416 | 'initial_fp': initial_data['fp'], 417 | 'initial_total': initial_data['total'], 418 | 'initial_tp_rate': initial_data['tp_rate'], 419 | 'final_tp': final_data['tp'], 420 | 'final_fp': final_data['fp'], 421 | 'final_total': final_data['total'], 422 | 'final_tp_rate': final_data['tp_rate'], 423 | 'tp_improvement': final_data['tp'] - initial_data['tp'], 424 | 'total_improvement': final_data['total'] - initial_data['total'], 425 | 'tp_rate_improvement': final_data['tp_rate'] - initial_data['tp_rate'] 426 | }) 427 | 428 | # Sort by TP rate improvement 429 | comparisons.sort(key=lambda x: x['tp_rate_improvement'], reverse=True) 430 | 431 | print(f"\n{'='*150}") 432 | print(f"{'INITIAL vs FINAL COMPARISON':^150}") 433 | print(f"{'='*150}") 434 | print(f"{'CWE':<6} {'Query':<30} {'Initial':<25} {'Final':<25} {'Improvement':<25} {'TP Rate Δ'}") 435 | print(f"{'':<6} {'':<30} {'TP/FP/Total':<25} {'TP/FP/Total':<25} {'TP/Total':<25} {'(%)':>10}") 436 | print(f"{'-'*150}") 437 | 438 | for c in comparisons: 439 | initial_str = f"{c['initial_tp']}/{c['initial_fp']}/{c['initial_total']}" 440 | final_str = f"{c['final_tp']}/{c['final_fp']}/{c['final_total']}" 441 | improvement_str = f"{c['tp_improvement']:+}/{c['total_improvement']:+}" 442 | 443 | print(f"CWE-{c['cwe']:<3} {c['query']:<30} {initial_str:<25} {final_str:<25} {improvement_str:<25} {c['tp_rate_improvement']:+7.1f}%") 444 | 445 | def main(): 446 | parser = argparse.ArgumentParser(description='Generate initial vs final comparison charts') 447 | parser.add_argument('--workspace', type=str, default='/hdd2/QL-Relax/qlworkspace', 448 | help='Path to QL-Relax workspace directory') 449 | parser.add_argument('--output-dir', type=str, default='/hdd2/QL-Relax', 450 | help='Output directory for charts') 451 | 452 | args = parser.parse_args() 453 | 454 | print("Collecting initial and final results...") 455 | initial_results, final_results = collect_initial_and_final_results(args.workspace) 456 | 457 | if not initial_results and not final_results: 458 | print("No results found!") 459 | return 460 | 461 | print(f"Found {len(initial_results)} initial results and {len(final_results)} final results") 462 | 463 | # Print comparison table 464 | print_comparison_table(initial_results, final_results) 465 | 466 | # Generate comparison charts 467 | print("\nGenerating initial vs final comparison chart...") 468 | comparison_path = os.path.join(args.output_dir, 'initial_vs_final_comparison.png') 469 | create_comparison_chart(initial_results, final_results, comparison_path) 470 | 471 | print("Generating TP rate comparison chart...") 472 | tp_rate_path = os.path.join(args.output_dir, 'tp_rate_comparison.png') 473 | create_tp_rate_comparison(initial_results, final_results, tp_rate_path) 474 | 475 | if __name__ == "__main__": 476 | main() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openai>=1.0.0 2 | pydantic>=2.0.0 3 | colorama>=0.4.0 4 | requests>=2.25.0 5 | claude-code-sdk>=0.1.0 6 | anyio>=3.0.0 -------------------------------------------------------------------------------- /run_ql_workflow.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Main entry point for running the QL Workflow. 4 | This script executes the pipeline that broadens QL queries iteratively. 5 | """ 6 | 7 | import sys 8 | import os 9 | import argparse 10 | import json 11 | import glob 12 | 13 | # Get the directory of the script for relative paths 14 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 15 | 16 | from BaseMachine import StateMachine 17 | from QLWorkflow.pipeline_config import state_definitions, QLWorkflowContext 18 | from QLWorkflow.util.evaluation_utils import evaluate_sarif_results 19 | 20 | 21 | def run_evaluation_only(cwe_number, output_dir, specific_query=None): 22 | """ 23 | Run evaluation only on existing SARIF files for a CWE. 24 | """ 25 | print(f"\nRunning evaluation only for CWE-{cwe_number}") 26 | 27 | # Find SARIF files for this CWE 28 | cwe_dir = os.path.join(output_dir, f"CWE-{cwe_number}_*") 29 | cwe_dirs = glob.glob(cwe_dir) 30 | 31 | if not cwe_dirs: 32 | print(f"No directories found for CWE-{cwe_number} in {output_dir}") 33 | return 34 | 35 | results = [] 36 | 37 | for cwe_path in cwe_dirs: 38 | # Skip if specific query is requested and this doesn't match 39 | if specific_query and specific_query not in os.path.basename(cwe_path): 40 | continue 41 | 42 | # Find all SARIF files in iterations and initial 43 | sarif_patterns = [ 44 | os.path.join(cwe_path, "initial/query_results/*.sarif"), 45 | os.path.join(cwe_path, "iteration_*/query_results/*.sarif") 46 | ] 47 | 48 | for pattern in sarif_patterns: 49 | sarif_files = glob.glob(pattern) 50 | 51 | for sarif_file in sarif_files: 52 | print(f"\nEvaluating: {sarif_file}") 53 | 54 | # Get output directory for this SARIF 55 | query_results_dir = os.path.dirname(sarif_file) 56 | 57 | # Find source base directory for CWE 58 | testcases_base = os.path.join(SCRIPT_DIR, 'juliet-test-suite-c', 'testcases') 59 | source_base_dir = None 60 | if os.path.exists(testcases_base): 61 | for dirname in os.listdir(testcases_base): 62 | if dirname.startswith(f'CWE{cwe_number}_'): 63 | source_base_dir = os.path.join(testcases_base, dirname) 64 | break 65 | 66 | # Run evaluation 67 | evaluation_metrics = evaluate_sarif_results(sarif_file, query_results_dir, source_base_dir) 68 | 69 | # Update results_log.json 70 | results_log_path = os.path.join(query_results_dir, 'results_log.json') 71 | if os.path.exists(results_log_path): 72 | with open(results_log_path, 'r') as f: 73 | results_log = json.load(f) 74 | 75 | # Remove any existing error and update with evaluation metrics 76 | if 'error' in results_log: 77 | del results_log['error'] 78 | results_log.update(evaluation_metrics) 79 | 80 | # Update result_count based on SARIF threadflows 81 | if 'total_threadflows' in evaluation_metrics: 82 | results_log['result_count'] = evaluation_metrics['total_threadflows'] 83 | 84 | with open(results_log_path, 'w') as f: 85 | json.dump(results_log, f, indent=2) 86 | 87 | print(f"Updated: {results_log_path}") 88 | else: 89 | # Create new results_log.json with evaluation metrics 90 | results_log = { 91 | 'sarif_file': sarif_file, 92 | **evaluation_metrics 93 | } 94 | 95 | with open(results_log_path, 'w') as f: 96 | json.dump(results_log, f, indent=2) 97 | 98 | print(f"Created: {results_log_path}") 99 | 100 | # Print evaluation results 101 | print(f" Good results (FP): {evaluation_metrics['good_result_count']}") 102 | print(f" Bad results (TP): {evaluation_metrics['bad_result_count']}") 103 | print(f" Unknown results: {evaluation_metrics['unknown_result_count']}") 104 | print(f" Total threadflows: {evaluation_metrics['total_threadflows']}") 105 | print(f" True positive rate: {evaluation_metrics['true_positive_rate']}%") 106 | print(f" False positive rate: {evaluation_metrics['false_positive_rate']}%") 107 | 108 | results.append({ 109 | 'sarif_file': sarif_file, 110 | 'metrics': evaluation_metrics 111 | }) 112 | 113 | if not results: 114 | print(f"\nNo SARIF files found for evaluation") 115 | else: 116 | print(f"\nEvaluated {len(results)} SARIF files") 117 | 118 | # Update final_report.json if it exists 119 | update_final_report_evaluation(cwe_number, output_dir, results, specific_query) 120 | 121 | return results 122 | 123 | 124 | def update_final_report_evaluation(cwe_number, output_dir, eval_results, specific_query=None): 125 | """ 126 | Update final_report.json files with evaluation metrics. 127 | """ 128 | # Find final_report.json files for this CWE 129 | cwe_dir = os.path.join(output_dir, f"CWE-{cwe_number}_*") 130 | cwe_dirs = glob.glob(cwe_dir) 131 | 132 | for cwe_path in cwe_dirs: 133 | # Skip if specific query is requested and this doesn't match 134 | if specific_query and specific_query not in os.path.basename(cwe_path): 135 | continue 136 | 137 | final_report_path = os.path.join(cwe_path, 'final_report.json') 138 | if os.path.exists(final_report_path): 139 | try: 140 | with open(final_report_path, 'r') as f: 141 | report_data = json.load(f) 142 | 143 | # Find evaluation metrics from the results for this specific CWE dir 144 | initial_metrics = None 145 | final_metrics = None 146 | 147 | # Get the query name from the directory 148 | dir_query_name = os.path.basename(cwe_path).split('_', 1)[1] if '_' in os.path.basename(cwe_path) else '' 149 | 150 | for result in eval_results: 151 | sarif_file = result['sarif_file'] 152 | metrics = result['metrics'] 153 | 154 | # Check if this result belongs to the current CWE directory 155 | if dir_query_name and dir_query_name in sarif_file: 156 | if '/initial/' in sarif_file: 157 | initial_metrics = metrics 158 | elif '/iteration_' in sarif_file: 159 | # Use the last iteration as final metrics 160 | final_metrics = metrics 161 | 162 | # Update report with evaluation metrics 163 | if initial_metrics: 164 | report_data.update({ 165 | "initial_true_positive": initial_metrics['true_positive_count'], 166 | "initial_false_positive": initial_metrics['false_positive_count'], 167 | "initial_unknown_result": initial_metrics['unknown_result_count'], 168 | "initial_true_positive_rate": initial_metrics['true_positive_rate'], 169 | "initial_false_positive_rate": initial_metrics['false_positive_rate'], 170 | "initial_good_result": initial_metrics['good_result_count'], 171 | "initial_bad_result": initial_metrics['bad_result_count'] 172 | }) 173 | # Update initial_result_count based on threadflows 174 | if 'total_threadflows' in initial_metrics: 175 | report_data["initial_result_count"] = initial_metrics['total_threadflows'] 176 | 177 | if final_metrics: 178 | report_data.update({ 179 | "final_true_positive": final_metrics['true_positive_count'], 180 | "final_false_positive": final_metrics['false_positive_count'], 181 | "final_unknown_result": final_metrics['unknown_result_count'], 182 | "final_true_positive_rate": final_metrics['true_positive_rate'], 183 | "final_false_positive_rate": final_metrics['false_positive_rate'], 184 | "final_good_result": final_metrics['good_result_count'], 185 | "final_bad_result": final_metrics['bad_result_count'], 186 | "final_total_threadflows": final_metrics['total_threadflows'] 187 | }) 188 | # Update final_result_count based on threadflows 189 | if 'total_threadflows' in final_metrics: 190 | report_data["final_result_count"] = final_metrics['total_threadflows'] 191 | 192 | # Update iteration result counts based on eval results 193 | if 'iterations' in report_data: 194 | for iteration in report_data['iterations']: 195 | # Find corresponding eval result for this iteration 196 | for result in eval_results: 197 | if f'iteration_{iteration["iteration"]}' in result['sarif_file']: 198 | iteration['result_count'] = result['metrics']['total_threadflows'] 199 | if 'validation' in iteration: 200 | iteration['validation']['current_count'] = result['metrics']['total_threadflows'] 201 | break 202 | 203 | # Calculate improvement 204 | if initial_metrics: 205 | initial_tp = initial_metrics['true_positive_count'] 206 | final_tp = final_metrics['true_positive_count'] 207 | initial_fp = initial_metrics['false_positive_count'] 208 | final_fp = final_metrics['false_positive_count'] 209 | initial_count = initial_metrics.get('total_threadflows', 0) 210 | final_count = final_metrics.get('total_threadflows', 0) 211 | 212 | tp_improvement = final_tp - initial_tp 213 | fp_improvement = final_fp - initial_fp 214 | 215 | report_data["true_positive_improvement"] = { 216 | "absolute": tp_improvement, 217 | "percentage": (tp_improvement / initial_tp * 100) if initial_tp > 0 else 0.0 218 | } 219 | 220 | report_data["false_positive_improvement"] = { 221 | "absolute": fp_improvement, 222 | "percentage": (fp_improvement / initial_fp * 100) if initial_fp > 0 else 0.0 223 | } 224 | 225 | # Update overall improvement 226 | report_data["overall_improvement"] = { 227 | "absolute": final_count - initial_count, 228 | "percentage": ((final_count - initial_count) / initial_count * 100) if initial_count > 0 else 0.0 if final_count == 0 else 100.0 229 | } 230 | 231 | # Save updated report 232 | with open(final_report_path, 'w') as f: 233 | json.dump(report_data, f, indent=2) 234 | 235 | print(f"Updated: {final_report_path}") 236 | 237 | except Exception as e: 238 | print(f"Error updating final report {final_report_path}: {str(e)}") 239 | 240 | 241 | def main(): 242 | parser = argparse.ArgumentParser(description='Run QL Workflow for modifying CodeQL queries') 243 | parser.add_argument('--cwe', type=int, 244 | help='CWE number to process') 245 | parser.add_argument('--all', action='store_true', 246 | help='Process all available CWEs') 247 | parser.add_argument('--max-iterations', type=int, default=10, 248 | help='Maximum number of iterations per QL file (default: 10)') 249 | parser.add_argument('--output-dir', type=str, 250 | default=os.path.join(SCRIPT_DIR, 'qlworkspace'), 251 | help='Output directory for results and logs') 252 | parser.add_argument('--mode', type=str, default='agent', 253 | help='Execution mode (default: agent)') 254 | parser.add_argument('--query', type=str, 255 | help='Specific query name to run (e.g., TaintedAllocationSize)') 256 | parser.add_argument('--eval-only', action='store_true', 257 | help='Only run evaluation on existing SARIF files without running queries') 258 | 259 | args = parser.parse_args() 260 | 261 | # Validate arguments 262 | if not args.all and not args.cwe: 263 | parser.error('Either --cwe or --all must be specified') 264 | if args.all and args.cwe: 265 | parser.error('Cannot specify both --cwe and --all options') 266 | if args.query and args.all: 267 | parser.error('--query can only be used with --cwe, not with --all') 268 | # Handle eval-only mode 269 | if args.eval_only: 270 | if args.all: 271 | # Run evaluation for all CWEs 272 | print("Running evaluation-only mode for all CWEs") 273 | 274 | # Find all CWE directories 275 | cwe_dirs = glob.glob(os.path.join(args.output_dir, "CWE-*")) 276 | cwe_numbers = set() 277 | 278 | for cwe_dir in cwe_dirs: 279 | # Extract CWE number from directory name 280 | dirname = os.path.basename(cwe_dir) 281 | if dirname.startswith("CWE-") and "_" in dirname: 282 | cwe_num = dirname.split("-")[1].split("_")[0] 283 | if cwe_num.isdigit(): 284 | cwe_numbers.add(int(cwe_num)) 285 | 286 | # Run evaluation for each unique CWE 287 | for cwe_num in sorted(cwe_numbers): 288 | print(f"\n{'='*60}") 289 | print(f"Processing CWE-{cwe_num}") 290 | print(f"{'='*60}") 291 | run_evaluation_only(cwe_num, args.output_dir, None) 292 | 293 | return 294 | elif args.cwe: 295 | print(f"Running evaluation-only mode for CWE-{args.cwe}") 296 | if args.query: 297 | print(f"Filtering for query: {args.query}") 298 | 299 | run_evaluation_only(args.cwe, args.output_dir, args.query) 300 | return 301 | else: 302 | parser.error('--eval-only requires either --cwe or --all to be specified') 303 | 304 | # Create context 305 | context = QLWorkflowContext( 306 | max_iterations=args.max_iterations, 307 | output_dir=args.output_dir, 308 | specific_cwe=args.cwe, 309 | process_all_cwes=args.all, 310 | specific_query=args.query 311 | ) 312 | 313 | # Create and run the state machine 314 | print(f"Starting QL Workflow...") 315 | if args.all: 316 | print(f"Processing: All CWEs") 317 | else: 318 | print(f"CWE: {args.cwe}") 319 | if args.query: 320 | print(f"Query: {args.query}") 321 | print(f"Max iterations: {args.max_iterations}") 322 | print(f"Output directory: {args.output_dir}") 323 | print(f"Mode: {args.mode} (using BaseMachine agent mode)") 324 | 325 | # Set config path 326 | config_path = os.path.join(SCRIPT_DIR, '.config/config.json') 327 | 328 | machine = StateMachine( 329 | state_definitions=state_definitions, 330 | initial_state='GetCommonCWEs', 331 | context=context, 332 | mode=args.mode, 333 | config_path=config_path 334 | ) 335 | 336 | try: 337 | # Run the workflow 338 | machine.process() 339 | print("\nQL Workflow completed successfully!") 340 | 341 | except RuntimeError as e: 342 | if "Claude AI usage limit reached" in str(e): 343 | print("\n[STOPPED] Claude AI usage limit reached. Pipeline stopped gracefully.") 344 | print("The workflow was interrupted due to API rate limits.") 345 | print(f"Partial results saved to: {args.output_dir}") 346 | # Don't re-raise - exit gracefully 347 | return 348 | else: 349 | print(f"\nRuntime error in QL Workflow: {str(e)}") 350 | raise 351 | except Exception as e: 352 | print(f"\nError running QL Workflow: {str(e)}") 353 | raise 354 | 355 | 356 | if __name__ == "__main__": 357 | main() -------------------------------------------------------------------------------- /start_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Start script for QL-Relax Docker environment 3 | 4 | set -e 5 | 6 | echo "Setting up QL-Relax Docker environment..." 7 | 8 | # Check if Docker is installed 9 | if ! command -v docker &> /dev/null; then 10 | echo "Error: Docker is not installed. Please install Docker first." 11 | exit 1 12 | fi 13 | 14 | # Build the image 15 | echo "Building QL-Relax image..." 16 | docker build -t ql-relax:latest . 17 | 18 | # Check if container already exists 19 | if docker ps -a --format '{{.Names}}' | grep -q '^ql-relax-container$'; then 20 | echo "Container 'ql-relax-container' already exists" 21 | # Start it if it's stopped 22 | docker start ql-relax-container 23 | else 24 | # Create and run new container 25 | echo "Creating new container 'ql-relax-container'..." 26 | docker run -d \ 27 | --name ql-relax-container \ 28 | -v "$(pwd)":/workspace \ 29 | -v "$(pwd)/juliet-test-suite-c":/workspace/juliet-test-suite-c \ 30 | ql-relax:latest \ 31 | tail -f /dev/null 32 | fi 33 | 34 | echo "" 35 | echo "QL-Relax Docker environment is ready!" 36 | echo "Container name: ql-relax-container" 37 | echo "" 38 | echo "You can now run QL-Relax commands:" 39 | echo " python3 run_juliet.py --cwe 190" 40 | echo "" 41 | echo "Or enter the container:" 42 | echo " docker exec -it ql-relax-container bash" --------------------------------------------------------------------------------