├── src ├── agent.py ├── __init__.py ├── tools │ ├── __init__.py │ ├── README.md │ ├── test_tools.py │ ├── examples │ │ ├── math_example.py │ │ ├── puppeteer_example.py │ │ ├── simple_math_client.py │ │ ├── math_custom_adapter.py │ │ └── playwright_example.py │ ├── simple_test.py │ ├── mcp_README.md │ ├── math_server.py │ ├── test_mcp_tools.py │ └── registry.py ├── test_agent_architecture.py └── configuration.py ├── log-browser.yml ├── log-server.yml ├── models ├── __init__.py ├── research.py └── file_analysis.py ├── routers ├── __init__.py └── database.py ├── services └── __init__.py ├── assets ├── edr_ppl.png ├── benchmarks.png ├── edr-logo.png └── leaderboard.png ├── ai-research-assistant ├── postcss.config.js ├── public │ ├── sfr_logo.jpeg │ └── index.html ├── .gitignore ├── src │ ├── index.js │ ├── components │ │ ├── ResearchItem.js │ │ ├── CodeWithVisualization.js │ │ ├── Navbar.js │ │ ├── LoadingIndicator.js │ │ └── CodeSnippetViewer.js │ ├── index.css │ └── App.js ├── tailwind.config.js └── package.json ├── Tech_Report__Enterprise_Deep_Research.pdf ├── CODEOWNERS ├── e2b.Dockerfile ├── langgraph.json ├── mcp_agent.secrets.yaml ├── replit.nix ├── SECURITY.md ├── e2b.toml ├── package.json ├── AI_ETHICS.md ├── math_server.py ├── benchmarks ├── run_research.sh ├── process_drb.py └── README.md ├── pyproject.toml ├── model_test.py ├── .gitignore ├── requirements.txt ├── .env.sample ├── graph_test.py ├── test_graph.py ├── test_agents.py ├── math_client_new.py ├── session_store.py ├── math_client_langgraph.py ├── test_visualization.py ├── CODE_OF_CONDUCT.md ├── app.py ├── CONTRIBUTING.md ├── test_specialized_searches.py ├── test_unified_query.py └── test_benchmark.py /src/agent.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /log-browser.yml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /log-server.yml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # models package -------------------------------------------------------------------------------- /routers/__init__.py: -------------------------------------------------------------------------------- 1 | # routers package -------------------------------------------------------------------------------- /services/__init__.py: -------------------------------------------------------------------------------- 1 | # services package -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # Make src directory a Python package 3 | -------------------------------------------------------------------------------- /assets/edr_ppl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/enterprise-deep-research/HEAD/assets/edr_ppl.png -------------------------------------------------------------------------------- /assets/benchmarks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/enterprise-deep-research/HEAD/assets/benchmarks.png -------------------------------------------------------------------------------- /assets/edr-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/enterprise-deep-research/HEAD/assets/edr-logo.png -------------------------------------------------------------------------------- /assets/leaderboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/enterprise-deep-research/HEAD/assets/leaderboard.png -------------------------------------------------------------------------------- /ai-research-assistant/postcss.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | plugins: { 3 | tailwindcss: {}, 4 | autoprefixer: {}, 5 | }, 6 | }; -------------------------------------------------------------------------------- /Tech_Report__Enterprise_Deep_Research.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/enterprise-deep-research/HEAD/Tech_Report__Enterprise_Deep_Research.pdf -------------------------------------------------------------------------------- /ai-research-assistant/public/sfr_logo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SalesforceAIResearch/enterprise-deep-research/HEAD/ai-research-assistant/public/sfr_logo.jpeg -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related other information. Please be careful while editing. 2 | #ECCN:Open Source 3 | #GUSINFO:Open Source,Open Source Workflow 4 | -------------------------------------------------------------------------------- /e2b.Dockerfile: -------------------------------------------------------------------------------- 1 | # You can use most Debian-based base images 2 | # FROM ubuntu:22.04 3 | FROM e2bdev/code-interpreter:latest 4 | # Install dependencies and customize sandbox 5 | 6 | # Install some Python packages 7 | RUN pip install cowsay 8 | 9 | -------------------------------------------------------------------------------- /langgraph.json: -------------------------------------------------------------------------------- 1 | { 2 | "dockerfile_lines": [], 3 | "graphs": { 4 | "frank_deep_researcher": "./src/graph.py:graph" 5 | }, 6 | "python_version": "3.11", 7 | "env": "./.env", 8 | "dependencies": [ 9 | "." 10 | ] 11 | } -------------------------------------------------------------------------------- /ai-research-assistant/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | AI Research Assistant 7 | 8 | 9 |
10 | 11 | -------------------------------------------------------------------------------- /mcp_agent.secrets.yaml: -------------------------------------------------------------------------------- 1 | # MCP Secrets Configuration 2 | # Fill in your API keys below 3 | 4 | # Anthropic API Key (for Claude) 5 | anthropic: 6 | api_key: "your_anthropic_api_key_here" 7 | 8 | # OpenAI API Key 9 | openai: 10 | api_key: "your_openai_api_key_here" 11 | 12 | # GitHub API Key (for GitHub MCP server) 13 | github: 14 | api_key: "your_github_api_key_here" 15 | -------------------------------------------------------------------------------- /replit.nix: -------------------------------------------------------------------------------- 1 | {pkgs}: { 2 | deps = [ 3 | pkgs.tk 4 | pkgs.tcl 5 | pkgs.qhull 6 | pkgs.pkg-config 7 | pkgs.gtk3 8 | pkgs.gobject-introspection 9 | pkgs.ghostscript 10 | pkgs.freetype 11 | pkgs.ffmpeg-full 12 | pkgs.cairo 13 | pkgs.libxcrypt 14 | pkgs.glibcLocales 15 | pkgs.bash 16 | pkgs.rustc 17 | pkgs.libiconv 18 | pkgs.cargo 19 | ]; 20 | } 21 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. -------------------------------------------------------------------------------- /ai-research-assistant/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | 8 | # testing 9 | /coverage 10 | 11 | # production 12 | /build 13 | 14 | # misc 15 | .DS_Store 16 | .env.local 17 | .env.development.local 18 | .env.test.local 19 | .env.production.local 20 | 21 | npm-debug.log* 22 | yarn-debug.log* 23 | yarn-error.log* 24 | -------------------------------------------------------------------------------- /e2b.toml: -------------------------------------------------------------------------------- 1 | # This is a config for E2B sandbox template. 2 | # You can use template ID (xe0uinj2n1ufrgmmksbu) to create a sandbox: 3 | 4 | # Python SDK 5 | # from e2b import Sandbox, AsyncSandbox 6 | # sandbox = Sandbox("xe0uinj2n1ufrgmmksbu") # Sync sandbox 7 | # sandbox = await AsyncSandbox.create("xe0uinj2n1ufrgmmksbu") # Async sandbox 8 | 9 | # JS SDK 10 | # import { Sandbox } from 'e2b' 11 | # const sandbox = await Sandbox.create('xe0uinj2n1ufrgmmksbu') 12 | 13 | team_id = "f6bd0df0-7308-4e18-87e2-02e452fa8e83" 14 | memory_mb = 4_096 15 | cpu_count = 4 16 | start_cmd = "/root/.jupyter/start-up.sh" 17 | dockerfile = "e2b.Dockerfile" 18 | template_id = "xe0uinj2n1ufrgmmksbu" 19 | -------------------------------------------------------------------------------- /ai-research-assistant/src/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom/client'; 3 | import './index.css'; // Main Tailwind CSS 4 | import App from './App'; // Assuming App.js is in the same directory 5 | // import reportWebVitals from './reportWebVitals'; // Optional 6 | 7 | const root = ReactDOM.createRoot(document.getElementById('root')); 8 | root.render( 9 | 10 | 11 | 12 | ); 13 | 14 | // If you want to start measuring performance in your app, pass a function 15 | // to log results (for example: reportWebVitals(console.log)) 16 | // or send to an analytics endpoint. Learn more: https://bit.ly/CRA-vitals 17 | // reportWebVitals(); -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "workspace", 3 | "version": "1.0.0", 4 | "description": "This project provides a flexible system for integrating various LLM providers (OpenAI, Anthropic, Groq) with the E2B code interpreter.", 5 | "main": "index.js", 6 | "scripts": { 7 | "test": "echo \"Error: no test specified\" && exit 1" 8 | }, 9 | "repository": { 10 | "type": "git", 11 | "url": "git+https://github.com/frankyanwang/deep_research.git" 12 | }, 13 | "keywords": [], 14 | "author": "", 15 | "license": "ISC", 16 | "type": "commonjs", 17 | "bugs": { 18 | "url": "https://github.com/frankyanwang/deep_research/issues" 19 | }, 20 | "homepage": "https://github.com/frankyanwang/deep_research#readme", 21 | "dependencies": { 22 | "@e2b/cli": "^1.3.2", 23 | "rehype-raw": "^7.0.0", 24 | "serve": "^14.2.4" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /AI_ETHICS.md: -------------------------------------------------------------------------------- 1 | ## Ethics disclaimer for Salesforce AI models, data, code 2 | 3 | This release is for research purposes only in support of an academic 4 | paper. Our models, datasets, and code are not specifically designed or 5 | evaluated for all downstream purposes. We strongly recommend users 6 | evaluate and address potential concerns related to accuracy, safety, and 7 | fairness before deploying this model. We encourage users to consider the 8 | common limitations of AI, comply with applicable laws, and leverage best 9 | practices when selecting use cases, particularly for high-risk scenarios 10 | where errors or misuse could significantly impact people’s lives, rights, 11 | or safety. For further guidance on use cases, refer to our standard 12 | [AUP](https://www.salesforce.com/content/dam/web/en_us/www/documents/legal/Agreements/policies/ExternalFacing_Services_Policy.pdf) 13 | and [AI AUP](https://www.salesforce.com/content/dam/web/en_us/www/documents/legal/Agreements/policies/ai-acceptable-use-policy.pdf). -------------------------------------------------------------------------------- /math_server.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple MCP server that provides math tools. 3 | """ 4 | import sys 5 | from mcp.server.fastmcp import FastMCP 6 | 7 | mcp = FastMCP("MathTools") 8 | 9 | @mcp.tool() 10 | def add(a: int, b: int) -> int: 11 | """Add two numbers.""" 12 | print(f"Frank: Adding {a} and {b}", file=sys.stderr) 13 | return a + b 14 | 15 | @mcp.tool() 16 | def subtract(a: int, b: int) -> int: 17 | """Subtract b from a.""" 18 | print(f"Frank: Subtracting {b} from {a}", file=sys.stderr) 19 | return a - b 20 | 21 | @mcp.tool() 22 | def multiply(a: int, b: int) -> int: 23 | """Multiply two numbers.""" 24 | print(f"Frank: Multiplying {a} and {b}", file=sys.stderr) 25 | return a * b 26 | 27 | @mcp.tool() 28 | def divide(a: float, b: float) -> float: 29 | """Divide a by b. Returns an error if b is 0.""" 30 | print(f"Frank: Dividing {a} by {b}", file=sys.stderr) 31 | if b == 0: 32 | raise ValueError("Cannot divide by zero") 33 | return a / b 34 | 35 | if __name__ == "__main__": 36 | mcp.run(transport="stdio") 37 | -------------------------------------------------------------------------------- /benchmarks/run_research.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "🚀 Starting run_research.sh script..." 4 | echo "📅 Start time: $(date)" 5 | echo "👤 User: $(whoami)" 6 | echo "📁 Working directory: $(pwd)" 7 | echo "===============================================" 8 | 9 | LOGS_DIR="logs" 10 | mkdir -p $LOGS_DIR 11 | 12 | ## Simple test 13 | python -u run_research.py "what is ai?" \ 14 | --max-loops 2 \ 15 | --output sample_result.json > $LOGS_DIR/traj.log 2>&1 & 16 | 17 | ## DeepResearch Bench (DRB) 18 | # python -u run_research_concurrent.py \ 19 | # --benchmark drb \ 20 | # --input /Users/akshara.prabhakar/Documents/deep_research/benchmarks/deep_research_bench/data/prompt_data/query.jsonl \ 21 | # --output_dir drb_trajectories \ 22 | # --collect-traj \ 23 | # --task_ids 81 > $LOGS_DIR/drb_traj1.log 2>&1 & 24 | 25 | ## DeepConsult 26 | # python -u run_research_concurrent.py \ 27 | # --benchmark deepconsult \ 28 | # --input ydc-deep-research-evals/datasets/DeepConsult/queries.csv \ 29 | # --limit 1 \ 30 | # --output_dir deepconsult_trajectories > $LOGS_DIR/deepconsult_traj.log 2>&1 & 31 | -------------------------------------------------------------------------------- /ai-research-assistant/tailwind.config.js: -------------------------------------------------------------------------------- 1 | /** @type {import('tailwindcss').Config} */ 2 | module.exports = { 3 | content: [ 4 | "./src/**/*.{js,jsx,ts,tsx}", 5 | "./public/index.html" 6 | ], 7 | theme: { 8 | extend: { 9 | colors: { 10 | bg: '#FFFFFF', 11 | surface: '#F7F8F8', 12 | text: '#0F172A', 13 | accent: '#DA7756', 14 | border: '#E5E7EB', 15 | }, 16 | fontFamily: { 17 | sans: ['Inter', 'sans-serif'], 18 | }, 19 | borderRadius: { 20 | DEFAULT: '6px', 21 | md: '6px', // Explicitly set md if you use rounded-md often 22 | }, 23 | spacing: { 24 | // Example: Define specific spacing values if needed, otherwise rely on defaults scaled appropriately. 25 | // You might want to align defaults closer to a 24px grid, e.g., 6: '1.5rem' (24px) 26 | '6': '1.5rem', // 24px 27 | }, 28 | boxShadow: { 29 | sm: '0 1px 2px 0 rgb(0 0 0 / 0.05)', // Keep or adjust default shadow-sm 30 | } 31 | }, 32 | }, 33 | plugins: [], 34 | } -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "e2b-research-agent" 3 | version = "0.1.0" 4 | description = "Research Agent using LangGraph" 5 | authors = [ 6 | {name = "User", email = "user@example.com"}, 7 | ] 8 | readme = "README.md" 9 | requires-python = ">=3.11" 10 | dependencies = [ 11 | "python-dotenv>=1.0.0", 12 | "langchain>=0.3.9", 13 | "langchain_openai>=0.2.11", 14 | "langchain_anthropic>=0.3.10", 15 | "langchain_groq>=0.0.3", 16 | "groq>=0.4.2", 17 | "openai>=1.6.0", 18 | "anthropic>=0.49.0", 19 | "e2b_code_interpreter>=1.1.1", 20 | "tenacity>=8.5.0", 21 | "pandas==2.2.1", 22 | "matplotlib==3.8.3", 23 | "seaborn==0.13.2", 24 | "scikit-learn==1.6.1", 25 | "langgraph>=0.3.16", 26 | "tavily-python>=0.5.0", 27 | "langchain-core>=0.3.45", 28 | "requests>=2.32.0", 29 | "typing-extensions>=4.0.0", 30 | "pydantic>=2.5.2", 31 | "tiktoken>=0.5.1", 32 | "mcp>=1.4.1", 33 | "langchain-mcp-adapters>=0.0.6" 34 | ] 35 | 36 | [build-system] 37 | requires = ["setuptools>=61.0"] 38 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /model_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | from src.configuration import Configuration 4 | 5 | # Clear any existing env vars 6 | if "LLM_MODEL" in os.environ: 7 | del os.environ["LLM_MODEL"] 8 | if "LLM_PROVIDER" in os.environ: 9 | del os.environ["LLM_PROVIDER"] 10 | 11 | # Load environment variables from .env file 12 | load_dotenv(override=True) 13 | 14 | # Print the raw environment variables 15 | print(f"Env var LLM_MODEL: {os.environ.get('LLM_MODEL')}") 16 | print(f"Env var LLM_PROVIDER: {os.environ.get('LLM_PROVIDER')}") 17 | 18 | # Get the configuration 19 | config = Configuration() 20 | print(f"\nDefault Configuration:") 21 | print(f"LLM provider: {config.llm_provider}") 22 | print(f"LLM model: {config.llm_model}") 23 | 24 | # Test with explicit env vars 25 | os.environ["LLM_PROVIDER"] = "openai" 26 | os.environ["LLM_MODEL"] = "" # Empty to test the fallback 27 | 28 | # Get a new configuration instance 29 | config = Configuration() 30 | print(f"\nWith explicit provider and empty model:") 31 | print(f"LLM provider: {config.llm_provider}") 32 | print(f"LLM model: {config.llm_model}") -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # Unit test / coverage reports 28 | htmlcov/ 29 | .tox/ 30 | .coverage 31 | .coverage.* 32 | .cache 33 | nosetests.xml 34 | coverage.xml 35 | *.cover 36 | .hypothesis/ 37 | 38 | # Environments 39 | .env 40 | .venv 41 | env/ 42 | edr/ 43 | venv/ 44 | ENV/ 45 | env.bak/ 46 | venv.bak/ 47 | .envnew/ 48 | 49 | # IDE specific files 50 | .idea/ 51 | .vscode/ 52 | *.swp 53 | *.swo 54 | 55 | # Project specific 56 | *.png 57 | *.jpg 58 | *.jpeg 59 | *.csv 60 | *.pdf 61 | *.log 62 | !assets/edr_ppl.png 63 | !assets/leaderboard.png 64 | !assets/benchmarks.png 65 | !ai-research-assistant/public/sfr_logo.jpeg 66 | uploads/ 67 | visualizations/ 68 | steering_sessions.json 69 | backend_logs* 70 | 71 | 72 | # LangGraph Studio 73 | .langgraph/ 74 | .langgraph-studio/ 75 | .langgraph-studio-cache/ 76 | .langgraph_api 77 | 78 | .claude/ -------------------------------------------------------------------------------- /ai-research-assistant/src/components/ResearchItem.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ActivityItem from './ActivityItem'; 3 | 4 | function ResearchItem({ item, isActive, onClick }) { // isActive might not be needed now, onClick passed down 5 | const { id, title, timestamp, content, activityText, enrichedData, type } = item; 6 | 7 | // Determine the main activity text for this item 8 | // Prioritize activityText, then title, then content 9 | const mainActivityText = activityText || title || content || 'Processing...'; 10 | 11 | // Accept index as a prop if passed from parent (default 0) 12 | // Accept itemType as item.type or 'default' 13 | const itemType = type || 'default'; 14 | const itemIndex = typeof item.index === 'number' ? item.index : 0; 15 | 16 | // Always render using ActivityItem structure for consistency 17 | return ( 18 |
19 | 26 |
27 | ); 28 | } 29 | 30 | export default ResearchItem; -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python-dotenv>=1.0.0 2 | langchain-core>=0.3.78,<1.0.0 3 | langchain>=0.3.9 4 | langchain_openai>=0.2.11 5 | langchain_anthropic>=0.3.10 6 | langchain_groq>=0.0.3 7 | langchain_google_genai>=2.1.3 8 | langchain_google_vertexai>=2.0.0,<3.0.0 9 | groq>=0.4.2 10 | openai>=1.6.0 11 | anthropic>=0.49.0 12 | e2b-code-interpreter==1.2.0b1 13 | tenacity<9.0.0 14 | pandas==2.2.1 15 | matplotlib==3.8.3 16 | seaborn==0.13.2 17 | scikit-learn==1.6.1 18 | langgraph==0.3.22 19 | langgraph-checkpoint==2.0.23 20 | tavily-python>=0.5.0 21 | requests>=2.32.0 22 | typing-extensions>=4.0.0 23 | pydantic>=2.5.2 24 | tiktoken>=0.5.1 25 | fastapi>=0.110.0 26 | uvicorn>=0.28.0 27 | starlette>=0.36.0 28 | langsmith 29 | langchain-mcp-adapters>=0.0.6 30 | mcp>=1.4.1 31 | numpy<2.0.0 32 | plotly>=5.0.0 33 | sse-starlette>=2.2.1 34 | python-multipart>=0.0.6 35 | nest-asyncio>=1.5.0 36 | 37 | # File processing dependencies 38 | PyPDF2>=3.0.1 39 | pdfplumber>=0.9.0 40 | python-docx>=0.8.11 41 | openpyxl>=3.1.2 42 | Pillow>=10.0.1 43 | pytesseract>=0.3.10 44 | opencv-python>=4.8.1.78 45 | python-magic>=0.4.27 46 | 47 | # Audio/Video processing 48 | SpeechRecognition>=3.10.0 49 | moviepy>=1.0.3 50 | whisper>=1.1.10 51 | 52 | # Additional utilities for file handling 53 | aiofiles>=23.2.1 54 | -------------------------------------------------------------------------------- /.env.sample: -------------------------------------------------------------------------------- 1 | # TODO: Get your E2B API key from https://e2b.dev/docs/getting-started/api-key 2 | E2B_API_KEY="" 3 | ANTHROPIC_API_KEY="" 4 | FIRECRAWL_API_KEY="" 5 | SEARCH_API=tavily 6 | TAVILY_API_KEY="" 7 | LANGCHAIN_API_KEY="" 8 | GOOGLE_CLOUD_PROJECT="" 9 | OPENAI_API_KEY="" 10 | SAMBNOVA_API_KEY="" 11 | # TODO: Get your Groq API key from https://console.groq.com/keys 12 | GROQ_API_KEY="" 13 | 14 | JINA_API_KEY="" 15 | 16 | MAX_WEB_RESEARCH_LOOPS=3 17 | FETCH_FULL_PAGE=True 18 | 19 | LLM_PROVIDER=google 20 | LLM_MODEL=gemini-2.5-pro 21 | 22 | ## LLM Providers ## 23 | 24 | # LLM_PROVIDER=anthropic 25 | # LLM_MODEL=claude-3-7-sonnet 26 | # LLM_MODEL=claude-3-7-sonnet-thinking 27 | 28 | # LLM_PROVIDER=openai 29 | # LLM_MODEL=gpt-4.1 30 | # LLM_MODEL=o3-mini 31 | # LLM_MODEL=o3-mini-reasoning 32 | 33 | # LLM_PROVIDER=groq 34 | # LLM_MODEL=deepseek-r1-distill-llama-70b 35 | 36 | # LLM_PROVIDER=sambnova 37 | # LLM_MODEL=DeepSeek-V3-0324 38 | 39 | ## Activity generation configuration ## 40 | ENABLE_ACTIVITY_GENERATION=true 41 | ACTIVITY_VERBOSITY=medium 42 | ACTIVITY_LLM_PROVIDER=google 43 | ACTIVITY_LLM_MODEL=gemini-2.5-flash 44 | 45 | LANGCHAIN_TRACING_V2=true 46 | LANGCHAIN_ENDPOINT="" 47 | LANGCHAIN_PROJECT="" 48 | 49 | SCRAPYBARA_API_KEY="" -------------------------------------------------------------------------------- /src/tools/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools package for the research agent. 3 | 4 | This package contains the implementation of tools used by the research agent, 5 | including search tools, tool registry, and tool executor. 6 | """ 7 | 8 | from src.tools.search_tools import ( 9 | GeneralSearchTool, 10 | AcademicSearchTool, 11 | GithubSearchTool, 12 | LinkedinSearchTool 13 | ) 14 | from src.tools.registry import SearchToolRegistry 15 | from src.tools.executor import ToolExecutor 16 | from src.tools.tool_schema import ( 17 | SEARCH_TOOL_FUNCTIONS, 18 | TOPIC_DECOMPOSITION_FUNCTION, 19 | GeneralSearchToolSchema, 20 | AcademicSearchToolSchema, 21 | GithubSearchToolSchema, 22 | LinkedinSearchToolSchema, 23 | SimpleTopicResponse, 24 | ComplexTopicResponse, 25 | Subtopic 26 | ) 27 | from src.tools.mcp_tools import MCPToolProvider, MCPToolManager 28 | 29 | __all__ = [ 30 | 'GeneralSearchTool', 31 | 'AcademicSearchTool', 32 | 'GithubSearchTool', 33 | 'LinkedinSearchTool', 34 | 'SearchToolRegistry', 35 | 'ToolExecutor', 36 | 'SEARCH_TOOL_FUNCTIONS', 37 | 'TOPIC_DECOMPOSITION_FUNCTION', 38 | 'GeneralSearchToolSchema', 39 | 'AcademicSearchToolSchema', 40 | 'GithubSearchToolSchema', 41 | 'LinkedinSearchToolSchema', 42 | 'SimpleTopicResponse', 43 | 'ComplexTopicResponse', 44 | 'Subtopic', 45 | 'MCPToolProvider', 46 | 'MCPToolManager' 47 | ] -------------------------------------------------------------------------------- /graph_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | import importlib 4 | import sys 5 | 6 | # Clear out any cached module 7 | if 'src.graph' in sys.modules: 8 | del sys.modules['src.graph'] 9 | if 'src.configuration' in sys.modules: 10 | del sys.modules['src.configuration'] 11 | 12 | # Clear any existing env vars 13 | if "LLM_MODEL" in os.environ: 14 | del os.environ["LLM_MODEL"] 15 | 16 | # Load environment variables from .env file with override 17 | load_dotenv(override=True) 18 | 19 | # Print environment variables 20 | print(f"Environment variable LLM_MODEL: {os.environ.get('LLM_MODEL')}") 21 | 22 | # Import the modules 23 | from src.configuration import Configuration 24 | 25 | # Create a configuration and print the model 26 | config = Configuration() 27 | print(f"Configuration LLM model: {config.llm_model}") 28 | 29 | # Test with no environment variable 30 | del os.environ["LLM_MODEL"] 31 | config = Configuration() 32 | print(f"Configuration LLM model with no env var: {config.llm_model}") 33 | 34 | # Import the graph module 35 | from src import graph 36 | 37 | # Print the model used in a graph function 38 | print(f"\nTesting graph module behavior:") 39 | config = {"configurable": {"llm_model": ""}} 40 | from src.graph import initial_query 41 | try: 42 | # This will fail but we just want to see what model it tries to use 43 | initial_query({"research_topic": "test"}, config) 44 | except Exception as e: 45 | print(f"Expected error: {str(e)[:100]}...") -------------------------------------------------------------------------------- /ai-research-assistant/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ai-research-assistant", 3 | "version": "0.1.0", 4 | "private": true, 5 | "dependencies": { 6 | "@emotion/react": "^11.14.0", 7 | "@emotion/styled": "^11.14.0", 8 | "@headlessui/react": "^2.2.0", 9 | "@mui/icons-material": "^7.0.1", 10 | "@mui/material": "^7.0.1", 11 | "@tippyjs/react": "^4.2.6", 12 | "date-fns": "^4.1.0", 13 | "docx": "^9.3.0", 14 | "file-saver": "^2.0.5", 15 | "html2pdf.js": "^0.10.3", 16 | "lucide-react": "^0.503.0", 17 | "react": "^18.2.0", 18 | "react-dom": "^18.2.0", 19 | "react-icons": "^5.5.0", 20 | "react-markdown": "^10.1.0", 21 | "react-syntax-highlighter": "^15.6.1", 22 | "rehype-raw": "^7.0.0", 23 | "remark-gfm": "^4.0.1", 24 | "styled-components": "^6.1.17", 25 | "tailwindcss": "^3.3.5" 26 | }, 27 | "scripts": { 28 | "start": "PORT=3001 react-scripts start", 29 | "build": "react-scripts build", 30 | "test": "react-scripts test", 31 | "eject": "react-scripts eject" 32 | }, 33 | "eslintConfig": { 34 | "extends": [ 35 | "react-app", 36 | "react-app/jest" 37 | ] 38 | }, 39 | "browserslist": { 40 | "production": [ 41 | ">0.2%", 42 | "not dead", 43 | "not op_mini all" 44 | ], 45 | "development": [ 46 | "last 1 chrome version", 47 | "last 1 firefox version", 48 | "last 1 safari version" 49 | ] 50 | }, 51 | "devDependencies": { 52 | "autoprefixer": "^10.4.21", 53 | "postcss": "^8.5.3", 54 | "react-scripts": "5.0.1" 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /test_graph.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | from src.graph import graph 4 | from src.configuration import Configuration, LLMProvider, SearchAPI 5 | 6 | # Load environment variables from .env file 7 | load_dotenv() 8 | 9 | # Debug: print SEARCH_API value from environment 10 | print(f"Debug - SEARCH_API from env: '{os.getenv('SEARCH_API')}'") 11 | 12 | def test_graph(): 13 | # Configure the graph 14 | # Get MAX_WEB_RESEARCH_LOOPS with better error handling 15 | max_loops_str = os.getenv("MAX_WEB_RESEARCH_LOOPS", "10") 16 | max_loops = int(max_loops_str) if max_loops_str.strip() else 10 17 | 18 | config = { 19 | "configurable": { 20 | "llm_provider": os.getenv("LLM_PROVIDER", "openai"), 21 | "llm_model": os.getenv("LLM_MODEL", "o3-mini"), 22 | "search_api": os.getenv("SEARCH_API", "tavily"), 23 | "max_web_research_loops": max_loops 24 | } 25 | } 26 | 27 | # Define a research topic 28 | research_topic = "Salesforce acquisitions last 10 years. What are companies, who are the founders of those companies, whether if those founders are still working for Salesforce. Please list them as a table" 29 | 30 | print(f"\n{'='*80}") 31 | print(f"Starting research on: {research_topic}") 32 | print(f"Using LLM provider: {config['configurable']['llm_provider']}") 33 | print(f"Using LLM model: {config['configurable']['llm_model']}") 34 | print(f"Using search API: {config['configurable']['search_api']}") 35 | print(f"Max research loops: {config['configurable']['max_web_research_loops']}") 36 | print(f"{'='*80}\n") 37 | 38 | # Run the graph 39 | # Add recursion_limit to the config (outside of 'configurable') 40 | config["recursion_limit"] = 50 41 | result = graph.invoke({"research_topic": research_topic}, config=config) 42 | 43 | print(f"\n{'='*80}") 44 | print("--- Research Complete ---") 45 | print(f"{'='*80}\n") 46 | print(result.get("running_summary", "No summary generated")) 47 | 48 | if __name__ == "__main__": 49 | test_graph() -------------------------------------------------------------------------------- /src/tools/README.md: -------------------------------------------------------------------------------- 1 | # Tool Calling Mechanism for Research Agent 2 | 3 | This directory contains the implementation of a tool calling mechanism for the research agent. The implementation follows the LangChain tool calling pattern and provides a standardized interface for search tools. 4 | 5 | ## Components 6 | 7 | ### Search Tools 8 | 9 | - `GeneralSearchTool`: A tool for general web search 10 | - `AcademicSearchTool`: A tool for academic and scholarly search 11 | - `GithubSearchTool`: A tool for GitHub and code-related search 12 | - `LinkedinSearchTool`: A tool for LinkedIn and professional profile search 13 | 14 | Each tool follows a standardized interface and provides a consistent output format. 15 | 16 | ### Tool Registry 17 | 18 | The `SearchToolRegistry` provides a central registry for search tools. It allows tools to be registered, retrieved, and managed in a central location. 19 | 20 | ### Tool Executor 21 | 22 | The `ToolExecutor` is responsible for executing tools based on their name and parameters. It handles both synchronous and asynchronous tools. 23 | 24 | ## Usage 25 | 26 | To use the tool calling mechanism, you need to: 27 | 28 | 1. Create a registry with `registry = SearchToolRegistry(config)` 29 | 2. Create an executor with `executor = ToolExecutor(registry, config)` 30 | 3. Execute a tool with `result = executor.execute_tool_sync("tool_name", {"param": "value"})` 31 | 32 | ## Testing 33 | 34 | You can test the implementation with: 35 | 36 | - `test_tools.py`: Tests the search tools, registry, and executor 37 | - `simple_test.py`: Tests the research agent with the new tool calling mechanism 38 | 39 | ## Extension 40 | 41 | To add a new search tool: 42 | 43 | 1. Create a new tool class that inherits from `BaseTool` 44 | 2. Implement the `_run` method to execute the tool 45 | 3. Register the tool with the registry using `registry.register_tool(new_tool)` 46 | 47 | ## Example 48 | 49 | ```python 50 | # Create registry and executor 51 | registry = SearchToolRegistry(config) 52 | executor = ToolExecutor(registry, config) 53 | 54 | # Execute a tool 55 | result = executor.execute_tool_sync( 56 | "general_search", 57 | {"query": "python langgraph framework"} 58 | ) 59 | 60 | # Process results 61 | formatted_sources = result.get("formatted_sources", []) 62 | search_string = result.get("search_string", "") 63 | tools_used = result.get("tools", []) 64 | domains = result.get("domains", []) -------------------------------------------------------------------------------- /src/tools/test_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test script for the tool calling mechanism. 3 | 4 | This script tests the search tools, registry, and executor implementations. 5 | """ 6 | 7 | import logging 8 | import sys 9 | import json 10 | from src.tools import ( 11 | SearchToolRegistry, 12 | ToolExecutor, 13 | GeneralSearchTool, 14 | AcademicSearchTool, 15 | GithubSearchTool, 16 | LinkedinSearchTool 17 | ) 18 | 19 | # Configure logging 20 | logging.basicConfig( 21 | level=logging.INFO, 22 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 23 | handlers=[logging.StreamHandler(sys.stdout)] 24 | ) 25 | 26 | logger = logging.getLogger("test_tools") 27 | 28 | def test_tool_registry(): 29 | """Test the tool registry implementation.""" 30 | logger.info("Testing tool registry...") 31 | 32 | # Create a registry 33 | registry = SearchToolRegistry() 34 | 35 | # Get tool descriptions 36 | descriptions = registry.get_tool_descriptions() 37 | logger.info(f"Registered tools: {json.dumps(descriptions, indent=2)}") 38 | 39 | # Get a specific tool 40 | tool = registry.get_tool("general_search") 41 | logger.info(f"Retrieved tool: {tool.name} - {tool.description}") 42 | 43 | logger.info("Tool registry test completed successfully.") 44 | 45 | def test_tool_executor(): 46 | """Test the tool executor implementation.""" 47 | logger.info("Testing tool executor...") 48 | 49 | # Create a registry and executor 50 | registry = SearchToolRegistry() 51 | executor = ToolExecutor(registry) 52 | 53 | # Execute a tool 54 | query = "python langgraph framework" 55 | logger.info(f"Executing general_search tool with query: {query}") 56 | 57 | result = executor.execute_tool_sync("general_search", {"query": query}) 58 | 59 | # Print results 60 | logger.info(f"Tool execution result - Formatted sources count: {len(result.get('formatted_sources', []))}") 61 | logger.info(f"Tool execution result - Search string length: {len(result.get('search_string', ''))}") 62 | logger.info(f"Tool execution result - Tools used: {result.get('tools', [])}") 63 | logger.info(f"Tool execution result - Domains count: {len(result.get('domains', []))}") 64 | 65 | logger.info("Tool executor test completed successfully.") 66 | 67 | if __name__ == "__main__": 68 | logger.info("Starting tool tests...") 69 | 70 | # Run tests 71 | test_tool_registry() 72 | print("\n" + "="*80 + "\n") 73 | test_tool_executor() 74 | 75 | logger.info("All tests completed.") -------------------------------------------------------------------------------- /src/tools/examples/math_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example script demonstrating how to use the Math MCP server with our tool registry. 3 | 4 | This example shows how to: 5 | 1. Start the Math MCP server 6 | 2. Register it with our tool registry 7 | 3. Execute some math operations using the tools 8 | 9 | Prerequisites: 10 | - The math_server.py file should be in the src/tools directory 11 | """ 12 | import asyncio 13 | import sys 14 | import os 15 | from pathlib import Path 16 | 17 | # Add the project root to the Python path 18 | project_root = Path(__file__).parent.parent.parent.parent 19 | sys.path.append(str(project_root)) 20 | 21 | from src.tools.registry import SearchToolRegistry 22 | from src.tools.executor import ToolExecutor 23 | from src.tools.mcp_tools import MCPToolManager 24 | 25 | async def main(): 26 | # Create a tool registry 27 | registry = SearchToolRegistry() 28 | executor = ToolExecutor(registry) 29 | 30 | # Create an MCP tool manager 31 | mcp_manager = MCPToolManager(registry) 32 | 33 | try: 34 | # Register the Math MCP server as a stdio subprocess 35 | math_server_path = Path(__file__).parent.parent / "math_server.py" 36 | 37 | if not math_server_path.exists(): 38 | print(f"Error: {math_server_path} not found") 39 | return 40 | 41 | print(f"Starting Math MCP server from: {math_server_path}") 42 | 43 | tools = await mcp_manager.register_stdio_server( 44 | name="math", 45 | command=sys.executable, 46 | args=[str(math_server_path)] 47 | ) 48 | 49 | print(f"Registered {len(tools)} tools from Math MCP server:") 50 | for tool in tools: 51 | print(f" - {tool.name}: {tool.description}") 52 | params = [f"{p.name}: {p.type}" for p in tool.parameters] 53 | print(f" Parameters: {', '.join(params)}") 54 | 55 | # Execute some math operations 56 | operations = [ 57 | ("add", {"a": 10, "b": 5}), 58 | ("subtract", {"a": 10, "b": 5}), 59 | ("multiply", {"a": 10, "b": 5}), 60 | ("divide", {"a": 10, "b": 5}) 61 | ] 62 | 63 | for op_name, params in operations: 64 | print(f"\nExecuting {op_name} with parameters: {params}") 65 | result = await executor.execute_tool( 66 | f"mcp.math.{op_name}", 67 | params 68 | ) 69 | print(f"Result: {result}") 70 | 71 | finally: 72 | # Close all MCP connections 73 | await mcp_manager.close_all() 74 | 75 | if __name__ == "__main__": 76 | asyncio.run(main()) -------------------------------------------------------------------------------- /test_agents.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Simple test script for the agent architecture. 4 | """ 5 | 6 | import os 7 | import sys 8 | import json 9 | import logging 10 | 11 | # Configure logging 12 | logging.basicConfig( 13 | level=logging.INFO, 14 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' 15 | ) 16 | logger = logging.getLogger(__name__) 17 | 18 | # Add src directory to Python path 19 | sys.path.append(".") 20 | 21 | # Create a simple test state class 22 | class TestState: 23 | def __init__(self, topic): 24 | self.research_topic = topic 25 | self.search_query = None 26 | self.knowledge_gap = "" 27 | self.research_loop_count = 0 28 | self.config = None 29 | 30 | def main(): 31 | """Test the agent architecture directly.""" 32 | 33 | # Check if OpenAI API key is set 34 | api_key = os.environ.get("OPENAI_API_KEY") 35 | if not api_key: 36 | logger.error("OPENAI_API_KEY environment variable is not set") 37 | return {"success": False, "error": "OPENAI_API_KEY environment variable is not set"} 38 | 39 | # Import the agent classes 40 | from src.agent_architecture import MasterResearchAgent 41 | 42 | # Create a test state with a simple research topic 43 | test_topic = "The impact of artificial intelligence on healthcare" 44 | state = TestState(test_topic) 45 | 46 | # Initialize the master agent 47 | master_agent = MasterResearchAgent() 48 | 49 | try: 50 | # Execute research 51 | logger.info(f"Starting research on topic: {test_topic}") 52 | results = master_agent.execute_research(state) 53 | 54 | # Print summary of results 55 | topic_complexity = results.get("research_results", {}).get("topic_complexity", "unknown") 56 | sources_count = len(results.get("sources_gathered", [])) 57 | tools_used = ", ".join(results.get("tools", [])) 58 | 59 | logger.info(f"Research complete!") 60 | logger.info(f"Topic complexity: {topic_complexity}") 61 | logger.info(f"Sources found: {sources_count}") 62 | logger.info(f"Tools used: {tools_used}") 63 | 64 | return { 65 | "success": True, 66 | "topic_complexity": topic_complexity, 67 | "sources_count": sources_count, 68 | "tools_used": results.get("tools", []) 69 | } 70 | 71 | except Exception as e: 72 | logger.error(f"Error testing agent architecture: {str(e)}") 73 | import traceback 74 | logger.error(traceback.format_exc()) 75 | return {"success": False, "error": str(e)} 76 | 77 | if __name__ == "__main__": 78 | result = main() 79 | 80 | # Print results in a formatted way 81 | print("\n" + "="*50) 82 | print("TEST RESULT:") 83 | print(json.dumps(result, indent=2)) 84 | print("="*50) -------------------------------------------------------------------------------- /benchmarks/process_drb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Convert individual JSON research reports to JSONL format for DeepResearchBench evaluation. 4 | 5 | Usage: 6 | python process_drb.py --input-dir /path/to/json/files --model-name your_model_name 7 | """ 8 | 9 | import json 10 | import os 11 | import argparse 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser( 16 | description="Convert individual JSON research reports to JSONL format for DeepResearchBench evaluation", 17 | formatter_class=argparse.RawDescriptionHelpFormatter, 18 | ) 19 | 20 | parser.add_argument( 21 | "--input-dir", 22 | "-i", 23 | type=str, 24 | required=True, 25 | help="Directory containing individual JSON report files", 26 | ) 27 | 28 | parser.add_argument( 29 | "--model-name", 30 | "-m", 31 | type=str, 32 | required=True, 33 | help="Model name for the output JSONL filename", 34 | ) 35 | 36 | args = parser.parse_args() 37 | 38 | # Validate input directory exists 39 | if not os.path.exists(args.input_dir): 40 | print(f"Error: Input directory '{args.input_dir}' does not exist") 41 | return 1 42 | 43 | # Collect all JSON files in the directory 44 | json_files = [f for f in os.listdir(args.input_dir) if f.endswith(".json")] 45 | 46 | if not json_files: 47 | print(f"Error: No JSON files found in '{args.input_dir}'") 48 | return 1 49 | 50 | print(f"Found {len(json_files)} JSON files in '{args.input_dir}'") 51 | 52 | all_reports = [] 53 | for file in json_files: 54 | file_path = os.path.join(args.input_dir, file) 55 | try: 56 | with open(file_path, "r", encoding="utf-8") as f: 57 | data = json.load(f) 58 | 59 | # Extract required fields 60 | report = { 61 | "id": data["id"], 62 | "prompt": data["prompt"], 63 | "article": data["article"], 64 | } 65 | all_reports.append(report) 66 | 67 | except (json.JSONDecodeError, KeyError) as e: 68 | print(f"Warning: Error processing file '{file}': {e}") 69 | continue 70 | except Exception as e: 71 | print(f"Warning: Unexpected error processing file '{file}': {e}") 72 | continue 73 | 74 | if not all_reports: 75 | print("Error: No valid reports were processed") 76 | return 1 77 | 78 | # Sort all reports by id 79 | all_reports.sort(key=lambda x: x["id"]) 80 | 81 | with open(f"deep_research_bench/data/test_data/raw_data/{args.model_name}.jsonl", "w", encoding="utf-8") as f: 82 | for report in all_reports: 83 | f.write(json.dumps(report, ensure_ascii=False) + "\n") 84 | 85 | print(f"Successfully processed {len(all_reports)} reports") 86 | print(f"Output saved to: {args.model_name}.jsonl") 87 | 88 | return 0 89 | 90 | 91 | if __name__ == "__main__": 92 | exit_code = main() 93 | exit(exit_code) 94 | -------------------------------------------------------------------------------- /src/test_agent_architecture.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test script for agent architecture. 3 | 4 | This script helps validate that the agent architecture is working correctly 5 | by executing a simple research request and comparing results with the 6 | original implementation. 7 | """ 8 | 9 | import os 10 | import sys 11 | import json 12 | import logging 13 | 14 | # Configure logging 15 | logging.basicConfig( 16 | level=logging.INFO, 17 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' 18 | ) 19 | logger = logging.getLogger(__name__) 20 | 21 | # Create a simple summary state class for testing 22 | class TestSummaryState: 23 | def __init__(self, research_topic, knowledge_gap="", research_loop_count=0): 24 | self.research_topic = research_topic 25 | self.search_query = None # Use research_topic 26 | self.knowledge_gap = knowledge_gap 27 | self.research_loop_count = research_loop_count 28 | self.config = None # Use default config 29 | 30 | def test_agent_architecture(): 31 | """Test the new agent architecture with a simple query.""" 32 | try: 33 | # Import the necessary functions and classes 34 | logger.info("Importing agent architecture components...") 35 | from src.agent_architecture import MasterResearchAgent 36 | 37 | # Create a test state 38 | test_topic = "The impact of artificial intelligence on healthcare" 39 | test_state = TestSummaryState(test_topic) 40 | 41 | # Test the agent architecture directly 42 | logger.info(f"Testing agent architecture with topic: {test_topic}") 43 | master_agent = MasterResearchAgent() 44 | results = master_agent.execute_research(test_state) 45 | 46 | # Log results summary 47 | topic_complexity = results.get("research_results", {}).get("topic_complexity", "unknown") 48 | sources_count = len(results.get("sources_gathered", [])) 49 | tools_used = results.get("tools", []) 50 | 51 | logger.info(f"Research completed with topic_complexity: {topic_complexity}") 52 | logger.info(f"Found {sources_count} sources using tools: {', '.join(tools_used)}") 53 | 54 | return { 55 | "success": True, 56 | "topic_complexity": topic_complexity, 57 | "sources_count": sources_count, 58 | "tools_used": tools_used 59 | } 60 | 61 | except Exception as e: 62 | logger.error(f"Error in test: {str(e)}") 63 | import traceback 64 | logger.error(traceback.format_exc()) 65 | return {"success": False, "error": str(e)} 66 | 67 | if __name__ == "__main__": 68 | # Check if OpenAI API key is set 69 | if not os.environ.get("OPENAI_API_KEY"): 70 | logger.error("OPENAI_API_KEY environment variable is not set") 71 | sys.exit(1) 72 | 73 | # Run the test 74 | result = test_agent_architecture() 75 | 76 | # Print result 77 | print("\n" + "="*50) 78 | print("TEST RESULT:") 79 | print(json.dumps(result, indent=2)) 80 | print("="*50) -------------------------------------------------------------------------------- /ai-research-assistant/src/components/CodeWithVisualization.js: -------------------------------------------------------------------------------- 1 | import React, { useEffect } from 'react'; 2 | import CodeSnippetViewer from './CodeSnippetViewer'; 3 | 4 | /** 5 | * Component that displays a code snippet with its associated visualization. 6 | * This creates a clear cause-effect relationship between code and its output. 7 | */ 8 | function CodeWithVisualization({ snippet }) { 9 | useEffect(() => { 10 | // Log component initialization with detailed snippet info 11 | console.debug('[CodeWithVisualization] Rendering with snippet:', { 12 | hasCode: !!snippet?.code, 13 | codeLength: snippet?.code?.length, 14 | language: snippet?.language, 15 | hasVisualization: !!snippet?.visualization, 16 | visualizationType: snippet?.visualization?.src ? 'src' : snippet?.visualization?.data ? 'data' : 'none', 17 | hasDescription: !!snippet?.visualization?.description 18 | }); 19 | }, [snippet]); 20 | 21 | // Input validation with detailed error logging 22 | if (!snippet || typeof snippet !== 'object') { 23 | console.error('[CodeWithVisualization] Invalid snippet provided:', snippet); 24 | return null; 25 | } 26 | 27 | const { code, language, visualization } = snippet; 28 | 29 | // Check that we have code - the essential part 30 | if (!code) { 31 | console.warn('[CodeWithVisualization] Snippet missing code:', snippet); 32 | return null; 33 | } 34 | 35 | // Detailed visualization validation 36 | const hasValidVisualization = visualization && 37 | (visualization.src || (visualization.data && visualization.format)); 38 | 39 | if (visualization && !hasValidVisualization) { 40 | console.warn('[CodeWithVisualization] Visualization data is invalid:', visualization); 41 | } 42 | 43 | return ( 44 |
45 | {/* Code snippet first */} 46 | 47 | 48 | {/* Visualization below the code if available and valid */} 49 | {hasValidVisualization && ( 50 |
51 | {visualization.description { 58 | console.error(`[CodeWithVisualization] Error loading visualization:`, e); 59 | console.error(`[CodeWithVisualization] Failed visualization data:`, { 60 | srcLength: visualization.src?.length, 61 | dataLength: visualization.data?.length, 62 | format: visualization.format, 63 | description: visualization.description 64 | }); 65 | e.target.style.display = 'none'; 66 | }} 67 | onLoad={() => console.debug('[CodeWithVisualization] Visualization loaded successfully')} 68 | /> 69 | {visualization.description && ( 70 |
{visualization.description}
71 | )} 72 |
73 | )} 74 |
75 | ); 76 | } 77 | 78 | export default CodeWithVisualization; 79 | -------------------------------------------------------------------------------- /math_client_new.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example client that uses langchain-mcp-adapters to connect to the math_server 3 | and uses the tools with a LangGraph agent. 4 | """ 5 | import asyncio 6 | import os 7 | import sys 8 | from pathlib import Path 9 | 10 | from langchain_openai import ChatOpenAI 11 | from langchain.agents import AgentExecutor 12 | from langchain.agents import AgentType 13 | from langchain.agents.initialize import initialize_agent 14 | from mcp import ClientSession, StdioServerParameters 15 | from mcp.client.stdio import stdio_client 16 | 17 | from langchain_mcp_adapters.tools import load_mcp_tools 18 | 19 | # Check if OpenAI API key is set 20 | OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "***REMOVED***") 21 | if not OPENAI_API_KEY: 22 | print("Warning: OPENAI_API_KEY environment variable not set.") 23 | print("You can set it with: export OPENAI_API_KEY=your_api_key_here") 24 | 25 | async def main(): 26 | # Get the current directory 27 | current_dir = Path(__file__).parent.absolute() 28 | 29 | # Path to the math_server.py file 30 | server_path = current_dir / "math_server.py" 31 | 32 | # Create server parameters for stdio connection 33 | server_params = StdioServerParameters( 34 | command=sys.executable, # Use the current Python interpreter 35 | args=[str(server_path)], 36 | ) 37 | 38 | # Initialize the chat model 39 | model = ChatOpenAI( 40 | model="gpt-4o", 41 | temperature=0, 42 | api_key=OPENAI_API_KEY, # Use the API key we got from the environment 43 | ) 44 | 45 | print("Starting MCP client session...") 46 | 47 | # Connect to the MCP server 48 | async with stdio_client(server_params) as (read, write): 49 | async with ClientSession(read, write) as session: 50 | # Initialize the connection 51 | await session.initialize() 52 | 53 | # Get the list of available tools 54 | tools = await load_mcp_tools(session) 55 | 56 | # Print available tools 57 | print(f"Loaded {len(tools)} tools from MCP server:") 58 | for tool in tools: 59 | print(f" - {tool.name}: {tool.description}") 60 | 61 | # Initialize a structured agent that properly handles JSON schema tools 62 | agent_executor = initialize_agent( 63 | tools, 64 | model, 65 | agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, 66 | verbose=True 67 | ) 68 | 69 | # Define some example queries to run 70 | queries = [ 71 | "What is 25 plus 17?", 72 | "If I have 100 and subtract 28, what do I get?", 73 | "What is 13 multiplied by 5?", 74 | "What is 120 divided by 4?", 75 | "If I add 10 to 20, then multiply by 3, what's the result?" 76 | ] 77 | 78 | # Run the agent for each query 79 | for query in queries: 80 | print("\n" + "="*50) 81 | print(f"Query: {query}") 82 | print("="*50) 83 | 84 | try: 85 | # Invoke the agent 86 | agent_response = await agent_executor.ainvoke({"input": query}) 87 | 88 | # Print the agent's response 89 | print("\nAgent Response:") 90 | print(agent_response["output"]) 91 | except Exception as e: 92 | print(f"Error processing query: {e}") 93 | 94 | if __name__ == "__main__": 95 | asyncio.run(main()) 96 | -------------------------------------------------------------------------------- /src/tools/simple_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Test script for the research agent with tool calling. 4 | This script serves as a simple test to ensure the research agent works correctly. 5 | """ 6 | 7 | import os 8 | import sys 9 | import logging 10 | import json 11 | from datetime import datetime, timedelta 12 | 13 | # Configure logging 14 | logging.basicConfig( 15 | level=logging.INFO, 16 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 17 | handlers=[logging.StreamHandler()] 18 | ) 19 | 20 | logger = logging.getLogger("research_agent_test") 21 | 22 | # Import needed modules 23 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) 24 | from src.graph import research_agent 25 | from src.state import SummaryState 26 | 27 | def run_test(): 28 | """Run a simple test of the research agent.""" 29 | logger.info("Starting research agent test...") 30 | logger.info("Testing research agent with tool calling...") 31 | 32 | # Create a summary state 33 | summary_state = SummaryState( 34 | research_topic="Latest advancements in large language models", 35 | research_loop_count=0 36 | ) 37 | 38 | # Call the research agent 39 | logger.info(f"Calling research agent with topic: {summary_state.research_topic}") 40 | 41 | # Create a simple configuration 42 | config = { 43 | "callbacks": { 44 | "on_event": lambda event_type, data: logger.info(f"Event: {event_type} - {json.dumps(data)}") 45 | } 46 | } 47 | 48 | # Execute the agent 49 | result = research_agent(summary_state, config) 50 | 51 | # Log the results 52 | logger.info("Research completed. Results summary:") 53 | logger.info(f"- Sources count: {len(result.get('formatted_sources', []))}") 54 | logger.info(f"- Tools used: {result.get('tools', [])}") 55 | logger.info(f"- Domains count: {len(result.get('domains', []))}") 56 | 57 | logger.info("Test completed.") 58 | 59 | def run_consolidated_test(): 60 | """Run a test of the consolidated tool calling approach.""" 61 | logger.info("Starting consolidated tool calling test...") 62 | 63 | # Create a summary state with a biographical query to test various tool types 64 | summary_state = SummaryState( 65 | research_topic="Who is Caiming Xiong and what are his contributions to AI research?", 66 | research_loop_count=0 67 | ) 68 | 69 | # Call the research agent 70 | logger.info(f"Calling research agent with biographical topic: {summary_state.research_topic}") 71 | 72 | # Create a simple configuration 73 | config = { 74 | "callbacks": { 75 | "on_event": lambda event_type, data: logger.info(f"Event: {event_type} - {json.dumps(data, default=str)}") 76 | } 77 | } 78 | 79 | # Execute the agent 80 | result = research_agent(summary_state, config) 81 | 82 | # Log the results 83 | logger.info("Biographical research completed. Results summary:") 84 | logger.info(f"- Sources count: {len(result.get('formatted_sources', []))}") 85 | tools = result.get('tools', []) 86 | logger.info(f"- Tools used: {tools}") 87 | logger.info(f"- Domains count: {len(result.get('domains', []))}") 88 | 89 | # Verify that the results include appropriate tools 90 | if 'linkedin_search' in tools: 91 | logger.info("✅ Test passed: linkedin_search tool was used for biographical research") 92 | else: 93 | logger.warning("❌ Test warning: linkedin_search tool was not used for biographical research") 94 | 95 | if 'academic_search' in tools: 96 | logger.info("✅ Test passed: academic_search tool was used for contributions research") 97 | else: 98 | logger.warning("❌ Test warning: academic_search tool was not used for contributions research") 99 | 100 | logger.info("Consolidated tool calling test completed.") 101 | 102 | if __name__ == "__main__": 103 | logger.info("Running research agent tests...") 104 | 105 | # Run both tests 106 | run_test() 107 | run_consolidated_test() -------------------------------------------------------------------------------- /models/research.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Optional, Dict, Any, List 3 | 4 | 5 | class ResearchRequest(BaseModel): 6 | """ 7 | Model for the research request. 8 | """ 9 | 10 | query: str = Field(..., description="The research query or topic to investigate") 11 | extra_effort: bool = Field( 12 | False, description="Whether to perform more extensive research (more loops)" 13 | ) 14 | minimum_effort: bool = Field( 15 | False, description="Whether to force minimum (1 loop) research" 16 | ) 17 | streaming: bool = Field(False, description="Whether to stream the response") 18 | provider: Optional[str] = Field( 19 | None, 20 | description="The LLM provider to use (e.g., 'openai', 'google', 'anthropic')", 21 | ) 22 | model: Optional[str] = Field( 23 | None, 24 | description="The specific model to use (e.g., 'o3-mini', 'gemini-2.5-pro')", 25 | ) 26 | benchmark_mode: bool = Field( 27 | False, description="Whether to run in benchmark Q&A mode for testing accuracy" 28 | ) 29 | uploaded_data_content: Optional[str] = Field( 30 | None, description="Content of the uploaded external data source" 31 | ) 32 | uploaded_files: Optional[List[str]] = Field( 33 | None, description="List of uploaded file IDs to include in research" 34 | ) 35 | steering_enabled: bool = Field( 36 | False, description="Whether to enable real-time steering functionality" 37 | ) 38 | database_info: Optional[List[Dict[str, Any]]] = Field( 39 | None, description="Information about uploaded databases for text2sql functionality" 40 | ) 41 | 42 | 43 | class ResearchResponse(BaseModel): 44 | """ 45 | Model for the research response. 46 | """ 47 | 48 | running_summary: str = Field(..., description="The comprehensive research summary") 49 | research_complete: bool = Field( 50 | ..., description="Whether the research process is complete" 51 | ) 52 | research_loop_count: int = Field( 53 | ..., description="Number of research loops performed" 54 | ) 55 | sources_gathered: List[str] = Field( 56 | default_factory=list, description="List of sources used in research" 57 | ) 58 | web_research_results: List[Dict[str, Any]] = Field( 59 | default_factory=list, description="Raw web research results" 60 | ) 61 | source_citations: Dict[str, Dict[str, str]] = Field( 62 | default_factory=dict, description="Source citations mapping" 63 | ) 64 | benchmark_mode: bool = Field( 65 | default=False, description="Whether ran in benchmark Q&A mode" 66 | ) 67 | benchmark_result: Optional[Dict[str, Any]] = Field( 68 | default=None, description="Results from benchmark testing" 69 | ) 70 | visualizations: List[Dict[str, Any]] = Field( 71 | default_factory=list, description="Generated visualizations" 72 | ) 73 | base64_encoded_images: List[Dict[str, Any]] = Field( 74 | default_factory=list, description="Base64 encoded images" 75 | ) 76 | visualization_paths: List[str] = Field( 77 | default_factory=list, description="Paths to visualization files" 78 | ) 79 | code_snippets: List[Dict[str, Any]] = Field( 80 | default_factory=list, description="Generated code snippets" 81 | ) 82 | uploaded_knowledge: Optional[str] = Field( 83 | None, description="User-provided external knowledge" 84 | ) 85 | analyzed_files: List[Dict[str, Any]] = Field( 86 | default_factory=list, description="Analysis results from uploaded files" 87 | ) 88 | 89 | 90 | class ResearchEvent(BaseModel): 91 | """ 92 | Model for research events during streaming. 93 | """ 94 | 95 | event_type: str = Field(..., description="Type of the event") 96 | data: Dict[str, Any] = Field(..., description="Event data") 97 | timestamp: Optional[str] = Field(None, description="Event timestamp") 98 | 99 | 100 | class StreamResponse(BaseModel): 101 | """ 102 | Model for streaming response. 103 | """ 104 | 105 | stream_url: str = Field(..., description="URL to connect for streaming updates") 106 | message: str = Field(..., description="Status message") 107 | -------------------------------------------------------------------------------- /session_store.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Simple file-based session store for steering functionality 4 | 5 | This allows sessions to be shared between processes 6 | """ 7 | 8 | import json 9 | import os 10 | import time 11 | from pathlib import Path 12 | from typing import Dict, Any, Optional 13 | import logging 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class SessionStore: 19 | """Simple file-based session store""" 20 | 21 | def __init__(self, store_file: str = "steering_sessions.json"): 22 | self.store_file = Path(store_file) 23 | self.sessions: Dict[str, Dict[str, Any]] = {} 24 | self.last_load_time = 0 25 | self.load_sessions() 26 | 27 | def load_sessions(self): 28 | """Load sessions from file""" 29 | try: 30 | if self.store_file.exists(): 31 | with open(self.store_file, "r") as f: 32 | data = json.load(f) 33 | self.sessions = data.get("sessions", {}) 34 | self.last_load_time = time.time() 35 | logger.info(f"[SESSION_STORE] Loaded {len(self.sessions)} sessions") 36 | except Exception as e: 37 | logger.warning(f"[SESSION_STORE] Error loading sessions: {e}") 38 | self.sessions = {} 39 | 40 | def save_sessions(self): 41 | """Save sessions to file""" 42 | try: 43 | data = {"sessions": self.sessions, "timestamp": time.time()} 44 | with open(self.store_file, "w") as f: 45 | json.dump(data, f, indent=2, default=str) 46 | logger.info(f"[SESSION_STORE] Saved {len(self.sessions)} sessions") 47 | except Exception as e: 48 | logger.error(f"[SESSION_STORE] Error saving sessions: {e}") 49 | 50 | def add_session(self, session_id: str, session_info: Dict[str, Any]): 51 | """Add a session to the store""" 52 | # Convert state object to serializable format 53 | serializable_info = self._make_serializable(session_info) 54 | self.sessions[session_id] = serializable_info 55 | self.save_sessions() 56 | logger.info(f"[SESSION_STORE] Added session {session_id}") 57 | 58 | def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: 59 | """Get a session from the store""" 60 | # Reload if file has been updated 61 | if self.store_file.exists(): 62 | file_mtime = self.store_file.stat().st_mtime 63 | if file_mtime > self.last_load_time: 64 | self.load_sessions() 65 | 66 | return self.sessions.get(session_id) 67 | 68 | def get_all_sessions(self) -> Dict[str, Dict[str, Any]]: 69 | """Get all sessions""" 70 | # Reload if file has been updated 71 | if self.store_file.exists(): 72 | file_mtime = self.store_file.stat().st_mtime 73 | if file_mtime > self.last_load_time: 74 | self.load_sessions() 75 | 76 | return self.sessions 77 | 78 | def remove_session(self, session_id: str): 79 | """Remove a session from the store""" 80 | if session_id in self.sessions: 81 | del self.sessions[session_id] 82 | self.save_sessions() 83 | logger.info(f"[SESSION_STORE] Removed session {session_id}") 84 | 85 | def _make_serializable(self, obj: Any) -> Any: 86 | """Make object serializable""" 87 | if hasattr(obj, "__dict__"): 88 | result = {} 89 | for key, value in obj.__dict__.items(): 90 | if not key.startswith("_") and not callable(value): 91 | try: 92 | result[key] = self._make_serializable(value) 93 | except: 94 | result[key] = str(value) 95 | return result 96 | elif isinstance(obj, dict): 97 | return {k: self._make_serializable(v) for k, v in obj.items()} 98 | elif isinstance(obj, list): 99 | return [self._make_serializable(item) for item in obj] 100 | elif isinstance(obj, (str, int, float, bool, type(None))): 101 | return obj 102 | else: 103 | return str(obj) 104 | 105 | 106 | # Global session store instance 107 | session_store = SessionStore() 108 | -------------------------------------------------------------------------------- /math_client_langgraph.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example client that uses langchain-mcp-adapters to connect to the math_server 3 | and uses the LangChain structured agent with the tools. 4 | """ 5 | import asyncio 6 | import os 7 | import sys 8 | from pathlib import Path 9 | 10 | from langchain_openai import ChatOpenAI 11 | from mcp import ClientSession, StdioServerParameters 12 | from mcp.client.stdio import stdio_client 13 | from langchain_mcp_adapters.tools import load_mcp_tools 14 | 15 | # Import structured agent from LangChain 16 | from langchain.agents import AgentExecutor, AgentType, initialize_agent 17 | 18 | # Check if OpenAI API key is set 19 | OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") 20 | if not OPENAI_API_KEY: 21 | print("Warning: OPENAI_API_KEY environment variable not set.") 22 | print("You can set it with: export OPENAI_API_KEY=your_api_key_here") 23 | 24 | async def main(): 25 | # Get the current directory 26 | current_dir = Path(__file__).parent.absolute() 27 | 28 | # Path to the math_server.py file 29 | server_path = current_dir / "math_server.py" 30 | 31 | if not server_path.exists(): 32 | print(f"Error: {server_path} not found. Make sure math_server.py is in the same directory.") 33 | return 34 | 35 | # Create server parameters for stdio connection 36 | server_params = StdioServerParameters( 37 | command=sys.executable, # Use the current Python interpreter 38 | args=[str(server_path)], 39 | ) 40 | 41 | # Initialize the chat model 42 | model = ChatOpenAI( 43 | model="gpt-4o", 44 | temperature=0, 45 | api_key=OPENAI_API_KEY, 46 | ) 47 | 48 | print("Starting MCP client session...") 49 | 50 | # Connect to the MCP server 51 | async with stdio_client(server_params) as (read, write): 52 | async with ClientSession(read, write) as session: 53 | # Initialize the connection 54 | await session.initialize() 55 | 56 | # Get the list of available tools 57 | tools = await load_mcp_tools(session) 58 | 59 | # Print available tools 60 | print(f"Loaded {len(tools)} tools from MCP server:") 61 | for tool in tools: 62 | print(f" - {tool.name}: {tool.description}") 63 | if hasattr(tool, 'args_schema'): 64 | if hasattr(tool.args_schema, 'schema'): 65 | print(f" Parameters: {tool.args_schema.schema()}") 66 | elif isinstance(tool.args_schema, dict): 67 | print(f" Parameters: {tool.args_schema}") 68 | else: 69 | print(f" Parameters: {type(tool.args_schema)}") 70 | else: 71 | print(" Parameters: None") 72 | 73 | # Initialize a structured agent that properly handles JSON schema tools 74 | agent_executor = initialize_agent( 75 | tools, 76 | model, 77 | agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, 78 | verbose=True 79 | ) 80 | 81 | # Define some example queries to run 82 | queries = [ 83 | "What is 25 plus 17?", 84 | "If I have 100 and subtract 28, what do I get?", 85 | "What is 13 multiplied by 5?", 86 | "What is 120 divided by 4?", 87 | "If I add 10 to 20, then multiply by 3, what's the result?" 88 | ] 89 | 90 | # Run the agent for each query 91 | for query in queries: 92 | print("\n" + "="*50) 93 | print(f"Query: {query}") 94 | print("="*50) 95 | 96 | try: 97 | # Invoke the agent 98 | agent_response = await agent_executor.ainvoke({"input": query}) 99 | 100 | # Print the agent's response 101 | print("\nAgent Response:") 102 | print(agent_response["output"]) 103 | except Exception as e: 104 | print(f"Error processing query: {e}") 105 | import traceback 106 | traceback.print_exc() 107 | 108 | if __name__ == "__main__": 109 | asyncio.run(main()) 110 | -------------------------------------------------------------------------------- /test_visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test script for the visualization agent functionality. 3 | 4 | This script tests the visualization agent in isolation to verify that 5 | it correctly analyzes research content, generates visualization code, 6 | and executes it using the E2B sandbox. 7 | """ 8 | 9 | import asyncio 10 | import os 11 | import json 12 | import logging 13 | from src.visualization_agent import VisualizationAgent 14 | 15 | # Configure logging 16 | logging.basicConfig( 17 | level=logging.INFO, 18 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' 19 | ) 20 | 21 | # Sample search result (mimics the output from the SearchAgent) 22 | SAMPLE_SEARCH_RESULT = { 23 | "formatted_sources": [ 24 | "1. Title: Global Smartphone Market Share 2023 (Source: https://example.com/smartphone-stats)", 25 | "Apple's iPhone captured 25% of the global smartphone market in 2023, while Samsung maintained its lead with 30%. Xiaomi came in third with 15%, followed by Oppo at 10% and Vivo at 8%. Other manufacturers combined for the remaining 12% market share.", 26 | "2. Title: Smartphone Unit Sales 2022-2023 (Source: https://example.com/smartphone-sales)", 27 | "Global smartphone shipments decreased by 3% in 2023 compared to 2022, with a total of 1.2 billion units shipped. Apple shipped 300 million units, Samsung 360 million, Xiaomi 180 million, Oppo 120 million, and Vivo 96 million." 28 | ], 29 | "search_string": "smartphone market share 2023", 30 | "subtask": { 31 | "type": "search", 32 | "name": "Smartphone Market Analysis", 33 | "query": "smartphone market share 2023", 34 | "tool": "general_search", 35 | "aspect": "market share statistics" 36 | } 37 | } 38 | 39 | async def test_visualization_agent(): 40 | """Test the VisualizationAgent class.""" 41 | print("\n=== Testing VisualizationAgent ===\n") 42 | 43 | # Initialize the visualization agent 44 | agent = VisualizationAgent() 45 | 46 | # Test determine_visualization_needs 47 | print("\n--- Testing determine_visualization_needs ---\n") 48 | viz_needs = await agent.determine_visualization_needs(SAMPLE_SEARCH_RESULT) 49 | 50 | if viz_needs: 51 | print(f"Visualization needed: {viz_needs.get('visualization_needed')}") 52 | print(f"Rationale: {viz_needs.get('rationale')}") 53 | print("\nVisualization types:") 54 | for viz_type in viz_needs.get("visualization_types", []): 55 | print(f"- {viz_type.get('type')}: {viz_type.get('description')}") 56 | print(f" Data requirements: {viz_type.get('data_requirements')}") 57 | else: 58 | print("Failed to determine visualization needs") 59 | return 60 | 61 | # Skip the rest if visualization is not needed 62 | if not viz_needs.get("visualization_needed", False): 63 | print("Visualization not needed for this content") 64 | return 65 | 66 | # Test generate_visualization_code 67 | print("\n--- Testing generate_visualization_code ---\n") 68 | code_data = await agent.generate_visualization_code(SAMPLE_SEARCH_RESULT, viz_needs) 69 | 70 | if code_data: 71 | print(f"Generated code for {len(code_data.get('visualization_types', []))} visualization types") 72 | print("\nCode preview (first 300 characters):") 73 | code_preview = code_data.get("code", "")[:300] + "..." if code_data.get("code") else "No code generated" 74 | print(code_preview) 75 | else: 76 | print("Failed to generate visualization code") 77 | return 78 | 79 | # Test execute_visualization_code 80 | print("\n--- Testing execute_visualization_code ---\n") 81 | viz_results = await agent.execute_visualization_code(code_data) 82 | 83 | if viz_results: 84 | if "error" in viz_results: 85 | print(f"Error executing code: {viz_results.get('error')}") 86 | else: 87 | print(f"Generated {len(viz_results.get('results', []))} visualization files") 88 | 89 | # Print visualization info 90 | for viz in viz_results.get("results", []): 91 | print(f"- {viz.get('type')}: {viz.get('filename')}") 92 | print(f" Path: {viz.get('filepath')}") 93 | 94 | # Check if file exists 95 | if os.path.exists(viz.get("filepath", "")): 96 | print(f" File exists: {os.path.getsize(viz.get('filepath', ''))} bytes") 97 | else: 98 | print(f" File does not exist") 99 | else: 100 | print("Failed to execute visualization code") 101 | return 102 | 103 | # # Test end-to-end execution 104 | # print("\n--- Testing end-to-end execution ---\n") 105 | # result = await agent.execute(SAMPLE_SEARCH_RESULT) 106 | 107 | # if result: 108 | # if "error" in result: 109 | # print(f"Error in end-to-end execution: {result.get('error')}") 110 | # else: 111 | # print("End-to-end execution successful") 112 | # print(f"Visualization needs: {result.get('visualization_needs', {}).get('visualization_needed')}") 113 | # print(f"Generated {len(result.get('results', []))} visualization files") 114 | # else: 115 | # print("End-to-end execution returned None") 116 | 117 | if __name__ == "__main__": 118 | # Run the async test 119 | asyncio.run(test_visualization_agent()) -------------------------------------------------------------------------------- /src/tools/examples/puppeteer_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example script demonstrating how to use the Puppeteer MCP server with our tool registry. 3 | 4 | This example shows how to: 5 | 1. Start the Puppeteer MCP server (assuming it's installed) 6 | 2. Register it with our tool registry 7 | 3. Execute some web browsing commands using the tools 8 | 9 | Prerequisites: 10 | - Install the MCP Puppeteer server: 11 | npm install -g @modelcontextprotocol/server-puppeteer 12 | 13 | - Run the server: 14 | npx -y @modelcontextprotocol/server-puppeteer 15 | 16 | - Alternatively, you can run it without installing: 17 | npx -y @modelcontextprotocol/server-puppeteer 18 | """ 19 | import asyncio 20 | import sys 21 | import os 22 | import shutil 23 | from pathlib import Path 24 | 25 | # Add the project root to the Python path 26 | project_root = Path(__file__).parent.parent.parent.parent 27 | sys.path.append(str(project_root)) 28 | 29 | from src.tools.registry import SearchToolRegistry 30 | from src.tools.executor import ToolExecutor 31 | from src.tools.mcp_tools import MCPToolManager 32 | 33 | async def main(): 34 | # Create a tool registry 35 | registry = SearchToolRegistry() 36 | executor = ToolExecutor(registry) 37 | 38 | # Create an MCP tool manager 39 | mcp_manager = MCPToolManager(registry) 40 | 41 | try: 42 | # Start the Puppeteer MCP server as a subprocess 43 | print("Registering Puppeteer MCP server via stdio...") 44 | tools = await mcp_manager.register_stdio_server( 45 | name="puppeteer", 46 | command="npx", 47 | args=["-y", "@modelcontextprotocol/server-puppeteer"] 48 | ) 49 | 50 | print(f"Registered {len(tools)} tools from Puppeteer MCP server:") 51 | for tool in tools: 52 | print(f" - {tool.name}: {tool.description}") 53 | 54 | # Example: Execute some web browsing commands 55 | 56 | # First, navigate to a website 57 | print("\nNavigating to a website...") 58 | result = await executor.execute_tool( 59 | "mcp.puppeteer.puppeteer_navigate", 60 | { 61 | "url": "https://example.com", 62 | "launchOptions": { 63 | "headless": "new", 64 | "args": ["--no-sandbox"], 65 | }, 66 | "allowDangerous": True 67 | }, 68 | config={} 69 | ) 70 | print(f"Navigation result: {result}") 71 | 72 | # Take a screenshot 73 | print("\nTaking screenshot...") 74 | screenshots_dir = os.path.join(project_root, "screenshots") 75 | os.makedirs(screenshots_dir, exist_ok=True) 76 | 77 | try: 78 | result = await executor.execute_tool( 79 | "mcp.puppeteer.puppeteer_screenshot", 80 | { 81 | "name": "example_screenshot", 82 | "width": 1024, 83 | "height": 768 84 | }, 85 | config={} 86 | ) 87 | print(f"Screenshot command executed. Result: {result}") 88 | 89 | # Process the result which should contain base64 image data 90 | if isinstance(result, dict) and 'content' in result: 91 | for content in result['content']: 92 | if content.get('type') == 'image' and content.get('data'): 93 | # Save the base64 data 94 | screenshot_path = os.path.join(screenshots_dir, "example_screenshot.png") 95 | import base64 96 | img_data = base64.b64decode(content['data']) 97 | with open(screenshot_path, 'wb') as f: 98 | f.write(img_data) 99 | print(f"Screenshot saved to: {screenshot_path}") 100 | break 101 | except Exception as e: 102 | print(f"Error taking screenshot: {e}") 103 | import traceback 104 | traceback.print_exc() 105 | 106 | # Get basic page information 107 | print("\nGetting page information...") 108 | try: 109 | result = await executor.execute_tool( 110 | "mcp.puppeteer.puppeteer_evaluate", 111 | { 112 | "pageFunction": "() => document.title" 113 | }, 114 | config={} 115 | ) 116 | print(f"Page title: {result}") 117 | 118 | # Get more page information with a separate call 119 | result = await executor.execute_tool( 120 | "mcp.puppeteer.puppeteer_evaluate", 121 | { 122 | "pageFunction": "() => ({ url: window.location.href, h1: document.querySelector('h1')?.textContent })" 123 | }, 124 | config={} 125 | ) 126 | print(f"Additional page info: {result}") 127 | except Exception as e: 128 | print(f"Error getting page information: {e}") 129 | import traceback 130 | traceback.print_exc() 131 | 132 | finally: 133 | # Close all MCP connections 134 | print("\nClosing MCP connections...") 135 | await mcp_manager.close_all() 136 | print("Puppeteer MCP Server closed") 137 | 138 | if __name__ == "__main__": 139 | asyncio.run(main()) -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Salesforce Open Source Community Code of Conduct 2 | 3 | ## About the Code of Conduct 4 | 5 | Equality is a core value at Salesforce. We believe a diverse and inclusive 6 | community fosters innovation and creativity, and are committed to building a 7 | culture where everyone feels included. 8 | 9 | Salesforce open-source projects are committed to providing a friendly, safe, and 10 | welcoming environment for all, regardless of gender identity and expression, 11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality, 12 | race, age, religion, level of experience, education, socioeconomic status, or 13 | other similar personal characteristics. 14 | 15 | The goal of this code of conduct is to specify a baseline standard of behavior so 16 | that people with different social values and communication styles can work 17 | together effectively, productively, and respectfully in our open source community. 18 | It also establishes a mechanism for reporting issues and resolving conflicts. 19 | 20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior 21 | in a Salesforce open-source project may be reported by contacting the Salesforce 22 | Open Source Conduct Committee at ossconduct@salesforce.com. 23 | 24 | ## Our Pledge 25 | 26 | In the interest of fostering an open and welcoming environment, we as 27 | contributors and maintainers pledge to making participation in our project and 28 | our community a harassment-free experience for everyone, regardless of gender 29 | identity and expression, sexual orientation, disability, physical appearance, 30 | body size, ethnicity, nationality, race, age, religion, level of experience, education, 31 | socioeconomic status, or other similar personal characteristics. 32 | 33 | ## Our Standards 34 | 35 | Examples of behavior that contributes to creating a positive environment 36 | include: 37 | 38 | * Using welcoming and inclusive language 39 | * Being respectful of differing viewpoints and experiences 40 | * Gracefully accepting constructive criticism 41 | * Focusing on what is best for the community 42 | * Showing empathy toward other community members 43 | 44 | Examples of unacceptable behavior by participants include: 45 | 46 | * The use of sexualized language or imagery and unwelcome sexual attention or 47 | advances 48 | * Personal attacks, insulting/derogatory comments, or trolling 49 | * Public or private harassment 50 | * Publishing, or threatening to publish, others' private information—such as 51 | a physical or electronic address—without explicit permission 52 | * Other conduct which could reasonably be considered inappropriate in a 53 | professional setting 54 | * Advocating for or encouraging any of the above behaviors 55 | 56 | ## Our Responsibilities 57 | 58 | Project maintainers are responsible for clarifying the standards of acceptable 59 | behavior and are expected to take appropriate and fair corrective action in 60 | response to any instances of unacceptable behavior. 61 | 62 | Project maintainers have the right and responsibility to remove, edit, or 63 | reject comments, commits, code, wiki edits, issues, and other contributions 64 | that are not aligned with this Code of Conduct, or to ban temporarily or 65 | permanently any contributor for other behaviors that they deem inappropriate, 66 | threatening, offensive, or harmful. 67 | 68 | ## Scope 69 | 70 | This Code of Conduct applies both within project spaces and in public spaces 71 | when an individual is representing the project or its community. Examples of 72 | representing a project or community include using an official project email 73 | address, posting via an official social media account, or acting as an appointed 74 | representative at an online or offline event. Representation of a project may be 75 | further defined and clarified by project maintainers. 76 | 77 | ## Enforcement 78 | 79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 80 | reported by contacting the Salesforce Open Source Conduct Committee 81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated 82 | and will result in a response that is deemed necessary and appropriate to the 83 | circumstances. The committee is obligated to maintain confidentiality with 84 | regard to the reporter of an incident. Further details of specific enforcement 85 | policies may be posted separately. 86 | 87 | Project maintainers who do not follow or enforce the Code of Conduct in good 88 | faith may face temporary or permanent repercussions as determined by other 89 | members of the project's leadership and the Salesforce Open Source Conduct 90 | Committee. 91 | 92 | ## Attribution 93 | 94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home], 95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. 96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], 97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc]. 98 | 99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us]. 100 | 101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/) 102 | [golang-coc]: https://golang.org/conduct 103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md 104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/ 105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/ -------------------------------------------------------------------------------- /src/tools/examples/simple_math_client.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simplified client to communicate with our custom math server. 3 | 4 | This example shows how to: 5 | 1. Start the Math server as a subprocess 6 | 2. Communicate with it directly 7 | 3. Execute math operations 8 | 9 | Prerequisites: 10 | - The math_server.py file should be in the src/tools directory 11 | """ 12 | import asyncio 13 | import json 14 | import subprocess 15 | import sys 16 | from pathlib import Path 17 | from typing import Dict, Any, List, Optional, AsyncGenerator 18 | 19 | # Add the project root to the Python path 20 | project_root = Path(__file__).parent.parent.parent.parent 21 | sys.path.append(str(project_root)) 22 | 23 | class SimpleMathClient: 24 | """A simple client for the math server.""" 25 | 26 | def __init__(self, server_path: str): 27 | """Initialize the client. 28 | 29 | Args: 30 | server_path: Path to the math server script 31 | """ 32 | self.server_path = server_path 33 | self.process = None 34 | 35 | async def start(self): 36 | """Start the math server process.""" 37 | # Start the server process 38 | self.process = await asyncio.create_subprocess_exec( 39 | sys.executable, 40 | self.server_path, 41 | stdin=subprocess.PIPE, 42 | stdout=subprocess.PIPE, 43 | stderr=subprocess.PIPE 44 | ) 45 | 46 | # Initialize the server 47 | response = await self._send_message({ 48 | "type": "request", 49 | "method": "initialize" 50 | }) 51 | 52 | print(f"Server initialized: {response}") 53 | 54 | # Get the list of available tools 55 | response = await self._send_message({ 56 | "type": "request", 57 | "method": "getTools" 58 | }) 59 | 60 | self.tools = response.get("tools", []) 61 | print(f"Available tools: {len(self.tools)}") 62 | for tool in self.tools: 63 | params = [f"{name}: {param['type']}" for name, param in tool.get('parameters', {}).items()] 64 | print(f" - {tool['name']}: {tool.get('description', '')}") 65 | print(f" Parameters: {', '.join(params)}") 66 | 67 | return self.tools 68 | 69 | async def execute_tool(self, tool_name: str, parameters: Dict[str, Any]) -> Any: 70 | """Execute a tool. 71 | 72 | Args: 73 | tool_name: Name of the tool to execute 74 | parameters: Parameters to pass to the tool 75 | 76 | Returns: 77 | The result of the operation 78 | """ 79 | response = await self._send_message({ 80 | "type": "request", 81 | "method": "executeTool", 82 | "tool": tool_name, 83 | "parameters": parameters 84 | }) 85 | 86 | if "error" in response: 87 | raise ValueError(response["error"]) 88 | 89 | return response.get("result") 90 | 91 | async def _send_message(self, message: Dict[str, Any]) -> Dict[str, Any]: 92 | """Send a message to the server and wait for a response. 93 | 94 | Args: 95 | message: The message to send 96 | 97 | Returns: 98 | The response from the server 99 | """ 100 | if not self.process: 101 | raise ValueError("Server not started") 102 | 103 | # Send the message 104 | self.process.stdin.write(json.dumps(message).encode() + b"\n") 105 | await self.process.stdin.drain() 106 | 107 | # Read the response 108 | response_line = await self.process.stdout.readline() 109 | response = json.loads(response_line.decode()) 110 | 111 | return response 112 | 113 | async def close(self): 114 | """Close the connection to the server.""" 115 | if self.process: 116 | # Close pipes 117 | if self.process.stdin: 118 | self.process.stdin.close() 119 | await self.process.stdin.wait_closed() 120 | 121 | # Terminate the process 122 | self.process.terminate() 123 | await self.process.wait() 124 | self.process = None 125 | 126 | async def main(): 127 | # Get the path to the math server 128 | math_server_path = Path(__file__).parent.parent / "math_server.py" 129 | 130 | if not math_server_path.exists(): 131 | print(f"Error: Math server not found at {math_server_path}") 132 | return 133 | 134 | # Create the client 135 | client = SimpleMathClient(str(math_server_path)) 136 | 137 | try: 138 | # Start the server 139 | await client.start() 140 | 141 | # Execute some math operations 142 | operations = [ 143 | ("add", {"a": 10, "b": 5}), 144 | ("subtract", {"a": 10, "b": 5}), 145 | ("multiply", {"a": 10, "b": 5}), 146 | ("divide", {"a": 10, "b": 5}) 147 | ] 148 | 149 | for op_name, params in operations: 150 | print(f"\nExecuting {op_name} with parameters: {params}") 151 | result = await client.execute_tool(op_name, params) 152 | print(f"Result: {result}") 153 | 154 | finally: 155 | # Close the connection 156 | await client.close() 157 | 158 | if __name__ == "__main__": 159 | asyncio.run(main()) -------------------------------------------------------------------------------- /src/tools/examples/math_custom_adapter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example demonstrating how to integrate a custom MCP client with our tool registry. 3 | 4 | This script shows how to: 5 | 1. Create a custom MCP client for our math server 6 | 2. Adapt it to work with our tool registry 7 | 3. Execute math operations using our tool executor 8 | 9 | Prerequisites: 10 | - The math_server.py file should be in the src/tools directory 11 | """ 12 | import asyncio 13 | import sys 14 | from pathlib import Path 15 | from typing import Dict, Any, List 16 | 17 | # Add the project root to the Python path 18 | project_root = Path(__file__).parent.parent.parent.parent 19 | sys.path.append(str(project_root)) 20 | 21 | from src.tools.registry import SearchToolRegistry 22 | from src.tools.executor import ToolExecutor 23 | from src.tools.tool_schema import Tool, ToolParameter, ToolParameterType 24 | from src.tools.examples.simple_math_client import SimpleMathClient 25 | 26 | class MathServerAdapter: 27 | """Adapter for the math server that works with our tool registry.""" 28 | 29 | def __init__(self, server_path: str): 30 | """Initialize the adapter. 31 | 32 | Args: 33 | server_path: Path to the math server script 34 | """ 35 | self.server_path = server_path 36 | self.client = None 37 | self.tools = [] 38 | 39 | async def start(self) -> List[Tool]: 40 | """Start the math server and register tools with the registry. 41 | 42 | Returns: 43 | List of tools that were registered 44 | """ 45 | # Create and start the client 46 | self.client = SimpleMathClient(self.server_path) 47 | tool_definitions = await self.client.start() 48 | 49 | # Convert MCP tools to our Tool format 50 | self.tools = [] 51 | for tool_def in tool_definitions: 52 | # Extract tool information 53 | name = tool_def.get("name") 54 | description = tool_def.get("description", "") 55 | 56 | # Extract parameter information 57 | parameters = [] 58 | for param_name, param_info in tool_def.get("parameters", {}).items(): 59 | # Determine parameter type 60 | param_type = ToolParameterType.STRING 61 | if param_info.get("type") == "number": 62 | param_type = ToolParameterType.NUMBER 63 | elif param_info.get("type") == "boolean": 64 | param_type = ToolParameterType.BOOLEAN 65 | 66 | # Create parameter object 67 | parameters.append(ToolParameter( 68 | name=param_name, 69 | type=param_type, 70 | required=True, 71 | description=param_info.get("description", "") 72 | )) 73 | 74 | # Create a tool wrapper function that captures the current name 75 | tool_name = name # Create a variable to capture in the closure 76 | 77 | # Define a specific async function for each tool to ensure closure captures correctly 78 | async def make_tool_function(tool_name): 79 | async def execute_wrapper(**kwargs): 80 | return await self.client.execute_tool(tool_name, kwargs) 81 | return execute_wrapper 82 | 83 | # Create our Tool object with the dynamic function 84 | tool = Tool( 85 | name=f"math.{name}", 86 | description=description, 87 | parameters=parameters, 88 | function=await make_tool_function(name) 89 | ) 90 | 91 | self.tools.append(tool) 92 | 93 | return self.tools 94 | 95 | async def close(self): 96 | """Close the connection to the server.""" 97 | if self.client: 98 | await self.client.close() 99 | self.client = None 100 | 101 | async def main(): 102 | # Create a tool registry 103 | registry = SearchToolRegistry() 104 | executor = ToolExecutor(registry) 105 | 106 | # Get the path to the math server 107 | math_server_path = Path(__file__).parent.parent / "math_server.py" 108 | 109 | if not math_server_path.exists(): 110 | print(f"Error: Math server not found at {math_server_path}") 111 | return 112 | 113 | # Create the adapter 114 | adapter = MathServerAdapter(str(math_server_path)) 115 | 116 | try: 117 | # Start the adapter and register tools 118 | tools = await adapter.start() 119 | 120 | # Register tools with the registry 121 | for tool in tools: 122 | registry.register_tool(tool) 123 | 124 | print(f"Registered {len(tools)} tools with the registry:") 125 | for tool in tools: 126 | print(f" - {tool.name}: {tool.description}") 127 | params = [f"{p.name}: {p.type}" for p in tool.parameters] 128 | print(f" Parameters: {', '.join(params)}") 129 | 130 | # Execute some math operations 131 | operations = [ 132 | ("math.add", {"a": 10, "b": 5}), 133 | ("math.subtract", {"a": 10, "b": 5}), 134 | ("math.multiply", {"a": 10, "b": 5}), 135 | ("math.divide", {"a": 10, "b": 5}) 136 | ] 137 | 138 | for op_name, params in operations: 139 | print(f"\nExecuting {op_name} with parameters: {params}") 140 | result = await executor.execute_tool(op_name, params) 141 | print(f"Result: {result}") 142 | 143 | finally: 144 | # Close the adapter 145 | await adapter.close() 146 | 147 | if __name__ == "__main__": 148 | asyncio.run(main()) -------------------------------------------------------------------------------- /src/tools/mcp_README.md: -------------------------------------------------------------------------------- 1 | # MCP (Model Context Protocol) Integration 2 | 3 | This module provides integration with the Model Context Protocol (MCP), allowing you to connect to MCP servers and use their tools within our tool registry system. 4 | 5 | ## What is MCP? 6 | 7 | The Model Context Protocol (MCP) is an open protocol that standardizes how applications provide context to LLMs. Think of MCP like a USB-C port for AI applications - it provides a standardized way to connect AI models to different data sources and tools. 8 | 9 | MCP solves the "N×M problem" where N represents different LLMs and M represents various tools and data sources. Without standardization, each combination requires custom integration work. MCP standardizes the interface, allowing any MCP-compatible tool to work with any MCP-compatible AI model or application. 10 | 11 | ## Features 12 | 13 | - Connect to MCP servers using either HTTP or stdio transport 14 | - Load tools from MCP servers and convert them to our Tool format 15 | - Register MCP tools with our tool registry 16 | - Manage multiple MCP connections 17 | 18 | ## Installation 19 | 20 | To use this module, you need to install the following packages: 21 | 22 | ```bash 23 | pip install mcp 24 | pip install langchain-mcp-adapters 25 | ``` 26 | 27 | ## Usage 28 | 29 | ### Connecting to an MCP Server via HTTP 30 | 31 | ```python 32 | import asyncio 33 | from src.tools.registry import ToolRegistry 34 | from src.tools.mcp_tools import MCPToolManager 35 | 36 | async def main(): 37 | # Create a tool registry 38 | registry = ToolRegistry() 39 | 40 | # Create an MCP tool manager 41 | mcp_manager = MCPToolManager(registry) 42 | 43 | # Connect to an MCP server over HTTP 44 | tools = await mcp_manager.register_http_server( 45 | name="my_server", 46 | base_url="http://localhost:3000" 47 | ) 48 | 49 | print(f"Registered {len(tools)} tools from MCP server") 50 | 51 | # Don't forget to close connections when done 52 | await mcp_manager.close_all() 53 | 54 | if __name__ == "__main__": 55 | asyncio.run(main()) 56 | ``` 57 | 58 | ### Connecting to an MCP Server via Stdio 59 | 60 | ```python 61 | import asyncio 62 | from src.tools.registry import ToolRegistry 63 | from src.tools.mcp_tools import MCPToolManager 64 | 65 | async def main(): 66 | # Create a tool registry 67 | registry = ToolRegistry() 68 | 69 | # Create an MCP tool manager 70 | mcp_manager = MCPToolManager(registry) 71 | 72 | # Start an MCP server as a subprocess and connect to it 73 | tools = await mcp_manager.register_stdio_server( 74 | name="math_tools", 75 | command="python", 76 | args=["math_server.py"] 77 | ) 78 | 79 | print(f"Registered {len(tools)} tools from MCP server") 80 | 81 | # Don't forget to close connections when done 82 | await mcp_manager.close_all() 83 | 84 | if __name__ == "__main__": 85 | asyncio.run(main()) 86 | ``` 87 | 88 | ### Using MCP Tools with the Tool Executor 89 | 90 | ```python 91 | import asyncio 92 | from src.tools.registry import ToolRegistry 93 | from src.tools.executor import ToolExecutor 94 | from src.tools.mcp_tools import MCPToolManager 95 | 96 | async def main(): 97 | # Create a tool registry and executor 98 | registry = ToolRegistry() 99 | executor = ToolExecutor(registry) 100 | 101 | # Create an MCP tool manager 102 | mcp_manager = MCPToolManager(registry) 103 | 104 | # Connect to a math MCP server 105 | await mcp_manager.register_stdio_server( 106 | name="math", 107 | command="python", 108 | args=["math_server.py"] 109 | ) 110 | 111 | # Execute a tool 112 | result = await executor.execute_tool( 113 | "mcp.math.add", 114 | {"a": 5, "b": 3} 115 | ) 116 | 117 | print(f"Result: {result}") # Output: Result: 8 118 | 119 | # Close all MCP connections 120 | await mcp_manager.close_all() 121 | 122 | if __name__ == "__main__": 123 | asyncio.run(main()) 124 | ``` 125 | 126 | ## Available MCP Servers 127 | 128 | There are several MCP servers available that you can use with this integration: 129 | 130 | 1. **Puppeteer MCP Server**: Provides web browsing capabilities 131 | - GitHub: https://github.com/modelcontextprotocol/servers/tree/main/src/puppeteer 132 | - Installation: `npm install -g @modelcontextprotocol/server-puppeteer` 133 | - Usage: `npx -y @modelcontextprotocol/server-puppeteer` 134 | 135 | 2. **Filesystem MCP Server**: File system operations 136 | - GitHub: https://github.com/modelcontextprotocol/servers/tree/main/src/filesystem 137 | - Installation: `npm install -g @modelcontextprotocol/server-filesystem` 138 | - Usage: `npx -y @modelcontextprotocol/server-filesystem` 139 | 140 | 3. **Fetch MCP Server**: HTTP requests and responses 141 | - GitHub: https://github.com/modelcontextprotocol/servers/tree/main/src/fetch 142 | - Installation: `npm install -g @modelcontextprotocol/server-fetch` 143 | - Usage: `npx -y @modelcontextprotocol/server-fetch` 144 | 145 | 4. **Math MCP Server**: Simple example providing math operations 146 | - See the `math_server.py` example in this repository 147 | 148 | 5. **Other MCP Servers**: More servers are available in the [MCP servers repository](https://github.com/modelcontextprotocol/servers/tree/main/src) 149 | 150 | ## Creating Your Own MCP Server 151 | 152 | You can create your own MCP server using the `fastmcp` library. Here's a simple example: 153 | 154 | ```python 155 | from mcp.server.fastmcp import FastMCP 156 | 157 | mcp = FastMCP("MyTools") 158 | 159 | @mcp.tool() 160 | def greet(name: str) -> str: 161 | """Greet a person by name.""" 162 | return f"Hello, {name}!" 163 | 164 | if __name__ == "__main__": 165 | mcp.run(transport="stdio") # or "http" 166 | ``` 167 | 168 | For more information, see the [Model Context Protocol documentation](https://modelcontextprotocol.io/). -------------------------------------------------------------------------------- /ai-research-assistant/src/components/Navbar.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | 3 | const Navbar = ({ 4 | isResearching, 5 | onShowProgress, 6 | onShowReport, 7 | hasReport, 8 | isProgressOpen, 9 | isReportOpen 10 | }) => { 11 | return ( 12 | 84 | ); 85 | }; 86 | 87 | export default Navbar; 88 | -------------------------------------------------------------------------------- /models/file_analysis.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import Optional, Dict, Any, List 3 | from datetime import datetime 4 | from enum import Enum 5 | 6 | class FileStatus(str, Enum): 7 | UPLOADED = "uploaded" 8 | PROCESSING = "processing" 9 | COMPLETED = "completed" 10 | FAILED = "failed" 11 | 12 | class AnalysisType(str, Enum): 13 | QUICK = "quick" 14 | COMPREHENSIVE = "comprehensive" 15 | CUSTOM = "custom" 16 | 17 | class FileUploadResponse(BaseModel): 18 | file_id: str = Field(..., description="Unique identifier for the uploaded file") 19 | filename: str = Field(..., description="Sanitized filename") 20 | original_name: str = Field(..., description="Original filename as uploaded") 21 | file_type: str = Field(..., description="File extension/type") 22 | file_size: int = Field(..., description="File size in bytes") 23 | status: FileStatus = Field(..., description="Current processing status") 24 | upload_timestamp: datetime = Field(..., description="When the file was uploaded") 25 | analysis_eta: Optional[str] = Field(None, description="Estimated time for analysis completion") 26 | 27 | class FileAnalysisRequest(BaseModel): 28 | analysis_type: AnalysisType = Field(default=AnalysisType.COMPREHENSIVE, description="Type of analysis to perform") 29 | custom_prompt: Optional[str] = Field(None, description="Custom prompt for analysis if using custom type") 30 | 31 | class FileAnalysisResponse(BaseModel): 32 | file_id: str = Field(..., description="Unique identifier for the analyzed file") 33 | status: FileStatus = Field(..., description="Analysis status") 34 | content_description: str = Field(..., description="Detailed description of the file content") 35 | analysis_timestamp: datetime = Field(..., description="When the analysis was completed") 36 | metadata: Dict[str, Any] = Field(default_factory=dict, description="File-specific metadata and insights") 37 | processing_time: float = Field(..., description="Time taken for analysis in seconds") 38 | error_message: Optional[str] = Field(None, description="Error message if analysis failed") 39 | 40 | class FileMetadata(BaseModel): 41 | file_id: str 42 | filename: str 43 | original_name: str 44 | file_type: str 45 | file_size: int 46 | upload_timestamp: datetime 47 | file_path: str 48 | status: FileStatus 49 | 50 | class ContentInsights(BaseModel): 51 | """Structured insights extracted from file content""" 52 | key_topics: List[str] = Field(default_factory=list, description="Main topics identified in the content") 53 | entities: List[str] = Field(default_factory=list, description="Named entities found in the content") 54 | summary: str = Field("", description="Brief summary of the content") 55 | language: Optional[str] = Field(None, description="Detected language of the content") 56 | sentiment: Optional[str] = Field(None, description="Overall sentiment if applicable") 57 | confidence_score: float = Field(0.0, description="Confidence score for the analysis (0-1)") 58 | 59 | class DocumentMetadata(BaseModel): 60 | """Metadata specific to document files""" 61 | page_count: Optional[int] = None 62 | word_count: Optional[int] = None 63 | character_count: Optional[int] = None 64 | author: Optional[str] = None 65 | creation_date: Optional[datetime] = None 66 | modification_date: Optional[datetime] = None 67 | 68 | class ImageMetadata(BaseModel): 69 | """Metadata specific to image files""" 70 | width: Optional[int] = None 71 | height: Optional[int] = None 72 | format: Optional[str] = None 73 | has_text: bool = False 74 | detected_objects: List[str] = Field(default_factory=list) 75 | color_palette: List[str] = Field(default_factory=list) 76 | 77 | class DataFileMetadata(BaseModel): 78 | """Metadata specific to structured data files (CSV, Excel, etc.)""" 79 | row_count: Optional[int] = None 80 | column_count: Optional[int] = None 81 | columns: List[str] = Field(default_factory=list) 82 | data_types: Dict[str, str] = Field(default_factory=dict) 83 | missing_values: Dict[str, int] = Field(default_factory=dict) 84 | sample_data: Optional[Dict[str, Any]] = None 85 | 86 | class AudioVideoMetadata(BaseModel): 87 | """Metadata specific to audio and video files""" 88 | duration: Optional[float] = None # in seconds 89 | format: Optional[str] = None 90 | bitrate: Optional[int] = None 91 | sample_rate: Optional[int] = None 92 | channels: Optional[int] = None 93 | has_speech: bool = False 94 | transcript_available: bool = False 95 | 96 | class FileAnalysisStatus(BaseModel): 97 | """Detailed status information for file analysis""" 98 | file_id: str 99 | status: FileStatus 100 | progress_percentage: float = Field(0.0, description="Analysis progress (0-100)") 101 | current_stage: str = Field("", description="Current processing stage") 102 | started_at: Optional[datetime] = None 103 | estimated_completion: Optional[datetime] = None 104 | error_details: Optional[str] = None 105 | 106 | class BatchAnalysisRequest(BaseModel): 107 | """Request model for batch file analysis""" 108 | file_ids: List[str] = Field(..., description="List of file IDs to analyze") 109 | analysis_type: AnalysisType = Field(default=AnalysisType.COMPREHENSIVE) 110 | custom_prompt: Optional[str] = None 111 | parallel_processing: bool = Field(True, description="Whether to process files in parallel") 112 | 113 | class BatchAnalysisResponse(BaseModel): 114 | """Response model for batch file analysis""" 115 | batch_id: str = Field(..., description="Unique identifier for the batch operation") 116 | total_files: int = Field(..., description="Total number of files in the batch") 117 | status: str = Field(..., description="Overall batch status") 118 | individual_results: List[FileAnalysisResponse] = Field(default_factory=list) 119 | batch_started_at: datetime 120 | estimated_completion: Optional[datetime] = None -------------------------------------------------------------------------------- /src/tools/math_server.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple MCP server that provides math tools. 3 | """ 4 | import sys 5 | import json 6 | from typing import Dict, Any, List, Optional, Union 7 | 8 | # Simple MCP server that doesn't rely on external libraries 9 | class SimpleMathMCP: 10 | """A simple MCP server that provides math tools.""" 11 | 12 | def __init__(self): 13 | """Initialize the MCP server.""" 14 | self.tools = { 15 | "add": { 16 | "name": "add", 17 | "description": "Add two numbers.", 18 | "parameters": { 19 | "a": {"type": "number", "description": "First number"}, 20 | "b": {"type": "number", "description": "Second number"} 21 | } 22 | }, 23 | "subtract": { 24 | "name": "subtract", 25 | "description": "Subtract b from a.", 26 | "parameters": { 27 | "a": {"type": "number", "description": "First number"}, 28 | "b": {"type": "number", "description": "Number to subtract"} 29 | } 30 | }, 31 | "multiply": { 32 | "name": "multiply", 33 | "description": "Multiply two numbers.", 34 | "parameters": { 35 | "a": {"type": "number", "description": "First number"}, 36 | "b": {"type": "number", "description": "Second number"} 37 | } 38 | }, 39 | "divide": { 40 | "name": "divide", 41 | "description": "Divide a by b. Returns an error if b is 0.", 42 | "parameters": { 43 | "a": {"type": "number", "description": "Dividend"}, 44 | "b": {"type": "number", "description": "Divisor"} 45 | } 46 | } 47 | } 48 | 49 | def handle_message(self, message: Dict[str, Any]) -> Dict[str, Any]: 50 | """Handle an MCP message. 51 | 52 | Args: 53 | message: The MCP message 54 | 55 | Returns: 56 | The response message 57 | """ 58 | if message.get("type") == "request": 59 | if message.get("method") == "initialize": 60 | return self._handle_initialize() 61 | elif message.get("method") == "getTools": 62 | return self._handle_get_tools() 63 | elif message.get("method") == "executeTool": 64 | return self._handle_execute_tool(message) 65 | else: 66 | return {"type": "error", "message": f"Unknown method: {message.get('method')}"} 67 | else: 68 | return {"type": "error", "message": f"Unknown message type: {message.get('type')}"} 69 | 70 | def _handle_initialize(self) -> Dict[str, Any]: 71 | """Handle an initialize request.""" 72 | return { 73 | "type": "response", 74 | "serverInfo": { 75 | "name": "MathTools", 76 | "version": "1.0.0" 77 | } 78 | } 79 | 80 | def _handle_get_tools(self) -> Dict[str, Any]: 81 | """Handle a getTools request.""" 82 | return { 83 | "type": "response", 84 | "tools": list(self.tools.values()) 85 | } 86 | 87 | def _handle_execute_tool(self, message: Dict[str, Any]) -> Dict[str, Any]: 88 | """Handle an executeTool request. 89 | 90 | Args: 91 | message: The MCP message 92 | 93 | Returns: 94 | The response message 95 | """ 96 | tool_name = message.get("tool") 97 | params = message.get("parameters", {}) 98 | 99 | if tool_name not in self.tools: 100 | return {"type": "error", "message": f"Unknown tool: {tool_name}"} 101 | 102 | try: 103 | # Convert parameters to numbers 104 | a = float(params.get("a", 0)) 105 | b = float(params.get("b", 0)) 106 | 107 | # Execute the appropriate operation 108 | if tool_name == "add": 109 | result = a + b 110 | elif tool_name == "subtract": 111 | result = a - b 112 | elif tool_name == "multiply": 113 | result = a * b 114 | elif tool_name == "divide": 115 | if b == 0: 116 | return {"type": "error", "message": "Cannot divide by zero"} 117 | result = a / b 118 | else: 119 | return {"type": "error", "message": f"Unknown operation: {tool_name}"} 120 | 121 | # Return the result 122 | return { 123 | "type": "response", 124 | "result": result 125 | } 126 | 127 | except Exception as e: 128 | return {"type": "error", "message": str(e)} 129 | 130 | def main(): 131 | """Main entry point.""" 132 | server = SimpleMathMCP() 133 | 134 | # Read messages from stdin and write responses to stdout 135 | while True: 136 | try: 137 | line = input() 138 | if not line: 139 | continue 140 | 141 | # Parse the message 142 | message = json.loads(line) 143 | 144 | # Handle the message 145 | response = server.handle_message(message) 146 | 147 | # Write the response 148 | print(json.dumps(response)) 149 | sys.stdout.flush() 150 | 151 | except EOFError: 152 | # End of input 153 | break 154 | except Exception as e: 155 | # Write an error response 156 | error_response = {"type": "error", "message": str(e)} 157 | print(json.dumps(error_response)) 158 | sys.stdout.flush() 159 | 160 | if __name__ == "__main__": 161 | main() 162 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | import logging 4 | from dotenv import load_dotenv 5 | from fastapi import FastAPI, Depends, HTTPException 6 | from fastapi.middleware.cors import CORSMiddleware 7 | from fastapi.staticfiles import StaticFiles 8 | from fastapi.responses import FileResponse, JSONResponse 9 | from typing import Optional 10 | 11 | # Application version 12 | VERSION = "v0.6.5" 13 | 14 | # Configure logging - write to both console and file 15 | logging.basicConfig( 16 | level=logging.INFO, 17 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 18 | handlers=[ 19 | logging.FileHandler("backend_logs.txt"), # Write to file 20 | logging.StreamHandler(), # Also show in console 21 | ], 22 | ) 23 | logger = logging.getLogger(__name__) 24 | 25 | # Create FastAPI app 26 | app = FastAPI( 27 | title="Deep Research API", 28 | description="An API for performing deep research on topics using LLMs and web search", 29 | version=VERSION, 30 | ) 31 | 32 | # Load environment variables from .env and Secrets 33 | load_dotenv(override=True) 34 | 35 | # Set default values for critical configuration 36 | llm_provider = os.environ.get("LLM_PROVIDER", "openai") 37 | llm_model = os.environ.get("LLM_MODEL", "o3-mini") 38 | max_loops = os.environ.get("MAX_WEB_RESEARCH_LOOPS", "20") 39 | 40 | # Debug information 41 | # print("\n=== Environment Variables Debug ===") 42 | # print(f"REPL_ENVIRONMENT: {os.environ.get('REPL_ENVIRONMENT')}") 43 | # 44 | # print(f"All environment variables: {dict(os.environ)}") 45 | 46 | llm_provider = os.environ.get("LLM_PROVIDER") 47 | llm_model = os.environ.get("LLM_MODEL") 48 | max_loops = os.environ.get("MAX_WEB_RESEARCH_LOOPS") 49 | 50 | print("\n=== LLM Configuration ===") 51 | print(f"LLM_PROVIDER: {llm_provider}") 52 | print(f"LLM_MODEL: {llm_model}") 53 | print(f"MAX_WEB_RESEARCH_LOOPS: {max_loops}") 54 | print("==============================\n") 55 | 56 | 57 | # Add error handler 58 | @app.exception_handler(Exception) 59 | async def general_exception_handler(request, exc): 60 | logger.error(f"Unhandled error: {str(exc)}", exc_info=True) 61 | return JSONResponse( 62 | status_code=500, content={"detail": "Internal server error", "error": str(exc)} 63 | ) 64 | 65 | 66 | # Try to import our routers 67 | try: 68 | from routers.research import router as research_router 69 | from routers.file_analysis import router as file_analysis_router 70 | from routers.database import router as database_router 71 | 72 | logger.info("Successfully imported routers") 73 | except ImportError as e: 74 | logger.error(f"Error importing routers: {e}") 75 | raise 76 | 77 | # Configure CORS and cache control 78 | from fastapi.middleware.trustedhost import TrustedHostMiddleware 79 | from fastapi.responses import Response 80 | from starlette.middleware.base import BaseHTTPMiddleware 81 | 82 | 83 | class CacheControlMiddleware(BaseHTTPMiddleware): 84 | async def dispatch(self, request, call_next): 85 | response = await call_next(request) 86 | response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate" 87 | response.headers["Pragma"] = "no-cache" 88 | response.headers["Expires"] = "0" 89 | return response 90 | 91 | 92 | app.add_middleware(CacheControlMiddleware) 93 | app.add_middleware( 94 | CORSMiddleware, 95 | allow_origins=["*"], 96 | allow_credentials=True, 97 | allow_methods=["*"], 98 | allow_headers=["*"], 99 | ) 100 | 101 | # Include routers 102 | app.include_router(research_router) 103 | app.include_router(file_analysis_router) 104 | app.include_router(database_router, prefix="/api/database") 105 | 106 | 107 | # Include simple steering router 108 | try: 109 | from routers.simple_steering_api import router as simple_steering_router 110 | 111 | app.include_router(simple_steering_router) 112 | logger.info("✅ Simple steering API enabled") 113 | except ImportError as e: 114 | logger.warning(f"⚠️ Simple steering API not available: {e}") 115 | 116 | 117 | # Mount the React build directory 118 | app.mount( 119 | "/static", 120 | StaticFiles(directory="ai-research-assistant/build/static"), 121 | name="static", 122 | ) 123 | app.mount( 124 | "/", StaticFiles(directory="ai-research-assistant/build", html=True), name="root" 125 | ) 126 | 127 | 128 | @app.get("/") 129 | async def root(): 130 | """Root endpoint that returns basic API information.""" 131 | return { 132 | "message": "Deep Research API is running", 133 | "version": VERSION, 134 | "endpoints": { 135 | "POST /deep-research": "Perform deep research on a topic with optional steering", 136 | "POST /api/files/upload": "Upload and analyze files", 137 | "GET /api/files/{file_id}/analysis": "Get file analysis results", 138 | "POST /api/database/upload": "Upload database files for text2sql", 139 | "GET /api/database/list": "List uploaded databases", 140 | "GET /api/database/{database_id}/schema": "Get database schema", 141 | "POST /api/database/query": "Execute text2sql queries", 142 | "DELETE /api/database/{database_id}": "Delete uploaded database", 143 | "POST /steering/message": "Send steering messages during research", 144 | "GET /steering/plan/{session_id}": "Get current research plan", 145 | }, 146 | "documentation": "/docs", 147 | } 148 | 149 | 150 | @app.get("/{path:path}") 151 | async def serve_react(path: str): 152 | """Catch-all route for React app""" 153 | if path.startswith("api/") or path.startswith("steering/"): 154 | raise HTTPException(status_code=404, detail="API route not found") 155 | return FileResponse("ai-research-assistant/build/index.html") 156 | 157 | 158 | if __name__ == "__main__": 159 | import uvicorn 160 | import logging 161 | 162 | logging.basicConfig(level=logging.INFO) 163 | uvicorn.run( 164 | "app:app", host="0.0.0.0", port=8000, reload=True, log_level="info", workers=1 165 | ) 166 | -------------------------------------------------------------------------------- /ai-research-assistant/src/components/LoadingIndicator.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | 3 | /** 4 | * LoadingIndicator component displays an animated progress indicator 5 | * @param {Object} props - Component props 6 | * @param {string} props.type - Type of indicator: 'bar', 'spinner', 'dots', 'pulse' (default: 'spinner') 7 | * @param {string} props.size - Size of the indicator: 'small', 'medium', 'large' (default: 'medium') 8 | * @param {string} props.color - Primary color of the indicator (default: '#1a5fb4') 9 | * @param {string} props.text - Optional text to display with the indicator 10 | * @param {boolean} props.fullScreen - Whether to display the indicator in fullscreen overlay 11 | * @param {number} props.progress - Progress value (0-100) for bar type indicators 12 | */ 13 | function LoadingIndicator({ 14 | type = 'spinner', 15 | size = 'medium', 16 | color = '#1a5fb4', 17 | text = '', 18 | fullScreen = false, 19 | progress = -1 20 | }) { 21 | // Size mappings in pixels 22 | const sizeMap = { 23 | small: { container: 16, spinner: 16, bar: 4, dots: 8, pulse: 8 }, 24 | medium: { container: 24, spinner: 24, bar: 6, dots: 10, pulse: 12 }, 25 | large: { container: 40, spinner: 40, bar: 8, dots: 14, pulse: 18 } 26 | }; 27 | 28 | const selectedSize = sizeMap[size] || sizeMap.medium; 29 | 30 | // Base styles 31 | const styles = { 32 | container: { 33 | display: 'flex', 34 | flexDirection: 'column', 35 | alignItems: 'center', 36 | justifyContent: 'center', 37 | fontFamily: '-apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif', 38 | }, 39 | fullScreenOverlay: { 40 | position: 'fixed', 41 | top: 0, 42 | left: 0, 43 | width: '100%', 44 | height: '100%', 45 | backgroundColor: 'rgba(255, 255, 255, 0.8)', 46 | zIndex: 9999, 47 | display: 'flex', 48 | alignItems: 'center', 49 | justifyContent: 'center', 50 | }, 51 | text: { 52 | marginTop: 10, 53 | fontSize: size === 'small' ? 12 : size === 'large' ? 16 : 14, 54 | color: '#333', 55 | }, 56 | // Spinner styles 57 | spinner: { 58 | width: selectedSize.spinner, 59 | height: selectedSize.spinner, 60 | border: `${selectedSize.spinner / 8}px solid rgba(0, 0, 0, 0.1)`, 61 | borderRadius: '50%', 62 | borderTop: `${selectedSize.spinner / 8}px solid ${color}`, 63 | animation: 'spin 1s linear infinite', 64 | }, 65 | // Bar styles 66 | barContainer: { 67 | width: selectedSize.container * 5, 68 | height: selectedSize.bar, 69 | backgroundColor: 'rgba(0, 0, 0, 0.1)', 70 | borderRadius: selectedSize.bar / 2, 71 | overflow: 'hidden', 72 | }, 73 | bar: { 74 | height: '100%', 75 | backgroundColor: color, 76 | borderRadius: selectedSize.bar / 2, 77 | transition: 'width 0.3s ease', 78 | width: progress >= 0 && progress <= 100 ? `${progress}%` : '0%', 79 | animation: progress < 0 ? 'barIndeterminate 2s ease-in-out infinite' : 'none', 80 | }, 81 | // Dots styles 82 | dotsContainer: { 83 | display: 'flex', 84 | alignItems: 'center', 85 | justifyContent: 'center', 86 | gap: selectedSize.dots / 2, 87 | }, 88 | dot: { 89 | width: selectedSize.dots, 90 | height: selectedSize.dots, 91 | borderRadius: '50%', 92 | backgroundColor: color, 93 | }, 94 | // Pulse styles 95 | pulse: { 96 | width: selectedSize.pulse, 97 | height: selectedSize.pulse, 98 | borderRadius: '50%', 99 | backgroundColor: color, 100 | animation: 'pulse 1.5s ease-in-out infinite', 101 | } 102 | }; 103 | 104 | // Keyframes are added as a 166 |
167 |
168 | {renderLoadingIndicator()} 169 | {text &&
{text}
} 170 |
171 |
172 | 173 | ); 174 | } 175 | 176 | export default LoadingIndicator; 177 | -------------------------------------------------------------------------------- /src/tools/test_mcp_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the MCP tools integration. 3 | """ 4 | import asyncio 5 | import pytest 6 | import sys 7 | from pathlib import Path 8 | from unittest.mock import AsyncMock, MagicMock, patch 9 | 10 | # Add the project root to the Python path 11 | project_root = Path(__file__).parent.parent.parent 12 | if str(project_root) not in sys.path: 13 | sys.path.append(str(project_root)) 14 | 15 | from src.tools.mcp_tools import MCPToolProvider, MCPToolManager 16 | from src.tools.registry import ToolRegistry 17 | from src.tools.tool_schema import Tool, ToolParameter, ToolParameterType 18 | 19 | # Create a simple test MCP server script for testing 20 | TEST_SERVER_SCRIPT = """ 21 | from mcp.server.fastmcp import FastMCP 22 | 23 | mcp = FastMCP("TestTools") 24 | 25 | @mcp.tool() 26 | def echo(message: str) -> str: 27 | \"\"\"Echo back the message.\"\"\" 28 | return message 29 | 30 | @mcp.tool() 31 | def add(a: int, b: int) -> int: 32 | \"\"\"Add two numbers.\"\"\" 33 | return a + b 34 | 35 | if __name__ == "__main__": 36 | import sys 37 | if len(sys.argv) > 1 and sys.argv[1] == "--http": 38 | mcp.run(transport="http") 39 | else: 40 | mcp.run(transport="stdio") 41 | """ 42 | 43 | @pytest.fixture 44 | def test_server_path(tmp_path): 45 | """Create a temporary test server script.""" 46 | server_path = tmp_path / "test_server.py" 47 | server_path.write_text(TEST_SERVER_SCRIPT) 48 | return server_path 49 | 50 | @pytest.fixture 51 | def tool_registry(): 52 | """Create a tool registry for testing.""" 53 | return ToolRegistry() 54 | 55 | @pytest.mark.asyncio 56 | async def test_mcp_tool_provider_stdio(test_server_path): 57 | """Test that we can connect to an MCP server using stdio transport.""" 58 | provider = MCPToolProvider("test") 59 | 60 | # Mock the ClientSession 61 | with patch("src.tools.mcp_tools.stdio_client") as mock_stdio_client, \ 62 | patch("src.tools.mcp_tools.ClientSession") as mock_session_class, \ 63 | patch("src.tools.mcp_tools.load_mcp_tools") as mock_load_mcp_tools: 64 | 65 | # Mock the stdio client 66 | mock_read = AsyncMock() 67 | mock_write = AsyncMock() 68 | mock_stdio_client.return_value.__aenter__.return_value = (mock_read, mock_write) 69 | 70 | # Mock the session 71 | mock_session = AsyncMock() 72 | mock_session_class.return_value.__aenter__.return_value = mock_session 73 | 74 | # Mock the load_mcp_tools function 75 | mock_tool = MagicMock() 76 | mock_tool.name = "echo" 77 | mock_tool.description = "Echo back the message." 78 | mock_tool._run = lambda **kwargs: kwargs["message"] 79 | mock_tool.args_schema.schema.return_value = { 80 | "properties": { 81 | "message": { 82 | "type": "string", 83 | "description": "The message to echo back" 84 | } 85 | }, 86 | "required": ["message"] 87 | } 88 | mock_load_mcp_tools.return_value = [mock_tool] 89 | 90 | # Connect to the server 91 | await provider.connect_stdio( 92 | command=sys.executable, 93 | args=[str(test_server_path)] 94 | ) 95 | 96 | # Check that the session was initialized 97 | mock_session.initialize.assert_called_once() 98 | 99 | # Load the tools 100 | tools = await provider.load_tools() 101 | 102 | # Check that we got the expected tools 103 | assert len(tools) == 1 104 | assert tools[0].name == "mcp.test.echo" 105 | assert tools[0].description == "Echo back the message." 106 | assert len(tools[0].parameters) == 1 107 | assert tools[0].parameters[0].name == "message" 108 | assert tools[0].parameters[0].type == ToolParameterType.STRING 109 | assert tools[0].parameters[0].required == True 110 | 111 | # Close the connection 112 | await provider.close() 113 | 114 | @pytest.mark.asyncio 115 | async def test_mcp_tool_manager(tool_registry): 116 | """Test the MCP tool manager.""" 117 | manager = MCPToolManager(tool_registry) 118 | 119 | # Mock the MCPToolProvider 120 | with patch("src.tools.mcp_tools.MCPToolProvider") as mock_provider_class: 121 | # Create a mock provider 122 | mock_provider = AsyncMock() 123 | mock_provider.name = "test" 124 | 125 | # Create a mock tool 126 | mock_tool = Tool( 127 | name="mcp.test.echo", 128 | description="Echo back the message.", 129 | parameters=[ 130 | ToolParameter( 131 | name="message", 132 | type=ToolParameterType.STRING, 133 | required=True, 134 | description="The message to echo back" 135 | ) 136 | ], 137 | function=lambda **kwargs: kwargs["message"] 138 | ) 139 | 140 | # Set up the mock provider 141 | mock_provider.load_tools.return_value = [mock_tool] 142 | mock_provider_class.return_value = mock_provider 143 | 144 | # Register a server 145 | tools = await manager.register_stdio_server( 146 | name="test", 147 | command=sys.executable, 148 | args=["test_server.py"] 149 | ) 150 | 151 | # Check that the provider was created and used 152 | mock_provider_class.assert_called_once_with("test") 153 | mock_provider.connect_stdio.assert_called_once() 154 | mock_provider.load_tools.assert_called_once() 155 | 156 | # Check that we got the expected tools 157 | assert len(tools) == 1 158 | assert tools[0].name == "mcp.test.echo" 159 | 160 | # Check that the tool was registered 161 | assert tool_registry.get_tool("mcp.test.echo") is not None 162 | 163 | # Close the connections 164 | await manager.close_all() 165 | mock_provider.close.assert_called_once() 166 | 167 | if __name__ == "__main__": 168 | pytest.main(["-xvs", __file__]) -------------------------------------------------------------------------------- /benchmarks/README.md: -------------------------------------------------------------------------------- 1 | # 📊 Benchmarking Guide 2 | 3 | This guide demonstrates how to evaluate the Enterprise Deep Research Agent using various benchmarks and evaluation frameworks. 4 | 5 | ## 🚀 Quick Start 6 | 7 | ### Prerequisites 8 | 9 | Complete the [main installation setup](../README.md) first, then configure your environment for benchmarking. 10 | 11 | ### 🔧 Recommended Configuration 12 | 13 | ```bash 14 | # .env file settings for optimal benchmarking 15 | LLM_PROVIDER=google 16 | LLM_MODEL=gemini-2.5-pro 17 | GOOGLE_CLOUD_PROJECT=your-project-id 18 | TAVILY_API_KEY=your-tavily-key 19 | MAX_WEB_RESEARCH_LOOPS=5 20 | 21 | # Optional: LangSmith for tracing (not required) 22 | # LANGCHAIN_API_KEY=your-key 23 | # LANGCHAIN_TRACING_V2=true 24 | # LANGCHAIN_PROJECT=your-project 25 | ``` 26 | 27 | --- 28 | 29 | ## 📋 Evaluation Modes 30 | 31 | ### 🔄 Sequential Processing 32 | Process queries one at a time using `run_research.py`: 33 | 34 | ```bash 35 | python run_research.py "Your research query" \ 36 | --provider google \ 37 | --model gemini-2.5-pro \ 38 | --max-loops 2 \ 39 | --output result.json 40 | ``` 41 | 42 | ### ⚡ Concurrent Processing 43 | Process multiple queries in parallel using `run_research_concurrent.py`: 44 | 45 | ```bash 46 | python run_research_concurrent.py \ 47 | --benchmark drb \ 48 | --max_concurrent 4 \ 49 | --provider google \ 50 | --model gemini-2.5-pro \ 51 | --max_loops 5 52 | ``` 53 | 54 | **Optional: Enable trajectory collection for detailed execution traces** 55 | ```bash 56 | python run_research_concurrent.py \ 57 | --benchmark drb \ 58 | --max_concurrent 4 \ 59 | --collect-traj # Saves detailed trajectory data for analysis 60 | ``` 61 | 62 | --- 63 | 64 | ## 🎯 Supported Benchmarks 65 | 66 | > 💡 **Default Paths**: The scripts automatically use default input/output paths for each benchmark. You can override with `--input` and `--output_dir` flags. 67 | 68 | ### 1. DeepResearchBench (DRB) 69 | 70 | Comprehensive research evaluation with 100 PhD-curated diverse queries. 71 | 72 | **Setup:** 73 | ```bash 74 | cd benchmarks 75 | git clone https://github.com/Ayanami0730/deep_research_bench.git 76 | ``` 77 | 78 | To run DeepResearchBench evaluation: 79 | 80 | **Step 1: Generate responses for all 100 queries** 81 | ```bash 82 | python run_research_concurrent.py \ 83 | --benchmark drb \ 84 | --max_concurrent 4 \ 85 | --provider google \ 86 | --model gemini-2.5-pro \ 87 | --max_loops 5 88 | ``` 89 | 90 | > 💡 **Tip**: Add `--collect-traj` to save detailed execution traces for debugging or analysis. 91 | 92 | **Step 2: Convert to benchmark format** 93 | ```bash 94 | python process_drb.py \ 95 | --input-dir deep_research_bench/data/test_data/raw_data/edr_reports_gemini \ 96 | --model-name edr_gemini 97 | ``` 98 | 99 | > 📝 **Note**: 100 | > - The processed report will be saved to `deep_research_bench/data/test_data/raw_data/edr_gemini.jsonl` 101 | > - Add your model name (eg. `edr_gemini`) to `TARGET_MODELS` in `run_benchmark.sh` inside `deep_research_bench` 102 | 103 | **Step 3: Run DeepResearchBench evaluation** 104 | ```bash 105 | cd deep_research_bench 106 | # Set up Gemini and Jina API keys for LLM evaluation and web scraping 107 | export GEMINI_API_KEY="your_gemini_api_key_here" 108 | export JINA_API_KEY="your_jina_api_key_here" 109 | bash run_benchmark.sh 110 | ``` 111 | 112 | > 🎉 **Results**: The evaluation results will be written to `deep_research_bench/results/` 113 | 114 | --- 115 | 116 | ### 2. DeepConsult 117 | 118 | Multi-perspective research evaluation with diverse query types. 119 | 120 | **Setup:** 121 | 122 | Clone the DeepConsult repo and follow the [installation steps](https://github.com/Su-Sea/ydc-deep-research-evals?tab=readme-ov-file#installation): 123 | ```bash 124 | git clone https://github.com/Su-Sea/ydc-deep-research-evals.git 125 | ``` 126 | 127 | To run DeepConsult evaluation: 128 | 129 | **Step 1: Process DeepConsult CSV queries** 130 | ```bash 131 | python run_research_concurrent.py \ 132 | --benchmark deepconsult \ 133 | --max_concurrent 4 \ 134 | --max_loops 10 \ 135 | --provider google \ 136 | --model gemini-2.5-pro 137 | ``` 138 | 139 | **Step 2: Create responses CSV for evaluation** 140 | ```bash 141 | python process_deepconsult.py \ 142 | --queries-file /path/to/queries.csv \ 143 | --baseline-file /path/to/baseline_responses.csv \ 144 | --reports-dir /path/to/generated_reports \ 145 | --output-file /path/to/custom_output.csv 146 | ``` 147 | 148 | > 📋 **This script combines**: 149 | > - Questions from the original `queries.csv` 150 | > - Baseline answers from existing responses 151 | > - Your generated candidate answers from the JSON files 152 | > - **Output**: `responses_EDR_vs_ARI_YYYY-MM-DD.csv` 153 | 154 | **Step 3: Run pairwise evaluation** 155 | ```bash 156 | cd benchmarks/ydc-deep-research-evals/evals 157 | export OPENAI_API_KEY="your_openai_key_here" 158 | python deep_research_pairwise_evals.py \ 159 | --input-data /path/to/csv/previous/step \ 160 | --output-dir results \ 161 | --model gpt-4.1-2025-04-14 \ 162 | --num-workers 4 \ 163 | --metric-num-workers 3 \ 164 | --metric-num-trials 3 165 | ``` 166 | 167 | --- 168 | 169 | ## 📈 Monitoring and Debugging 170 | 171 | ### 🔍 Real-time Progress Monitoring 172 | 173 | The concurrent script provides **detailed progress tracking**: 174 | 175 | - ⏱️ **Live progress updates** every 10 seconds showing completion rate and ETA 176 | - 📊 **Individual task logging** with timing and performance metrics 177 | - 📋 **Comprehensive summary** with success/failure statistics 178 | 179 | #### 💻 Example Output 180 | ```bash 181 | 🚀 Starting concurrent processing of 100 tasks 182 | 📊 Max workers: 4 183 | ⏱️ Rate limit delay: 1.0s 184 | 🤖 Using google/gemini-2.5-pro 185 | 186 | 📈 Progress: 45/100 completed, 2 failed, 8 in progress, ETA: 12.3min 187 | [Task 23] ✅ SUCCESS - Completed in 45.67s 188 | [Task 23] 📊 Metrics: 3 loops, 12 sources, 8,234 chars 189 | [Task 23] 📈 Throughput: 180 chars/second 190 | ``` 191 | 192 | --- 193 | 194 | ## 🐛 Troubleshooting 195 | 196 | ### ⚠️ Common Issues 197 | 198 | **Rate Limiting:** 199 | ```bash 200 | # Increase rate limit delay 201 | --rate-limit 2.0 202 | 203 | # Reduce concurrent workers 204 | --max-workers 2 205 | ``` 206 | 207 | **API Errors:** 208 | - ✅ Verify all API keys are correctly set 209 | - ✅ Check API quotas and billing 210 | - ✅ Ensure proper network connectivity 211 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guide For Enterprise Deep Research 2 | 3 | This page lists the operational governance model of this project, as well as the recommendations and requirements for how to best contribute to *Enterprise Deep Research*. We strive to obey these as best as possible. As always, thanks for contributing – we hope these guidelines make it easier and shed some light on our approach and processes. 4 | 5 | # Governance Model 6 | > Pick the most appropriate one 7 | 8 | ## Community Based 9 | 10 | The intent and goal of open sourcing this project is to increase the contributor and user base. The governance model is one where new project leads (`admins`) will be added to the project based on their contributions and efforts, a so-called "do-acracy" or "meritocracy" similar to that used by all Apache Software Foundation projects. 11 | 12 | > or 13 | 14 | ## Salesforce Sponsored 15 | 16 | The intent and goal of open sourcing this project is to increase the contributor and user base. However, only Salesforce employees will be given `admin` rights and will be the final arbitrars of what contributions are accepted or not. 17 | 18 | > or 19 | 20 | ## Published but not supported 21 | 22 | The intent and goal of open sourcing this project is because it may contain useful or interesting code/concepts that we wish to share with the larger open source community. Although occasional work may be done on it, we will not be looking for or soliciting contributions. 23 | 24 | # Getting started 25 | 26 | Please join the community by: 27 | - Opening issues and discussions on our [GitHub repository](https://github.com/SalesforceAIResearch/enterprise-deep-research/) 28 | - Following our research updates and announcements 29 | 30 | Also please make sure to take a look at the project documentation in the README files to understand our current capabilities and future directions. 31 | 32 | 33 | # Issues, requests & ideas 34 | 35 | Use GitHub Issues page to submit issues, enhancement requests and discuss ideas. 36 | 37 | ### Bug Reports and Fixes 38 | - If you find a bug, please search for it in the [Issues](https://github.com/SalesforceAIResearch/enterprise-deep-research/issues), and if it isn't already tracked, 39 | [create a new issue](https://github.com/SalesforceAIResearch/enterprise-deep-research/issues/new). Fill out the "Bug Report" section of the issue template. Even if an Issue is closed, feel free to comment and add details, it will still 40 | be reviewed. 41 | - Issues that have already been identified as a bug (note: able to reproduce) will be labelled `bug`. 42 | - If you'd like to submit a fix for a bug, [send a Pull Request](#creating_a_pull_request) and mention the Issue number. 43 | - Include tests that isolate the bug and verifies that it was fixed. 44 | 45 | ### New Features 46 | - If you'd like to add new functionality to this project, describe the problem you want to solve in a [new Issue](https://github.com/SalesforceAIResearch/enterprise-deep-research/issues/new). 47 | - Issues that have been identified as a feature request will be labelled `enhancement`. 48 | - If you'd like to implement the new feature, please wait for feedback from the project 49 | maintainers before spending too much time writing the code. In some cases, `enhancement`s may 50 | not align well with the project objectives at the time. 51 | 52 | ### Tests, Documentation, Miscellaneous 53 | - If you'd like to improve the tests, you want to make the documentation clearer, you have an 54 | alternative implementation of something that may have advantages over the way its currently 55 | done, or you have any other change, we would be happy to hear about it! 56 | - If its a trivial change, go ahead and [send a Pull Request](#creating_a_pull_request) with the changes you have in mind. 57 | - If not, [open an Issue](https://github.com/SalesforceAIResearch/enterprise-deep-research/issues/new) to discuss the idea first. 58 | 59 | If you're new to our project and looking for some way to make your first contribution, look for 60 | Issues labelled `good first contribution`. 61 | 62 | # Contribution Checklist 63 | 64 | - [x] Clean, simple, well styled code 65 | - [x] Commits should be atomic and messages must be descriptive. Related issues should be mentioned by Issue number. 66 | - [x] Comments 67 | - Module-level & function-level comments. 68 | - Comments on complex blocks of code or algorithms (include references to sources). 69 | - [x] Tests 70 | - The test suite, if provided, must be complete and pass 71 | - Increase code coverage, not versa. 72 | - [x] Dependencies 73 | - Minimize number of dependencies. 74 | - Prefer Apache 2.0, BSD3, MIT, ISC and MPL licenses. 75 | - [x] Reviews 76 | - Changes must be approved via peer code review 77 | 78 | # Creating a Pull Request 79 | 80 | 1. **Ensure the bug/feature was not already reported** by searching on GitHub under Issues. If none exists, create a new issue so that other contributors can keep track of what you are trying to add/fix and offer suggestions (or let you know if there is already an effort in progress). 81 | 3. **Clone** the forked repo to your machine. 82 | 4. **Create** a new branch to contain your work (e.g. `git br fix-issue-11`) 83 | 4. **Commit** changes to your own branch. 84 | 5. **Push** your work back up to your fork. (e.g. `git push fix-issue-11`) 85 | 6. **Submit** a Pull Request against the `main` branch and refer to the issue(s) you are fixing. Try not to pollute your pull request with unintended changes. Keep it simple and small. 86 | 7. **Sign** the Salesforce CLA (you will be prompted to do so when submitting the Pull Request) 87 | 88 | > **NOTE**: Be sure to [sync your fork](https://help.github.com/articles/syncing-a-fork/) before making a pull request. 89 | 90 | # Contributor License Agreement ("CLA") 91 | In order to accept your pull request, we need you to submit a CLA. You only need 92 | to do this once to work on any of Salesforce's open source projects. 93 | 94 | Complete your CLA here: 95 | 96 | # Issues 97 | We use GitHub issues to track public bugs. Please ensure your description is 98 | clear and has sufficient instructions to be able to reproduce the issue. 99 | 100 | # Code of Conduct 101 | Please follow our [Code of Conduct](CODE_OF_CONDUCT.md). 102 | 103 | # License 104 | By contributing your code, you agree to license your contribution under the terms of our project [LICENSE](LICENSE.txt) and to sign the [Salesforce CLA](https://cla.salesforce.com/sign-cla) 105 | -------------------------------------------------------------------------------- /test_specialized_searches.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Test script for specialized search functions: 5 | - linkedin_search 6 | - github_search 7 | - academic_search 8 | 9 | This script demonstrates how each specialized search function uses different 10 | parameters for the Tavily API, and shows the results from each search. 11 | """ 12 | 13 | import os 14 | import json 15 | import argparse 16 | import logging 17 | from dotenv import load_dotenv 18 | from src.utils import linkedin_search, github_search, academic_search 19 | 20 | # Configure logging to show the search parameters 21 | logging.basicConfig( 22 | level=logging.INFO, 23 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' 24 | ) 25 | 26 | def run_linkedin_search(query, top_k=3, min_score=0.4): 27 | """Run a LinkedIn search and display results""" 28 | print("\n" + "="*80) 29 | print(f"LINKEDIN SEARCH: {query}") 30 | print("="*80) 31 | 32 | print("\nRunning LinkedIn search with the following constraints:") 33 | print(f"- Minimum score threshold: {min_score}") 34 | print(f"- Maximum results: {top_k}") 35 | print(f"- Domain restriction: linkedin.com") 36 | print(f"- Raw content: Only for highest scoring result\n") 37 | 38 | results = linkedin_search( 39 | query=query, 40 | include_raw_content=True, 41 | top_k=top_k, 42 | min_score=min_score 43 | ) 44 | 45 | print(f"\nFound {len(results.get('results', []))} LinkedIn results with score >= {min_score}") 46 | 47 | for i, result in enumerate(results.get('results', [])): 48 | print(f"\nResult #{i+1}: score={result.get('score'):.4f}") 49 | print(f"Title: {result.get('title')}") 50 | print(f"URL: {result.get('url')}") 51 | print(f"Has raw content: {result.get('raw_content') is not None}") 52 | print("-" * 40) 53 | 54 | return results 55 | 56 | def run_github_search(query, top_k=5, min_score=0.6): 57 | """Run a GitHub search and display results""" 58 | print("\n" + "="*80) 59 | print(f"GITHUB SEARCH: {query}") 60 | print("="*80) 61 | 62 | print("\nRunning GitHub search with the following constraints:") 63 | print(f"- Minimum score threshold: {min_score}") 64 | print(f"- Maximum results: {top_k}") 65 | print(f"- Domain restriction: github.com") 66 | print(f"- Score boost: 20% for actual repositories (/blob/ or /tree/ in URL)") 67 | print(f"- Raw content: Only kept for repositories with code content\n") 68 | 69 | results = github_search( 70 | query=query, 71 | include_raw_content=True, 72 | top_k=top_k, 73 | min_score=min_score 74 | ) 75 | 76 | print(f"\nFound {len(results.get('results', []))} GitHub results with score >= {min_score}") 77 | 78 | for i, result in enumerate(results.get('results', [])): 79 | print(f"\nResult #{i+1}: score={result.get('score'):.4f}") 80 | print(f"Title: {result.get('title')}") 81 | print(f"URL: {result.get('url')}") 82 | print(f"Has raw content: {result.get('raw_content') is not None}") 83 | print("-" * 40) 84 | 85 | return results 86 | 87 | def run_academic_search(query, top_k=5, min_score=0.65, recent_years=5): 88 | """Run an academic search and display results""" 89 | print("\n" + "="*80) 90 | print(f"ACADEMIC SEARCH: {query}") 91 | print("="*80) 92 | 93 | print("\nRunning Academic search with the following constraints:") 94 | print(f"- Minimum score threshold: {min_score}") 95 | print(f"- Maximum results: {top_k}") 96 | print(f"- Domain restriction: Academic domains (arxiv.org, scholar.google.com, etc.)") 97 | print(f"- Score boost: 15% for papers from the last {recent_years} years") 98 | print(f"- Score boost: Up to 30% for academic indicators (doi, abstract, etc.)") 99 | print(f"- Search depth: Advanced (for more comprehensive academic results)\n") 100 | 101 | results = academic_search( 102 | query=query, 103 | include_raw_content=True, 104 | top_k=top_k, 105 | min_score=min_score, 106 | recent_years=recent_years 107 | ) 108 | 109 | print(f"\nFound {len(results.get('results', []))} academic results with score >= {min_score}") 110 | 111 | for i, result in enumerate(results.get('results', [])): 112 | print(f"\nResult #{i+1}: score={result.get('score'):.4f}") 113 | print(f"Title: {result.get('title')}") 114 | print(f"URL: {result.get('url')}") 115 | print(f"Has raw content: {result.get('raw_content') is not None}") 116 | print("-" * 40) 117 | 118 | return results 119 | 120 | def main(): 121 | """Main function to demonstrate all specialized search functions""" 122 | 123 | # Load environment variables from .env file if it exists 124 | load_dotenv() 125 | 126 | # Parse command-line arguments 127 | parser = argparse.ArgumentParser(description='Test specialized search functions') 128 | parser.add_argument('--type', '-t', choices=['linkedin', 'github', 'academic', 'all'], 129 | default='all', help='Type of search to run') 130 | parser.add_argument('--query', '-q', type=str, 131 | help='Search query to execute') 132 | parser.add_argument('--verbose', '-v', action='store_true', 133 | help='Show full search results including raw content') 134 | args = parser.parse_args() 135 | 136 | # Set default queries for each search type if not provided 137 | linkedin_query = args.query or "frank wang salesforce, tell me about his background in details" 138 | github_query = args.query or "open source RAG framework python" 139 | academic_query = args.query or "large language models in healthcare research" 140 | 141 | # Run the requested search type(s) 142 | results = {} 143 | 144 | if args.type == 'all' or args.type == 'linkedin': 145 | results['linkedin'] = run_linkedin_search(linkedin_query) 146 | 147 | if args.type == 'all' or args.type == 'github': 148 | results['github'] = run_github_search(github_query) 149 | 150 | if args.type == 'all' or args.type == 'academic': 151 | results['academic'] = run_academic_search(academic_query) 152 | 153 | # If verbose mode, print the full JSON results 154 | if args.verbose: 155 | print("\n\nFULL SEARCH RESULTS (JSON):") 156 | print(json.dumps(results, indent=2)) 157 | 158 | print("\nAll searches completed.") 159 | 160 | if __name__ == "__main__": 161 | main() -------------------------------------------------------------------------------- /test_unified_query.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | from src.graph import create_fresh_graph 4 | from src.configuration import Configuration, LLMProvider, SearchAPI 5 | from src.state import SummaryStateInput 6 | from typing import Dict, Any 7 | 8 | def test_unified_query(max_loops: int = 3, recursion_limit: int = 100) -> Dict[str, Any]: 9 | """ 10 | Test the unified query planning and parallel search capabilities on a complex topic 11 | 12 | Args: 13 | max_loops: Maximum number of research loops to perform 14 | recursion_limit: Maximum recursion limit for the graph 15 | 16 | Returns: 17 | Dictionary containing the research results 18 | """ 19 | # Create a configuration object - using a dict to match expected format 20 | config = { 21 | "llm_provider": "openai", 22 | "llm_model": "gpt-4o-mini", # Alternatively, use an available model 23 | "search_api": "tavily", # Use available search API 24 | "include_raw_content": True, 25 | "max_web_research_loops": max_loops, 26 | "recursion_limit": recursion_limit 27 | } 28 | 29 | # Create a fresh graph 30 | graph = create_fresh_graph() 31 | 32 | # Set the complex research topic (on Agentic RAG systems) 33 | research_topic = """ 34 | Agentic RAG Systems - Architecture, Benefits, and Implementation. 35 | 36 | I'm interested in understanding the architectural components, advanced techniques, 37 | benefits, and implementation approaches for Agentic RAG systems that enhance 38 | traditional Retrieval-Augmented Generation with agent-based capabilities. 39 | 40 | Please cover core components, architectural patterns, strategic retrieval planning, 41 | multi-hop reasoning, adaptive retrieval, self-improvement mechanisms, and notable 42 | implementations. 43 | """ 44 | 45 | # Create input object 46 | input_obj = SummaryStateInput(research_topic=research_topic) 47 | 48 | # Add tracing callback to debug state transitions 49 | def trace_state_changes(state, node_name=None): 50 | if node_name == "research_agent" and hasattr(state, "subtopic_queries"): 51 | print(f"\n[DEBUG] After research_agent node, state.subtopic_queries: {state.subtopic_queries}") 52 | 53 | try: 54 | # Run the graph with the input 55 | print(f"\n[TEST] Starting unified query planning test with max_loops={max_loops}, recursion_limit={recursion_limit}...") 56 | 57 | # Configure tracing in the config 58 | config["callbacks"] = {"on_node_run": trace_state_changes} 59 | 60 | # Run the graph with the input 61 | results = graph.invoke( 62 | input_obj, 63 | {"configurable": config} 64 | ) 65 | 66 | # Print the final summary 67 | print("\n[TEST] ========== RESEARCH RESULTS ==========\n") 68 | print(f"[TEST] Final summary length: {len(results['running_summary']) if 'running_summary' in results else 0} characters") 69 | print(f"[TEST] Research loops completed: {results.get('research_loop_count', 0)}") 70 | 71 | # Print additional debug info 72 | print(f"[TEST] Has subtopic_queries: {'subtopic_queries' in results}") 73 | if 'subtopic_queries' in results: 74 | print(f"[TEST] Number of subtopic queries: {len(results['subtopic_queries'])}") 75 | print(f"[TEST] Subtopic queries: {results['subtopic_queries']}") 76 | 77 | # Print the first 500 characters of the summary as a preview 78 | if 'running_summary' in results and results['running_summary']: 79 | print("\n[TEST] ----- Summary Preview (first 500 chars) -----") 80 | print(results['running_summary'][:500]) 81 | print("...\n") 82 | 83 | print("[TEST] ----- Source Citations -----") 84 | if results.get('source_citations'): 85 | for num, src in sorted(results['source_citations'].items()): 86 | print(f"[{num}] {src.get('title', 'No title')} : {src.get('url', 'No URL')}") 87 | else: 88 | print("[TEST] No source citations found") 89 | 90 | return results 91 | 92 | except Exception as e: 93 | print(f"\n[TEST] ERROR: {str(e)}") 94 | print("[TEST] Test failed but we can still analyze what happened up to this point.") 95 | return {"error": str(e), "research_topic": research_topic} 96 | 97 | if __name__ == "__main__": 98 | # Run the test with minimal loops and a high recursion limit to avoid errors 99 | # We'll focus on the initial query planning and parallel search 100 | results = test_unified_query(max_loops=0, recursion_limit=200) 101 | 102 | # Save output to file even if we had an error 103 | output_file = "unified_query_results.md" 104 | 105 | try: 106 | with open(output_file, "w") as f: 107 | f.write("# Unified Query Planning Research Results\n\n") 108 | f.write("## Research Topic\n") 109 | f.write(f"{results.get('research_topic', 'Topic not available')}\n\n") 110 | 111 | if "error" in results: 112 | f.write(f"## Error Encountered\n\n") 113 | f.write(f"{results['error']}\n\n") 114 | 115 | if "running_summary" in results: 116 | f.write("## Full Research Summary\n\n") 117 | f.write(results['running_summary']) 118 | f.write("\n\n") 119 | 120 | if results.get('source_citations'): 121 | f.write("## Source Citations\n\n") 122 | for num, src in sorted(results['source_citations'].items()): 123 | f.write(f"[{num}] {src.get('title', 'No title')} : {src.get('url', 'No URL')}\n") 124 | else: 125 | f.write("## Source Citations\n\nNo source citations found\n") 126 | 127 | print(f"\n[TEST] Results saved to '{output_file}'") 128 | except Exception as e: 129 | print(f"\n[TEST] Error saving results: {str(e)}") 130 | 131 | # Save a debug file with full results as JSON for inspection 132 | try: 133 | debug_results = {k: v for k, v in results.items() if k != "running_summary"} 134 | if "running_summary" in results: 135 | debug_results["summary_length"] = len(results["running_summary"]) 136 | 137 | with open("unified_query_debug.json", "w") as f: 138 | json.dump(debug_results, f, indent=2, default=str) 139 | 140 | print("[TEST] Debug information saved to 'unified_query_debug.json'") 141 | except Exception as e: 142 | print(f"[TEST] Error saving debug info: {str(e)}") -------------------------------------------------------------------------------- /ai-research-assistant/src/components/CodeSnippetViewer.js: -------------------------------------------------------------------------------- 1 | import React, { useState } from 'react'; 2 | import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'; 3 | import { tomorrow, vs } from 'react-syntax-highlighter/dist/esm/styles/prism'; 4 | 5 | function CodeSnippetViewer({ snippet, initialCollapsed = true }) { 6 | const [codeState, setCodeState] = useState({ 7 | collapsed: initialCollapsed, // Use the prop to determine initial state 8 | wrap: false, 9 | darkTheme: true, 10 | copied: false, 11 | }); 12 | 13 | // Validation check 14 | if (!snippet || typeof snippet !== 'object' || !snippet.code) { 15 | console.warn("CodeSnippetViewer received invalid snippet data:", snippet); 16 | return null; 17 | } 18 | 19 | // Extract necessary props 20 | const { language, code } = snippet; 21 | 22 | // Determine what code to display based on collapsed state 23 | const getDisplayCode = () => { 24 | if (!codeState.collapsed) return code; 25 | 26 | const lines = code.split('\n'); 27 | if (lines.length <= 15) return code; 28 | 29 | return lines.slice(0, 15).join('\n') + '\n// ...'; 30 | }; 31 | 32 | // Check if this is a special case of minimal content 33 | const isMinimalContent = code.trim().length < 15 || code.trim().split(/\s+/).length <= 2; 34 | const singleWord = code.trim().split(/\s+/).length === 1; 35 | 36 | // Determine the language to use for syntax highlighting 37 | let syntaxLanguage = language; 38 | if (!language && isMinimalContent) { 39 | syntaxLanguage = 'plaintext'; 40 | } 41 | 42 | 43 | // Code block toolbar toggle handlers 44 | const toggleCollapse = () => { 45 | setCodeState(prev => ({ ...prev, collapsed: !prev.collapsed })); 46 | }; 47 | 48 | const toggleWrap = () => { 49 | setCodeState(prev => ({ ...prev, wrap: !prev.wrap })); 50 | }; 51 | 52 | const toggleTheme = () => { 53 | setCodeState(prev => ({ ...prev, darkTheme: !prev.darkTheme })); 54 | }; 55 | 56 | const handleCodeCopy = () => { 57 | // Always copy the full code regardless of display state 58 | navigator.clipboard.writeText(code) 59 | .then(() => { 60 | setCodeState(prev => ({ ...prev, copied: true })); 61 | setTimeout(() => { 62 | setCodeState(prev => ({ ...prev, copied: false })); 63 | }, 2000); 64 | }) 65 | .catch(err => { 66 | console.error("Failed to copy code: ", err); 67 | }); 68 | }; 69 | 70 | // Define CSS styles for the component 71 | const styles = { 72 | container: { 73 | maxWidth: '670px', 74 | width: '100%', 75 | overflow: 'hidden', 76 | borderRadius: '4px', 77 | border: '1px solid #e2e8f0', 78 | backgroundColor: '#f8fafc' 79 | }, 80 | toolbar: { 81 | display: 'flex', 82 | justifyContent: 'space-between', 83 | alignItems: 'center', 84 | padding: '4px 10px', 85 | borderBottom: '1px solid #e2e8f0', 86 | backgroundColor: '#f1f5f9' 87 | }, 88 | toolbarLeft: { 89 | display: 'flex', 90 | alignItems: 'center' 91 | }, 92 | toolbarRight: { 93 | display: 'flex', 94 | gap: '8px' 95 | }, 96 | language: { 97 | fontSize: '12px', 98 | color: '#64748b', 99 | fontFamily: 'monospace' 100 | }, 101 | button: { 102 | padding: '2px 8px', 103 | fontSize: '12px', 104 | border: '1px solid #cbd5e1', 105 | borderRadius: '3px', 106 | backgroundColor: 'white', 107 | color: '#64748b', 108 | cursor: 'pointer' 109 | }, 110 | activeButton: { 111 | backgroundColor: '#e2e8f0' 112 | }, 113 | successButton: { 114 | backgroundColor: '#bbf7d0', 115 | borderColor: '#86efac', 116 | color: '#166534' 117 | } 118 | }; 119 | 120 | return ( 121 |
122 |
123 |
124 | {language && {language}} 125 |
126 |
127 | 136 | 145 | 154 | 163 |
164 |
165 |
166 | 191 | {getDisplayCode()} 192 | 193 |
194 |
195 | ); 196 | } 197 | 198 | export default CodeSnippetViewer; -------------------------------------------------------------------------------- /src/tools/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tool registry for managing and accessing search tools. 3 | 4 | This module provides a registry system for search tools used by the research agent. 5 | It allows tools to be registered, retrieved, and managed in a central location. 6 | """ 7 | 8 | import logging 9 | import traceback 10 | from typing import Dict, List, Optional, Any 11 | 12 | from src.tools.search_tools import ( 13 | GeneralSearchTool, 14 | AcademicSearchTool, 15 | GithubSearchTool, 16 | LinkedinSearchTool 17 | ) 18 | from src.tools.text2sql_tool import Text2SQLTool 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | class SearchToolRegistry: 23 | """Registry for search tools that can be used by the research agent.""" 24 | 25 | def __init__(self, config=None): 26 | """ 27 | Initialize the search tool registry. 28 | 29 | Args: 30 | config: Configuration object to pass to tools 31 | """ 32 | logger.info(f"[SearchToolRegistry.__init__] Initializing registry with config type: {type(config).__name__ if config else 'None'}") 33 | self.config = config 34 | self.tools = {} 35 | self._register_default_tools() 36 | 37 | def _register_default_tools(self): 38 | """Register the default set of search tools.""" 39 | logger.info("Registering default search tools") 40 | 41 | try: 42 | # Initialize tools with config 43 | logger.info(f"[SearchToolRegistry._register_default_tools] Creating tool instances with config") 44 | 45 | general_search = GeneralSearchTool() 46 | general_search.config = self.config 47 | logger.info(f"[SearchToolRegistry._register_default_tools] Created GeneralSearchTool instance") 48 | 49 | github_search = GithubSearchTool() 50 | github_search.config = self.config 51 | logger.info(f"[SearchToolRegistry._register_default_tools] Created GithubSearchTool instance") 52 | 53 | academic_search = AcademicSearchTool() 54 | academic_search.config = self.config 55 | logger.info(f"[SearchToolRegistry._register_default_tools] Created AcademicSearchTool instance") 56 | 57 | linkedin_search = LinkedinSearchTool() 58 | linkedin_search.config = self.config 59 | logger.info(f"[SearchToolRegistry._register_default_tools] Created LinkedinSearchTool instance") 60 | 61 | text2sql = Text2SQLTool() 62 | text2sql.config = self.config 63 | logger.info(f"[SearchToolRegistry._register_default_tools] Created Text2SQLTool instance") 64 | 65 | # Add tools to registry 66 | logger.info(f"[SearchToolRegistry._register_default_tools] Registering tools in registry") 67 | self.register_tool(general_search) 68 | self.register_tool(github_search) 69 | self.register_tool(academic_search) 70 | self.register_tool(linkedin_search) 71 | self.register_tool(text2sql) 72 | 73 | logger.info(f"Registered {len(self.tools)} default search tools") 74 | logger.info(f"[SearchToolRegistry._register_default_tools] Registered tools: {list(self.tools.keys())}") 75 | except Exception as e: 76 | logger.error(f"Error registering default tools: {str(e)}") 77 | logger.error(f"[SearchToolRegistry._register_default_tools] Traceback: {traceback.format_exc()}") 78 | # Re-raise the exception to help with debugging 79 | raise 80 | 81 | def register_tool(self, tool): 82 | """ 83 | Register a new tool with the registry. 84 | 85 | Args: 86 | tool: The tool to register 87 | """ 88 | logger.info(f"[SearchToolRegistry.register_tool] Registering tool: {tool.name}, class: {tool.__class__.__name__}") 89 | # Ensure the tool has the config attribute 90 | if hasattr(tool, 'config') and tool.config is None: 91 | tool.config = self.config 92 | logger.info(f"[SearchToolRegistry.register_tool] Set config on tool: {tool.name}") 93 | self.tools[tool.name] = tool 94 | logger.info(f"[SearchToolRegistry.register_tool] Tool {tool.name} successfully registered") 95 | 96 | def get_tool(self, tool_name): 97 | """ 98 | Get a tool by name. 99 | 100 | Args: 101 | tool_name: The name of the tool to retrieve 102 | 103 | Returns: 104 | The requested tool or None if not found 105 | """ 106 | logger.info(f"[SearchToolRegistry.get_tool] Retrieving tool: {tool_name}") 107 | tool = self.tools.get(tool_name) 108 | if not tool: 109 | logger.warning(f"Tool not found: {tool_name}") 110 | logger.info(f"[SearchToolRegistry.get_tool] Available tools: {list(self.tools.keys())}") 111 | return None 112 | 113 | logger.info(f"[SearchToolRegistry.get_tool] Retrieved tool {tool_name}, class: {tool.__class__.__name__}") 114 | return tool 115 | 116 | def get_all_tools(self): 117 | """ 118 | Get all registered tools. 119 | 120 | Returns: 121 | List of all registered tools 122 | """ 123 | logger.info(f"[SearchToolRegistry.get_all_tools] Returning all {len(self.tools)} registered tools") 124 | return list(self.tools.values()) 125 | 126 | def get_tool_description(self, tool_name): 127 | """ 128 | Get the description of a tool. 129 | 130 | Args: 131 | tool_name: The name of the tool 132 | 133 | Returns: 134 | The description of the tool or None if the tool is not found 135 | """ 136 | logger.info(f"[SearchToolRegistry.get_tool_description] Getting description for tool: {tool_name}") 137 | tool = self.get_tool(tool_name) 138 | if tool and hasattr(tool, 'description'): 139 | logger.info(f"[SearchToolRegistry.get_tool_description] Found description for {tool_name}") 140 | return tool.description 141 | logger.info(f"[SearchToolRegistry.get_tool_description] No description found for {tool_name}") 142 | return None 143 | 144 | def get_all_tool_descriptions(self): 145 | """ 146 | Get descriptions of all registered tools. 147 | 148 | Returns: 149 | Dict mapping tool names to their descriptions 150 | """ 151 | logger.info(f"[SearchToolRegistry.get_all_tool_descriptions] Getting descriptions for all tools") 152 | return { 153 | name: tool.description if hasattr(tool, 'description') else "No description" 154 | for name, tool in self.tools.items() 155 | } -------------------------------------------------------------------------------- /src/configuration.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, fields 3 | from typing import Any, Optional 4 | 5 | from langchain_core.runnables import RunnableConfig 6 | from dataclasses import dataclass 7 | 8 | from enum import Enum 9 | 10 | 11 | class SearchAPI(Enum): 12 | TAVILY = "tavily" 13 | 14 | 15 | class LLMProvider(Enum): 16 | OPENAI = "openai" 17 | ANTHROPIC = "anthropic" 18 | GROQ = "groq" 19 | GOOGLE = "google" 20 | 21 | 22 | class ActivityVerbosity(Enum): 23 | NONE = "none" # No activity generation 24 | LOW = "low" # Minimal activities, only for major steps 25 | MEDIUM = "medium" # Moderate detail, for important transitions 26 | HIGH = "high" # Detailed activities for most state changes 27 | 28 | 29 | class Configuration: 30 | """The configurable fields for the research assistant.""" 31 | 32 | def __init__(self, **kwargs): 33 | # Initialize from kwargs or environment variables 34 | self._max_web_research_loops = kwargs.get("max_web_research_loops") 35 | self._search_api = kwargs.get("search_api") 36 | self._fetch_full_page = kwargs.get("fetch_full_page") 37 | self._include_raw_content = kwargs.get("include_raw_content") 38 | self._llm_provider = kwargs.get("llm_provider") 39 | self._llm_model = kwargs.get("llm_model") 40 | 41 | # Set activity generation defaults to avoid needing .env settings 42 | self._enable_activity_generation = kwargs.get( 43 | "enable_activity_generation", True 44 | ) 45 | self._activity_verbosity = kwargs.get("activity_verbosity", "medium") 46 | self._activity_llm_provider = kwargs.get("activity_llm_provider", "openai") 47 | self._activity_llm_model = kwargs.get("activity_llm_model", "o3-mini") 48 | 49 | @property 50 | def max_web_research_loops(self) -> int: 51 | # Get the value from environment variables at runtime, not during class definition 52 | if self._max_web_research_loops is not None: 53 | return self._max_web_research_loops 54 | 55 | env_value = os.environ.get("MAX_WEB_RESEARCH_LOOPS") 56 | print(f"Reading MAX_WEB_RESEARCH_LOOPS from environment: {env_value}") 57 | return int(env_value or "10") 58 | 59 | """ 60 | Maximum number of web research loops to perform before finalizing. 61 | This helps prevent hitting the graph recursion limit (default 25). 62 | 63 | Recommended values: 64 | - For simple research topics: 5-8 65 | - For complex research topics: 8-15 66 | - Use values >15 with caution as you may hit recursion limits 67 | """ 68 | 69 | @property 70 | def search_api(self): 71 | if self._search_api is not None: 72 | return self._search_api 73 | return SearchAPI(os.environ.get("SEARCH_API") or "tavily") 74 | 75 | @property 76 | def fetch_full_page(self) -> bool: 77 | if self._fetch_full_page is not None: 78 | return self._fetch_full_page 79 | return (os.environ.get("FETCH_FULL_PAGE") or "False").lower() in ( 80 | "true", 81 | "1", 82 | "t", 83 | ) 84 | 85 | @property 86 | def include_raw_content(self) -> bool: 87 | if self._include_raw_content is not None: 88 | return self._include_raw_content 89 | return (os.environ.get("INCLUDE_RAW_CONTENT") or "True").lower() in ( 90 | "true", 91 | "1", 92 | "t", 93 | ) 94 | 95 | # LLM configuration 96 | @property 97 | def llm_provider(self): 98 | if self._llm_provider is not None: 99 | return self._llm_provider 100 | return LLMProvider(os.environ.get("LLM_PROVIDER") or "google") 101 | 102 | @property 103 | def llm_model(self) -> str: 104 | if self._llm_model is not None: 105 | return self._llm_model 106 | 107 | provider_str = os.environ.get("LLM_PROVIDER") 108 | return os.environ.get("LLM_MODEL") or ( 109 | "o3-mini" 110 | if provider_str == "openai" 111 | else ( 112 | "claude-3-7-sonnet" 113 | if provider_str == "anthropic" 114 | else ( 115 | "llama-3.3-70b-versatile" 116 | if provider_str == "groq" 117 | else ( 118 | "gemini-2.5-pro" 119 | if provider_str == "google" 120 | else "gemini-2.5-pro" 121 | ) 122 | ) 123 | ) 124 | ) # Default to Gemini 2.5 Pro 125 | 126 | # Activity generation configuration 127 | @property 128 | def enable_activity_generation(self) -> bool: 129 | """Whether to enable the generation of detailed activity descriptions.""" 130 | if self._enable_activity_generation is not None: 131 | return self._enable_activity_generation 132 | return (os.environ.get("ENABLE_ACTIVITY_GENERATION") or "True").lower() in ( 133 | "true", 134 | "1", 135 | "t", 136 | ) 137 | 138 | @property 139 | def activity_verbosity(self) -> ActivityVerbosity: 140 | """The level of detail for generated activities.""" 141 | if self._activity_verbosity is not None: 142 | return self._activity_verbosity 143 | verbosity_str = os.environ.get("ACTIVITY_VERBOSITY") or "medium" 144 | return ActivityVerbosity(verbosity_str.lower()) 145 | 146 | @property 147 | def activity_llm_provider(self) -> LLMProvider: 148 | """The LLM provider to use for activity generation.""" 149 | if self._activity_llm_provider is not None: 150 | return self._activity_llm_provider 151 | provider_str = os.environ.get("ACTIVITY_LLM_PROVIDER") or "openai" 152 | return LLMProvider(provider_str.lower()) 153 | 154 | @property 155 | def activity_llm_model(self) -> str: 156 | """The LLM model to use for activity generation.""" 157 | if self._activity_llm_model is not None: 158 | return self._activity_llm_model 159 | return os.environ.get("ACTIVITY_LLM_MODEL") or "o3-mini" 160 | 161 | @classmethod 162 | def from_runnable_config( 163 | cls, config: Optional[RunnableConfig] = None 164 | ) -> "Configuration": 165 | """Create a Configuration instance from a RunnableConfig.""" 166 | configurable = ( 167 | config["configurable"] if config and "configurable" in config else {} 168 | ) 169 | 170 | # Config properties to check 171 | properties = [ 172 | "max_web_research_loops", 173 | "search_api", 174 | "fetch_full_page", 175 | "include_raw_content", 176 | "llm_provider", 177 | "llm_model", 178 | "enable_activity_generation", 179 | "activity_verbosity", 180 | "activity_llm_provider", 181 | "activity_llm_model", 182 | ] 183 | 184 | values = {} 185 | for prop in properties: 186 | # Get from configurable or environment 187 | env_value = os.environ.get(prop.upper()) 188 | config_value = configurable.get(prop) 189 | 190 | # Use configurable value first, then environment value 191 | if config_value is not None: 192 | values[prop] = config_value 193 | elif env_value is not None: 194 | values[prop] = env_value 195 | 196 | # Create new Configuration instance with values 197 | return cls(**values) 198 | -------------------------------------------------------------------------------- /routers/database.py: -------------------------------------------------------------------------------- 1 | """ 2 | Database API Router 3 | 4 | This module provides API endpoints for database upload, management, and text2sql functionality. 5 | """ 6 | 7 | from fastapi import APIRouter, HTTPException, UploadFile, File, Form 8 | from typing import List, Dict, Any, Optional 9 | import logging 10 | import json 11 | from pydantic import BaseModel, Field 12 | 13 | from src.tools.text2sql_tool import Text2SQLTool 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | # Create router 18 | router = APIRouter(tags=["Database"]) 19 | 20 | # Global Text2SQL tool instance 21 | text2sql_tool = Text2SQLTool() 22 | 23 | # Pydantic models for request/response 24 | class DatabaseUploadResponse(BaseModel): 25 | database_id: str 26 | filename: str 27 | file_type: str 28 | tables: List[str] 29 | message: str 30 | 31 | class DatabaseListResponse(BaseModel): 32 | databases: List[Dict[str, Any]] 33 | 34 | class DatabaseSchemaResponse(BaseModel): 35 | database_id: str 36 | filename: str 37 | database_schema: Dict[str, Any] = Field(alias="schema") 38 | 39 | class Text2SQLRequest(BaseModel): 40 | query: str 41 | database_id: Optional[str] = None 42 | 43 | class Text2SQLResponse(BaseModel): 44 | query: str 45 | sql: Optional[str] = None 46 | results: Optional[Dict[str, Any]] = None 47 | error: Optional[str] = None 48 | database: str 49 | executed_at: str 50 | 51 | @router.post( 52 | "/upload", 53 | response_model=DatabaseUploadResponse, 54 | summary="Upload a database file", 55 | description="Upload SQLite or CSV files for text2sql querying" 56 | ) 57 | async def upload_database( 58 | file: UploadFile = File(...), 59 | file_type: Optional[str] = Form(None) 60 | ): 61 | """ 62 | Upload a database file (SQLite or CSV) for text2sql functionality. 63 | 64 | Args: 65 | file: The database file to upload 66 | file_type: Optional file type override (sqlite, csv) 67 | 68 | Returns: 69 | Database upload response with ID and metadata 70 | """ 71 | try: 72 | # Validate file type 73 | allowed_types = ['.db', '.sqlite', '.sqlite3', '.csv', '.json'] 74 | file_extension = '.' + file.filename.split('.')[-1].lower() 75 | 76 | if file_extension not in allowed_types: 77 | raise HTTPException( 78 | status_code=400, 79 | detail=f"Unsupported file type. Allowed types: {', '.join(allowed_types)}" 80 | ) 81 | 82 | # Read file content 83 | file_content = await file.read() 84 | 85 | # Upload to Text2SQL tool 86 | database_id = text2sql_tool.upload_database( 87 | file_content=file_content, 88 | filename=file.filename, 89 | file_type=file_type 90 | ) 91 | 92 | # Get database info 93 | db_info = text2sql_tool.databases[database_id] 94 | 95 | logger.info(f"Successfully uploaded database: {file.filename} (ID: {database_id})") 96 | 97 | return DatabaseUploadResponse( 98 | database_id=database_id, 99 | filename=file.filename, 100 | file_type=db_info['file_type'], 101 | tables=db_info['metadata']['tables'], 102 | message=f"Database {file.filename} uploaded successfully" 103 | ) 104 | 105 | except Exception as e: 106 | logger.error(f"Error uploading database: {e}") 107 | raise HTTPException(status_code=500, detail=str(e)) 108 | 109 | @router.get( 110 | "/list", 111 | response_model=DatabaseListResponse, 112 | summary="List uploaded databases", 113 | description="Get a list of all uploaded databases" 114 | ) 115 | async def list_databases(): 116 | """ 117 | Get a list of all uploaded databases. 118 | 119 | Returns: 120 | List of database information 121 | """ 122 | try: 123 | databases = text2sql_tool.list_databases() 124 | 125 | return DatabaseListResponse(databases=databases) 126 | 127 | except Exception as e: 128 | logger.error(f"Error listing databases: {e}") 129 | raise HTTPException(status_code=500, detail=str(e)) 130 | 131 | @router.get( 132 | "/{database_id}/schema", 133 | response_model=DatabaseSchemaResponse, 134 | summary="Get database schema", 135 | description="Get detailed schema information for a specific database" 136 | ) 137 | async def get_database_schema(database_id: str): 138 | """ 139 | Get schema information for a specific database. 140 | 141 | Args: 142 | database_id: The ID of the database 143 | 144 | Returns: 145 | Database schema information 146 | """ 147 | try: 148 | if database_id not in text2sql_tool.databases: 149 | raise HTTPException(status_code=404, detail="Database not found") 150 | 151 | db_info = text2sql_tool.databases[database_id] 152 | schema = text2sql_tool.get_database_schema(database_id) 153 | 154 | return DatabaseSchemaResponse( 155 | database_id=database_id, 156 | filename=db_info['filename'], 157 | database_schema=schema 158 | ) 159 | 160 | except HTTPException: 161 | raise 162 | except Exception as e: 163 | logger.error(f"Error getting database schema: {e}") 164 | raise HTTPException(status_code=500, detail=str(e)) 165 | 166 | @router.post( 167 | "/query", 168 | response_model=Text2SQLResponse, 169 | summary="Execute text2sql query", 170 | description="Convert natural language to SQL and execute against uploaded databases" 171 | ) 172 | async def execute_text2sql(request: Text2SQLRequest): 173 | """ 174 | Execute a text2sql query against uploaded databases. 175 | 176 | Args: 177 | request: Text2SQL request with query and optional database_id 178 | 179 | Returns: 180 | Query results with SQL and data 181 | """ 182 | try: 183 | # Execute the query 184 | result = text2sql_tool.query_database( 185 | db_id=request.database_id, 186 | natural_language_query=request.query 187 | ) 188 | 189 | return Text2SQLResponse(**result) 190 | 191 | except Exception as e: 192 | logger.error(f"Error executing text2sql query: {e}") 193 | return Text2SQLResponse( 194 | query=request.query, 195 | error=str(e), 196 | database=request.database_id or "Unknown", 197 | executed_at=text2sql_tool._get_current_time() 198 | ) 199 | 200 | @router.delete( 201 | "/{database_id}", 202 | summary="Delete database", 203 | description="Delete an uploaded database and its files" 204 | ) 205 | async def delete_database(database_id: str): 206 | """ 207 | Delete an uploaded database. 208 | 209 | Args: 210 | database_id: The ID of the database to delete 211 | 212 | Returns: 213 | Success message 214 | """ 215 | try: 216 | if database_id not in text2sql_tool.databases: 217 | raise HTTPException(status_code=404, detail="Database not found") 218 | 219 | success = text2sql_tool.delete_database(database_id) 220 | 221 | if success: 222 | return {"message": f"Database {database_id} deleted successfully"} 223 | else: 224 | raise HTTPException(status_code=500, detail="Failed to delete database") 225 | 226 | except HTTPException: 227 | raise 228 | except Exception as e: 229 | logger.error(f"Error deleting database: {e}") 230 | raise HTTPException(status_code=500, detail=str(e)) 231 | 232 | # Add a method to get current time for the tool 233 | def _get_current_time(): 234 | from datetime import datetime 235 | return datetime.now().isoformat() 236 | 237 | # Monkey patch the method into the tool 238 | Text2SQLTool._get_current_time = staticmethod(_get_current_time) 239 | -------------------------------------------------------------------------------- /ai-research-assistant/src/index.css: -------------------------------------------------------------------------------- 1 | @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap'); 2 | 3 | @tailwind base; 4 | @tailwind components; 5 | @tailwind utilities; 6 | 7 | @layer base { 8 | html { 9 | font-family: 'Inter', ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; 10 | } 11 | } 12 | 13 | body { 14 | font-family: 'Inter', ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; 15 | font-weight: 400; 16 | letter-spacing: normal; 17 | color: theme('colors.text'); 18 | background-color: theme('colors.bg'); 19 | } 20 | 21 | .mono { 22 | font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; 23 | } 24 | 25 | .subtle-shadow { 26 | box-shadow: 0 1px 3px rgba(0, 0, 0, 0.05); 27 | } 28 | 29 | .checklist-item::before { 30 | content: ''; 31 | display: inline-block; 32 | width: 16px; 33 | height: 16px; 34 | margin-right: 8px; 35 | border: 1px solid #000; 36 | border-radius: 2px; 37 | vertical-align: middle; 38 | } 39 | 40 | .checklist-item.completed::before { 41 | background-color: #000; 42 | background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' fill='none' viewBox='0 0 24 24' stroke='white'%3E%3Cpath stroke-linecap='round' stroke-linejoin='round' stroke-width='2' d='M5 13l4 4L19 7' /%3E%3C/svg%3E"); 43 | background-size: 12px; 44 | background-position: center; 45 | background-repeat: no-repeat; 46 | } 47 | 48 | .animate-pulse { 49 | animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite; 50 | } 51 | 52 | @keyframes pulse { 53 | 0%, 100% { 54 | opacity: 1; 55 | } 56 | 50% { 57 | opacity: 0.7; 58 | } 59 | } 60 | 61 | .typing-indicator::after { 62 | content: '...'; 63 | animation: typing 1.5s infinite; 64 | } 65 | 66 | @keyframes typing { 67 | 0%, 100% { content: ''; } 68 | 25% { content: '.'; } 69 | 50% { content: '..'; } 70 | 75% { content: '...'; } 71 | } 72 | 73 | .research-item { 74 | cursor: pointer; 75 | transition: background-color 0.2s; 76 | } 77 | 78 | .research-item:hover { 79 | background-color: rgba(0, 0, 0, 0.03); 80 | } 81 | 82 | .research-item.active { 83 | background-color: rgba(0, 0, 0, 0.05); 84 | border-left-color: #000; 85 | border-left-width: 3px; 86 | } 87 | 88 | /* Table of contents styling */ 89 | .prose ul li ul { 90 | margin-top: 0.25rem; 91 | margin-bottom: 0.25rem; 92 | } 93 | 94 | .prose ul li { 95 | position: relative; 96 | } 97 | 98 | /* For nested list indentation */ 99 | .prose ul li ul { 100 | margin-left: 1.5rem; 101 | } 102 | 103 | /* Special styling for research report */ 104 | h1 { 105 | font-size: 1.75rem; 106 | font-weight: 700; 107 | margin-top: 2rem; 108 | margin-bottom: 1rem; 109 | border-bottom: 1px solid #e5e7eb; 110 | padding-bottom: 0.5rem; 111 | } 112 | 113 | h2 { 114 | font-size: 1.5rem; 115 | font-weight: 600; 116 | margin-top: 1.5rem; 117 | margin-bottom: 0.75rem; 118 | } 119 | 120 | /* Special styling for cover page */ 121 | h1 + p { 122 | font-size: 1.1rem; 123 | line-height: 1.7; 124 | } 125 | 126 | /* Table of contents specific styling */ 127 | h2:has(+ ul) + ul { 128 | border-left: 2px solid #e5e7eb; 129 | padding-left: 1rem; 130 | } 131 | 132 | /* Adding page numbers alignment for TOC items */ 133 | .prose ul li p:last-child { 134 | margin-left: auto; 135 | font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; 136 | } 137 | 138 | /* Code blocks styling */ 139 | pre { 140 | margin: 1rem 0; 141 | overflow-x: auto; 142 | border-radius: 0.375rem; 143 | font-size: 0.8rem !important; /* Applied to
 directly */
144 | }
145 | 
146 | /* General code tag styling (applies to inline and block if not overridden) */
147 | code {
148 |   font-family: 'Menlo', 'Monaco', 'Courier New', monospace !important;
149 |   font-size: 0.85em !important; /* Relative to parent, usually 
 or 

*/ 150 | } 151 | 152 | /* Inline code specific styling (rendered by FinalReport.js with .final-report-inline-code) */ 153 | .final-report-inline-code { 154 | background-color: rgba(27,31,35,.07) !important; /* GitHub-like subtle gray */ 155 | padding: .2em .4em !important; 156 | margin: 0 !important; 157 | font-size: 85% !important; /* Relative to its parent (e.g.,

) */ 158 | border-radius: 3px !important; /* Softer radius */ 159 | font-family: SFMono-Regular, Consolas, 'Liberation Mono', Menlo, Courier, monospace !important; 160 | color: #1f2328 !important; /* Darker text for better readability */ 161 | white-space: nowrap !important; /* Keep inline code on one line */ 162 | display: inline !important; /* Ensure it behaves as inline */ 163 | vertical-align: baseline !important; 164 | } 165 | 166 | .dark .final-report-inline-code { 167 | background-color: rgba(110,118,129,0.4) !important; 168 | color: #c9d1d9 !important; 169 | } 170 | 171 | /* Minimal code block styling (rendered by FinalReport.js with .final-report-minimal-block) */ 172 | .final-report-minimal-block { 173 | display: block !important; 174 | padding: 0.2em 0.5em !important; 175 | margin: 0.5em 0 !important; 176 | background-color: #f0f2f5 !important; 177 | border: 1px solid #d9dcdf !important; 178 | border-radius: 4px !important; 179 | font-family: SFMono-Regular, Consolas, 'Liberation Mono', Menlo, Courier, monospace !important; 180 | font-size: 0.8em !important; /* Explicit size for the block */ 181 | white-space: pre !important; 182 | overflow-x: auto !important; 183 | color: #303133 !important; 184 | line-height: 1.4 !important; 185 | min-height: auto !important; 186 | } 187 | .final-report-minimal-block code { /* Styling for inside minimal block */ 188 | font-family: inherit !important; /* Inherit from .final-report-minimal-block */ 189 | background-color: transparent !important; 190 | padding: 0 !important; 191 | font-size: inherit !important; /* Inherit size from pre */ 192 | color: inherit !important; 193 | } 194 | .dark .final-report-minimal-block { 195 | background-color: #2c2c2c !important; 196 | border-color: #444 !important; 197 | color: #c9d1d9 !important; 198 | } 199 | 200 | /* Empty code block styling (rendered by FinalReport.js with .final-report-empty-block) */ 201 | .final-report-empty-block { 202 | display: block !important; 203 | padding: 0.5em 0.8em !important; 204 | margin: 0.5em 0 !important; 205 | background-color: #f9f9f9 !important; 206 | border: 1px dashed #ccc !important; 207 | border-radius: 4px !important; 208 | font-family: SFMono-Regular, Consolas, 'Liberation Mono', Menlo, Courier, monospace !important; 209 | font-size: 0.75em !important; 210 | color: #777 !important; 211 | min-height: 2em !important; 212 | line-height: 1 !important; 213 | text-align: center !important; 214 | } 215 | .final-report-empty-block code { /* Styling for inside empty block */ 216 | font-family: inherit !important; 217 | background-color: transparent !important; 218 | padding: 0 !important; 219 | font-size: inherit !important; 220 | color: inherit !important; 221 | } 222 | .dark .final-report-empty-block { 223 | background-color: #252525 !important; 224 | border-color: #555 !important; 225 | color: #888 !important; 226 | } 227 | 228 | /* Ensure SyntaxHighlighter's

 doesn't add extra margins or borders */
229 | .code-block-content pre {
230 |   margin: 0 !important;
231 |   border-radius: 0 !important;
232 |   border: none !important;
233 | }
234 | 
235 | /* Ensure keywords in code blocks are properly styled by SyntaxHighlighter */
236 | /* These are examples, actual classes depend on SyntaxHighlighter's theme */
237 | .token.keyword,
238 | .token.function,
239 | .token.builtin,
240 | .token.class-name {
241 |   /* color: #d73a49; Example color, theme will override */
242 | }
243 | 
244 | /* Ensure all text within SyntaxHighlighter uses monospace */
245 | .prism-code .token { /* General token class for Prism */
246 |   font-family: 'Menlo', 'Monaco', 'Courier New', monospace !important;
247 | }


--------------------------------------------------------------------------------
/test_benchmark.py:
--------------------------------------------------------------------------------
  1 | import asyncio
  2 | import logging
  3 | from src.graph import create_fresh_graph
  4 | from src.state import SummaryState
  5 | from models.research import ResearchRequest
  6 | from src.configuration import Configuration
  7 | 
  8 | # Set up logging
  9 | logging.basicConfig(level=logging.INFO)
 10 | logger = logging.getLogger(__name__)
 11 | 
 12 | async def test_benchmark_question():
 13 |     print("\n=== Starting Benchmark Test ===")
 14 |     
 15 |     # Define the benchmark question
 16 |     question = """An African author tragically passed away in a tragic road accident. As a child, he'd wanted to be a police officer. He lectured at a private university from 2018 until his death. In 2018, this author spoke about writing stories that have no sell by date in an interview. One of his books was selected to be a compulsory school reading in an African country in 2017. Which years did this author work as a probation officer?"""
 17 |     expected_answer = "1988-96"
 18 |     
 19 |     print(f"Question: {question}")
 20 |     print(f"Expected Answer: {expected_answer}")
 21 |     print("=== Running Research ===\n")
 22 | 
 23 |     # Create configuration
 24 |     config = {
 25 |         "configurable": {
 26 |             "thread_id": "benchmark_test",
 27 |             "llm_provider": "google",
 28 |             "llm_model": "gemini-2.5-pro",
 29 |             "max_web_research_loops": 10
 30 |         },
 31 |         "recursion_limit": 100, # Set recursion limit to 100
 32 |     }    
 33 | 
 34 |     # config = {
 35 |     #     "configurable": {
 36 |     #         "thread_id": "benchmark_test",
 37 |     #         "llm_provider": "anthropic",
 38 |     #         "llm_model": "claude-3-5-sonnet",
 39 |     #         "max_web_research_loops": 10
 40 |     #     },
 41 |     #     "recursion_limit": 100, # Set recursion limit to 100
 42 |     # }  
 43 |     
 44 |     # Create initial state with benchmark mode explicitly enabled
 45 |     initial_state = SummaryState(
 46 |         research_topic=question,
 47 |         benchmark_mode=True,  # Explicitly enable benchmark mode
 48 |         extra_effort=True,    # Enable thorough research
 49 |         research_loop_count=0,
 50 |         running_summary="",
 51 |         search_query="",
 52 |         research_complete=False,
 53 |         knowledge_gap="",
 54 |         search_results_empty=False,
 55 |         selected_search_tool="general_search",
 56 |         sources_gathered=[],
 57 |         web_research_results=[],
 58 |         source_citations={},
 59 |         research_plan={},     # Initialize with empty research_plan to avoid AttributeError
 60 |         previous_answers=[],  # Initialize empty previous_answers list
 61 |         reflection_history=[], # Initialize empty reflection_history list
 62 |         llm_provider=config["configurable"]["llm_provider"],
 63 |         llm_model=config["configurable"]["llm_model"],
 64 |         config={
 65 |             "benchmark": {
 66 |                 "expected_answer": expected_answer,
 67 |                 "confidence_threshold": 0.8
 68 |             },
 69 |             "configurable": config["configurable"]
 70 |         }
 71 |     )
 72 | 
 73 |     # Log the initial state
 74 |     logger.info(f"Initial state benchmark_mode: {initial_state.benchmark_mode}")
 75 |     logger.info(f"Initial state config: {initial_state.config}")
 76 | 
 77 |     # Create graph and run research
 78 |     graph = create_fresh_graph()
 79 |     
 80 |     # Debug log before invoking graph
 81 |     print("\nState before graph invocation:")
 82 |     print(f"benchmark_mode: {initial_state.benchmark_mode}")
 83 |     print(f"research_topic: {initial_state.research_topic}")
 84 |     
 85 |     try:
 86 |         final_state = await graph.ainvoke(
 87 |             initial_state,
 88 |             config=config
 89 |         )
 90 |         
 91 |         # Debug log after graph execution
 92 |         print("\nState after graph execution:")
 93 |         print(f"benchmark_mode: {getattr(final_state, 'benchmark_mode', False)}")
 94 |         print(f"research_complete: {getattr(final_state, 'research_complete', False)}")
 95 | 
 96 |         # Debugging - print all available keys in the state
 97 |         print("\nDetailed state inspection:")
 98 |         if isinstance(final_state, dict):
 99 |             print("State is a dictionary with keys:", list(final_state.keys()))
