├── .dockerignore ├── .env.example ├── .gitattributes ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── crawled_pages.sql ├── knowledge_graphs ├── ai_hallucination_detector.py ├── ai_script_analyzer.py ├── hallucination_reporter.py ├── knowledge_graph_validator.py ├── parse_repo_into_neo4j.py ├── query_knowledge_graph.py └── test_script.py ├── pyproject.toml ├── src ├── crawl4ai_mcp.py └── utils.py └── uv.lock /.dockerignore: -------------------------------------------------------------------------------- 1 | crawl4ai_mcp.egg-info 2 | __pycache__ 3 | .venv 4 | .env -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # The transport for the MCP server - either 'sse' or 'stdio' (defaults to sse if left empty) 2 | TRANSPORT= 3 | 4 | # Host to bind to if using sse as the transport (leave empty if using stdio) 5 | # Set this to 0.0.0.0 if using Docker, otherwise set to localhost (if using uv) 6 | HOST= 7 | 8 | # Port to listen on if using sse as the transport (leave empty if using stdio) 9 | PORT= 10 | 11 | # Get your Open AI API Key by following these instructions - 12 | # https://help.openai.com/en/articles/4936850-where-do-i-find-my-openai-api-key 13 | # This is for the embedding model - text-embed-small-3 will be used 14 | OPENAI_API_KEY= 15 | 16 | # The LLM you want to use for summaries and contextual embeddings 17 | # Generally this is a very cheap and fast LLM like gpt-4.1-nano 18 | MODEL_CHOICE= 19 | 20 | # RAG strategies - set these to "true" or "false" (default to "false") 21 | # USE_CONTEXTUAL_EMBEDDINGS: Enhances embeddings with contextual information for better retrieval 22 | USE_CONTEXTUAL_EMBEDDINGS=false 23 | 24 | # USE_HYBRID_SEARCH: Combines vector similarity search with keyword search for better results 25 | USE_HYBRID_SEARCH=false 26 | 27 | # USE_AGENTIC_RAG: Enables code example extraction, storage, and specialized code search functionality 28 | USE_AGENTIC_RAG=false 29 | 30 | # USE_RERANKING: Applies cross-encoder reranking to improve search result relevance 31 | USE_RERANKING=false 32 | 33 | # USE_KNOWLEDGE_GRAPH: Enables AI hallucination detection and repository parsing tools using Neo4j 34 | # If you set this to true, you must also set the Neo4j environment variables below. 35 | USE_KNOWLEDGE_GRAPH=false 36 | 37 | # For the Supabase version (sample_supabase_agent.py), set your Supabase URL and Service Key. 38 | # Get your SUPABASE_URL from the API section of your Supabase project settings - 39 | # https://supabase.com/dashboard/project/<your project ID>/settings/api 40 | SUPABASE_URL= 41 | 42 | # Get your SUPABASE_SERVICE_KEY from the API section of your Supabase project settings - 43 | # https://supabase.com/dashboard/project/<your project ID>/settings/api 44 | # On this page it is called the service_role secret. 45 | SUPABASE_SERVICE_KEY= 46 | 47 | # Neo4j Configuration for Knowledge Graph Tools 48 | # These are required for the AI hallucination detection and repository parsing tools 49 | # Leave empty to disable knowledge graph functionality 50 | 51 | # Neo4j connection URI - use bolt://localhost:7687 for local, neo4j:// for cloud instances 52 | # IMPORTANT: If running the MCP server through Docker, change localhost to host.docker.internal 53 | NEO4J_URI=bolt://localhost:7687 54 | 55 | # Neo4j username (usually 'neo4j' for default installations) 56 | NEO4J_USER=neo4j 57 | 58 | # Neo4j password for your database instance 59 | NEO4J_PASSWORD= -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | .venv 3 | __pycache__ 4 | crawl4ai_mcp.egg-info 5 | repos 6 | .claude 7 | test_script_hallucination* -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.12-slim 2 | 3 | ARG PORT=8051 4 | 5 | WORKDIR /app 6 | 7 | # Install uv 8 | RUN pip install uv 9 | 10 | # Copy the MCP server files 11 | COPY . . 12 | 13 | # Install packages directly to the system (no virtual environment) 14 | # Combining commands to reduce Docker layers 15 | RUN uv pip install --system -e . && \ 16 | crawl4ai-setup 17 | 18 | EXPOSE ${PORT} 19 | 20 | # Command to run the MCP server 21 | CMD ["python", "src/crawl4ai_mcp.py"] 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Cole Medin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | <h1 align="center">Crawl4AI RAG MCP Server</h1> 2 | 3 | <p align="center"> 4 | <em>Web Crawling and RAG Capabilities for AI Agents and AI Coding Assistants</em> 5 | </p> 6 | 7 | A powerful implementation of the [Model Context Protocol (MCP)](https://modelcontextprotocol.io) integrated with [Crawl4AI](https://crawl4ai.com) and [Supabase](https://supabase.com/) for providing AI agents and AI coding assistants with advanced web crawling and RAG capabilities. 8 | 9 | With this MCP server, you can <b>scrape anything</b> and then <b>use that knowledge anywhere</b> for RAG. 10 | 11 | The primary goal is to bring this MCP server into [Archon](https://github.com/coleam00/Archon) as I evolve it to be more of a knowledge engine for AI coding assistants to build AI agents. This first version of the Crawl4AI/RAG MCP server will be improved upon greatly soon, especially making it more configurable so you can use different embedding models and run everything locally with Ollama. 12 | 13 | Consider this GitHub repository a testbed, hence why I haven't been super actively address issues and pull requests yet. I certainly will though as I bring this into Archon V2! 14 | 15 | ## Overview 16 | 17 | This MCP server provides tools that enable AI agents to crawl websites, store content in a vector database (Supabase), and perform RAG over the crawled content. It follows the best practices for building MCP servers based on the [Mem0 MCP server template](https://github.com/coleam00/mcp-mem0/) I provided on my channel previously. 18 | 19 | The server includes several advanced RAG strategies that can be enabled to enhance retrieval quality: 20 | - **Contextual Embeddings** for enriched semantic understanding 21 | - **Hybrid Search** combining vector and keyword search 22 | - **Agentic RAG** for specialized code example extraction 23 | - **Reranking** for improved result relevance using cross-encoder models 24 | - **Knowledge Graph** for AI hallucination detection and repository code analysis 25 | 26 | See the [Configuration section](#configuration) below for details on how to enable and configure these strategies. 27 | 28 | ## Vision 29 | 30 | The Crawl4AI RAG MCP server is just the beginning. Here's where we're headed: 31 | 32 | 1. **Integration with Archon**: Building this system directly into [Archon](https://github.com/coleam00/Archon) to create a comprehensive knowledge engine for AI coding assistants to build better AI agents. 33 | 34 | 2. **Multiple Embedding Models**: Expanding beyond OpenAI to support a variety of embedding models, including the ability to run everything locally with Ollama for complete control and privacy. 35 | 36 | 3. **Advanced RAG Strategies**: Implementing sophisticated retrieval techniques like contextual retrieval, late chunking, and others to move beyond basic "naive lookups" and significantly enhance the power and precision of the RAG system, especially as it integrates with Archon. 37 | 38 | 4. **Enhanced Chunking Strategy**: Implementing a Context 7-inspired chunking approach that focuses on examples and creates distinct, semantically meaningful sections for each chunk, improving retrieval precision. 39 | 40 | 5. **Performance Optimization**: Increasing crawling and indexing speed to make it more realistic to "quickly" index new documentation to then leverage it within the same prompt in an AI coding assistant. 41 | 42 | ## Features 43 | 44 | - **Smart URL Detection**: Automatically detects and handles different URL types (regular webpages, sitemaps, text files) 45 | - **Recursive Crawling**: Follows internal links to discover content 46 | - **Parallel Processing**: Efficiently crawls multiple pages simultaneously 47 | - **Content Chunking**: Intelligently splits content by headers and size for better processing 48 | - **Vector Search**: Performs RAG over crawled content, optionally filtering by data source for precision 49 | - **Source Retrieval**: Retrieve sources available for filtering to guide the RAG process 50 | 51 | ## Tools 52 | 53 | The server provides essential web crawling and search tools: 54 | 55 | ### Core Tools (Always Available) 56 | 57 | 1. **`crawl_single_page`**: Quickly crawl a single web page and store its content in the vector database 58 | 2. **`smart_crawl_url`**: Intelligently crawl a full website based on the type of URL provided (sitemap, llms-full.txt, or a regular webpage that needs to be crawled recursively) 59 | 3. **`get_available_sources`**: Get a list of all available sources (domains) in the database 60 | 4. **`perform_rag_query`**: Search for relevant content using semantic search with optional source filtering 61 | 62 | ### Conditional Tools 63 | 64 | 5. **`search_code_examples`** (requires `USE_AGENTIC_RAG=true`): Search specifically for code examples and their summaries from crawled documentation. This tool provides targeted code snippet retrieval for AI coding assistants. 65 | 66 | ### Knowledge Graph Tools (requires `USE_KNOWLEDGE_GRAPH=true`, see below) 67 | 68 | 6. **`parse_github_repository`**: Parse a GitHub repository into a Neo4j knowledge graph, extracting classes, methods, functions, and their relationships for hallucination detection 69 | 7. **`check_ai_script_hallucinations`**: Analyze Python scripts for AI hallucinations by validating imports, method calls, and class usage against the knowledge graph 70 | 8. **`query_knowledge_graph`**: Explore and query the Neo4j knowledge graph with commands like `repos`, `classes`, `methods`, and custom Cypher queries 71 | 72 | ## Prerequisites 73 | 74 | - [Docker/Docker Desktop](https://www.docker.com/products/docker-desktop/) if running the MCP server as a container (recommended) 75 | - [Python 3.12+](https://www.python.org/downloads/) if running the MCP server directly through uv 76 | - [Supabase](https://supabase.com/) (database for RAG) 77 | - [OpenAI API key](https://platform.openai.com/api-keys) (for generating embeddings) 78 | - [Neo4j](https://neo4j.com/) (optional, for knowledge graph functionality) - see [Knowledge Graph Setup](#knowledge-graph-setup) section 79 | 80 | ## Installation 81 | 82 | ### Using Docker (Recommended) 83 | 84 | 1. Clone this repository: 85 | ```bash 86 | git clone https://github.com/coleam00/mcp-crawl4ai-rag.git 87 | cd mcp-crawl4ai-rag 88 | ``` 89 | 90 | 2. Build the Docker image: 91 | ```bash 92 | docker build -t mcp/crawl4ai-rag --build-arg PORT=8051 . 93 | ``` 94 | 95 | 3. Create a `.env` file based on the configuration section below 96 | 97 | ### Using uv directly (no Docker) 98 | 99 | 1. Clone this repository: 100 | ```bash 101 | git clone https://github.com/coleam00/mcp-crawl4ai-rag.git 102 | cd mcp-crawl4ai-rag 103 | ``` 104 | 105 | 2. Install uv if you don't have it: 106 | ```bash 107 | pip install uv 108 | ``` 109 | 110 | 3. Create and activate a virtual environment: 111 | ```bash 112 | uv venv 113 | .venv\Scripts\activate 114 | # on Mac/Linux: source .venv/bin/activate 115 | ``` 116 | 117 | 4. Install dependencies: 118 | ```bash 119 | uv pip install -e . 120 | crawl4ai-setup 121 | ``` 122 | 123 | 5. Create a `.env` file based on the configuration section below 124 | 125 | ## Database Setup 126 | 127 | Before running the server, you need to set up the database with the pgvector extension: 128 | 129 | 1. Go to the SQL Editor in your Supabase dashboard (create a new project first if necessary) 130 | 131 | 2. Create a new query and paste the contents of `crawled_pages.sql` 132 | 133 | 3. Run the query to create the necessary tables and functions 134 | 135 | ## Knowledge Graph Setup (Optional) 136 | 137 | To enable AI hallucination detection and repository analysis features, you need to set up Neo4j. 138 | 139 | Also, the knowledge graph implementation isn't fully compatible with Docker yet, so I would recommend right now running directly through uv if you want to use the hallucination detection within the MCP server! 140 | 141 | For installing Neo4j: 142 | 143 | ### Local AI Package (Recommended) 144 | 145 | The easiest way to get Neo4j running locally is with the [Local AI Package](https://github.com/coleam00/local-ai-packaged) - a curated collection of local AI services including Neo4j: 146 | 147 | 1. **Clone the Local AI Package**: 148 | ```bash 149 | git clone https://github.com/coleam00/local-ai-packaged.git 150 | cd local-ai-packaged 151 | ``` 152 | 153 | 2. **Start Neo4j**: 154 | Follow the instructions in the Local AI Package repository to start Neo4j with Docker Compose 155 | 156 | 3. **Default connection details**: 157 | - URI: `bolt://localhost:7687` 158 | - Username: `neo4j` 159 | - Password: Check the Local AI Package documentation for the default password 160 | 161 | ### Manual Neo4j Installation 162 | 163 | Alternatively, install Neo4j directly: 164 | 165 | 1. **Install Neo4j Desktop**: Download from [neo4j.com/download](https://neo4j.com/download/) 166 | 167 | 2. **Create a new database**: 168 | - Open Neo4j Desktop 169 | - Create a new project and database 170 | - Set a password for the `neo4j` user 171 | - Start the database 172 | 173 | 3. **Note your connection details**: 174 | - URI: `bolt://localhost:7687` (default) 175 | - Username: `neo4j` (default) 176 | - Password: Whatever you set during creation 177 | 178 | ## Configuration 179 | 180 | Create a `.env` file in the project root with the following variables: 181 | 182 | ``` 183 | # MCP Server Configuration 184 | HOST=0.0.0.0 185 | PORT=8051 186 | TRANSPORT=sse 187 | 188 | # OpenAI API Configuration 189 | OPENAI_API_KEY=your_openai_api_key 190 | 191 | # LLM for summaries and contextual embeddings 192 | MODEL_CHOICE=gpt-4.1-nano 193 | 194 | # RAG Strategies (set to "true" or "false", default to "false") 195 | USE_CONTEXTUAL_EMBEDDINGS=false 196 | USE_HYBRID_SEARCH=false 197 | USE_AGENTIC_RAG=false 198 | USE_RERANKING=false 199 | USE_KNOWLEDGE_GRAPH=false 200 | 201 | # Supabase Configuration 202 | SUPABASE_URL=your_supabase_project_url 203 | SUPABASE_SERVICE_KEY=your_supabase_service_key 204 | 205 | # Neo4j Configuration (required for knowledge graph functionality) 206 | NEO4J_URI=bolt://localhost:7687 207 | NEO4J_USER=neo4j 208 | NEO4J_PASSWORD=your_neo4j_password 209 | ``` 210 | 211 | ### RAG Strategy Options 212 | 213 | The Crawl4AI RAG MCP server supports four powerful RAG strategies that can be enabled independently: 214 | 215 | #### 1. **USE_CONTEXTUAL_EMBEDDINGS** 216 | When enabled, this strategy enhances each chunk's embedding with additional context from the entire document. The system passes both the full document and the specific chunk to an LLM (configured via `MODEL_CHOICE`) to generate enriched context that gets embedded alongside the chunk content. 217 | 218 | - **When to use**: Enable this when you need high-precision retrieval where context matters, such as technical documentation where terms might have different meanings in different sections. 219 | - **Trade-offs**: Slower indexing due to LLM calls for each chunk, but significantly better retrieval accuracy. 220 | - **Cost**: Additional LLM API calls during indexing. 221 | 222 | #### 2. **USE_HYBRID_SEARCH** 223 | Combines traditional keyword search with semantic vector search to provide more comprehensive results. The system performs both searches in parallel and intelligently merges results, prioritizing documents that appear in both result sets. 224 | 225 | - **When to use**: Enable this when users might search using specific technical terms, function names, or when exact keyword matches are important alongside semantic understanding. 226 | - **Trade-offs**: Slightly slower search queries but more robust results, especially for technical content. 227 | - **Cost**: No additional API costs, just computational overhead. 228 | 229 | #### 3. **USE_AGENTIC_RAG** 230 | Enables specialized code example extraction and storage. When crawling documentation, the system identifies code blocks (≥300 characters), extracts them with surrounding context, generates summaries, and stores them in a separate vector database table specifically designed for code search. 231 | 232 | - **When to use**: Essential for AI coding assistants that need to find specific code examples, implementation patterns, or usage examples from documentation. 233 | - **Trade-offs**: Significantly slower crawling due to code extraction and summarization, requires more storage space. 234 | - **Cost**: Additional LLM API calls for summarizing each code example. 235 | - **Benefits**: Provides a dedicated `search_code_examples` tool that AI agents can use to find specific code implementations. 236 | 237 | #### 4. **USE_RERANKING** 238 | Applies cross-encoder reranking to search results after initial retrieval. Uses a lightweight cross-encoder model (`cross-encoder/ms-marco-MiniLM-L-6-v2`) to score each result against the original query, then reorders results by relevance. 239 | 240 | - **When to use**: Enable this when search precision is critical and you need the most relevant results at the top. Particularly useful for complex queries where semantic similarity alone might not capture query intent. 241 | - **Trade-offs**: Adds ~100-200ms to search queries depending on result count, but significantly improves result ordering. 242 | - **Cost**: No additional API costs - uses a local model that runs on CPU. 243 | - **Benefits**: Better result relevance, especially for complex queries. Works with both regular RAG search and code example search. 244 | 245 | #### 5. **USE_KNOWLEDGE_GRAPH** 246 | Enables AI hallucination detection and repository analysis using Neo4j knowledge graphs. When enabled, the system can parse GitHub repositories into a graph database and validate AI-generated code against real repository structures. (NOT fully compatible with Docker yet, I'd recommend running through uv) 247 | 248 | - **When to use**: Enable this for AI coding assistants that need to validate generated code against real implementations, or when you want to detect when AI models hallucinate non-existent methods, classes, or incorrect usage patterns. 249 | - **Trade-offs**: Requires Neo4j setup and additional dependencies. Repository parsing can be slow for large codebases, and validation requires repositories to be pre-indexed. 250 | - **Cost**: No additional API costs for validation, but requires Neo4j infrastructure (can use free local installation or cloud AuraDB). 251 | - **Benefits**: Provides three powerful tools: `parse_github_repository` for indexing codebases, `check_ai_script_hallucinations` for validating AI-generated code, and `query_knowledge_graph` for exploring indexed repositories. 252 | 253 | You can now tell the AI coding assistant to add a Python GitHub repository to the knowledge graph like: 254 | 255 | "Add https://github.com/pydantic/pydantic-ai.git to the knowledge graph" 256 | 257 | Make sure the repo URL ends with .git. 258 | 259 | You can also have the AI coding assistant check for hallucinations with scripts it just created, or you can manually run the command: 260 | 261 | ``` 262 | python knowledge_graphs/ai_hallucination_detector.py [full path to your script to analyze] 263 | ``` 264 | 265 | ### Recommended Configurations 266 | 267 | **For general documentation RAG:** 268 | ``` 269 | USE_CONTEXTUAL_EMBEDDINGS=false 270 | USE_HYBRID_SEARCH=true 271 | USE_AGENTIC_RAG=false 272 | USE_RERANKING=true 273 | ``` 274 | 275 | **For AI coding assistant with code examples:** 276 | ``` 277 | USE_CONTEXTUAL_EMBEDDINGS=true 278 | USE_HYBRID_SEARCH=true 279 | USE_AGENTIC_RAG=true 280 | USE_RERANKING=true 281 | USE_KNOWLEDGE_GRAPH=false 282 | ``` 283 | 284 | **For AI coding assistant with hallucination detection:** 285 | ``` 286 | USE_CONTEXTUAL_EMBEDDINGS=true 287 | USE_HYBRID_SEARCH=true 288 | USE_AGENTIC_RAG=true 289 | USE_RERANKING=true 290 | USE_KNOWLEDGE_GRAPH=true 291 | ``` 292 | 293 | **For fast, basic RAG:** 294 | ``` 295 | USE_CONTEXTUAL_EMBEDDINGS=false 296 | USE_HYBRID_SEARCH=true 297 | USE_AGENTIC_RAG=false 298 | USE_RERANKING=false 299 | USE_KNOWLEDGE_GRAPH=false 300 | ``` 301 | 302 | ## Running the Server 303 | 304 | ### Using Docker 305 | 306 | ```bash 307 | docker run --env-file .env -p 8051:8051 mcp/crawl4ai-rag 308 | ``` 309 | 310 | ### Using Python 311 | 312 | ```bash 313 | uv run src/crawl4ai_mcp.py 314 | ``` 315 | 316 | The server will start and listen on the configured host and port. 317 | 318 | ## Integration with MCP Clients 319 | 320 | ### SSE Configuration 321 | 322 | Once you have the server running with SSE transport, you can connect to it using this configuration: 323 | 324 | ```json 325 | { 326 | "mcpServers": { 327 | "crawl4ai-rag": { 328 | "transport": "sse", 329 | "url": "http://localhost:8051/sse" 330 | } 331 | } 332 | } 333 | ``` 334 | 335 | > **Note for Windsurf users**: Use `serverUrl` instead of `url` in your configuration: 336 | > ```json 337 | > { 338 | > "mcpServers": { 339 | > "crawl4ai-rag": { 340 | > "transport": "sse", 341 | > "serverUrl": "http://localhost:8051/sse" 342 | > } 343 | > } 344 | > } 345 | > ``` 346 | > 347 | > **Note for Docker users**: Use `host.docker.internal` instead of `localhost` if your client is running in a different container. This will apply if you are using this MCP server within n8n! 348 | 349 | > **Note for Claude Code users**: 350 | ``` 351 | claude mcp add-json crawl4ai-rag '{"type":"http","url":"http://localhost:8051/sse"}' --scope user 352 | ``` 353 | 354 | ### Stdio Configuration 355 | 356 | Add this server to your MCP configuration for Claude Desktop, Windsurf, or any other MCP client: 357 | 358 | ```json 359 | { 360 | "mcpServers": { 361 | "crawl4ai-rag": { 362 | "command": "python", 363 | "args": ["path/to/crawl4ai-mcp/src/crawl4ai_mcp.py"], 364 | "env": { 365 | "TRANSPORT": "stdio", 366 | "OPENAI_API_KEY": "your_openai_api_key", 367 | "SUPABASE_URL": "your_supabase_url", 368 | "SUPABASE_SERVICE_KEY": "your_supabase_service_key", 369 | "USE_KNOWLEDGE_GRAPH": "false", 370 | "NEO4J_URI": "bolt://localhost:7687", 371 | "NEO4J_USER": "neo4j", 372 | "NEO4J_PASSWORD": "your_neo4j_password" 373 | } 374 | } 375 | } 376 | } 377 | ``` 378 | 379 | ### Docker with Stdio Configuration 380 | 381 | ```json 382 | { 383 | "mcpServers": { 384 | "crawl4ai-rag": { 385 | "command": "docker", 386 | "args": ["run", "--rm", "-i", 387 | "-e", "TRANSPORT", 388 | "-e", "OPENAI_API_KEY", 389 | "-e", "SUPABASE_URL", 390 | "-e", "SUPABASE_SERVICE_KEY", 391 | "-e", "USE_KNOWLEDGE_GRAPH", 392 | "-e", "NEO4J_URI", 393 | "-e", "NEO4J_USER", 394 | "-e", "NEO4J_PASSWORD", 395 | "mcp/crawl4ai"], 396 | "env": { 397 | "TRANSPORT": "stdio", 398 | "OPENAI_API_KEY": "your_openai_api_key", 399 | "SUPABASE_URL": "your_supabase_url", 400 | "SUPABASE_SERVICE_KEY": "your_supabase_service_key", 401 | "USE_KNOWLEDGE_GRAPH": "false", 402 | "NEO4J_URI": "bolt://localhost:7687", 403 | "NEO4J_USER": "neo4j", 404 | "NEO4J_PASSWORD": "your_neo4j_password" 405 | } 406 | } 407 | } 408 | } 409 | ``` 410 | 411 | ## Knowledge Graph Architecture 412 | 413 | The knowledge graph system stores repository code structure in Neo4j with the following components: 414 | 415 | ### Core Components (`knowledge_graphs/` folder): 416 | 417 | - **`parse_repo_into_neo4j.py`**: Clones and analyzes GitHub repositories, extracting Python classes, methods, functions, and imports into Neo4j nodes and relationships 418 | - **`ai_script_analyzer.py`**: Parses Python scripts using AST to extract imports, class instantiations, method calls, and function usage 419 | - **`knowledge_graph_validator.py`**: Validates AI-generated code against the knowledge graph to detect hallucinations (non-existent methods, incorrect parameters, etc.) 420 | - **`hallucination_reporter.py`**: Generates comprehensive reports about detected hallucinations with confidence scores and recommendations 421 | - **`query_knowledge_graph.py`**: Interactive CLI tool for exploring the knowledge graph (functionality now integrated into MCP tools) 422 | 423 | ### Knowledge Graph Schema: 424 | 425 | The Neo4j database stores code structure as: 426 | 427 | **Nodes:** 428 | - `Repository`: GitHub repositories 429 | - `File`: Python files within repositories 430 | - `Class`: Python classes with methods and attributes 431 | - `Method`: Class methods with parameter information 432 | - `Function`: Standalone functions 433 | - `Attribute`: Class attributes 434 | 435 | **Relationships:** 436 | - `Repository` -[:CONTAINS]-> `File` 437 | - `File` -[:DEFINES]-> `Class` 438 | - `File` -[:DEFINES]-> `Function` 439 | - `Class` -[:HAS_METHOD]-> `Method` 440 | - `Class` -[:HAS_ATTRIBUTE]-> `Attribute` 441 | 442 | ### Workflow: 443 | 444 | 1. **Repository Parsing**: Use `parse_github_repository` tool to clone and analyze open-source repositories 445 | 2. **Code Validation**: Use `check_ai_script_hallucinations` tool to validate AI-generated Python scripts 446 | 3. **Knowledge Exploration**: Use `query_knowledge_graph` tool to explore available repositories, classes, and methods 447 | 448 | ## Building Your Own Server 449 | 450 | This implementation provides a foundation for building more complex MCP servers with web crawling capabilities. To build your own: 451 | 452 | 1. Add your own tools by creating methods with the `@mcp.tool()` decorator 453 | 2. Create your own lifespan function to add your own dependencies 454 | 3. Modify the `utils.py` file for any helper functions you need 455 | 4. Extend the crawling capabilities by adding more specialized crawlers 456 | -------------------------------------------------------------------------------- /crawled_pages.sql: -------------------------------------------------------------------------------- 1 | -- Enable the pgvector extension 2 | create extension if not exists vector; 3 | 4 | -- Drop tables if they exist (to allow rerunning the script) 5 | drop table if exists crawled_pages; 6 | drop table if exists code_examples; 7 | drop table if exists sources; 8 | 9 | -- Create the sources table 10 | create table sources ( 11 | source_id text primary key, 12 | summary text, 13 | total_word_count integer default 0, 14 | created_at timestamp with time zone default timezone('utc'::text, now()) not null, 15 | updated_at timestamp with time zone default timezone('utc'::text, now()) not null 16 | ); 17 | 18 | -- Create the documentation chunks table 19 | create table crawled_pages ( 20 | id bigserial primary key, 21 | url varchar not null, 22 | chunk_number integer not null, 23 | content text not null, 24 | metadata jsonb not null default '{}'::jsonb, 25 | source_id text not null, 26 | embedding vector(1536), -- OpenAI embeddings are 1536 dimensions 27 | created_at timestamp with time zone default timezone('utc'::text, now()) not null, 28 | 29 | -- Add a unique constraint to prevent duplicate chunks for the same URL 30 | unique(url, chunk_number), 31 | 32 | -- Add foreign key constraint to sources table 33 | foreign key (source_id) references sources(source_id) 34 | ); 35 | 36 | -- Create an index for better vector similarity search performance 37 | create index on crawled_pages using ivfflat (embedding vector_cosine_ops); 38 | 39 | -- Create an index on metadata for faster filtering 40 | create index idx_crawled_pages_metadata on crawled_pages using gin (metadata); 41 | 42 | -- Create an index on source_id for faster filtering 43 | CREATE INDEX idx_crawled_pages_source_id ON crawled_pages (source_id); 44 | 45 | -- Create a function to search for documentation chunks 46 | create or replace function match_crawled_pages ( 47 | query_embedding vector(1536), 48 | match_count int default 10, 49 | filter jsonb DEFAULT '{}'::jsonb, 50 | source_filter text DEFAULT NULL 51 | ) returns table ( 52 | id bigint, 53 | url varchar, 54 | chunk_number integer, 55 | content text, 56 | metadata jsonb, 57 | source_id text, 58 | similarity float 59 | ) 60 | language plpgsql 61 | as $ 62 | #variable_conflict use_column 63 | begin 64 | return query 65 | select 66 | id, 67 | url, 68 | chunk_number, 69 | content, 70 | metadata, 71 | source_id, 72 | 1 - (crawled_pages.embedding <=> query_embedding) as similarity 73 | from crawled_pages 74 | where metadata @> filter 75 | AND (source_filter IS NULL OR source_id = source_filter) 76 | order by crawled_pages.embedding <=> query_embedding 77 | limit match_count; 78 | end; 79 | $; 80 | 81 | -- Enable RLS on the crawled_pages table 82 | alter table crawled_pages enable row level security; 83 | 84 | -- Create a policy that allows anyone to read crawled_pages 85 | create policy "Allow public read access to crawled_pages" 86 | on crawled_pages 87 | for select 88 | to public 89 | using (true); 90 | 91 | -- Enable RLS on the sources table 92 | alter table sources enable row level security; 93 | 94 | -- Create a policy that allows anyone to read sources 95 | create policy "Allow public read access to sources" 96 | on sources 97 | for select 98 | to public 99 | using (true); 100 | 101 | -- Create the code_examples table 102 | create table code_examples ( 103 | id bigserial primary key, 104 | url varchar not null, 105 | chunk_number integer not null, 106 | content text not null, -- The code example content 107 | summary text not null, -- Summary of the code example 108 | metadata jsonb not null default '{}'::jsonb, 109 | source_id text not null, 110 | embedding vector(1536), -- OpenAI embeddings are 1536 dimensions 111 | created_at timestamp with time zone default timezone('utc'::text, now()) not null, 112 | 113 | -- Add a unique constraint to prevent duplicate chunks for the same URL 114 | unique(url, chunk_number), 115 | 116 | -- Add foreign key constraint to sources table 117 | foreign key (source_id) references sources(source_id) 118 | ); 119 | 120 | -- Create an index for better vector similarity search performance 121 | create index on code_examples using ivfflat (embedding vector_cosine_ops); 122 | 123 | -- Create an index on metadata for faster filtering 124 | create index idx_code_examples_metadata on code_examples using gin (metadata); 125 | 126 | -- Create an index on source_id for faster filtering 127 | CREATE INDEX idx_code_examples_source_id ON code_examples (source_id); 128 | 129 | -- Create a function to search for code examples 130 | create or replace function match_code_examples ( 131 | query_embedding vector(1536), 132 | match_count int default 10, 133 | filter jsonb DEFAULT '{}'::jsonb, 134 | source_filter text DEFAULT NULL 135 | ) returns table ( 136 | id bigint, 137 | url varchar, 138 | chunk_number integer, 139 | content text, 140 | summary text, 141 | metadata jsonb, 142 | source_id text, 143 | similarity float 144 | ) 145 | language plpgsql 146 | as $ 147 | #variable_conflict use_column 148 | begin 149 | return query 150 | select 151 | id, 152 | url, 153 | chunk_number, 154 | content, 155 | summary, 156 | metadata, 157 | source_id, 158 | 1 - (code_examples.embedding <=> query_embedding) as similarity 159 | from code_examples 160 | where metadata @> filter 161 | AND (source_filter IS NULL OR source_id = source_filter) 162 | order by code_examples.embedding <=> query_embedding 163 | limit match_count; 164 | end; 165 | $; 166 | 167 | -- Enable RLS on the code_examples table 168 | alter table code_examples enable row level security; 169 | 170 | -- Create a policy that allows anyone to read code_examples 171 | create policy "Allow public read access to code_examples" 172 | on code_examples 173 | for select 174 | to public 175 | using (true); -------------------------------------------------------------------------------- /knowledge_graphs/ai_hallucination_detector.py: -------------------------------------------------------------------------------- 1 | """ 2 | AI Hallucination Detector 3 | 4 | Main orchestrator for detecting AI coding assistant hallucinations in Python scripts. 5 | Combines AST analysis, knowledge graph validation, and comprehensive reporting. 6 | """ 7 | 8 | import asyncio 9 | import argparse 10 | import logging 11 | import os 12 | import sys 13 | from pathlib import Path 14 | from typing import Optional, List 15 | 16 | from dotenv import load_dotenv 17 | 18 | from ai_script_analyzer import AIScriptAnalyzer, analyze_ai_script 19 | from knowledge_graph_validator import KnowledgeGraphValidator 20 | from hallucination_reporter import HallucinationReporter 21 | 22 | # Configure logging 23 | logging.basicConfig( 24 | level=logging.INFO, 25 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 26 | datefmt='%Y-%m-%d %H:%M:%S' 27 | ) 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class AIHallucinationDetector: 32 | """Main detector class that orchestrates the entire process""" 33 | 34 | def __init__(self, neo4j_uri: str, neo4j_user: str, neo4j_password: str): 35 | self.validator = KnowledgeGraphValidator(neo4j_uri, neo4j_user, neo4j_password) 36 | self.reporter = HallucinationReporter() 37 | self.analyzer = AIScriptAnalyzer() 38 | 39 | async def initialize(self): 40 | """Initialize connections and components""" 41 | await self.validator.initialize() 42 | logger.info("AI Hallucination Detector initialized successfully") 43 | 44 | async def close(self): 45 | """Close connections""" 46 | await self.validator.close() 47 | 48 | async def detect_hallucinations(self, script_path: str, 49 | output_dir: Optional[str] = None, 50 | save_json: bool = True, 51 | save_markdown: bool = True, 52 | print_summary: bool = True) -> dict: 53 | """ 54 | Main detection function that analyzes a script and generates reports 55 | 56 | Args: 57 | script_path: Path to the AI-generated Python script 58 | output_dir: Directory to save reports (defaults to script directory) 59 | save_json: Whether to save JSON report 60 | save_markdown: Whether to save Markdown report 61 | print_summary: Whether to print summary to console 62 | 63 | Returns: 64 | Complete validation report as dictionary 65 | """ 66 | logger.info(f"Starting hallucination detection for: {script_path}") 67 | 68 | # Validate input 69 | if not os.path.exists(script_path): 70 | raise FileNotFoundError(f"Script not found: {script_path}") 71 | 72 | if not script_path.endswith('.py'): 73 | raise ValueError("Only Python (.py) files are supported") 74 | 75 | # Set output directory 76 | if output_dir is None: 77 | output_dir = str(Path(script_path).parent) 78 | 79 | os.makedirs(output_dir, exist_ok=True) 80 | 81 | try: 82 | # Step 1: Analyze the script using AST 83 | logger.info("Step 1: Analyzing script structure...") 84 | analysis_result = self.analyzer.analyze_script(script_path) 85 | 86 | if analysis_result.errors: 87 | logger.warning(f"Analysis warnings: {analysis_result.errors}") 88 | 89 | logger.info(f"Found: {len(analysis_result.imports)} imports, " 90 | f"{len(analysis_result.class_instantiations)} class instantiations, " 91 | f"{len(analysis_result.method_calls)} method calls, " 92 | f"{len(analysis_result.function_calls)} function calls, " 93 | f"{len(analysis_result.attribute_accesses)} attribute accesses") 94 | 95 | # Step 2: Validate against knowledge graph 96 | logger.info("Step 2: Validating against knowledge graph...") 97 | validation_result = await self.validator.validate_script(analysis_result) 98 | 99 | logger.info(f"Validation complete. Overall confidence: {validation_result.overall_confidence:.1%}") 100 | 101 | # Step 3: Generate comprehensive report 102 | logger.info("Step 3: Generating reports...") 103 | report = self.reporter.generate_comprehensive_report(validation_result) 104 | 105 | # Step 4: Save reports 106 | script_name = Path(script_path).stem 107 | 108 | if save_json: 109 | json_path = os.path.join(output_dir, f"{script_name}_hallucination_report.json") 110 | self.reporter.save_json_report(report, json_path) 111 | 112 | if save_markdown: 113 | md_path = os.path.join(output_dir, f"{script_name}_hallucination_report.md") 114 | self.reporter.save_markdown_report(report, md_path) 115 | 116 | # Step 5: Print summary 117 | if print_summary: 118 | self.reporter.print_summary(report) 119 | 120 | logger.info("Hallucination detection completed successfully") 121 | return report 122 | 123 | except Exception as e: 124 | logger.error(f"Error during hallucination detection: {str(e)}") 125 | raise 126 | 127 | async def batch_detect(self, script_paths: List[str], 128 | output_dir: Optional[str] = None) -> List[dict]: 129 | """ 130 | Detect hallucinations in multiple scripts 131 | 132 | Args: 133 | script_paths: List of paths to Python scripts 134 | output_dir: Directory to save all reports 135 | 136 | Returns: 137 | List of validation reports 138 | """ 139 | logger.info(f"Starting batch detection for {len(script_paths)} scripts") 140 | 141 | results = [] 142 | for i, script_path in enumerate(script_paths, 1): 143 | logger.info(f"Processing script {i}/{len(script_paths)}: {script_path}") 144 | 145 | try: 146 | result = await self.detect_hallucinations( 147 | script_path=script_path, 148 | output_dir=output_dir, 149 | print_summary=False # Don't print individual summaries in batch mode 150 | ) 151 | results.append(result) 152 | 153 | except Exception as e: 154 | logger.error(f"Failed to process {script_path}: {str(e)}") 155 | # Continue with other scripts 156 | continue 157 | 158 | # Print batch summary 159 | self._print_batch_summary(results) 160 | 161 | return results 162 | 163 | def _print_batch_summary(self, results: List[dict]): 164 | """Print summary of batch processing results""" 165 | if not results: 166 | print("No scripts were successfully processed.") 167 | return 168 | 169 | print("\n" + "="*80) 170 | print("🚀 BATCH HALLUCINATION DETECTION SUMMARY") 171 | print("="*80) 172 | 173 | total_scripts = len(results) 174 | total_validations = sum(r['validation_summary']['total_validations'] for r in results) 175 | total_valid = sum(r['validation_summary']['valid_count'] for r in results) 176 | total_invalid = sum(r['validation_summary']['invalid_count'] for r in results) 177 | total_not_found = sum(r['validation_summary']['not_found_count'] for r in results) 178 | total_hallucinations = sum(len(r['hallucinations_detected']) for r in results) 179 | 180 | avg_confidence = sum(r['validation_summary']['overall_confidence'] for r in results) / total_scripts 181 | 182 | print(f"Scripts Processed: {total_scripts}") 183 | print(f"Total Validations: {total_validations}") 184 | print(f"Average Confidence: {avg_confidence:.1%}") 185 | print(f"Total Hallucinations: {total_hallucinations}") 186 | 187 | print(f"\nAggregated Results:") 188 | print(f" ✅ Valid: {total_valid} ({total_valid/total_validations:.1%})") 189 | print(f" ❌ Invalid: {total_invalid} ({total_invalid/total_validations:.1%})") 190 | print(f" 🔍 Not Found: {total_not_found} ({total_not_found/total_validations:.1%})") 191 | 192 | # Show worst performing scripts 193 | print(f"\n🚨 Scripts with Most Hallucinations:") 194 | sorted_results = sorted(results, key=lambda x: len(x['hallucinations_detected']), reverse=True) 195 | for result in sorted_results[:5]: 196 | script_name = Path(result['analysis_metadata']['script_path']).name 197 | hall_count = len(result['hallucinations_detected']) 198 | confidence = result['validation_summary']['overall_confidence'] 199 | print(f" - {script_name}: {hall_count} hallucinations ({confidence:.1%} confidence)") 200 | 201 | print("="*80) 202 | 203 | 204 | async def main(): 205 | """Command-line interface for the AI Hallucination Detector""" 206 | parser = argparse.ArgumentParser( 207 | description="Detect AI coding assistant hallucinations in Python scripts", 208 | formatter_class=argparse.RawDescriptionHelpFormatter, 209 | epilog=""" 210 | Examples: 211 | # Analyze single script 212 | python ai_hallucination_detector.py script.py 213 | 214 | # Analyze multiple scripts 215 | python ai_hallucination_detector.py script1.py script2.py script3.py 216 | 217 | # Specify output directory 218 | python ai_hallucination_detector.py script.py --output-dir reports/ 219 | 220 | # Skip markdown report 221 | python ai_hallucination_detector.py script.py --no-markdown 222 | """ 223 | ) 224 | 225 | parser.add_argument( 226 | 'scripts', 227 | nargs='+', 228 | help='Python script(s) to analyze for hallucinations' 229 | ) 230 | 231 | parser.add_argument( 232 | '--output-dir', 233 | help='Directory to save reports (defaults to script directory)' 234 | ) 235 | 236 | parser.add_argument( 237 | '--no-json', 238 | action='store_true', 239 | help='Skip JSON report generation' 240 | ) 241 | 242 | parser.add_argument( 243 | '--no-markdown', 244 | action='store_true', 245 | help='Skip Markdown report generation' 246 | ) 247 | 248 | parser.add_argument( 249 | '--no-summary', 250 | action='store_true', 251 | help='Skip printing summary to console' 252 | ) 253 | 254 | parser.add_argument( 255 | '--neo4j-uri', 256 | default=None, 257 | help='Neo4j URI (default: from environment NEO4J_URI)' 258 | ) 259 | 260 | parser.add_argument( 261 | '--neo4j-user', 262 | default=None, 263 | help='Neo4j username (default: from environment NEO4J_USER)' 264 | ) 265 | 266 | parser.add_argument( 267 | '--neo4j-password', 268 | default=None, 269 | help='Neo4j password (default: from environment NEO4J_PASSWORD)' 270 | ) 271 | 272 | parser.add_argument( 273 | '--verbose', 274 | action='store_true', 275 | help='Enable verbose logging' 276 | ) 277 | 278 | args = parser.parse_args() 279 | 280 | if args.verbose: 281 | logging.getLogger().setLevel(logging.INFO) 282 | # Only enable debug for our modules, not neo4j 283 | logging.getLogger('neo4j').setLevel(logging.WARNING) 284 | logging.getLogger('neo4j.pool').setLevel(logging.WARNING) 285 | logging.getLogger('neo4j.io').setLevel(logging.WARNING) 286 | 287 | # Load environment variables 288 | load_dotenv() 289 | 290 | # Get Neo4j credentials 291 | neo4j_uri = args.neo4j_uri or os.environ.get('NEO4J_URI', 'bolt://localhost:7687') 292 | neo4j_user = args.neo4j_user or os.environ.get('NEO4J_USER', 'neo4j') 293 | neo4j_password = args.neo4j_password or os.environ.get('NEO4J_PASSWORD', 'password') 294 | 295 | if not neo4j_password or neo4j_password == 'password': 296 | logger.error("Please set NEO4J_PASSWORD environment variable or use --neo4j-password") 297 | sys.exit(1) 298 | 299 | # Initialize detector 300 | detector = AIHallucinationDetector(neo4j_uri, neo4j_user, neo4j_password) 301 | 302 | try: 303 | await detector.initialize() 304 | 305 | # Process scripts 306 | if len(args.scripts) == 1: 307 | # Single script mode 308 | await detector.detect_hallucinations( 309 | script_path=args.scripts[0], 310 | output_dir=args.output_dir, 311 | save_json=not args.no_json, 312 | save_markdown=not args.no_markdown, 313 | print_summary=not args.no_summary 314 | ) 315 | else: 316 | # Batch mode 317 | await detector.batch_detect( 318 | script_paths=args.scripts, 319 | output_dir=args.output_dir 320 | ) 321 | 322 | except KeyboardInterrupt: 323 | logger.info("Detection interrupted by user") 324 | sys.exit(1) 325 | 326 | except Exception as e: 327 | logger.error(f"Detection failed: {str(e)}") 328 | sys.exit(1) 329 | 330 | finally: 331 | await detector.close() 332 | 333 | 334 | if __name__ == "__main__": 335 | asyncio.run(main()) -------------------------------------------------------------------------------- /knowledge_graphs/ai_script_analyzer.py: -------------------------------------------------------------------------------- 1 | """ 2 | AI Script Analyzer 3 | 4 | Parses Python scripts generated by AI coding assistants using AST to extract: 5 | - Import statements and their usage 6 | - Class instantiations and method calls 7 | - Function calls with parameters 8 | - Attribute access patterns 9 | - Variable type tracking 10 | """ 11 | 12 | import ast 13 | import logging 14 | from pathlib import Path 15 | from typing import Dict, List, Set, Any, Optional, Tuple 16 | from dataclasses import dataclass, field 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | @dataclass 22 | class ImportInfo: 23 | """Information about an import statement""" 24 | module: str 25 | name: str 26 | alias: Optional[str] = None 27 | is_from_import: bool = False 28 | line_number: int = 0 29 | 30 | 31 | @dataclass 32 | class MethodCall: 33 | """Information about a method call""" 34 | object_name: str 35 | method_name: str 36 | args: List[str] 37 | kwargs: Dict[str, str] 38 | line_number: int 39 | object_type: Optional[str] = None # Inferred class type 40 | 41 | 42 | @dataclass 43 | class AttributeAccess: 44 | """Information about attribute access""" 45 | object_name: str 46 | attribute_name: str 47 | line_number: int 48 | object_type: Optional[str] = None # Inferred class type 49 | 50 | 51 | @dataclass 52 | class FunctionCall: 53 | """Information about a function call""" 54 | function_name: str 55 | args: List[str] 56 | kwargs: Dict[str, str] 57 | line_number: int 58 | full_name: Optional[str] = None # Module.function_name 59 | 60 | 61 | @dataclass 62 | class ClassInstantiation: 63 | """Information about class instantiation""" 64 | variable_name: str 65 | class_name: str 66 | args: List[str] 67 | kwargs: Dict[str, str] 68 | line_number: int 69 | full_class_name: Optional[str] = None # Module.ClassName 70 | 71 | 72 | @dataclass 73 | class AnalysisResult: 74 | """Complete analysis results for a Python script""" 75 | file_path: str 76 | imports: List[ImportInfo] = field(default_factory=list) 77 | class_instantiations: List[ClassInstantiation] = field(default_factory=list) 78 | method_calls: List[MethodCall] = field(default_factory=list) 79 | attribute_accesses: List[AttributeAccess] = field(default_factory=list) 80 | function_calls: List[FunctionCall] = field(default_factory=list) 81 | variable_types: Dict[str, str] = field(default_factory=dict) # variable_name -> class_type 82 | errors: List[str] = field(default_factory=list) 83 | 84 | 85 | class AIScriptAnalyzer: 86 | """Analyzes AI-generated Python scripts for validation against knowledge graph""" 87 | 88 | def __init__(self): 89 | self.import_map: Dict[str, str] = {} # alias -> actual_module_name 90 | self.variable_types: Dict[str, str] = {} # variable_name -> class_type 91 | self.context_manager_vars: Dict[str, Tuple[int, int, str]] = {} # var_name -> (start_line, end_line, type) 92 | 93 | def analyze_script(self, script_path: str) -> AnalysisResult: 94 | """Analyze a Python script and extract all relevant information""" 95 | try: 96 | with open(script_path, 'r', encoding='utf-8') as f: 97 | content = f.read() 98 | 99 | tree = ast.parse(content) 100 | result = AnalysisResult(file_path=script_path) 101 | 102 | # Reset state for new analysis 103 | self.import_map.clear() 104 | self.variable_types.clear() 105 | self.context_manager_vars.clear() 106 | 107 | # Track processed nodes to avoid duplicates 108 | self.processed_calls = set() 109 | self.method_call_attributes = set() 110 | 111 | # First pass: collect imports and build import map 112 | for node in ast.walk(tree): 113 | if isinstance(node, (ast.Import, ast.ImportFrom)): 114 | self._extract_imports(node, result) 115 | 116 | # Second pass: analyze usage patterns 117 | for node in ast.walk(tree): 118 | self._analyze_node(node, result) 119 | 120 | # Set inferred types on method calls and attribute accesses 121 | self._infer_object_types(result) 122 | 123 | result.variable_types = self.variable_types.copy() 124 | 125 | return result 126 | 127 | except Exception as e: 128 | error_msg = f"Failed to analyze script {script_path}: {str(e)}" 129 | logger.error(error_msg) 130 | result = AnalysisResult(file_path=script_path) 131 | result.errors.append(error_msg) 132 | return result 133 | 134 | def _extract_imports(self, node: ast.AST, result: AnalysisResult): 135 | """Extract import information and build import mapping""" 136 | line_num = getattr(node, 'lineno', 0) 137 | 138 | if isinstance(node, ast.Import): 139 | for alias in node.names: 140 | import_name = alias.name 141 | alias_name = alias.asname or import_name 142 | 143 | result.imports.append(ImportInfo( 144 | module=import_name, 145 | name=import_name, 146 | alias=alias.asname, 147 | is_from_import=False, 148 | line_number=line_num 149 | )) 150 | 151 | self.import_map[alias_name] = import_name 152 | 153 | elif isinstance(node, ast.ImportFrom): 154 | module = node.module or "" 155 | for alias in node.names: 156 | import_name = alias.name 157 | alias_name = alias.asname or import_name 158 | 159 | result.imports.append(ImportInfo( 160 | module=module, 161 | name=import_name, 162 | alias=alias.asname, 163 | is_from_import=True, 164 | line_number=line_num 165 | )) 166 | 167 | # Map alias to full module.name 168 | if module: 169 | full_name = f"{module}.{import_name}" 170 | self.import_map[alias_name] = full_name 171 | else: 172 | self.import_map[alias_name] = import_name 173 | 174 | def _analyze_node(self, node: ast.AST, result: AnalysisResult): 175 | """Analyze individual AST nodes for usage patterns""" 176 | line_num = getattr(node, 'lineno', 0) 177 | 178 | # Assignments (class instantiations and method call results) 179 | if isinstance(node, ast.Assign): 180 | if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): 181 | if isinstance(node.value, ast.Call): 182 | # Check if it's a class instantiation or method call 183 | if isinstance(node.value.func, ast.Name): 184 | # Direct function/class call 185 | self._extract_class_instantiation(node, result) 186 | # Mark this call as processed to avoid duplicate processing 187 | self.processed_calls.add(id(node.value)) 188 | elif isinstance(node.value.func, ast.Attribute): 189 | # Method call - track the variable assignment for type inference 190 | var_name = node.targets[0].id 191 | self._track_method_result_assignment(node.value, var_name) 192 | # Still process the method call 193 | self._extract_method_call(node.value, result) 194 | self.processed_calls.add(id(node.value)) 195 | 196 | # AsyncWith statements (context managers) 197 | elif isinstance(node, ast.AsyncWith): 198 | self._handle_async_with(node, result) 199 | elif isinstance(node, ast.With): 200 | self._handle_with(node, result) 201 | 202 | # Method calls and function calls 203 | elif isinstance(node, ast.Call): 204 | # Skip if this call was already processed as part of an assignment 205 | if id(node) in self.processed_calls: 206 | return 207 | 208 | if isinstance(node.func, ast.Attribute): 209 | self._extract_method_call(node, result) 210 | # Mark this attribute as used in method call to avoid duplicate processing 211 | self.method_call_attributes.add(id(node.func)) 212 | elif isinstance(node.func, ast.Name): 213 | # Check if this is likely a class instantiation (based on imported classes) 214 | func_name = node.func.id 215 | full_name = self._resolve_full_name(func_name) 216 | 217 | # If this is a known imported class, treat as class instantiation 218 | if self._is_likely_class_instantiation(func_name, full_name): 219 | self._extract_nested_class_instantiation(node, result) 220 | else: 221 | self._extract_function_call(node, result) 222 | 223 | # Attribute access (not in call context) 224 | elif isinstance(node, ast.Attribute): 225 | # Skip if this attribute was already processed as part of a method call 226 | if id(node) in self.method_call_attributes: 227 | return 228 | self._extract_attribute_access(node, result) 229 | 230 | def _extract_class_instantiation(self, node: ast.Assign, result: AnalysisResult): 231 | """Extract class instantiation from assignment""" 232 | target = node.targets[0] 233 | call = node.value 234 | line_num = getattr(node, 'lineno', 0) 235 | 236 | if isinstance(target, ast.Name) and isinstance(call, ast.Call): 237 | var_name = target.id 238 | class_name = self._get_name_from_call(call.func) 239 | 240 | if class_name: 241 | args = [self._get_arg_representation(arg) for arg in call.args] 242 | kwargs = { 243 | kw.arg: self._get_arg_representation(kw.value) 244 | for kw in call.keywords if kw.arg 245 | } 246 | 247 | # Resolve full class name using import map 248 | full_class_name = self._resolve_full_name(class_name) 249 | 250 | instantiation = ClassInstantiation( 251 | variable_name=var_name, 252 | class_name=class_name, 253 | args=args, 254 | kwargs=kwargs, 255 | line_number=line_num, 256 | full_class_name=full_class_name 257 | ) 258 | 259 | result.class_instantiations.append(instantiation) 260 | 261 | # Track variable type for later method call analysis 262 | self.variable_types[var_name] = full_class_name or class_name 263 | 264 | def _extract_method_call(self, node: ast.Call, result: AnalysisResult): 265 | """Extract method call information""" 266 | if isinstance(node.func, ast.Attribute): 267 | line_num = getattr(node, 'lineno', 0) 268 | 269 | # Get object and method names 270 | obj_name = self._get_name_from_node(node.func.value) 271 | method_name = node.func.attr 272 | 273 | if obj_name and method_name: 274 | args = [self._get_arg_representation(arg) for arg in node.args] 275 | kwargs = { 276 | kw.arg: self._get_arg_representation(kw.value) 277 | for kw in node.keywords if kw.arg 278 | } 279 | 280 | method_call = MethodCall( 281 | object_name=obj_name, 282 | method_name=method_name, 283 | args=args, 284 | kwargs=kwargs, 285 | line_number=line_num, 286 | object_type=self.variable_types.get(obj_name) 287 | ) 288 | 289 | result.method_calls.append(method_call) 290 | 291 | def _extract_function_call(self, node: ast.Call, result: AnalysisResult): 292 | """Extract function call information""" 293 | if isinstance(node.func, ast.Name): 294 | line_num = getattr(node, 'lineno', 0) 295 | func_name = node.func.id 296 | 297 | args = [self._get_arg_representation(arg) for arg in node.args] 298 | kwargs = { 299 | kw.arg: self._get_arg_representation(kw.value) 300 | for kw in node.keywords if kw.arg 301 | } 302 | 303 | # Resolve full function name using import map 304 | full_func_name = self._resolve_full_name(func_name) 305 | 306 | function_call = FunctionCall( 307 | function_name=func_name, 308 | args=args, 309 | kwargs=kwargs, 310 | line_number=line_num, 311 | full_name=full_func_name 312 | ) 313 | 314 | result.function_calls.append(function_call) 315 | 316 | def _extract_attribute_access(self, node: ast.Attribute, result: AnalysisResult): 317 | """Extract attribute access information""" 318 | line_num = getattr(node, 'lineno', 0) 319 | 320 | obj_name = self._get_name_from_node(node.value) 321 | attr_name = node.attr 322 | 323 | if obj_name and attr_name: 324 | attribute_access = AttributeAccess( 325 | object_name=obj_name, 326 | attribute_name=attr_name, 327 | line_number=line_num, 328 | object_type=self.variable_types.get(obj_name) 329 | ) 330 | 331 | result.attribute_accesses.append(attribute_access) 332 | 333 | def _infer_object_types(self, result: AnalysisResult): 334 | """Update object types for method calls and attribute accesses""" 335 | for method_call in result.method_calls: 336 | if not method_call.object_type: 337 | # First check context manager variables 338 | obj_type = self._get_context_aware_type(method_call.object_name, method_call.line_number) 339 | if obj_type: 340 | method_call.object_type = obj_type 341 | else: 342 | method_call.object_type = self.variable_types.get(method_call.object_name) 343 | 344 | for attr_access in result.attribute_accesses: 345 | if not attr_access.object_type: 346 | # First check context manager variables 347 | obj_type = self._get_context_aware_type(attr_access.object_name, attr_access.line_number) 348 | if obj_type: 349 | attr_access.object_type = obj_type 350 | else: 351 | attr_access.object_type = self.variable_types.get(attr_access.object_name) 352 | 353 | def _get_context_aware_type(self, var_name: str, line_number: int) -> Optional[str]: 354 | """Get the type of a variable considering its context (e.g., async with scope)""" 355 | if var_name in self.context_manager_vars: 356 | start_line, end_line, var_type = self.context_manager_vars[var_name] 357 | if start_line <= line_number <= end_line: 358 | return var_type 359 | return None 360 | 361 | def _get_name_from_call(self, node: ast.AST) -> Optional[str]: 362 | """Get the name from a call node (for class instantiation)""" 363 | if isinstance(node, ast.Name): 364 | return node.id 365 | elif isinstance(node, ast.Attribute): 366 | value_name = self._get_name_from_node(node.value) 367 | if value_name: 368 | return f"{value_name}.{node.attr}" 369 | return None 370 | 371 | def _get_name_from_node(self, node: ast.AST) -> Optional[str]: 372 | """Get string representation of a node (for object names)""" 373 | if isinstance(node, ast.Name): 374 | return node.id 375 | elif isinstance(node, ast.Attribute): 376 | value_name = self._get_name_from_node(node.value) 377 | if value_name: 378 | return f"{value_name}.{node.attr}" 379 | return None 380 | 381 | def _get_arg_representation(self, node: ast.AST) -> str: 382 | """Get string representation of an argument""" 383 | if isinstance(node, ast.Constant): 384 | return repr(node.value) 385 | elif isinstance(node, ast.Name): 386 | return node.id 387 | elif isinstance(node, ast.Attribute): 388 | return self._get_name_from_node(node) or "<?>" 389 | elif isinstance(node, ast.Call): 390 | func_name = self._get_name_from_call(node.func) 391 | return f"{func_name}(...)" if func_name else "call(...)" 392 | else: 393 | return f"<{type(node).__name__}>" 394 | 395 | def _is_likely_class_instantiation(self, func_name: str, full_name: Optional[str]) -> bool: 396 | """Determine if a function call is likely a class instantiation""" 397 | # Check if it's a known imported class (classes typically start with uppercase) 398 | if func_name and func_name[0].isupper(): 399 | return True 400 | 401 | # Check if the full name suggests a class (contains known class patterns) 402 | if full_name: 403 | # Common class patterns in module names 404 | class_patterns = [ 405 | 'Model', 'Provider', 'Client', 'Agent', 'Manager', 'Handler', 406 | 'Builder', 'Factory', 'Service', 'Controller', 'Processor' 407 | ] 408 | return any(pattern in full_name for pattern in class_patterns) 409 | 410 | return False 411 | 412 | def _extract_nested_class_instantiation(self, node: ast.Call, result: AnalysisResult): 413 | """Extract class instantiation that's not in direct assignment (e.g., as parameter)""" 414 | line_num = getattr(node, 'lineno', 0) 415 | 416 | if isinstance(node.func, ast.Name): 417 | class_name = node.func.id 418 | 419 | args = [self._get_arg_representation(arg) for arg in node.args] 420 | kwargs = { 421 | kw.arg: self._get_arg_representation(kw.value) 422 | for kw in node.keywords if kw.arg 423 | } 424 | 425 | # Resolve full class name using import map 426 | full_class_name = self._resolve_full_name(class_name) 427 | 428 | # Use a synthetic variable name since this isn't assigned to a variable 429 | var_name = f"<{class_name.lower()}_instance>" 430 | 431 | instantiation = ClassInstantiation( 432 | variable_name=var_name, 433 | class_name=class_name, 434 | args=args, 435 | kwargs=kwargs, 436 | line_number=line_num, 437 | full_class_name=full_class_name 438 | ) 439 | 440 | result.class_instantiations.append(instantiation) 441 | 442 | def _track_method_result_assignment(self, call_node: ast.Call, var_name: str): 443 | """Track when a variable is assigned the result of a method call""" 444 | if isinstance(call_node.func, ast.Attribute): 445 | # For now, we'll use a generic type hint for method results 446 | # In a more sophisticated system, we could look up the return type 447 | self.variable_types[var_name] = "method_result" 448 | 449 | def _handle_async_with(self, node: ast.AsyncWith, result: AnalysisResult): 450 | """Handle async with statements and track context manager variables""" 451 | for item in node.items: 452 | if item.optional_vars and isinstance(item.optional_vars, ast.Name): 453 | var_name = item.optional_vars.id 454 | 455 | # If the context manager is a method call, track the result type 456 | if isinstance(item.context_expr, ast.Call) and isinstance(item.context_expr.func, ast.Attribute): 457 | # Extract and process the method call 458 | self._extract_method_call(item.context_expr, result) 459 | self.processed_calls.add(id(item.context_expr)) 460 | 461 | # Track context manager scope for pydantic_ai run_stream calls 462 | obj_name = self._get_name_from_node(item.context_expr.func.value) 463 | method_name = item.context_expr.func.attr 464 | 465 | if (obj_name and obj_name in self.variable_types and 466 | 'pydantic_ai' in str(self.variable_types[obj_name]) and 467 | method_name == 'run_stream'): 468 | 469 | # Calculate the scope of this async with block 470 | start_line = getattr(node, 'lineno', 0) 471 | end_line = getattr(node, 'end_lineno', start_line + 50) # fallback estimate 472 | 473 | # For run_stream, the return type is specifically StreamedRunResult 474 | # This is the actual return type, not a generic placeholder 475 | self.context_manager_vars[var_name] = (start_line, end_line, "pydantic_ai.StreamedRunResult") 476 | 477 | def _handle_with(self, node: ast.With, result: AnalysisResult): 478 | """Handle regular with statements and track context manager variables""" 479 | for item in node.items: 480 | if item.optional_vars and isinstance(item.optional_vars, ast.Name): 481 | var_name = item.optional_vars.id 482 | 483 | # If the context manager is a method call, track the result type 484 | if isinstance(item.context_expr, ast.Call) and isinstance(item.context_expr.func, ast.Attribute): 485 | # Extract and process the method call 486 | self._extract_method_call(item.context_expr, result) 487 | self.processed_calls.add(id(item.context_expr)) 488 | 489 | # Track basic type information 490 | self.variable_types[var_name] = "context_manager_result" 491 | 492 | def _resolve_full_name(self, name: str) -> Optional[str]: 493 | """Resolve a name to its full module.name using import map""" 494 | # Check if it's a direct import mapping 495 | if name in self.import_map: 496 | return self.import_map[name] 497 | 498 | # Check if it's a dotted name with first part in import map 499 | parts = name.split('.') 500 | if len(parts) > 1 and parts[0] in self.import_map: 501 | base_module = self.import_map[parts[0]] 502 | return f"{base_module}.{'.'.join(parts[1:])}" 503 | 504 | return None 505 | 506 | 507 | def analyze_ai_script(script_path: str) -> AnalysisResult: 508 | """Convenience function to analyze a single AI-generated script""" 509 | analyzer = AIScriptAnalyzer() 510 | return analyzer.analyze_script(script_path) 511 | 512 | 513 | if __name__ == "__main__": 514 | # Example usage 515 | import sys 516 | 517 | if len(sys.argv) != 2: 518 | print("Usage: python ai_script_analyzer.py <script_path>") 519 | sys.exit(1) 520 | 521 | script_path = sys.argv[1] 522 | result = analyze_ai_script(script_path) 523 | 524 | print(f"Analysis Results for: {result.file_path}") 525 | print(f"Imports: {len(result.imports)}") 526 | print(f"Class Instantiations: {len(result.class_instantiations)}") 527 | print(f"Method Calls: {len(result.method_calls)}") 528 | print(f"Function Calls: {len(result.function_calls)}") 529 | print(f"Attribute Accesses: {len(result.attribute_accesses)}") 530 | 531 | if result.errors: 532 | print(f"Errors: {result.errors}") -------------------------------------------------------------------------------- /knowledge_graphs/hallucination_reporter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hallucination Reporter 3 | 4 | Generates comprehensive reports about AI coding assistant hallucinations 5 | detected in Python scripts. Supports multiple output formats. 6 | """ 7 | 8 | import json 9 | import logging 10 | from datetime import datetime, timezone 11 | from pathlib import Path 12 | from typing import Dict, List, Any, Optional 13 | 14 | from knowledge_graph_validator import ( 15 | ScriptValidationResult, ValidationStatus, ValidationResult 16 | ) 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class HallucinationReporter: 22 | """Generates reports about detected hallucinations""" 23 | 24 | def __init__(self): 25 | self.report_timestamp = datetime.now(timezone.utc) 26 | 27 | def generate_comprehensive_report(self, validation_result: ScriptValidationResult) -> Dict[str, Any]: 28 | """Generate a comprehensive report in JSON format""" 29 | 30 | # Categorize validations by status (knowledge graph items only) 31 | valid_items = [] 32 | invalid_items = [] 33 | uncertain_items = [] 34 | not_found_items = [] 35 | 36 | # Process imports (only knowledge graph ones) 37 | for val in validation_result.import_validations: 38 | if not val.validation.details.get('in_knowledge_graph', False): 39 | continue # Skip external libraries 40 | item = { 41 | 'type': 'IMPORT', 42 | 'name': val.import_info.module, 43 | 'line': val.import_info.line_number, 44 | 'status': val.validation.status.value, 45 | 'confidence': val.validation.confidence, 46 | 'message': val.validation.message, 47 | 'details': { 48 | 'is_from_import': val.import_info.is_from_import, 49 | 'alias': val.import_info.alias, 50 | 'available_classes': val.available_classes, 51 | 'available_functions': val.available_functions 52 | } 53 | } 54 | self._categorize_item(item, val.validation.status, valid_items, invalid_items, uncertain_items, not_found_items) 55 | 56 | # Process classes (only knowledge graph ones) 57 | for val in validation_result.class_validations: 58 | class_name = val.class_instantiation.full_class_name or val.class_instantiation.class_name 59 | if not self._is_from_knowledge_graph(class_name, validation_result): 60 | continue # Skip external classes 61 | item = { 62 | 'type': 'CLASS_INSTANTIATION', 63 | 'name': val.class_instantiation.class_name, 64 | 'full_name': val.class_instantiation.full_class_name, 65 | 'variable': val.class_instantiation.variable_name, 66 | 'line': val.class_instantiation.line_number, 67 | 'status': val.validation.status.value, 68 | 'confidence': val.validation.confidence, 69 | 'message': val.validation.message, 70 | 'details': { 71 | 'args_provided': val.class_instantiation.args, 72 | 'kwargs_provided': list(val.class_instantiation.kwargs.keys()), 73 | 'constructor_params': val.constructor_params, 74 | 'parameter_validation': self._serialize_validation_result(val.parameter_validation) if val.parameter_validation else None 75 | } 76 | } 77 | self._categorize_item(item, val.validation.status, valid_items, invalid_items, uncertain_items, not_found_items) 78 | 79 | # Track reported items to avoid duplicates 80 | reported_items = set() 81 | 82 | # Process methods (only knowledge graph ones) 83 | for val in validation_result.method_validations: 84 | if not (val.method_call.object_type and self._is_from_knowledge_graph(val.method_call.object_type, validation_result)): 85 | continue # Skip external methods 86 | 87 | # Create unique key to avoid duplicates 88 | key = (val.method_call.line_number, val.method_call.method_name, val.method_call.object_type) 89 | if key not in reported_items: 90 | reported_items.add(key) 91 | item = { 92 | 'type': 'METHOD_CALL', 93 | 'name': val.method_call.method_name, 94 | 'object': val.method_call.object_name, 95 | 'object_type': val.method_call.object_type, 96 | 'line': val.method_call.line_number, 97 | 'status': val.validation.status.value, 98 | 'confidence': val.validation.confidence, 99 | 'message': val.validation.message, 100 | 'details': { 101 | 'args_provided': val.method_call.args, 102 | 'kwargs_provided': list(val.method_call.kwargs.keys()), 103 | 'expected_params': val.expected_params, 104 | 'parameter_validation': self._serialize_validation_result(val.parameter_validation) if val.parameter_validation else None, 105 | 'suggestions': val.validation.suggestions 106 | } 107 | } 108 | self._categorize_item(item, val.validation.status, valid_items, invalid_items, uncertain_items, not_found_items) 109 | 110 | # Process attributes (only knowledge graph ones) - but skip if already reported as method 111 | for val in validation_result.attribute_validations: 112 | if not (val.attribute_access.object_type and self._is_from_knowledge_graph(val.attribute_access.object_type, validation_result)): 113 | continue # Skip external attributes 114 | 115 | # Create unique key - if this was already reported as a method, skip it 116 | key = (val.attribute_access.line_number, val.attribute_access.attribute_name, val.attribute_access.object_type) 117 | if key not in reported_items: 118 | reported_items.add(key) 119 | item = { 120 | 'type': 'ATTRIBUTE_ACCESS', 121 | 'name': val.attribute_access.attribute_name, 122 | 'object': val.attribute_access.object_name, 123 | 'object_type': val.attribute_access.object_type, 124 | 'line': val.attribute_access.line_number, 125 | 'status': val.validation.status.value, 126 | 'confidence': val.validation.confidence, 127 | 'message': val.validation.message, 128 | 'details': { 129 | 'expected_type': val.expected_type 130 | } 131 | } 132 | self._categorize_item(item, val.validation.status, valid_items, invalid_items, uncertain_items, not_found_items) 133 | 134 | # Process functions (only knowledge graph ones) 135 | for val in validation_result.function_validations: 136 | if not (val.function_call.full_name and self._is_from_knowledge_graph(val.function_call.full_name, validation_result)): 137 | continue # Skip external functions 138 | item = { 139 | 'type': 'FUNCTION_CALL', 140 | 'name': val.function_call.function_name, 141 | 'full_name': val.function_call.full_name, 142 | 'line': val.function_call.line_number, 143 | 'status': val.validation.status.value, 144 | 'confidence': val.validation.confidence, 145 | 'message': val.validation.message, 146 | 'details': { 147 | 'args_provided': val.function_call.args, 148 | 'kwargs_provided': list(val.function_call.kwargs.keys()), 149 | 'expected_params': val.expected_params, 150 | 'parameter_validation': self._serialize_validation_result(val.parameter_validation) if val.parameter_validation else None 151 | } 152 | } 153 | self._categorize_item(item, val.validation.status, valid_items, invalid_items, uncertain_items, not_found_items) 154 | 155 | # Create library summary 156 | library_summary = self._create_library_summary(validation_result) 157 | 158 | # Generate report 159 | report = { 160 | 'analysis_metadata': { 161 | 'script_path': validation_result.script_path, 162 | 'analysis_timestamp': self.report_timestamp.isoformat(), 163 | 'total_imports': len(validation_result.import_validations), 164 | 'total_classes': len(validation_result.class_validations), 165 | 'total_methods': len(validation_result.method_validations), 166 | 'total_attributes': len(validation_result.attribute_validations), 167 | 'total_functions': len(validation_result.function_validations) 168 | }, 169 | 'validation_summary': { 170 | 'overall_confidence': validation_result.overall_confidence, 171 | 'total_validations': len(valid_items) + len(invalid_items) + len(uncertain_items) + len(not_found_items), 172 | 'valid_count': len(valid_items), 173 | 'invalid_count': len(invalid_items), 174 | 'uncertain_count': len(uncertain_items), 175 | 'not_found_count': len(not_found_items), 176 | 'hallucination_rate': len(invalid_items + not_found_items) / max(1, len(valid_items) + len(invalid_items) + len(not_found_items)) 177 | }, 178 | 'libraries_analyzed': library_summary, 179 | 'validation_details': { 180 | 'valid_items': valid_items, 181 | 'invalid_items': invalid_items, 182 | 'uncertain_items': uncertain_items, 183 | 'not_found_items': not_found_items 184 | }, 185 | 'hallucinations_detected': validation_result.hallucinations_detected, 186 | 'recommendations': self._generate_recommendations(validation_result) 187 | } 188 | 189 | return report 190 | 191 | def _is_from_knowledge_graph(self, item_name: str, validation_result) -> bool: 192 | """Check if an item is from a knowledge graph module""" 193 | if not item_name: 194 | return False 195 | 196 | # Get knowledge graph modules from import validations 197 | kg_modules = set() 198 | for val in validation_result.import_validations: 199 | if val.validation.details.get('in_knowledge_graph', False): 200 | kg_modules.add(val.import_info.module) 201 | if '.' in val.import_info.module: 202 | kg_modules.add(val.import_info.module.split('.')[0]) 203 | 204 | # Check if the item belongs to any knowledge graph module 205 | if '.' in item_name: 206 | base_module = item_name.split('.')[0] 207 | return base_module in kg_modules 208 | 209 | return any(item_name in module or module.endswith(item_name) for module in kg_modules) 210 | 211 | def _serialize_validation_result(self, validation_result) -> Dict[str, Any]: 212 | """Convert ValidationResult to JSON-serializable dictionary""" 213 | if validation_result is None: 214 | return None 215 | 216 | return { 217 | 'status': validation_result.status.value, 218 | 'confidence': validation_result.confidence, 219 | 'message': validation_result.message, 220 | 'details': validation_result.details, 221 | 'suggestions': validation_result.suggestions 222 | } 223 | 224 | def _categorize_item(self, item: Dict[str, Any], status: ValidationStatus, 225 | valid_items: List, invalid_items: List, uncertain_items: List, not_found_items: List): 226 | """Categorize validation item by status""" 227 | if status == ValidationStatus.VALID: 228 | valid_items.append(item) 229 | elif status == ValidationStatus.INVALID: 230 | invalid_items.append(item) 231 | elif status == ValidationStatus.UNCERTAIN: 232 | uncertain_items.append(item) 233 | elif status == ValidationStatus.NOT_FOUND: 234 | not_found_items.append(item) 235 | 236 | def _create_library_summary(self, validation_result: ScriptValidationResult) -> List[Dict[str, Any]]: 237 | """Create summary of libraries analyzed""" 238 | library_stats = {} 239 | 240 | # Aggregate stats by library/module 241 | for val in validation_result.import_validations: 242 | module = val.import_info.module 243 | if module not in library_stats: 244 | library_stats[module] = { 245 | 'module_name': module, 246 | 'import_status': val.validation.status.value, 247 | 'import_confidence': val.validation.confidence, 248 | 'classes_used': [], 249 | 'methods_called': [], 250 | 'attributes_accessed': [], 251 | 'functions_called': [] 252 | } 253 | 254 | # Add class usage 255 | for val in validation_result.class_validations: 256 | class_name = val.class_instantiation.class_name 257 | full_name = val.class_instantiation.full_class_name 258 | 259 | # Try to match to library 260 | if full_name: 261 | parts = full_name.split('.') 262 | if len(parts) > 1: 263 | module = '.'.join(parts[:-1]) 264 | if module in library_stats: 265 | library_stats[module]['classes_used'].append({ 266 | 'class_name': class_name, 267 | 'status': val.validation.status.value, 268 | 'confidence': val.validation.confidence 269 | }) 270 | 271 | # Add method usage 272 | for val in validation_result.method_validations: 273 | method_name = val.method_call.method_name 274 | object_type = val.method_call.object_type 275 | 276 | if object_type: 277 | parts = object_type.split('.') 278 | if len(parts) > 1: 279 | module = '.'.join(parts[:-1]) 280 | if module in library_stats: 281 | library_stats[module]['methods_called'].append({ 282 | 'method_name': method_name, 283 | 'class_name': parts[-1], 284 | 'status': val.validation.status.value, 285 | 'confidence': val.validation.confidence 286 | }) 287 | 288 | # Add attribute usage 289 | for val in validation_result.attribute_validations: 290 | attr_name = val.attribute_access.attribute_name 291 | object_type = val.attribute_access.object_type 292 | 293 | if object_type: 294 | parts = object_type.split('.') 295 | if len(parts) > 1: 296 | module = '.'.join(parts[:-1]) 297 | if module in library_stats: 298 | library_stats[module]['attributes_accessed'].append({ 299 | 'attribute_name': attr_name, 300 | 'class_name': parts[-1], 301 | 'status': val.validation.status.value, 302 | 'confidence': val.validation.confidence 303 | }) 304 | 305 | # Add function usage 306 | for val in validation_result.function_validations: 307 | func_name = val.function_call.function_name 308 | full_name = val.function_call.full_name 309 | 310 | if full_name: 311 | parts = full_name.split('.') 312 | if len(parts) > 1: 313 | module = '.'.join(parts[:-1]) 314 | if module in library_stats: 315 | library_stats[module]['functions_called'].append({ 316 | 'function_name': func_name, 317 | 'status': val.validation.status.value, 318 | 'confidence': val.validation.confidence 319 | }) 320 | 321 | return list(library_stats.values()) 322 | 323 | def _generate_recommendations(self, validation_result: ScriptValidationResult) -> List[str]: 324 | """Generate recommendations based on validation results""" 325 | recommendations = [] 326 | 327 | # Only count actual hallucinations (from knowledge graph libraries) 328 | kg_hallucinations = [h for h in validation_result.hallucinations_detected] 329 | 330 | if kg_hallucinations: 331 | method_issues = [h for h in kg_hallucinations if h['type'] == 'METHOD_NOT_FOUND'] 332 | attr_issues = [h for h in kg_hallucinations if h['type'] == 'ATTRIBUTE_NOT_FOUND'] 333 | param_issues = [h for h in kg_hallucinations if h['type'] == 'INVALID_PARAMETERS'] 334 | 335 | if method_issues: 336 | recommendations.append( 337 | f"Found {len(method_issues)} non-existent methods in knowledge graph libraries. " 338 | "Consider checking the official documentation for correct method names." 339 | ) 340 | 341 | if attr_issues: 342 | recommendations.append( 343 | f"Found {len(attr_issues)} non-existent attributes in knowledge graph libraries. " 344 | "Verify attribute names against the class documentation." 345 | ) 346 | 347 | if param_issues: 348 | recommendations.append( 349 | f"Found {len(param_issues)} parameter mismatches in knowledge graph libraries. " 350 | "Check function signatures for correct parameter names and types." 351 | ) 352 | else: 353 | recommendations.append( 354 | "No hallucinations detected in knowledge graph libraries. " 355 | "External library usage appears to be working as expected." 356 | ) 357 | 358 | if validation_result.overall_confidence < 0.7: 359 | recommendations.append( 360 | "Overall confidence is moderate. Most validations were for external libraries not in the knowledge graph." 361 | ) 362 | 363 | return recommendations 364 | 365 | def save_json_report(self, report: Dict[str, Any], output_path: str): 366 | """Save report as JSON file""" 367 | with open(output_path, 'w', encoding='utf-8') as f: 368 | json.dump(report, f, indent=2, ensure_ascii=False) 369 | 370 | logger.info(f"JSON report saved to: {output_path}") 371 | 372 | def save_markdown_report(self, report: Dict[str, Any], output_path: str): 373 | """Save report as Markdown file""" 374 | md_content = self._generate_markdown_content(report) 375 | 376 | with open(output_path, 'w', encoding='utf-8') as f: 377 | f.write(md_content) 378 | 379 | logger.info(f"Markdown report saved to: {output_path}") 380 | 381 | def _generate_markdown_content(self, report: Dict[str, Any]) -> str: 382 | """Generate Markdown content from report""" 383 | md = [] 384 | 385 | # Header 386 | md.append("# AI Hallucination Detection Report") 387 | md.append("") 388 | md.append(f"**Script:** `{report['analysis_metadata']['script_path']}`") 389 | md.append(f"**Analysis Date:** {report['analysis_metadata']['analysis_timestamp']}") 390 | md.append(f"**Overall Confidence:** {report['validation_summary']['overall_confidence']:.2%}") 391 | md.append("") 392 | 393 | # Summary 394 | summary = report['validation_summary'] 395 | md.append("## Summary") 396 | md.append("") 397 | md.append(f"- **Total Validations:** {summary['total_validations']}") 398 | md.append(f"- **Valid:** {summary['valid_count']} ({summary['valid_count']/summary['total_validations']:.1%})") 399 | md.append(f"- **Invalid:** {summary['invalid_count']} ({summary['invalid_count']/summary['total_validations']:.1%})") 400 | md.append(f"- **Not Found:** {summary['not_found_count']} ({summary['not_found_count']/summary['total_validations']:.1%})") 401 | md.append(f"- **Uncertain:** {summary['uncertain_count']} ({summary['uncertain_count']/summary['total_validations']:.1%})") 402 | md.append(f"- **Hallucination Rate:** {summary['hallucination_rate']:.1%}") 403 | md.append("") 404 | 405 | # Hallucinations 406 | if report['hallucinations_detected']: 407 | md.append("## 🚨 Hallucinations Detected") 408 | md.append("") 409 | for i, hallucination in enumerate(report['hallucinations_detected'], 1): 410 | md.append(f"### {i}. {hallucination['type'].replace('_', ' ').title()}") 411 | md.append(f"**Location:** {hallucination['location']}") 412 | md.append(f"**Description:** {hallucination['description']}") 413 | if hallucination.get('suggestion'): 414 | md.append(f"**Suggestion:** {hallucination['suggestion']}") 415 | md.append("") 416 | 417 | # Libraries 418 | if report['libraries_analyzed']: 419 | md.append("## 📚 Libraries Analyzed") 420 | md.append("") 421 | for lib in report['libraries_analyzed']: 422 | md.append(f"### {lib['module_name']}") 423 | md.append(f"**Import Status:** {lib['import_status']}") 424 | md.append(f"**Import Confidence:** {lib['import_confidence']:.2%}") 425 | 426 | if lib['classes_used']: 427 | md.append("**Classes Used:**") 428 | for cls in lib['classes_used']: 429 | status_emoji = "✅" if cls['status'] == 'VALID' else "❌" 430 | md.append(f" - {status_emoji} `{cls['class_name']}` ({cls['confidence']:.1%})") 431 | 432 | if lib['methods_called']: 433 | md.append("**Methods Called:**") 434 | for method in lib['methods_called']: 435 | status_emoji = "✅" if method['status'] == 'VALID' else "❌" 436 | md.append(f" - {status_emoji} `{method['class_name']}.{method['method_name']}()` ({method['confidence']:.1%})") 437 | 438 | if lib['attributes_accessed']: 439 | md.append("**Attributes Accessed:**") 440 | for attr in lib['attributes_accessed']: 441 | status_emoji = "✅" if attr['status'] == 'VALID' else "❌" 442 | md.append(f" - {status_emoji} `{attr['class_name']}.{attr['attribute_name']}` ({attr['confidence']:.1%})") 443 | 444 | if lib['functions_called']: 445 | md.append("**Functions Called:**") 446 | for func in lib['functions_called']: 447 | status_emoji = "✅" if func['status'] == 'VALID' else "❌" 448 | md.append(f" - {status_emoji} `{func['function_name']}()` ({func['confidence']:.1%})") 449 | 450 | md.append("") 451 | 452 | # Recommendations 453 | if report['recommendations']: 454 | md.append("## 💡 Recommendations") 455 | md.append("") 456 | for rec in report['recommendations']: 457 | md.append(f"- {rec}") 458 | md.append("") 459 | 460 | # Detailed Results 461 | md.append("## 📋 Detailed Validation Results") 462 | md.append("") 463 | 464 | # Invalid items 465 | invalid_items = report['validation_details']['invalid_items'] 466 | if invalid_items: 467 | md.append("### ❌ Invalid Items") 468 | md.append("") 469 | for item in invalid_items: 470 | md.append(f"- **{item['type']}** `{item['name']}` (Line {item['line']}) - {item['message']}") 471 | md.append("") 472 | 473 | # Not found items 474 | not_found_items = report['validation_details']['not_found_items'] 475 | if not_found_items: 476 | md.append("### 🔍 Not Found Items") 477 | md.append("") 478 | for item in not_found_items: 479 | md.append(f"- **{item['type']}** `{item['name']}` (Line {item['line']}) - {item['message']}") 480 | md.append("") 481 | 482 | # Valid items (sample) 483 | valid_items = report['validation_details']['valid_items'] 484 | if valid_items: 485 | md.append("### ✅ Valid Items (Sample)") 486 | md.append("") 487 | for item in valid_items[:10]: # Show first 10 488 | md.append(f"- **{item['type']}** `{item['name']}` (Line {item['line']}) - {item['message']}") 489 | if len(valid_items) > 10: 490 | md.append(f"- ... and {len(valid_items) - 10} more valid items") 491 | md.append("") 492 | 493 | return "\n".join(md) 494 | 495 | def print_summary(self, report: Dict[str, Any]): 496 | """Print a concise summary to console""" 497 | print("\n" + "="*80) 498 | print("🤖 AI HALLUCINATION DETECTION REPORT") 499 | print("="*80) 500 | 501 | print(f"Script: {report['analysis_metadata']['script_path']}") 502 | print(f"Overall Confidence: {report['validation_summary']['overall_confidence']:.1%}") 503 | 504 | summary = report['validation_summary'] 505 | print(f"\nValidation Results:") 506 | print(f" ✅ Valid: {summary['valid_count']}") 507 | print(f" ❌ Invalid: {summary['invalid_count']}") 508 | print(f" 🔍 Not Found: {summary['not_found_count']}") 509 | print(f" ❓ Uncertain: {summary['uncertain_count']}") 510 | print(f" 📊 Hallucination Rate: {summary['hallucination_rate']:.1%}") 511 | 512 | if report['hallucinations_detected']: 513 | print(f"\n🚨 {len(report['hallucinations_detected'])} Hallucinations Detected:") 514 | for hall in report['hallucinations_detected'][:5]: # Show first 5 515 | print(f" - {hall['type'].replace('_', ' ').title()} at {hall['location']}") 516 | print(f" {hall['description']}") 517 | 518 | if report['recommendations']: 519 | print(f"\n💡 Recommendations:") 520 | for rec in report['recommendations'][:3]: # Show first 3 521 | print(f" - {rec}") 522 | 523 | print("="*80) -------------------------------------------------------------------------------- /knowledge_graphs/parse_repo_into_neo4j.py: -------------------------------------------------------------------------------- 1 | """ 2 | Direct Neo4j GitHub Code Repository Extractor 3 | 4 | Creates nodes and relationships directly in Neo4j without Graphiti: 5 | - File nodes 6 | - Class nodes 7 | - Method nodes 8 | - Function nodes 9 | - Import relationships 10 | 11 | Bypasses all LLM processing for maximum speed. 12 | """ 13 | 14 | import asyncio 15 | import logging 16 | import os 17 | import subprocess 18 | import shutil 19 | from datetime import datetime, timezone 20 | from pathlib import Path 21 | from typing import List, Optional, Dict, Any, Set 22 | import ast 23 | 24 | from dotenv import load_dotenv 25 | from neo4j import AsyncGraphDatabase 26 | 27 | # Configure logging 28 | logging.basicConfig( 29 | level=logging.INFO, 30 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 31 | datefmt='%Y-%m-%d %H:%M:%S', 32 | ) 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | class Neo4jCodeAnalyzer: 37 | """Analyzes code for direct Neo4j insertion""" 38 | 39 | def __init__(self): 40 | # External modules to ignore 41 | self.external_modules = { 42 | # Python standard library 43 | 'os', 'sys', 'json', 'logging', 'datetime', 'pathlib', 'typing', 'collections', 44 | 'asyncio', 'subprocess', 'ast', 're', 'string', 'urllib', 'http', 'email', 45 | 'time', 'uuid', 'hashlib', 'base64', 'itertools', 'functools', 'operator', 46 | 'contextlib', 'copy', 'pickle', 'tempfile', 'shutil', 'glob', 'fnmatch', 47 | 'io', 'codecs', 'locale', 'platform', 'socket', 'ssl', 'threading', 'queue', 48 | 'multiprocessing', 'concurrent', 'warnings', 'traceback', 'inspect', 49 | 'importlib', 'pkgutil', 'types', 'weakref', 'gc', 'dataclasses', 'enum', 50 | 'abc', 'numbers', 'decimal', 'fractions', 'math', 'cmath', 'random', 'statistics', 51 | 52 | # Common third-party libraries 53 | 'requests', 'urllib3', 'httpx', 'aiohttp', 'flask', 'django', 'fastapi', 54 | 'pydantic', 'sqlalchemy', 'alembic', 'psycopg2', 'pymongo', 'redis', 55 | 'celery', 'pytest', 'unittest', 'mock', 'faker', 'factory', 'hypothesis', 56 | 'numpy', 'pandas', 'matplotlib', 'seaborn', 'scipy', 'sklearn', 'torch', 57 | 'tensorflow', 'keras', 'opencv', 'pillow', 'boto3', 'botocore', 'azure', 58 | 'google', 'openai', 'anthropic', 'langchain', 'transformers', 'huggingface_hub', 59 | 'click', 'typer', 'rich', 'colorama', 'tqdm', 'python-dotenv', 'pyyaml', 60 | 'toml', 'configargparse', 'marshmallow', 'attrs', 'dataclasses-json', 61 | 'jsonschema', 'cerberus', 'voluptuous', 'schema', 'jinja2', 'mako', 62 | 'cryptography', 'bcrypt', 'passlib', 'jwt', 'authlib', 'oauthlib' 63 | } 64 | 65 | def analyze_python_file(self, file_path: Path, repo_root: Path, project_modules: Set[str]) -> Dict[str, Any]: 66 | """Extract structure for direct Neo4j insertion""" 67 | try: 68 | with open(file_path, 'r', encoding='utf-8') as f: 69 | content = f.read() 70 | 71 | tree = ast.parse(content) 72 | relative_path = str(file_path.relative_to(repo_root)) 73 | module_name = self._get_importable_module_name(file_path, repo_root, relative_path) 74 | 75 | # Extract structure 76 | classes = [] 77 | functions = [] 78 | imports = [] 79 | 80 | for node in ast.walk(tree): 81 | if isinstance(node, ast.ClassDef): 82 | # Extract class with its methods and attributes 83 | methods = [] 84 | attributes = [] 85 | 86 | for item in node.body: 87 | if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): 88 | if not item.name.startswith('_'): # Public methods only 89 | # Extract comprehensive parameter info 90 | params = self._extract_function_parameters(item) 91 | 92 | # Get return type annotation 93 | return_type = self._get_name(item.returns) if item.returns else 'Any' 94 | 95 | # Create detailed parameter list for Neo4j storage 96 | params_detailed = [] 97 | for p in params: 98 | param_str = f"{p['name']}:{p['type']}" 99 | if p['optional'] and p['default'] is not None: 100 | param_str += f"={p['default']}" 101 | elif p['optional']: 102 | param_str += "=None" 103 | if p['kind'] != 'positional': 104 | param_str = f"[{p['kind']}] {param_str}" 105 | params_detailed.append(param_str) 106 | 107 | methods.append({ 108 | 'name': item.name, 109 | 'params': params, # Full parameter objects 110 | 'params_detailed': params_detailed, # Detailed string format 111 | 'return_type': return_type, 112 | 'args': [arg.arg for arg in item.args.args if arg.arg != 'self'] # Keep for backwards compatibility 113 | }) 114 | elif isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name): 115 | # Type annotated attributes 116 | if not item.target.id.startswith('_'): 117 | attributes.append({ 118 | 'name': item.target.id, 119 | 'type': self._get_name(item.annotation) if item.annotation else 'Any' 120 | }) 121 | 122 | classes.append({ 123 | 'name': node.name, 124 | 'full_name': f"{module_name}.{node.name}", 125 | 'methods': methods, 126 | 'attributes': attributes 127 | }) 128 | 129 | elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): 130 | # Only top-level functions 131 | if not any(node in cls_node.body for cls_node in ast.walk(tree) if isinstance(cls_node, ast.ClassDef)): 132 | if not node.name.startswith('_'): 133 | # Extract comprehensive parameter info 134 | params = self._extract_function_parameters(node) 135 | 136 | # Get return type annotation 137 | return_type = self._get_name(node.returns) if node.returns else 'Any' 138 | 139 | # Create detailed parameter list for Neo4j storage 140 | params_detailed = [] 141 | for p in params: 142 | param_str = f"{p['name']}:{p['type']}" 143 | if p['optional'] and p['default'] is not None: 144 | param_str += f"={p['default']}" 145 | elif p['optional']: 146 | param_str += "=None" 147 | if p['kind'] != 'positional': 148 | param_str = f"[{p['kind']}] {param_str}" 149 | params_detailed.append(param_str) 150 | 151 | # Simple format for backwards compatibility 152 | params_list = [f"{p['name']}:{p['type']}" for p in params] 153 | 154 | functions.append({ 155 | 'name': node.name, 156 | 'full_name': f"{module_name}.{node.name}", 157 | 'params': params, # Full parameter objects 158 | 'params_detailed': params_detailed, # Detailed string format 159 | 'params_list': params_list, # Simple string format for backwards compatibility 160 | 'return_type': return_type, 161 | 'args': [arg.arg for arg in node.args.args] # Keep for backwards compatibility 162 | }) 163 | 164 | elif isinstance(node, (ast.Import, ast.ImportFrom)): 165 | # Track internal imports only 166 | if isinstance(node, ast.Import): 167 | for alias in node.names: 168 | if self._is_likely_internal(alias.name, project_modules): 169 | imports.append(alias.name) 170 | elif isinstance(node, ast.ImportFrom) and node.module: 171 | if (node.module.startswith('.') or self._is_likely_internal(node.module, project_modules)): 172 | imports.append(node.module) 173 | 174 | return { 175 | 'module_name': module_name, 176 | 'file_path': relative_path, 177 | 'classes': classes, 178 | 'functions': functions, 179 | 'imports': list(set(imports)), # Remove duplicates 180 | 'line_count': len(content.splitlines()) 181 | } 182 | 183 | except Exception as e: 184 | logger.warning(f"Could not analyze {file_path}: {e}") 185 | return None 186 | 187 | def _is_likely_internal(self, import_name: str, project_modules: Set[str]) -> bool: 188 | """Check if an import is likely internal to the project""" 189 | if not import_name: 190 | return False 191 | 192 | # Relative imports are definitely internal 193 | if import_name.startswith('.'): 194 | return True 195 | 196 | # Check if it's a known external module 197 | base_module = import_name.split('.')[0] 198 | if base_module in self.external_modules: 199 | return False 200 | 201 | # Check if it matches any project module 202 | for project_module in project_modules: 203 | if import_name.startswith(project_module): 204 | return True 205 | 206 | # If it's not obviously external, consider it internal 207 | if (not any(ext in base_module.lower() for ext in ['test', 'mock', 'fake']) and 208 | not base_module.startswith('_') and 209 | len(base_module) > 2): 210 | return True 211 | 212 | return False 213 | 214 | def _get_importable_module_name(self, file_path: Path, repo_root: Path, relative_path: str) -> str: 215 | """Determine the actual importable module name for a Python file""" 216 | # Start with the default: convert file path to module path 217 | default_module = relative_path.replace('/', '.').replace('\\', '.').replace('.py', '') 218 | 219 | # Common patterns to detect the actual package root 220 | path_parts = Path(relative_path).parts 221 | 222 | # Look for common package indicators 223 | package_roots = [] 224 | 225 | # Check each directory level for __init__.py to find package boundaries 226 | current_path = repo_root 227 | for i, part in enumerate(path_parts[:-1]): # Exclude the .py file itself 228 | current_path = current_path / part 229 | if (current_path / '__init__.py').exists(): 230 | # This is a package directory, mark it as a potential root 231 | package_roots.append(i) 232 | 233 | if package_roots: 234 | # Use the first (outermost) package as the root 235 | package_start = package_roots[0] 236 | module_parts = path_parts[package_start:] 237 | module_name = '.'.join(module_parts).replace('.py', '') 238 | return module_name 239 | 240 | # Fallback: look for common Python project structures 241 | # Skip common non-package directories 242 | skip_dirs = {'src', 'lib', 'source', 'python', 'pkg', 'packages'} 243 | 244 | # Find the first directory that's not in skip_dirs 245 | filtered_parts = [] 246 | for part in path_parts: 247 | if part.lower() not in skip_dirs or filtered_parts: # Once we start including, include everything 248 | filtered_parts.append(part) 249 | 250 | if filtered_parts: 251 | module_name = '.'.join(filtered_parts).replace('.py', '') 252 | return module_name 253 | 254 | # Final fallback: use the default 255 | return default_module 256 | 257 | def _extract_function_parameters(self, func_node): 258 | """Comprehensive parameter extraction from function definition""" 259 | params = [] 260 | 261 | # Regular positional arguments 262 | for i, arg in enumerate(func_node.args.args): 263 | if arg.arg == 'self': 264 | continue 265 | 266 | param_info = { 267 | 'name': arg.arg, 268 | 'type': self._get_name(arg.annotation) if arg.annotation else 'Any', 269 | 'kind': 'positional', 270 | 'optional': False, 271 | 'default': None 272 | } 273 | 274 | # Check if this argument has a default value 275 | defaults_start = len(func_node.args.args) - len(func_node.args.defaults) 276 | if i >= defaults_start: 277 | default_idx = i - defaults_start 278 | if default_idx < len(func_node.args.defaults): 279 | param_info['optional'] = True 280 | param_info['default'] = self._get_default_value(func_node.args.defaults[default_idx]) 281 | 282 | params.append(param_info) 283 | 284 | # *args parameter 285 | if func_node.args.vararg: 286 | params.append({ 287 | 'name': f"*{func_node.args.vararg.arg}", 288 | 'type': self._get_name(func_node.args.vararg.annotation) if func_node.args.vararg.annotation else 'Any', 289 | 'kind': 'var_positional', 290 | 'optional': True, 291 | 'default': None 292 | }) 293 | 294 | # Keyword-only arguments (after *) 295 | for i, arg in enumerate(func_node.args.kwonlyargs): 296 | param_info = { 297 | 'name': arg.arg, 298 | 'type': self._get_name(arg.annotation) if arg.annotation else 'Any', 299 | 'kind': 'keyword_only', 300 | 'optional': True, # All kwonly args are optional unless explicitly required 301 | 'default': None 302 | } 303 | 304 | # Check for default value 305 | if i < len(func_node.args.kw_defaults) and func_node.args.kw_defaults[i] is not None: 306 | param_info['default'] = self._get_default_value(func_node.args.kw_defaults[i]) 307 | else: 308 | param_info['optional'] = False # No default = required kwonly arg 309 | 310 | params.append(param_info) 311 | 312 | # **kwargs parameter 313 | if func_node.args.kwarg: 314 | params.append({ 315 | 'name': f"**{func_node.args.kwarg.arg}", 316 | 'type': self._get_name(func_node.args.kwarg.annotation) if func_node.args.kwarg.annotation else 'Dict[str, Any]', 317 | 'kind': 'var_keyword', 318 | 'optional': True, 319 | 'default': None 320 | }) 321 | 322 | return params 323 | 324 | def _get_default_value(self, default_node): 325 | """Extract default value from AST node""" 326 | try: 327 | if isinstance(default_node, ast.Constant): 328 | return repr(default_node.value) 329 | elif isinstance(default_node, ast.Name): 330 | return default_node.id 331 | elif isinstance(default_node, ast.Attribute): 332 | return self._get_name(default_node) 333 | elif isinstance(default_node, ast.List): 334 | return "[]" 335 | elif isinstance(default_node, ast.Dict): 336 | return "{}" 337 | else: 338 | return "..." 339 | except Exception: 340 | return "..." 341 | 342 | def _get_name(self, node): 343 | """Extract name from AST node, handling complex types safely""" 344 | if node is None: 345 | return "Any" 346 | 347 | try: 348 | if isinstance(node, ast.Name): 349 | return node.id 350 | elif isinstance(node, ast.Attribute): 351 | if hasattr(node, 'value'): 352 | return f"{self._get_name(node.value)}.{node.attr}" 353 | else: 354 | return node.attr 355 | elif isinstance(node, ast.Subscript): 356 | # Handle List[Type], Dict[K,V], etc. 357 | base = self._get_name(node.value) 358 | if hasattr(node, 'slice'): 359 | if isinstance(node.slice, ast.Name): 360 | return f"{base}[{node.slice.id}]" 361 | elif isinstance(node.slice, ast.Tuple): 362 | elts = [self._get_name(elt) for elt in node.slice.elts] 363 | return f"{base}[{', '.join(elts)}]" 364 | elif isinstance(node.slice, ast.Constant): 365 | return f"{base}[{repr(node.slice.value)}]" 366 | elif isinstance(node.slice, ast.Attribute): 367 | return f"{base}[{self._get_name(node.slice)}]" 368 | elif isinstance(node.slice, ast.Subscript): 369 | return f"{base}[{self._get_name(node.slice)}]" 370 | else: 371 | # Try to get the name of the slice, fallback to Any if it fails 372 | try: 373 | slice_name = self._get_name(node.slice) 374 | return f"{base}[{slice_name}]" 375 | except: 376 | return f"{base}[Any]" 377 | return base 378 | elif isinstance(node, ast.Constant): 379 | return str(node.value) 380 | elif isinstance(node, ast.Str): # Python < 3.8 381 | return f'"{node.s}"' 382 | elif isinstance(node, ast.Tuple): 383 | elts = [self._get_name(elt) for elt in node.elts] 384 | return f"({', '.join(elts)})" 385 | elif isinstance(node, ast.List): 386 | elts = [self._get_name(elt) for elt in node.elts] 387 | return f"[{', '.join(elts)}]" 388 | else: 389 | # Fallback for complex types - return a simple string representation 390 | return "Any" 391 | except Exception: 392 | # If anything goes wrong, return a safe default 393 | return "Any" 394 | 395 | 396 | class DirectNeo4jExtractor: 397 | """Creates nodes and relationships directly in Neo4j""" 398 | 399 | def __init__(self, neo4j_uri: str, neo4j_user: str, neo4j_password: str): 400 | self.neo4j_uri = neo4j_uri 401 | self.neo4j_user = neo4j_user 402 | self.neo4j_password = neo4j_password 403 | self.driver = None 404 | self.analyzer = Neo4jCodeAnalyzer() 405 | 406 | async def initialize(self): 407 | """Initialize Neo4j connection""" 408 | logger.info("Initializing Neo4j connection...") 409 | self.driver = AsyncGraphDatabase.driver( 410 | self.neo4j_uri, 411 | auth=(self.neo4j_user, self.neo4j_password) 412 | ) 413 | 414 | # Clear existing data 415 | # logger.info("Clearing existing data...") 416 | # async with self.driver.session() as session: 417 | # await session.run("MATCH (n) DETACH DELETE n") 418 | 419 | # Create constraints and indexes 420 | logger.info("Creating constraints and indexes...") 421 | async with self.driver.session() as session: 422 | # Create constraints - using MERGE-friendly approach 423 | await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (f:File) REQUIRE f.path IS UNIQUE") 424 | await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (c:Class) REQUIRE c.full_name IS UNIQUE") 425 | # Remove unique constraints for methods/attributes since they can be duplicated across classes 426 | # await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (m:Method) REQUIRE m.full_name IS UNIQUE") 427 | # await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (f:Function) REQUIRE f.full_name IS UNIQUE") 428 | # await session.run("CREATE CONSTRAINT IF NOT EXISTS FOR (a:Attribute) REQUIRE a.full_name IS UNIQUE") 429 | 430 | # Create indexes for performance 431 | await session.run("CREATE INDEX IF NOT EXISTS FOR (f:File) ON (f.name)") 432 | await session.run("CREATE INDEX IF NOT EXISTS FOR (c:Class) ON (c.name)") 433 | await session.run("CREATE INDEX IF NOT EXISTS FOR (m:Method) ON (m.name)") 434 | 435 | logger.info("Neo4j initialized successfully") 436 | 437 | async def clear_repository_data(self, repo_name: str): 438 | """Clear all data for a specific repository""" 439 | logger.info(f"Clearing existing data for repository: {repo_name}") 440 | async with self.driver.session() as session: 441 | # Delete in specific order to avoid constraint issues 442 | 443 | # 1. Delete methods and attributes (they depend on classes) 444 | await session.run(""" 445 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_METHOD]->(m:Method) 446 | DETACH DELETE m 447 | """, repo_name=repo_name) 448 | 449 | await session.run(""" 450 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class)-[:HAS_ATTRIBUTE]->(a:Attribute) 451 | DETACH DELETE a 452 | """, repo_name=repo_name) 453 | 454 | # 2. Delete functions (they depend on files) 455 | await session.run(""" 456 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(func:Function) 457 | DETACH DELETE func 458 | """, repo_name=repo_name) 459 | 460 | # 3. Delete classes (they depend on files) 461 | await session.run(""" 462 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class) 463 | DETACH DELETE c 464 | """, repo_name=repo_name) 465 | 466 | # 4. Delete files (they depend on repository) 467 | await session.run(""" 468 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File) 469 | DETACH DELETE f 470 | """, repo_name=repo_name) 471 | 472 | # 5. Finally delete the repository 473 | await session.run(""" 474 | MATCH (r:Repository {name: $repo_name}) 475 | DETACH DELETE r 476 | """, repo_name=repo_name) 477 | 478 | logger.info(f"Cleared data for repository: {repo_name}") 479 | 480 | async def close(self): 481 | """Close Neo4j connection""" 482 | if self.driver: 483 | await self.driver.close() 484 | 485 | def clone_repo(self, repo_url: str, target_dir: str) -> str: 486 | """Clone repository with shallow clone""" 487 | logger.info(f"Cloning repository to: {target_dir}") 488 | if os.path.exists(target_dir): 489 | logger.info(f"Removing existing directory: {target_dir}") 490 | try: 491 | def handle_remove_readonly(func, path, exc): 492 | try: 493 | if os.path.exists(path): 494 | os.chmod(path, 0o777) 495 | func(path) 496 | except PermissionError: 497 | logger.warning(f"Could not remove {path} - file in use, skipping") 498 | pass 499 | shutil.rmtree(target_dir, onerror=handle_remove_readonly) 500 | except Exception as e: 501 | logger.warning(f"Could not fully remove {target_dir}: {e}. Proceeding anyway...") 502 | 503 | logger.info(f"Running git clone from {repo_url}") 504 | subprocess.run(['git', 'clone', '--depth', '1', repo_url, target_dir], check=True) 505 | logger.info("Repository cloned successfully") 506 | return target_dir 507 | 508 | def get_python_files(self, repo_path: str) -> List[Path]: 509 | """Get Python files, focusing on main source directories""" 510 | python_files = [] 511 | exclude_dirs = { 512 | 'tests', 'test', '__pycache__', '.git', 'venv', 'env', 513 | 'node_modules', 'build', 'dist', '.pytest_cache', 'docs', 514 | 'examples', 'example', 'demo', 'benchmark' 515 | } 516 | 517 | for root, dirs, files in os.walk(repo_path): 518 | dirs[:] = [d for d in dirs if d not in exclude_dirs and not d.startswith('.')] 519 | 520 | for file in files: 521 | if file.endswith('.py') and not file.startswith('test_'): 522 | file_path = Path(root) / file 523 | if (file_path.stat().st_size < 500_000 and 524 | file not in ['setup.py', 'conftest.py']): 525 | python_files.append(file_path) 526 | 527 | return python_files 528 | 529 | async def analyze_repository(self, repo_url: str, temp_dir: str = None): 530 | """Analyze repository and create nodes/relationships in Neo4j""" 531 | repo_name = repo_url.split('/')[-1].replace('.git', '') 532 | logger.info(f"Analyzing repository: {repo_name}") 533 | 534 | # Clear existing data for this repository before re-processing 535 | await self.clear_repository_data(repo_name) 536 | 537 | # Set default temp_dir to repos folder at script level 538 | if temp_dir is None: 539 | script_dir = Path(__file__).parent 540 | temp_dir = str(script_dir / "repos" / repo_name) 541 | 542 | # Clone and analyze 543 | repo_path = Path(self.clone_repo(repo_url, temp_dir)) 544 | 545 | try: 546 | logger.info("Getting Python files...") 547 | python_files = self.get_python_files(str(repo_path)) 548 | logger.info(f"Found {len(python_files)} Python files to analyze") 549 | 550 | # First pass: identify project modules 551 | logger.info("Identifying project modules...") 552 | project_modules = set() 553 | for file_path in python_files: 554 | relative_path = str(file_path.relative_to(repo_path)) 555 | module_parts = relative_path.replace('/', '.').replace('.py', '').split('.') 556 | if len(module_parts) > 0 and not module_parts[0].startswith('.'): 557 | project_modules.add(module_parts[0]) 558 | 559 | logger.info(f"Identified project modules: {sorted(project_modules)}") 560 | 561 | # Second pass: analyze files and collect data 562 | logger.info("Analyzing Python files...") 563 | modules_data = [] 564 | for i, file_path in enumerate(python_files): 565 | if i % 20 == 0: 566 | logger.info(f"Analyzing file {i+1}/{len(python_files)}: {file_path.name}") 567 | 568 | analysis = self.analyzer.analyze_python_file(file_path, repo_path, project_modules) 569 | if analysis: 570 | modules_data.append(analysis) 571 | 572 | logger.info(f"Found {len(modules_data)} files with content") 573 | 574 | # Create nodes and relationships in Neo4j 575 | logger.info("Creating nodes and relationships in Neo4j...") 576 | await self._create_graph(repo_name, modules_data) 577 | 578 | # Print summary 579 | total_classes = sum(len(mod['classes']) for mod in modules_data) 580 | total_methods = sum(len(cls['methods']) for mod in modules_data for cls in mod['classes']) 581 | total_functions = sum(len(mod['functions']) for mod in modules_data) 582 | total_imports = sum(len(mod['imports']) for mod in modules_data) 583 | 584 | print(f"\\n=== Direct Neo4j Repository Analysis for {repo_name} ===") 585 | print(f"Files processed: {len(modules_data)}") 586 | print(f"Classes created: {total_classes}") 587 | print(f"Methods created: {total_methods}") 588 | print(f"Functions created: {total_functions}") 589 | print(f"Import relationships: {total_imports}") 590 | 591 | logger.info(f"Successfully created Neo4j graph for {repo_name}") 592 | 593 | finally: 594 | if os.path.exists(temp_dir): 595 | logger.info(f"Cleaning up temporary directory: {temp_dir}") 596 | try: 597 | def handle_remove_readonly(func, path, exc): 598 | try: 599 | if os.path.exists(path): 600 | os.chmod(path, 0o777) 601 | func(path) 602 | except PermissionError: 603 | logger.warning(f"Could not remove {path} - file in use, skipping") 604 | pass 605 | 606 | shutil.rmtree(temp_dir, onerror=handle_remove_readonly) 607 | logger.info("Cleanup completed") 608 | except Exception as e: 609 | logger.warning(f"Cleanup failed: {e}. Directory may remain at {temp_dir}") 610 | # Don't fail the whole process due to cleanup issues 611 | 612 | async def _create_graph(self, repo_name: str, modules_data: List[Dict]): 613 | """Create all nodes and relationships in Neo4j""" 614 | 615 | async with self.driver.session() as session: 616 | # Create Repository node 617 | await session.run( 618 | "CREATE (r:Repository {name: $repo_name, created_at: datetime()})", 619 | repo_name=repo_name 620 | ) 621 | 622 | nodes_created = 0 623 | relationships_created = 0 624 | 625 | for i, mod in enumerate(modules_data): 626 | # 1. Create File node 627 | await session.run(""" 628 | CREATE (f:File { 629 | name: $name, 630 | path: $path, 631 | module_name: $module_name, 632 | line_count: $line_count, 633 | created_at: datetime() 634 | }) 635 | """, 636 | name=mod['file_path'].split('/')[-1], 637 | path=mod['file_path'], 638 | module_name=mod['module_name'], 639 | line_count=mod['line_count'] 640 | ) 641 | nodes_created += 1 642 | 643 | # 2. Connect File to Repository 644 | await session.run(""" 645 | MATCH (r:Repository {name: $repo_name}) 646 | MATCH (f:File {path: $file_path}) 647 | CREATE (r)-[:CONTAINS]->(f) 648 | """, repo_name=repo_name, file_path=mod['file_path']) 649 | relationships_created += 1 650 | 651 | # 3. Create Class nodes and relationships 652 | for cls in mod['classes']: 653 | # Create Class node using MERGE to avoid duplicates 654 | await session.run(""" 655 | MERGE (c:Class {full_name: $full_name}) 656 | ON CREATE SET c.name = $name, c.created_at = datetime() 657 | """, name=cls['name'], full_name=cls['full_name']) 658 | nodes_created += 1 659 | 660 | # Connect File to Class 661 | await session.run(""" 662 | MATCH (f:File {path: $file_path}) 663 | MATCH (c:Class {full_name: $class_full_name}) 664 | MERGE (f)-[:DEFINES]->(c) 665 | """, file_path=mod['file_path'], class_full_name=cls['full_name']) 666 | relationships_created += 1 667 | 668 | # 4. Create Method nodes - use MERGE to avoid duplicates 669 | for method in cls['methods']: 670 | method_full_name = f"{cls['full_name']}.{method['name']}" 671 | # Create method with unique ID to avoid conflicts 672 | method_id = f"{cls['full_name']}::{method['name']}" 673 | 674 | await session.run(""" 675 | MERGE (m:Method {method_id: $method_id}) 676 | ON CREATE SET m.name = $name, 677 | m.full_name = $full_name, 678 | m.args = $args, 679 | m.params_list = $params_list, 680 | m.params_detailed = $params_detailed, 681 | m.return_type = $return_type, 682 | m.created_at = datetime() 683 | """, 684 | name=method['name'], 685 | full_name=method_full_name, 686 | method_id=method_id, 687 | args=method['args'], 688 | params_list=[f"{p['name']}:{p['type']}" for p in method['params']], # Simple format 689 | params_detailed=method.get('params_detailed', []), # Detailed format 690 | return_type=method['return_type'] 691 | ) 692 | nodes_created += 1 693 | 694 | # Connect Class to Method 695 | await session.run(""" 696 | MATCH (c:Class {full_name: $class_full_name}) 697 | MATCH (m:Method {method_id: $method_id}) 698 | MERGE (c)-[:HAS_METHOD]->(m) 699 | """, 700 | class_full_name=cls['full_name'], 701 | method_id=method_id 702 | ) 703 | relationships_created += 1 704 | 705 | # 5. Create Attribute nodes - use MERGE to avoid duplicates 706 | for attr in cls['attributes']: 707 | attr_full_name = f"{cls['full_name']}.{attr['name']}" 708 | # Create attribute with unique ID to avoid conflicts 709 | attr_id = f"{cls['full_name']}::{attr['name']}" 710 | await session.run(""" 711 | MERGE (a:Attribute {attr_id: $attr_id}) 712 | ON CREATE SET a.name = $name, 713 | a.full_name = $full_name, 714 | a.type = $type, 715 | a.created_at = datetime() 716 | """, 717 | name=attr['name'], 718 | full_name=attr_full_name, 719 | attr_id=attr_id, 720 | type=attr['type'] 721 | ) 722 | nodes_created += 1 723 | 724 | # Connect Class to Attribute 725 | await session.run(""" 726 | MATCH (c:Class {full_name: $class_full_name}) 727 | MATCH (a:Attribute {attr_id: $attr_id}) 728 | MERGE (c)-[:HAS_ATTRIBUTE]->(a) 729 | """, 730 | class_full_name=cls['full_name'], 731 | attr_id=attr_id 732 | ) 733 | relationships_created += 1 734 | 735 | # 6. Create Function nodes (top-level) - use MERGE to avoid duplicates 736 | for func in mod['functions']: 737 | func_id = f"{mod['file_path']}::{func['name']}" 738 | await session.run(""" 739 | MERGE (f:Function {func_id: $func_id}) 740 | ON CREATE SET f.name = $name, 741 | f.full_name = $full_name, 742 | f.args = $args, 743 | f.params_list = $params_list, 744 | f.params_detailed = $params_detailed, 745 | f.return_type = $return_type, 746 | f.created_at = datetime() 747 | """, 748 | name=func['name'], 749 | full_name=func['full_name'], 750 | func_id=func_id, 751 | args=func['args'], 752 | params_list=func.get('params_list', []), # Simple format for backwards compatibility 753 | params_detailed=func.get('params_detailed', []), # Detailed format 754 | return_type=func['return_type'] 755 | ) 756 | nodes_created += 1 757 | 758 | # Connect File to Function 759 | await session.run(""" 760 | MATCH (file:File {path: $file_path}) 761 | MATCH (func:Function {func_id: $func_id}) 762 | MERGE (file)-[:DEFINES]->(func) 763 | """, file_path=mod['file_path'], func_id=func_id) 764 | relationships_created += 1 765 | 766 | # 7. Create Import relationships 767 | for import_name in mod['imports']: 768 | # Try to find the target file 769 | await session.run(""" 770 | MATCH (source:File {path: $source_path}) 771 | OPTIONAL MATCH (target:File) 772 | WHERE target.module_name = $import_name OR target.module_name STARTS WITH $import_name 773 | WITH source, target 774 | WHERE target IS NOT NULL 775 | MERGE (source)-[:IMPORTS]->(target) 776 | """, source_path=mod['file_path'], import_name=import_name) 777 | relationships_created += 1 778 | 779 | if (i + 1) % 10 == 0: 780 | logger.info(f"Processed {i + 1}/{len(modules_data)} files...") 781 | 782 | logger.info(f"Created {nodes_created} nodes and {relationships_created} relationships") 783 | 784 | async def search_graph(self, query_type: str, **kwargs): 785 | """Search the Neo4j graph directly""" 786 | async with self.driver.session() as session: 787 | if query_type == "files_importing": 788 | target = kwargs.get('target') 789 | result = await session.run(""" 790 | MATCH (source:File)-[:IMPORTS]->(target:File) 791 | WHERE target.module_name CONTAINS $target 792 | RETURN source.path as file, target.module_name as imports 793 | """, target=target) 794 | return [{"file": record["file"], "imports": record["imports"]} async for record in result] 795 | 796 | elif query_type == "classes_in_file": 797 | file_path = kwargs.get('file_path') 798 | result = await session.run(""" 799 | MATCH (f:File {path: $file_path})-[:DEFINES]->(c:Class) 800 | RETURN c.name as class_name, c.full_name as full_name 801 | """, file_path=file_path) 802 | return [{"class_name": record["class_name"], "full_name": record["full_name"]} async for record in result] 803 | 804 | elif query_type == "methods_of_class": 805 | class_name = kwargs.get('class_name') 806 | result = await session.run(""" 807 | MATCH (c:Class)-[:HAS_METHOD]->(m:Method) 808 | WHERE c.name CONTAINS $class_name OR c.full_name CONTAINS $class_name 809 | RETURN m.name as method_name, m.args as args 810 | """, class_name=class_name) 811 | return [{"method_name": record["method_name"], "args": record["args"]} async for record in result] 812 | 813 | 814 | async def main(): 815 | """Example usage""" 816 | load_dotenv() 817 | 818 | neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687') 819 | neo4j_user = os.environ.get('NEO4J_USER', 'neo4j') 820 | neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password') 821 | 822 | extractor = DirectNeo4jExtractor(neo4j_uri, neo4j_user, neo4j_password) 823 | 824 | try: 825 | await extractor.initialize() 826 | 827 | # Analyze repository - direct Neo4j, no LLM processing! 828 | # repo_url = "https://github.com/pydantic/pydantic-ai.git" 829 | repo_url = "https://github.com/getzep/graphiti.git" 830 | await extractor.analyze_repository(repo_url) 831 | 832 | # Direct graph queries 833 | print("\\n=== Direct Neo4j Queries ===") 834 | 835 | # Which files import from models? 836 | results = await extractor.search_graph("files_importing", target="models") 837 | print(f"\\nFiles importing from 'models': {len(results)}") 838 | for result in results[:3]: 839 | print(f"- {result['file']} imports {result['imports']}") 840 | 841 | # What classes are in a specific file? 842 | results = await extractor.search_graph("classes_in_file", file_path="pydantic_ai/models/openai.py") 843 | print(f"\\nClasses in openai.py: {len(results)}") 844 | for result in results: 845 | print(f"- {result['class_name']}") 846 | 847 | # What methods does OpenAIModel have? 848 | results = await extractor.search_graph("methods_of_class", class_name="OpenAIModel") 849 | print(f"\\nMethods of OpenAIModel: {len(results)}") 850 | for result in results[:5]: 851 | print(f"- {result['method_name']}({', '.join(result['args'])})") 852 | 853 | finally: 854 | await extractor.close() 855 | 856 | 857 | if __name__ == "__main__": 858 | asyncio.run(main()) -------------------------------------------------------------------------------- /knowledge_graphs/query_knowledge_graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Knowledge Graph Query Tool 4 | 5 | Interactive script to explore what's actually stored in your Neo4j knowledge graph. 6 | Useful for debugging hallucination detection and understanding graph contents. 7 | """ 8 | 9 | import asyncio 10 | import os 11 | from dotenv import load_dotenv 12 | from neo4j import AsyncGraphDatabase 13 | from typing import List, Dict, Any 14 | import argparse 15 | 16 | 17 | class KnowledgeGraphQuerier: 18 | """Interactive tool to query the knowledge graph""" 19 | 20 | def __init__(self, neo4j_uri: str, neo4j_user: str, neo4j_password: str): 21 | self.neo4j_uri = neo4j_uri 22 | self.neo4j_user = neo4j_user 23 | self.neo4j_password = neo4j_password 24 | self.driver = None 25 | 26 | async def initialize(self): 27 | """Initialize Neo4j connection""" 28 | self.driver = AsyncGraphDatabase.driver( 29 | self.neo4j_uri, 30 | auth=(self.neo4j_user, self.neo4j_password) 31 | ) 32 | print("🔗 Connected to Neo4j knowledge graph") 33 | 34 | async def close(self): 35 | """Close Neo4j connection""" 36 | if self.driver: 37 | await self.driver.close() 38 | 39 | async def list_repositories(self): 40 | """List all repositories in the knowledge graph""" 41 | print("\n📚 Repositories in Knowledge Graph:") 42 | print("=" * 50) 43 | 44 | async with self.driver.session() as session: 45 | query = "MATCH (r:Repository) RETURN r.name as name ORDER BY r.name" 46 | result = await session.run(query) 47 | 48 | repos = [] 49 | async for record in result: 50 | repos.append(record['name']) 51 | 52 | if repos: 53 | for i, repo in enumerate(repos, 1): 54 | print(f"{i}. {repo}") 55 | else: 56 | print("No repositories found in knowledge graph.") 57 | 58 | return repos 59 | 60 | async def explore_repository(self, repo_name: str): 61 | """Get overview of a specific repository""" 62 | print(f"\n🔍 Exploring Repository: {repo_name}") 63 | print("=" * 60) 64 | 65 | async with self.driver.session() as session: 66 | # Get file count 67 | files_query = """ 68 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File) 69 | RETURN count(f) as file_count 70 | """ 71 | result = await session.run(files_query, repo_name=repo_name) 72 | file_count = (await result.single())['file_count'] 73 | 74 | # Get class count 75 | classes_query = """ 76 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class) 77 | RETURN count(DISTINCT c) as class_count 78 | """ 79 | result = await session.run(classes_query, repo_name=repo_name) 80 | class_count = (await result.single())['class_count'] 81 | 82 | # Get function count 83 | functions_query = """ 84 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(func:Function) 85 | RETURN count(DISTINCT func) as function_count 86 | """ 87 | result = await session.run(functions_query, repo_name=repo_name) 88 | function_count = (await result.single())['function_count'] 89 | 90 | print(f"📄 Files: {file_count}") 91 | print(f"🏗️ Classes: {class_count}") 92 | print(f"⚙️ Functions: {function_count}") 93 | 94 | async def list_classes(self, repo_name: str = None, limit: int = 20): 95 | """List classes in the knowledge graph""" 96 | title = f"Classes in {repo_name}" if repo_name else "All Classes" 97 | print(f"\n🏗️ {title} (limit {limit}):") 98 | print("=" * 50) 99 | 100 | async with self.driver.session() as session: 101 | if repo_name: 102 | query = """ 103 | MATCH (r:Repository {name: $repo_name})-[:CONTAINS]->(f:File)-[:DEFINES]->(c:Class) 104 | RETURN c.name as name, c.full_name as full_name 105 | ORDER BY c.name 106 | LIMIT $limit 107 | """ 108 | result = await session.run(query, repo_name=repo_name, limit=limit) 109 | else: 110 | query = """ 111 | MATCH (c:Class) 112 | RETURN c.name as name, c.full_name as full_name 113 | ORDER BY c.name 114 | LIMIT $limit 115 | """ 116 | result = await session.run(query, limit=limit) 117 | 118 | classes = [] 119 | async for record in result: 120 | classes.append({ 121 | 'name': record['name'], 122 | 'full_name': record['full_name'] 123 | }) 124 | 125 | if classes: 126 | for i, cls in enumerate(classes, 1): 127 | print(f"{i:2d}. {cls['name']} ({cls['full_name']})") 128 | else: 129 | print("No classes found.") 130 | 131 | return classes 132 | 133 | async def explore_class(self, class_name: str): 134 | """Get detailed information about a specific class""" 135 | print(f"\n🔍 Exploring Class: {class_name}") 136 | print("=" * 60) 137 | 138 | async with self.driver.session() as session: 139 | # Find the class 140 | class_query = """ 141 | MATCH (c:Class) 142 | WHERE c.name = $class_name OR c.full_name = $class_name 143 | RETURN c.name as name, c.full_name as full_name 144 | LIMIT 1 145 | """ 146 | result = await session.run(class_query, class_name=class_name) 147 | class_record = await result.single() 148 | 149 | if not class_record: 150 | print(f"❌ Class '{class_name}' not found in knowledge graph.") 151 | return None 152 | 153 | actual_name = class_record['name'] 154 | full_name = class_record['full_name'] 155 | 156 | print(f"📋 Name: {actual_name}") 157 | print(f"📋 Full Name: {full_name}") 158 | 159 | # Get methods 160 | methods_query = """ 161 | MATCH (c:Class)-[:HAS_METHOD]->(m:Method) 162 | WHERE c.name = $class_name OR c.full_name = $class_name 163 | RETURN m.name as name, m.params_list as params_list, m.params_detailed as params_detailed, m.return_type as return_type 164 | ORDER BY m.name 165 | """ 166 | result = await session.run(methods_query, class_name=class_name) 167 | 168 | methods = [] 169 | async for record in result: 170 | methods.append({ 171 | 'name': record['name'], 172 | 'params_list': record['params_list'] or [], 173 | 'params_detailed': record['params_detailed'] or [], 174 | 'return_type': record['return_type'] or 'Any' 175 | }) 176 | 177 | if methods: 178 | print(f"\n⚙️ Methods ({len(methods)}):") 179 | for i, method in enumerate(methods, 1): 180 | # Use detailed params if available, fall back to simple params 181 | params_to_show = method['params_detailed'] or method['params_list'] 182 | params = ', '.join(params_to_show) if params_to_show else '' 183 | print(f"{i:2d}. {method['name']}({params}) -> {method['return_type']}") 184 | else: 185 | print("\n⚙️ No methods found.") 186 | 187 | # Get attributes 188 | attributes_query = """ 189 | MATCH (c:Class)-[:HAS_ATTRIBUTE]->(a:Attribute) 190 | WHERE c.name = $class_name OR c.full_name = $class_name 191 | RETURN a.name as name, a.type as type 192 | ORDER BY a.name 193 | """ 194 | result = await session.run(attributes_query, class_name=class_name) 195 | 196 | attributes = [] 197 | async for record in result: 198 | attributes.append({ 199 | 'name': record['name'], 200 | 'type': record['type'] or 'Any' 201 | }) 202 | 203 | if attributes: 204 | print(f"\n📋 Attributes ({len(attributes)}):") 205 | for i, attr in enumerate(attributes, 1): 206 | print(f"{i:2d}. {attr['name']}: {attr['type']}") 207 | else: 208 | print("\n📋 No attributes found.") 209 | 210 | return {'methods': methods, 'attributes': attributes} 211 | 212 | async def search_method(self, method_name: str, class_name: str = None): 213 | """Search for methods by name""" 214 | title = f"Method '{method_name}'" 215 | if class_name: 216 | title += f" in class '{class_name}'" 217 | 218 | print(f"\n🔍 Searching for {title}:") 219 | print("=" * 60) 220 | 221 | async with self.driver.session() as session: 222 | if class_name: 223 | query = """ 224 | MATCH (c:Class)-[:HAS_METHOD]->(m:Method) 225 | WHERE (c.name = $class_name OR c.full_name = $class_name) 226 | AND m.name = $method_name 227 | RETURN c.name as class_name, c.full_name as class_full_name, 228 | m.name as method_name, m.params_list as params_list, 229 | m.return_type as return_type, m.args as args 230 | """ 231 | result = await session.run(query, class_name=class_name, method_name=method_name) 232 | else: 233 | query = """ 234 | MATCH (c:Class)-[:HAS_METHOD]->(m:Method) 235 | WHERE m.name = $method_name 236 | RETURN c.name as class_name, c.full_name as class_full_name, 237 | m.name as method_name, m.params_list as params_list, 238 | m.return_type as return_type, m.args as args 239 | ORDER BY c.name 240 | """ 241 | result = await session.run(query, method_name=method_name) 242 | 243 | methods = [] 244 | async for record in result: 245 | methods.append({ 246 | 'class_name': record['class_name'], 247 | 'class_full_name': record['class_full_name'], 248 | 'method_name': record['method_name'], 249 | 'params_list': record['params_list'] or [], 250 | 'return_type': record['return_type'] or 'Any', 251 | 'args': record['args'] or [] 252 | }) 253 | 254 | if methods: 255 | for i, method in enumerate(methods, 1): 256 | params = ', '.join(method['params_list']) if method['params_list'] else '' 257 | print(f"{i}. {method['class_full_name']}.{method['method_name']}({params}) -> {method['return_type']}") 258 | if method['args']: 259 | print(f" Legacy args: {method['args']}") 260 | else: 261 | print(f"❌ Method '{method_name}' not found.") 262 | 263 | return methods 264 | 265 | async def run_custom_query(self, query: str): 266 | """Run a custom Cypher query""" 267 | print(f"\n🔍 Running Custom Query:") 268 | print("=" * 60) 269 | print(f"Query: {query}") 270 | print("-" * 60) 271 | 272 | async with self.driver.session() as session: 273 | try: 274 | result = await session.run(query) 275 | 276 | records = [] 277 | async for record in result: 278 | records.append(dict(record)) 279 | 280 | if records: 281 | for i, record in enumerate(records, 1): 282 | print(f"{i:2d}. {record}") 283 | if i >= 20: # Limit output 284 | print(f"... and {len(records) - 20} more records") 285 | break 286 | else: 287 | print("No results found.") 288 | 289 | return records 290 | 291 | except Exception as e: 292 | print(f"❌ Query error: {str(e)}") 293 | return None 294 | 295 | 296 | async def interactive_mode(querier: KnowledgeGraphQuerier): 297 | """Interactive exploration mode""" 298 | print("\n🚀 Welcome to Knowledge Graph Explorer!") 299 | print("Available commands:") 300 | print(" repos - List all repositories") 301 | print(" explore <repo> - Explore a specific repository") 302 | print(" classes [repo] - List classes (optionally in specific repo)") 303 | print(" class <name> - Explore a specific class") 304 | print(" method <name> [class] - Search for method") 305 | print(" query <cypher> - Run custom Cypher query") 306 | print(" quit - Exit") 307 | print() 308 | 309 | while True: 310 | try: 311 | command = input("🔍 > ").strip() 312 | 313 | if not command: 314 | continue 315 | elif command == "quit": 316 | break 317 | elif command == "repos": 318 | await querier.list_repositories() 319 | elif command.startswith("explore "): 320 | repo_name = command[8:].strip() 321 | await querier.explore_repository(repo_name) 322 | elif command == "classes": 323 | await querier.list_classes() 324 | elif command.startswith("classes "): 325 | repo_name = command[8:].strip() 326 | await querier.list_classes(repo_name) 327 | elif command.startswith("class "): 328 | class_name = command[6:].strip() 329 | await querier.explore_class(class_name) 330 | elif command.startswith("method "): 331 | parts = command[7:].strip().split() 332 | if len(parts) >= 2: 333 | await querier.search_method(parts[0], parts[1]) 334 | else: 335 | await querier.search_method(parts[0]) 336 | elif command.startswith("query "): 337 | query = command[6:].strip() 338 | await querier.run_custom_query(query) 339 | else: 340 | print("❌ Unknown command. Type 'quit' to exit.") 341 | 342 | except KeyboardInterrupt: 343 | print("\n👋 Goodbye!") 344 | break 345 | except Exception as e: 346 | print(f"❌ Error: {str(e)}") 347 | 348 | 349 | async def main(): 350 | """Main function with CLI argument support""" 351 | parser = argparse.ArgumentParser(description="Query the knowledge graph") 352 | parser.add_argument('--repos', action='store_true', help='List repositories') 353 | parser.add_argument('--classes', metavar='REPO', nargs='?', const='', help='List classes') 354 | parser.add_argument('--explore', metavar='REPO', help='Explore repository') 355 | parser.add_argument('--class', dest='class_name', metavar='NAME', help='Explore class') 356 | parser.add_argument('--method', nargs='+', metavar=('NAME', 'CLASS'), help='Search method') 357 | parser.add_argument('--query', metavar='CYPHER', help='Run custom query') 358 | parser.add_argument('--interactive', action='store_true', help='Interactive mode') 359 | 360 | args = parser.parse_args() 361 | 362 | # Load environment 363 | load_dotenv() 364 | neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687') 365 | neo4j_user = os.environ.get('NEO4J_USER', 'neo4j') 366 | neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password') 367 | 368 | querier = KnowledgeGraphQuerier(neo4j_uri, neo4j_user, neo4j_password) 369 | 370 | try: 371 | await querier.initialize() 372 | 373 | # Execute commands based on arguments 374 | if args.repos: 375 | await querier.list_repositories() 376 | elif args.classes is not None: 377 | await querier.list_classes(args.classes if args.classes else None) 378 | elif args.explore: 379 | await querier.explore_repository(args.explore) 380 | elif args.class_name: 381 | await querier.explore_class(args.class_name) 382 | elif args.method: 383 | if len(args.method) >= 2: 384 | await querier.search_method(args.method[0], args.method[1]) 385 | else: 386 | await querier.search_method(args.method[0]) 387 | elif args.query: 388 | await querier.run_custom_query(args.query) 389 | elif args.interactive or len(sys.argv) == 1: 390 | await interactive_mode(querier) 391 | else: 392 | parser.print_help() 393 | 394 | finally: 395 | await querier.close() 396 | 397 | 398 | if __name__ == "__main__": 399 | import sys 400 | asyncio.run(main()) -------------------------------------------------------------------------------- /knowledge_graphs/test_script.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Dict, List, Optional 3 | from dataclasses import dataclass 4 | from pydantic import BaseModel, Field 5 | from dotenv import load_dotenv 6 | from rich.markdown import Markdown 7 | from rich.console import Console 8 | from rich.live import Live 9 | import asyncio 10 | import os 11 | 12 | from pydantic_ai.providers.openai import OpenAIProvider 13 | from pydantic_ai.models.openai import OpenAIModel 14 | from pydantic_ai import Agent, RunContext 15 | from graphiti_core import Graphiti 16 | 17 | load_dotenv() 18 | 19 | # ========== Define dependencies ========== 20 | @dataclass 21 | class GraphitiDependencies: 22 | """Dependencies for the Graphiti agent.""" 23 | graphiti_client: Graphiti 24 | 25 | # ========== Helper function to get model configuration ========== 26 | def get_model(): 27 | """Configure and return the LLM model to use.""" 28 | model_choice = os.getenv('MODEL_CHOICE', 'gpt-4.1-mini') 29 | api_key = os.getenv('OPENAI_API_KEY', 'no-api-key-provided') 30 | 31 | return OpenAIModel(model_choice, provider=OpenAIProvider(api_key=api_key)) 32 | 33 | # ========== Create the Graphiti agent ========== 34 | graphiti_agent = Agent( 35 | get_model(), 36 | system_prompt="""You are a helpful assistant with access to a knowledge graph filled with temporal data about LLMs. 37 | When the user asks you a question, use your search tool to query the knowledge graph and then answer honestly. 38 | Be willing to admit when you didn't find the information necessary to answer the question.""", 39 | deps_type=GraphitiDependencies 40 | ) 41 | 42 | # ========== Define a result model for Graphiti search ========== 43 | class GraphitiSearchResult(BaseModel): 44 | """Model representing a search result from Graphiti.""" 45 | uuid: str = Field(description="The unique identifier for this fact") 46 | fact: str = Field(description="The factual statement retrieved from the knowledge graph") 47 | valid_at: Optional[str] = Field(None, description="When this fact became valid (if known)") 48 | invalid_at: Optional[str] = Field(None, description="When this fact became invalid (if known)") 49 | source_node_uuid: Optional[str] = Field(None, description="UUID of the source node") 50 | 51 | # ========== Graphiti search tool ========== 52 | @graphiti_agent.tool 53 | async def search_graphiti(ctx: RunContext[GraphitiDependencies], query: str) -> List[GraphitiSearchResult]: 54 | """Search the Graphiti knowledge graph with the given query. 55 | 56 | Args: 57 | ctx: The run context containing dependencies 58 | query: The search query to find information in the knowledge graph 59 | 60 | Returns: 61 | A list of search results containing facts that match the query 62 | """ 63 | # Access the Graphiti client from dependencies 64 | graphiti = ctx.deps.graphiti_client 65 | 66 | try: 67 | # Perform the search 68 | results = await graphiti.search(query) 69 | 70 | # Format the results 71 | formatted_results = [] 72 | for result in results: 73 | formatted_result = GraphitiSearchResult( 74 | uuid=result.uuid, 75 | fact=result.fact, 76 | source_node_uuid=result.source_node_uuid if hasattr(result, 'source_node_uuid') else None 77 | ) 78 | 79 | # Add temporal information if available 80 | if hasattr(result, 'valid_at') and result.valid_at: 81 | formatted_result.valid_at = str(result.valid_at) 82 | if hasattr(result, 'invalid_at') and result.invalid_at: 83 | formatted_result.invalid_at = str(result.invalid_at) 84 | 85 | formatted_results.append(formatted_result) 86 | 87 | return formatted_results 88 | except Exception as e: 89 | # Log the error but don't close the connection since it's managed by the dependency 90 | print(f"Error searching Graphiti: {str(e)}") 91 | raise 92 | 93 | # ========== Main execution function ========== 94 | async def main(): 95 | """Run the Graphiti agent with user queries.""" 96 | print("Graphiti Agent - Powered by Pydantic AI, Graphiti, and Neo4j") 97 | print("Enter 'exit' to quit the program.") 98 | 99 | # Neo4j connection parameters 100 | neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687') 101 | neo4j_user = os.environ.get('NEO4J_USER', 'neo4j') 102 | neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password') 103 | 104 | # Initialize Graphiti with Neo4j connection 105 | graphiti_client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) 106 | 107 | # Initialize the graph database with graphiti's indices if needed 108 | try: 109 | await graphiti_client.build_indices_and_constraints() 110 | print("Graphiti indices built successfully.") 111 | except Exception as e: 112 | print(f"Note: {str(e)}") 113 | print("Continuing with existing indices...") 114 | 115 | console = Console() 116 | messages = [] 117 | 118 | try: 119 | while True: 120 | # Get user input 121 | user_input = input("\n[You] ") 122 | 123 | # Check if user wants to exit 124 | if user_input.lower() in ['exit', 'quit', 'bye', 'goodbye']: 125 | print("Goodbye!") 126 | break 127 | 128 | try: 129 | # Process the user input and output the response 130 | print("\n[Assistant]") 131 | with Live('', console=console, vertical_overflow='visible') as live: 132 | # Pass the Graphiti client as a dependency 133 | deps = GraphitiDependencies(graphiti_client=graphiti_client) 134 | 135 | async with graphiti_agent.run_a_stream( 136 | user_input, message_history=messages, deps=deps 137 | ) as result: 138 | curr_message = "" 139 | async for message in result.stream_text(delta=True): 140 | curr_message += message 141 | live.update(Markdown(curr_message)) 142 | 143 | # Add the new messages to the chat history 144 | messages.extend(result.all_messages()) 145 | 146 | except Exception as e: 147 | print(f"\n[Error] An error occurred: {str(e)}") 148 | finally: 149 | # Close the Graphiti connection when done 150 | await graphiti_client.close() 151 | print("\nGraphiti connection closed.") 152 | 153 | if __name__ == "__main__": 154 | try: 155 | asyncio.run(main()) 156 | except KeyboardInterrupt: 157 | print("\nProgram terminated by user.") 158 | except Exception as e: 159 | print(f"\nUnexpected error: {str(e)}") 160 | raise 161 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "crawl4ai-mcp" 3 | version = "0.1.0" 4 | description = "MCP server for integrating web crawling and RAG into AI agents and AI coding assistants" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "crawl4ai==0.6.2", 9 | "mcp==1.7.1", 10 | "supabase==2.15.1", 11 | "openai==1.71.0", 12 | "dotenv==0.9.9", 13 | "sentence-transformers>=4.1.0", 14 | "neo4j>=5.28.1", 15 | ] 16 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for the Crawl4AI MCP server. 3 | """ 4 | import os 5 | import concurrent.futures 6 | from typing import List, Dict, Any, Optional, Tuple 7 | import json 8 | from supabase import create_client, Client 9 | from urllib.parse import urlparse 10 | import openai 11 | import re 12 | import time 13 | 14 | # Load OpenAI API key for embeddings 15 | openai.api_key = os.getenv("OPENAI_API_KEY") 16 | 17 | def get_supabase_client() -> Client: 18 | """ 19 | Get a Supabase client with the URL and key from environment variables. 20 | 21 | Returns: 22 | Supabase client instance 23 | """ 24 | url = os.getenv("SUPABASE_URL") 25 | key = os.getenv("SUPABASE_SERVICE_KEY") 26 | 27 | if not url or not key: 28 | raise ValueError("SUPABASE_URL and SUPABASE_SERVICE_KEY must be set in environment variables") 29 | 30 | return create_client(url, key) 31 | 32 | def create_embeddings_batch(texts: List[str]) -> List[List[float]]: 33 | """ 34 | Create embeddings for multiple texts in a single API call. 35 | 36 | Args: 37 | texts: List of texts to create embeddings for 38 | 39 | Returns: 40 | List of embeddings (each embedding is a list of floats) 41 | """ 42 | if not texts: 43 | return [] 44 | 45 | max_retries = 3 46 | retry_delay = 1.0 # Start with 1 second delay 47 | 48 | for retry in range(max_retries): 49 | try: 50 | response = openai.embeddings.create( 51 | model="text-embedding-3-small", # Hardcoding embedding model for now, will change this later to be more dynamic 52 | input=texts 53 | ) 54 | return [item.embedding for item in response.data] 55 | except Exception as e: 56 | if retry < max_retries - 1: 57 | print(f"Error creating batch embeddings (attempt {retry + 1}/{max_retries}): {e}") 58 | print(f"Retrying in {retry_delay} seconds...") 59 | time.sleep(retry_delay) 60 | retry_delay *= 2 # Exponential backoff 61 | else: 62 | print(f"Failed to create batch embeddings after {max_retries} attempts: {e}") 63 | # Try creating embeddings one by one as fallback 64 | print("Attempting to create embeddings individually...") 65 | embeddings = [] 66 | successful_count = 0 67 | 68 | for i, text in enumerate(texts): 69 | try: 70 | individual_response = openai.embeddings.create( 71 | model="text-embedding-3-small", 72 | input=[text] 73 | ) 74 | embeddings.append(individual_response.data[0].embedding) 75 | successful_count += 1 76 | except Exception as individual_error: 77 | print(f"Failed to create embedding for text {i}: {individual_error}") 78 | # Add zero embedding as fallback 79 | embeddings.append([0.0] * 1536) 80 | 81 | print(f"Successfully created {successful_count}/{len(texts)} embeddings individually") 82 | return embeddings 83 | 84 | def create_embedding(text: str) -> List[float]: 85 | """ 86 | Create an embedding for a single text using OpenAI's API. 87 | 88 | Args: 89 | text: Text to create an embedding for 90 | 91 | Returns: 92 | List of floats representing the embedding 93 | """ 94 | try: 95 | embeddings = create_embeddings_batch([text]) 96 | return embeddings[0] if embeddings else [0.0] * 1536 97 | except Exception as e: 98 | print(f"Error creating embedding: {e}") 99 | # Return empty embedding if there's an error 100 | return [0.0] * 1536 101 | 102 | def generate_contextual_embedding(full_document: str, chunk: str) -> Tuple[str, bool]: 103 | """ 104 | Generate contextual information for a chunk within a document to improve retrieval. 105 | 106 | Args: 107 | full_document: The complete document text 108 | chunk: The specific chunk of text to generate context for 109 | 110 | Returns: 111 | Tuple containing: 112 | - The contextual text that situates the chunk within the document 113 | - Boolean indicating if contextual embedding was performed 114 | """ 115 | model_choice = os.getenv("MODEL_CHOICE") 116 | 117 | try: 118 | # Create the prompt for generating contextual information 119 | prompt = f"""<document> 120 | {full_document[:25000]} 121 | </document> 122 | Here is the chunk we want to situate within the whole document 123 | <chunk> 124 | {chunk} 125 | </chunk> 126 | Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk. Answer only with the succinct context and nothing else.""" 127 | 128 | # Call the OpenAI API to generate contextual information 129 | response = openai.chat.completions.create( 130 | model=model_choice, 131 | messages=[ 132 | {"role": "system", "content": "You are a helpful assistant that provides concise contextual information."}, 133 | {"role": "user", "content": prompt} 134 | ], 135 | temperature=0.3, 136 | max_tokens=200 137 | ) 138 | 139 | # Extract the generated context 140 | context = response.choices[0].message.content.strip() 141 | 142 | # Combine the context with the original chunk 143 | contextual_text = f"{context}\n---\n{chunk}" 144 | 145 | return contextual_text, True 146 | 147 | except Exception as e: 148 | print(f"Error generating contextual embedding: {e}. Using original chunk instead.") 149 | return chunk, False 150 | 151 | def process_chunk_with_context(args): 152 | """ 153 | Process a single chunk with contextual embedding. 154 | This function is designed to be used with concurrent.futures. 155 | 156 | Args: 157 | args: Tuple containing (url, content, full_document) 158 | 159 | Returns: 160 | Tuple containing: 161 | - The contextual text that situates the chunk within the document 162 | - Boolean indicating if contextual embedding was performed 163 | """ 164 | url, content, full_document = args 165 | return generate_contextual_embedding(full_document, content) 166 | 167 | def add_documents_to_supabase( 168 | client: Client, 169 | urls: List[str], 170 | chunk_numbers: List[int], 171 | contents: List[str], 172 | metadatas: List[Dict[str, Any]], 173 | url_to_full_document: Dict[str, str], 174 | batch_size: int = 20 175 | ) -> None: 176 | """ 177 | Add documents to the Supabase crawled_pages table in batches. 178 | Deletes existing records with the same URLs before inserting to prevent duplicates. 179 | 180 | Args: 181 | client: Supabase client 182 | urls: List of URLs 183 | chunk_numbers: List of chunk numbers 184 | contents: List of document contents 185 | metadatas: List of document metadata 186 | url_to_full_document: Dictionary mapping URLs to their full document content 187 | batch_size: Size of each batch for insertion 188 | """ 189 | # Get unique URLs to delete existing records 190 | unique_urls = list(set(urls)) 191 | 192 | # Delete existing records for these URLs in a single operation 193 | try: 194 | if unique_urls: 195 | # Use the .in_() filter to delete all records with matching URLs 196 | client.table("crawled_pages").delete().in_("url", unique_urls).execute() 197 | except Exception as e: 198 | print(f"Batch delete failed: {e}. Trying one-by-one deletion as fallback.") 199 | # Fallback: delete records one by one 200 | for url in unique_urls: 201 | try: 202 | client.table("crawled_pages").delete().eq("url", url).execute() 203 | except Exception as inner_e: 204 | print(f"Error deleting record for URL {url}: {inner_e}") 205 | # Continue with the next URL even if one fails 206 | 207 | # Check if MODEL_CHOICE is set for contextual embeddings 208 | use_contextual_embeddings = os.getenv("USE_CONTEXTUAL_EMBEDDINGS", "false") == "true" 209 | print(f"\n\nUse contextual embeddings: {use_contextual_embeddings}\n\n") 210 | 211 | # Process in batches to avoid memory issues 212 | for i in range(0, len(contents), batch_size): 213 | batch_end = min(i + batch_size, len(contents)) 214 | 215 | # Get batch slices 216 | batch_urls = urls[i:batch_end] 217 | batch_chunk_numbers = chunk_numbers[i:batch_end] 218 | batch_contents = contents[i:batch_end] 219 | batch_metadatas = metadatas[i:batch_end] 220 | 221 | # Apply contextual embedding to each chunk if MODEL_CHOICE is set 222 | if use_contextual_embeddings: 223 | # Prepare arguments for parallel processing 224 | process_args = [] 225 | for j, content in enumerate(batch_contents): 226 | url = batch_urls[j] 227 | full_document = url_to_full_document.get(url, "") 228 | process_args.append((url, content, full_document)) 229 | 230 | # Process in parallel using ThreadPoolExecutor 231 | contextual_contents = [] 232 | with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: 233 | # Submit all tasks and collect results 234 | future_to_idx = {executor.submit(process_chunk_with_context, arg): idx 235 | for idx, arg in enumerate(process_args)} 236 | 237 | # Process results as they complete 238 | for future in concurrent.futures.as_completed(future_to_idx): 239 | idx = future_to_idx[future] 240 | try: 241 | result, success = future.result() 242 | contextual_contents.append(result) 243 | if success: 244 | batch_metadatas[idx]["contextual_embedding"] = True 245 | except Exception as e: 246 | print(f"Error processing chunk {idx}: {e}") 247 | # Use original content as fallback 248 | contextual_contents.append(batch_contents[idx]) 249 | 250 | # Sort results back into original order if needed 251 | if len(contextual_contents) != len(batch_contents): 252 | print(f"Warning: Expected {len(batch_contents)} results but got {len(contextual_contents)}") 253 | # Use original contents as fallback 254 | contextual_contents = batch_contents 255 | else: 256 | # If not using contextual embeddings, use original contents 257 | contextual_contents = batch_contents 258 | 259 | # Create embeddings for the entire batch at once 260 | batch_embeddings = create_embeddings_batch(contextual_contents) 261 | 262 | batch_data = [] 263 | for j in range(len(contextual_contents)): 264 | # Extract metadata fields 265 | chunk_size = len(contextual_contents[j]) 266 | 267 | # Extract source_id from URL 268 | parsed_url = urlparse(batch_urls[j]) 269 | source_id = parsed_url.netloc or parsed_url.path 270 | 271 | # Prepare data for insertion 272 | data = { 273 | "url": batch_urls[j], 274 | "chunk_number": batch_chunk_numbers[j], 275 | "content": contextual_contents[j], # Store original content 276 | "metadata": { 277 | "chunk_size": chunk_size, 278 | **batch_metadatas[j] 279 | }, 280 | "source_id": source_id, # Add source_id field 281 | "embedding": batch_embeddings[j] # Use embedding from contextual content 282 | } 283 | 284 | batch_data.append(data) 285 | 286 | # Insert batch into Supabase with retry logic 287 | max_retries = 3 288 | retry_delay = 1.0 # Start with 1 second delay 289 | 290 | for retry in range(max_retries): 291 | try: 292 | client.table("crawled_pages").insert(batch_data).execute() 293 | # Success - break out of retry loop 294 | break 295 | except Exception as e: 296 | if retry < max_retries - 1: 297 | print(f"Error inserting batch into Supabase (attempt {retry + 1}/{max_retries}): {e}") 298 | print(f"Retrying in {retry_delay} seconds...") 299 | time.sleep(retry_delay) 300 | retry_delay *= 2 # Exponential backoff 301 | else: 302 | # Final attempt failed 303 | print(f"Failed to insert batch after {max_retries} attempts: {e}") 304 | # Optionally, try inserting records one by one as a last resort 305 | print("Attempting to insert records individually...") 306 | successful_inserts = 0 307 | for record in batch_data: 308 | try: 309 | client.table("crawled_pages").insert(record).execute() 310 | successful_inserts += 1 311 | except Exception as individual_error: 312 | print(f"Failed to insert individual record for URL {record['url']}: {individual_error}") 313 | 314 | if successful_inserts > 0: 315 | print(f"Successfully inserted {successful_inserts}/{len(batch_data)} records individually") 316 | 317 | def search_documents( 318 | client: Client, 319 | query: str, 320 | match_count: int = 10, 321 | filter_metadata: Optional[Dict[str, Any]] = None 322 | ) -> List[Dict[str, Any]]: 323 | """ 324 | Search for documents in Supabase using vector similarity. 325 | 326 | Args: 327 | client: Supabase client 328 | query: Query text 329 | match_count: Maximum number of results to return 330 | filter_metadata: Optional metadata filter 331 | 332 | Returns: 333 | List of matching documents 334 | """ 335 | # Create embedding for the query 336 | query_embedding = create_embedding(query) 337 | 338 | # Execute the search using the match_crawled_pages function 339 | try: 340 | # Only include filter parameter if filter_metadata is provided and not empty 341 | params = { 342 | 'query_embedding': query_embedding, 343 | 'match_count': match_count 344 | } 345 | 346 | # Only add the filter if it's actually provided and not empty 347 | if filter_metadata: 348 | params['filter'] = filter_metadata # Pass the dictionary directly, not JSON-encoded 349 | 350 | result = client.rpc('match_crawled_pages', params).execute() 351 | 352 | return result.data 353 | except Exception as e: 354 | print(f"Error searching documents: {e}") 355 | return [] 356 | 357 | 358 | def extract_code_blocks(markdown_content: str, min_length: int = 1000) -> List[Dict[str, Any]]: 359 | """ 360 | Extract code blocks from markdown content along with context. 361 | 362 | Args: 363 | markdown_content: The markdown content to extract code blocks from 364 | min_length: Minimum length of code blocks to extract (default: 1000 characters) 365 | 366 | Returns: 367 | List of dictionaries containing code blocks and their context 368 | """ 369 | code_blocks = [] 370 | 371 | # Skip if content starts with triple backticks (edge case for files wrapped in backticks) 372 | content = markdown_content.strip() 373 | start_offset = 0 374 | if content.startswith('```'): 375 | # Skip the first triple backticks 376 | start_offset = 3 377 | print("Skipping initial triple backticks") 378 | 379 | # Find all occurrences of triple backticks 380 | backtick_positions = [] 381 | pos = start_offset 382 | while True: 383 | pos = markdown_content.find('```', pos) 384 | if pos == -1: 385 | break 386 | backtick_positions.append(pos) 387 | pos += 3 388 | 389 | # Process pairs of backticks 390 | i = 0 391 | while i < len(backtick_positions) - 1: 392 | start_pos = backtick_positions[i] 393 | end_pos = backtick_positions[i + 1] 394 | 395 | # Extract the content between backticks 396 | code_section = markdown_content[start_pos+3:end_pos] 397 | 398 | # Check if there's a language specifier on the first line 399 | lines = code_section.split('\n', 1) 400 | if len(lines) > 1: 401 | # Check if first line is a language specifier (no spaces, common language names) 402 | first_line = lines[0].strip() 403 | if first_line and not ' ' in first_line and len(first_line) < 20: 404 | language = first_line 405 | code_content = lines[1].strip() if len(lines) > 1 else "" 406 | else: 407 | language = "" 408 | code_content = code_section.strip() 409 | else: 410 | language = "" 411 | code_content = code_section.strip() 412 | 413 | # Skip if code block is too short 414 | if len(code_content) < min_length: 415 | i += 2 # Move to next pair 416 | continue 417 | 418 | # Extract context before (1000 chars) 419 | context_start = max(0, start_pos - 1000) 420 | context_before = markdown_content[context_start:start_pos].strip() 421 | 422 | # Extract context after (1000 chars) 423 | context_end = min(len(markdown_content), end_pos + 3 + 1000) 424 | context_after = markdown_content[end_pos + 3:context_end].strip() 425 | 426 | code_blocks.append({ 427 | 'code': code_content, 428 | 'language': language, 429 | 'context_before': context_before, 430 | 'context_after': context_after, 431 | 'full_context': f"{context_before}\n\n{code_content}\n\n{context_after}" 432 | }) 433 | 434 | # Move to next pair (skip the closing backtick we just processed) 435 | i += 2 436 | 437 | return code_blocks 438 | 439 | 440 | def generate_code_example_summary(code: str, context_before: str, context_after: str) -> str: 441 | """ 442 | Generate a summary for a code example using its surrounding context. 443 | 444 | Args: 445 | code: The code example 446 | context_before: Context before the code 447 | context_after: Context after the code 448 | 449 | Returns: 450 | A summary of what the code example demonstrates 451 | """ 452 | model_choice = os.getenv("MODEL_CHOICE") 453 | 454 | # Create the prompt 455 | prompt = f"""<context_before> 456 | {context_before[-500:] if len(context_before) > 500 else context_before} 457 | </context_before> 458 | 459 | <code_example> 460 | {code[:1500] if len(code) > 1500 else code} 461 | </code_example> 462 | 463 | <context_after> 464 | {context_after[:500] if len(context_after) > 500 else context_after} 465 | </context_after> 466 | 467 | Based on the code example and its surrounding context, provide a concise summary (2-3 sentences) that describes what this code example demonstrates and its purpose. Focus on the practical application and key concepts illustrated. 468 | """ 469 | 470 | try: 471 | response = openai.chat.completions.create( 472 | model=model_choice, 473 | messages=[ 474 | {"role": "system", "content": "You are a helpful assistant that provides concise code example summaries."}, 475 | {"role": "user", "content": prompt} 476 | ], 477 | temperature=0.3, 478 | max_tokens=100 479 | ) 480 | 481 | return response.choices[0].message.content.strip() 482 | 483 | except Exception as e: 484 | print(f"Error generating code example summary: {e}") 485 | return "Code example for demonstration purposes." 486 | 487 | 488 | def add_code_examples_to_supabase( 489 | client: Client, 490 | urls: List[str], 491 | chunk_numbers: List[int], 492 | code_examples: List[str], 493 | summaries: List[str], 494 | metadatas: List[Dict[str, Any]], 495 | batch_size: int = 20 496 | ): 497 | """ 498 | Add code examples to the Supabase code_examples table in batches. 499 | 500 | Args: 501 | client: Supabase client 502 | urls: List of URLs 503 | chunk_numbers: List of chunk numbers 504 | code_examples: List of code example contents 505 | summaries: List of code example summaries 506 | metadatas: List of metadata dictionaries 507 | batch_size: Size of each batch for insertion 508 | """ 509 | if not urls: 510 | return 511 | 512 | # Delete existing records for these URLs 513 | unique_urls = list(set(urls)) 514 | for url in unique_urls: 515 | try: 516 | client.table('code_examples').delete().eq('url', url).execute() 517 | except Exception as e: 518 | print(f"Error deleting existing code examples for {url}: {e}") 519 | 520 | # Process in batches 521 | total_items = len(urls) 522 | for i in range(0, total_items, batch_size): 523 | batch_end = min(i + batch_size, total_items) 524 | batch_texts = [] 525 | 526 | # Create combined texts for embedding (code + summary) 527 | for j in range(i, batch_end): 528 | combined_text = f"{code_examples[j]}\n\nSummary: {summaries[j]}" 529 | batch_texts.append(combined_text) 530 | 531 | # Create embeddings for the batch 532 | embeddings = create_embeddings_batch(batch_texts) 533 | 534 | # Check if embeddings are valid (not all zeros) 535 | valid_embeddings = [] 536 | for embedding in embeddings: 537 | if embedding and not all(v == 0.0 for v in embedding): 538 | valid_embeddings.append(embedding) 539 | else: 540 | print(f"Warning: Zero or invalid embedding detected, creating new one...") 541 | # Try to create a single embedding as fallback 542 | single_embedding = create_embedding(batch_texts[len(valid_embeddings)]) 543 | valid_embeddings.append(single_embedding) 544 | 545 | # Prepare batch data 546 | batch_data = [] 547 | for j, embedding in enumerate(valid_embeddings): 548 | idx = i + j 549 | 550 | # Extract source_id from URL 551 | parsed_url = urlparse(urls[idx]) 552 | source_id = parsed_url.netloc or parsed_url.path 553 | 554 | batch_data.append({ 555 | 'url': urls[idx], 556 | 'chunk_number': chunk_numbers[idx], 557 | 'content': code_examples[idx], 558 | 'summary': summaries[idx], 559 | 'metadata': metadatas[idx], # Store as JSON object, not string 560 | 'source_id': source_id, 561 | 'embedding': embedding 562 | }) 563 | 564 | # Insert batch into Supabase with retry logic 565 | max_retries = 3 566 | retry_delay = 1.0 # Start with 1 second delay 567 | 568 | for retry in range(max_retries): 569 | try: 570 | client.table('code_examples').insert(batch_data).execute() 571 | # Success - break out of retry loop 572 | break 573 | except Exception as e: 574 | if retry < max_retries - 1: 575 | print(f"Error inserting batch into Supabase (attempt {retry + 1}/{max_retries}): {e}") 576 | print(f"Retrying in {retry_delay} seconds...") 577 | time.sleep(retry_delay) 578 | retry_delay *= 2 # Exponential backoff 579 | else: 580 | # Final attempt failed 581 | print(f"Failed to insert batch after {max_retries} attempts: {e}") 582 | # Optionally, try inserting records one by one as a last resort 583 | print("Attempting to insert records individually...") 584 | successful_inserts = 0 585 | for record in batch_data: 586 | try: 587 | client.table('code_examples').insert(record).execute() 588 | successful_inserts += 1 589 | except Exception as individual_error: 590 | print(f"Failed to insert individual record for URL {record['url']}: {individual_error}") 591 | 592 | if successful_inserts > 0: 593 | print(f"Successfully inserted {successful_inserts}/{len(batch_data)} records individually") 594 | print(f"Inserted batch {i//batch_size + 1} of {(total_items + batch_size - 1)//batch_size} code examples") 595 | 596 | 597 | def update_source_info(client: Client, source_id: str, summary: str, word_count: int): 598 | """ 599 | Update or insert source information in the sources table. 600 | 601 | Args: 602 | client: Supabase client 603 | source_id: The source ID (domain) 604 | summary: Summary of the source 605 | word_count: Total word count for the source 606 | """ 607 | try: 608 | # Try to update existing source 609 | result = client.table('sources').update({ 610 | 'summary': summary, 611 | 'total_word_count': word_count, 612 | 'updated_at': 'now()' 613 | }).eq('source_id', source_id).execute() 614 | 615 | # If no rows were updated, insert new source 616 | if not result.data: 617 | client.table('sources').insert({ 618 | 'source_id': source_id, 619 | 'summary': summary, 620 | 'total_word_count': word_count 621 | }).execute() 622 | print(f"Created new source: {source_id}") 623 | else: 624 | print(f"Updated source: {source_id}") 625 | 626 | except Exception as e: 627 | print(f"Error updating source {source_id}: {e}") 628 | 629 | 630 | def extract_source_summary(source_id: str, content: str, max_length: int = 500) -> str: 631 | """ 632 | Extract a summary for a source from its content using an LLM. 633 | 634 | This function uses the OpenAI API to generate a concise summary of the source content. 635 | 636 | Args: 637 | source_id: The source ID (domain) 638 | content: The content to extract a summary from 639 | max_length: Maximum length of the summary 640 | 641 | Returns: 642 | A summary string 643 | """ 644 | # Default summary if we can't extract anything meaningful 645 | default_summary = f"Content from {source_id}" 646 | 647 | if not content or len(content.strip()) == 0: 648 | return default_summary 649 | 650 | # Get the model choice from environment variables 651 | model_choice = os.getenv("MODEL_CHOICE") 652 | 653 | # Limit content length to avoid token limits 654 | truncated_content = content[:25000] if len(content) > 25000 else content 655 | 656 | # Create the prompt for generating the summary 657 | prompt = f"""<source_content> 658 | {truncated_content} 659 | </source_content> 660 | 661 | The above content is from the documentation for '{source_id}'. Please provide a concise summary (3-5 sentences) that describes what this library/tool/framework is about. The summary should help understand what the library/tool/framework accomplishes and the purpose. 662 | """ 663 | 664 | try: 665 | # Call the OpenAI API to generate the summary 666 | response = openai.chat.completions.create( 667 | model=model_choice, 668 | messages=[ 669 | {"role": "system", "content": "You are a helpful assistant that provides concise library/tool/framework summaries."}, 670 | {"role": "user", "content": prompt} 671 | ], 672 | temperature=0.3, 673 | max_tokens=150 674 | ) 675 | 676 | # Extract the generated summary 677 | summary = response.choices[0].message.content.strip() 678 | 679 | # Ensure the summary is not too long 680 | if len(summary) > max_length: 681 | summary = summary[:max_length] + "..." 682 | 683 | return summary 684 | 685 | except Exception as e: 686 | print(f"Error generating summary with LLM for {source_id}: {e}. Using default summary.") 687 | return default_summary 688 | 689 | 690 | def search_code_examples( 691 | client: Client, 692 | query: str, 693 | match_count: int = 10, 694 | filter_metadata: Optional[Dict[str, Any]] = None, 695 | source_id: Optional[str] = None 696 | ) -> List[Dict[str, Any]]: 697 | """ 698 | Search for code examples in Supabase using vector similarity. 699 | 700 | Args: 701 | client: Supabase client 702 | query: Query text 703 | match_count: Maximum number of results to return 704 | filter_metadata: Optional metadata filter 705 | source_id: Optional source ID to filter results 706 | 707 | Returns: 708 | List of matching code examples 709 | """ 710 | # Create a more descriptive query for better embedding match 711 | # Since code examples are embedded with their summaries, we should make the query more descriptive 712 | enhanced_query = f"Code example for {query}\n\nSummary: Example code showing {query}" 713 | 714 | # Create embedding for the enhanced query 715 | query_embedding = create_embedding(enhanced_query) 716 | 717 | # Execute the search using the match_code_examples function 718 | try: 719 | # Only include filter parameter if filter_metadata is provided and not empty 720 | params = { 721 | 'query_embedding': query_embedding, 722 | 'match_count': match_count 723 | } 724 | 725 | # Only add the filter if it's actually provided and not empty 726 | if filter_metadata: 727 | params['filter'] = filter_metadata 728 | 729 | # Add source filter if provided 730 | if source_id: 731 | params['source_filter'] = source_id 732 | 733 | result = client.rpc('match_code_examples', params).execute() 734 | 735 | return result.data 736 | except Exception as e: 737 | print(f"Error searching code examples: {e}") 738 | return [] --------------------------------------------------------------------------------