├── cognidb ├── security │ ├── __init__.py │ ├── query_parser.py │ ├── access_control.py │ ├── sanitizer.py │ └── validator.py ├── config │ ├── __init__.py │ ├── settings.py │ ├── secrets.py │ └── loader.py ├── drivers │ ├── __init__.py │ ├── base_driver.py │ ├── mysql_driver.py │ └── postgres_driver.py ├── ai │ ├── __init__.py │ ├── llm_manager.py │ ├── prompt_builder.py │ ├── query_generator.py │ └── cost_tracker.py └── core │ ├── __init__.py │ ├── exceptions.py │ ├── interfaces.py │ └── query_intent.py ├── requirements.txt ├── .gitignore ├── setup.py ├── cognidb.example.yaml ├── git_commit_script.bat ├── git_commit_script.ps1 ├── examples └── basic_usage.py ├── Readme.md └── __init__.py /cognidb/security/__init__.py: -------------------------------------------------------------------------------- 1 | """Security module for CogniDB.""" 2 | 3 | from .validator import QuerySecurityValidator 4 | from .sanitizer import InputSanitizer 5 | from .query_parser import SQLQueryParser 6 | from .access_control import AccessController 7 | 8 | __all__ = [ 9 | 'QuerySecurityValidator', 10 | 'InputSanitizer', 11 | 'SQLQueryParser', 12 | 'AccessController' 13 | ] -------------------------------------------------------------------------------- /cognidb/config/__init__.py: -------------------------------------------------------------------------------- 1 | """Configuration module for CogniDB.""" 2 | 3 | from .settings import Settings, DatabaseConfig, LLMConfig, CacheConfig, SecurityConfig 4 | from .secrets import SecretsManager 5 | from .loader import ConfigLoader 6 | 7 | __all__ = [ 8 | 'Settings', 9 | 'DatabaseConfig', 10 | 'LLMConfig', 11 | 'CacheConfig', 12 | 'SecurityConfig', 13 | 'SecretsManager', 14 | 'ConfigLoader' 15 | ] -------------------------------------------------------------------------------- /cognidb/drivers/__init__.py: -------------------------------------------------------------------------------- 1 | """Database drivers module.""" 2 | 3 | from .mysql_driver import MySQLDriver 4 | from .postgres_driver import PostgreSQLDriver 5 | from .mongodb_driver import MongoDBDriver 6 | from .dynamodb_driver import DynamoDBDriver 7 | from .sqlite_driver import SQLiteDriver 8 | 9 | __all__ = [ 10 | 'MySQLDriver', 11 | 'PostgreSQLDriver', 12 | 'MongoDBDriver', 13 | 'DynamoDBDriver', 14 | 'SQLiteDriver' 15 | ] -------------------------------------------------------------------------------- /cognidb/ai/__init__.py: -------------------------------------------------------------------------------- 1 | """AI and LLM integration module.""" 2 | 3 | from .llm_manager import LLMManager 4 | from .prompt_builder import PromptBuilder 5 | from .query_generator import QueryGenerator 6 | from .cost_tracker import CostTracker 7 | from .providers import OpenAIProvider, AnthropicProvider, AzureOpenAIProvider 8 | 9 | __all__ = [ 10 | 'LLMManager', 11 | 'PromptBuilder', 12 | 'QueryGenerator', 13 | 'CostTracker', 14 | 'OpenAIProvider', 15 | 'AnthropicProvider', 16 | 'AzureOpenAIProvider' 17 | ] -------------------------------------------------------------------------------- /cognidb/core/__init__.py: -------------------------------------------------------------------------------- 1 | """Core abstractions for CogniDB.""" 2 | 3 | from .query_intent import QueryIntent, QueryType, JoinCondition, Aggregation 4 | from .interfaces import ( 5 | DatabaseDriver, 6 | QueryTranslator, 7 | SecurityValidator, 8 | ResultNormalizer, 9 | CacheProvider 10 | ) 11 | from .exceptions import ( 12 | CogniDBError, 13 | SecurityError, 14 | TranslationError, 15 | ExecutionError, 16 | ValidationError 17 | ) 18 | 19 | __all__ = [ 20 | 'QueryIntent', 21 | 'QueryType', 22 | 'JoinCondition', 23 | 'Aggregation', 24 | 'DatabaseDriver', 25 | 'QueryTranslator', 26 | 'SecurityValidator', 27 | 'ResultNormalizer', 28 | 'CacheProvider', 29 | 'CogniDBError', 30 | 'SecurityError', 31 | 'TranslationError', 32 | 'ExecutionError', 33 | 'ValidationError' 34 | ] -------------------------------------------------------------------------------- /cognidb/core/exceptions.py: -------------------------------------------------------------------------------- 1 | """Custom exceptions for CogniDB.""" 2 | 3 | 4 | class CogniDBError(Exception): 5 | """Base exception for all CogniDB errors.""" 6 | 7 | def __init__(self, message: str, details: dict = None): 8 | super().__init__(message) 9 | self.message = message 10 | self.details = details or {} 11 | 12 | 13 | class SecurityError(CogniDBError): 14 | """Raised when a security violation is detected.""" 15 | pass 16 | 17 | 18 | class ValidationError(CogniDBError): 19 | """Raised when validation fails.""" 20 | pass 21 | 22 | 23 | class TranslationError(CogniDBError): 24 | """Raised when query translation fails.""" 25 | pass 26 | 27 | 28 | class ExecutionError(CogniDBError): 29 | """Raised when query execution fails.""" 30 | pass 31 | 32 | 33 | class ConnectionError(CogniDBError): 34 | """Raised when database connection fails.""" 35 | pass 36 | 37 | 38 | class SchemaError(CogniDBError): 39 | """Raised when schema-related operations fail.""" 40 | pass 41 | 42 | 43 | class ConfigurationError(CogniDBError): 44 | """Raised when configuration is invalid.""" 45 | pass 46 | 47 | 48 | class CacheError(CogniDBError): 49 | """Raised when cache operations fail.""" 50 | pass 51 | 52 | 53 | class RateLimitError(CogniDBError): 54 | """Raised when rate limits are exceeded.""" 55 | 56 | def __init__(self, message: str, retry_after: int = None, details: dict = None): 57 | super().__init__(message, details) 58 | self.retry_after = retry_after -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Core database drivers 2 | mysql-connector-python>=8.0.33 3 | psycopg2-binary>=2.9.9 4 | pymongo>=4.6.0 5 | boto3>=1.28.0 # For AWS DynamoDB 6 | 7 | # SQL parsing and validation 8 | sqlparse>=0.4.4 9 | 10 | # LLM providers 11 | openai>=1.0.0 12 | anthropic>=0.8.0 13 | transformers>=4.35.0 # For HuggingFace models 14 | torch>=2.0.0 # For local models 15 | llama-cpp-python>=0.2.0 # For llama.cpp models (optional) 16 | 17 | # Security and encryption 18 | cryptography>=41.0.0 19 | 20 | # Configuration 21 | pyyaml>=6.0.1 22 | 23 | # Caching 24 | redis>=5.0.0 # For Redis cache (optional) 25 | 26 | # Cloud secrets management (optional) 27 | hvac>=1.2.0 # HashiCorp Vault 28 | azure-identity>=1.14.0 # Azure Key Vault 29 | azure-keyvault-secrets>=4.7.0 # Azure Key Vault 30 | 31 | # Testing 32 | pytest>=7.4.0 33 | pytest-cov>=4.1.0 34 | pytest-asyncio>=0.21.0 35 | pytest-mock>=3.11.0 36 | 37 | # Development tools 38 | black>=23.0.0 39 | flake8>=6.0.0 40 | mypy>=1.5.0 41 | pre-commit>=3.3.0 42 | 43 | # Documentation 44 | sphinx>=7.0.0 45 | sphinx-rtd-theme>=1.3.0 46 | 47 | # Monitoring and metrics (optional) 48 | prometheus-client>=0.18.0 49 | opentelemetry-api>=1.20.0 50 | opentelemetry-sdk>=1.20.0 51 | opentelemetry-instrumentation>=0.41b0 52 | 53 | # Data processing 54 | pandas>=2.0.0 # For result formatting 55 | numpy>=1.24.0 # For numerical operations 56 | 57 | # Web framework (if adding API) 58 | fastapi>=0.104.0 # Optional, for REST API 59 | uvicorn>=0.24.0 # ASGI server 60 | 61 | # Utilities 62 | python-dotenv>=1.0.0 # For .env file support 63 | click>=8.1.0 # For CLI interface 64 | rich>=13.6.0 # For beautiful terminal output 65 | tabulate>=0.9.0 # For table formatting -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | docs/_build/doctrees 74 | docs/_build/html 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | celerybeat.pid 92 | 93 | # SageMath parsed files 94 | *.sage.py 95 | 96 | # Environments 97 | .env 98 | .venv 99 | env/ 100 | venv/ 101 | ENV/ 102 | env.bak/ 103 | venv.bak/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyproject 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | .dmypy.json 118 | dmypy.json 119 | 120 | # Pyre type checker 121 | .pyre/ 122 | 123 | # pytype static type analyzer 124 | .pytype/ 125 | 126 | # Cython debug symbols 127 | cython_debug/ -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup configuration for CogniDB.""" 2 | 3 | from setuptools import setup, find_packages 4 | from pathlib import Path 5 | 6 | # Read the README file 7 | this_directory = Path(__file__).parent 8 | long_description = (this_directory / "Readme.md").read_text() 9 | 10 | # Read requirements 11 | requirements = (this_directory / "requirements.txt").read_text().splitlines() 12 | 13 | # Separate optional requirements 14 | core_requirements = [] 15 | optional_requirements = { 16 | 'llama': ['llama-cpp-python>=0.2.0'], 17 | 'azure': ['azure-identity>=1.14.0', 'azure-keyvault-secrets>=4.7.0'], 18 | 'vault': ['hvac>=1.2.0'], 19 | 'redis': ['redis>=5.0.0'], 20 | 'api': ['fastapi>=0.104.0', 'uvicorn>=0.24.0'], 21 | 'dev': [ 22 | 'pytest>=7.4.0', 23 | 'pytest-cov>=4.1.0', 24 | 'pytest-asyncio>=0.21.0', 25 | 'pytest-mock>=3.11.0', 26 | 'black>=23.0.0', 27 | 'flake8>=6.0.0', 28 | 'mypy>=1.5.0', 29 | 'pre-commit>=3.3.0' 30 | ], 31 | 'docs': [ 32 | 'sphinx>=7.0.0', 33 | 'sphinx-rtd-theme>=1.3.0' 34 | ] 35 | } 36 | 37 | # Filter core requirements (exclude optional ones) 38 | for line in requirements: 39 | if line and not line.startswith('#'): 40 | # Skip optional dependencies 41 | if not any(opt in line.lower() for opt in ['llama', 'azure', 'hvac', 'redis', 'fastapi', 'pytest', 'sphinx']): 42 | core_requirements.append(line) 43 | 44 | setup( 45 | name="cognidb", 46 | version="2.0.0", 47 | author="CogniDB Team", 48 | author_email="team@cognidb.io", 49 | description="Secure Natural Language Database Interface", 50 | long_description=long_description, 51 | long_description_content_type="text/markdown", 52 | url="https://github.com/adrienckr/cognidb", 53 | packages=find_packages(), 54 | classifiers=[ 55 | "Development Status :: 4 - Beta", 56 | "Intended Audience :: Developers", 57 | "Topic :: Database", 58 | "Topic :: Software Development :: Libraries :: Python Modules", 59 | "License :: OSI Approved :: MIT License", 60 | "Programming Language :: Python :: 3", 61 | "Programming Language :: Python :: 3.8", 62 | "Programming Language :: Python :: 3.9", 63 | "Programming Language :: Python :: 3.10", 64 | "Programming Language :: Python :: 3.11", 65 | "Programming Language :: Python :: 3.12", 66 | ], 67 | python_requires=">=3.8", 68 | install_requires=core_requirements, 69 | extras_require={ 70 | **optional_requirements, 71 | 'all': sum(optional_requirements.values(), []) 72 | }, 73 | entry_points={ 74 | 'console_scripts': [ 75 | 'cognidb=cognidb.cli:main', 76 | ], 77 | }, 78 | include_package_data=True, 79 | package_data={ 80 | 'cognidb': [ 81 | 'cognidb.example.yaml', 82 | 'examples/*.py', 83 | ], 84 | }, 85 | keywords='database sql natural-language nlp ai llm security', 86 | project_urls={ 87 | 'Bug Reports': 'https://github.com/adrienckr/cognidb/issues', 88 | 'Source': 'https://github.com/adrienckr/cognidb', 89 | 'Documentation': 'https://cognidb.readthedocs.io', 90 | }, 91 | ) -------------------------------------------------------------------------------- /cognidb.example.yaml: -------------------------------------------------------------------------------- 1 | # CogniDB Configuration Example 2 | # Copy this to cognidb.yaml and update with your settings 3 | 4 | # Application settings 5 | app_name: CogniDB 6 | environment: production 7 | debug: false 8 | log_level: INFO 9 | 10 | # Database configuration 11 | database: 12 | type: postgresql # Options: mysql, postgresql, mongodb, dynamodb, sqlite 13 | host: localhost 14 | port: 5432 15 | database: your_database 16 | username: your_username 17 | password: ${DB_PASSWORD} # Use environment variable 18 | 19 | # Connection pool settings 20 | pool_size: 5 21 | max_overflow: 10 22 | pool_timeout: 30 23 | pool_recycle: 3600 24 | 25 | # SSL/TLS settings 26 | ssl_enabled: true 27 | ssl_ca_cert: null 28 | ssl_client_cert: null 29 | ssl_client_key: null 30 | 31 | # Query settings 32 | query_timeout: 30 # seconds 33 | max_result_size: 10000 # rows 34 | 35 | # LLM configuration 36 | llm: 37 | provider: openai # Options: openai, anthropic, azure_openai, huggingface, local 38 | api_key: ${LLM_API_KEY} # Use environment variable 39 | 40 | # Model settings 41 | model_name: gpt-4 42 | temperature: 0.1 43 | max_tokens: 1000 44 | timeout: 30 45 | 46 | # Cost control 47 | max_tokens_per_query: 2000 48 | max_queries_per_minute: 60 49 | max_cost_per_day: 100.0 50 | 51 | # Few-shot examples for better SQL generation 52 | few_shot_examples: 53 | - query: "Show me all users who registered last month" 54 | sql: "SELECT * FROM users WHERE created_at >= DATE_TRUNC('month', CURRENT_DATE - INTERVAL '1 month') AND created_at < DATE_TRUNC('month', CURRENT_DATE)" 55 | 56 | - query: "What's the total revenue by product category?" 57 | sql: "SELECT p.category, SUM(o.amount) as total_revenue FROM orders o JOIN products p ON o.product_id = p.id GROUP BY p.category ORDER BY total_revenue DESC" 58 | 59 | - query: "Find customers who haven't made a purchase in the last 90 days" 60 | sql: "SELECT c.* FROM customers c WHERE c.id NOT IN (SELECT DISTINCT customer_id FROM orders WHERE order_date > CURRENT_DATE - INTERVAL '90 days')" 61 | 62 | # Cache configuration 63 | cache: 64 | provider: in_memory # Options: in_memory, redis, memcached, disk 65 | 66 | # TTL settings (in seconds) 67 | query_result_ttl: 3600 # 1 hour 68 | schema_ttl: 86400 # 24 hours 69 | llm_response_ttl: 7200 # 2 hours 70 | 71 | # Redis settings (if using Redis) 72 | redis_host: localhost 73 | redis_port: 6379 74 | redis_password: ${REDIS_PASSWORD} 75 | redis_db: 0 76 | redis_ssl: false 77 | 78 | # Security configuration 79 | security: 80 | # Query validation 81 | allow_only_select: true 82 | max_query_complexity: 10 83 | allow_subqueries: false 84 | allow_unions: false 85 | 86 | # Rate limiting 87 | enable_rate_limiting: true 88 | rate_limit_per_minute: 100 89 | rate_limit_per_hour: 1000 90 | 91 | # Access control 92 | enable_access_control: true 93 | default_user_permissions: ["SELECT"] 94 | require_authentication: false 95 | 96 | # Audit logging 97 | enable_audit_logging: true 98 | audit_log_path: ~/.cognidb/audit.log 99 | log_query_results: false 100 | 101 | # Encryption 102 | encrypt_cache: true 103 | encrypt_logs: true 104 | encryption_key: ${ENCRYPTION_KEY} 105 | 106 | # Network security 107 | allowed_ip_ranges: [] 108 | require_ssl: true 109 | 110 | # Feature flags 111 | enable_natural_language: true 112 | enable_query_explanation: true 113 | enable_query_optimization: true 114 | enable_auto_indexing: false 115 | 116 | # Monitoring 117 | enable_metrics: true 118 | metrics_port: 9090 119 | enable_tracing: true 120 | tracing_endpoint: null -------------------------------------------------------------------------------- /git_commit_script.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | REM Script to create multiple logical commits for CogniDB v2.0.0 3 | REM Each commit adds related files with appropriate messages 4 | 5 | echo Starting CogniDB v2.0.0 commit process... 6 | echo. 7 | 8 | REM Commit 1: Core abstractions and interfaces 9 | echo Creating commit 1: Core abstractions... 10 | git add cognidb/core/__init__.py 11 | git add cognidb/core/exceptions.py 12 | git add cognidb/core/interfaces.py 13 | git add cognidb/core/query_intent.py 14 | git commit -m "feat(core): Add core abstractions and interfaces" -m "" -m "- Implement QueryIntent for database-agnostic query representation" -m "- Define abstract interfaces for drivers, translators, and validators" -m "- Add comprehensive exception hierarchy" -m "- Create foundation for modular architecture" 15 | echo. 16 | 17 | REM Commit 2: Security layer 18 | echo Creating commit 2: Security layer... 19 | git add cognidb/security/__init__.py 20 | git add cognidb/security/validator.py 21 | git add cognidb/security/sanitizer.py 22 | git add cognidb/security/query_parser.py 23 | git add cognidb/security/access_control.py 24 | git commit -m "feat(security): Implement comprehensive security layer" -m "" -m "- Add multi-layer query validation to prevent SQL injection" -m "- Implement input sanitization for all user inputs" -m "- Create SQL query parser for security analysis" -m "- Add access control with table/column/row-level permissions" -m "- Enforce parameterized queries throughout" 25 | echo. 26 | 27 | REM Commit 3: Configuration management 28 | echo Creating commit 3: Configuration system... 29 | git add cognidb/config/__init__.py 30 | git add cognidb/config/settings.py 31 | git add cognidb/config/secrets.py 32 | git add cognidb/config/loader.py 33 | git add cognidb.example.yaml 34 | git commit -m "feat(config): Add flexible configuration management" -m "" -m "- Implement settings with dataclasses for type safety" -m "- Add secrets manager supporting multiple providers" -m "- Create config loader with YAML/JSON/env support" -m "- Add example configuration file" -m "- Support for environment variable interpolation" 35 | echo. 36 | 37 | REM Commit 4: AI/LLM integration 38 | echo Creating commit 4: AI/LLM integration... 39 | git add cognidb/ai/__init__.py 40 | git add cognidb/ai/llm_manager.py 41 | git add cognidb/ai/providers.py 42 | git add cognidb/ai/prompt_builder.py 43 | git add cognidb/ai/query_generator.py 44 | git add cognidb/ai/cost_tracker.py 45 | git commit -m "feat(ai): Implement modern LLM integration" -m "" -m "- Add multi-provider support (OpenAI, Anthropic, Azure, etc)" -m "- Implement cost tracking with daily limits" -m "- Create advanced prompt builder with few-shot learning" -m "- Add query generation with optimization suggestions" -m "- Include response caching to reduce costs" -m "- Support streaming and fallback providers" 46 | echo. 47 | 48 | REM Commit 5: Database drivers 49 | echo Creating commit 5: Secure database drivers... 50 | git add cognidb/drivers/__init__.py 51 | git add cognidb/drivers/base_driver.py 52 | git add cognidb/drivers/mysql_driver.py 53 | git add cognidb/drivers/postgres_driver.py 54 | git commit -m "feat(drivers): Add secure database drivers" -m "" -m "- Implement base driver with common functionality" -m "- Add MySQL driver with connection pooling and SSL support" -m "- Add PostgreSQL driver with prepared statements" -m "- Use parameterized queries exclusively" -m "- Include timeout management and result limiting" -m "- Add schema caching for performance" 55 | echo. 56 | 57 | REM Commit 6: Main CogniDB class 58 | echo Creating commit 6: Main CogniDB interface... 59 | git add __init__.py 60 | git commit -m "feat: Implement main CogniDB class with new architecture" -m "" -m "- Create unified interface for natural language queries" -m "- Integrate all components (security, AI, drivers, config)" -m "- Add context manager support" -m "- Include query optimization and suggestions" -m "- Implement audit logging" -m "- Add comprehensive error handling" 61 | echo. 62 | 63 | REM Commit 7: Documentation and examples 64 | echo Creating commit 7: Documentation and examples... 65 | git add Readme.md 66 | git add examples/basic_usage.py 67 | git commit -m "docs: Update documentation for v2.0.0" -m "" -m "- Rewrite README with security-first approach" -m "- Add comprehensive usage examples" -m "- Document all features and configuration options" -m "- Include security best practices" -m "- Add performance tips" -m "- Update badges and project description" 68 | echo. 69 | 70 | REM Commit 8: Package setup and requirements 71 | echo Creating commit 8: Package configuration... 72 | git add setup.py 73 | git add requirements.txt 74 | git commit -m "build: Update package configuration for v2.0.0" -m "" -m "- Update requirements with all dependencies" -m "- Configure setup.py with optional extras" -m "- Add proper classifiers and metadata" -m "- Include console script entry point" -m "- Separate core and optional dependencies" 75 | echo. 76 | 77 | REM Commit 9: Remove old insecure files 78 | echo Creating commit 9: Clean up old implementation... 79 | git add -A 80 | git commit -m "refactor: Remove old insecure implementation" -m "" -m "- Remove vulnerable SQL string interpolation code" -m "- Delete unused modules" -m "- Remove redundant wrappers" -m "- Clean up old database implementations" -m "- Remove insecure query validator" -m "" -m "BREAKING CHANGE: Complete API redesign for v2.0.0" 81 | echo. 82 | 83 | echo All commits created successfully! 84 | echo. 85 | echo Repository status: 86 | git status 87 | echo. 88 | echo Commit history: 89 | git log --oneline -n 9 90 | echo. 91 | pause -------------------------------------------------------------------------------- /git_commit_script.ps1: -------------------------------------------------------------------------------- 1 | # PowerShell script to create multiple logical commits for CogniDB v2.0.0 2 | # Each commit adds related files with appropriate messages 3 | 4 | Write-Host "Starting CogniDB v2.0.0 commit process..." -ForegroundColor Green 5 | Write-Host "" 6 | 7 | # Commit 1: Core abstractions and interfaces 8 | Write-Host "Creating commit 1: Core abstractions..." -ForegroundColor Yellow 9 | git add cognidb/core/__init__.py 10 | git add cognidb/core/exceptions.py 11 | git add cognidb/core/interfaces.py 12 | git add cognidb/core/query_intent.py 13 | 14 | $commit1Message = @" 15 | feat(core): Add core abstractions and interfaces 16 | 17 | - Implement QueryIntent for database-agnostic query representation 18 | - Define abstract interfaces for drivers, translators, and validators 19 | - Add comprehensive exception hierarchy 20 | - Create foundation for modular architecture 21 | "@ 22 | git commit -m $commit1Message 23 | Write-Host "" 24 | 25 | # Commit 2: Security layer 26 | Write-Host "Creating commit 2: Security layer..." -ForegroundColor Yellow 27 | git add cognidb/security/__init__.py 28 | git add cognidb/security/validator.py 29 | git add cognidb/security/sanitizer.py 30 | git add cognidb/security/query_parser.py 31 | git add cognidb/security/access_control.py 32 | 33 | $commit2Message = @" 34 | feat(security): Implement comprehensive security layer 35 | 36 | - Add multi-layer query validation to prevent SQL injection 37 | - Implement input sanitization for all user inputs 38 | - Create SQL query parser for security analysis 39 | - Add access control with table/column/row-level permissions 40 | - Enforce parameterized queries throughout 41 | "@ 42 | git commit -m $commit2Message 43 | Write-Host "" 44 | 45 | # Commit 3: Configuration management 46 | Write-Host "Creating commit 3: Configuration system..." -ForegroundColor Yellow 47 | git add cognidb/config/__init__.py 48 | git add cognidb/config/settings.py 49 | git add cognidb/config/secrets.py 50 | git add cognidb/config/loader.py 51 | git add cognidb.example.yaml 52 | 53 | $commit3Message = @" 54 | feat(config): Add flexible configuration management 55 | 56 | - Implement settings with dataclasses for type safety 57 | - Add secrets manager supporting multiple providers (env, file, AWS, Vault) 58 | - Create config loader with YAML/JSON/env support 59 | - Add example configuration file 60 | - Support for environment variable interpolation 61 | "@ 62 | git commit -m $commit3Message 63 | Write-Host "" 64 | 65 | # Commit 4: AI/LLM integration 66 | Write-Host "Creating commit 4: AI/LLM integration..." -ForegroundColor Yellow 67 | git add cognidb/ai/__init__.py 68 | git add cognidb/ai/llm_manager.py 69 | git add cognidb/ai/providers.py 70 | git add cognidb/ai/prompt_builder.py 71 | git add cognidb/ai/query_generator.py 72 | git add cognidb/ai/cost_tracker.py 73 | 74 | $commit4Message = @" 75 | feat(ai): Implement modern LLM integration 76 | 77 | - Add multi-provider support (OpenAI, Anthropic, Azure, HuggingFace, Local) 78 | - Implement cost tracking with daily limits 79 | - Create advanced prompt builder with few-shot learning 80 | - Add query generation with optimization suggestions 81 | - Include response caching to reduce costs 82 | - Support streaming and fallback providers 83 | "@ 84 | git commit -m $commit4Message 85 | Write-Host "" 86 | 87 | # Commit 5: Database drivers 88 | Write-Host "Creating commit 5: Secure database drivers..." -ForegroundColor Yellow 89 | git add cognidb/drivers/__init__.py 90 | git add cognidb/drivers/base_driver.py 91 | git add cognidb/drivers/mysql_driver.py 92 | git add cognidb/drivers/postgres_driver.py 93 | 94 | $commit5Message = @" 95 | feat(drivers): Add secure database drivers 96 | 97 | - Implement base driver with common functionality 98 | - Add MySQL driver with connection pooling and SSL support 99 | - Add PostgreSQL driver with prepared statements 100 | - Use parameterized queries exclusively 101 | - Include timeout management and result limiting 102 | - Add schema caching for performance 103 | "@ 104 | git commit -m $commit5Message 105 | Write-Host "" 106 | 107 | # Commit 6: Main CogniDB class 108 | Write-Host "Creating commit 6: Main CogniDB interface..." -ForegroundColor Yellow 109 | git add __init__.py 110 | 111 | $commit6Message = @" 112 | feat: Implement main CogniDB class with new architecture 113 | 114 | - Create unified interface for natural language queries 115 | - Integrate all components (security, AI, drivers, config) 116 | - Add context manager support 117 | - Include query optimization and suggestions 118 | - Implement audit logging 119 | - Add comprehensive error handling 120 | "@ 121 | git commit -m $commit6Message 122 | Write-Host "" 123 | 124 | # Commit 7: Documentation and examples 125 | Write-Host "Creating commit 7: Documentation and examples..." -ForegroundColor Yellow 126 | git add Readme.md 127 | git add examples/basic_usage.py 128 | 129 | $commit7Message = @" 130 | docs: Update documentation for v2.0.0 131 | 132 | - Rewrite README with security-first approach 133 | - Add comprehensive usage examples 134 | - Document all features and configuration options 135 | - Include security best practices 136 | - Add performance tips 137 | - Update badges and project description 138 | "@ 139 | git commit -m $commit7Message 140 | Write-Host "" 141 | 142 | # Commit 8: Package setup and requirements 143 | Write-Host "Creating commit 8: Package configuration..." -ForegroundColor Yellow 144 | git add setup.py 145 | git add requirements.txt 146 | 147 | $commit8Message = @" 148 | build: Update package configuration for v2.0.0 149 | 150 | - Update requirements with all dependencies 151 | - Configure setup.py with optional extras 152 | - Add proper classifiers and metadata 153 | - Include console script entry point 154 | - Separate core and optional dependencies 155 | "@ 156 | git commit -m $commit8Message 157 | Write-Host "" 158 | 159 | # Commit 9: Remove old insecure files 160 | Write-Host "Creating commit 9: Clean up old implementation..." -ForegroundColor Yellow 161 | git add -A 162 | 163 | $commit9Message = @" 164 | refactor: Remove old insecure implementation 165 | 166 | - Remove vulnerable SQL string interpolation code 167 | - Delete unused modules (clarification_handler, user_input_processor) 168 | - Remove redundant wrappers (db_connection, schema_fetcher) 169 | - Clean up old database implementations 170 | - Remove insecure query validator 171 | - Delete analysis report 172 | 173 | BREAKING CHANGE: Complete API redesign for v2.0.0 174 | "@ 175 | git commit -m $commit9Message 176 | Write-Host "" 177 | 178 | Write-Host "All commits created successfully!" -ForegroundColor Green 179 | Write-Host "" 180 | Write-Host "Repository status:" -ForegroundColor Cyan 181 | git status 182 | Write-Host "" 183 | Write-Host "Commit history:" -ForegroundColor Cyan 184 | git log --oneline -n 9 185 | Write-Host "" 186 | Write-Host "Press any key to continue..." 187 | $null = $Host.UI.RawUI.ReadKey("NoEcho,IncludeKeyDown") -------------------------------------------------------------------------------- /cognidb/core/interfaces.py: -------------------------------------------------------------------------------- 1 | """Core interfaces for CogniDB components.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Dict, List, Any, Optional, Tuple 5 | from .query_intent import QueryIntent 6 | 7 | 8 | class DatabaseDriver(ABC): 9 | """Abstract base class for database drivers.""" 10 | 11 | @abstractmethod 12 | def connect(self) -> None: 13 | """Establish connection to the database.""" 14 | pass 15 | 16 | @abstractmethod 17 | def disconnect(self) -> None: 18 | """Close the database connection.""" 19 | pass 20 | 21 | @abstractmethod 22 | def execute_native_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: 23 | """ 24 | Execute a native query with parameters. 25 | 26 | Args: 27 | query: Native query string with parameter placeholders 28 | params: Parameter values for the query 29 | 30 | Returns: 31 | List of result rows as dictionaries 32 | """ 33 | pass 34 | 35 | @abstractmethod 36 | def fetch_schema(self) -> Dict[str, Dict[str, str]]: 37 | """ 38 | Fetch database schema. 39 | 40 | Returns: 41 | Dictionary mapping table names to column info: 42 | { 43 | 'table_name': { 44 | 'column_name': 'data_type', 45 | ... 46 | }, 47 | ... 48 | } 49 | """ 50 | pass 51 | 52 | @abstractmethod 53 | def validate_table_name(self, table_name: str) -> bool: 54 | """Validate that a table name exists and is safe.""" 55 | pass 56 | 57 | @abstractmethod 58 | def validate_column_name(self, table_name: str, column_name: str) -> bool: 59 | """Validate that a column exists in the table.""" 60 | pass 61 | 62 | @abstractmethod 63 | def get_connection_info(self) -> Dict[str, Any]: 64 | """Get connection information (for debugging, minus secrets).""" 65 | pass 66 | 67 | @property 68 | @abstractmethod 69 | def supports_transactions(self) -> bool: 70 | """Whether this driver supports transactions.""" 71 | pass 72 | 73 | @property 74 | @abstractmethod 75 | def supports_schemas(self) -> bool: 76 | """Whether this database supports schemas/namespaces.""" 77 | pass 78 | 79 | 80 | class QueryTranslator(ABC): 81 | """Abstract base class for query translators.""" 82 | 83 | @abstractmethod 84 | def translate(self, query_intent: QueryIntent) -> Tuple[str, Dict[str, Any]]: 85 | """ 86 | Translate a QueryIntent into a native query. 87 | 88 | Args: 89 | query_intent: The query intent to translate 90 | 91 | Returns: 92 | Tuple of (query_string, parameters_dict) 93 | """ 94 | pass 95 | 96 | @abstractmethod 97 | def validate_intent(self, query_intent: QueryIntent) -> List[str]: 98 | """ 99 | Validate that the query intent can be translated. 100 | 101 | Returns: 102 | List of validation errors (empty if valid) 103 | """ 104 | pass 105 | 106 | @property 107 | @abstractmethod 108 | def supported_features(self) -> Dict[str, bool]: 109 | """ 110 | Return supported features for this translator. 111 | 112 | Example: 113 | { 114 | 'joins': True, 115 | 'subqueries': False, 116 | 'window_functions': True, 117 | 'cte': False 118 | } 119 | """ 120 | pass 121 | 122 | 123 | class SecurityValidator(ABC): 124 | """Abstract base class for security validators.""" 125 | 126 | @abstractmethod 127 | def validate_query_intent(self, query_intent: QueryIntent) -> Tuple[bool, Optional[str]]: 128 | """ 129 | Validate query intent for security issues. 130 | 131 | Returns: 132 | Tuple of (is_valid, error_message) 133 | """ 134 | pass 135 | 136 | @abstractmethod 137 | def validate_native_query(self, query: str) -> Tuple[bool, Optional[str]]: 138 | """ 139 | Validate native query for security issues. 140 | 141 | Returns: 142 | Tuple of (is_valid, error_message) 143 | """ 144 | pass 145 | 146 | @abstractmethod 147 | def sanitize_identifier(self, identifier: str) -> str: 148 | """Sanitize a table/column identifier.""" 149 | pass 150 | 151 | @abstractmethod 152 | def sanitize_value(self, value: Any) -> Any: 153 | """Sanitize a parameter value.""" 154 | pass 155 | 156 | @property 157 | @abstractmethod 158 | def allowed_operations(self) -> List[str]: 159 | """List of allowed query operations.""" 160 | pass 161 | 162 | 163 | class ResultNormalizer(ABC): 164 | """Abstract base class for result normalizers.""" 165 | 166 | @abstractmethod 167 | def normalize(self, raw_results: Any) -> List[Dict[str, Any]]: 168 | """ 169 | Normalize database-specific results to standard format. 170 | 171 | Args: 172 | raw_results: Raw results from database driver 173 | 174 | Returns: 175 | List of dictionaries with consistent structure 176 | """ 177 | pass 178 | 179 | @abstractmethod 180 | def format_for_output(self, 181 | normalized_results: List[Dict[str, Any]], 182 | output_format: str = 'json') -> Any: 183 | """ 184 | Format normalized results for final output. 185 | 186 | Args: 187 | normalized_results: Normalized result set 188 | output_format: One of 'json', 'csv', 'table', 'dataframe' 189 | 190 | Returns: 191 | Formatted results 192 | """ 193 | pass 194 | 195 | 196 | class CacheProvider(ABC): 197 | """Abstract base class for cache providers.""" 198 | 199 | @abstractmethod 200 | def get(self, key: str) -> Optional[Any]: 201 | """Retrieve value from cache.""" 202 | pass 203 | 204 | @abstractmethod 205 | def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool: 206 | """ 207 | Store value in cache. 208 | 209 | Args: 210 | key: Cache key 211 | value: Value to cache 212 | ttl: Time to live in seconds (None = no expiration) 213 | 214 | Returns: 215 | Success status 216 | """ 217 | pass 218 | 219 | @abstractmethod 220 | def delete(self, key: str) -> bool: 221 | """Delete value from cache.""" 222 | pass 223 | 224 | @abstractmethod 225 | def clear(self) -> bool: 226 | """Clear all cached values.""" 227 | pass 228 | 229 | @abstractmethod 230 | def get_stats(self) -> Dict[str, Any]: 231 | """Get cache statistics.""" 232 | pass -------------------------------------------------------------------------------- /examples/basic_usage.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic usage examples for CogniDB. 3 | 4 | This demonstrates how to use CogniDB for natural language database queries 5 | with various configurations and features. 6 | """ 7 | 8 | import os 9 | from cognidb import CogniDB, create_cognidb 10 | 11 | 12 | def basic_example(): 13 | """Basic usage with environment variables.""" 14 | # Set environment variables (in production, use .env file or secrets manager) 15 | os.environ['DB_TYPE'] = 'postgresql' 16 | os.environ['DB_HOST'] = 'localhost' 17 | os.environ['DB_PORT'] = '5432' 18 | os.environ['DB_NAME'] = 'mydb' 19 | os.environ['DB_USER'] = 'myuser' 20 | os.environ['DB_PASSWORD'] = 'mypassword' 21 | os.environ['LLM_API_KEY'] = 'your-openai-api-key' 22 | 23 | # Create CogniDB instance 24 | with CogniDB() as db: 25 | # Simple query 26 | result = db.query("Show me all customers who made a purchase last month") 27 | 28 | if result['success']: 29 | print(f"Generated SQL: {result['sql']}") 30 | print(f"Found {result['row_count']} results") 31 | for row in result['results'][:5]: # Show first 5 32 | print(row) 33 | else: 34 | print(f"Error: {result['error']}") 35 | 36 | 37 | def config_file_example(): 38 | """Usage with configuration file.""" 39 | # Use configuration file 40 | with CogniDB(config_file='cognidb.yaml') as db: 41 | # Query with explanation 42 | result = db.query( 43 | "What are the top 5 products by revenue?", 44 | explain=True 45 | ) 46 | 47 | if result['success']: 48 | print(f"SQL: {result['sql']}") 49 | print(f"\nExplanation: {result['explanation']}") 50 | print(f"\nResults:") 51 | for row in result['results']: 52 | print(f" {row}") 53 | 54 | 55 | def advanced_features_example(): 56 | """Demonstrate advanced features.""" 57 | db = create_cognidb( 58 | database={ 59 | 'type': 'postgresql', 60 | 'host': 'localhost', 61 | 'database': 'analytics_db' 62 | }, 63 | llm={ 64 | 'provider': 'openai', 65 | 'model_name': 'gpt-4', 66 | 'temperature': 0.1 67 | } 68 | ) 69 | 70 | try: 71 | # 1. Query suggestions 72 | suggestions = db.suggest_queries("customers who") 73 | print("Query suggestions:") 74 | for suggestion in suggestions: 75 | print(f" - {suggestion}") 76 | 77 | # 2. Query optimization 78 | sql = "SELECT * FROM orders WHERE customer_id IN (SELECT id FROM customers WHERE country = 'USA')" 79 | optimization = db.optimize_query(sql) 80 | 81 | if optimization['success']: 82 | print(f"\nOriginal: {optimization['original_query']}") 83 | print(f"Optimized: {optimization['optimized_query']}") 84 | print(f"Explanation: {optimization['explanation']}") 85 | 86 | # 3. Schema inspection 87 | schema = db.get_schema('customers') 88 | print(f"\nCustomers table schema:") 89 | for column, dtype in schema['customers'].items(): 90 | print(f" {column}: {dtype}") 91 | 92 | # 4. Usage statistics 93 | stats = db.get_usage_stats() 94 | print(f"\nUsage stats:") 95 | print(f" Total cost: ${stats['total_cost']:.2f}") 96 | print(f" Requests today: {stats['request_count']}") 97 | 98 | finally: 99 | db.close() 100 | 101 | 102 | def multi_database_example(): 103 | """Example with different database types.""" 104 | databases = [ 105 | { 106 | 'type': 'mysql', 107 | 'config': { 108 | 'host': 'mysql.example.com', 109 | 'database': 'app_db' 110 | } 111 | }, 112 | { 113 | 'type': 'postgresql', 114 | 'config': { 115 | 'host': 'postgres.example.com', 116 | 'database': 'analytics_db' 117 | } 118 | }, 119 | { 120 | 'type': 'sqlite', 121 | 'config': { 122 | 'database': '/path/to/local.db' 123 | } 124 | } 125 | ] 126 | 127 | for db_info in databases: 128 | print(f"\nQuerying {db_info['type']} database:") 129 | 130 | with create_cognidb(database=db_info['config']) as db: 131 | result = db.query("Count the total number of records in the main table") 132 | 133 | if result['success']: 134 | print(f" SQL: {result['sql']}") 135 | print(f" Result: {result['results']}") 136 | 137 | 138 | def security_example(): 139 | """Demonstrate security features.""" 140 | # Configure with strict security 141 | db = create_cognidb( 142 | security={ 143 | 'allow_only_select': True, 144 | 'max_query_complexity': 5, 145 | 'allow_subqueries': False, 146 | 'enable_rate_limiting': True, 147 | 'rate_limit_per_minute': 10 148 | } 149 | ) 150 | 151 | # These queries will be validated 152 | safe_queries = [ 153 | "Show me all active users", 154 | "What's the average order value?", 155 | "List products with low stock" 156 | ] 157 | 158 | # These will be rejected 159 | unsafe_queries = [ 160 | "Delete all users where active = false", # Not SELECT 161 | "Update products set price = price * 1.1", # Not SELECT 162 | "'; DROP TABLE users; --", # SQL injection attempt 163 | ] 164 | 165 | print("Testing safe queries:") 166 | for query in safe_queries: 167 | result = db.query(query) 168 | print(f" '{query}' - Success: {result['success']}") 169 | 170 | print("\nTesting unsafe queries (should be rejected):") 171 | for query in unsafe_queries: 172 | result = db.query(query) 173 | print(f" '{query}' - Success: {result['success']}, Error: {result.get('error', '')[:50]}...") 174 | 175 | db.close() 176 | 177 | 178 | def custom_llm_example(): 179 | """Example with custom LLM configuration.""" 180 | # Use Anthropic Claude 181 | db_claude = create_cognidb( 182 | llm={ 183 | 'provider': 'anthropic', 184 | 'api_key': os.environ.get('ANTHROPIC_API_KEY'), 185 | 'model_name': 'claude-3-sonnet', 186 | 'temperature': 0.0 187 | } 188 | ) 189 | 190 | # Use local model 191 | db_local = create_cognidb( 192 | llm={ 193 | 'provider': 'local', 194 | 'local_model_path': '/path/to/model.gguf', 195 | 'max_tokens': 500 196 | } 197 | ) 198 | 199 | # Use Azure OpenAI 200 | db_azure = create_cognidb( 201 | llm={ 202 | 'provider': 'azure_openai', 203 | 'api_key': os.environ.get('AZURE_OPENAI_KEY'), 204 | 'azure_endpoint': 'https://myorg.openai.azure.com/', 205 | 'azure_deployment': 'gpt-4-deployment' 206 | } 207 | ) 208 | 209 | 210 | if __name__ == "__main__": 211 | # Run examples 212 | print("=== Basic Example ===") 213 | basic_example() 214 | 215 | print("\n=== Config File Example ===") 216 | config_file_example() 217 | 218 | print("\n=== Advanced Features ===") 219 | advanced_features_example() 220 | 221 | print("\n=== Security Example ===") 222 | security_example() -------------------------------------------------------------------------------- /cognidb/drivers/base_driver.py: -------------------------------------------------------------------------------- 1 | """Base implementation for database drivers with common functionality.""" 2 | 3 | import time 4 | import logging 5 | from typing import Dict, List, Any, Optional, Tuple 6 | from contextlib import contextmanager 7 | from abc import abstractmethod 8 | from ..core.interfaces import DatabaseDriver 9 | from ..core.exceptions import ConnectionError, ExecutionError, SchemaError 10 | from ..security.sanitizer import InputSanitizer 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class BaseDriver(DatabaseDriver): 16 | """ 17 | Base database driver with common functionality. 18 | 19 | Provides: 20 | - Connection management 21 | - Query execution with timeouts 22 | - Schema caching 23 | - Security validation 24 | - Error handling 25 | """ 26 | 27 | def __init__(self, config: Dict[str, Any]): 28 | """ 29 | Initialize base driver. 30 | 31 | Args: 32 | config: Database configuration 33 | """ 34 | self.config = config 35 | self.connection = None 36 | self.schema_cache = None 37 | self.schema_cache_time = 0 38 | self.schema_cache_ttl = 3600 # 1 hour 39 | self.sanitizer = InputSanitizer() 40 | self._connection_time = None 41 | 42 | @contextmanager 43 | def transaction(self): 44 | """Context manager for database transactions.""" 45 | if not self.supports_transactions: 46 | yield 47 | return 48 | 49 | try: 50 | self._begin_transaction() 51 | yield 52 | self._commit_transaction() 53 | except Exception as e: 54 | self._rollback_transaction() 55 | raise ExecutionError(f"Transaction failed: {str(e)}") 56 | 57 | def execute_native_query(self, 58 | query: str, 59 | params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: 60 | """ 61 | Execute a native query with parameters. 62 | 63 | Args: 64 | query: Native query string with parameter placeholders 65 | params: Parameter values for the query 66 | 67 | Returns: 68 | List of result rows as dictionaries 69 | """ 70 | if not self.connection: 71 | raise ConnectionError("Not connected to database") 72 | 73 | # Log query for debugging (without params for security) 74 | logger.debug(f"Executing query: {query[:100]}...") 75 | 76 | start_time = time.time() 77 | 78 | try: 79 | # Execute with timeout 80 | results = self._execute_with_timeout(query, params) 81 | 82 | # Log execution time 83 | execution_time = time.time() - start_time 84 | logger.info(f"Query executed in {execution_time:.2f}s") 85 | 86 | return results 87 | 88 | except Exception as e: 89 | logger.error(f"Query execution failed: {str(e)}") 90 | raise ExecutionError(f"Query execution failed: {str(e)}") 91 | 92 | def fetch_schema(self) -> Dict[str, Dict[str, str]]: 93 | """ 94 | Fetch database schema with caching. 95 | 96 | Returns: 97 | Dictionary mapping table names to column info 98 | """ 99 | # Check cache 100 | if self.schema_cache and (time.time() - self.schema_cache_time) < self.schema_cache_ttl: 101 | logger.debug("Using cached schema") 102 | return self.schema_cache 103 | 104 | logger.info("Fetching database schema") 105 | 106 | try: 107 | schema = self._fetch_schema_impl() 108 | 109 | # Cache the schema 110 | self.schema_cache = schema 111 | self.schema_cache_time = time.time() 112 | 113 | logger.info(f"Schema fetched: {len(schema)} tables") 114 | return schema 115 | 116 | except Exception as e: 117 | logger.error(f"Failed to fetch schema: {str(e)}") 118 | raise SchemaError(f"Failed to fetch schema: {str(e)}") 119 | 120 | def validate_table_name(self, table_name: str) -> bool: 121 | """Validate that a table name exists and is safe.""" 122 | # Sanitize the table name 123 | try: 124 | sanitized = self.sanitizer.sanitize_identifier(table_name) 125 | except ValueError: 126 | return False 127 | 128 | # Check if table exists in schema 129 | schema = self.fetch_schema() 130 | return sanitized in schema 131 | 132 | def validate_column_name(self, table_name: str, column_name: str) -> bool: 133 | """Validate that a column exists in the table.""" 134 | # Validate table first 135 | if not self.validate_table_name(table_name): 136 | return False 137 | 138 | # Sanitize column name 139 | try: 140 | sanitized_column = self.sanitizer.sanitize_identifier(column_name) 141 | except ValueError: 142 | return False 143 | 144 | # Check if column exists 145 | schema = self.fetch_schema() 146 | table_columns = schema.get(table_name, {}) 147 | return sanitized_column in table_columns 148 | 149 | def get_connection_info(self) -> Dict[str, Any]: 150 | """Get connection information (for debugging, minus secrets).""" 151 | info = { 152 | 'driver': self.__class__.__name__, 153 | 'host': self.config.get('host', 'N/A'), 154 | 'port': self.config.get('port', 'N/A'), 155 | 'database': self.config.get('database', 'N/A'), 156 | 'connected': self.connection is not None, 157 | 'connection_time': self._connection_time, 158 | 'schema_tables': len(self.schema_cache) if self.schema_cache else 0 159 | } 160 | 161 | # Add driver-specific info 162 | info.update(self._get_driver_info()) 163 | 164 | return info 165 | 166 | def invalidate_schema_cache(self): 167 | """Invalidate the schema cache.""" 168 | self.schema_cache = None 169 | self.schema_cache_time = 0 170 | logger.info("Schema cache invalidated") 171 | 172 | def ping(self) -> bool: 173 | """Check if connection is alive.""" 174 | try: 175 | # Simple query to test connection 176 | self.execute_native_query("SELECT 1") 177 | return True 178 | except Exception: 179 | return False 180 | 181 | def reconnect(self): 182 | """Reconnect to the database.""" 183 | logger.info("Attempting to reconnect") 184 | self.disconnect() 185 | self.connect() 186 | 187 | # Abstract methods to be implemented by subclasses 188 | 189 | @abstractmethod 190 | def _create_connection(self): 191 | """Create the actual database connection.""" 192 | pass 193 | 194 | @abstractmethod 195 | def _close_connection(self): 196 | """Close the database connection.""" 197 | pass 198 | 199 | @abstractmethod 200 | def _execute_with_timeout(self, query: str, params: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]: 201 | """Execute query with timeout (implementation specific).""" 202 | pass 203 | 204 | @abstractmethod 205 | def _fetch_schema_impl(self) -> Dict[str, Dict[str, str]]: 206 | """Fetch schema implementation.""" 207 | pass 208 | 209 | @abstractmethod 210 | def _begin_transaction(self): 211 | """Begin a transaction.""" 212 | pass 213 | 214 | @abstractmethod 215 | def _commit_transaction(self): 216 | """Commit a transaction.""" 217 | pass 218 | 219 | @abstractmethod 220 | def _rollback_transaction(self): 221 | """Rollback a transaction.""" 222 | pass 223 | 224 | @abstractmethod 225 | def _get_driver_info(self) -> Dict[str, Any]: 226 | """Get driver-specific information.""" 227 | pass -------------------------------------------------------------------------------- /cognidb/security/query_parser.py: -------------------------------------------------------------------------------- 1 | """SQL query parser for security validation.""" 2 | 3 | import re 4 | from typing import Dict, Any, List, Optional 5 | import sqlparse 6 | from sqlparse.sql import IdentifierList, Identifier, Token 7 | from sqlparse.tokens import Keyword, DML 8 | 9 | 10 | class SQLQueryParser: 11 | """ 12 | SQL query parser for security analysis. 13 | 14 | Parses SQL queries to extract structure and identify 15 | potential security issues. 16 | """ 17 | 18 | def __init__(self): 19 | """Initialize the parser.""" 20 | self.parsed_cache = {} 21 | 22 | def parse(self, query: str) -> Dict[str, Any]: 23 | """ 24 | Parse SQL query and extract security-relevant information. 25 | 26 | Args: 27 | query: SQL query string 28 | 29 | Returns: 30 | Dictionary with parsed query information: 31 | { 32 | 'type': 'SELECT', 33 | 'tables': ['users', 'orders'], 34 | 'columns': ['id', 'name'], 35 | 'has_subquery': False, 36 | 'has_union': False, 37 | 'has_join': True, 38 | 'has_where': True, 39 | 'complexity': 5 40 | } 41 | """ 42 | # Clean and normalize query 43 | query = query.strip() 44 | 45 | # Check cache 46 | cache_key = hash(query) 47 | if cache_key in self.parsed_cache: 48 | return self.parsed_cache[cache_key] 49 | 50 | # Parse with sqlparse 51 | parsed = sqlparse.parse(query)[0] 52 | 53 | result = { 54 | 'type': self._get_query_type(parsed), 55 | 'tables': self._extract_tables(parsed), 56 | 'columns': self._extract_columns(parsed), 57 | 'has_subquery': self._has_subquery(parsed), 58 | 'has_union': self._has_union(query), 59 | 'has_join': self._has_join(parsed), 60 | 'has_where': self._has_where(parsed), 61 | 'has_having': self._has_having(parsed), 62 | 'has_order_by': self._has_order_by(parsed), 63 | 'has_group_by': self._has_group_by(parsed), 64 | 'complexity': self._calculate_complexity(parsed) 65 | } 66 | 67 | # Cache result 68 | self.parsed_cache[cache_key] = result 69 | 70 | return result 71 | 72 | def _get_query_type(self, parsed) -> str: 73 | """Extract the main query type.""" 74 | for token in parsed.tokens: 75 | if token.ttype is DML: 76 | return token.value.upper() 77 | return "UNKNOWN" 78 | 79 | def _extract_tables(self, parsed) -> List[str]: 80 | """Extract table names from the query.""" 81 | tables = [] 82 | from_seen = False 83 | 84 | for token in parsed.tokens: 85 | if from_seen: 86 | if isinstance(token, IdentifierList): 87 | for identifier in token.get_identifiers(): 88 | tables.append(self._get_name(identifier)) 89 | elif isinstance(token, Identifier): 90 | tables.append(self._get_name(token)) 91 | elif token.ttype is None: 92 | tables.append(token.value) 93 | 94 | if token.ttype is Keyword and token.value.upper() == 'FROM': 95 | from_seen = True 96 | elif token.ttype is Keyword and token.value.upper() in ('WHERE', 'GROUP', 'ORDER', 'HAVING'): 97 | from_seen = False 98 | 99 | return [t.strip() for t in tables if t.strip()] 100 | 101 | def _extract_columns(self, parsed) -> List[str]: 102 | """Extract column names from SELECT clause.""" 103 | columns = [] 104 | select_seen = False 105 | 106 | for token in parsed.tokens: 107 | if select_seen and token.ttype is Keyword: 108 | break 109 | 110 | if select_seen: 111 | if isinstance(token, IdentifierList): 112 | for identifier in token.get_identifiers(): 113 | columns.append(self._get_name(identifier)) 114 | elif isinstance(token, Identifier): 115 | columns.append(self._get_name(token)) 116 | elif token.ttype is None and token.value not in (',', ' '): 117 | columns.append(token.value) 118 | 119 | if token.ttype is DML and token.value.upper() == 'SELECT': 120 | select_seen = True 121 | 122 | return [c.strip() for c in columns if c.strip()] 123 | 124 | def _get_name(self, identifier) -> str: 125 | """Get the name from an identifier.""" 126 | if hasattr(identifier, 'get_name'): 127 | return identifier.get_name() 128 | return str(identifier) 129 | 130 | def _has_subquery(self, parsed) -> bool: 131 | """Check if query contains subqueries.""" 132 | query_str = str(parsed) 133 | # Simple check for nested SELECT 134 | return query_str.count('SELECT') > 1 135 | 136 | def _has_union(self, query: str) -> bool: 137 | """Check if query contains UNION.""" 138 | return bool(re.search(r'\bUNION\b', query, re.IGNORECASE)) 139 | 140 | def _has_join(self, parsed) -> bool: 141 | """Check if query contains JOIN.""" 142 | for token in parsed.tokens: 143 | if token.ttype is Keyword and 'JOIN' in token.value.upper(): 144 | return True 145 | return False 146 | 147 | def _has_where(self, parsed) -> bool: 148 | """Check if query has WHERE clause.""" 149 | for token in parsed.tokens: 150 | if token.ttype is Keyword and token.value.upper() == 'WHERE': 151 | return True 152 | return False 153 | 154 | def _has_having(self, parsed) -> bool: 155 | """Check if query has HAVING clause.""" 156 | for token in parsed.tokens: 157 | if token.ttype is Keyword and token.value.upper() == 'HAVING': 158 | return True 159 | return False 160 | 161 | def _has_order_by(self, parsed) -> bool: 162 | """Check if query has ORDER BY clause.""" 163 | query_str = str(parsed).upper() 164 | return 'ORDER BY' in query_str 165 | 166 | def _has_group_by(self, parsed) -> bool: 167 | """Check if query has GROUP BY clause.""" 168 | query_str = str(parsed).upper() 169 | return 'GROUP BY' in query_str 170 | 171 | def _calculate_complexity(self, parsed) -> int: 172 | """ 173 | Calculate query complexity score. 174 | 175 | Higher scores indicate more complex queries that might 176 | need additional scrutiny. 177 | """ 178 | score = 1 # Base score 179 | 180 | # Add complexity for various features 181 | if self._has_subquery(parsed): 182 | score += 3 183 | if self._has_union(str(parsed)): 184 | score += 2 185 | if self._has_join(parsed): 186 | score += 2 187 | if self._has_where(parsed): 188 | score += 1 189 | if self._has_group_by(parsed): 190 | score += 2 191 | if self._has_having(parsed): 192 | score += 2 193 | if self._has_order_by(parsed): 194 | score += 1 195 | 196 | # Add complexity for number of tables 197 | tables = self._extract_tables(parsed) 198 | if len(tables) > 1: 199 | score += len(tables) - 1 200 | 201 | return score 202 | 203 | def validate_structure(self, query: str) -> Optional[str]: 204 | """ 205 | Validate query structure and return error if invalid. 206 | 207 | Args: 208 | query: SQL query string 209 | 210 | Returns: 211 | Error message if invalid, None if valid 212 | """ 213 | try: 214 | parsed = sqlparse.parse(query) 215 | if not parsed: 216 | return "Empty or invalid query" 217 | 218 | # Check for multiple statements 219 | if len(parsed) > 1: 220 | return "Multiple statements not allowed" 221 | 222 | # Get query type 223 | query_type = self._get_query_type(parsed[0]) 224 | if query_type == "UNKNOWN": 225 | return "Unknown query type" 226 | 227 | return None 228 | 229 | except Exception as e: 230 | return f"Query parsing error: {str(e)}" -------------------------------------------------------------------------------- /cognidb/security/access_control.py: -------------------------------------------------------------------------------- 1 | """Access control and permissions management.""" 2 | 3 | from typing import Dict, List, Set, Optional 4 | from dataclasses import dataclass, field 5 | from enum import Enum, auto 6 | from ..core.exceptions import SecurityError 7 | 8 | 9 | class Permission(Enum): 10 | """Database permissions.""" 11 | SELECT = auto() 12 | INSERT = auto() 13 | UPDATE = auto() 14 | DELETE = auto() 15 | CREATE = auto() 16 | DROP = auto() 17 | ALTER = auto() 18 | EXECUTE = auto() 19 | 20 | 21 | @dataclass 22 | class TablePermissions: 23 | """Permissions for a specific table.""" 24 | table_name: str 25 | allowed_operations: Set[Permission] = field(default_factory=set) 26 | allowed_columns: Optional[Set[str]] = None # None means all columns 27 | row_filter: Optional[str] = None # SQL condition for row-level security 28 | 29 | def can_access_column(self, column: str) -> bool: 30 | """Check if column access is allowed.""" 31 | if self.allowed_columns is None: 32 | return True 33 | return column in self.allowed_columns 34 | 35 | def can_perform_operation(self, operation: Permission) -> bool: 36 | """Check if operation is allowed.""" 37 | return operation in self.allowed_operations 38 | 39 | 40 | @dataclass 41 | class UserPermissions: 42 | """User's database permissions.""" 43 | user_id: str 44 | is_admin: bool = False 45 | table_permissions: Dict[str, TablePermissions] = field(default_factory=dict) 46 | global_permissions: Set[Permission] = field(default_factory=set) 47 | max_rows_per_query: int = 10000 48 | max_execution_time: int = 30 # seconds 49 | allowed_schemas: Set[str] = field(default_factory=set) 50 | 51 | def add_table_permission(self, table_perm: TablePermissions): 52 | """Add permissions for a table.""" 53 | self.table_permissions[table_perm.table_name] = table_perm 54 | 55 | def can_access_table(self, table: str) -> bool: 56 | """Check if user can access table.""" 57 | if self.is_admin: 58 | return True 59 | return table in self.table_permissions 60 | 61 | def can_perform_operation_on_table(self, table: str, operation: Permission) -> bool: 62 | """Check if user can perform operation on table.""" 63 | if self.is_admin: 64 | return True 65 | if table not in self.table_permissions: 66 | return False 67 | return self.table_permissions[table].can_perform_operation(operation) 68 | 69 | 70 | class AccessController: 71 | """ 72 | Controls access to database resources. 73 | 74 | Implements: 75 | - Table-level permissions 76 | - Column-level permissions 77 | - Row-level security 78 | - Operation restrictions 79 | - Resource limits 80 | """ 81 | 82 | def __init__(self): 83 | """Initialize access controller.""" 84 | self.users: Dict[str, UserPermissions] = {} 85 | self.default_permissions = UserPermissions( 86 | user_id="default", 87 | global_permissions={Permission.SELECT}, 88 | max_rows_per_query=1000, 89 | max_execution_time=10 90 | ) 91 | 92 | def add_user(self, user_permissions: UserPermissions): 93 | """Add user with permissions.""" 94 | self.users[user_permissions.user_id] = user_permissions 95 | 96 | def get_user_permissions(self, user_id: str) -> UserPermissions: 97 | """Get user permissions or default if not found.""" 98 | return self.users.get(user_id, self.default_permissions) 99 | 100 | def check_table_access(self, user_id: str, tables: List[str]) -> None: 101 | """ 102 | Check if user can access all tables. 103 | 104 | Raises: 105 | SecurityError: If access denied 106 | """ 107 | permissions = self.get_user_permissions(user_id) 108 | 109 | for table in tables: 110 | if not permissions.can_access_table(table): 111 | raise SecurityError(f"Access denied to table: {table}") 112 | 113 | def check_column_access(self, user_id: str, table: str, columns: List[str]) -> None: 114 | """ 115 | Check if user can access columns. 116 | 117 | Raises: 118 | SecurityError: If access denied 119 | """ 120 | permissions = self.get_user_permissions(user_id) 121 | 122 | if not permissions.can_access_table(table): 123 | raise SecurityError(f"Access denied to table: {table}") 124 | 125 | if permissions.is_admin: 126 | return 127 | 128 | if table in permissions.table_permissions: 129 | table_perm = permissions.table_permissions[table] 130 | for column in columns: 131 | if not table_perm.can_access_column(column): 132 | raise SecurityError(f"Access denied to column: {table}.{column}") 133 | 134 | def check_operation(self, user_id: str, operation: Permission, tables: List[str]) -> None: 135 | """ 136 | Check if user can perform operation. 137 | 138 | Raises: 139 | SecurityError: If operation not allowed 140 | """ 141 | permissions = self.get_user_permissions(user_id) 142 | 143 | # Check global permissions first 144 | if operation in permissions.global_permissions: 145 | return 146 | 147 | # Check table-specific permissions 148 | for table in tables: 149 | if not permissions.can_perform_operation_on_table(table, operation): 150 | raise SecurityError( 151 | f"Operation {operation.name} not allowed on table: {table}" 152 | ) 153 | 154 | def get_row_filters(self, user_id: str, table: str) -> Optional[str]: 155 | """Get row-level security filters for table.""" 156 | permissions = self.get_user_permissions(user_id) 157 | 158 | if permissions.is_admin: 159 | return None 160 | 161 | if table in permissions.table_permissions: 162 | return permissions.table_permissions[table].row_filter 163 | 164 | return None 165 | 166 | def get_resource_limits(self, user_id: str) -> Dict[str, int]: 167 | """Get resource limits for user.""" 168 | permissions = self.get_user_permissions(user_id) 169 | 170 | return { 171 | 'max_rows': permissions.max_rows_per_query, 172 | 'max_execution_time': permissions.max_execution_time 173 | } 174 | 175 | def create_read_only_user(self, user_id: str, allowed_tables: List[str]) -> UserPermissions: 176 | """Create a read-only user with access to specific tables.""" 177 | user = UserPermissions( 178 | user_id=user_id, 179 | global_permissions={Permission.SELECT}, 180 | max_rows_per_query=5000, 181 | max_execution_time=20 182 | ) 183 | 184 | for table in allowed_tables: 185 | user.add_table_permission( 186 | TablePermissions( 187 | table_name=table, 188 | allowed_operations={Permission.SELECT} 189 | ) 190 | ) 191 | 192 | self.add_user(user) 193 | return user 194 | 195 | def create_restricted_user(self, 196 | user_id: str, 197 | table_permissions_dict: Dict[str, Dict[str, Any]]) -> UserPermissions: 198 | """ 199 | Create user with specific table permissions. 200 | 201 | Args: 202 | user_id: User identifier 203 | table_permissions_dict: Dictionary of table permissions 204 | { 205 | 'users': { 206 | 'operations': ['SELECT'], 207 | 'columns': ['id', 'name', 'email'], 208 | 'row_filter': "department = 'sales'" 209 | } 210 | } 211 | """ 212 | user = UserPermissions(user_id=user_id) 213 | 214 | for table, perms in table_permissions_dict.items(): 215 | operations = { 216 | Permission[op] for op in perms.get('operations', ['SELECT']) 217 | } 218 | 219 | table_perm = TablePermissions( 220 | table_name=table, 221 | allowed_operations=operations, 222 | allowed_columns=set(perms['columns']) if 'columns' in perms else None, 223 | row_filter=perms.get('row_filter') 224 | ) 225 | 226 | user.add_table_permission(table_perm) 227 | 228 | self.add_user(user) 229 | return user -------------------------------------------------------------------------------- /cognidb/config/settings.py: -------------------------------------------------------------------------------- 1 | """Settings and configuration classes.""" 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Dict, List, Optional, Any 5 | from enum import Enum, auto 6 | import os 7 | from pathlib import Path 8 | 9 | 10 | class DatabaseType(Enum): 11 | """Supported database types.""" 12 | MYSQL = "mysql" 13 | POSTGRESQL = "postgresql" 14 | MONGODB = "mongodb" 15 | DYNAMODB = "dynamodb" 16 | SQLITE = "sqlite" 17 | 18 | 19 | class LLMProvider(Enum): 20 | """Supported LLM providers.""" 21 | OPENAI = "openai" 22 | ANTHROPIC = "anthropic" 23 | AZURE_OPENAI = "azure_openai" 24 | LOCAL = "local" 25 | HUGGINGFACE = "huggingface" 26 | 27 | 28 | class CacheProvider(Enum): 29 | """Supported cache providers.""" 30 | IN_MEMORY = "in_memory" 31 | REDIS = "redis" 32 | MEMCACHED = "memcached" 33 | DISK = "disk" 34 | 35 | 36 | @dataclass 37 | class DatabaseConfig: 38 | """Database configuration.""" 39 | type: DatabaseType 40 | host: str 41 | port: int 42 | database: str 43 | username: Optional[str] = None 44 | password: Optional[str] = None 45 | 46 | # Connection pool settings 47 | pool_size: int = 5 48 | max_overflow: int = 10 49 | pool_timeout: int = 30 50 | pool_recycle: int = 3600 51 | 52 | # SSL/TLS settings 53 | ssl_enabled: bool = False 54 | ssl_ca_cert: Optional[str] = None 55 | ssl_client_cert: Optional[str] = None 56 | ssl_client_key: Optional[str] = None 57 | 58 | # Query settings 59 | query_timeout: int = 30 # seconds 60 | max_result_size: int = 10000 # rows 61 | 62 | # Additional options 63 | options: Dict[str, Any] = field(default_factory=dict) 64 | 65 | def get_connection_string(self) -> str: 66 | """Generate connection string (without password).""" 67 | if self.type == DatabaseType.SQLITE: 68 | return f"sqlite:///{self.database}" 69 | 70 | auth = "" 71 | if self.username: 72 | auth = f"{self.username}:***@" 73 | 74 | return f"{self.type.value}://{auth}{self.host}:{self.port}/{self.database}" 75 | 76 | 77 | @dataclass 78 | class LLMConfig: 79 | """LLM configuration.""" 80 | provider: LLMProvider 81 | api_key: Optional[str] = None 82 | 83 | # Model settings 84 | model_name: str = "gpt-4" 85 | temperature: float = 0.1 86 | max_tokens: int = 1000 87 | timeout: int = 30 88 | 89 | # Cost control 90 | max_tokens_per_query: int = 2000 91 | max_queries_per_minute: int = 60 92 | max_cost_per_day: float = 100.0 93 | 94 | # Prompt settings 95 | system_prompt: Optional[str] = None 96 | few_shot_examples: List[Dict[str, str]] = field(default_factory=list) 97 | 98 | # Provider-specific settings 99 | azure_endpoint: Optional[str] = None 100 | azure_deployment: Optional[str] = None 101 | huggingface_model_id: Optional[str] = None 102 | local_model_path: Optional[str] = None 103 | 104 | # Advanced settings 105 | enable_function_calling: bool = True 106 | enable_streaming: bool = False 107 | retry_attempts: int = 3 108 | retry_delay: float = 1.0 109 | 110 | 111 | @dataclass 112 | class CacheConfig: 113 | """Cache configuration.""" 114 | provider: CacheProvider 115 | 116 | # TTL settings (in seconds) 117 | query_result_ttl: int = 3600 # 1 hour 118 | schema_ttl: int = 86400 # 24 hours 119 | llm_response_ttl: int = 7200 # 2 hours 120 | 121 | # Size limits 122 | max_cache_size_mb: int = 100 123 | max_entry_size_mb: int = 10 124 | eviction_policy: str = "lru" # lru, lfu, ttl 125 | 126 | # Redis settings 127 | redis_host: str = "localhost" 128 | redis_port: int = 6379 129 | redis_password: Optional[str] = None 130 | redis_db: int = 0 131 | redis_ssl: bool = False 132 | 133 | # Disk cache settings 134 | disk_cache_path: str = str(Path.home() / ".cognidb" / "cache") 135 | 136 | # Performance settings 137 | enable_compression: bool = True 138 | enable_async_writes: bool = True 139 | 140 | 141 | @dataclass 142 | class SecurityConfig: 143 | """Security configuration.""" 144 | # Query validation 145 | allow_only_select: bool = True 146 | max_query_complexity: int = 10 147 | allow_subqueries: bool = False 148 | allow_unions: bool = False 149 | 150 | # Rate limiting 151 | enable_rate_limiting: bool = True 152 | rate_limit_per_minute: int = 100 153 | rate_limit_per_hour: int = 1000 154 | 155 | # Access control 156 | enable_access_control: bool = True 157 | default_user_permissions: List[str] = field(default_factory=lambda: ["SELECT"]) 158 | require_authentication: bool = False 159 | 160 | # Audit logging 161 | enable_audit_logging: bool = True 162 | audit_log_path: str = str(Path.home() / ".cognidb" / "audit.log") 163 | log_query_results: bool = False 164 | 165 | # Encryption 166 | encrypt_cache: bool = True 167 | encrypt_logs: bool = True 168 | encryption_key: Optional[str] = None # Should be loaded from secrets 169 | 170 | # Network security 171 | allowed_ip_ranges: List[str] = field(default_factory=list) 172 | require_ssl: bool = True 173 | 174 | 175 | @dataclass 176 | class Settings: 177 | """Main settings container.""" 178 | # Core configurations 179 | database: DatabaseConfig 180 | llm: LLMConfig 181 | cache: CacheConfig 182 | security: SecurityConfig 183 | 184 | # Application settings 185 | app_name: str = "CogniDB" 186 | environment: str = "production" 187 | debug: bool = False 188 | log_level: str = "INFO" 189 | 190 | # Paths 191 | data_dir: str = str(Path.home() / ".cognidb") 192 | log_dir: str = str(Path.home() / ".cognidb" / "logs") 193 | 194 | # Feature flags 195 | enable_natural_language: bool = True 196 | enable_query_explanation: bool = True 197 | enable_query_optimization: bool = True 198 | enable_auto_indexing: bool = False 199 | 200 | # Monitoring 201 | enable_metrics: bool = True 202 | metrics_port: int = 9090 203 | enable_tracing: bool = True 204 | tracing_endpoint: Optional[str] = None 205 | 206 | @classmethod 207 | def from_env(cls) -> 'Settings': 208 | """Create settings from environment variables.""" 209 | return cls( 210 | database=DatabaseConfig( 211 | type=DatabaseType(os.getenv('DB_TYPE', 'postgresql')), 212 | host=os.getenv('DB_HOST', 'localhost'), 213 | port=int(os.getenv('DB_PORT', '5432')), 214 | database=os.getenv('DB_NAME', 'cognidb'), 215 | username=os.getenv('DB_USER'), 216 | password=os.getenv('DB_PASSWORD') 217 | ), 218 | llm=LLMConfig( 219 | provider=LLMProvider(os.getenv('LLM_PROVIDER', 'openai')), 220 | api_key=os.getenv('LLM_API_KEY'), 221 | model_name=os.getenv('LLM_MODEL', 'gpt-4') 222 | ), 223 | cache=CacheConfig( 224 | provider=CacheProvider(os.getenv('CACHE_PROVIDER', 'in_memory')) 225 | ), 226 | security=SecurityConfig( 227 | allow_only_select=os.getenv('SECURITY_SELECT_ONLY', 'true').lower() == 'true' 228 | ), 229 | environment=os.getenv('ENVIRONMENT', 'production'), 230 | debug=os.getenv('DEBUG', 'false').lower() == 'true' 231 | ) 232 | 233 | def validate(self) -> List[str]: 234 | """Validate settings and return list of errors.""" 235 | errors = [] 236 | 237 | # Database validation 238 | if not self.database.host: 239 | errors.append("Database host is required") 240 | if self.database.port <= 0 or self.database.port > 65535: 241 | errors.append("Invalid database port") 242 | 243 | # LLM validation 244 | if self.llm.provider != LLMProvider.LOCAL and not self.llm.api_key: 245 | errors.append("LLM API key is required for non-local providers") 246 | if self.llm.temperature < 0 or self.llm.temperature > 2: 247 | errors.append("LLM temperature must be between 0 and 2") 248 | 249 | # Security validation 250 | if self.security.encrypt_cache and not self.security.encryption_key: 251 | errors.append("Encryption key required when encryption is enabled") 252 | 253 | # Path validation 254 | for path_attr in ['data_dir', 'log_dir']: 255 | path = getattr(self, path_attr) 256 | try: 257 | Path(path).mkdir(parents=True, exist_ok=True) 258 | except Exception as e: 259 | errors.append(f"Cannot create {path_attr}: {e}") 260 | 261 | return errors -------------------------------------------------------------------------------- /cognidb/core/query_intent.py: -------------------------------------------------------------------------------- 1 | """Query intent representation - database agnostic query structure.""" 2 | 3 | from dataclasses import dataclass, field 4 | from enum import Enum, auto 5 | from typing import List, Optional, Dict, Any, Union 6 | 7 | 8 | class QueryType(Enum): 9 | """Supported query types.""" 10 | SELECT = auto() 11 | AGGREGATE = auto() 12 | COUNT = auto() 13 | DISTINCT = auto() 14 | 15 | 16 | class ComparisonOperator(Enum): 17 | """Comparison operators for conditions.""" 18 | EQ = "=" 19 | NE = "!=" 20 | GT = ">" 21 | GTE = ">=" 22 | LT = "<" 23 | LTE = "<=" 24 | IN = "IN" 25 | NOT_IN = "NOT IN" 26 | LIKE = "LIKE" 27 | NOT_LIKE = "NOT LIKE" 28 | IS_NULL = "IS NULL" 29 | IS_NOT_NULL = "IS NOT NULL" 30 | BETWEEN = "BETWEEN" 31 | 32 | 33 | class LogicalOperator(Enum): 34 | """Logical operators for combining conditions.""" 35 | AND = "AND" 36 | OR = "OR" 37 | 38 | 39 | class AggregateFunction(Enum): 40 | """Supported aggregate functions.""" 41 | SUM = "SUM" 42 | AVG = "AVG" 43 | COUNT = "COUNT" 44 | MIN = "MIN" 45 | MAX = "MAX" 46 | GROUP_CONCAT = "GROUP_CONCAT" 47 | 48 | 49 | class JoinType(Enum): 50 | """Types of joins.""" 51 | INNER = "INNER" 52 | LEFT = "LEFT" 53 | RIGHT = "RIGHT" 54 | FULL = "FULL" 55 | 56 | 57 | @dataclass 58 | class Column: 59 | """Represents a column reference.""" 60 | name: str 61 | table: Optional[str] = None 62 | alias: Optional[str] = None 63 | 64 | def __str__(self) -> str: 65 | if self.table: 66 | return f"{self.table}.{self.name}" 67 | return self.name 68 | 69 | 70 | @dataclass 71 | class Condition: 72 | """Represents a query condition.""" 73 | column: Column 74 | operator: ComparisonOperator 75 | value: Any 76 | 77 | def __post_init__(self): 78 | """Validate condition parameters.""" 79 | if self.operator == ComparisonOperator.BETWEEN: 80 | if not isinstance(self.value, (list, tuple)) or len(self.value) != 2: 81 | raise ValueError("BETWEEN operator requires a list/tuple of two values") 82 | elif self.operator in (ComparisonOperator.IN, ComparisonOperator.NOT_IN): 83 | if not isinstance(self.value, (list, tuple, set)): 84 | raise ValueError(f"{self.operator.value} operator requires a list/tuple/set") 85 | 86 | 87 | @dataclass 88 | class ConditionGroup: 89 | """Group of conditions with logical operator.""" 90 | conditions: List[Union[Condition, 'ConditionGroup']] 91 | operator: LogicalOperator = LogicalOperator.AND 92 | 93 | def add_condition(self, condition: Union[Condition, 'ConditionGroup']): 94 | """Add a condition to the group.""" 95 | self.conditions.append(condition) 96 | 97 | 98 | @dataclass 99 | class JoinCondition: 100 | """Represents a join between tables.""" 101 | join_type: JoinType 102 | left_table: str 103 | right_table: str 104 | left_column: str 105 | right_column: str 106 | additional_conditions: Optional[ConditionGroup] = None 107 | 108 | 109 | @dataclass 110 | class Aggregation: 111 | """Represents an aggregation operation.""" 112 | function: AggregateFunction 113 | column: Column 114 | alias: Optional[str] = None 115 | 116 | def __str__(self) -> str: 117 | return f"{self.function.value}({self.column})" 118 | 119 | 120 | @dataclass 121 | class OrderBy: 122 | """Represents ordering specification.""" 123 | column: Column 124 | ascending: bool = True 125 | 126 | 127 | @dataclass 128 | class QueryIntent: 129 | """ 130 | Database-agnostic representation of a query intent. 131 | 132 | This is the core abstraction that allows CogniDB to work with 133 | multiple database types by translating this intent into 134 | database-specific queries. 135 | """ 136 | query_type: QueryType 137 | tables: List[str] 138 | columns: List[Column] = field(default_factory=list) 139 | conditions: Optional[ConditionGroup] = None 140 | joins: List[JoinCondition] = field(default_factory=list) 141 | aggregations: List[Aggregation] = field(default_factory=list) 142 | group_by: List[Column] = field(default_factory=list) 143 | having: Optional[ConditionGroup] = None 144 | order_by: List[OrderBy] = field(default_factory=list) 145 | limit: Optional[int] = None 146 | offset: Optional[int] = None 147 | distinct: bool = False 148 | 149 | # Metadata for optimization and caching 150 | natural_language_query: Optional[str] = None 151 | estimated_cost: Optional[float] = None 152 | cache_ttl: Optional[int] = None # seconds 153 | 154 | def __post_init__(self): 155 | """Validate query intent.""" 156 | if not self.tables: 157 | raise ValueError("At least one table must be specified") 158 | 159 | if self.query_type == QueryType.SELECT and not self.columns: 160 | # Default to all columns if none specified 161 | self.columns = [Column("*")] 162 | 163 | if self.aggregations and not self.group_by: 164 | # Check if all columns are aggregated 165 | non_aggregated = [ 166 | col for col in self.columns 167 | if not any(agg.column.name == col.name for agg in self.aggregations) 168 | ] 169 | if non_aggregated: 170 | raise ValueError( 171 | "Non-aggregated columns must be in GROUP BY clause" 172 | ) 173 | 174 | if self.having and not self.group_by: 175 | raise ValueError("HAVING clause requires GROUP BY") 176 | 177 | def add_column(self, column: Union[str, Column]): 178 | """Add a column to select.""" 179 | if isinstance(column, str): 180 | column = Column(column) 181 | self.columns.append(column) 182 | 183 | def add_condition(self, condition: Condition): 184 | """Add a WHERE condition.""" 185 | if not self.conditions: 186 | self.conditions = ConditionGroup([]) 187 | self.conditions.add_condition(condition) 188 | 189 | def add_join(self, join: JoinCondition): 190 | """Add a join condition.""" 191 | self.joins.append(join) 192 | 193 | def add_aggregation(self, aggregation: Aggregation): 194 | """Add an aggregation.""" 195 | self.aggregations.append(aggregation) 196 | self.query_type = QueryType.AGGREGATE 197 | 198 | def set_limit(self, limit: int, offset: int = 0): 199 | """Set result limit and offset.""" 200 | if limit <= 0: 201 | raise ValueError("Limit must be positive") 202 | if offset < 0: 203 | raise ValueError("Offset must be non-negative") 204 | self.limit = limit 205 | self.offset = offset 206 | 207 | def to_dict(self) -> Dict[str, Any]: 208 | """Convert to dictionary for serialization.""" 209 | return { 210 | 'query_type': self.query_type.name, 211 | 'tables': self.tables, 212 | 'columns': [str(col) for col in self.columns], 213 | 'conditions': self._condition_group_to_dict(self.conditions) if self.conditions else None, 214 | 'joins': [self._join_to_dict(j) for j in self.joins], 215 | 'aggregations': [str(agg) for agg in self.aggregations], 216 | 'group_by': [str(col) for col in self.group_by], 217 | 'having': self._condition_group_to_dict(self.having) if self.having else None, 218 | 'order_by': [{'column': str(ob.column), 'asc': ob.ascending} for ob in self.order_by], 219 | 'limit': self.limit, 220 | 'offset': self.offset, 221 | 'distinct': self.distinct, 222 | 'natural_language_query': self.natural_language_query 223 | } 224 | 225 | def _condition_group_to_dict(self, group: ConditionGroup) -> Dict[str, Any]: 226 | """Convert condition group to dict.""" 227 | return { 228 | 'operator': group.operator.value, 229 | 'conditions': [ 230 | self._condition_to_dict(c) if isinstance(c, Condition) 231 | else self._condition_group_to_dict(c) 232 | for c in group.conditions 233 | ] 234 | } 235 | 236 | def _condition_to_dict(self, condition: Condition) -> Dict[str, Any]: 237 | """Convert condition to dict.""" 238 | return { 239 | 'column': str(condition.column), 240 | 'operator': condition.operator.value, 241 | 'value': condition.value 242 | } 243 | 244 | def _join_to_dict(self, join: JoinCondition) -> Dict[str, Any]: 245 | """Convert join to dict.""" 246 | return { 247 | 'type': join.join_type.value, 248 | 'left_table': join.left_table, 249 | 'right_table': join.right_table, 250 | 'left_column': join.left_column, 251 | 'right_column': join.right_column 252 | } -------------------------------------------------------------------------------- /cognidb/security/sanitizer.py: -------------------------------------------------------------------------------- 1 | """Input sanitization utilities.""" 2 | 3 | import re 4 | import html 5 | from typing import Any, Dict, List, Union 6 | 7 | 8 | class InputSanitizer: 9 | """ 10 | Comprehensive input sanitizer for all user inputs. 11 | 12 | Provides multiple sanitization strategies: 13 | 1. SQL identifiers (tables, columns) 14 | 2. String values 15 | 3. Numeric values 16 | 4. Natural language queries 17 | """ 18 | 19 | # Characters allowed in natural language queries 20 | ALLOWED_NL_CHARS = re.compile(r'[^a-zA-Z0-9\s\-_.,!?\'"\(\)%$#@]') 21 | 22 | # Maximum lengths for various inputs 23 | MAX_NATURAL_LANGUAGE_LENGTH = 500 24 | MAX_IDENTIFIER_LENGTH = 64 25 | MAX_STRING_VALUE_LENGTH = 1000 26 | 27 | @staticmethod 28 | def sanitize_natural_language(query: str) -> str: 29 | """ 30 | Sanitize natural language query. 31 | 32 | Args: 33 | query: Raw natural language query 34 | 35 | Returns: 36 | Sanitized query safe for LLM processing 37 | """ 38 | if not query: 39 | return "" 40 | 41 | # Truncate if too long 42 | query = query[:InputSanitizer.MAX_NATURAL_LANGUAGE_LENGTH] 43 | 44 | # Remove potentially harmful characters while preserving readability 45 | query = InputSanitizer.ALLOWED_NL_CHARS.sub(' ', query) 46 | 47 | # Normalize whitespace 48 | query = ' '.join(query.split()) 49 | 50 | # HTML escape for additional safety 51 | query = html.escape(query, quote=False) 52 | 53 | return query.strip() 54 | 55 | @staticmethod 56 | def sanitize_identifier(identifier: str) -> str: 57 | """ 58 | Sanitize database identifier (table/column name). 59 | 60 | Args: 61 | identifier: Raw identifier 62 | 63 | Returns: 64 | Sanitized identifier 65 | 66 | Raises: 67 | ValueError: If identifier cannot be sanitized 68 | """ 69 | if not identifier: 70 | raise ValueError("Identifier cannot be empty") 71 | 72 | # Remove any quotes or special characters 73 | identifier = re.sub(r'[^a-zA-Z0-9_]', '', identifier) 74 | 75 | # Ensure it starts with a letter or underscore 76 | if not re.match(r'^[a-zA-Z_]', identifier): 77 | identifier = f"_{identifier}" 78 | 79 | # Truncate if too long 80 | identifier = identifier[:InputSanitizer.MAX_IDENTIFIER_LENGTH] 81 | 82 | if not identifier: 83 | raise ValueError("Identifier contains no valid characters") 84 | 85 | return identifier 86 | 87 | @staticmethod 88 | def sanitize_string_value(value: str, allow_wildcards: bool = False) -> str: 89 | """ 90 | Sanitize string value for use in queries. 91 | 92 | Args: 93 | value: Raw string value 94 | allow_wildcards: Whether to allow SQL wildcards (% and _) 95 | 96 | Returns: 97 | Sanitized string value 98 | """ 99 | if not isinstance(value, str): 100 | return str(value) 101 | 102 | # Truncate if too long 103 | value = value[:InputSanitizer.MAX_STRING_VALUE_LENGTH] 104 | 105 | # Remove null bytes 106 | value = value.replace('\x00', '') 107 | 108 | # Handle SQL wildcards 109 | if not allow_wildcards: 110 | value = value.replace('%', '\\%').replace('_', '\\_') 111 | 112 | # Note: Actual SQL escaping should be done by parameterized queries 113 | # This is just an additional safety layer 114 | 115 | return value 116 | 117 | @staticmethod 118 | def sanitize_numeric_value(value: Union[int, float, str]) -> Union[int, float, None]: 119 | """ 120 | Sanitize numeric value. 121 | 122 | Args: 123 | value: Raw numeric value 124 | 125 | Returns: 126 | Sanitized numeric value or None if invalid 127 | """ 128 | if isinstance(value, (int, float)): 129 | return value 130 | 131 | if isinstance(value, str): 132 | try: 133 | # Try to parse as float first 134 | if '.' in value: 135 | return float(value) 136 | else: 137 | return int(value) 138 | except ValueError: 139 | return None 140 | 141 | return None 142 | 143 | @staticmethod 144 | def sanitize_list_value(values: List[Any], sanitize_func=None) -> List[Any]: 145 | """ 146 | Sanitize a list of values. 147 | 148 | Args: 149 | values: List of raw values 150 | sanitize_func: Function to apply to each value 151 | 152 | Returns: 153 | List of sanitized values 154 | """ 155 | if not isinstance(values, (list, tuple, set)): 156 | raise ValueError("Input must be a list, tuple, or set") 157 | 158 | if sanitize_func is None: 159 | sanitize_func = InputSanitizer.sanitize_string_value 160 | 161 | sanitized = [] 162 | for value in values: 163 | try: 164 | sanitized_value = sanitize_func(value) 165 | if sanitized_value is not None: 166 | sanitized.append(sanitized_value) 167 | except Exception: 168 | # Skip invalid values 169 | continue 170 | 171 | return sanitized 172 | 173 | @staticmethod 174 | def sanitize_dict_value(data: Dict[str, Any]) -> Dict[str, Any]: 175 | """ 176 | Recursively sanitize dictionary values. 177 | 178 | Args: 179 | data: Dictionary with raw values 180 | 181 | Returns: 182 | Dictionary with sanitized values 183 | """ 184 | if not isinstance(data, dict): 185 | raise ValueError("Input must be a dictionary") 186 | 187 | sanitized = {} 188 | for key, value in data.items(): 189 | # Sanitize key 190 | safe_key = InputSanitizer.sanitize_identifier(key) 191 | 192 | # Sanitize value based on type 193 | if isinstance(value, str): 194 | sanitized[safe_key] = InputSanitizer.sanitize_string_value(value) 195 | elif isinstance(value, (int, float)): 196 | sanitized[safe_key] = InputSanitizer.sanitize_numeric_value(value) 197 | elif isinstance(value, (list, tuple, set)): 198 | sanitized[safe_key] = InputSanitizer.sanitize_list_value(value) 199 | elif isinstance(value, dict): 200 | sanitized[safe_key] = InputSanitizer.sanitize_dict_value(value) 201 | elif value is None: 202 | sanitized[safe_key] = None 203 | else: 204 | # Convert to string and sanitize 205 | sanitized[safe_key] = InputSanitizer.sanitize_string_value(str(value)) 206 | 207 | return sanitized 208 | 209 | @staticmethod 210 | def escape_like_pattern(pattern: str) -> str: 211 | """ 212 | Escape special characters in LIKE patterns. 213 | 214 | Args: 215 | pattern: Raw LIKE pattern 216 | 217 | Returns: 218 | Escaped pattern 219 | """ 220 | # Escape LIKE special characters 221 | pattern = pattern.replace('\\', '\\\\') 222 | pattern = pattern.replace('%', '\\%') 223 | pattern = pattern.replace('_', '\\_') 224 | return pattern 225 | 226 | @staticmethod 227 | def validate_and_sanitize_limit(limit: Any) -> int: 228 | """ 229 | Validate and sanitize LIMIT value. 230 | 231 | Args: 232 | limit: Raw limit value 233 | 234 | Returns: 235 | Sanitized limit value 236 | 237 | Raises: 238 | ValueError: If limit is invalid 239 | """ 240 | try: 241 | limit = int(limit) 242 | if limit < 1: 243 | raise ValueError("Limit must be positive") 244 | if limit > 10000: # Reasonable maximum 245 | return 10000 246 | return limit 247 | except (TypeError, ValueError): 248 | raise ValueError("Invalid limit value") 249 | 250 | @staticmethod 251 | def validate_and_sanitize_offset(offset: Any) -> int: 252 | """ 253 | Validate and sanitize OFFSET value. 254 | 255 | Args: 256 | offset: Raw offset value 257 | 258 | Returns: 259 | Sanitized offset value 260 | 261 | Raises: 262 | ValueError: If offset is invalid 263 | """ 264 | try: 265 | offset = int(offset) 266 | if offset < 0: 267 | raise ValueError("Offset must be non-negative") 268 | if offset > 1000000: # Reasonable maximum 269 | raise ValueError("Offset too large") 270 | return offset 271 | except (TypeError, ValueError): 272 | raise ValueError("Invalid offset value") -------------------------------------------------------------------------------- /cognidb/drivers/mysql_driver.py: -------------------------------------------------------------------------------- 1 | """Secure MySQL driver implementation.""" 2 | 3 | import time 4 | import logging 5 | from typing import Dict, List, Any, Optional 6 | import mysql.connector 7 | from mysql.connector import pooling, Error 8 | from .base_driver import BaseDriver 9 | from ..core.exceptions import ConnectionError, ExecutionError 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class MySQLDriver(BaseDriver): 15 | """ 16 | MySQL database driver with security enhancements. 17 | 18 | Features: 19 | - Connection pooling 20 | - Parameterized queries only 21 | - SSL/TLS support 22 | - Query timeout enforcement 23 | - Automatic reconnection 24 | """ 25 | 26 | def __init__(self, config: Dict[str, Any]): 27 | """Initialize MySQL driver.""" 28 | super().__init__(config) 29 | self.pool = None 30 | 31 | def connect(self) -> None: 32 | """Establish connection to MySQL database.""" 33 | try: 34 | # Prepare connection config 35 | pool_config = { 36 | 'pool_name': 'cognidb_mysql_pool', 37 | 'pool_size': self.config.get('pool_size', 5), 38 | 'host': self.config['host'], 39 | 'port': self.config.get('port', 3306), 40 | 'database': self.config['database'], 41 | 'user': self.config.get('username'), 42 | 'password': self.config.get('password'), 43 | 'autocommit': False, 44 | 'raise_on_warnings': True, 45 | 'sql_mode': 'TRADITIONAL', 46 | 'time_zone': '+00:00', 47 | 'connect_timeout': self.config.get('connection_timeout', 10) 48 | } 49 | 50 | # SSL configuration 51 | if self.config.get('ssl_enabled'): 52 | ssl_config = {} 53 | if self.config.get('ssl_ca_cert'): 54 | ssl_config['ca'] = self.config['ssl_ca_cert'] 55 | if self.config.get('ssl_client_cert'): 56 | ssl_config['cert'] = self.config['ssl_client_cert'] 57 | if self.config.get('ssl_client_key'): 58 | ssl_config['key'] = self.config['ssl_client_key'] 59 | pool_config['ssl_disabled'] = False 60 | pool_config['ssl_verify_cert'] = True 61 | pool_config['ssl_verify_identity'] = True 62 | pool_config.update(ssl_config) 63 | 64 | # Create connection pool 65 | self.pool = pooling.MySQLConnectionPool(**pool_config) 66 | 67 | # Test connection 68 | self.connection = self.pool.get_connection() 69 | self._connection_time = time.time() 70 | 71 | logger.info(f"Connected to MySQL database: {self.config['database']}") 72 | 73 | except Error as e: 74 | logger.error(f"MySQL connection failed: {str(e)}") 75 | raise ConnectionError(f"Failed to connect to MySQL: {str(e)}") 76 | 77 | def disconnect(self) -> None: 78 | """Close the database connection.""" 79 | if self.connection: 80 | try: 81 | self.connection.close() 82 | logger.info("Disconnected from MySQL database") 83 | except Error as e: 84 | logger.error(f"Error closing connection: {str(e)}") 85 | finally: 86 | self.connection = None 87 | self._connection_time = None 88 | 89 | def _create_connection(self): 90 | """Get connection from pool.""" 91 | if not self.pool: 92 | raise ConnectionError("Connection pool not initialized") 93 | return self.pool.get_connection() 94 | 95 | def _close_connection(self): 96 | """Return connection to pool.""" 97 | if self.connection: 98 | self.connection.close() 99 | 100 | def _execute_with_timeout(self, 101 | query: str, 102 | params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: 103 | """Execute query with timeout.""" 104 | cursor = None 105 | 106 | try: 107 | # Get connection from pool if needed 108 | if not self.connection or not self.connection.is_connected(): 109 | self.connection = self._create_connection() 110 | 111 | # Set query timeout 112 | timeout = self.config.get('query_timeout', 30) 113 | self.connection.cmd_query(f"SET SESSION MAX_EXECUTION_TIME={timeout * 1000}") 114 | 115 | # Create cursor 116 | cursor = self.connection.cursor(dictionary=True, buffered=True) 117 | 118 | # Execute query with parameters 119 | if params: 120 | # Convert dict params to list for MySQL 121 | cursor.execute(query, list(params.values())) 122 | else: 123 | cursor.execute(query) 124 | 125 | # Fetch results 126 | if cursor.description: 127 | results = cursor.fetchall() 128 | 129 | # Apply result size limit 130 | max_results = self.config.get('max_result_size', 10000) 131 | if len(results) > max_results: 132 | logger.warning(f"Result truncated from {len(results)} to {max_results} rows") 133 | results = results[:max_results] 134 | 135 | return results 136 | else: 137 | # For non-SELECT queries 138 | self.connection.commit() 139 | return [{'affected_rows': cursor.rowcount}] 140 | 141 | except Error as e: 142 | if self.connection: 143 | self.connection.rollback() 144 | raise ExecutionError(f"Query execution failed: {str(e)}") 145 | finally: 146 | if cursor: 147 | cursor.close() 148 | 149 | def _fetch_schema_impl(self) -> Dict[str, Dict[str, str]]: 150 | """Fetch MySQL schema using INFORMATION_SCHEMA.""" 151 | query = """ 152 | SELECT 153 | TABLE_NAME, 154 | COLUMN_NAME, 155 | DATA_TYPE, 156 | IS_NULLABLE, 157 | COLUMN_KEY, 158 | COLUMN_DEFAULT, 159 | EXTRA 160 | FROM INFORMATION_SCHEMA.COLUMNS 161 | WHERE TABLE_SCHEMA = %s 162 | ORDER BY TABLE_NAME, ORDINAL_POSITION 163 | """ 164 | 165 | cursor = None 166 | try: 167 | cursor = self.connection.cursor(dictionary=True) 168 | cursor.execute(query, (self.config['database'],)) 169 | 170 | schema = {} 171 | for row in cursor: 172 | table_name = row['TABLE_NAME'] 173 | if table_name not in schema: 174 | schema[table_name] = {} 175 | 176 | # Build column info 177 | col_type = row['DATA_TYPE'] 178 | if row['IS_NULLABLE'] == 'NO': 179 | col_type += ' NOT NULL' 180 | if row['COLUMN_KEY'] == 'PRI': 181 | col_type += ' PRIMARY KEY' 182 | if row['EXTRA']: 183 | col_type += f" {row['EXTRA']}" 184 | 185 | schema[table_name][row['COLUMN_NAME']] = col_type 186 | 187 | # Fetch indexes 188 | self._fetch_indexes(schema, cursor) 189 | 190 | return schema 191 | 192 | finally: 193 | if cursor: 194 | cursor.close() 195 | 196 | def _fetch_indexes(self, schema: Dict[str, Dict[str, str]], cursor): 197 | """Fetch index information.""" 198 | query = """ 199 | SELECT 200 | TABLE_NAME, 201 | INDEX_NAME, 202 | GROUP_CONCAT(COLUMN_NAME ORDER BY SEQ_IN_INDEX) as COLUMNS 203 | FROM INFORMATION_SCHEMA.STATISTICS 204 | WHERE TABLE_SCHEMA = %s AND INDEX_NAME != 'PRIMARY' 205 | GROUP BY TABLE_NAME, INDEX_NAME 206 | """ 207 | 208 | cursor.execute(query, (self.config['database'],)) 209 | 210 | for row in cursor: 211 | table_name = row['TABLE_NAME'] 212 | if table_name in schema: 213 | index_key = f"{table_name}_indexes" 214 | if index_key not in schema: 215 | schema[index_key] = [] 216 | schema[index_key].append(f"{row['INDEX_NAME']} ({row['COLUMNS']})") 217 | 218 | def _begin_transaction(self): 219 | """Begin a transaction.""" 220 | self.connection.start_transaction() 221 | 222 | def _commit_transaction(self): 223 | """Commit a transaction.""" 224 | self.connection.commit() 225 | 226 | def _rollback_transaction(self): 227 | """Rollback a transaction.""" 228 | self.connection.rollback() 229 | 230 | def _get_driver_info(self) -> Dict[str, Any]: 231 | """Get MySQL-specific information.""" 232 | info = { 233 | 'server_version': None, 234 | 'connection_id': None, 235 | 'character_set': None 236 | } 237 | 238 | if self.connection and self.connection.is_connected(): 239 | try: 240 | cursor = self.connection.cursor() 241 | cursor.execute("SELECT VERSION(), CONNECTION_ID(), @@character_set_database") 242 | result = cursor.fetchone() 243 | info['server_version'] = result[0] 244 | info['connection_id'] = result[1] 245 | info['character_set'] = result[2] 246 | cursor.close() 247 | except Exception: 248 | pass 249 | 250 | return info 251 | 252 | @property 253 | def supports_transactions(self) -> bool: 254 | """MySQL supports transactions.""" 255 | return True 256 | 257 | @property 258 | def supports_schemas(self) -> bool: 259 | """MySQL supports schemas (databases).""" 260 | return True -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # `CogniDB` 2 | 3 | A secure, production-ready natural language database interface that empowers users to query databases using plain English while maintaining enterprise-grade security and performance.
4 | 5 | --- 6 | 7 | 8 | 9 | ## Features 10 | 11 | ### Core Capabilities
12 | - 🗣️ Natural Language Querying: Ask questions in plain English
13 | - 🔍 Intelligent SQL Generation: Context-aware query generation with schema understanding
14 | - 🛡️ Enterprise Security: Multi-layer security validation and sanitization
15 | - 🚀 High Performance: Query caching, connection pooling, and optimization
16 | - 📊 Multi-Database Support: MySQL, PostgreSQL, MongoDB, DynamoDB, SQLite
17 | - 💰 Cost Control: LLM usage tracking with configurable limits
18 | - 📈 Query Optimization: AI-powered query performance suggestions 19 | 20 | ### Security Features
21 | - SQL Injection Prevention: Parameterized queries and comprehensive validation
22 | - Access Control: Table and column-level permissions
23 | - Rate Limiting: Configurable request limits
24 | - Audit Logging: Complete query audit trail
25 | - Encryption: At-rest and in-transit encryption support 26 | 27 | ### AI/LLM Features
28 | - Multi-Provider Support: OpenAI, Anthropic, Azure, HuggingFace, Local models
29 | - Cost Tracking: Real-time usage and cost monitoring
30 | - Smart Caching: Reduce costs with intelligent response caching
31 | - Few-Shot Learning: Improve accuracy with custom examples 32 | 33 | --- 34 | 35 | ## Quick Start 36 | 37 | ### Installation 38 | 39 | 1. `Install dependencies`
40 | pip install cognidb 41 | 42 | 2. `With all optional dependencies`
43 | pip install cognidb[all] 44 | 45 | 3. `With specific features`
46 | pip install cognidb[redis,azure] 47 | 48 | ### Basic Usage 49 | 50 | ```python 51 | from cognidb import create_cognidb 52 | 53 | # Initialize with configuration 54 | db = create_cognidb( 55 | database={ 56 | 'type': 'postgresql', 57 | 'host': 'localhost', 58 | 'database': 'mydb', 59 | 'username': 'user', 60 | 'password': 'pass' 61 | }, 62 | llm={ 63 | 'provider': 'openai', 64 | 'api_key': 'your-api-key' 65 | } 66 | ) 67 | 68 | # Query in natural language 69 | result = db.query("Show me the top 10 customers by total purchase amount") 70 | 71 | if result['success']: 72 | print(f"SQL: {result['sql']}") 73 | print(f"Results: {result['results']}") 74 | 75 | # Always close when done 76 | db.close() 77 | ``` 78 | 79 | ### Using Context Manager 80 | 81 | ```python 82 | from cognidb import CogniDB 83 | 84 | # Automatically handles connection cleanup 85 | with CogniDB(config_file='cognidb.yaml') as db: 86 | result = db.query( 87 | "What were the sales trends last quarter?", 88 | explain=True # Get explanation of the query 89 | ) 90 | 91 | if result['success']: 92 | print(f"Explanation: {result['explanation']}") 93 | for row in result['results']: 94 | print(row) 95 | ``` 96 | 97 | --- 98 | 99 | ## Configuration 100 | 101 | ### Environment Variables 102 | 103 | ```bash 104 | # Database settings 105 | export DB_TYPE=postgresql 106 | export DB_HOST=localhost 107 | export DB_PORT=5432 108 | export DB_NAME=mydb 109 | export DB_USER=dbuser 110 | export DB_PASSWORD=secure_password 111 | 112 | # LLM settings 113 | export LLM_PROVIDER=openai 114 | export LLM_API_KEY=your_api_key 115 | export LLM_MODEL=gpt-4 116 | 117 | # Optional: Use configuration file instead 118 | export COGNIDB_CONFIG=/path/to/cognidb.yaml 119 | ``` 120 | 121 | ### Configuration File (YAML) 122 | 123 | Create a cognidb.yaml file: 124 | 125 | ```yaml 126 | database: 127 | type: postgresql 128 | host: localhost 129 | port: 5432 130 | database: analytics_db 131 | username: ${DB_USER} # Use environment variable 132 | password: ${DB_PASSWORD} 133 | 134 | # Connection settings 135 | pool_size: 5 136 | query_timeout: 30 137 | ssl_enabled: true 138 | 139 | llm: 140 | provider: openai 141 | api_key: ${LLM_API_KEY} 142 | model_name: gpt-4 143 | temperature: 0.1 144 | max_cost_per_day: 100.0 145 | 146 | # Improve accuracy with examples 147 | few_shot_examples: 148 | - query: "Show total sales by month" 149 | sql: "SELECT DATE_TRUNC('month', order_date) as month, SUM(amount) as total FROM orders GROUP BY month ORDER BY month" 150 | 151 | security: 152 | allow_only_select: true 153 | enable_rate_limiting: true 154 | rate_limit_per_minute: 100 155 | enable_audit_logging: true 156 | ``` 157 | 158 | See cognidb.example.yaml for a complete configuration example. 159 | 160 | --- 161 | 162 | ## Advanced Features 163 | 164 | ### Query Optimization 165 | 166 | ```python 167 | # Get optimization suggestions 168 | sql = "SELECT * FROM orders WHERE customer_id IN (SELECT id FROM customers WHERE country = 'USA')" 169 | optimization = db.optimize_query(sql) 170 | 171 | print(f"Original: {optimization['original_query']}") 172 | print(f"Optimized: {optimization['optimized_query']}") 173 | print(f"Explanation: {optimization['explanation']}") 174 | ``` 175 | 176 | ### Query Suggestions 177 | 178 | ```python 179 | # Get AI-powered query suggestions 180 | suggestions = db.suggest_queries("customers who haven't") 181 | for suggestion in suggestions: 182 | print(f"- {suggestion}") 183 | # Output: 184 | # - customers who haven't made a purchase in the last 30 days 185 | # - customers who haven't updated their profile 186 | # - customers who haven't verified their email 187 | ``` 188 | 189 | ### Access Control 190 | 191 | ```python 192 | from cognidb.security import AccessController 193 | 194 | # Set up user permissions 195 | access = AccessController() 196 | access.create_restricted_user( 197 | user_id="analyst_1", 198 | table_permissions={ 199 | 'customers': { 200 | 'operations': ['SELECT'], 201 | 'columns': ['id', 'name', 'email', 'country'], 202 | 'row_filter': "country = 'USA'" # Row-level security 203 | } 204 | } 205 | ) 206 | 207 | # Query with user context 208 | result = db.query( 209 | "Show me all customer emails", 210 | user_id="analyst_1" # Will only see US customers 211 | ) 212 | ``` 213 | 214 | ### Cost Tracking 215 | 216 | ```python 217 | # Monitor LLM usage and costs 218 | stats = db.get_usage_stats() 219 | print(f"Total cost today: ${stats['daily_cost']:.2f}") 220 | print(f"Remaining budget: ${stats['remaining_budget']:.2f}") 221 | print(f"Queries today: {stats['request_count']}") 222 | 223 | # Export usage report 224 | report = db.export_usage_report( 225 | start_date='2024-01-01', 226 | end_date='2024-01-31', 227 | format='csv' 228 | ) 229 | ``` 230 | 231 | --- 232 | 233 | ## Architecture 234 | 235 | CogniDB uses a modular, secure architecture: 236 | 237 | ``` 238 | ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ 239 | │ User Input │────▶│ Security Layer │────▶│ LLM Manager │ 240 | └─────────────────┘ └─────────────────┘ └─────────────────┘ 241 | │ │ 242 | ▼ ▼ 243 | ┌─────────────────┐ ┌─────────────────┐ 244 | │ Query Validator│ │ Query Generator │ 245 | └─────────────────┘ └─────────────────┘ 246 | │ │ 247 | ▼ ▼ 248 | ┌─────────────────┐ ┌─────────────────┐ 249 | │ Database Driver │────▶│ Result Cache │ 250 | └─────────────────┘ └─────────────────┘ 251 | ``` 252 | 253 | ### Key Components
254 | 1. Security Layer: Multi-stage validation and sanitization
255 | 2. LLM Manager: Handles all AI interactions with fallback support
256 | 3. Query Generator: Converts natural language to SQL with schema awareness
257 | 4. Database Drivers: Secure, parameterized database connections
258 | 5. Cache Layer: Reduces costs and improves performance 259 | 260 | --- 261 | 262 | ## Security Best Practices
263 | 1. Never expose credentials: Use environment variables or secrets managers
264 | 2. Enable SSL/TLS: Always use encrypted connections
265 | 3. Restrict permissions: Use read-only database users when possible
266 | 4. Monitor usage: Enable audit logging and review regularly
267 | 5. Update regularly: Keep CogniDB and dependencies up to date 268 | 269 | --- 270 | 271 | ## Testing 272 | 273 | ```bash 274 | # Run all tests 275 | pytest 276 | 277 | # Run with coverage 278 | pytest --cov=cognidb 279 | 280 | # Run security tests only 281 | pytest tests/security/ 282 | 283 | # Run integration tests 284 | pytest tests/integration/ --db-host=localhost 285 | ``` 286 | 287 | --- 288 | 289 | ## Performance Tips
290 | 1. Use connection pooling: Enabled by default for better performance
291 | 2. Enable caching: Reduces LLM costs and improves response time
292 | 3. Optimize schemas: Add appropriate indexes based on query patterns
293 | 4. Use prepared statements: For frequently executed queries
294 | 5. Monitor query performance: Use the optimization feature regularly 295 | 296 | --- 297 | 298 | ## Contributing 299 | 300 | We welcome contributions! Please see CONTRIBUTING.md for guidelines. 301 | 302 | ### Development Setup 303 | 304 | ```bash 305 | # Clone the repository 306 | git clone https://github.com/boxed-dev/cognidb 307 | cd cognidb 308 | 309 | # Install in development mode 310 | pip install -e .[dev] 311 | 312 | # Run pre-commit hooks 313 | pre-commit install 314 | 315 | # Run tests 316 | pytest 317 | ``` 318 | 319 | --- 320 | 321 | ## License 322 | 323 | This project is licensed under the MIT License - see LICENSE for details. 324 | 325 | --- 326 | 327 | ## Acknowledgments
328 | - OpenAI, Anthropic, and the open-source LLM community
329 | - Contributors to SQLParse, psycopg2, and other dependencies
330 | - The CogniDB community for feedback and contributions 331 | 332 | --- 333 | 334 | Built with ❤️ for data democratization 335 | 336 |
337 | -------------------------------------------------------------------------------- /cognidb/ai/llm_manager.py: -------------------------------------------------------------------------------- 1 | """LLM manager with multi-provider support and cost tracking.""" 2 | 3 | import time 4 | from typing import Dict, Any, Optional, List, Union 5 | from abc import ABC, abstractmethod 6 | from dataclasses import dataclass 7 | from ..config.settings import LLMConfig, LLMProvider 8 | from ..core.exceptions import CogniDBError, RateLimitError 9 | from .cost_tracker import CostTracker 10 | from .providers import ( 11 | OpenAIProvider, 12 | AnthropicProvider, 13 | AzureOpenAIProvider, 14 | HuggingFaceProvider, 15 | LocalProvider 16 | ) 17 | 18 | 19 | @dataclass 20 | class LLMResponse: 21 | """LLM response container.""" 22 | content: str 23 | model: str 24 | provider: str 25 | usage: Dict[str, int] 26 | cost: float 27 | latency: float 28 | cached: bool = False 29 | 30 | 31 | class LLMManager: 32 | """ 33 | Manages LLM interactions with multiple providers. 34 | 35 | Features: 36 | - Multi-provider support with fallback 37 | - Cost tracking and limits 38 | - Rate limiting 39 | - Response caching 40 | - Token usage tracking 41 | """ 42 | 43 | def __init__(self, config: LLMConfig, cache_provider=None): 44 | """ 45 | Initialize LLM manager. 46 | 47 | Args: 48 | config: LLM configuration 49 | cache_provider: Optional cache provider for response caching 50 | """ 51 | self.config = config 52 | self.cache = cache_provider 53 | self.cost_tracker = CostTracker(max_daily_cost=config.max_cost_per_day) 54 | 55 | # Initialize primary provider 56 | self.primary_provider = self._create_provider(config.provider, config) 57 | 58 | # Initialize fallback providers 59 | self.fallback_providers = [] 60 | if config.provider != LLMProvider.OPENAI: 61 | self.fallback_providers.append( 62 | self._create_provider(LLMProvider.OPENAI, config) 63 | ) 64 | 65 | # Rate limiting 66 | self._request_times: List[float] = [] 67 | self._last_request_time = 0 68 | 69 | def generate(self, 70 | prompt: str, 71 | system_prompt: Optional[str] = None, 72 | max_tokens: Optional[int] = None, 73 | temperature: Optional[float] = None, 74 | use_cache: bool = True) -> LLMResponse: 75 | """ 76 | Generate response from LLM. 77 | 78 | Args: 79 | prompt: User prompt 80 | system_prompt: System prompt (overrides config) 81 | max_tokens: Maximum tokens (overrides config) 82 | temperature: Temperature (overrides config) 83 | use_cache: Whether to use cached responses 84 | 85 | Returns: 86 | LLM response 87 | 88 | Raises: 89 | RateLimitError: If rate limit exceeded 90 | CogniDBError: If generation fails 91 | """ 92 | # Check rate limits 93 | self._check_rate_limit() 94 | 95 | # Check cost limits 96 | if self.cost_tracker.is_limit_exceeded(): 97 | raise CogniDBError("Daily cost limit exceeded") 98 | 99 | # Check cache 100 | if use_cache and self.cache: 101 | cache_key = self._generate_cache_key(prompt, system_prompt) 102 | cached_response = self.cache.get(cache_key) 103 | if cached_response: 104 | cached_response['cached'] = True 105 | return LLMResponse(**cached_response) 106 | 107 | # Prepare parameters 108 | params = { 109 | 'prompt': prompt, 110 | 'system_prompt': system_prompt or self.config.system_prompt, 111 | 'max_tokens': max_tokens or self.config.max_tokens, 112 | 'temperature': temperature or self.config.temperature 113 | } 114 | 115 | # Try primary provider 116 | start_time = time.time() 117 | response = None 118 | last_error = None 119 | 120 | for provider in [self.primary_provider] + self.fallback_providers: 121 | try: 122 | response = provider.generate(**params) 123 | response['provider'] = provider.name 124 | response['latency'] = time.time() - start_time 125 | break 126 | except Exception as e: 127 | last_error = e 128 | continue 129 | 130 | if response is None: 131 | raise CogniDBError(f"All LLM providers failed: {last_error}") 132 | 133 | # Track cost 134 | cost = self._calculate_cost(response['usage'], response['model']) 135 | response['cost'] = cost 136 | self.cost_tracker.track_usage(cost, response['usage']) 137 | 138 | # Cache response 139 | if use_cache and self.cache: 140 | self.cache.set( 141 | cache_key, 142 | response, 143 | ttl=self.config.llm_response_ttl 144 | ) 145 | 146 | # Update rate limiting 147 | self._request_times.append(time.time()) 148 | 149 | return LLMResponse(**response) 150 | 151 | def generate_with_examples(self, 152 | prompt: str, 153 | examples: List[Dict[str, str]], 154 | **kwargs) -> LLMResponse: 155 | """ 156 | Generate response with few-shot examples. 157 | 158 | Args: 159 | prompt: User prompt 160 | examples: List of input/output examples 161 | **kwargs: Additional generation parameters 162 | 163 | Returns: 164 | LLM response 165 | """ 166 | # Build prompt with examples 167 | formatted_prompt = self._format_with_examples(prompt, examples) 168 | return self.generate(formatted_prompt, **kwargs) 169 | 170 | def stream_generate(self, 171 | prompt: str, 172 | callback, 173 | **kwargs): 174 | """ 175 | Stream generation with callback. 176 | 177 | Args: 178 | prompt: User prompt 179 | callback: Function called with each token 180 | **kwargs: Additional generation parameters 181 | """ 182 | if not self.config.enable_streaming: 183 | raise CogniDBError("Streaming not enabled in configuration") 184 | 185 | # Check if provider supports streaming 186 | if not hasattr(self.primary_provider, 'stream_generate'): 187 | raise CogniDBError("Provider does not support streaming") 188 | 189 | # Stream from provider 190 | self.primary_provider.stream_generate(prompt, callback, **kwargs) 191 | 192 | def get_usage_stats(self) -> Dict[str, Any]: 193 | """Get usage statistics.""" 194 | return { 195 | 'total_cost': self.cost_tracker.get_total_cost(), 196 | 'daily_cost': self.cost_tracker.get_daily_cost(), 197 | 'token_usage': self.cost_tracker.get_token_usage(), 198 | 'request_count': len(self._request_times), 199 | 'cache_stats': self.cache.get_stats() if self.cache else None 200 | } 201 | 202 | def _create_provider(self, provider_type: LLMProvider, config: LLMConfig): 203 | """Create LLM provider instance.""" 204 | if provider_type == LLMProvider.OPENAI: 205 | return OpenAIProvider(config) 206 | elif provider_type == LLMProvider.ANTHROPIC: 207 | return AnthropicProvider(config) 208 | elif provider_type == LLMProvider.AZURE_OPENAI: 209 | return AzureOpenAIProvider(config) 210 | elif provider_type == LLMProvider.HUGGINGFACE: 211 | return HuggingFaceProvider(config) 212 | elif provider_type == LLMProvider.LOCAL: 213 | return LocalProvider(config) 214 | else: 215 | raise ValueError(f"Unknown provider: {provider_type}") 216 | 217 | def _check_rate_limit(self) -> None: 218 | """Check and enforce rate limits.""" 219 | current_time = time.time() 220 | 221 | # Clean old request times 222 | minute_ago = current_time - 60 223 | self._request_times = [ 224 | t for t in self._request_times if t > minute_ago 225 | ] 226 | 227 | # Check rate limit 228 | if len(self._request_times) >= self.config.max_queries_per_minute: 229 | retry_after = 60 - (current_time - self._request_times[0]) 230 | raise RateLimitError( 231 | f"Rate limit exceeded ({self.config.max_queries_per_minute}/min)", 232 | retry_after=int(retry_after) 233 | ) 234 | 235 | def _generate_cache_key(self, prompt: str, system_prompt: Optional[str]) -> str: 236 | """Generate cache key for prompt.""" 237 | import hashlib 238 | 239 | key_parts = [ 240 | self.config.provider.value, 241 | self.config.model_name, 242 | str(self.config.temperature), 243 | str(self.config.max_tokens), 244 | system_prompt or self.config.system_prompt or "", 245 | prompt 246 | ] 247 | 248 | key_string = "|".join(key_parts) 249 | return f"llm:{hashlib.sha256(key_string.encode()).hexdigest()}" 250 | 251 | def _calculate_cost(self, usage: Dict[str, int], model: str) -> float: 252 | """Calculate cost based on token usage.""" 253 | # Model pricing (per 1K tokens) 254 | pricing = { 255 | # OpenAI 256 | 'gpt-4': {'input': 0.03, 'output': 0.06}, 257 | 'gpt-4-turbo': {'input': 0.01, 'output': 0.03}, 258 | 'gpt-3.5-turbo': {'input': 0.0005, 'output': 0.0015}, 259 | # Anthropic 260 | 'claude-3-opus': {'input': 0.015, 'output': 0.075}, 261 | 'claude-3-sonnet': {'input': 0.003, 'output': 0.015}, 262 | 'claude-3-haiku': {'input': 0.00025, 'output': 0.00125}, 263 | # Default 264 | 'default': {'input': 0.001, 'output': 0.002} 265 | } 266 | 267 | model_pricing = pricing.get(model, pricing['default']) 268 | 269 | input_cost = (usage.get('prompt_tokens', 0) / 1000) * model_pricing['input'] 270 | output_cost = (usage.get('completion_tokens', 0) / 1000) * model_pricing['output'] 271 | 272 | return input_cost + output_cost 273 | 274 | def _format_with_examples(self, prompt: str, examples: List[Dict[str, str]]) -> str: 275 | """Format prompt with few-shot examples.""" 276 | formatted_examples = [] 277 | 278 | for example in examples: 279 | formatted_examples.append( 280 | f"Input: {example['input']}\nOutput: {example['output']}" 281 | ) 282 | 283 | examples_text = "\n\n".join(formatted_examples) 284 | 285 | return f"""Here are some examples: 286 | 287 | {examples_text} 288 | 289 | Now, for the following input: 290 | Input: {prompt} 291 | Output:""" -------------------------------------------------------------------------------- /cognidb/security/validator.py: -------------------------------------------------------------------------------- 1 | """Security validator implementation.""" 2 | 3 | import re 4 | from typing import Tuple, Optional, List, Set 5 | from ..core.interfaces import SecurityValidator 6 | from ..core.query_intent import QueryIntent, QueryType 7 | from ..core.exceptions import SecurityError 8 | from .query_parser import SQLQueryParser 9 | 10 | 11 | class QuerySecurityValidator(SecurityValidator): 12 | """ 13 | Comprehensive security validator for queries. 14 | 15 | Implements multiple layers of security: 16 | 1. Query intent validation 17 | 2. Native query validation 18 | 3. Identifier sanitization 19 | 4. Value sanitization 20 | """ 21 | 22 | # Dangerous SQL keywords that should never appear 23 | FORBIDDEN_KEYWORDS = { 24 | 'DROP', 'DELETE', 'TRUNCATE', 'UPDATE', 'INSERT', 'ALTER', 25 | 'CREATE', 'REPLACE', 'RENAME', 'GRANT', 'REVOKE', 'EXECUTE', 26 | 'EXEC', 'CALL', 'MERGE', 'LOCK', 'UNLOCK' 27 | } 28 | 29 | # Patterns that might indicate SQL injection 30 | SQL_INJECTION_PATTERNS = [ 31 | r';\s*--', # Statement termination followed by comment 32 | r';\s*\/\*', # Statement termination followed by comment 33 | r'UNION\s+SELECT', # UNION-based injection 34 | r'OR\s+1\s*=\s*1', # Classic SQL injection 35 | r'OR\s+\'1\'\s*=\s*\'1\'', # Classic SQL injection with quotes 36 | r'WAITFOR\s+DELAY', # Time-based injection 37 | r'BENCHMARK\s*\(', # MySQL time-based injection 38 | r'PG_SLEEP\s*\(', # PostgreSQL time-based injection 39 | r'LOAD_FILE\s*\(', # File system access 40 | r'INTO\s+OUTFILE', # File system write 41 | r'xp_cmdshell', # SQL Server command execution 42 | ] 43 | 44 | # Valid identifier pattern (alphanumeric + underscore) 45 | VALID_IDENTIFIER_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') 46 | 47 | # Maximum identifier length 48 | MAX_IDENTIFIER_LENGTH = 64 49 | 50 | def __init__(self, 51 | allowed_operations: Optional[List[str]] = None, 52 | max_query_complexity: int = 10, 53 | allow_subqueries: bool = False): 54 | """ 55 | Initialize security validator. 56 | 57 | Args: 58 | allowed_operations: List of allowed query types (default: SELECT only) 59 | max_query_complexity: Maximum allowed query complexity score 60 | allow_subqueries: Whether to allow subqueries 61 | """ 62 | self._allowed_operations = allowed_operations or ['SELECT'] 63 | self.max_query_complexity = max_query_complexity 64 | self.allow_subqueries = allow_subqueries 65 | self.parser = SQLQueryParser() 66 | 67 | @property 68 | def allowed_operations(self) -> List[str]: 69 | """List of allowed query operations.""" 70 | return self._allowed_operations 71 | 72 | def validate_query_intent(self, query_intent: QueryIntent) -> Tuple[bool, Optional[str]]: 73 | """ 74 | Validate query intent for security issues. 75 | 76 | Checks: 77 | 1. Query type is allowed 78 | 2. Table/column names are valid 79 | 3. Query complexity is within limits 80 | 4. No forbidden patterns in conditions 81 | """ 82 | # Check query type 83 | if query_intent.query_type.name not in self.allowed_operations: 84 | return False, f"Query type {query_intent.query_type.name} is not allowed" 85 | 86 | # Validate table names 87 | for table in query_intent.tables: 88 | if not self._is_valid_identifier(table): 89 | return False, f"Invalid table name: {table}" 90 | 91 | # Validate column names 92 | for column in query_intent.columns: 93 | if column.name != "*" and not self._is_valid_identifier(column.name): 94 | return False, f"Invalid column name: {column.name}" 95 | if column.table and not self._is_valid_identifier(column.table): 96 | return False, f"Invalid table reference in column: {column.table}" 97 | 98 | # Check query complexity 99 | complexity = self._calculate_complexity(query_intent) 100 | if complexity > self.max_query_complexity: 101 | return False, f"Query too complex (score: {complexity}, max: {self.max_query_complexity})" 102 | 103 | # Validate conditions 104 | if query_intent.conditions: 105 | valid, error = self._validate_conditions(query_intent.conditions) 106 | if not valid: 107 | return False, error 108 | 109 | # Validate joins 110 | for join in query_intent.joins: 111 | if not self._is_valid_identifier(join.left_table): 112 | return False, f"Invalid table in join: {join.left_table}" 113 | if not self._is_valid_identifier(join.right_table): 114 | return False, f"Invalid table in join: {join.right_table}" 115 | if not self._is_valid_identifier(join.left_column): 116 | return False, f"Invalid column in join: {join.left_column}" 117 | if not self._is_valid_identifier(join.right_column): 118 | return False, f"Invalid column in join: {join.right_column}" 119 | 120 | return True, None 121 | 122 | def validate_native_query(self, query: str) -> Tuple[bool, Optional[str]]: 123 | """ 124 | Validate native SQL query for security issues. 125 | 126 | Performs comprehensive security checks including: 127 | 1. Forbidden keyword detection 128 | 2. SQL injection pattern matching 129 | 3. Query parsing and analysis 130 | """ 131 | # Normalize query for analysis 132 | normalized_query = query.upper().strip() 133 | 134 | # Check for forbidden keywords 135 | for keyword in self.FORBIDDEN_KEYWORDS: 136 | if re.search(rf'\b{keyword}\b', normalized_query): 137 | return False, f"Forbidden keyword detected: {keyword}" 138 | 139 | # Check for SQL injection patterns 140 | for pattern in self.SQL_INJECTION_PATTERNS: 141 | if re.search(pattern, normalized_query, re.IGNORECASE): 142 | return False, f"Potential SQL injection pattern detected" 143 | 144 | # Parse and validate query structure 145 | try: 146 | parsed = self.parser.parse(query) 147 | if parsed['type'] not in self.allowed_operations: 148 | return False, f"Query type {parsed['type']} is not allowed" 149 | 150 | # Additional checks based on parsed structure 151 | if not self.allow_subqueries and parsed.get('has_subquery'): 152 | return False, "Subqueries are not allowed" 153 | 154 | except Exception as e: 155 | return False, f"Query parsing failed: {str(e)}" 156 | 157 | return True, None 158 | 159 | def sanitize_identifier(self, identifier: str) -> str: 160 | """ 161 | Sanitize a table/column identifier. 162 | 163 | Args: 164 | identifier: The identifier to sanitize 165 | 166 | Returns: 167 | Sanitized identifier 168 | 169 | Raises: 170 | SecurityError: If identifier cannot be sanitized safely 171 | """ 172 | # Remove any quotes 173 | identifier = identifier.strip().strip('"\'`[]') 174 | 175 | # Validate 176 | if not self._is_valid_identifier(identifier): 177 | raise SecurityError(f"Invalid identifier: {identifier}") 178 | 179 | return identifier 180 | 181 | def sanitize_value(self, value: any) -> any: 182 | """ 183 | Sanitize a parameter value. 184 | 185 | Args: 186 | value: The value to sanitize 187 | 188 | Returns: 189 | Sanitized value 190 | """ 191 | if value is None: 192 | return None 193 | 194 | if isinstance(value, str): 195 | # Remove any SQL comment indicators 196 | value = re.sub(r'--.*$', '', value, flags=re.MULTILINE) 197 | value = re.sub(r'/\*.*?\*/', '', value, flags=re.DOTALL) 198 | 199 | # Escape special characters 200 | # Note: Actual escaping should be done by the database driver 201 | # This is just an additional safety layer 202 | value = value.replace('\x00', '') # Remove null bytes 203 | 204 | elif isinstance(value, (list, tuple)): 205 | # Recursively sanitize collections 206 | return type(value)(self.sanitize_value(v) for v in value) 207 | 208 | elif isinstance(value, dict): 209 | # Recursively sanitize dictionaries 210 | return {k: self.sanitize_value(v) for k, v in value.items()} 211 | 212 | return value 213 | 214 | def _is_valid_identifier(self, identifier: str) -> bool: 215 | """Check if an identifier is valid.""" 216 | if not identifier or len(identifier) > self.MAX_IDENTIFIER_LENGTH: 217 | return False 218 | return bool(self.VALID_IDENTIFIER_PATTERN.match(identifier)) 219 | 220 | def _calculate_complexity(self, query_intent: QueryIntent) -> int: 221 | """ 222 | Calculate query complexity score. 223 | 224 | Factors: 225 | - Number of tables 226 | - Number of joins 227 | - Number of conditions 228 | - Aggregations 229 | - Subqueries (if parsed) 230 | """ 231 | score = 0 232 | 233 | # Base score for tables 234 | score += len(query_intent.tables) 235 | 236 | # Joins add complexity 237 | score += len(query_intent.joins) * 2 238 | 239 | # Conditions add complexity 240 | if query_intent.conditions: 241 | score += self._count_conditions(query_intent.conditions) 242 | 243 | # Aggregations add complexity 244 | score += len(query_intent.aggregations) 245 | 246 | # Group by adds complexity 247 | if query_intent.group_by: 248 | score += 1 249 | 250 | # Having clause adds complexity 251 | if query_intent.having: 252 | score += 2 253 | 254 | return score 255 | 256 | def _count_conditions(self, condition_group) -> int: 257 | """Recursively count conditions in a group.""" 258 | count = 0 259 | for condition in condition_group.conditions: 260 | if hasattr(condition, 'conditions'): # It's a group 261 | count += self._count_conditions(condition) 262 | else: 263 | count += 1 264 | return count 265 | 266 | def _validate_conditions(self, condition_group) -> Tuple[bool, Optional[str]]: 267 | """Validate conditions in a condition group.""" 268 | for condition in condition_group.conditions: 269 | if hasattr(condition, 'conditions'): # It's a group 270 | valid, error = self._validate_conditions(condition) 271 | if not valid: 272 | return False, error 273 | else: 274 | # Validate column name 275 | if not self._is_valid_identifier(condition.column.name): 276 | return False, f"Invalid column in condition: {condition.column.name}" 277 | 278 | # Validate value isn't attempting injection 279 | if isinstance(condition.value, str): 280 | if any(keyword in condition.value.upper() for keyword in self.FORBIDDEN_KEYWORDS): 281 | return False, f"Forbidden keyword in condition value" 282 | 283 | return True, None -------------------------------------------------------------------------------- /cognidb/ai/prompt_builder.py: -------------------------------------------------------------------------------- 1 | """Advanced prompt builder for SQL generation.""" 2 | 3 | from typing import Dict, List, Any, Optional 4 | from ..core.query_intent import QueryIntent 5 | 6 | 7 | class PromptBuilder: 8 | """ 9 | Builds optimized prompts for SQL generation. 10 | 11 | Features: 12 | - Schema-aware prompts 13 | - Few-shot examples 14 | - Database-specific hints 15 | - Query optimization suggestions 16 | """ 17 | 18 | # Database-specific SQL dialects 19 | SQL_DIALECTS = { 20 | 'mysql': { 21 | 'limit': 'LIMIT {limit}', 22 | 'string_concat': "CONCAT({args})", 23 | 'current_date': 'CURDATE()', 24 | 'date_format': "DATE_FORMAT({date}, '{format}')" 25 | }, 26 | 'postgresql': { 27 | 'limit': 'LIMIT {limit}', 28 | 'string_concat': "{args}", # Uses || 29 | 'current_date': 'CURRENT_DATE', 30 | 'date_format': "TO_CHAR({date}, '{format}')" 31 | }, 32 | 'sqlite': { 33 | 'limit': 'LIMIT {limit}', 34 | 'string_concat': "{args}", # Uses || 35 | 'current_date': "DATE('now')", 36 | 'date_format': "STRFTIME('{format}', {date})" 37 | } 38 | } 39 | 40 | def __init__(self, database_type: str = 'postgresql'): 41 | """ 42 | Initialize prompt builder. 43 | 44 | Args: 45 | database_type: Type of database (mysql, postgresql, sqlite) 46 | """ 47 | self.database_type = database_type 48 | self.dialect = self.SQL_DIALECTS.get(database_type, self.SQL_DIALECTS['postgresql']) 49 | 50 | def build_sql_generation_prompt(self, 51 | natural_language_query: str, 52 | schema: Dict[str, Dict[str, str]], 53 | examples: Optional[List[Dict[str, str]]] = None, 54 | context: Optional[Dict[str, Any]] = None) -> str: 55 | """ 56 | Build prompt for SQL generation. 57 | 58 | Args: 59 | natural_language_query: User's natural language query 60 | schema: Database schema 61 | examples: Optional few-shot examples 62 | context: Optional context (user preferences, constraints) 63 | 64 | Returns: 65 | Optimized prompt for LLM 66 | """ 67 | # Build schema description 68 | schema_desc = self._build_schema_description(schema) 69 | 70 | # Build examples section 71 | examples_section = "" 72 | if examples: 73 | examples_section = self._build_examples_section(examples) 74 | 75 | # Build context hints 76 | context_hints = self._build_context_hints(context or {}) 77 | 78 | # Construct the prompt 79 | prompt = f"""You are an expert SQL query generator for {self.database_type} databases. 80 | 81 | Database Schema: 82 | {schema_desc} 83 | 84 | {context_hints} 85 | 86 | Important Instructions: 87 | 1. Generate ONLY the SQL query, no explanations or markdown 88 | 2. Use proper {self.database_type} syntax and functions 89 | 3. Always use table aliases for clarity 90 | 4. Include appropriate JOINs when querying multiple tables 91 | 5. Use parameterized placeholders (?) for any user-provided values 92 | 6. Ensure the query is optimized for performance 93 | 7. Handle NULL values appropriately 94 | 8. Use appropriate data type conversions when needed 95 | 96 | {examples_section} 97 | 98 | Now generate a SQL query for the following request: 99 | User Query: {natural_language_query} 100 | 101 | SQL Query:""" 102 | 103 | return prompt 104 | 105 | def build_query_explanation_prompt(self, 106 | sql_query: str, 107 | schema: Dict[str, Dict[str, str]]) -> str: 108 | """ 109 | Build prompt for explaining SQL query. 110 | 111 | Args: 112 | sql_query: SQL query to explain 113 | schema: Database schema 114 | 115 | Returns: 116 | Prompt for explanation 117 | """ 118 | schema_desc = self._build_schema_description(schema) 119 | 120 | return f"""Explain the following SQL query in simple terms: 121 | 122 | Database Schema: 123 | {schema_desc} 124 | 125 | SQL Query: 126 | {sql_query} 127 | 128 | Provide a clear, concise explanation of: 129 | 1. What the query does 130 | 2. Which tables and columns it uses 131 | 3. Any filters or conditions applied 132 | 4. The expected result format 133 | 134 | Explanation:""" 135 | 136 | def build_optimization_prompt(self, 137 | sql_query: str, 138 | schema: Dict[str, Dict[str, str]], 139 | performance_stats: Optional[Dict[str, Any]] = None) -> str: 140 | """ 141 | Build prompt for query optimization suggestions. 142 | 143 | Args: 144 | sql_query: SQL query to optimize 145 | schema: Database schema with index information 146 | performance_stats: Optional query performance statistics 147 | 148 | Returns: 149 | Prompt for optimization 150 | """ 151 | schema_desc = self._build_schema_description(schema, include_indexes=True) 152 | 153 | perf_section = "" 154 | if performance_stats: 155 | perf_section = f"\nPerformance Stats:\n{self._format_performance_stats(performance_stats)}" 156 | 157 | return f"""Analyze and optimize the following SQL query: 158 | 159 | Database Schema: 160 | {schema_desc} 161 | 162 | Current Query: 163 | {sql_query} 164 | {perf_section} 165 | 166 | Provide optimization suggestions considering: 167 | 1. Index usage 168 | 2. JOIN optimization 169 | 3. Subquery elimination 170 | 4. Filtering efficiency 171 | 5. Result set size reduction 172 | 173 | Optimized Query and Explanation:""" 174 | 175 | def build_intent_to_sql_prompt(self, 176 | query_intent: QueryIntent, 177 | database_type: str) -> str: 178 | """ 179 | Build prompt to convert QueryIntent to SQL. 180 | 181 | Args: 182 | query_intent: Parsed query intent 183 | database_type: Target database type 184 | 185 | Returns: 186 | Prompt for SQL generation 187 | """ 188 | intent_desc = self._describe_query_intent(query_intent) 189 | 190 | return f"""Convert the following query specification to {database_type} SQL: 191 | 192 | Query Specification: 193 | {intent_desc} 194 | 195 | Generate a properly formatted {database_type} SQL query that implements this specification. 196 | Use appropriate syntax and functions for {database_type}. 197 | 198 | SQL Query:""" 199 | 200 | def _build_schema_description(self, 201 | schema: Dict[str, Dict[str, str]], 202 | include_indexes: bool = False) -> str: 203 | """Build formatted schema description.""" 204 | lines = [] 205 | 206 | for table_name, columns in schema.items(): 207 | lines.append(f"Table: {table_name}") 208 | 209 | for col_name, col_type in columns.items(): 210 | lines.append(f" - {col_name}: {col_type}") 211 | 212 | if include_indexes and f"{table_name}_indexes" in schema: 213 | lines.append(" Indexes:") 214 | for index in schema[f"{table_name}_indexes"]: 215 | lines.append(f" - {index}") 216 | 217 | lines.append("") 218 | 219 | return "\n".join(lines) 220 | 221 | def _build_examples_section(self, examples: List[Dict[str, str]]) -> str: 222 | """Build few-shot examples section.""" 223 | if not examples: 224 | return "" 225 | 226 | lines = ["Examples of similar queries:\n"] 227 | 228 | for i, example in enumerate(examples, 1): 229 | lines.append(f"Example {i}:") 230 | lines.append(f"User Query: {example['query']}") 231 | lines.append(f"SQL: {example['sql']}") 232 | lines.append("") 233 | 234 | return "\n".join(lines) 235 | 236 | def _build_context_hints(self, context: Dict[str, Any]) -> str: 237 | """Build context-specific hints.""" 238 | hints = [] 239 | 240 | if context.get('timezone'): 241 | hints.append(f"Timezone: {context['timezone']} (adjust date/time queries accordingly)") 242 | 243 | if context.get('date_format'): 244 | hints.append(f"Preferred date format: {context['date_format']}") 245 | 246 | if context.get('limit_default'): 247 | hints.append(f"Default result limit: {context['limit_default']}") 248 | 249 | if context.get('case_sensitive'): 250 | hints.append("Use case-sensitive string comparisons") 251 | 252 | if context.get('exclude_deleted'): 253 | hints.append("Exclude soft-deleted records (check for deleted_at IS NULL)") 254 | 255 | if hints: 256 | return "Context:\n" + "\n".join(f"- {hint}" for hint in hints) + "\n" 257 | 258 | return "" 259 | 260 | def _describe_query_intent(self, query_intent: QueryIntent) -> str: 261 | """Convert QueryIntent to human-readable description.""" 262 | lines = [] 263 | 264 | lines.append(f"Query Type: {query_intent.query_type.name}") 265 | lines.append(f"Tables: {', '.join(query_intent.tables)}") 266 | 267 | if query_intent.columns: 268 | cols = [str(col) for col in query_intent.columns] 269 | lines.append(f"Columns: {', '.join(cols)}") 270 | 271 | if query_intent.joins: 272 | lines.append("Joins:") 273 | for join in query_intent.joins: 274 | lines.append( 275 | f" - {join.join_type.value} JOIN {join.right_table} " 276 | f"ON {join.left_table}.{join.left_column} = " 277 | f"{join.right_table}.{join.right_column}" 278 | ) 279 | 280 | if query_intent.conditions: 281 | lines.append("Conditions: [Complex condition group]") 282 | 283 | if query_intent.group_by: 284 | cols = [str(col) for col in query_intent.group_by] 285 | lines.append(f"Group By: {', '.join(cols)}") 286 | 287 | if query_intent.order_by: 288 | order_specs = [] 289 | for order in query_intent.order_by: 290 | direction = "ASC" if order.ascending else "DESC" 291 | order_specs.append(f"{order.column} {direction}") 292 | lines.append(f"Order By: {', '.join(order_specs)}") 293 | 294 | if query_intent.limit: 295 | lines.append(f"Limit: {query_intent.limit}") 296 | if query_intent.offset: 297 | lines.append(f"Offset: {query_intent.offset}") 298 | 299 | return "\n".join(lines) 300 | 301 | def _format_performance_stats(self, stats: Dict[str, Any]) -> str: 302 | """Format performance statistics.""" 303 | lines = [] 304 | 305 | if 'execution_time' in stats: 306 | lines.append(f"- Execution Time: {stats['execution_time']}ms") 307 | 308 | if 'rows_examined' in stats: 309 | lines.append(f"- Rows Examined: {stats['rows_examined']}") 310 | 311 | if 'rows_returned' in stats: 312 | lines.append(f"- Rows Returned: {stats['rows_returned']}") 313 | 314 | if 'index_used' in stats: 315 | lines.append(f"- Index Used: {stats['index_used']}") 316 | 317 | return "\n".join(lines) -------------------------------------------------------------------------------- /cognidb/ai/query_generator.py: -------------------------------------------------------------------------------- 1 | """Query generator using LLM with advanced features.""" 2 | 3 | import re 4 | from typing import Dict, Any, List, Optional, Tuple 5 | from ..core.query_intent import QueryIntent, QueryType, Column, Condition, ComparisonOperator 6 | from ..core.exceptions import TranslationError 7 | from ..security.sanitizer import InputSanitizer 8 | from .llm_manager import LLMManager 9 | from .prompt_builder import PromptBuilder 10 | 11 | 12 | class QueryGenerator: 13 | """ 14 | Generates SQL queries from natural language using LLM. 15 | 16 | Features: 17 | - Natural language to SQL conversion 18 | - Query intent parsing 19 | - Schema-aware generation 20 | - Query validation and correction 21 | - Caching for repeated queries 22 | """ 23 | 24 | def __init__(self, 25 | llm_manager: LLMManager, 26 | database_type: str = 'postgresql'): 27 | """ 28 | Initialize query generator. 29 | 30 | Args: 31 | llm_manager: LLM manager instance 32 | database_type: Type of database 33 | """ 34 | self.llm_manager = llm_manager 35 | self.database_type = database_type 36 | self.prompt_builder = PromptBuilder(database_type) 37 | self.sanitizer = InputSanitizer() 38 | 39 | def generate_sql(self, 40 | natural_language_query: str, 41 | schema: Dict[str, Dict[str, str]], 42 | examples: Optional[List[Dict[str, str]]] = None, 43 | context: Optional[Dict[str, Any]] = None) -> str: 44 | """ 45 | Generate SQL from natural language query. 46 | 47 | Args: 48 | natural_language_query: User's query in natural language 49 | schema: Database schema 50 | examples: Optional few-shot examples 51 | context: Optional context information 52 | 53 | Returns: 54 | Generated SQL query 55 | 56 | Raises: 57 | TranslationError: If SQL generation fails 58 | """ 59 | # Sanitize input 60 | sanitized_query = self.sanitizer.sanitize_natural_language(natural_language_query) 61 | 62 | # Build prompt 63 | prompt = self.prompt_builder.build_sql_generation_prompt( 64 | sanitized_query, 65 | schema, 66 | examples, 67 | context 68 | ) 69 | 70 | # Generate SQL 71 | try: 72 | response = self.llm_manager.generate(prompt) 73 | sql_query = self._extract_sql(response.content) 74 | 75 | # Validate basic SQL structure 76 | if not self._is_valid_sql(sql_query): 77 | raise TranslationError("Generated SQL is invalid") 78 | 79 | return sql_query 80 | 81 | except Exception as e: 82 | raise TranslationError(f"Failed to generate SQL: {str(e)}") 83 | 84 | def parse_to_intent(self, 85 | natural_language_query: str, 86 | schema: Dict[str, Dict[str, str]]) -> QueryIntent: 87 | """ 88 | Parse natural language to QueryIntent. 89 | 90 | Args: 91 | natural_language_query: User's query 92 | schema: Database schema 93 | 94 | Returns: 95 | Parsed QueryIntent 96 | """ 97 | # First generate SQL 98 | sql_query = self.generate_sql(natural_language_query, schema) 99 | 100 | # Then parse SQL to QueryIntent 101 | return self._parse_sql_to_intent(sql_query, schema) 102 | 103 | def explain_query(self, 104 | sql_query: str, 105 | schema: Dict[str, Dict[str, str]]) -> str: 106 | """ 107 | Generate natural language explanation of SQL query. 108 | 109 | Args: 110 | sql_query: SQL query to explain 111 | schema: Database schema 112 | 113 | Returns: 114 | Natural language explanation 115 | """ 116 | prompt = self.prompt_builder.build_query_explanation_prompt(sql_query, schema) 117 | 118 | try: 119 | response = self.llm_manager.generate(prompt, temperature=0.3) 120 | return response.content 121 | except Exception as e: 122 | return f"Could not generate explanation: {str(e)}" 123 | 124 | def optimize_query(self, 125 | sql_query: str, 126 | schema: Dict[str, Dict[str, str]], 127 | performance_stats: Optional[Dict[str, Any]] = None) -> Tuple[str, str]: 128 | """ 129 | Generate optimized version of SQL query. 130 | 131 | Args: 132 | sql_query: SQL query to optimize 133 | schema: Database schema with indexes 134 | performance_stats: Optional performance statistics 135 | 136 | Returns: 137 | Tuple of (optimized_query, explanation) 138 | """ 139 | prompt = self.prompt_builder.build_optimization_prompt( 140 | sql_query, 141 | schema, 142 | performance_stats 143 | ) 144 | 145 | try: 146 | response = self.llm_manager.generate(prompt, temperature=0.2) 147 | 148 | # Extract optimized query and explanation 149 | content = response.content 150 | if "```sql" in content: 151 | # Extract SQL from markdown 152 | sql_match = re.search(r'```sql\n(.*?)\n```', content, re.DOTALL) 153 | if sql_match: 154 | optimized_query = sql_match.group(1).strip() 155 | explanation = content.replace(sql_match.group(0), '').strip() 156 | return optimized_query, explanation 157 | 158 | # Try to split by common patterns 159 | lines = content.split('\n') 160 | sql_lines = [] 161 | explanation_lines = [] 162 | in_sql = True 163 | 164 | for line in lines: 165 | if line.strip() and not line.startswith(('--', '#', '//')): 166 | if in_sql and any(keyword in line.upper() for keyword in 167 | ['EXPLANATION:', 'CHANGES:', 'OPTIMIZATION:']): 168 | in_sql = False 169 | 170 | if in_sql: 171 | sql_lines.append(line) 172 | else: 173 | explanation_lines.append(line) 174 | 175 | optimized_query = '\n'.join(sql_lines).strip() 176 | explanation = '\n'.join(explanation_lines).strip() 177 | 178 | return optimized_query, explanation 179 | 180 | except Exception as e: 181 | return sql_query, f"Could not optimize: {str(e)}" 182 | 183 | def suggest_queries(self, 184 | partial_query: str, 185 | schema: Dict[str, Dict[str, str]], 186 | num_suggestions: int = 3) -> List[str]: 187 | """ 188 | Generate query suggestions based on partial input. 189 | 190 | Args: 191 | partial_query: Partial natural language query 192 | schema: Database schema 193 | num_suggestions: Number of suggestions to generate 194 | 195 | Returns: 196 | List of suggested queries 197 | """ 198 | prompt = f"""Based on the database schema and partial query, suggest {num_suggestions} complete queries. 199 | 200 | Database Schema: 201 | {self.prompt_builder._build_schema_description(schema)} 202 | 203 | Partial Query: {partial_query} 204 | 205 | Generate {num_suggestions} relevant query suggestions that complete or expand on the partial query. 206 | Format each suggestion on a new line starting with "- ". 207 | 208 | Suggestions:""" 209 | 210 | try: 211 | response = self.llm_manager.generate(prompt, temperature=0.7) 212 | 213 | # Extract suggestions 214 | suggestions = [] 215 | for line in response.content.split('\n'): 216 | if line.strip().startswith('- '): 217 | suggestion = line.strip()[2:].strip() 218 | if suggestion: 219 | suggestions.append(suggestion) 220 | 221 | return suggestions[:num_suggestions] 222 | 223 | except Exception: 224 | return [] 225 | 226 | def _extract_sql(self, llm_response: str) -> str: 227 | """Extract SQL query from LLM response.""" 228 | # Remove markdown code blocks if present 229 | if "```sql" in llm_response: 230 | match = re.search(r'```sql\n(.*?)\n```', llm_response, re.DOTALL) 231 | if match: 232 | llm_response = match.group(1) 233 | elif "```" in llm_response: 234 | match = re.search(r'```\n(.*?)\n```', llm_response, re.DOTALL) 235 | if match: 236 | llm_response = match.group(1) 237 | 238 | # Clean up the response 239 | sql_query = llm_response.strip() 240 | 241 | # Remove any leading/trailing quotes 242 | if sql_query.startswith('"') and sql_query.endswith('"'): 243 | sql_query = sql_query[1:-1] 244 | elif sql_query.startswith("'") and sql_query.endswith("'"): 245 | sql_query = sql_query[1:-1] 246 | 247 | # Ensure it ends with semicolon 248 | if not sql_query.endswith(';'): 249 | sql_query += ';' 250 | 251 | return sql_query 252 | 253 | def _is_valid_sql(self, sql_query: str) -> bool: 254 | """Basic SQL validation.""" 255 | if not sql_query or not sql_query.strip(): 256 | return False 257 | 258 | # Check for basic SQL structure 259 | sql_upper = sql_query.upper() 260 | valid_starts = ['SELECT', 'WITH', 'SHOW', 'DESCRIBE', 'EXPLAIN'] 261 | 262 | return any(sql_upper.strip().startswith(start) for start in valid_starts) 263 | 264 | def _parse_sql_to_intent(self, sql_query: str, schema: Dict[str, Dict[str, str]]) -> QueryIntent: 265 | """ 266 | Parse SQL query to QueryIntent (simplified version). 267 | 268 | This is a basic implementation. In production, you'd want 269 | a full SQL parser. 270 | """ 271 | # Extract tables (basic regex approach) 272 | tables = [] 273 | from_match = re.search(r'FROM\s+(\w+)', sql_query, re.IGNORECASE) 274 | if from_match: 275 | tables.append(from_match.group(1)) 276 | 277 | # Extract columns (basic approach) 278 | columns = [] 279 | select_match = re.search(r'SELECT\s+(.*?)\s+FROM', sql_query, re.IGNORECASE | re.DOTALL) 280 | if select_match: 281 | column_str = select_match.group(1) 282 | if column_str.strip() == '*': 283 | columns = [Column('*')] 284 | else: 285 | # Simple split by comma (doesn't handle complex cases) 286 | for col in column_str.split(','): 287 | col = col.strip() 288 | if ' AS ' in col.upper(): 289 | parts = re.split(r'\s+AS\s+', col, flags=re.IGNORECASE) 290 | columns.append(Column(parts[0].strip(), alias=parts[1].strip())) 291 | else: 292 | columns.append(Column(col)) 293 | 294 | # Create basic QueryIntent 295 | intent = QueryIntent( 296 | query_type=QueryType.SELECT, 297 | tables=tables, 298 | columns=columns, 299 | natural_language_query=sql_query 300 | ) 301 | 302 | return intent -------------------------------------------------------------------------------- /cognidb/ai/cost_tracker.py: -------------------------------------------------------------------------------- 1 | """Cost tracking for LLM usage.""" 2 | 3 | import time 4 | from datetime import datetime, timedelta 5 | from typing import Dict, Any, List, Optional 6 | from collections import defaultdict 7 | import json 8 | from pathlib import Path 9 | 10 | 11 | class CostTracker: 12 | """ 13 | Tracks LLM usage and costs. 14 | 15 | Features: 16 | - Token usage tracking 17 | - Cost calculation and limits 18 | - Daily/monthly aggregation 19 | - Persistent storage 20 | """ 21 | 22 | def __init__(self, 23 | max_daily_cost: float = 100.0, 24 | storage_path: Optional[str] = None): 25 | """ 26 | Initialize cost tracker. 27 | 28 | Args: 29 | max_daily_cost: Maximum allowed daily cost 30 | storage_path: Path to store usage data 31 | """ 32 | self.max_daily_cost = max_daily_cost 33 | self.storage_path = storage_path or str( 34 | Path.home() / '.cognidb' / 'usage.json' 35 | ) 36 | 37 | # Usage data structure 38 | self.usage_data = defaultdict(lambda: { 39 | 'requests': 0, 40 | 'tokens': {'prompt': 0, 'completion': 0, 'total': 0}, 41 | 'cost': 0.0, 42 | 'models': defaultdict(int) 43 | }) 44 | 45 | # Load existing data 46 | self._load_usage_data() 47 | 48 | def track_usage(self, 49 | cost: float, 50 | token_usage: Dict[str, int], 51 | model: Optional[str] = None) -> None: 52 | """ 53 | Track usage for a request. 54 | 55 | Args: 56 | cost: Cost of the request 57 | token_usage: Token usage statistics 58 | model: Model used 59 | """ 60 | today = datetime.now().strftime('%Y-%m-%d') 61 | 62 | # Update daily stats 63 | self.usage_data[today]['requests'] += 1 64 | self.usage_data[today]['cost'] += cost 65 | self.usage_data[today]['tokens']['prompt'] += token_usage.get('prompt_tokens', 0) 66 | self.usage_data[today]['tokens']['completion'] += token_usage.get('completion_tokens', 0) 67 | self.usage_data[today]['tokens']['total'] += token_usage.get('total_tokens', 0) 68 | 69 | if model: 70 | self.usage_data[today]['models'][model] += 1 71 | 72 | # Save data 73 | self._save_usage_data() 74 | 75 | def get_daily_cost(self, date: Optional[str] = None) -> float: 76 | """ 77 | Get cost for a specific day. 78 | 79 | Args: 80 | date: Date in YYYY-MM-DD format (default: today) 81 | 82 | Returns: 83 | Daily cost 84 | """ 85 | if date is None: 86 | date = datetime.now().strftime('%Y-%m-%d') 87 | 88 | return self.usage_data.get(date, {}).get('cost', 0.0) 89 | 90 | def get_monthly_cost(self, year: int, month: int) -> float: 91 | """ 92 | Get cost for a specific month. 93 | 94 | Args: 95 | year: Year 96 | month: Month (1-12) 97 | 98 | Returns: 99 | Monthly cost 100 | """ 101 | total_cost = 0.0 102 | month_str = f"{year:04d}-{month:02d}" 103 | 104 | for date, data in self.usage_data.items(): 105 | if date.startswith(month_str): 106 | total_cost += data.get('cost', 0.0) 107 | 108 | return total_cost 109 | 110 | def get_total_cost(self) -> float: 111 | """Get total cost across all time.""" 112 | return sum(data.get('cost', 0.0) for data in self.usage_data.values()) 113 | 114 | def get_token_usage(self, date: Optional[str] = None) -> Dict[str, int]: 115 | """ 116 | Get token usage for a specific day. 117 | 118 | Args: 119 | date: Date in YYYY-MM-DD format (default: today) 120 | 121 | Returns: 122 | Token usage statistics 123 | """ 124 | if date is None: 125 | date = datetime.now().strftime('%Y-%m-%d') 126 | 127 | return self.usage_data.get(date, {}).get('tokens', { 128 | 'prompt': 0, 129 | 'completion': 0, 130 | 'total': 0 131 | }) 132 | 133 | def is_limit_exceeded(self, date: Optional[str] = None) -> bool: 134 | """ 135 | Check if daily cost limit is exceeded. 136 | 137 | Args: 138 | date: Date to check (default: today) 139 | 140 | Returns: 141 | True if limit exceeded 142 | """ 143 | daily_cost = self.get_daily_cost(date) 144 | return daily_cost >= self.max_daily_cost 145 | 146 | def get_remaining_budget(self, date: Optional[str] = None) -> float: 147 | """ 148 | Get remaining budget for the day. 149 | 150 | Args: 151 | date: Date to check (default: today) 152 | 153 | Returns: 154 | Remaining budget 155 | """ 156 | daily_cost = self.get_daily_cost(date) 157 | return max(0, self.max_daily_cost - daily_cost) 158 | 159 | def get_usage_summary(self, days: int = 7) -> Dict[str, Any]: 160 | """ 161 | Get usage summary for recent days. 162 | 163 | Args: 164 | days: Number of days to include 165 | 166 | Returns: 167 | Usage summary 168 | """ 169 | end_date = datetime.now() 170 | start_date = end_date - timedelta(days=days-1) 171 | 172 | summary = { 173 | 'period': f"{start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}", 174 | 'total_cost': 0.0, 175 | 'total_requests': 0, 176 | 'total_tokens': 0, 177 | 'daily_breakdown': [], 178 | 'model_usage': defaultdict(int) 179 | } 180 | 181 | current_date = start_date 182 | while current_date <= end_date: 183 | date_str = current_date.strftime('%Y-%m-%d') 184 | if date_str in self.usage_data: 185 | data = self.usage_data[date_str] 186 | summary['total_cost'] += data['cost'] 187 | summary['total_requests'] += data['requests'] 188 | summary['total_tokens'] += data['tokens']['total'] 189 | 190 | summary['daily_breakdown'].append({ 191 | 'date': date_str, 192 | 'cost': data['cost'], 193 | 'requests': data['requests'], 194 | 'tokens': data['tokens']['total'] 195 | }) 196 | 197 | for model, count in data['models'].items(): 198 | summary['model_usage'][model] += count 199 | 200 | current_date += timedelta(days=1) 201 | 202 | summary['model_usage'] = dict(summary['model_usage']) 203 | summary['average_daily_cost'] = summary['total_cost'] / days 204 | summary['average_cost_per_request'] = ( 205 | summary['total_cost'] / summary['total_requests'] 206 | if summary['total_requests'] > 0 else 0 207 | ) 208 | 209 | return summary 210 | 211 | def cleanup_old_data(self, days_to_keep: int = 90) -> None: 212 | """ 213 | Remove usage data older than specified days. 214 | 215 | Args: 216 | days_to_keep: Number of days of data to keep 217 | """ 218 | cutoff_date = datetime.now() - timedelta(days=days_to_keep) 219 | cutoff_str = cutoff_date.strftime('%Y-%m-%d') 220 | 221 | dates_to_remove = [ 222 | date for date in self.usage_data.keys() 223 | if date < cutoff_str 224 | ] 225 | 226 | for date in dates_to_remove: 227 | del self.usage_data[date] 228 | 229 | if dates_to_remove: 230 | self._save_usage_data() 231 | 232 | def export_usage_report(self, 233 | start_date: str, 234 | end_date: str, 235 | format: str = 'json') -> str: 236 | """ 237 | Export usage report for date range. 238 | 239 | Args: 240 | start_date: Start date (YYYY-MM-DD) 241 | end_date: End date (YYYY-MM-DD) 242 | format: Export format (json, csv) 243 | 244 | Returns: 245 | Formatted report 246 | """ 247 | report_data = [] 248 | 249 | current = datetime.strptime(start_date, '%Y-%m-%d') 250 | end = datetime.strptime(end_date, '%Y-%m-%d') 251 | 252 | while current <= end: 253 | date_str = current.strftime('%Y-%m-%d') 254 | if date_str in self.usage_data: 255 | data = self.usage_data[date_str] 256 | report_data.append({ 257 | 'date': date_str, 258 | 'requests': data['requests'], 259 | 'cost': data['cost'], 260 | 'prompt_tokens': data['tokens']['prompt'], 261 | 'completion_tokens': data['tokens']['completion'], 262 | 'total_tokens': data['tokens']['total'], 263 | 'models': dict(data['models']) 264 | }) 265 | current += timedelta(days=1) 266 | 267 | if format == 'json': 268 | return json.dumps(report_data, indent=2) 269 | elif format == 'csv': 270 | import csv 271 | import io 272 | 273 | output = io.StringIO() 274 | if report_data: 275 | writer = csv.DictWriter( 276 | output, 277 | fieldnames=['date', 'requests', 'cost', 'prompt_tokens', 278 | 'completion_tokens', 'total_tokens'] 279 | ) 280 | writer.writeheader() 281 | for row in report_data: 282 | # Flatten models field 283 | row_copy = row.copy() 284 | del row_copy['models'] 285 | writer.writerow(row_copy) 286 | 287 | return output.getvalue() 288 | else: 289 | raise ValueError(f"Unsupported format: {format}") 290 | 291 | def _load_usage_data(self) -> None: 292 | """Load usage data from storage.""" 293 | try: 294 | if Path(self.storage_path).exists(): 295 | with open(self.storage_path, 'r') as f: 296 | data = json.load(f) 297 | # Convert to defaultdict structure 298 | for date, usage in data.items(): 299 | self.usage_data[date] = usage 300 | if 'models' in usage: 301 | self.usage_data[date]['models'] = defaultdict( 302 | int, usage['models'] 303 | ) 304 | except Exception: 305 | # If loading fails, start fresh 306 | pass 307 | 308 | def _save_usage_data(self) -> None: 309 | """Save usage data to storage.""" 310 | try: 311 | # Ensure directory exists 312 | Path(self.storage_path).parent.mkdir(parents=True, exist_ok=True) 313 | 314 | # Convert defaultdict to regular dict for JSON serialization 315 | data_to_save = {} 316 | for date, usage in self.usage_data.items(): 317 | data_to_save[date] = { 318 | 'requests': usage['requests'], 319 | 'tokens': usage['tokens'], 320 | 'cost': usage['cost'], 321 | 'models': dict(usage['models']) 322 | } 323 | 324 | with open(self.storage_path, 'w') as f: 325 | json.dump(data_to_save, f, indent=2) 326 | except Exception: 327 | # Log error but don't fail the request 328 | pass -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | CogniDB - Secure Natural Language Database Interface 3 | 4 | A production-ready natural language to SQL interface with comprehensive 5 | security, multi-database support, and intelligent query generation. 6 | """ 7 | 8 | __version__ = "2.0.0" 9 | __author__ = "CogniDB Team" 10 | 11 | import logging 12 | from typing import Dict, Any, Optional, List, Union 13 | from pathlib import Path 14 | 15 | # Core imports 16 | from .cognidb.core.exceptions import CogniDBError 17 | from .cognidb.config import ConfigLoader, Settings, DatabaseType 18 | from .cognidb.security import QuerySecurityValidator, AccessController, InputSanitizer 19 | from .cognidb.ai import LLMManager, QueryGenerator 20 | from .cognidb.drivers import ( 21 | MySQLDriver, 22 | PostgreSQLDriver, 23 | MongoDBDriver, 24 | DynamoDBDriver, 25 | SQLiteDriver 26 | ) 27 | 28 | # Setup logging 29 | logging.basicConfig( 30 | level=logging.INFO, 31 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' 32 | ) 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | class CogniDB: 37 | """ 38 | Main CogniDB interface for natural language database queries. 39 | 40 | Features: 41 | - Natural language to SQL conversion 42 | - Multi-database support (MySQL, PostgreSQL, MongoDB, DynamoDB) 43 | - Comprehensive security validation 44 | - Query optimization and caching 45 | - Cost tracking and limits 46 | - Audit logging 47 | """ 48 | 49 | def __init__(self, 50 | config_file: Optional[str] = None, 51 | **kwargs): 52 | """ 53 | Initialize CogniDB. 54 | 55 | Args: 56 | config_file: Path to configuration file 57 | **kwargs: Override configuration values 58 | """ 59 | # Load configuration 60 | self.config_loader = ConfigLoader(config_file) 61 | self.settings = self.config_loader.load() 62 | 63 | # Apply any overrides 64 | self._apply_config_overrides(kwargs) 65 | 66 | # Initialize components 67 | self._init_driver() 68 | self._init_security() 69 | self._init_ai() 70 | self._init_cache() 71 | 72 | # Connect to database 73 | self.driver.connect() 74 | 75 | # Cache schema 76 | self.schema = self.driver.fetch_schema() 77 | 78 | logger.info("CogniDB initialized successfully") 79 | 80 | def query(self, 81 | natural_language_query: str, 82 | user_id: Optional[str] = None, 83 | explain: bool = False) -> Dict[str, Any]: 84 | """ 85 | Execute a natural language query. 86 | 87 | Args: 88 | natural_language_query: Query in natural language 89 | user_id: Optional user ID for access control 90 | explain: Whether to include query explanation 91 | 92 | Returns: 93 | Dictionary with results and metadata 94 | """ 95 | try: 96 | # Sanitize input 97 | sanitized_query = self.input_sanitizer.sanitize_natural_language( 98 | natural_language_query 99 | ) 100 | 101 | # Check access permissions 102 | if self.settings.security.enable_access_control and user_id: 103 | # This would integrate with your access control system 104 | pass 105 | 106 | # Generate SQL from natural language 107 | sql_query = self.query_generator.generate_sql( 108 | sanitized_query, 109 | self.schema, 110 | examples=self.settings.llm.few_shot_examples 111 | ) 112 | 113 | # Validate generated SQL 114 | is_valid, error = self.security_validator.validate_native_query(sql_query) 115 | if not is_valid: 116 | raise CogniDBError(f"Security validation failed: {error}") 117 | 118 | # Execute query 119 | results = self.driver.execute_native_query(sql_query) 120 | 121 | # Prepare response 122 | response = { 123 | 'success': True, 124 | 'query': natural_language_query, 125 | 'sql': sql_query, 126 | 'results': results, 127 | 'row_count': len(results), 128 | 'execution_time': None # Would be tracked by driver 129 | } 130 | 131 | # Add explanation if requested 132 | if explain: 133 | response['explanation'] = self.query_generator.explain_query( 134 | sql_query, 135 | self.schema 136 | ) 137 | 138 | # Log query for audit 139 | self._audit_log(user_id, natural_language_query, sql_query, True) 140 | 141 | return response 142 | 143 | except Exception as e: 144 | logger.error(f"Query execution failed: {str(e)}") 145 | self._audit_log(user_id, natural_language_query, None, False, str(e)) 146 | 147 | return { 148 | 'success': False, 149 | 'query': natural_language_query, 150 | 'error': str(e) 151 | } 152 | 153 | def optimize_query(self, sql_query: str) -> Dict[str, Any]: 154 | """ 155 | Get optimization suggestions for a SQL query. 156 | 157 | Args: 158 | sql_query: SQL query to optimize 159 | 160 | Returns: 161 | Dictionary with optimized query and explanation 162 | """ 163 | try: 164 | optimized_sql, explanation = self.query_generator.optimize_query( 165 | sql_query, 166 | self.schema 167 | ) 168 | 169 | return { 170 | 'success': True, 171 | 'original_query': sql_query, 172 | 'optimized_query': optimized_sql, 173 | 'explanation': explanation 174 | } 175 | 176 | except Exception as e: 177 | logger.error(f"Query optimization failed: {str(e)}") 178 | return { 179 | 'success': False, 180 | 'error': str(e) 181 | } 182 | 183 | def suggest_queries(self, partial_query: str) -> List[str]: 184 | """ 185 | Get query suggestions based on partial input. 186 | 187 | Args: 188 | partial_query: Partial natural language query 189 | 190 | Returns: 191 | List of suggested queries 192 | """ 193 | try: 194 | return self.query_generator.suggest_queries( 195 | partial_query, 196 | self.schema 197 | ) 198 | except Exception as e: 199 | logger.error(f"Failed to generate suggestions: {str(e)}") 200 | return [] 201 | 202 | def get_schema(self, table_name: Optional[str] = None) -> Dict[str, Any]: 203 | """ 204 | Get database schema information. 205 | 206 | Args: 207 | table_name: Optional specific table name 208 | 209 | Returns: 210 | Schema information 211 | """ 212 | if table_name: 213 | return { 214 | table_name: self.schema.get(table_name, {}) 215 | } 216 | return self.schema 217 | 218 | def get_usage_stats(self) -> Dict[str, Any]: 219 | """Get usage statistics including costs.""" 220 | return self.llm_manager.get_usage_stats() 221 | 222 | def close(self): 223 | """Close all connections and cleanup resources.""" 224 | try: 225 | if hasattr(self, 'driver'): 226 | self.driver.disconnect() 227 | logger.info("CogniDB closed successfully") 228 | except Exception as e: 229 | logger.error(f"Error closing CogniDB: {str(e)}") 230 | 231 | def __enter__(self): 232 | """Context manager entry.""" 233 | return self 234 | 235 | def __exit__(self, exc_type, exc_val, exc_tb): 236 | """Context manager exit.""" 237 | self.close() 238 | 239 | # Private methods 240 | 241 | def _init_driver(self): 242 | """Initialize database driver.""" 243 | driver_map = { 244 | DatabaseType.MYSQL: MySQLDriver, 245 | DatabaseType.POSTGRESQL: PostgreSQLDriver, 246 | DatabaseType.MONGODB: MongoDBDriver, 247 | DatabaseType.DYNAMODB: DynamoDBDriver, 248 | DatabaseType.SQLITE: SQLiteDriver 249 | } 250 | 251 | driver_class = driver_map.get(self.settings.database.type) 252 | if not driver_class: 253 | raise CogniDBError(f"Unsupported database type: {self.settings.database.type}") 254 | 255 | # Convert settings to driver config 256 | driver_config = { 257 | 'host': self.settings.database.host, 258 | 'port': self.settings.database.port, 259 | 'database': self.settings.database.database, 260 | 'username': self.settings.database.username, 261 | 'password': self.settings.database.password, 262 | 'ssl_enabled': self.settings.database.ssl_enabled, 263 | 'query_timeout': self.settings.database.query_timeout, 264 | 'max_result_size': self.settings.database.max_result_size, 265 | **self.settings.database.options 266 | } 267 | 268 | self.driver = driver_class(driver_config) 269 | 270 | def _init_security(self): 271 | """Initialize security components.""" 272 | self.security_validator = QuerySecurityValidator( 273 | allowed_operations=['SELECT'], 274 | max_query_complexity=self.settings.security.max_query_complexity, 275 | allow_subqueries=self.settings.security.allow_subqueries 276 | ) 277 | 278 | self.access_controller = AccessController() 279 | self.input_sanitizer = InputSanitizer() 280 | 281 | def _init_ai(self): 282 | """Initialize AI components.""" 283 | # Initialize LLM manager 284 | self.llm_manager = LLMManager( 285 | self.settings.llm, 286 | cache_provider=None # Will be set after cache init 287 | ) 288 | 289 | # Initialize query generator 290 | self.query_generator = QueryGenerator( 291 | self.llm_manager, 292 | database_type=self.settings.database.type.value 293 | ) 294 | 295 | def _init_cache(self): 296 | """Initialize caching layer.""" 297 | # For now, using in-memory cache from LLM manager 298 | # In production, would initialize Redis/Memcached here 299 | pass 300 | 301 | def _apply_config_overrides(self, overrides: Dict[str, Any]): 302 | """Apply configuration overrides.""" 303 | for key, value in overrides.items(): 304 | if hasattr(self.settings, key): 305 | setattr(self.settings, key, value) 306 | 307 | def _audit_log(self, 308 | user_id: Optional[str], 309 | natural_language_query: str, 310 | sql_query: Optional[str], 311 | success: bool, 312 | error: Optional[str] = None): 313 | """Log query for audit trail.""" 314 | if not self.settings.security.enable_audit_logging: 315 | return 316 | 317 | # In production, this would write to a proper audit log 318 | logger.info( 319 | f"AUDIT: user={user_id}, query={natural_language_query[:50]}..., " 320 | f"success={success}, error={error}" 321 | ) 322 | 323 | 324 | # Convenience function for quick usage 325 | def create_cognidb(**kwargs) -> CogniDB: 326 | """ 327 | Create a CogniDB instance with configuration. 328 | 329 | Args: 330 | **kwargs: Configuration parameters 331 | 332 | Returns: 333 | CogniDB instance 334 | """ 335 | return CogniDB(**kwargs) -------------------------------------------------------------------------------- /cognidb/config/secrets.py: -------------------------------------------------------------------------------- 1 | """Secrets management for sensitive configuration.""" 2 | 3 | import os 4 | import json 5 | import base64 6 | from typing import Dict, Any, Optional 7 | from pathlib import Path 8 | from cryptography.fernet import Fernet 9 | from cryptography.hazmat.primitives import hashes 10 | from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC 11 | from ..core.exceptions import ConfigurationError 12 | 13 | 14 | class SecretsManager: 15 | """ 16 | Secure secrets management. 17 | 18 | Supports: 19 | - Environment variables 20 | - Encrypted file storage 21 | - AWS Secrets Manager 22 | - HashiCorp Vault 23 | - Azure Key Vault 24 | """ 25 | 26 | def __init__(self, provider: str = "env", **kwargs): 27 | """ 28 | Initialize secrets manager. 29 | 30 | Args: 31 | provider: One of 'env', 'file', 'aws', 'vault', 'azure' 32 | **kwargs: Provider-specific configuration 33 | """ 34 | self.provider = provider 35 | self.config = kwargs 36 | self._cache: Dict[str, Any] = {} 37 | self._cipher: Optional[Fernet] = None 38 | 39 | if provider == "file": 40 | self._init_file_provider() 41 | elif provider == "aws": 42 | self._init_aws_provider() 43 | elif provider == "vault": 44 | self._init_vault_provider() 45 | elif provider == "azure": 46 | self._init_azure_provider() 47 | 48 | def get_secret(self, key: str, default: Any = None) -> Any: 49 | """ 50 | Retrieve a secret value. 51 | 52 | Args: 53 | key: Secret key 54 | default: Default value if not found 55 | 56 | Returns: 57 | Secret value 58 | """ 59 | # Check cache first 60 | if key in self._cache: 61 | return self._cache[key] 62 | 63 | value = None 64 | 65 | if self.provider == "env": 66 | value = os.getenv(key, default) 67 | elif self.provider == "file": 68 | value = self._get_file_secret(key, default) 69 | elif self.provider == "aws": 70 | value = self._get_aws_secret(key, default) 71 | elif self.provider == "vault": 72 | value = self._get_vault_secret(key, default) 73 | elif self.provider == "azure": 74 | value = self._get_azure_secret(key, default) 75 | 76 | # Cache the value 77 | if value is not None: 78 | self._cache[key] = value 79 | 80 | return value 81 | 82 | def set_secret(self, key: str, value: Any) -> None: 83 | """ 84 | Store a secret value. 85 | 86 | Args: 87 | key: Secret key 88 | value: Secret value 89 | """ 90 | if self.provider == "env": 91 | os.environ[key] = str(value) 92 | elif self.provider == "file": 93 | self._set_file_secret(key, value) 94 | elif self.provider == "aws": 95 | self._set_aws_secret(key, value) 96 | elif self.provider == "vault": 97 | self._set_vault_secret(key, value) 98 | elif self.provider == "azure": 99 | self._set_azure_secret(key, value) 100 | 101 | # Update cache 102 | self._cache[key] = value 103 | 104 | def delete_secret(self, key: str) -> None: 105 | """Delete a secret.""" 106 | if self.provider == "env": 107 | os.environ.pop(key, None) 108 | elif self.provider == "file": 109 | self._delete_file_secret(key) 110 | elif self.provider == "aws": 111 | self._delete_aws_secret(key) 112 | elif self.provider == "vault": 113 | self._delete_vault_secret(key) 114 | elif self.provider == "azure": 115 | self._delete_azure_secret(key) 116 | 117 | # Remove from cache 118 | self._cache.pop(key, None) 119 | 120 | def clear_cache(self) -> None: 121 | """Clear the secrets cache.""" 122 | self._cache.clear() 123 | 124 | # File-based provider methods 125 | 126 | def _init_file_provider(self) -> None: 127 | """Initialize file-based secrets provider.""" 128 | secrets_file = self.config.get('secrets_file', 129 | str(Path.home() / '.cognidb' / 'secrets.enc')) 130 | master_password = self.config.get('master_password') 131 | 132 | if not master_password: 133 | master_password = os.getenv('COGNIDB_MASTER_PASSWORD') 134 | if not master_password: 135 | raise ConfigurationError( 136 | "Master password required for file-based secrets" 137 | ) 138 | 139 | # Create directory if needed 140 | Path(secrets_file).parent.mkdir(parents=True, exist_ok=True) 141 | 142 | # Generate encryption key from password 143 | kdf = PBKDF2HMAC( 144 | algorithm=hashes.SHA256(), 145 | length=32, 146 | salt=b'cognidb_salt', # In production, use random salt 147 | iterations=100000, 148 | ) 149 | key = base64.urlsafe_b64encode(kdf.derive(master_password.encode())) 150 | self._cipher = Fernet(key) 151 | 152 | self.secrets_file = secrets_file 153 | self._load_secrets_file() 154 | 155 | def _load_secrets_file(self) -> None: 156 | """Load secrets from encrypted file.""" 157 | if not Path(self.secrets_file).exists(): 158 | self._secrets_data = {} 159 | return 160 | 161 | try: 162 | with open(self.secrets_file, 'rb') as f: 163 | encrypted_data = f.read() 164 | 165 | decrypted_data = self._cipher.decrypt(encrypted_data) 166 | self._secrets_data = json.loads(decrypted_data.decode()) 167 | except Exception as e: 168 | raise ConfigurationError(f"Failed to load secrets file: {e}") 169 | 170 | def _save_secrets_file(self) -> None: 171 | """Save secrets to encrypted file.""" 172 | try: 173 | data = json.dumps(self._secrets_data).encode() 174 | encrypted_data = self._cipher.encrypt(data) 175 | 176 | with open(self.secrets_file, 'wb') as f: 177 | f.write(encrypted_data) 178 | except Exception as e: 179 | raise ConfigurationError(f"Failed to save secrets file: {e}") 180 | 181 | def _get_file_secret(self, key: str, default: Any) -> Any: 182 | """Get secret from file.""" 183 | return self._secrets_data.get(key, default) 184 | 185 | def _set_file_secret(self, key: str, value: Any) -> None: 186 | """Set secret in file.""" 187 | self._secrets_data[key] = value 188 | self._save_secrets_file() 189 | 190 | def _delete_file_secret(self, key: str) -> None: 191 | """Delete secret from file.""" 192 | self._secrets_data.pop(key, None) 193 | self._save_secrets_file() 194 | 195 | # AWS Secrets Manager provider methods 196 | 197 | def _init_aws_provider(self) -> None: 198 | """Initialize AWS Secrets Manager provider.""" 199 | try: 200 | import boto3 201 | self._aws_client = boto3.client( 202 | 'secretsmanager', 203 | region_name=self.config.get('region', 'us-east-1') 204 | ) 205 | except ImportError: 206 | raise ConfigurationError( 207 | "boto3 required for AWS Secrets Manager. Install with: pip install boto3" 208 | ) 209 | 210 | def _get_aws_secret(self, key: str, default: Any) -> Any: 211 | """Get secret from AWS Secrets Manager.""" 212 | try: 213 | response = self._aws_client.get_secret_value(SecretId=key) 214 | if 'SecretString' in response: 215 | secret = response['SecretString'] 216 | try: 217 | return json.loads(secret) 218 | except json.JSONDecodeError: 219 | return secret 220 | else: 221 | return base64.b64decode(response['SecretBinary']) 222 | except self._aws_client.exceptions.ResourceNotFoundException: 223 | return default 224 | except Exception as e: 225 | raise ConfigurationError(f"Failed to get AWS secret: {e}") 226 | 227 | def _set_aws_secret(self, key: str, value: Any) -> None: 228 | """Set secret in AWS Secrets Manager.""" 229 | try: 230 | secret_string = json.dumps(value) if not isinstance(value, str) else value 231 | try: 232 | self._aws_client.update_secret( 233 | SecretId=key, 234 | SecretString=secret_string 235 | ) 236 | except self._aws_client.exceptions.ResourceNotFoundException: 237 | self._aws_client.create_secret( 238 | Name=key, 239 | SecretString=secret_string 240 | ) 241 | except Exception as e: 242 | raise ConfigurationError(f"Failed to set AWS secret: {e}") 243 | 244 | def _delete_aws_secret(self, key: str) -> None: 245 | """Delete secret from AWS Secrets Manager.""" 246 | try: 247 | self._aws_client.delete_secret( 248 | SecretId=key, 249 | ForceDeleteWithoutRecovery=True 250 | ) 251 | except Exception as e: 252 | raise ConfigurationError(f"Failed to delete AWS secret: {e}") 253 | 254 | # HashiCorp Vault provider methods 255 | 256 | def _init_vault_provider(self) -> None: 257 | """Initialize HashiCorp Vault provider.""" 258 | try: 259 | import hvac 260 | self._vault_client = hvac.Client( 261 | url=self.config.get('url', 'http://localhost:8200'), 262 | token=self.config.get('token') or os.getenv('VAULT_TOKEN') 263 | ) 264 | if not self._vault_client.is_authenticated(): 265 | raise ConfigurationError("Vault authentication failed") 266 | except ImportError: 267 | raise ConfigurationError( 268 | "hvac required for HashiCorp Vault. Install with: pip install hvac" 269 | ) 270 | 271 | def _get_vault_secret(self, key: str, default: Any) -> Any: 272 | """Get secret from HashiCorp Vault.""" 273 | try: 274 | mount_point = self.config.get('mount_point', 'secret') 275 | response = self._vault_client.secrets.kv.v2.read_secret_version( 276 | path=key, 277 | mount_point=mount_point 278 | ) 279 | return response['data']['data'] 280 | except Exception: 281 | return default 282 | 283 | def _set_vault_secret(self, key: str, value: Any) -> None: 284 | """Set secret in HashiCorp Vault.""" 285 | try: 286 | mount_point = self.config.get('mount_point', 'secret') 287 | self._vault_client.secrets.kv.v2.create_or_update_secret( 288 | path=key, 289 | secret=dict(value=value) if not isinstance(value, dict) else value, 290 | mount_point=mount_point 291 | ) 292 | except Exception as e: 293 | raise ConfigurationError(f"Failed to set Vault secret: {e}") 294 | 295 | def _delete_vault_secret(self, key: str) -> None: 296 | """Delete secret from HashiCorp Vault.""" 297 | try: 298 | mount_point = self.config.get('mount_point', 'secret') 299 | self._vault_client.secrets.kv.v2.delete_metadata_and_all_versions( 300 | path=key, 301 | mount_point=mount_point 302 | ) 303 | except Exception as e: 304 | raise ConfigurationError(f"Failed to delete Vault secret: {e}") 305 | 306 | # Azure Key Vault provider methods 307 | 308 | def _init_azure_provider(self) -> None: 309 | """Initialize Azure Key Vault provider.""" 310 | try: 311 | from azure.keyvault.secrets import SecretClient 312 | from azure.identity import DefaultAzureCredential 313 | 314 | vault_url = self.config.get('vault_url') 315 | if not vault_url: 316 | raise ConfigurationError("vault_url required for Azure Key Vault") 317 | 318 | credential = DefaultAzureCredential() 319 | self._azure_client = SecretClient( 320 | vault_url=vault_url, 321 | credential=credential 322 | ) 323 | except ImportError: 324 | raise ConfigurationError( 325 | "azure-keyvault-secrets required. Install with: " 326 | "pip install azure-keyvault-secrets azure-identity" 327 | ) 328 | 329 | def _get_azure_secret(self, key: str, default: Any) -> Any: 330 | """Get secret from Azure Key Vault.""" 331 | try: 332 | secret = self._azure_client.get_secret(key) 333 | return secret.value 334 | except Exception: 335 | return default 336 | 337 | def _set_azure_secret(self, key: str, value: Any) -> None: 338 | """Set secret in Azure Key Vault.""" 339 | try: 340 | value_str = json.dumps(value) if not isinstance(value, str) else value 341 | self._azure_client.set_secret(key, value_str) 342 | except Exception as e: 343 | raise ConfigurationError(f"Failed to set Azure secret: {e}") 344 | 345 | def _delete_azure_secret(self, key: str) -> None: 346 | """Delete secret from Azure Key Vault.""" 347 | try: 348 | poller = self._azure_client.begin_delete_secret(key) 349 | poller.wait() 350 | except Exception as e: 351 | raise ConfigurationError(f"Failed to delete Azure secret: {e}") -------------------------------------------------------------------------------- /cognidb/config/loader.py: -------------------------------------------------------------------------------- 1 | """Configuration loader with multiple source support.""" 2 | 3 | import os 4 | import json 5 | import yaml 6 | from pathlib import Path 7 | from typing import Dict, Any, Optional, List 8 | from .settings import Settings, DatabaseConfig, LLMConfig, CacheConfig, SecurityConfig 9 | from .settings import DatabaseType, LLMProvider, CacheProvider 10 | from .secrets import SecretsManager 11 | from ..core.exceptions import ConfigurationError 12 | 13 | 14 | class ConfigLoader: 15 | """ 16 | Load configuration from multiple sources. 17 | 18 | Priority order (highest to lowest): 19 | 1. Environment variables 20 | 2. Config file (JSON/YAML) 21 | 3. Defaults 22 | """ 23 | 24 | def __init__(self, 25 | config_file: Optional[str] = None, 26 | secrets_manager: Optional[SecretsManager] = None): 27 | """ 28 | Initialize config loader. 29 | 30 | Args: 31 | config_file: Path to configuration file 32 | secrets_manager: Secrets manager instance 33 | """ 34 | self.config_file = config_file or self._find_config_file() 35 | self.secrets_manager = secrets_manager or SecretsManager() 36 | self._config_data: Dict[str, Any] = {} 37 | 38 | def load(self) -> Settings: 39 | """ 40 | Load configuration from all sources. 41 | 42 | Returns: 43 | Settings object 44 | """ 45 | # Load from file if exists 46 | if self.config_file and Path(self.config_file).exists(): 47 | self._load_from_file() 48 | 49 | # Override with environment variables 50 | self._load_from_env() 51 | 52 | # Load secrets 53 | self._load_secrets() 54 | 55 | # Create settings object 56 | settings = self._create_settings() 57 | 58 | # Validate 59 | errors = settings.validate() 60 | if errors: 61 | raise ConfigurationError(f"Configuration errors: {', '.join(errors)}") 62 | 63 | return settings 64 | 65 | def _find_config_file(self) -> Optional[str]: 66 | """Find configuration file in standard locations.""" 67 | # Check environment variable 68 | if 'COGNIDB_CONFIG' in os.environ: 69 | return os.environ['COGNIDB_CONFIG'] 70 | 71 | # Check standard locations 72 | locations = [ 73 | 'cognidb.yaml', 74 | 'cognidb.yml', 75 | 'cognidb.json', 76 | '.cognidb.yaml', 77 | '.cognidb.yml', 78 | '.cognidb.json', 79 | str(Path.home() / '.cognidb' / 'config.yaml'), 80 | str(Path.home() / '.cognidb' / 'config.yml'), 81 | str(Path.home() / '.cognidb' / 'config.json'), 82 | '/etc/cognidb/config.yaml', 83 | '/etc/cognidb/config.yml', 84 | '/etc/cognidb/config.json', 85 | ] 86 | 87 | for location in locations: 88 | if Path(location).exists(): 89 | return location 90 | 91 | return None 92 | 93 | def _load_from_file(self) -> None: 94 | """Load configuration from file.""" 95 | try: 96 | with open(self.config_file, 'r') as f: 97 | if self.config_file.endswith(('.yaml', '.yml')): 98 | self._config_data = yaml.safe_load(f) 99 | else: 100 | self._config_data = json.load(f) 101 | except Exception as e: 102 | raise ConfigurationError(f"Failed to load config file: {e}") 103 | 104 | def _load_from_env(self) -> None: 105 | """Load configuration from environment variables.""" 106 | # Database settings 107 | if 'DB_TYPE' in os.environ: 108 | self._set_nested('database.type', os.environ['DB_TYPE']) 109 | if 'DB_HOST' in os.environ: 110 | self._set_nested('database.host', os.environ['DB_HOST']) 111 | if 'DB_PORT' in os.environ: 112 | self._set_nested('database.port', int(os.environ['DB_PORT'])) 113 | if 'DB_NAME' in os.environ: 114 | self._set_nested('database.database', os.environ['DB_NAME']) 115 | if 'DB_USER' in os.environ: 116 | self._set_nested('database.username', os.environ['DB_USER']) 117 | if 'DB_PASSWORD' in os.environ: 118 | self._set_nested('database.password', os.environ['DB_PASSWORD']) 119 | 120 | # LLM settings 121 | if 'LLM_PROVIDER' in os.environ: 122 | self._set_nested('llm.provider', os.environ['LLM_PROVIDER']) 123 | if 'LLM_API_KEY' in os.environ: 124 | self._set_nested('llm.api_key', os.environ['LLM_API_KEY']) 125 | if 'LLM_MODEL' in os.environ: 126 | self._set_nested('llm.model_name', os.environ['LLM_MODEL']) 127 | 128 | # Cache settings 129 | if 'CACHE_PROVIDER' in os.environ: 130 | self._set_nested('cache.provider', os.environ['CACHE_PROVIDER']) 131 | 132 | # Security settings 133 | if 'SECURITY_SELECT_ONLY' in os.environ: 134 | self._set_nested('security.allow_only_select', 135 | os.environ['SECURITY_SELECT_ONLY'].lower() == 'true') 136 | 137 | # Application settings 138 | if 'ENVIRONMENT' in os.environ: 139 | self._set_nested('environment', os.environ['ENVIRONMENT']) 140 | if 'DEBUG' in os.environ: 141 | self._set_nested('debug', os.environ['DEBUG'].lower() == 'true') 142 | if 'LOG_LEVEL' in os.environ: 143 | self._set_nested('log_level', os.environ['LOG_LEVEL']) 144 | 145 | def _load_secrets(self) -> None: 146 | """Load secrets from secrets manager.""" 147 | # Database password 148 | if not self._get_nested('database.password'): 149 | password = self.secrets_manager.get_secret('DB_PASSWORD') 150 | if password: 151 | self._set_nested('database.password', password) 152 | 153 | # LLM API key 154 | if not self._get_nested('llm.api_key'): 155 | api_key = self.secrets_manager.get_secret('LLM_API_KEY') 156 | if api_key: 157 | self._set_nested('llm.api_key', api_key) 158 | 159 | # Redis password 160 | if not self._get_nested('cache.redis_password'): 161 | redis_password = self.secrets_manager.get_secret('REDIS_PASSWORD') 162 | if redis_password: 163 | self._set_nested('cache.redis_password', redis_password) 164 | 165 | # Encryption key 166 | if not self._get_nested('security.encryption_key'): 167 | encryption_key = self.secrets_manager.get_secret('ENCRYPTION_KEY') 168 | if encryption_key: 169 | self._set_nested('security.encryption_key', encryption_key) 170 | 171 | def _create_settings(self) -> Settings: 172 | """Create Settings object from loaded configuration.""" 173 | # Database configuration 174 | db_config = self._get_nested('database', {}) 175 | database = DatabaseConfig( 176 | type=DatabaseType(db_config.get('type', 'postgresql')), 177 | host=db_config.get('host', 'localhost'), 178 | port=db_config.get('port', 5432), 179 | database=db_config.get('database', 'cognidb'), 180 | username=db_config.get('username'), 181 | password=db_config.get('password'), 182 | pool_size=db_config.get('pool_size', 5), 183 | max_overflow=db_config.get('max_overflow', 10), 184 | pool_timeout=db_config.get('pool_timeout', 30), 185 | pool_recycle=db_config.get('pool_recycle', 3600), 186 | ssl_enabled=db_config.get('ssl_enabled', False), 187 | ssl_ca_cert=db_config.get('ssl_ca_cert'), 188 | ssl_client_cert=db_config.get('ssl_client_cert'), 189 | ssl_client_key=db_config.get('ssl_client_key'), 190 | query_timeout=db_config.get('query_timeout', 30), 191 | max_result_size=db_config.get('max_result_size', 10000), 192 | options=db_config.get('options', {}) 193 | ) 194 | 195 | # LLM configuration 196 | llm_config = self._get_nested('llm', {}) 197 | llm = LLMConfig( 198 | provider=LLMProvider(llm_config.get('provider', 'openai')), 199 | api_key=llm_config.get('api_key'), 200 | model_name=llm_config.get('model_name', 'gpt-4'), 201 | temperature=llm_config.get('temperature', 0.1), 202 | max_tokens=llm_config.get('max_tokens', 1000), 203 | timeout=llm_config.get('timeout', 30), 204 | max_tokens_per_query=llm_config.get('max_tokens_per_query', 2000), 205 | max_queries_per_minute=llm_config.get('max_queries_per_minute', 60), 206 | max_cost_per_day=llm_config.get('max_cost_per_day', 100.0), 207 | system_prompt=llm_config.get('system_prompt'), 208 | few_shot_examples=llm_config.get('few_shot_examples', []), 209 | azure_endpoint=llm_config.get('azure_endpoint'), 210 | azure_deployment=llm_config.get('azure_deployment'), 211 | huggingface_model_id=llm_config.get('huggingface_model_id'), 212 | local_model_path=llm_config.get('local_model_path'), 213 | enable_function_calling=llm_config.get('enable_function_calling', True), 214 | enable_streaming=llm_config.get('enable_streaming', False), 215 | retry_attempts=llm_config.get('retry_attempts', 3), 216 | retry_delay=llm_config.get('retry_delay', 1.0) 217 | ) 218 | 219 | # Cache configuration 220 | cache_config = self._get_nested('cache', {}) 221 | cache = CacheConfig( 222 | provider=CacheProvider(cache_config.get('provider', 'in_memory')), 223 | query_result_ttl=cache_config.get('query_result_ttl', 3600), 224 | schema_ttl=cache_config.get('schema_ttl', 86400), 225 | llm_response_ttl=cache_config.get('llm_response_ttl', 7200), 226 | max_cache_size_mb=cache_config.get('max_cache_size_mb', 100), 227 | max_entry_size_mb=cache_config.get('max_entry_size_mb', 10), 228 | eviction_policy=cache_config.get('eviction_policy', 'lru'), 229 | redis_host=cache_config.get('redis_host', 'localhost'), 230 | redis_port=cache_config.get('redis_port', 6379), 231 | redis_password=cache_config.get('redis_password'), 232 | redis_db=cache_config.get('redis_db', 0), 233 | redis_ssl=cache_config.get('redis_ssl', False), 234 | disk_cache_path=cache_config.get('disk_cache_path', 235 | str(Path.home() / '.cognidb' / 'cache')), 236 | enable_compression=cache_config.get('enable_compression', True), 237 | enable_async_writes=cache_config.get('enable_async_writes', True) 238 | ) 239 | 240 | # Security configuration 241 | security_config = self._get_nested('security', {}) 242 | security = SecurityConfig( 243 | allow_only_select=security_config.get('allow_only_select', True), 244 | max_query_complexity=security_config.get('max_query_complexity', 10), 245 | allow_subqueries=security_config.get('allow_subqueries', False), 246 | allow_unions=security_config.get('allow_unions', False), 247 | enable_rate_limiting=security_config.get('enable_rate_limiting', True), 248 | rate_limit_per_minute=security_config.get('rate_limit_per_minute', 100), 249 | rate_limit_per_hour=security_config.get('rate_limit_per_hour', 1000), 250 | enable_access_control=security_config.get('enable_access_control', True), 251 | default_user_permissions=security_config.get('default_user_permissions', ['SELECT']), 252 | require_authentication=security_config.get('require_authentication', False), 253 | enable_audit_logging=security_config.get('enable_audit_logging', True), 254 | audit_log_path=security_config.get('audit_log_path', 255 | str(Path.home() / '.cognidb' / 'audit.log')), 256 | log_query_results=security_config.get('log_query_results', False), 257 | encrypt_cache=security_config.get('encrypt_cache', True), 258 | encrypt_logs=security_config.get('encrypt_logs', True), 259 | encryption_key=security_config.get('encryption_key'), 260 | allowed_ip_ranges=security_config.get('allowed_ip_ranges', []), 261 | require_ssl=security_config.get('require_ssl', True) 262 | ) 263 | 264 | # Create settings 265 | return Settings( 266 | database=database, 267 | llm=llm, 268 | cache=cache, 269 | security=security, 270 | app_name=self._get_nested('app_name', 'CogniDB'), 271 | environment=self._get_nested('environment', 'production'), 272 | debug=self._get_nested('debug', False), 273 | log_level=self._get_nested('log_level', 'INFO'), 274 | data_dir=self._get_nested('data_dir', str(Path.home() / '.cognidb')), 275 | log_dir=self._get_nested('log_dir', str(Path.home() / '.cognidb' / 'logs')), 276 | enable_natural_language=self._get_nested('enable_natural_language', True), 277 | enable_query_explanation=self._get_nested('enable_query_explanation', True), 278 | enable_query_optimization=self._get_nested('enable_query_optimization', True), 279 | enable_auto_indexing=self._get_nested('enable_auto_indexing', False), 280 | enable_metrics=self._get_nested('enable_metrics', True), 281 | metrics_port=self._get_nested('metrics_port', 9090), 282 | enable_tracing=self._get_nested('enable_tracing', True), 283 | tracing_endpoint=self._get_nested('tracing_endpoint') 284 | ) 285 | 286 | def _get_nested(self, path: str, default: Any = None) -> Any: 287 | """Get nested configuration value.""" 288 | keys = path.split('.') 289 | value = self._config_data 290 | 291 | for key in keys: 292 | if isinstance(value, dict) and key in value: 293 | value = value[key] 294 | else: 295 | return default 296 | 297 | return value 298 | 299 | def _set_nested(self, path: str, value: Any) -> None: 300 | """Set nested configuration value.""" 301 | keys = path.split('.') 302 | target = self._config_data 303 | 304 | for key in keys[:-1]: 305 | if key not in target: 306 | target[key] = {} 307 | target = target[key] 308 | 309 | target[keys[-1]] = value -------------------------------------------------------------------------------- /cognidb/drivers/postgres_driver.py: -------------------------------------------------------------------------------- 1 | """Secure PostgreSQL driver implementation.""" 2 | 3 | import time 4 | import logging 5 | from typing import Dict, List, Any, Optional 6 | import psycopg2 7 | from psycopg2 import pool, sql, extras, OperationalError, DatabaseError 8 | from psycopg2.extensions import ISOLATION_LEVEL_READ_COMMITTED 9 | from .base_driver import BaseDriver 10 | from ..core.exceptions import ConnectionError, ExecutionError 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class PostgreSQLDriver(BaseDriver): 16 | """ 17 | PostgreSQL database driver with security enhancements. 18 | 19 | Features: 20 | - Connection pooling with pgbouncer support 21 | - Parameterized queries with proper escaping 22 | - SSL/TLS enforcement 23 | - Statement timeout enforcement 24 | - Prepared statements 25 | - EXPLAIN ANALYZE integration 26 | """ 27 | 28 | def __init__(self, config: Dict[str, Any]): 29 | """Initialize PostgreSQL driver.""" 30 | super().__init__(config) 31 | self.pool = None 32 | self._prepared_statements = {} 33 | 34 | def connect(self) -> None: 35 | """Establish connection to PostgreSQL database.""" 36 | try: 37 | # Prepare connection config 38 | conn_params = { 39 | 'host': self.config['host'], 40 | 'port': self.config.get('port', 5432), 41 | 'database': self.config['database'], 42 | 'user': self.config.get('username'), 43 | 'password': self.config.get('password'), 44 | 'connect_timeout': self.config.get('connection_timeout', 10), 45 | 'application_name': 'CogniDB', 46 | 'options': f"-c statement_timeout={self.config.get('query_timeout', 30)}s" 47 | } 48 | 49 | # SSL configuration 50 | if self.config.get('ssl_enabled', True): 51 | conn_params['sslmode'] = 'require' 52 | if self.config.get('ssl_ca_cert'): 53 | conn_params['sslrootcert'] = self.config['ssl_ca_cert'] 54 | if self.config.get('ssl_client_cert'): 55 | conn_params['sslcert'] = self.config['ssl_client_cert'] 56 | if self.config.get('ssl_client_key'): 57 | conn_params['sslkey'] = self.config['ssl_client_key'] 58 | 59 | # Create connection pool 60 | self.pool = psycopg2.pool.ThreadedConnectionPool( 61 | minconn=1, 62 | maxconn=self.config.get('pool_size', 5), 63 | **conn_params 64 | ) 65 | 66 | # Test connection 67 | self.connection = self.pool.getconn() 68 | self.connection.set_isolation_level(ISOLATION_LEVEL_READ_COMMITTED) 69 | self._connection_time = time.time() 70 | 71 | # Set additional session parameters 72 | with self.connection.cursor() as cursor: 73 | cursor.execute("SET TIME ZONE 'UTC'") 74 | cursor.execute("SET lock_timeout = '5s'") 75 | cursor.execute("SET idle_in_transaction_session_timeout = '60s'") 76 | 77 | self.connection.commit() 78 | 79 | logger.info(f"Connected to PostgreSQL database: {self.config['database']}") 80 | 81 | except (OperationalError, DatabaseError) as e: 82 | logger.error(f"PostgreSQL connection failed: {str(e)}") 83 | raise ConnectionError(f"Failed to connect to PostgreSQL: {str(e)}") 84 | 85 | def disconnect(self) -> None: 86 | """Close the database connection.""" 87 | if self.connection: 88 | try: 89 | # Clear prepared statements 90 | for stmt_name in self._prepared_statements: 91 | try: 92 | with self.connection.cursor() as cursor: 93 | cursor.execute(f"DEALLOCATE {stmt_name}") 94 | except Exception: 95 | pass 96 | 97 | self._prepared_statements.clear() 98 | 99 | # Return connection to pool 100 | if self.pool: 101 | self.pool.putconn(self.connection) 102 | 103 | logger.info("Disconnected from PostgreSQL database") 104 | 105 | except Exception as e: 106 | logger.error(f"Error closing connection: {str(e)}") 107 | finally: 108 | self.connection = None 109 | self._connection_time = None 110 | 111 | # Close the pool 112 | if self.pool: 113 | self.pool.closeall() 114 | self.pool = None 115 | 116 | def _create_connection(self): 117 | """Get connection from pool.""" 118 | if not self.pool: 119 | raise ConnectionError("Connection pool not initialized") 120 | return self.pool.getconn() 121 | 122 | def _close_connection(self): 123 | """Return connection to pool.""" 124 | if self.connection and self.pool: 125 | self.pool.putconn(self.connection) 126 | 127 | def _execute_with_timeout(self, 128 | query: str, 129 | params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: 130 | """Execute query with timeout and proper parameterization.""" 131 | cursor = None 132 | 133 | try: 134 | # Ensure we have a valid connection 135 | if not self.connection or self.connection.closed: 136 | self.connection = self._create_connection() 137 | 138 | # Create cursor with RealDictCursor for dict results 139 | cursor = self.connection.cursor(cursor_factory=extras.RealDictCursor) 140 | 141 | # Execute query with parameters 142 | if params: 143 | # Use psycopg2's parameter substitution 144 | cursor.execute(query, params) 145 | else: 146 | cursor.execute(query) 147 | 148 | # Handle results 149 | if cursor.description: 150 | results = cursor.fetchall() 151 | 152 | # Convert RealDictRow to regular dict 153 | results = [dict(row) for row in results] 154 | 155 | # Apply result size limit 156 | max_results = self.config.get('max_result_size', 10000) 157 | if len(results) > max_results: 158 | logger.warning(f"Result truncated from {len(results)} to {max_results} rows") 159 | results = results[:max_results] 160 | 161 | return results 162 | else: 163 | # For non-SELECT queries 164 | self.connection.commit() 165 | return [{'affected_rows': cursor.rowcount}] 166 | 167 | except (OperationalError, DatabaseError) as e: 168 | if self.connection: 169 | self.connection.rollback() 170 | raise ExecutionError(f"Query execution failed: {str(e)}") 171 | finally: 172 | if cursor: 173 | cursor.close() 174 | 175 | def execute_prepared(self, 176 | name: str, 177 | query: str, 178 | params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: 179 | """Execute a prepared statement for better performance.""" 180 | cursor = None 181 | 182 | try: 183 | cursor = self.connection.cursor(cursor_factory=extras.RealDictCursor) 184 | 185 | # Prepare statement if not already prepared 186 | if name not in self._prepared_statements: 187 | cursor.execute(f"PREPARE {name} AS {query}") 188 | self._prepared_statements[name] = query 189 | 190 | # Execute prepared statement 191 | if params: 192 | execute_query = sql.SQL("EXECUTE {} ({})").format( 193 | sql.Identifier(name), 194 | sql.SQL(', ').join(sql.Placeholder() * len(params)) 195 | ) 196 | cursor.execute(execute_query, list(params.values())) 197 | else: 198 | cursor.execute(f"EXECUTE {name}") 199 | 200 | # Fetch results 201 | if cursor.description: 202 | return [dict(row) for row in cursor.fetchall()] 203 | else: 204 | self.connection.commit() 205 | return [{'affected_rows': cursor.rowcount}] 206 | 207 | except Exception as e: 208 | if self.connection: 209 | self.connection.rollback() 210 | raise ExecutionError(f"Prepared statement execution failed: {str(e)}") 211 | finally: 212 | if cursor: 213 | cursor.close() 214 | 215 | def explain_query(self, query: str, analyze: bool = False) -> Dict[str, Any]: 216 | """Get query execution plan.""" 217 | explain_query = f"EXPLAIN {'ANALYZE' if analyze else ''} {query}" 218 | 219 | try: 220 | results = self._execute_with_timeout(explain_query) 221 | return { 222 | 'plan': results, 223 | 'query': query 224 | } 225 | except Exception as e: 226 | raise ExecutionError(f"Failed to explain query: {str(e)}") 227 | 228 | def _fetch_schema_impl(self) -> Dict[str, Dict[str, str]]: 229 | """Fetch PostgreSQL schema using information_schema.""" 230 | query = """ 231 | SELECT 232 | t.table_name, 233 | c.column_name, 234 | c.data_type, 235 | c.character_maximum_length, 236 | c.numeric_precision, 237 | c.numeric_scale, 238 | c.is_nullable, 239 | c.column_default, 240 | tc.constraint_type 241 | FROM information_schema.tables t 242 | JOIN information_schema.columns c 243 | ON t.table_schema = c.table_schema 244 | AND t.table_name = c.table_name 245 | LEFT JOIN information_schema.key_column_usage kcu 246 | ON c.table_schema = kcu.table_schema 247 | AND c.table_name = kcu.table_name 248 | AND c.column_name = kcu.column_name 249 | LEFT JOIN information_schema.table_constraints tc 250 | ON kcu.constraint_schema = tc.constraint_schema 251 | AND kcu.constraint_name = tc.constraint_name 252 | WHERE t.table_schema = 'public' 253 | AND t.table_type = 'BASE TABLE' 254 | ORDER BY t.table_name, c.ordinal_position 255 | """ 256 | 257 | cursor = None 258 | try: 259 | cursor = self.connection.cursor(cursor_factory=extras.RealDictCursor) 260 | cursor.execute(query) 261 | 262 | schema = {} 263 | for row in cursor: 264 | table_name = row['table_name'] 265 | if table_name not in schema: 266 | schema[table_name] = {} 267 | 268 | # Build column type 269 | col_type = row['data_type'] 270 | if row['character_maximum_length']: 271 | col_type += f"({row['character_maximum_length']})" 272 | elif row['numeric_precision']: 273 | col_type += f"({row['numeric_precision']}" 274 | if row['numeric_scale']: 275 | col_type += f",{row['numeric_scale']}" 276 | col_type += ")" 277 | 278 | if row['is_nullable'] == 'NO': 279 | col_type += ' NOT NULL' 280 | if row['constraint_type'] == 'PRIMARY KEY': 281 | col_type += ' PRIMARY KEY' 282 | if row['column_default']: 283 | col_type += ' DEFAULT' 284 | 285 | schema[table_name][row['column_name']] = col_type 286 | 287 | # Fetch indexes 288 | self._fetch_indexes(schema, cursor) 289 | 290 | return schema 291 | 292 | finally: 293 | if cursor: 294 | cursor.close() 295 | 296 | def _fetch_indexes(self, schema: Dict[str, Dict[str, str]], cursor): 297 | """Fetch index information.""" 298 | query = """ 299 | SELECT 300 | tablename, 301 | indexname, 302 | indexdef 303 | FROM pg_indexes 304 | WHERE schemaname = 'public' 305 | AND indexname NOT LIKE '%_pkey' 306 | ORDER BY tablename, indexname 307 | """ 308 | 309 | cursor.execute(query) 310 | 311 | for row in cursor: 312 | table_name = row['tablename'] 313 | if table_name in schema: 314 | index_key = f"{table_name}_indexes" 315 | if index_key not in schema: 316 | schema[index_key] = [] 317 | schema[index_key].append(f"{row['indexname']}") 318 | 319 | def _begin_transaction(self): 320 | """Begin a transaction.""" 321 | # PostgreSQL starts transaction automatically 322 | pass 323 | 324 | def _commit_transaction(self): 325 | """Commit a transaction.""" 326 | self.connection.commit() 327 | 328 | def _rollback_transaction(self): 329 | """Rollback a transaction.""" 330 | self.connection.rollback() 331 | 332 | def _get_driver_info(self) -> Dict[str, Any]: 333 | """Get PostgreSQL-specific information.""" 334 | info = { 335 | 'server_version': None, 336 | 'connection_id': None, 337 | 'current_schema': None, 338 | 'encoding': None 339 | } 340 | 341 | if self.connection and not self.connection.closed: 342 | try: 343 | cursor = self.connection.cursor() 344 | cursor.execute(""" 345 | SELECT 346 | version(), 347 | pg_backend_pid(), 348 | current_schema(), 349 | pg_encoding_to_char(encoding) 350 | FROM pg_database 351 | WHERE datname = current_database() 352 | """) 353 | result = cursor.fetchone() 354 | info['server_version'] = result[0] 355 | info['connection_id'] = result[1] 356 | info['current_schema'] = result[2] 357 | info['encoding'] = result[3] 358 | cursor.close() 359 | except Exception: 360 | pass 361 | 362 | return info 363 | 364 | @property 365 | def supports_transactions(self) -> bool: 366 | """PostgreSQL supports transactions.""" 367 | return True 368 | 369 | @property 370 | def supports_schemas(self) -> bool: 371 | """PostgreSQL supports schemas.""" 372 | return True --------------------------------------------------------------------------------