100 |         elif hasattr(final_state, '__dict__'):
101 |             print("State attributes:", list(final_state.__dict__.keys()))
102 |         else:
103 |             print(f"State type: {type(final_state)}")
104 |             
105 |         # Try to find benchmark_result
106 |         benchmark_result = None
107 |         
108 |         # Check if state is a dictionary (new way state is returned)
109 |         if isinstance(final_state, dict):
110 |             if 'benchmark_result' in final_state:
111 |                 benchmark_result = final_state['benchmark_result']
112 |                 print("Found benchmark_result in dictionary state")
113 |             else:
114 |                 print("benchmark_result not found in dictionary state")
115 |                 
116 |         # Check if state is an object (old way state is returned)
117 |         elif hasattr(final_state, 'benchmark_result') and final_state.benchmark_result is not None:
118 |             benchmark_result = final_state.benchmark_result
119 |             print("Found benchmark_result in object state")
120 |         elif hasattr(final_state, '__dict__') and 'benchmark_result' in final_state.__dict__:
121 |             benchmark_result = final_state.__dict__['benchmark_result']
122 |             print("Found benchmark_result in object __dict__")
123 |         else:
124 |             print("benchmark_result not found in any state form")
125 |             
126 |         # Additional fallback check - look in the final answer
127 |         if not benchmark_result and hasattr(final_state, 'previous_answers') and final_state.previous_answers:
128 |             print("Checking previous_answers for final result")
129 |             # The last answer in previous_answers might contain our answer
130 |             last_answer = final_state.previous_answers[-1]
131 |             print(f"Last answer: {last_answer}")
132 | 
133 |         print("\n=== Results ===")
134 |         
135 |         if benchmark_result:
136 |             answer = benchmark_result.get('answer', 'No answer generated')
137 |             confidence = benchmark_result.get('confidence', 0.0)
138 |             sources = benchmark_result.get('sources', [])
139 |             
140 |             # Verify the answer
141 |             is_correct = expected_answer.lower().strip() in answer.lower().strip()
142 |             
143 |             print(f"Generated Answer: {answer}")
144 |             print(f"Expected Answer: {expected_answer}")
145 |             print(f"Correct: {is_correct}")
146 |             print(f"Confidence: {confidence}")
147 |             print(f"Sources: {sources}")
148 |             
149 |             # Print detailed analysis
150 |             if not is_correct:
151 |                 print("\nAnalysis:")
152 |                 print(f"- Generated answer differs from expected answer")
153 |                 print(f"- Confidence level: {'High' if confidence > 0.8 else 'Medium' if confidence > 0.5 else 'Low'}")
154 |                 if not sources:
155 |                     print("- No sources were cited to support the answer")
156 |         else:
157 |             print("No benchmark result found in final state")
158 |             print("Final Summary:", getattr(final_state, 'running_summary', 'No summary available'))
159 |             print("Research Complete:", getattr(final_state, 'research_complete', False))
160 |             
161 |             # Check if previous_answers exists and has entries
162 |             previous_answers = getattr(final_state, 'previous_answers', [])
163 |             if previous_answers:
164 |                 print("\nFound answers in previous_answers:")
165 |                 for i, answer in enumerate(previous_answers):
166 |                     print(f"\nAnswer {i+1}:")
167 |                     print(f"- Answer: {answer.get('answer', 'No answer')}")
168 |                     print(f"- Confidence: {answer.get('confidence', 0.0)}")
169 |                     print(f"- Sources: {answer.get('sources', [])}")
170 |     
171 |     except Exception as e:
172 |         print(f"\nERROR: Exception encountered during execution: {e}")
173 |         import traceback
174 |         print(traceback.format_exc())
175 | 
176 | if __name__ == "__main__":
177 |     asyncio.run(test_benchmark_question()) 


--------------------------------------------------------------------------------
/src/tools/examples/playwright_example.py:
--------------------------------------------------------------------------------
  1 | """
  2 | Example script demonstrating how to use the Playwright MCP server with our tool registry.
  3 | 
  4 | This example shows how to:
  5 | 1. Start the Playwright MCP server (via stdio)
  6 | 2. Register it with our tool registry
  7 | 3. Execute some web browsing commands using the Playwright tools
  8 | 
  9 | Prerequisites:
 10 | - Requires Node.js and npx to be installed.
 11 | - The Playwright MCP server will be downloaded via npx if not already present.
 12 | """
 13 | import asyncio
 14 | import sys
 15 | import os
 16 | from pathlib import Path
 17 | 
 18 | # Add the project root to the Python path
 19 | project_root = Path(__file__).parent.parent.parent.parent
 20 | sys.path.append(str(project_root))
 21 | 
 22 | from src.tools.registry import SearchToolRegistry
 23 | from src.tools.executor import ToolExecutor
 24 | from src.tools.mcp_tools import MCPToolManager
 25 | 
 26 | async def main():
 27 |     # Create a tool registry
 28 |     registry = SearchToolRegistry()
 29 |     executor = ToolExecutor(registry)
 30 | 
 31 |     # Create an MCP tool manager
 32 |     mcp_manager = MCPToolManager(registry)
 33 | 
 34 |     try:
 35 |         # Start the Playwright MCP server as a subprocess using stdio
 36 |         # Using headless mode by default
 37 |         print("Registering Playwright MCP server via stdio (headless)...")
 38 |         tools = await mcp_manager.register_stdio_server(
 39 |             name="playwright",
 40 |             command="npx",
 41 |             args=["-y", "@playwright/mcp@latest", "--headless"]
 42 |             # args=["-y", "@playwright/mcp@latest"]
 43 |         )
 44 | 
 45 |         print(f"Registered {len(tools)} tools from Playwright MCP server:")
 46 |         for tool in tools:
 47 |             print(f"  - {tool.name}: {tool.description}")
 48 | 
 49 |         # Example: Execute some web browsing commands
 50 | 
 51 |         # Navigate to a website
 52 |         print("\nNavigating to https://example.com...")
 53 |         nav_result = await executor.execute_tool(
 54 |             "mcp.playwright.browser_navigate",
 55 |             {"url": "https://xplorestaging.ieee.org/author/37086453282"},
 56 |             config={}
 57 |         )
 58 |         print(f"Navigation result: {nav_result}")
 59 | 
 60 |         # Take a screenshot after navigating (using PDF as workaround)
 61 |         print("\nTaking screenshot (using PDF as workaround)...")
 62 |         screenshots_dir = os.path.join(project_root, "screenshots")
 63 |         output_file_path = os.path.join(screenshots_dir, "example_com.pdf")
 64 |         
 65 |         try:
 66 |             pdf_result = await executor.execute_tool(
 67 |                 "mcp.playwright.browser_pdf_save",
 68 |                 {"path": "example_com.pdf"},  # This will be ignored, but we include it anyway
 69 |                 config={}
 70 |             )
 71 |             print(f"PDF save command executed. Result: {pdf_result}")
 72 |             
 73 |             # Extract the actual path from the result
 74 |             if isinstance(pdf_result, str) and pdf_result.startswith("Saved as "):
 75 |                 actual_path = pdf_result.replace("Saved as ", "").strip()
 76 |                 print(f"Actual PDF file location: {actual_path}")
 77 |                 
 78 |                 # Copy the file to our screenshots directory
 79 |                 import shutil
 80 |                 os.makedirs(screenshots_dir, exist_ok=True)  # Ensure screenshots directory exists
 81 |                 shutil.copy2(actual_path, output_file_path)
 82 |                 
 83 |                 if os.path.exists(output_file_path):
 84 |                     print(f"Successfully saved visual capture to: {output_file_path}")
 85 |                 else:
 86 |                     print(f"Failed to save visual capture to: {output_file_path}")
 87 |             else:
 88 |                 print(f"Unexpected PDF result format: {pdf_result}")
 89 |         except Exception as e:
 90 |             print(f"Error capturing visual: {e}")
 91 |             import traceback
 92 |             traceback.print_exc()
 93 |             
 94 |         # ALSO attempt to take an actual screenshot (as a second approach)
 95 |         print("\nAlso attempting to take a PNG screenshot...")
 96 |         output_screenshot_path = os.path.join(screenshots_dir, "linkedin_profile.png")
 97 |         
 98 |         try:
 99 |             screenshot_result = await executor.execute_tool(
100 |                 "mcp.playwright.browser_take_screenshot",
101 |                 {"path": "linkedin_profile.png"},
102 |                 config={}
103 |             )
104 |             print(f"Screenshot command executed. Result: {screenshot_result}")
105 |             
106 |             # Check if Playwright saved the screenshot somewhere and reported the location
107 |             if isinstance(screenshot_result, str) and "Saved as" in screenshot_result:
108 |                 # Extract the actual path from the result
109 |                 actual_path = screenshot_result.replace("Saved as", "").strip()
110 |                 print(f"Actual screenshot file location: {actual_path}")
111 |                 
112 |                 # Copy the file to our screenshots directory
113 |                 import shutil
114 |                 shutil.copy2(actual_path, output_screenshot_path)
115 |                 
116 |                 if os.path.exists(output_screenshot_path):
117 |                     print(f"Successfully copied screenshot to: {output_screenshot_path}")
118 |                 else:
119 |                     print(f"Failed to copy screenshot to: {output_screenshot_path}")
120 |             else:
121 |                 # In case it worked but with a different result format
122 |                 print(f"Screenshot result doesn't contain path info. Looking for files...")
123 |                 import glob
124 |                 # Try to find any recent png files in typical temp directories
125 |                 tmp_files = glob.glob("/tmp/*.png") + glob.glob("/var/tmp/*.png")
126 |                 # Sort by creation time, newest first
127 |                 tmp_files.sort(key=lambda x: os.path.getctime(x), reverse=True)
128 |                 
129 |                 if tmp_files:
130 |                     newest_file = tmp_files[0]
131 |                     print(f"Found possible screenshot at: {newest_file}")
132 |                     import shutil
133 |                     shutil.copy2(newest_file, output_screenshot_path)
134 |                     if os.path.exists(output_screenshot_path):
135 |                         print(f"Successfully copied most recent PNG to: {output_screenshot_path}")
136 |                     else:
137 |                         print(f"Failed to copy most recent PNG to: {output_screenshot_path}")
138 |                 else:
139 |                     print("No recent PNG files found in temp directories")
140 |         except Exception as e:
141 |             print(f"Error processing screenshot: {e}")
142 |             import traceback
143 |             traceback.print_exc()
144 | 
145 |         # Take an accessibility snapshot - Temporarily commented out for debugging
146 |         # print("\nTaking accessibility snapshot...")
147 |         # snapshot_result = await executor.execute_tool(
148 |         #     "mcp.playwright.browser_snapshot",
149 |         #     {},
150 |         #     config={}
151 |         # )
152 |         # # Snapshot result can be large, print only a part or confirmation
153 |         # if isinstance(snapshot_result, dict) and 'snapshot' in snapshot_result:
154 |         #      print(f"Snapshot captured successfully (first 200 chars):\n{str(snapshot_result['snapshot'])[:200]}...")
155 |         # else:
156 |         #      print(f"Snapshot result: {snapshot_result}")
157 | 
158 | 
159 |         # Example: Typing into a non-existent element (will likely fail, demonstrating error handling)
160 |         # On example.com, there isn't an obvious input field without more complex interaction
161 |         # print("\nAttempting to type (expected to fail)...")
162 |         # try:
163 |         #     type_result = await executor.execute_tool(
164 |         #         "mcp.playwright.browser_type",
165 |         #         {"ref": "input#search", "element": "search input", "text": "hello world"}
166 |         #     )
167 |         #     print(f"Type result: {type_result}")
168 |         # except Exception as e:
169 |         #     print(f"Typing failed as expected: {e}")
170 | 
171 |     except Exception as e:
172 |         print(f"An error occurred: {e}")
173 |     finally:
174 |         # Close all MCP connections
175 |         print("\nClosing MCP connections...")
176 |         await mcp_manager.close_all()
177 |         print("Connections closed.")
178 | 
179 | if __name__ == "__main__":
180 |     asyncio.run(main()) 


--------------------------------------------------------------------------------
/ai-research-assistant/src/App.js:
--------------------------------------------------------------------------------
  1 | import React, { useState, useCallback } from 'react';
  2 | import Navbar from './components/Navbar';
  3 | import ResearchPanel from './components/ResearchPanel';
  4 | import DetailsPanel from './components/DetailsPanel'; // Now handles both item details and report content
  5 | import './App.css'; // For global styles
  6 | 
  7 | function App() {
  8 |   const [isResearching, setIsResearching] = useState(false);
  9 |   const [currentQuery, setCurrentQuery] = useState('');
 10 |   const [extraEffort, setExtraEffort] = useState(false);
 11 |   const [minimumEffort, setMinimumEffort] = useState(false);
 12 |   const [benchmarkMode, setBenchmarkMode] = useState(false);
 13 |   const [modelProvider, setModelProvider] = useState('google'); // Default provider
 14 |   const [modelName, setModelName] = useState('gemini-2.5-pro'); // Default model
 15 |   const [uploadedFileContent, setUploadedFileContent] = useState(null); // Added state for uploaded file content
 16 |   const [databaseInfo, setDatabaseInfo] = useState(null); // Added state for database info
 17 | 
 18 |   const [isDetailsPanelOpen, setIsDetailsPanelOpen] = useState(false);
 19 |   const [detailsPanelContentType, setDetailsPanelContentType] = useState(null); // 'item' or 'report'
 20 |   const [detailsPanelContentData, setDetailsPanelContentData] = useState(null);
 21 | 
 22 |   // Steering state
 23 |   const [currentTodoPlan, setCurrentTodoPlan] = useState("");
 24 |   const [todoPlanVersion, setTodoPlanVersion] = useState(0);
 25 | 
 26 |   const handleBeginResearch = useCallback((query, extra, minimum, benchmark, modelConfig, fileContent, databaseInfo) => { // Added fileContent and databaseInfo
 27 |     setCurrentQuery(query);
 28 |     setExtraEffort(extra);
 29 |     setMinimumEffort(minimum);
 30 |     setBenchmarkMode(benchmark);
 31 |     if (modelConfig) {
 32 |       setModelProvider(modelConfig.provider);
 33 |       setModelName(modelConfig.model);
 34 |     }
 35 | 
 36 |     setUploadedFileContent(fileContent); // Set uploaded file content
 37 |     // Store database info for the research agent
 38 |     if (databaseInfo && databaseInfo.length > 0) {
 39 |       console.log('Database info passed to research agent:', databaseInfo);
 40 |       setDatabaseInfo(databaseInfo); // Store database info in state
 41 |     }
 42 |     setIsResearching(true);
 43 |     setIsDetailsPanelOpen(false); // Close details panel when new research starts
 44 |     // Wait for animation to complete before resetting content
 45 |     setTimeout(() => {
 46 |       setDetailsPanelContentType(null);
 47 |       setDetailsPanelContentData(null);
 48 |     }, 300);
 49 |   }, [uploadedFileContent]);
 50 | 
 51 |   const handleShowItemDetails = useCallback((item) => {
 52 |     // Set data first, then trigger animation
 53 |     setDetailsPanelContentData(item);
 54 |     setDetailsPanelContentType('item');
 55 |     // Small delay to ensure data is set before animation starts
 56 |     setTimeout(() => {
 57 |       setIsDetailsPanelOpen(true);
 58 |     }, 10);
 59 |   }, []);
 60 | 
 61 |   const handleShowReportDetails = useCallback((reportContent) => {
 62 |     // Set data first, then trigger animation
 63 |     setDetailsPanelContentData(reportContent);
 64 |     setDetailsPanelContentType('report');
 65 |     // Small delay to ensure data is set before animation starts
 66 |     setTimeout(() => {
 67 |       setIsDetailsPanelOpen(true);
 68 |     }, 10);
 69 |   }, []);
 70 | 
 71 |   const handleCloseDetailsPanel = useCallback(() => {
 72 |     setIsDetailsPanelOpen(false);
 73 |     // Wait for animation to complete before clearing data
 74 |     setTimeout(() => {
 75 |       setDetailsPanelContentData(null);
 76 |       setDetailsPanelContentType(null);
 77 |     }, 300); // Match the transition timing in CSS
 78 |   }, []);
 79 | 
 80 |   // This callback is used by ResearchPanel to inform App.js that a report is ready.
 81 |   const [finalReportData, setFinalReportData] = useState(null);
 82 |   const handleReportGenerated = useCallback((report) => {
 83 |     setFinalReportData(report); // Store report data
 84 |     // Optionally, automatically open the report:
 85 |     // handleShowReportDetails(report);
 86 |   }, []);
 87 | 
 88 |   const handleStopResearch = useCallback(() => {
 89 |     console.log('Stopping research from App.js');
 90 |     setIsResearching(false);
 91 | 
 92 |     // Clear all research-related state
 93 |     setFinalReportData(null);
 94 |     setCurrentTodoPlan("");
 95 |     setTodoPlanVersion(0); // Reset version counter
 96 | 
 97 |     // Close details panel if open
 98 |     if (isDetailsPanelOpen) {
 99 |       setIsDetailsPanelOpen(false);
100 |       setTimeout(() => {
101 |         setDetailsPanelContentType(null);
102 |         setDetailsPanelContentData(null);
103 |       }, 300);
104 |     }
105 |   }, [isDetailsPanelOpen]);
106 | 
107 |   const handleTodoPlanUpdate = useCallback((todoPlan) => {
108 |     if (todoPlan !== currentTodoPlan) {
109 |       setCurrentTodoPlan(todoPlan);
110 |       setTodoPlanVersion(prev => prev + 1);
111 |     }
112 | 
113 |     if (todoPlan && !isDetailsPanelOpen) {
114 |       setDetailsPanelContentType('todo');
115 |       setDetailsPanelContentData(todoPlan);
116 |       setTimeout(() => {
117 |         setIsDetailsPanelOpen(true);
118 |       }, 10);
119 |     } else if (detailsPanelContentType === 'todo' && isDetailsPanelOpen) {
120 |       setDetailsPanelContentData(todoPlan);
121 |     } else if (!isDetailsPanelOpen && detailsPanelContentType === 'todo') {
122 |       setDetailsPanelContentData(todoPlan);
123 |     }
124 |   }, [isDetailsPanelOpen, detailsPanelContentType, currentTodoPlan]);
125 | 
126 |   const handleToggleProgress = useCallback(() => {
127 |     if (detailsPanelContentType === 'todo' && isDetailsPanelOpen) {
128 |       handleCloseDetailsPanel();
129 |     } else if (currentTodoPlan) {
130 |       handleTodoPlanUpdate(currentTodoPlan);
131 |     }
132 |   }, [detailsPanelContentType, isDetailsPanelOpen, currentTodoPlan, handleCloseDetailsPanel, handleTodoPlanUpdate]);
133 | 
134 |   const handleToggleReport = useCallback(() => {
135 |     if (detailsPanelContentType === 'report' && isDetailsPanelOpen) {
136 |       handleCloseDetailsPanel();
137 |     } else if (finalReportData) {
138 |       handleShowReportDetails(finalReportData);
139 |     }
140 |   }, [detailsPanelContentType, isDetailsPanelOpen, finalReportData, handleCloseDetailsPanel, handleShowReportDetails]);
141 | 
142 |   return (
143 |     
144 | 152 | 153 |
154 |
155 | 172 |
173 | 174 | 186 |
187 |
188 | ); 189 | } 190 | 191 | export default App; --------------------------------------------------------------------------------