├── cognidb
├── security
│ ├── __init__.py
│ ├── query_parser.py
│ ├── access_control.py
│ ├── sanitizer.py
│ └── validator.py
├── config
│ ├── __init__.py
│ ├── settings.py
│ ├── secrets.py
│ └── loader.py
├── drivers
│ ├── __init__.py
│ ├── base_driver.py
│ ├── mysql_driver.py
│ └── postgres_driver.py
├── ai
│ ├── __init__.py
│ ├── llm_manager.py
│ ├── prompt_builder.py
│ ├── query_generator.py
│ └── cost_tracker.py
└── core
│ ├── __init__.py
│ ├── exceptions.py
│ ├── interfaces.py
│ └── query_intent.py
├── requirements.txt
├── .gitignore
├── setup.py
├── cognidb.example.yaml
├── git_commit_script.bat
├── git_commit_script.ps1
├── examples
└── basic_usage.py
├── Readme.md
└── __init__.py
/cognidb/security/__init__.py:
--------------------------------------------------------------------------------
1 | """Security module for CogniDB."""
2 |
3 | from .validator import QuerySecurityValidator
4 | from .sanitizer import InputSanitizer
5 | from .query_parser import SQLQueryParser
6 | from .access_control import AccessController
7 |
8 | __all__ = [
9 | 'QuerySecurityValidator',
10 | 'InputSanitizer',
11 | 'SQLQueryParser',
12 | 'AccessController'
13 | ]
--------------------------------------------------------------------------------
/cognidb/config/__init__.py:
--------------------------------------------------------------------------------
1 | """Configuration module for CogniDB."""
2 |
3 | from .settings import Settings, DatabaseConfig, LLMConfig, CacheConfig, SecurityConfig
4 | from .secrets import SecretsManager
5 | from .loader import ConfigLoader
6 |
7 | __all__ = [
8 | 'Settings',
9 | 'DatabaseConfig',
10 | 'LLMConfig',
11 | 'CacheConfig',
12 | 'SecurityConfig',
13 | 'SecretsManager',
14 | 'ConfigLoader'
15 | ]
--------------------------------------------------------------------------------
/cognidb/drivers/__init__.py:
--------------------------------------------------------------------------------
1 | """Database drivers module."""
2 |
3 | from .mysql_driver import MySQLDriver
4 | from .postgres_driver import PostgreSQLDriver
5 | from .mongodb_driver import MongoDBDriver
6 | from .dynamodb_driver import DynamoDBDriver
7 | from .sqlite_driver import SQLiteDriver
8 |
9 | __all__ = [
10 | 'MySQLDriver',
11 | 'PostgreSQLDriver',
12 | 'MongoDBDriver',
13 | 'DynamoDBDriver',
14 | 'SQLiteDriver'
15 | ]
--------------------------------------------------------------------------------
/cognidb/ai/__init__.py:
--------------------------------------------------------------------------------
1 | """AI and LLM integration module."""
2 |
3 | from .llm_manager import LLMManager
4 | from .prompt_builder import PromptBuilder
5 | from .query_generator import QueryGenerator
6 | from .cost_tracker import CostTracker
7 | from .providers import OpenAIProvider, AnthropicProvider, AzureOpenAIProvider
8 |
9 | __all__ = [
10 | 'LLMManager',
11 | 'PromptBuilder',
12 | 'QueryGenerator',
13 | 'CostTracker',
14 | 'OpenAIProvider',
15 | 'AnthropicProvider',
16 | 'AzureOpenAIProvider'
17 | ]
--------------------------------------------------------------------------------
/cognidb/core/__init__.py:
--------------------------------------------------------------------------------
1 | """Core abstractions for CogniDB."""
2 |
3 | from .query_intent import QueryIntent, QueryType, JoinCondition, Aggregation
4 | from .interfaces import (
5 | DatabaseDriver,
6 | QueryTranslator,
7 | SecurityValidator,
8 | ResultNormalizer,
9 | CacheProvider
10 | )
11 | from .exceptions import (
12 | CogniDBError,
13 | SecurityError,
14 | TranslationError,
15 | ExecutionError,
16 | ValidationError
17 | )
18 |
19 | __all__ = [
20 | 'QueryIntent',
21 | 'QueryType',
22 | 'JoinCondition',
23 | 'Aggregation',
24 | 'DatabaseDriver',
25 | 'QueryTranslator',
26 | 'SecurityValidator',
27 | 'ResultNormalizer',
28 | 'CacheProvider',
29 | 'CogniDBError',
30 | 'SecurityError',
31 | 'TranslationError',
32 | 'ExecutionError',
33 | 'ValidationError'
34 | ]
--------------------------------------------------------------------------------
/cognidb/core/exceptions.py:
--------------------------------------------------------------------------------
1 | """Custom exceptions for CogniDB."""
2 |
3 |
4 | class CogniDBError(Exception):
5 | """Base exception for all CogniDB errors."""
6 |
7 | def __init__(self, message: str, details: dict = None):
8 | super().__init__(message)
9 | self.message = message
10 | self.details = details or {}
11 |
12 |
13 | class SecurityError(CogniDBError):
14 | """Raised when a security violation is detected."""
15 | pass
16 |
17 |
18 | class ValidationError(CogniDBError):
19 | """Raised when validation fails."""
20 | pass
21 |
22 |
23 | class TranslationError(CogniDBError):
24 | """Raised when query translation fails."""
25 | pass
26 |
27 |
28 | class ExecutionError(CogniDBError):
29 | """Raised when query execution fails."""
30 | pass
31 |
32 |
33 | class ConnectionError(CogniDBError):
34 | """Raised when database connection fails."""
35 | pass
36 |
37 |
38 | class SchemaError(CogniDBError):
39 | """Raised when schema-related operations fail."""
40 | pass
41 |
42 |
43 | class ConfigurationError(CogniDBError):
44 | """Raised when configuration is invalid."""
45 | pass
46 |
47 |
48 | class CacheError(CogniDBError):
49 | """Raised when cache operations fail."""
50 | pass
51 |
52 |
53 | class RateLimitError(CogniDBError):
54 | """Raised when rate limits are exceeded."""
55 |
56 | def __init__(self, message: str, retry_after: int = None, details: dict = None):
57 | super().__init__(message, details)
58 | self.retry_after = retry_after
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Core database drivers
2 | mysql-connector-python>=8.0.33
3 | psycopg2-binary>=2.9.9
4 | pymongo>=4.6.0
5 | boto3>=1.28.0 # For AWS DynamoDB
6 |
7 | # SQL parsing and validation
8 | sqlparse>=0.4.4
9 |
10 | # LLM providers
11 | openai>=1.0.0
12 | anthropic>=0.8.0
13 | transformers>=4.35.0 # For HuggingFace models
14 | torch>=2.0.0 # For local models
15 | llama-cpp-python>=0.2.0 # For llama.cpp models (optional)
16 |
17 | # Security and encryption
18 | cryptography>=41.0.0
19 |
20 | # Configuration
21 | pyyaml>=6.0.1
22 |
23 | # Caching
24 | redis>=5.0.0 # For Redis cache (optional)
25 |
26 | # Cloud secrets management (optional)
27 | hvac>=1.2.0 # HashiCorp Vault
28 | azure-identity>=1.14.0 # Azure Key Vault
29 | azure-keyvault-secrets>=4.7.0 # Azure Key Vault
30 |
31 | # Testing
32 | pytest>=7.4.0
33 | pytest-cov>=4.1.0
34 | pytest-asyncio>=0.21.0
35 | pytest-mock>=3.11.0
36 |
37 | # Development tools
38 | black>=23.0.0
39 | flake8>=6.0.0
40 | mypy>=1.5.0
41 | pre-commit>=3.3.0
42 |
43 | # Documentation
44 | sphinx>=7.0.0
45 | sphinx-rtd-theme>=1.3.0
46 |
47 | # Monitoring and metrics (optional)
48 | prometheus-client>=0.18.0
49 | opentelemetry-api>=1.20.0
50 | opentelemetry-sdk>=1.20.0
51 | opentelemetry-instrumentation>=0.41b0
52 |
53 | # Data processing
54 | pandas>=2.0.0 # For result formatting
55 | numpy>=1.24.0 # For numerical operations
56 |
57 | # Web framework (if adding API)
58 | fastapi>=0.104.0 # Optional, for REST API
59 | uvicorn>=0.24.0 # ASGI server
60 |
61 | # Utilities
62 | python-dotenv>=1.0.0 # For .env file support
63 | click>=8.1.0 # For CLI interface
64 | rich>=13.6.0 # For beautiful terminal output
65 | tabulate>=0.9.0 # For table formatting
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 | docs/_build/doctrees
74 | docs/_build/html
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # celery beat schedule file
90 | celerybeat-schedule
91 | celerybeat.pid
92 |
93 | # SageMath parsed files
94 | *.sage.py
95 |
96 | # Environments
97 | .env
98 | .venv
99 | env/
100 | venv/
101 | ENV/
102 | env.bak/
103 | venv.bak/
104 |
105 | # Spyder project settings
106 | .spyderproject
107 | .spyproject
108 |
109 | # Rope project settings
110 | .ropeproject
111 |
112 | # mkdocs documentation
113 | /site
114 |
115 | # mypy
116 | .mypy_cache/
117 | .dmypy.json
118 | dmypy.json
119 |
120 | # Pyre type checker
121 | .pyre/
122 |
123 | # pytype static type analyzer
124 | .pytype/
125 |
126 | # Cython debug symbols
127 | cython_debug/
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """Setup configuration for CogniDB."""
2 |
3 | from setuptools import setup, find_packages
4 | from pathlib import Path
5 |
6 | # Read the README file
7 | this_directory = Path(__file__).parent
8 | long_description = (this_directory / "Readme.md").read_text()
9 |
10 | # Read requirements
11 | requirements = (this_directory / "requirements.txt").read_text().splitlines()
12 |
13 | # Separate optional requirements
14 | core_requirements = []
15 | optional_requirements = {
16 | 'llama': ['llama-cpp-python>=0.2.0'],
17 | 'azure': ['azure-identity>=1.14.0', 'azure-keyvault-secrets>=4.7.0'],
18 | 'vault': ['hvac>=1.2.0'],
19 | 'redis': ['redis>=5.0.0'],
20 | 'api': ['fastapi>=0.104.0', 'uvicorn>=0.24.0'],
21 | 'dev': [
22 | 'pytest>=7.4.0',
23 | 'pytest-cov>=4.1.0',
24 | 'pytest-asyncio>=0.21.0',
25 | 'pytest-mock>=3.11.0',
26 | 'black>=23.0.0',
27 | 'flake8>=6.0.0',
28 | 'mypy>=1.5.0',
29 | 'pre-commit>=3.3.0'
30 | ],
31 | 'docs': [
32 | 'sphinx>=7.0.0',
33 | 'sphinx-rtd-theme>=1.3.0'
34 | ]
35 | }
36 |
37 | # Filter core requirements (exclude optional ones)
38 | for line in requirements:
39 | if line and not line.startswith('#'):
40 | # Skip optional dependencies
41 | if not any(opt in line.lower() for opt in ['llama', 'azure', 'hvac', 'redis', 'fastapi', 'pytest', 'sphinx']):
42 | core_requirements.append(line)
43 |
44 | setup(
45 | name="cognidb",
46 | version="2.0.0",
47 | author="CogniDB Team",
48 | author_email="team@cognidb.io",
49 | description="Secure Natural Language Database Interface",
50 | long_description=long_description,
51 | long_description_content_type="text/markdown",
52 | url="https://github.com/adrienckr/cognidb",
53 | packages=find_packages(),
54 | classifiers=[
55 | "Development Status :: 4 - Beta",
56 | "Intended Audience :: Developers",
57 | "Topic :: Database",
58 | "Topic :: Software Development :: Libraries :: Python Modules",
59 | "License :: OSI Approved :: MIT License",
60 | "Programming Language :: Python :: 3",
61 | "Programming Language :: Python :: 3.8",
62 | "Programming Language :: Python :: 3.9",
63 | "Programming Language :: Python :: 3.10",
64 | "Programming Language :: Python :: 3.11",
65 | "Programming Language :: Python :: 3.12",
66 | ],
67 | python_requires=">=3.8",
68 | install_requires=core_requirements,
69 | extras_require={
70 | **optional_requirements,
71 | 'all': sum(optional_requirements.values(), [])
72 | },
73 | entry_points={
74 | 'console_scripts': [
75 | 'cognidb=cognidb.cli:main',
76 | ],
77 | },
78 | include_package_data=True,
79 | package_data={
80 | 'cognidb': [
81 | 'cognidb.example.yaml',
82 | 'examples/*.py',
83 | ],
84 | },
85 | keywords='database sql natural-language nlp ai llm security',
86 | project_urls={
87 | 'Bug Reports': 'https://github.com/adrienckr/cognidb/issues',
88 | 'Source': 'https://github.com/adrienckr/cognidb',
89 | 'Documentation': 'https://cognidb.readthedocs.io',
90 | },
91 | )
--------------------------------------------------------------------------------
/cognidb.example.yaml:
--------------------------------------------------------------------------------
1 | # CogniDB Configuration Example
2 | # Copy this to cognidb.yaml and update with your settings
3 |
4 | # Application settings
5 | app_name: CogniDB
6 | environment: production
7 | debug: false
8 | log_level: INFO
9 |
10 | # Database configuration
11 | database:
12 | type: postgresql # Options: mysql, postgresql, mongodb, dynamodb, sqlite
13 | host: localhost
14 | port: 5432
15 | database: your_database
16 | username: your_username
17 | password: ${DB_PASSWORD} # Use environment variable
18 |
19 | # Connection pool settings
20 | pool_size: 5
21 | max_overflow: 10
22 | pool_timeout: 30
23 | pool_recycle: 3600
24 |
25 | # SSL/TLS settings
26 | ssl_enabled: true
27 | ssl_ca_cert: null
28 | ssl_client_cert: null
29 | ssl_client_key: null
30 |
31 | # Query settings
32 | query_timeout: 30 # seconds
33 | max_result_size: 10000 # rows
34 |
35 | # LLM configuration
36 | llm:
37 | provider: openai # Options: openai, anthropic, azure_openai, huggingface, local
38 | api_key: ${LLM_API_KEY} # Use environment variable
39 |
40 | # Model settings
41 | model_name: gpt-4
42 | temperature: 0.1
43 | max_tokens: 1000
44 | timeout: 30
45 |
46 | # Cost control
47 | max_tokens_per_query: 2000
48 | max_queries_per_minute: 60
49 | max_cost_per_day: 100.0
50 |
51 | # Few-shot examples for better SQL generation
52 | few_shot_examples:
53 | - query: "Show me all users who registered last month"
54 | sql: "SELECT * FROM users WHERE created_at >= DATE_TRUNC('month', CURRENT_DATE - INTERVAL '1 month') AND created_at < DATE_TRUNC('month', CURRENT_DATE)"
55 |
56 | - query: "What's the total revenue by product category?"
57 | sql: "SELECT p.category, SUM(o.amount) as total_revenue FROM orders o JOIN products p ON o.product_id = p.id GROUP BY p.category ORDER BY total_revenue DESC"
58 |
59 | - query: "Find customers who haven't made a purchase in the last 90 days"
60 | sql: "SELECT c.* FROM customers c WHERE c.id NOT IN (SELECT DISTINCT customer_id FROM orders WHERE order_date > CURRENT_DATE - INTERVAL '90 days')"
61 |
62 | # Cache configuration
63 | cache:
64 | provider: in_memory # Options: in_memory, redis, memcached, disk
65 |
66 | # TTL settings (in seconds)
67 | query_result_ttl: 3600 # 1 hour
68 | schema_ttl: 86400 # 24 hours
69 | llm_response_ttl: 7200 # 2 hours
70 |
71 | # Redis settings (if using Redis)
72 | redis_host: localhost
73 | redis_port: 6379
74 | redis_password: ${REDIS_PASSWORD}
75 | redis_db: 0
76 | redis_ssl: false
77 |
78 | # Security configuration
79 | security:
80 | # Query validation
81 | allow_only_select: true
82 | max_query_complexity: 10
83 | allow_subqueries: false
84 | allow_unions: false
85 |
86 | # Rate limiting
87 | enable_rate_limiting: true
88 | rate_limit_per_minute: 100
89 | rate_limit_per_hour: 1000
90 |
91 | # Access control
92 | enable_access_control: true
93 | default_user_permissions: ["SELECT"]
94 | require_authentication: false
95 |
96 | # Audit logging
97 | enable_audit_logging: true
98 | audit_log_path: ~/.cognidb/audit.log
99 | log_query_results: false
100 |
101 | # Encryption
102 | encrypt_cache: true
103 | encrypt_logs: true
104 | encryption_key: ${ENCRYPTION_KEY}
105 |
106 | # Network security
107 | allowed_ip_ranges: []
108 | require_ssl: true
109 |
110 | # Feature flags
111 | enable_natural_language: true
112 | enable_query_explanation: true
113 | enable_query_optimization: true
114 | enable_auto_indexing: false
115 |
116 | # Monitoring
117 | enable_metrics: true
118 | metrics_port: 9090
119 | enable_tracing: true
120 | tracing_endpoint: null
--------------------------------------------------------------------------------
/git_commit_script.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | REM Script to create multiple logical commits for CogniDB v2.0.0
3 | REM Each commit adds related files with appropriate messages
4 |
5 | echo Starting CogniDB v2.0.0 commit process...
6 | echo.
7 |
8 | REM Commit 1: Core abstractions and interfaces
9 | echo Creating commit 1: Core abstractions...
10 | git add cognidb/core/__init__.py
11 | git add cognidb/core/exceptions.py
12 | git add cognidb/core/interfaces.py
13 | git add cognidb/core/query_intent.py
14 | git commit -m "feat(core): Add core abstractions and interfaces" -m "" -m "- Implement QueryIntent for database-agnostic query representation" -m "- Define abstract interfaces for drivers, translators, and validators" -m "- Add comprehensive exception hierarchy" -m "- Create foundation for modular architecture"
15 | echo.
16 |
17 | REM Commit 2: Security layer
18 | echo Creating commit 2: Security layer...
19 | git add cognidb/security/__init__.py
20 | git add cognidb/security/validator.py
21 | git add cognidb/security/sanitizer.py
22 | git add cognidb/security/query_parser.py
23 | git add cognidb/security/access_control.py
24 | git commit -m "feat(security): Implement comprehensive security layer" -m "" -m "- Add multi-layer query validation to prevent SQL injection" -m "- Implement input sanitization for all user inputs" -m "- Create SQL query parser for security analysis" -m "- Add access control with table/column/row-level permissions" -m "- Enforce parameterized queries throughout"
25 | echo.
26 |
27 | REM Commit 3: Configuration management
28 | echo Creating commit 3: Configuration system...
29 | git add cognidb/config/__init__.py
30 | git add cognidb/config/settings.py
31 | git add cognidb/config/secrets.py
32 | git add cognidb/config/loader.py
33 | git add cognidb.example.yaml
34 | git commit -m "feat(config): Add flexible configuration management" -m "" -m "- Implement settings with dataclasses for type safety" -m "- Add secrets manager supporting multiple providers" -m "- Create config loader with YAML/JSON/env support" -m "- Add example configuration file" -m "- Support for environment variable interpolation"
35 | echo.
36 |
37 | REM Commit 4: AI/LLM integration
38 | echo Creating commit 4: AI/LLM integration...
39 | git add cognidb/ai/__init__.py
40 | git add cognidb/ai/llm_manager.py
41 | git add cognidb/ai/providers.py
42 | git add cognidb/ai/prompt_builder.py
43 | git add cognidb/ai/query_generator.py
44 | git add cognidb/ai/cost_tracker.py
45 | git commit -m "feat(ai): Implement modern LLM integration" -m "" -m "- Add multi-provider support (OpenAI, Anthropic, Azure, etc)" -m "- Implement cost tracking with daily limits" -m "- Create advanced prompt builder with few-shot learning" -m "- Add query generation with optimization suggestions" -m "- Include response caching to reduce costs" -m "- Support streaming and fallback providers"
46 | echo.
47 |
48 | REM Commit 5: Database drivers
49 | echo Creating commit 5: Secure database drivers...
50 | git add cognidb/drivers/__init__.py
51 | git add cognidb/drivers/base_driver.py
52 | git add cognidb/drivers/mysql_driver.py
53 | git add cognidb/drivers/postgres_driver.py
54 | git commit -m "feat(drivers): Add secure database drivers" -m "" -m "- Implement base driver with common functionality" -m "- Add MySQL driver with connection pooling and SSL support" -m "- Add PostgreSQL driver with prepared statements" -m "- Use parameterized queries exclusively" -m "- Include timeout management and result limiting" -m "- Add schema caching for performance"
55 | echo.
56 |
57 | REM Commit 6: Main CogniDB class
58 | echo Creating commit 6: Main CogniDB interface...
59 | git add __init__.py
60 | git commit -m "feat: Implement main CogniDB class with new architecture" -m "" -m "- Create unified interface for natural language queries" -m "- Integrate all components (security, AI, drivers, config)" -m "- Add context manager support" -m "- Include query optimization and suggestions" -m "- Implement audit logging" -m "- Add comprehensive error handling"
61 | echo.
62 |
63 | REM Commit 7: Documentation and examples
64 | echo Creating commit 7: Documentation and examples...
65 | git add Readme.md
66 | git add examples/basic_usage.py
67 | git commit -m "docs: Update documentation for v2.0.0" -m "" -m "- Rewrite README with security-first approach" -m "- Add comprehensive usage examples" -m "- Document all features and configuration options" -m "- Include security best practices" -m "- Add performance tips" -m "- Update badges and project description"
68 | echo.
69 |
70 | REM Commit 8: Package setup and requirements
71 | echo Creating commit 8: Package configuration...
72 | git add setup.py
73 | git add requirements.txt
74 | git commit -m "build: Update package configuration for v2.0.0" -m "" -m "- Update requirements with all dependencies" -m "- Configure setup.py with optional extras" -m "- Add proper classifiers and metadata" -m "- Include console script entry point" -m "- Separate core and optional dependencies"
75 | echo.
76 |
77 | REM Commit 9: Remove old insecure files
78 | echo Creating commit 9: Clean up old implementation...
79 | git add -A
80 | git commit -m "refactor: Remove old insecure implementation" -m "" -m "- Remove vulnerable SQL string interpolation code" -m "- Delete unused modules" -m "- Remove redundant wrappers" -m "- Clean up old database implementations" -m "- Remove insecure query validator" -m "" -m "BREAKING CHANGE: Complete API redesign for v2.0.0"
81 | echo.
82 |
83 | echo All commits created successfully!
84 | echo.
85 | echo Repository status:
86 | git status
87 | echo.
88 | echo Commit history:
89 | git log --oneline -n 9
90 | echo.
91 | pause
--------------------------------------------------------------------------------
/git_commit_script.ps1:
--------------------------------------------------------------------------------
1 | # PowerShell script to create multiple logical commits for CogniDB v2.0.0
2 | # Each commit adds related files with appropriate messages
3 |
4 | Write-Host "Starting CogniDB v2.0.0 commit process..." -ForegroundColor Green
5 | Write-Host ""
6 |
7 | # Commit 1: Core abstractions and interfaces
8 | Write-Host "Creating commit 1: Core abstractions..." -ForegroundColor Yellow
9 | git add cognidb/core/__init__.py
10 | git add cognidb/core/exceptions.py
11 | git add cognidb/core/interfaces.py
12 | git add cognidb/core/query_intent.py
13 |
14 | $commit1Message = @"
15 | feat(core): Add core abstractions and interfaces
16 |
17 | - Implement QueryIntent for database-agnostic query representation
18 | - Define abstract interfaces for drivers, translators, and validators
19 | - Add comprehensive exception hierarchy
20 | - Create foundation for modular architecture
21 | "@
22 | git commit -m $commit1Message
23 | Write-Host ""
24 |
25 | # Commit 2: Security layer
26 | Write-Host "Creating commit 2: Security layer..." -ForegroundColor Yellow
27 | git add cognidb/security/__init__.py
28 | git add cognidb/security/validator.py
29 | git add cognidb/security/sanitizer.py
30 | git add cognidb/security/query_parser.py
31 | git add cognidb/security/access_control.py
32 |
33 | $commit2Message = @"
34 | feat(security): Implement comprehensive security layer
35 |
36 | - Add multi-layer query validation to prevent SQL injection
37 | - Implement input sanitization for all user inputs
38 | - Create SQL query parser for security analysis
39 | - Add access control with table/column/row-level permissions
40 | - Enforce parameterized queries throughout
41 | "@
42 | git commit -m $commit2Message
43 | Write-Host ""
44 |
45 | # Commit 3: Configuration management
46 | Write-Host "Creating commit 3: Configuration system..." -ForegroundColor Yellow
47 | git add cognidb/config/__init__.py
48 | git add cognidb/config/settings.py
49 | git add cognidb/config/secrets.py
50 | git add cognidb/config/loader.py
51 | git add cognidb.example.yaml
52 |
53 | $commit3Message = @"
54 | feat(config): Add flexible configuration management
55 |
56 | - Implement settings with dataclasses for type safety
57 | - Add secrets manager supporting multiple providers (env, file, AWS, Vault)
58 | - Create config loader with YAML/JSON/env support
59 | - Add example configuration file
60 | - Support for environment variable interpolation
61 | "@
62 | git commit -m $commit3Message
63 | Write-Host ""
64 |
65 | # Commit 4: AI/LLM integration
66 | Write-Host "Creating commit 4: AI/LLM integration..." -ForegroundColor Yellow
67 | git add cognidb/ai/__init__.py
68 | git add cognidb/ai/llm_manager.py
69 | git add cognidb/ai/providers.py
70 | git add cognidb/ai/prompt_builder.py
71 | git add cognidb/ai/query_generator.py
72 | git add cognidb/ai/cost_tracker.py
73 |
74 | $commit4Message = @"
75 | feat(ai): Implement modern LLM integration
76 |
77 | - Add multi-provider support (OpenAI, Anthropic, Azure, HuggingFace, Local)
78 | - Implement cost tracking with daily limits
79 | - Create advanced prompt builder with few-shot learning
80 | - Add query generation with optimization suggestions
81 | - Include response caching to reduce costs
82 | - Support streaming and fallback providers
83 | "@
84 | git commit -m $commit4Message
85 | Write-Host ""
86 |
87 | # Commit 5: Database drivers
88 | Write-Host "Creating commit 5: Secure database drivers..." -ForegroundColor Yellow
89 | git add cognidb/drivers/__init__.py
90 | git add cognidb/drivers/base_driver.py
91 | git add cognidb/drivers/mysql_driver.py
92 | git add cognidb/drivers/postgres_driver.py
93 |
94 | $commit5Message = @"
95 | feat(drivers): Add secure database drivers
96 |
97 | - Implement base driver with common functionality
98 | - Add MySQL driver with connection pooling and SSL support
99 | - Add PostgreSQL driver with prepared statements
100 | - Use parameterized queries exclusively
101 | - Include timeout management and result limiting
102 | - Add schema caching for performance
103 | "@
104 | git commit -m $commit5Message
105 | Write-Host ""
106 |
107 | # Commit 6: Main CogniDB class
108 | Write-Host "Creating commit 6: Main CogniDB interface..." -ForegroundColor Yellow
109 | git add __init__.py
110 |
111 | $commit6Message = @"
112 | feat: Implement main CogniDB class with new architecture
113 |
114 | - Create unified interface for natural language queries
115 | - Integrate all components (security, AI, drivers, config)
116 | - Add context manager support
117 | - Include query optimization and suggestions
118 | - Implement audit logging
119 | - Add comprehensive error handling
120 | "@
121 | git commit -m $commit6Message
122 | Write-Host ""
123 |
124 | # Commit 7: Documentation and examples
125 | Write-Host "Creating commit 7: Documentation and examples..." -ForegroundColor Yellow
126 | git add Readme.md
127 | git add examples/basic_usage.py
128 |
129 | $commit7Message = @"
130 | docs: Update documentation for v2.0.0
131 |
132 | - Rewrite README with security-first approach
133 | - Add comprehensive usage examples
134 | - Document all features and configuration options
135 | - Include security best practices
136 | - Add performance tips
137 | - Update badges and project description
138 | "@
139 | git commit -m $commit7Message
140 | Write-Host ""
141 |
142 | # Commit 8: Package setup and requirements
143 | Write-Host "Creating commit 8: Package configuration..." -ForegroundColor Yellow
144 | git add setup.py
145 | git add requirements.txt
146 |
147 | $commit8Message = @"
148 | build: Update package configuration for v2.0.0
149 |
150 | - Update requirements with all dependencies
151 | - Configure setup.py with optional extras
152 | - Add proper classifiers and metadata
153 | - Include console script entry point
154 | - Separate core and optional dependencies
155 | "@
156 | git commit -m $commit8Message
157 | Write-Host ""
158 |
159 | # Commit 9: Remove old insecure files
160 | Write-Host "Creating commit 9: Clean up old implementation..." -ForegroundColor Yellow
161 | git add -A
162 |
163 | $commit9Message = @"
164 | refactor: Remove old insecure implementation
165 |
166 | - Remove vulnerable SQL string interpolation code
167 | - Delete unused modules (clarification_handler, user_input_processor)
168 | - Remove redundant wrappers (db_connection, schema_fetcher)
169 | - Clean up old database implementations
170 | - Remove insecure query validator
171 | - Delete analysis report
172 |
173 | BREAKING CHANGE: Complete API redesign for v2.0.0
174 | "@
175 | git commit -m $commit9Message
176 | Write-Host ""
177 |
178 | Write-Host "All commits created successfully!" -ForegroundColor Green
179 | Write-Host ""
180 | Write-Host "Repository status:" -ForegroundColor Cyan
181 | git status
182 | Write-Host ""
183 | Write-Host "Commit history:" -ForegroundColor Cyan
184 | git log --oneline -n 9
185 | Write-Host ""
186 | Write-Host "Press any key to continue..."
187 | $null = $Host.UI.RawUI.ReadKey("NoEcho,IncludeKeyDown")
--------------------------------------------------------------------------------
/cognidb/core/interfaces.py:
--------------------------------------------------------------------------------
1 | """Core interfaces for CogniDB components."""
2 |
3 | from abc import ABC, abstractmethod
4 | from typing import Dict, List, Any, Optional, Tuple
5 | from .query_intent import QueryIntent
6 |
7 |
8 | class DatabaseDriver(ABC):
9 | """Abstract base class for database drivers."""
10 |
11 | @abstractmethod
12 | def connect(self) -> None:
13 | """Establish connection to the database."""
14 | pass
15 |
16 | @abstractmethod
17 | def disconnect(self) -> None:
18 | """Close the database connection."""
19 | pass
20 |
21 | @abstractmethod
22 | def execute_native_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
23 | """
24 | Execute a native query with parameters.
25 |
26 | Args:
27 | query: Native query string with parameter placeholders
28 | params: Parameter values for the query
29 |
30 | Returns:
31 | List of result rows as dictionaries
32 | """
33 | pass
34 |
35 | @abstractmethod
36 | def fetch_schema(self) -> Dict[str, Dict[str, str]]:
37 | """
38 | Fetch database schema.
39 |
40 | Returns:
41 | Dictionary mapping table names to column info:
42 | {
43 | 'table_name': {
44 | 'column_name': 'data_type',
45 | ...
46 | },
47 | ...
48 | }
49 | """
50 | pass
51 |
52 | @abstractmethod
53 | def validate_table_name(self, table_name: str) -> bool:
54 | """Validate that a table name exists and is safe."""
55 | pass
56 |
57 | @abstractmethod
58 | def validate_column_name(self, table_name: str, column_name: str) -> bool:
59 | """Validate that a column exists in the table."""
60 | pass
61 |
62 | @abstractmethod
63 | def get_connection_info(self) -> Dict[str, Any]:
64 | """Get connection information (for debugging, minus secrets)."""
65 | pass
66 |
67 | @property
68 | @abstractmethod
69 | def supports_transactions(self) -> bool:
70 | """Whether this driver supports transactions."""
71 | pass
72 |
73 | @property
74 | @abstractmethod
75 | def supports_schemas(self) -> bool:
76 | """Whether this database supports schemas/namespaces."""
77 | pass
78 |
79 |
80 | class QueryTranslator(ABC):
81 | """Abstract base class for query translators."""
82 |
83 | @abstractmethod
84 | def translate(self, query_intent: QueryIntent) -> Tuple[str, Dict[str, Any]]:
85 | """
86 | Translate a QueryIntent into a native query.
87 |
88 | Args:
89 | query_intent: The query intent to translate
90 |
91 | Returns:
92 | Tuple of (query_string, parameters_dict)
93 | """
94 | pass
95 |
96 | @abstractmethod
97 | def validate_intent(self, query_intent: QueryIntent) -> List[str]:
98 | """
99 | Validate that the query intent can be translated.
100 |
101 | Returns:
102 | List of validation errors (empty if valid)
103 | """
104 | pass
105 |
106 | @property
107 | @abstractmethod
108 | def supported_features(self) -> Dict[str, bool]:
109 | """
110 | Return supported features for this translator.
111 |
112 | Example:
113 | {
114 | 'joins': True,
115 | 'subqueries': False,
116 | 'window_functions': True,
117 | 'cte': False
118 | }
119 | """
120 | pass
121 |
122 |
123 | class SecurityValidator(ABC):
124 | """Abstract base class for security validators."""
125 |
126 | @abstractmethod
127 | def validate_query_intent(self, query_intent: QueryIntent) -> Tuple[bool, Optional[str]]:
128 | """
129 | Validate query intent for security issues.
130 |
131 | Returns:
132 | Tuple of (is_valid, error_message)
133 | """
134 | pass
135 |
136 | @abstractmethod
137 | def validate_native_query(self, query: str) -> Tuple[bool, Optional[str]]:
138 | """
139 | Validate native query for security issues.
140 |
141 | Returns:
142 | Tuple of (is_valid, error_message)
143 | """
144 | pass
145 |
146 | @abstractmethod
147 | def sanitize_identifier(self, identifier: str) -> str:
148 | """Sanitize a table/column identifier."""
149 | pass
150 |
151 | @abstractmethod
152 | def sanitize_value(self, value: Any) -> Any:
153 | """Sanitize a parameter value."""
154 | pass
155 |
156 | @property
157 | @abstractmethod
158 | def allowed_operations(self) -> List[str]:
159 | """List of allowed query operations."""
160 | pass
161 |
162 |
163 | class ResultNormalizer(ABC):
164 | """Abstract base class for result normalizers."""
165 |
166 | @abstractmethod
167 | def normalize(self, raw_results: Any) -> List[Dict[str, Any]]:
168 | """
169 | Normalize database-specific results to standard format.
170 |
171 | Args:
172 | raw_results: Raw results from database driver
173 |
174 | Returns:
175 | List of dictionaries with consistent structure
176 | """
177 | pass
178 |
179 | @abstractmethod
180 | def format_for_output(self,
181 | normalized_results: List[Dict[str, Any]],
182 | output_format: str = 'json') -> Any:
183 | """
184 | Format normalized results for final output.
185 |
186 | Args:
187 | normalized_results: Normalized result set
188 | output_format: One of 'json', 'csv', 'table', 'dataframe'
189 |
190 | Returns:
191 | Formatted results
192 | """
193 | pass
194 |
195 |
196 | class CacheProvider(ABC):
197 | """Abstract base class for cache providers."""
198 |
199 | @abstractmethod
200 | def get(self, key: str) -> Optional[Any]:
201 | """Retrieve value from cache."""
202 | pass
203 |
204 | @abstractmethod
205 | def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
206 | """
207 | Store value in cache.
208 |
209 | Args:
210 | key: Cache key
211 | value: Value to cache
212 | ttl: Time to live in seconds (None = no expiration)
213 |
214 | Returns:
215 | Success status
216 | """
217 | pass
218 |
219 | @abstractmethod
220 | def delete(self, key: str) -> bool:
221 | """Delete value from cache."""
222 | pass
223 |
224 | @abstractmethod
225 | def clear(self) -> bool:
226 | """Clear all cached values."""
227 | pass
228 |
229 | @abstractmethod
230 | def get_stats(self) -> Dict[str, Any]:
231 | """Get cache statistics."""
232 | pass
--------------------------------------------------------------------------------
/examples/basic_usage.py:
--------------------------------------------------------------------------------
1 | """
2 | Basic usage examples for CogniDB.
3 |
4 | This demonstrates how to use CogniDB for natural language database queries
5 | with various configurations and features.
6 | """
7 |
8 | import os
9 | from cognidb import CogniDB, create_cognidb
10 |
11 |
12 | def basic_example():
13 | """Basic usage with environment variables."""
14 | # Set environment variables (in production, use .env file or secrets manager)
15 | os.environ['DB_TYPE'] = 'postgresql'
16 | os.environ['DB_HOST'] = 'localhost'
17 | os.environ['DB_PORT'] = '5432'
18 | os.environ['DB_NAME'] = 'mydb'
19 | os.environ['DB_USER'] = 'myuser'
20 | os.environ['DB_PASSWORD'] = 'mypassword'
21 | os.environ['LLM_API_KEY'] = 'your-openai-api-key'
22 |
23 | # Create CogniDB instance
24 | with CogniDB() as db:
25 | # Simple query
26 | result = db.query("Show me all customers who made a purchase last month")
27 |
28 | if result['success']:
29 | print(f"Generated SQL: {result['sql']}")
30 | print(f"Found {result['row_count']} results")
31 | for row in result['results'][:5]: # Show first 5
32 | print(row)
33 | else:
34 | print(f"Error: {result['error']}")
35 |
36 |
37 | def config_file_example():
38 | """Usage with configuration file."""
39 | # Use configuration file
40 | with CogniDB(config_file='cognidb.yaml') as db:
41 | # Query with explanation
42 | result = db.query(
43 | "What are the top 5 products by revenue?",
44 | explain=True
45 | )
46 |
47 | if result['success']:
48 | print(f"SQL: {result['sql']}")
49 | print(f"\nExplanation: {result['explanation']}")
50 | print(f"\nResults:")
51 | for row in result['results']:
52 | print(f" {row}")
53 |
54 |
55 | def advanced_features_example():
56 | """Demonstrate advanced features."""
57 | db = create_cognidb(
58 | database={
59 | 'type': 'postgresql',
60 | 'host': 'localhost',
61 | 'database': 'analytics_db'
62 | },
63 | llm={
64 | 'provider': 'openai',
65 | 'model_name': 'gpt-4',
66 | 'temperature': 0.1
67 | }
68 | )
69 |
70 | try:
71 | # 1. Query suggestions
72 | suggestions = db.suggest_queries("customers who")
73 | print("Query suggestions:")
74 | for suggestion in suggestions:
75 | print(f" - {suggestion}")
76 |
77 | # 2. Query optimization
78 | sql = "SELECT * FROM orders WHERE customer_id IN (SELECT id FROM customers WHERE country = 'USA')"
79 | optimization = db.optimize_query(sql)
80 |
81 | if optimization['success']:
82 | print(f"\nOriginal: {optimization['original_query']}")
83 | print(f"Optimized: {optimization['optimized_query']}")
84 | print(f"Explanation: {optimization['explanation']}")
85 |
86 | # 3. Schema inspection
87 | schema = db.get_schema('customers')
88 | print(f"\nCustomers table schema:")
89 | for column, dtype in schema['customers'].items():
90 | print(f" {column}: {dtype}")
91 |
92 | # 4. Usage statistics
93 | stats = db.get_usage_stats()
94 | print(f"\nUsage stats:")
95 | print(f" Total cost: ${stats['total_cost']:.2f}")
96 | print(f" Requests today: {stats['request_count']}")
97 |
98 | finally:
99 | db.close()
100 |
101 |
102 | def multi_database_example():
103 | """Example with different database types."""
104 | databases = [
105 | {
106 | 'type': 'mysql',
107 | 'config': {
108 | 'host': 'mysql.example.com',
109 | 'database': 'app_db'
110 | }
111 | },
112 | {
113 | 'type': 'postgresql',
114 | 'config': {
115 | 'host': 'postgres.example.com',
116 | 'database': 'analytics_db'
117 | }
118 | },
119 | {
120 | 'type': 'sqlite',
121 | 'config': {
122 | 'database': '/path/to/local.db'
123 | }
124 | }
125 | ]
126 |
127 | for db_info in databases:
128 | print(f"\nQuerying {db_info['type']} database:")
129 |
130 | with create_cognidb(database=db_info['config']) as db:
131 | result = db.query("Count the total number of records in the main table")
132 |
133 | if result['success']:
134 | print(f" SQL: {result['sql']}")
135 | print(f" Result: {result['results']}")
136 |
137 |
138 | def security_example():
139 | """Demonstrate security features."""
140 | # Configure with strict security
141 | db = create_cognidb(
142 | security={
143 | 'allow_only_select': True,
144 | 'max_query_complexity': 5,
145 | 'allow_subqueries': False,
146 | 'enable_rate_limiting': True,
147 | 'rate_limit_per_minute': 10
148 | }
149 | )
150 |
151 | # These queries will be validated
152 | safe_queries = [
153 | "Show me all active users",
154 | "What's the average order value?",
155 | "List products with low stock"
156 | ]
157 |
158 | # These will be rejected
159 | unsafe_queries = [
160 | "Delete all users where active = false", # Not SELECT
161 | "Update products set price = price * 1.1", # Not SELECT
162 | "'; DROP TABLE users; --", # SQL injection attempt
163 | ]
164 |
165 | print("Testing safe queries:")
166 | for query in safe_queries:
167 | result = db.query(query)
168 | print(f" '{query}' - Success: {result['success']}")
169 |
170 | print("\nTesting unsafe queries (should be rejected):")
171 | for query in unsafe_queries:
172 | result = db.query(query)
173 | print(f" '{query}' - Success: {result['success']}, Error: {result.get('error', '')[:50]}...")
174 |
175 | db.close()
176 |
177 |
178 | def custom_llm_example():
179 | """Example with custom LLM configuration."""
180 | # Use Anthropic Claude
181 | db_claude = create_cognidb(
182 | llm={
183 | 'provider': 'anthropic',
184 | 'api_key': os.environ.get('ANTHROPIC_API_KEY'),
185 | 'model_name': 'claude-3-sonnet',
186 | 'temperature': 0.0
187 | }
188 | )
189 |
190 | # Use local model
191 | db_local = create_cognidb(
192 | llm={
193 | 'provider': 'local',
194 | 'local_model_path': '/path/to/model.gguf',
195 | 'max_tokens': 500
196 | }
197 | )
198 |
199 | # Use Azure OpenAI
200 | db_azure = create_cognidb(
201 | llm={
202 | 'provider': 'azure_openai',
203 | 'api_key': os.environ.get('AZURE_OPENAI_KEY'),
204 | 'azure_endpoint': 'https://myorg.openai.azure.com/',
205 | 'azure_deployment': 'gpt-4-deployment'
206 | }
207 | )
208 |
209 |
210 | if __name__ == "__main__":
211 | # Run examples
212 | print("=== Basic Example ===")
213 | basic_example()
214 |
215 | print("\n=== Config File Example ===")
216 | config_file_example()
217 |
218 | print("\n=== Advanced Features ===")
219 | advanced_features_example()
220 |
221 | print("\n=== Security Example ===")
222 | security_example()
--------------------------------------------------------------------------------
/cognidb/drivers/base_driver.py:
--------------------------------------------------------------------------------
1 | """Base implementation for database drivers with common functionality."""
2 |
3 | import time
4 | import logging
5 | from typing import Dict, List, Any, Optional, Tuple
6 | from contextlib import contextmanager
7 | from abc import abstractmethod
8 | from ..core.interfaces import DatabaseDriver
9 | from ..core.exceptions import ConnectionError, ExecutionError, SchemaError
10 | from ..security.sanitizer import InputSanitizer
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | class BaseDriver(DatabaseDriver):
16 | """
17 | Base database driver with common functionality.
18 |
19 | Provides:
20 | - Connection management
21 | - Query execution with timeouts
22 | - Schema caching
23 | - Security validation
24 | - Error handling
25 | """
26 |
27 | def __init__(self, config: Dict[str, Any]):
28 | """
29 | Initialize base driver.
30 |
31 | Args:
32 | config: Database configuration
33 | """
34 | self.config = config
35 | self.connection = None
36 | self.schema_cache = None
37 | self.schema_cache_time = 0
38 | self.schema_cache_ttl = 3600 # 1 hour
39 | self.sanitizer = InputSanitizer()
40 | self._connection_time = None
41 |
42 | @contextmanager
43 | def transaction(self):
44 | """Context manager for database transactions."""
45 | if not self.supports_transactions:
46 | yield
47 | return
48 |
49 | try:
50 | self._begin_transaction()
51 | yield
52 | self._commit_transaction()
53 | except Exception as e:
54 | self._rollback_transaction()
55 | raise ExecutionError(f"Transaction failed: {str(e)}")
56 |
57 | def execute_native_query(self,
58 | query: str,
59 | params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
60 | """
61 | Execute a native query with parameters.
62 |
63 | Args:
64 | query: Native query string with parameter placeholders
65 | params: Parameter values for the query
66 |
67 | Returns:
68 | List of result rows as dictionaries
69 | """
70 | if not self.connection:
71 | raise ConnectionError("Not connected to database")
72 |
73 | # Log query for debugging (without params for security)
74 | logger.debug(f"Executing query: {query[:100]}...")
75 |
76 | start_time = time.time()
77 |
78 | try:
79 | # Execute with timeout
80 | results = self._execute_with_timeout(query, params)
81 |
82 | # Log execution time
83 | execution_time = time.time() - start_time
84 | logger.info(f"Query executed in {execution_time:.2f}s")
85 |
86 | return results
87 |
88 | except Exception as e:
89 | logger.error(f"Query execution failed: {str(e)}")
90 | raise ExecutionError(f"Query execution failed: {str(e)}")
91 |
92 | def fetch_schema(self) -> Dict[str, Dict[str, str]]:
93 | """
94 | Fetch database schema with caching.
95 |
96 | Returns:
97 | Dictionary mapping table names to column info
98 | """
99 | # Check cache
100 | if self.schema_cache and (time.time() - self.schema_cache_time) < self.schema_cache_ttl:
101 | logger.debug("Using cached schema")
102 | return self.schema_cache
103 |
104 | logger.info("Fetching database schema")
105 |
106 | try:
107 | schema = self._fetch_schema_impl()
108 |
109 | # Cache the schema
110 | self.schema_cache = schema
111 | self.schema_cache_time = time.time()
112 |
113 | logger.info(f"Schema fetched: {len(schema)} tables")
114 | return schema
115 |
116 | except Exception as e:
117 | logger.error(f"Failed to fetch schema: {str(e)}")
118 | raise SchemaError(f"Failed to fetch schema: {str(e)}")
119 |
120 | def validate_table_name(self, table_name: str) -> bool:
121 | """Validate that a table name exists and is safe."""
122 | # Sanitize the table name
123 | try:
124 | sanitized = self.sanitizer.sanitize_identifier(table_name)
125 | except ValueError:
126 | return False
127 |
128 | # Check if table exists in schema
129 | schema = self.fetch_schema()
130 | return sanitized in schema
131 |
132 | def validate_column_name(self, table_name: str, column_name: str) -> bool:
133 | """Validate that a column exists in the table."""
134 | # Validate table first
135 | if not self.validate_table_name(table_name):
136 | return False
137 |
138 | # Sanitize column name
139 | try:
140 | sanitized_column = self.sanitizer.sanitize_identifier(column_name)
141 | except ValueError:
142 | return False
143 |
144 | # Check if column exists
145 | schema = self.fetch_schema()
146 | table_columns = schema.get(table_name, {})
147 | return sanitized_column in table_columns
148 |
149 | def get_connection_info(self) -> Dict[str, Any]:
150 | """Get connection information (for debugging, minus secrets)."""
151 | info = {
152 | 'driver': self.__class__.__name__,
153 | 'host': self.config.get('host', 'N/A'),
154 | 'port': self.config.get('port', 'N/A'),
155 | 'database': self.config.get('database', 'N/A'),
156 | 'connected': self.connection is not None,
157 | 'connection_time': self._connection_time,
158 | 'schema_tables': len(self.schema_cache) if self.schema_cache else 0
159 | }
160 |
161 | # Add driver-specific info
162 | info.update(self._get_driver_info())
163 |
164 | return info
165 |
166 | def invalidate_schema_cache(self):
167 | """Invalidate the schema cache."""
168 | self.schema_cache = None
169 | self.schema_cache_time = 0
170 | logger.info("Schema cache invalidated")
171 |
172 | def ping(self) -> bool:
173 | """Check if connection is alive."""
174 | try:
175 | # Simple query to test connection
176 | self.execute_native_query("SELECT 1")
177 | return True
178 | except Exception:
179 | return False
180 |
181 | def reconnect(self):
182 | """Reconnect to the database."""
183 | logger.info("Attempting to reconnect")
184 | self.disconnect()
185 | self.connect()
186 |
187 | # Abstract methods to be implemented by subclasses
188 |
189 | @abstractmethod
190 | def _create_connection(self):
191 | """Create the actual database connection."""
192 | pass
193 |
194 | @abstractmethod
195 | def _close_connection(self):
196 | """Close the database connection."""
197 | pass
198 |
199 | @abstractmethod
200 | def _execute_with_timeout(self, query: str, params: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]:
201 | """Execute query with timeout (implementation specific)."""
202 | pass
203 |
204 | @abstractmethod
205 | def _fetch_schema_impl(self) -> Dict[str, Dict[str, str]]:
206 | """Fetch schema implementation."""
207 | pass
208 |
209 | @abstractmethod
210 | def _begin_transaction(self):
211 | """Begin a transaction."""
212 | pass
213 |
214 | @abstractmethod
215 | def _commit_transaction(self):
216 | """Commit a transaction."""
217 | pass
218 |
219 | @abstractmethod
220 | def _rollback_transaction(self):
221 | """Rollback a transaction."""
222 | pass
223 |
224 | @abstractmethod
225 | def _get_driver_info(self) -> Dict[str, Any]:
226 | """Get driver-specific information."""
227 | pass
--------------------------------------------------------------------------------
/cognidb/security/query_parser.py:
--------------------------------------------------------------------------------
1 | """SQL query parser for security validation."""
2 |
3 | import re
4 | from typing import Dict, Any, List, Optional
5 | import sqlparse
6 | from sqlparse.sql import IdentifierList, Identifier, Token
7 | from sqlparse.tokens import Keyword, DML
8 |
9 |
10 | class SQLQueryParser:
11 | """
12 | SQL query parser for security analysis.
13 |
14 | Parses SQL queries to extract structure and identify
15 | potential security issues.
16 | """
17 |
18 | def __init__(self):
19 | """Initialize the parser."""
20 | self.parsed_cache = {}
21 |
22 | def parse(self, query: str) -> Dict[str, Any]:
23 | """
24 | Parse SQL query and extract security-relevant information.
25 |
26 | Args:
27 | query: SQL query string
28 |
29 | Returns:
30 | Dictionary with parsed query information:
31 | {
32 | 'type': 'SELECT',
33 | 'tables': ['users', 'orders'],
34 | 'columns': ['id', 'name'],
35 | 'has_subquery': False,
36 | 'has_union': False,
37 | 'has_join': True,
38 | 'has_where': True,
39 | 'complexity': 5
40 | }
41 | """
42 | # Clean and normalize query
43 | query = query.strip()
44 |
45 | # Check cache
46 | cache_key = hash(query)
47 | if cache_key in self.parsed_cache:
48 | return self.parsed_cache[cache_key]
49 |
50 | # Parse with sqlparse
51 | parsed = sqlparse.parse(query)[0]
52 |
53 | result = {
54 | 'type': self._get_query_type(parsed),
55 | 'tables': self._extract_tables(parsed),
56 | 'columns': self._extract_columns(parsed),
57 | 'has_subquery': self._has_subquery(parsed),
58 | 'has_union': self._has_union(query),
59 | 'has_join': self._has_join(parsed),
60 | 'has_where': self._has_where(parsed),
61 | 'has_having': self._has_having(parsed),
62 | 'has_order_by': self._has_order_by(parsed),
63 | 'has_group_by': self._has_group_by(parsed),
64 | 'complexity': self._calculate_complexity(parsed)
65 | }
66 |
67 | # Cache result
68 | self.parsed_cache[cache_key] = result
69 |
70 | return result
71 |
72 | def _get_query_type(self, parsed) -> str:
73 | """Extract the main query type."""
74 | for token in parsed.tokens:
75 | if token.ttype is DML:
76 | return token.value.upper()
77 | return "UNKNOWN"
78 |
79 | def _extract_tables(self, parsed) -> List[str]:
80 | """Extract table names from the query."""
81 | tables = []
82 | from_seen = False
83 |
84 | for token in parsed.tokens:
85 | if from_seen:
86 | if isinstance(token, IdentifierList):
87 | for identifier in token.get_identifiers():
88 | tables.append(self._get_name(identifier))
89 | elif isinstance(token, Identifier):
90 | tables.append(self._get_name(token))
91 | elif token.ttype is None:
92 | tables.append(token.value)
93 |
94 | if token.ttype is Keyword and token.value.upper() == 'FROM':
95 | from_seen = True
96 | elif token.ttype is Keyword and token.value.upper() in ('WHERE', 'GROUP', 'ORDER', 'HAVING'):
97 | from_seen = False
98 |
99 | return [t.strip() for t in tables if t.strip()]
100 |
101 | def _extract_columns(self, parsed) -> List[str]:
102 | """Extract column names from SELECT clause."""
103 | columns = []
104 | select_seen = False
105 |
106 | for token in parsed.tokens:
107 | if select_seen and token.ttype is Keyword:
108 | break
109 |
110 | if select_seen:
111 | if isinstance(token, IdentifierList):
112 | for identifier in token.get_identifiers():
113 | columns.append(self._get_name(identifier))
114 | elif isinstance(token, Identifier):
115 | columns.append(self._get_name(token))
116 | elif token.ttype is None and token.value not in (',', ' '):
117 | columns.append(token.value)
118 |
119 | if token.ttype is DML and token.value.upper() == 'SELECT':
120 | select_seen = True
121 |
122 | return [c.strip() for c in columns if c.strip()]
123 |
124 | def _get_name(self, identifier) -> str:
125 | """Get the name from an identifier."""
126 | if hasattr(identifier, 'get_name'):
127 | return identifier.get_name()
128 | return str(identifier)
129 |
130 | def _has_subquery(self, parsed) -> bool:
131 | """Check if query contains subqueries."""
132 | query_str = str(parsed)
133 | # Simple check for nested SELECT
134 | return query_str.count('SELECT') > 1
135 |
136 | def _has_union(self, query: str) -> bool:
137 | """Check if query contains UNION."""
138 | return bool(re.search(r'\bUNION\b', query, re.IGNORECASE))
139 |
140 | def _has_join(self, parsed) -> bool:
141 | """Check if query contains JOIN."""
142 | for token in parsed.tokens:
143 | if token.ttype is Keyword and 'JOIN' in token.value.upper():
144 | return True
145 | return False
146 |
147 | def _has_where(self, parsed) -> bool:
148 | """Check if query has WHERE clause."""
149 | for token in parsed.tokens:
150 | if token.ttype is Keyword and token.value.upper() == 'WHERE':
151 | return True
152 | return False
153 |
154 | def _has_having(self, parsed) -> bool:
155 | """Check if query has HAVING clause."""
156 | for token in parsed.tokens:
157 | if token.ttype is Keyword and token.value.upper() == 'HAVING':
158 | return True
159 | return False
160 |
161 | def _has_order_by(self, parsed) -> bool:
162 | """Check if query has ORDER BY clause."""
163 | query_str = str(parsed).upper()
164 | return 'ORDER BY' in query_str
165 |
166 | def _has_group_by(self, parsed) -> bool:
167 | """Check if query has GROUP BY clause."""
168 | query_str = str(parsed).upper()
169 | return 'GROUP BY' in query_str
170 |
171 | def _calculate_complexity(self, parsed) -> int:
172 | """
173 | Calculate query complexity score.
174 |
175 | Higher scores indicate more complex queries that might
176 | need additional scrutiny.
177 | """
178 | score = 1 # Base score
179 |
180 | # Add complexity for various features
181 | if self._has_subquery(parsed):
182 | score += 3
183 | if self._has_union(str(parsed)):
184 | score += 2
185 | if self._has_join(parsed):
186 | score += 2
187 | if self._has_where(parsed):
188 | score += 1
189 | if self._has_group_by(parsed):
190 | score += 2
191 | if self._has_having(parsed):
192 | score += 2
193 | if self._has_order_by(parsed):
194 | score += 1
195 |
196 | # Add complexity for number of tables
197 | tables = self._extract_tables(parsed)
198 | if len(tables) > 1:
199 | score += len(tables) - 1
200 |
201 | return score
202 |
203 | def validate_structure(self, query: str) -> Optional[str]:
204 | """
205 | Validate query structure and return error if invalid.
206 |
207 | Args:
208 | query: SQL query string
209 |
210 | Returns:
211 | Error message if invalid, None if valid
212 | """
213 | try:
214 | parsed = sqlparse.parse(query)
215 | if not parsed:
216 | return "Empty or invalid query"
217 |
218 | # Check for multiple statements
219 | if len(parsed) > 1:
220 | return "Multiple statements not allowed"
221 |
222 | # Get query type
223 | query_type = self._get_query_type(parsed[0])
224 | if query_type == "UNKNOWN":
225 | return "Unknown query type"
226 |
227 | return None
228 |
229 | except Exception as e:
230 | return f"Query parsing error: {str(e)}"
--------------------------------------------------------------------------------
/cognidb/security/access_control.py:
--------------------------------------------------------------------------------
1 | """Access control and permissions management."""
2 |
3 | from typing import Dict, List, Set, Optional
4 | from dataclasses import dataclass, field
5 | from enum import Enum, auto
6 | from ..core.exceptions import SecurityError
7 |
8 |
9 | class Permission(Enum):
10 | """Database permissions."""
11 | SELECT = auto()
12 | INSERT = auto()
13 | UPDATE = auto()
14 | DELETE = auto()
15 | CREATE = auto()
16 | DROP = auto()
17 | ALTER = auto()
18 | EXECUTE = auto()
19 |
20 |
21 | @dataclass
22 | class TablePermissions:
23 | """Permissions for a specific table."""
24 | table_name: str
25 | allowed_operations: Set[Permission] = field(default_factory=set)
26 | allowed_columns: Optional[Set[str]] = None # None means all columns
27 | row_filter: Optional[str] = None # SQL condition for row-level security
28 |
29 | def can_access_column(self, column: str) -> bool:
30 | """Check if column access is allowed."""
31 | if self.allowed_columns is None:
32 | return True
33 | return column in self.allowed_columns
34 |
35 | def can_perform_operation(self, operation: Permission) -> bool:
36 | """Check if operation is allowed."""
37 | return operation in self.allowed_operations
38 |
39 |
40 | @dataclass
41 | class UserPermissions:
42 | """User's database permissions."""
43 | user_id: str
44 | is_admin: bool = False
45 | table_permissions: Dict[str, TablePermissions] = field(default_factory=dict)
46 | global_permissions: Set[Permission] = field(default_factory=set)
47 | max_rows_per_query: int = 10000
48 | max_execution_time: int = 30 # seconds
49 | allowed_schemas: Set[str] = field(default_factory=set)
50 |
51 | def add_table_permission(self, table_perm: TablePermissions):
52 | """Add permissions for a table."""
53 | self.table_permissions[table_perm.table_name] = table_perm
54 |
55 | def can_access_table(self, table: str) -> bool:
56 | """Check if user can access table."""
57 | if self.is_admin:
58 | return True
59 | return table in self.table_permissions
60 |
61 | def can_perform_operation_on_table(self, table: str, operation: Permission) -> bool:
62 | """Check if user can perform operation on table."""
63 | if self.is_admin:
64 | return True
65 | if table not in self.table_permissions:
66 | return False
67 | return self.table_permissions[table].can_perform_operation(operation)
68 |
69 |
70 | class AccessController:
71 | """
72 | Controls access to database resources.
73 |
74 | Implements:
75 | - Table-level permissions
76 | - Column-level permissions
77 | - Row-level security
78 | - Operation restrictions
79 | - Resource limits
80 | """
81 |
82 | def __init__(self):
83 | """Initialize access controller."""
84 | self.users: Dict[str, UserPermissions] = {}
85 | self.default_permissions = UserPermissions(
86 | user_id="default",
87 | global_permissions={Permission.SELECT},
88 | max_rows_per_query=1000,
89 | max_execution_time=10
90 | )
91 |
92 | def add_user(self, user_permissions: UserPermissions):
93 | """Add user with permissions."""
94 | self.users[user_permissions.user_id] = user_permissions
95 |
96 | def get_user_permissions(self, user_id: str) -> UserPermissions:
97 | """Get user permissions or default if not found."""
98 | return self.users.get(user_id, self.default_permissions)
99 |
100 | def check_table_access(self, user_id: str, tables: List[str]) -> None:
101 | """
102 | Check if user can access all tables.
103 |
104 | Raises:
105 | SecurityError: If access denied
106 | """
107 | permissions = self.get_user_permissions(user_id)
108 |
109 | for table in tables:
110 | if not permissions.can_access_table(table):
111 | raise SecurityError(f"Access denied to table: {table}")
112 |
113 | def check_column_access(self, user_id: str, table: str, columns: List[str]) -> None:
114 | """
115 | Check if user can access columns.
116 |
117 | Raises:
118 | SecurityError: If access denied
119 | """
120 | permissions = self.get_user_permissions(user_id)
121 |
122 | if not permissions.can_access_table(table):
123 | raise SecurityError(f"Access denied to table: {table}")
124 |
125 | if permissions.is_admin:
126 | return
127 |
128 | if table in permissions.table_permissions:
129 | table_perm = permissions.table_permissions[table]
130 | for column in columns:
131 | if not table_perm.can_access_column(column):
132 | raise SecurityError(f"Access denied to column: {table}.{column}")
133 |
134 | def check_operation(self, user_id: str, operation: Permission, tables: List[str]) -> None:
135 | """
136 | Check if user can perform operation.
137 |
138 | Raises:
139 | SecurityError: If operation not allowed
140 | """
141 | permissions = self.get_user_permissions(user_id)
142 |
143 | # Check global permissions first
144 | if operation in permissions.global_permissions:
145 | return
146 |
147 | # Check table-specific permissions
148 | for table in tables:
149 | if not permissions.can_perform_operation_on_table(table, operation):
150 | raise SecurityError(
151 | f"Operation {operation.name} not allowed on table: {table}"
152 | )
153 |
154 | def get_row_filters(self, user_id: str, table: str) -> Optional[str]:
155 | """Get row-level security filters for table."""
156 | permissions = self.get_user_permissions(user_id)
157 |
158 | if permissions.is_admin:
159 | return None
160 |
161 | if table in permissions.table_permissions:
162 | return permissions.table_permissions[table].row_filter
163 |
164 | return None
165 |
166 | def get_resource_limits(self, user_id: str) -> Dict[str, int]:
167 | """Get resource limits for user."""
168 | permissions = self.get_user_permissions(user_id)
169 |
170 | return {
171 | 'max_rows': permissions.max_rows_per_query,
172 | 'max_execution_time': permissions.max_execution_time
173 | }
174 |
175 | def create_read_only_user(self, user_id: str, allowed_tables: List[str]) -> UserPermissions:
176 | """Create a read-only user with access to specific tables."""
177 | user = UserPermissions(
178 | user_id=user_id,
179 | global_permissions={Permission.SELECT},
180 | max_rows_per_query=5000,
181 | max_execution_time=20
182 | )
183 |
184 | for table in allowed_tables:
185 | user.add_table_permission(
186 | TablePermissions(
187 | table_name=table,
188 | allowed_operations={Permission.SELECT}
189 | )
190 | )
191 |
192 | self.add_user(user)
193 | return user
194 |
195 | def create_restricted_user(self,
196 | user_id: str,
197 | table_permissions_dict: Dict[str, Dict[str, Any]]) -> UserPermissions:
198 | """
199 | Create user with specific table permissions.
200 |
201 | Args:
202 | user_id: User identifier
203 | table_permissions_dict: Dictionary of table permissions
204 | {
205 | 'users': {
206 | 'operations': ['SELECT'],
207 | 'columns': ['id', 'name', 'email'],
208 | 'row_filter': "department = 'sales'"
209 | }
210 | }
211 | """
212 | user = UserPermissions(user_id=user_id)
213 |
214 | for table, perms in table_permissions_dict.items():
215 | operations = {
216 | Permission[op] for op in perms.get('operations', ['SELECT'])
217 | }
218 |
219 | table_perm = TablePermissions(
220 | table_name=table,
221 | allowed_operations=operations,
222 | allowed_columns=set(perms['columns']) if 'columns' in perms else None,
223 | row_filter=perms.get('row_filter')
224 | )
225 |
226 | user.add_table_permission(table_perm)
227 |
228 | self.add_user(user)
229 | return user
--------------------------------------------------------------------------------
/cognidb/config/settings.py:
--------------------------------------------------------------------------------
1 | """Settings and configuration classes."""
2 |
3 | from dataclasses import dataclass, field
4 | from typing import Dict, List, Optional, Any
5 | from enum import Enum, auto
6 | import os
7 | from pathlib import Path
8 |
9 |
10 | class DatabaseType(Enum):
11 | """Supported database types."""
12 | MYSQL = "mysql"
13 | POSTGRESQL = "postgresql"
14 | MONGODB = "mongodb"
15 | DYNAMODB = "dynamodb"
16 | SQLITE = "sqlite"
17 |
18 |
19 | class LLMProvider(Enum):
20 | """Supported LLM providers."""
21 | OPENAI = "openai"
22 | ANTHROPIC = "anthropic"
23 | AZURE_OPENAI = "azure_openai"
24 | LOCAL = "local"
25 | HUGGINGFACE = "huggingface"
26 |
27 |
28 | class CacheProvider(Enum):
29 | """Supported cache providers."""
30 | IN_MEMORY = "in_memory"
31 | REDIS = "redis"
32 | MEMCACHED = "memcached"
33 | DISK = "disk"
34 |
35 |
36 | @dataclass
37 | class DatabaseConfig:
38 | """Database configuration."""
39 | type: DatabaseType
40 | host: str
41 | port: int
42 | database: str
43 | username: Optional[str] = None
44 | password: Optional[str] = None
45 |
46 | # Connection pool settings
47 | pool_size: int = 5
48 | max_overflow: int = 10
49 | pool_timeout: int = 30
50 | pool_recycle: int = 3600
51 |
52 | # SSL/TLS settings
53 | ssl_enabled: bool = False
54 | ssl_ca_cert: Optional[str] = None
55 | ssl_client_cert: Optional[str] = None
56 | ssl_client_key: Optional[str] = None
57 |
58 | # Query settings
59 | query_timeout: int = 30 # seconds
60 | max_result_size: int = 10000 # rows
61 |
62 | # Additional options
63 | options: Dict[str, Any] = field(default_factory=dict)
64 |
65 | def get_connection_string(self) -> str:
66 | """Generate connection string (without password)."""
67 | if self.type == DatabaseType.SQLITE:
68 | return f"sqlite:///{self.database}"
69 |
70 | auth = ""
71 | if self.username:
72 | auth = f"{self.username}:***@"
73 |
74 | return f"{self.type.value}://{auth}{self.host}:{self.port}/{self.database}"
75 |
76 |
77 | @dataclass
78 | class LLMConfig:
79 | """LLM configuration."""
80 | provider: LLMProvider
81 | api_key: Optional[str] = None
82 |
83 | # Model settings
84 | model_name: str = "gpt-4"
85 | temperature: float = 0.1
86 | max_tokens: int = 1000
87 | timeout: int = 30
88 |
89 | # Cost control
90 | max_tokens_per_query: int = 2000
91 | max_queries_per_minute: int = 60
92 | max_cost_per_day: float = 100.0
93 |
94 | # Prompt settings
95 | system_prompt: Optional[str] = None
96 | few_shot_examples: List[Dict[str, str]] = field(default_factory=list)
97 |
98 | # Provider-specific settings
99 | azure_endpoint: Optional[str] = None
100 | azure_deployment: Optional[str] = None
101 | huggingface_model_id: Optional[str] = None
102 | local_model_path: Optional[str] = None
103 |
104 | # Advanced settings
105 | enable_function_calling: bool = True
106 | enable_streaming: bool = False
107 | retry_attempts: int = 3
108 | retry_delay: float = 1.0
109 |
110 |
111 | @dataclass
112 | class CacheConfig:
113 | """Cache configuration."""
114 | provider: CacheProvider
115 |
116 | # TTL settings (in seconds)
117 | query_result_ttl: int = 3600 # 1 hour
118 | schema_ttl: int = 86400 # 24 hours
119 | llm_response_ttl: int = 7200 # 2 hours
120 |
121 | # Size limits
122 | max_cache_size_mb: int = 100
123 | max_entry_size_mb: int = 10
124 | eviction_policy: str = "lru" # lru, lfu, ttl
125 |
126 | # Redis settings
127 | redis_host: str = "localhost"
128 | redis_port: int = 6379
129 | redis_password: Optional[str] = None
130 | redis_db: int = 0
131 | redis_ssl: bool = False
132 |
133 | # Disk cache settings
134 | disk_cache_path: str = str(Path.home() / ".cognidb" / "cache")
135 |
136 | # Performance settings
137 | enable_compression: bool = True
138 | enable_async_writes: bool = True
139 |
140 |
141 | @dataclass
142 | class SecurityConfig:
143 | """Security configuration."""
144 | # Query validation
145 | allow_only_select: bool = True
146 | max_query_complexity: int = 10
147 | allow_subqueries: bool = False
148 | allow_unions: bool = False
149 |
150 | # Rate limiting
151 | enable_rate_limiting: bool = True
152 | rate_limit_per_minute: int = 100
153 | rate_limit_per_hour: int = 1000
154 |
155 | # Access control
156 | enable_access_control: bool = True
157 | default_user_permissions: List[str] = field(default_factory=lambda: ["SELECT"])
158 | require_authentication: bool = False
159 |
160 | # Audit logging
161 | enable_audit_logging: bool = True
162 | audit_log_path: str = str(Path.home() / ".cognidb" / "audit.log")
163 | log_query_results: bool = False
164 |
165 | # Encryption
166 | encrypt_cache: bool = True
167 | encrypt_logs: bool = True
168 | encryption_key: Optional[str] = None # Should be loaded from secrets
169 |
170 | # Network security
171 | allowed_ip_ranges: List[str] = field(default_factory=list)
172 | require_ssl: bool = True
173 |
174 |
175 | @dataclass
176 | class Settings:
177 | """Main settings container."""
178 | # Core configurations
179 | database: DatabaseConfig
180 | llm: LLMConfig
181 | cache: CacheConfig
182 | security: SecurityConfig
183 |
184 | # Application settings
185 | app_name: str = "CogniDB"
186 | environment: str = "production"
187 | debug: bool = False
188 | log_level: str = "INFO"
189 |
190 | # Paths
191 | data_dir: str = str(Path.home() / ".cognidb")
192 | log_dir: str = str(Path.home() / ".cognidb" / "logs")
193 |
194 | # Feature flags
195 | enable_natural_language: bool = True
196 | enable_query_explanation: bool = True
197 | enable_query_optimization: bool = True
198 | enable_auto_indexing: bool = False
199 |
200 | # Monitoring
201 | enable_metrics: bool = True
202 | metrics_port: int = 9090
203 | enable_tracing: bool = True
204 | tracing_endpoint: Optional[str] = None
205 |
206 | @classmethod
207 | def from_env(cls) -> 'Settings':
208 | """Create settings from environment variables."""
209 | return cls(
210 | database=DatabaseConfig(
211 | type=DatabaseType(os.getenv('DB_TYPE', 'postgresql')),
212 | host=os.getenv('DB_HOST', 'localhost'),
213 | port=int(os.getenv('DB_PORT', '5432')),
214 | database=os.getenv('DB_NAME', 'cognidb'),
215 | username=os.getenv('DB_USER'),
216 | password=os.getenv('DB_PASSWORD')
217 | ),
218 | llm=LLMConfig(
219 | provider=LLMProvider(os.getenv('LLM_PROVIDER', 'openai')),
220 | api_key=os.getenv('LLM_API_KEY'),
221 | model_name=os.getenv('LLM_MODEL', 'gpt-4')
222 | ),
223 | cache=CacheConfig(
224 | provider=CacheProvider(os.getenv('CACHE_PROVIDER', 'in_memory'))
225 | ),
226 | security=SecurityConfig(
227 | allow_only_select=os.getenv('SECURITY_SELECT_ONLY', 'true').lower() == 'true'
228 | ),
229 | environment=os.getenv('ENVIRONMENT', 'production'),
230 | debug=os.getenv('DEBUG', 'false').lower() == 'true'
231 | )
232 |
233 | def validate(self) -> List[str]:
234 | """Validate settings and return list of errors."""
235 | errors = []
236 |
237 | # Database validation
238 | if not self.database.host:
239 | errors.append("Database host is required")
240 | if self.database.port <= 0 or self.database.port > 65535:
241 | errors.append("Invalid database port")
242 |
243 | # LLM validation
244 | if self.llm.provider != LLMProvider.LOCAL and not self.llm.api_key:
245 | errors.append("LLM API key is required for non-local providers")
246 | if self.llm.temperature < 0 or self.llm.temperature > 2:
247 | errors.append("LLM temperature must be between 0 and 2")
248 |
249 | # Security validation
250 | if self.security.encrypt_cache and not self.security.encryption_key:
251 | errors.append("Encryption key required when encryption is enabled")
252 |
253 | # Path validation
254 | for path_attr in ['data_dir', 'log_dir']:
255 | path = getattr(self, path_attr)
256 | try:
257 | Path(path).mkdir(parents=True, exist_ok=True)
258 | except Exception as e:
259 | errors.append(f"Cannot create {path_attr}: {e}")
260 |
261 | return errors
--------------------------------------------------------------------------------
/cognidb/core/query_intent.py:
--------------------------------------------------------------------------------
1 | """Query intent representation - database agnostic query structure."""
2 |
3 | from dataclasses import dataclass, field
4 | from enum import Enum, auto
5 | from typing import List, Optional, Dict, Any, Union
6 |
7 |
8 | class QueryType(Enum):
9 | """Supported query types."""
10 | SELECT = auto()
11 | AGGREGATE = auto()
12 | COUNT = auto()
13 | DISTINCT = auto()
14 |
15 |
16 | class ComparisonOperator(Enum):
17 | """Comparison operators for conditions."""
18 | EQ = "="
19 | NE = "!="
20 | GT = ">"
21 | GTE = ">="
22 | LT = "<"
23 | LTE = "<="
24 | IN = "IN"
25 | NOT_IN = "NOT IN"
26 | LIKE = "LIKE"
27 | NOT_LIKE = "NOT LIKE"
28 | IS_NULL = "IS NULL"
29 | IS_NOT_NULL = "IS NOT NULL"
30 | BETWEEN = "BETWEEN"
31 |
32 |
33 | class LogicalOperator(Enum):
34 | """Logical operators for combining conditions."""
35 | AND = "AND"
36 | OR = "OR"
37 |
38 |
39 | class AggregateFunction(Enum):
40 | """Supported aggregate functions."""
41 | SUM = "SUM"
42 | AVG = "AVG"
43 | COUNT = "COUNT"
44 | MIN = "MIN"
45 | MAX = "MAX"
46 | GROUP_CONCAT = "GROUP_CONCAT"
47 |
48 |
49 | class JoinType(Enum):
50 | """Types of joins."""
51 | INNER = "INNER"
52 | LEFT = "LEFT"
53 | RIGHT = "RIGHT"
54 | FULL = "FULL"
55 |
56 |
57 | @dataclass
58 | class Column:
59 | """Represents a column reference."""
60 | name: str
61 | table: Optional[str] = None
62 | alias: Optional[str] = None
63 |
64 | def __str__(self) -> str:
65 | if self.table:
66 | return f"{self.table}.{self.name}"
67 | return self.name
68 |
69 |
70 | @dataclass
71 | class Condition:
72 | """Represents a query condition."""
73 | column: Column
74 | operator: ComparisonOperator
75 | value: Any
76 |
77 | def __post_init__(self):
78 | """Validate condition parameters."""
79 | if self.operator == ComparisonOperator.BETWEEN:
80 | if not isinstance(self.value, (list, tuple)) or len(self.value) != 2:
81 | raise ValueError("BETWEEN operator requires a list/tuple of two values")
82 | elif self.operator in (ComparisonOperator.IN, ComparisonOperator.NOT_IN):
83 | if not isinstance(self.value, (list, tuple, set)):
84 | raise ValueError(f"{self.operator.value} operator requires a list/tuple/set")
85 |
86 |
87 | @dataclass
88 | class ConditionGroup:
89 | """Group of conditions with logical operator."""
90 | conditions: List[Union[Condition, 'ConditionGroup']]
91 | operator: LogicalOperator = LogicalOperator.AND
92 |
93 | def add_condition(self, condition: Union[Condition, 'ConditionGroup']):
94 | """Add a condition to the group."""
95 | self.conditions.append(condition)
96 |
97 |
98 | @dataclass
99 | class JoinCondition:
100 | """Represents a join between tables."""
101 | join_type: JoinType
102 | left_table: str
103 | right_table: str
104 | left_column: str
105 | right_column: str
106 | additional_conditions: Optional[ConditionGroup] = None
107 |
108 |
109 | @dataclass
110 | class Aggregation:
111 | """Represents an aggregation operation."""
112 | function: AggregateFunction
113 | column: Column
114 | alias: Optional[str] = None
115 |
116 | def __str__(self) -> str:
117 | return f"{self.function.value}({self.column})"
118 |
119 |
120 | @dataclass
121 | class OrderBy:
122 | """Represents ordering specification."""
123 | column: Column
124 | ascending: bool = True
125 |
126 |
127 | @dataclass
128 | class QueryIntent:
129 | """
130 | Database-agnostic representation of a query intent.
131 |
132 | This is the core abstraction that allows CogniDB to work with
133 | multiple database types by translating this intent into
134 | database-specific queries.
135 | """
136 | query_type: QueryType
137 | tables: List[str]
138 | columns: List[Column] = field(default_factory=list)
139 | conditions: Optional[ConditionGroup] = None
140 | joins: List[JoinCondition] = field(default_factory=list)
141 | aggregations: List[Aggregation] = field(default_factory=list)
142 | group_by: List[Column] = field(default_factory=list)
143 | having: Optional[ConditionGroup] = None
144 | order_by: List[OrderBy] = field(default_factory=list)
145 | limit: Optional[int] = None
146 | offset: Optional[int] = None
147 | distinct: bool = False
148 |
149 | # Metadata for optimization and caching
150 | natural_language_query: Optional[str] = None
151 | estimated_cost: Optional[float] = None
152 | cache_ttl: Optional[int] = None # seconds
153 |
154 | def __post_init__(self):
155 | """Validate query intent."""
156 | if not self.tables:
157 | raise ValueError("At least one table must be specified")
158 |
159 | if self.query_type == QueryType.SELECT and not self.columns:
160 | # Default to all columns if none specified
161 | self.columns = [Column("*")]
162 |
163 | if self.aggregations and not self.group_by:
164 | # Check if all columns are aggregated
165 | non_aggregated = [
166 | col for col in self.columns
167 | if not any(agg.column.name == col.name for agg in self.aggregations)
168 | ]
169 | if non_aggregated:
170 | raise ValueError(
171 | "Non-aggregated columns must be in GROUP BY clause"
172 | )
173 |
174 | if self.having and not self.group_by:
175 | raise ValueError("HAVING clause requires GROUP BY")
176 |
177 | def add_column(self, column: Union[str, Column]):
178 | """Add a column to select."""
179 | if isinstance(column, str):
180 | column = Column(column)
181 | self.columns.append(column)
182 |
183 | def add_condition(self, condition: Condition):
184 | """Add a WHERE condition."""
185 | if not self.conditions:
186 | self.conditions = ConditionGroup([])
187 | self.conditions.add_condition(condition)
188 |
189 | def add_join(self, join: JoinCondition):
190 | """Add a join condition."""
191 | self.joins.append(join)
192 |
193 | def add_aggregation(self, aggregation: Aggregation):
194 | """Add an aggregation."""
195 | self.aggregations.append(aggregation)
196 | self.query_type = QueryType.AGGREGATE
197 |
198 | def set_limit(self, limit: int, offset: int = 0):
199 | """Set result limit and offset."""
200 | if limit <= 0:
201 | raise ValueError("Limit must be positive")
202 | if offset < 0:
203 | raise ValueError("Offset must be non-negative")
204 | self.limit = limit
205 | self.offset = offset
206 |
207 | def to_dict(self) -> Dict[str, Any]:
208 | """Convert to dictionary for serialization."""
209 | return {
210 | 'query_type': self.query_type.name,
211 | 'tables': self.tables,
212 | 'columns': [str(col) for col in self.columns],
213 | 'conditions': self._condition_group_to_dict(self.conditions) if self.conditions else None,
214 | 'joins': [self._join_to_dict(j) for j in self.joins],
215 | 'aggregations': [str(agg) for agg in self.aggregations],
216 | 'group_by': [str(col) for col in self.group_by],
217 | 'having': self._condition_group_to_dict(self.having) if self.having else None,
218 | 'order_by': [{'column': str(ob.column), 'asc': ob.ascending} for ob in self.order_by],
219 | 'limit': self.limit,
220 | 'offset': self.offset,
221 | 'distinct': self.distinct,
222 | 'natural_language_query': self.natural_language_query
223 | }
224 |
225 | def _condition_group_to_dict(self, group: ConditionGroup) -> Dict[str, Any]:
226 | """Convert condition group to dict."""
227 | return {
228 | 'operator': group.operator.value,
229 | 'conditions': [
230 | self._condition_to_dict(c) if isinstance(c, Condition)
231 | else self._condition_group_to_dict(c)
232 | for c in group.conditions
233 | ]
234 | }
235 |
236 | def _condition_to_dict(self, condition: Condition) -> Dict[str, Any]:
237 | """Convert condition to dict."""
238 | return {
239 | 'column': str(condition.column),
240 | 'operator': condition.operator.value,
241 | 'value': condition.value
242 | }
243 |
244 | def _join_to_dict(self, join: JoinCondition) -> Dict[str, Any]:
245 | """Convert join to dict."""
246 | return {
247 | 'type': join.join_type.value,
248 | 'left_table': join.left_table,
249 | 'right_table': join.right_table,
250 | 'left_column': join.left_column,
251 | 'right_column': join.right_column
252 | }
--------------------------------------------------------------------------------
/cognidb/security/sanitizer.py:
--------------------------------------------------------------------------------
1 | """Input sanitization utilities."""
2 |
3 | import re
4 | import html
5 | from typing import Any, Dict, List, Union
6 |
7 |
8 | class InputSanitizer:
9 | """
10 | Comprehensive input sanitizer for all user inputs.
11 |
12 | Provides multiple sanitization strategies:
13 | 1. SQL identifiers (tables, columns)
14 | 2. String values
15 | 3. Numeric values
16 | 4. Natural language queries
17 | """
18 |
19 | # Characters allowed in natural language queries
20 | ALLOWED_NL_CHARS = re.compile(r'[^a-zA-Z0-9\s\-_.,!?\'"\(\)%$#@]')
21 |
22 | # Maximum lengths for various inputs
23 | MAX_NATURAL_LANGUAGE_LENGTH = 500
24 | MAX_IDENTIFIER_LENGTH = 64
25 | MAX_STRING_VALUE_LENGTH = 1000
26 |
27 | @staticmethod
28 | def sanitize_natural_language(query: str) -> str:
29 | """
30 | Sanitize natural language query.
31 |
32 | Args:
33 | query: Raw natural language query
34 |
35 | Returns:
36 | Sanitized query safe for LLM processing
37 | """
38 | if not query:
39 | return ""
40 |
41 | # Truncate if too long
42 | query = query[:InputSanitizer.MAX_NATURAL_LANGUAGE_LENGTH]
43 |
44 | # Remove potentially harmful characters while preserving readability
45 | query = InputSanitizer.ALLOWED_NL_CHARS.sub(' ', query)
46 |
47 | # Normalize whitespace
48 | query = ' '.join(query.split())
49 |
50 | # HTML escape for additional safety
51 | query = html.escape(query, quote=False)
52 |
53 | return query.strip()
54 |
55 | @staticmethod
56 | def sanitize_identifier(identifier: str) -> str:
57 | """
58 | Sanitize database identifier (table/column name).
59 |
60 | Args:
61 | identifier: Raw identifier
62 |
63 | Returns:
64 | Sanitized identifier
65 |
66 | Raises:
67 | ValueError: If identifier cannot be sanitized
68 | """
69 | if not identifier:
70 | raise ValueError("Identifier cannot be empty")
71 |
72 | # Remove any quotes or special characters
73 | identifier = re.sub(r'[^a-zA-Z0-9_]', '', identifier)
74 |
75 | # Ensure it starts with a letter or underscore
76 | if not re.match(r'^[a-zA-Z_]', identifier):
77 | identifier = f"_{identifier}"
78 |
79 | # Truncate if too long
80 | identifier = identifier[:InputSanitizer.MAX_IDENTIFIER_LENGTH]
81 |
82 | if not identifier:
83 | raise ValueError("Identifier contains no valid characters")
84 |
85 | return identifier
86 |
87 | @staticmethod
88 | def sanitize_string_value(value: str, allow_wildcards: bool = False) -> str:
89 | """
90 | Sanitize string value for use in queries.
91 |
92 | Args:
93 | value: Raw string value
94 | allow_wildcards: Whether to allow SQL wildcards (% and _)
95 |
96 | Returns:
97 | Sanitized string value
98 | """
99 | if not isinstance(value, str):
100 | return str(value)
101 |
102 | # Truncate if too long
103 | value = value[:InputSanitizer.MAX_STRING_VALUE_LENGTH]
104 |
105 | # Remove null bytes
106 | value = value.replace('\x00', '')
107 |
108 | # Handle SQL wildcards
109 | if not allow_wildcards:
110 | value = value.replace('%', '\\%').replace('_', '\\_')
111 |
112 | # Note: Actual SQL escaping should be done by parameterized queries
113 | # This is just an additional safety layer
114 |
115 | return value
116 |
117 | @staticmethod
118 | def sanitize_numeric_value(value: Union[int, float, str]) -> Union[int, float, None]:
119 | """
120 | Sanitize numeric value.
121 |
122 | Args:
123 | value: Raw numeric value
124 |
125 | Returns:
126 | Sanitized numeric value or None if invalid
127 | """
128 | if isinstance(value, (int, float)):
129 | return value
130 |
131 | if isinstance(value, str):
132 | try:
133 | # Try to parse as float first
134 | if '.' in value:
135 | return float(value)
136 | else:
137 | return int(value)
138 | except ValueError:
139 | return None
140 |
141 | return None
142 |
143 | @staticmethod
144 | def sanitize_list_value(values: List[Any], sanitize_func=None) -> List[Any]:
145 | """
146 | Sanitize a list of values.
147 |
148 | Args:
149 | values: List of raw values
150 | sanitize_func: Function to apply to each value
151 |
152 | Returns:
153 | List of sanitized values
154 | """
155 | if not isinstance(values, (list, tuple, set)):
156 | raise ValueError("Input must be a list, tuple, or set")
157 |
158 | if sanitize_func is None:
159 | sanitize_func = InputSanitizer.sanitize_string_value
160 |
161 | sanitized = []
162 | for value in values:
163 | try:
164 | sanitized_value = sanitize_func(value)
165 | if sanitized_value is not None:
166 | sanitized.append(sanitized_value)
167 | except Exception:
168 | # Skip invalid values
169 | continue
170 |
171 | return sanitized
172 |
173 | @staticmethod
174 | def sanitize_dict_value(data: Dict[str, Any]) -> Dict[str, Any]:
175 | """
176 | Recursively sanitize dictionary values.
177 |
178 | Args:
179 | data: Dictionary with raw values
180 |
181 | Returns:
182 | Dictionary with sanitized values
183 | """
184 | if not isinstance(data, dict):
185 | raise ValueError("Input must be a dictionary")
186 |
187 | sanitized = {}
188 | for key, value in data.items():
189 | # Sanitize key
190 | safe_key = InputSanitizer.sanitize_identifier(key)
191 |
192 | # Sanitize value based on type
193 | if isinstance(value, str):
194 | sanitized[safe_key] = InputSanitizer.sanitize_string_value(value)
195 | elif isinstance(value, (int, float)):
196 | sanitized[safe_key] = InputSanitizer.sanitize_numeric_value(value)
197 | elif isinstance(value, (list, tuple, set)):
198 | sanitized[safe_key] = InputSanitizer.sanitize_list_value(value)
199 | elif isinstance(value, dict):
200 | sanitized[safe_key] = InputSanitizer.sanitize_dict_value(value)
201 | elif value is None:
202 | sanitized[safe_key] = None
203 | else:
204 | # Convert to string and sanitize
205 | sanitized[safe_key] = InputSanitizer.sanitize_string_value(str(value))
206 |
207 | return sanitized
208 |
209 | @staticmethod
210 | def escape_like_pattern(pattern: str) -> str:
211 | """
212 | Escape special characters in LIKE patterns.
213 |
214 | Args:
215 | pattern: Raw LIKE pattern
216 |
217 | Returns:
218 | Escaped pattern
219 | """
220 | # Escape LIKE special characters
221 | pattern = pattern.replace('\\', '\\\\')
222 | pattern = pattern.replace('%', '\\%')
223 | pattern = pattern.replace('_', '\\_')
224 | return pattern
225 |
226 | @staticmethod
227 | def validate_and_sanitize_limit(limit: Any) -> int:
228 | """
229 | Validate and sanitize LIMIT value.
230 |
231 | Args:
232 | limit: Raw limit value
233 |
234 | Returns:
235 | Sanitized limit value
236 |
237 | Raises:
238 | ValueError: If limit is invalid
239 | """
240 | try:
241 | limit = int(limit)
242 | if limit < 1:
243 | raise ValueError("Limit must be positive")
244 | if limit > 10000: # Reasonable maximum
245 | return 10000
246 | return limit
247 | except (TypeError, ValueError):
248 | raise ValueError("Invalid limit value")
249 |
250 | @staticmethod
251 | def validate_and_sanitize_offset(offset: Any) -> int:
252 | """
253 | Validate and sanitize OFFSET value.
254 |
255 | Args:
256 | offset: Raw offset value
257 |
258 | Returns:
259 | Sanitized offset value
260 |
261 | Raises:
262 | ValueError: If offset is invalid
263 | """
264 | try:
265 | offset = int(offset)
266 | if offset < 0:
267 | raise ValueError("Offset must be non-negative")
268 | if offset > 1000000: # Reasonable maximum
269 | raise ValueError("Offset too large")
270 | return offset
271 | except (TypeError, ValueError):
272 | raise ValueError("Invalid offset value")
--------------------------------------------------------------------------------
/cognidb/drivers/mysql_driver.py:
--------------------------------------------------------------------------------
1 | """Secure MySQL driver implementation."""
2 |
3 | import time
4 | import logging
5 | from typing import Dict, List, Any, Optional
6 | import mysql.connector
7 | from mysql.connector import pooling, Error
8 | from .base_driver import BaseDriver
9 | from ..core.exceptions import ConnectionError, ExecutionError
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class MySQLDriver(BaseDriver):
15 | """
16 | MySQL database driver with security enhancements.
17 |
18 | Features:
19 | - Connection pooling
20 | - Parameterized queries only
21 | - SSL/TLS support
22 | - Query timeout enforcement
23 | - Automatic reconnection
24 | """
25 |
26 | def __init__(self, config: Dict[str, Any]):
27 | """Initialize MySQL driver."""
28 | super().__init__(config)
29 | self.pool = None
30 |
31 | def connect(self) -> None:
32 | """Establish connection to MySQL database."""
33 | try:
34 | # Prepare connection config
35 | pool_config = {
36 | 'pool_name': 'cognidb_mysql_pool',
37 | 'pool_size': self.config.get('pool_size', 5),
38 | 'host': self.config['host'],
39 | 'port': self.config.get('port', 3306),
40 | 'database': self.config['database'],
41 | 'user': self.config.get('username'),
42 | 'password': self.config.get('password'),
43 | 'autocommit': False,
44 | 'raise_on_warnings': True,
45 | 'sql_mode': 'TRADITIONAL',
46 | 'time_zone': '+00:00',
47 | 'connect_timeout': self.config.get('connection_timeout', 10)
48 | }
49 |
50 | # SSL configuration
51 | if self.config.get('ssl_enabled'):
52 | ssl_config = {}
53 | if self.config.get('ssl_ca_cert'):
54 | ssl_config['ca'] = self.config['ssl_ca_cert']
55 | if self.config.get('ssl_client_cert'):
56 | ssl_config['cert'] = self.config['ssl_client_cert']
57 | if self.config.get('ssl_client_key'):
58 | ssl_config['key'] = self.config['ssl_client_key']
59 | pool_config['ssl_disabled'] = False
60 | pool_config['ssl_verify_cert'] = True
61 | pool_config['ssl_verify_identity'] = True
62 | pool_config.update(ssl_config)
63 |
64 | # Create connection pool
65 | self.pool = pooling.MySQLConnectionPool(**pool_config)
66 |
67 | # Test connection
68 | self.connection = self.pool.get_connection()
69 | self._connection_time = time.time()
70 |
71 | logger.info(f"Connected to MySQL database: {self.config['database']}")
72 |
73 | except Error as e:
74 | logger.error(f"MySQL connection failed: {str(e)}")
75 | raise ConnectionError(f"Failed to connect to MySQL: {str(e)}")
76 |
77 | def disconnect(self) -> None:
78 | """Close the database connection."""
79 | if self.connection:
80 | try:
81 | self.connection.close()
82 | logger.info("Disconnected from MySQL database")
83 | except Error as e:
84 | logger.error(f"Error closing connection: {str(e)}")
85 | finally:
86 | self.connection = None
87 | self._connection_time = None
88 |
89 | def _create_connection(self):
90 | """Get connection from pool."""
91 | if not self.pool:
92 | raise ConnectionError("Connection pool not initialized")
93 | return self.pool.get_connection()
94 |
95 | def _close_connection(self):
96 | """Return connection to pool."""
97 | if self.connection:
98 | self.connection.close()
99 |
100 | def _execute_with_timeout(self,
101 | query: str,
102 | params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
103 | """Execute query with timeout."""
104 | cursor = None
105 |
106 | try:
107 | # Get connection from pool if needed
108 | if not self.connection or not self.connection.is_connected():
109 | self.connection = self._create_connection()
110 |
111 | # Set query timeout
112 | timeout = self.config.get('query_timeout', 30)
113 | self.connection.cmd_query(f"SET SESSION MAX_EXECUTION_TIME={timeout * 1000}")
114 |
115 | # Create cursor
116 | cursor = self.connection.cursor(dictionary=True, buffered=True)
117 |
118 | # Execute query with parameters
119 | if params:
120 | # Convert dict params to list for MySQL
121 | cursor.execute(query, list(params.values()))
122 | else:
123 | cursor.execute(query)
124 |
125 | # Fetch results
126 | if cursor.description:
127 | results = cursor.fetchall()
128 |
129 | # Apply result size limit
130 | max_results = self.config.get('max_result_size', 10000)
131 | if len(results) > max_results:
132 | logger.warning(f"Result truncated from {len(results)} to {max_results} rows")
133 | results = results[:max_results]
134 |
135 | return results
136 | else:
137 | # For non-SELECT queries
138 | self.connection.commit()
139 | return [{'affected_rows': cursor.rowcount}]
140 |
141 | except Error as e:
142 | if self.connection:
143 | self.connection.rollback()
144 | raise ExecutionError(f"Query execution failed: {str(e)}")
145 | finally:
146 | if cursor:
147 | cursor.close()
148 |
149 | def _fetch_schema_impl(self) -> Dict[str, Dict[str, str]]:
150 | """Fetch MySQL schema using INFORMATION_SCHEMA."""
151 | query = """
152 | SELECT
153 | TABLE_NAME,
154 | COLUMN_NAME,
155 | DATA_TYPE,
156 | IS_NULLABLE,
157 | COLUMN_KEY,
158 | COLUMN_DEFAULT,
159 | EXTRA
160 | FROM INFORMATION_SCHEMA.COLUMNS
161 | WHERE TABLE_SCHEMA = %s
162 | ORDER BY TABLE_NAME, ORDINAL_POSITION
163 | """
164 |
165 | cursor = None
166 | try:
167 | cursor = self.connection.cursor(dictionary=True)
168 | cursor.execute(query, (self.config['database'],))
169 |
170 | schema = {}
171 | for row in cursor:
172 | table_name = row['TABLE_NAME']
173 | if table_name not in schema:
174 | schema[table_name] = {}
175 |
176 | # Build column info
177 | col_type = row['DATA_TYPE']
178 | if row['IS_NULLABLE'] == 'NO':
179 | col_type += ' NOT NULL'
180 | if row['COLUMN_KEY'] == 'PRI':
181 | col_type += ' PRIMARY KEY'
182 | if row['EXTRA']:
183 | col_type += f" {row['EXTRA']}"
184 |
185 | schema[table_name][row['COLUMN_NAME']] = col_type
186 |
187 | # Fetch indexes
188 | self._fetch_indexes(schema, cursor)
189 |
190 | return schema
191 |
192 | finally:
193 | if cursor:
194 | cursor.close()
195 |
196 | def _fetch_indexes(self, schema: Dict[str, Dict[str, str]], cursor):
197 | """Fetch index information."""
198 | query = """
199 | SELECT
200 | TABLE_NAME,
201 | INDEX_NAME,
202 | GROUP_CONCAT(COLUMN_NAME ORDER BY SEQ_IN_INDEX) as COLUMNS
203 | FROM INFORMATION_SCHEMA.STATISTICS
204 | WHERE TABLE_SCHEMA = %s AND INDEX_NAME != 'PRIMARY'
205 | GROUP BY TABLE_NAME, INDEX_NAME
206 | """
207 |
208 | cursor.execute(query, (self.config['database'],))
209 |
210 | for row in cursor:
211 | table_name = row['TABLE_NAME']
212 | if table_name in schema:
213 | index_key = f"{table_name}_indexes"
214 | if index_key not in schema:
215 | schema[index_key] = []
216 | schema[index_key].append(f"{row['INDEX_NAME']} ({row['COLUMNS']})")
217 |
218 | def _begin_transaction(self):
219 | """Begin a transaction."""
220 | self.connection.start_transaction()
221 |
222 | def _commit_transaction(self):
223 | """Commit a transaction."""
224 | self.connection.commit()
225 |
226 | def _rollback_transaction(self):
227 | """Rollback a transaction."""
228 | self.connection.rollback()
229 |
230 | def _get_driver_info(self) -> Dict[str, Any]:
231 | """Get MySQL-specific information."""
232 | info = {
233 | 'server_version': None,
234 | 'connection_id': None,
235 | 'character_set': None
236 | }
237 |
238 | if self.connection and self.connection.is_connected():
239 | try:
240 | cursor = self.connection.cursor()
241 | cursor.execute("SELECT VERSION(), CONNECTION_ID(), @@character_set_database")
242 | result = cursor.fetchone()
243 | info['server_version'] = result[0]
244 | info['connection_id'] = result[1]
245 | info['character_set'] = result[2]
246 | cursor.close()
247 | except Exception:
248 | pass
249 |
250 | return info
251 |
252 | @property
253 | def supports_transactions(self) -> bool:
254 | """MySQL supports transactions."""
255 | return True
256 |
257 | @property
258 | def supports_schemas(self) -> bool:
259 | """MySQL supports schemas (databases)."""
260 | return True
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 | # `CogniDB`
2 |
3 | A secure, production-ready natural language database interface that empowers users to query databases using plain English while maintaining enterprise-grade security and performance.
4 |
5 | ---
6 |
7 |
8 |
9 | ## Features
10 |
11 | ### Core Capabilities
12 | - 🗣️ Natural Language Querying: Ask questions in plain English
13 | - 🔍 Intelligent SQL Generation: Context-aware query generation with schema understanding
14 | - 🛡️ Enterprise Security: Multi-layer security validation and sanitization
15 | - 🚀 High Performance: Query caching, connection pooling, and optimization
16 | - 📊 Multi-Database Support: MySQL, PostgreSQL, MongoDB, DynamoDB, SQLite
17 | - 💰 Cost Control: LLM usage tracking with configurable limits
18 | - 📈 Query Optimization: AI-powered query performance suggestions
19 |
20 | ### Security Features
21 | - SQL Injection Prevention: Parameterized queries and comprehensive validation
22 | - Access Control: Table and column-level permissions
23 | - Rate Limiting: Configurable request limits
24 | - Audit Logging: Complete query audit trail
25 | - Encryption: At-rest and in-transit encryption support
26 |
27 | ### AI/LLM Features
28 | - Multi-Provider Support: OpenAI, Anthropic, Azure, HuggingFace, Local models
29 | - Cost Tracking: Real-time usage and cost monitoring
30 | - Smart Caching: Reduce costs with intelligent response caching
31 | - Few-Shot Learning: Improve accuracy with custom examples
32 |
33 | ---
34 |
35 | ## Quick Start
36 |
37 | ### Installation
38 |
39 | 1. `Install dependencies`
40 | pip install cognidb
41 |
42 | 2. `With all optional dependencies`
43 | pip install cognidb[all]
44 |
45 | 3. `With specific features`
46 | pip install cognidb[redis,azure]
47 |
48 | ### Basic Usage
49 |
50 | ```python
51 | from cognidb import create_cognidb
52 |
53 | # Initialize with configuration
54 | db = create_cognidb(
55 | database={
56 | 'type': 'postgresql',
57 | 'host': 'localhost',
58 | 'database': 'mydb',
59 | 'username': 'user',
60 | 'password': 'pass'
61 | },
62 | llm={
63 | 'provider': 'openai',
64 | 'api_key': 'your-api-key'
65 | }
66 | )
67 |
68 | # Query in natural language
69 | result = db.query("Show me the top 10 customers by total purchase amount")
70 |
71 | if result['success']:
72 | print(f"SQL: {result['sql']}")
73 | print(f"Results: {result['results']}")
74 |
75 | # Always close when done
76 | db.close()
77 | ```
78 |
79 | ### Using Context Manager
80 |
81 | ```python
82 | from cognidb import CogniDB
83 |
84 | # Automatically handles connection cleanup
85 | with CogniDB(config_file='cognidb.yaml') as db:
86 | result = db.query(
87 | "What were the sales trends last quarter?",
88 | explain=True # Get explanation of the query
89 | )
90 |
91 | if result['success']:
92 | print(f"Explanation: {result['explanation']}")
93 | for row in result['results']:
94 | print(row)
95 | ```
96 |
97 | ---
98 |
99 | ## Configuration
100 |
101 | ### Environment Variables
102 |
103 | ```bash
104 | # Database settings
105 | export DB_TYPE=postgresql
106 | export DB_HOST=localhost
107 | export DB_PORT=5432
108 | export DB_NAME=mydb
109 | export DB_USER=dbuser
110 | export DB_PASSWORD=secure_password
111 |
112 | # LLM settings
113 | export LLM_PROVIDER=openai
114 | export LLM_API_KEY=your_api_key
115 | export LLM_MODEL=gpt-4
116 |
117 | # Optional: Use configuration file instead
118 | export COGNIDB_CONFIG=/path/to/cognidb.yaml
119 | ```
120 |
121 | ### Configuration File (YAML)
122 |
123 | Create a cognidb.yaml file:
124 |
125 | ```yaml
126 | database:
127 | type: postgresql
128 | host: localhost
129 | port: 5432
130 | database: analytics_db
131 | username: ${DB_USER} # Use environment variable
132 | password: ${DB_PASSWORD}
133 |
134 | # Connection settings
135 | pool_size: 5
136 | query_timeout: 30
137 | ssl_enabled: true
138 |
139 | llm:
140 | provider: openai
141 | api_key: ${LLM_API_KEY}
142 | model_name: gpt-4
143 | temperature: 0.1
144 | max_cost_per_day: 100.0
145 |
146 | # Improve accuracy with examples
147 | few_shot_examples:
148 | - query: "Show total sales by month"
149 | sql: "SELECT DATE_TRUNC('month', order_date) as month, SUM(amount) as total FROM orders GROUP BY month ORDER BY month"
150 |
151 | security:
152 | allow_only_select: true
153 | enable_rate_limiting: true
154 | rate_limit_per_minute: 100
155 | enable_audit_logging: true
156 | ```
157 |
158 | See cognidb.example.yaml for a complete configuration example.
159 |
160 | ---
161 |
162 | ## Advanced Features
163 |
164 | ### Query Optimization
165 |
166 | ```python
167 | # Get optimization suggestions
168 | sql = "SELECT * FROM orders WHERE customer_id IN (SELECT id FROM customers WHERE country = 'USA')"
169 | optimization = db.optimize_query(sql)
170 |
171 | print(f"Original: {optimization['original_query']}")
172 | print(f"Optimized: {optimization['optimized_query']}")
173 | print(f"Explanation: {optimization['explanation']}")
174 | ```
175 |
176 | ### Query Suggestions
177 |
178 | ```python
179 | # Get AI-powered query suggestions
180 | suggestions = db.suggest_queries("customers who haven't")
181 | for suggestion in suggestions:
182 | print(f"- {suggestion}")
183 | # Output:
184 | # - customers who haven't made a purchase in the last 30 days
185 | # - customers who haven't updated their profile
186 | # - customers who haven't verified their email
187 | ```
188 |
189 | ### Access Control
190 |
191 | ```python
192 | from cognidb.security import AccessController
193 |
194 | # Set up user permissions
195 | access = AccessController()
196 | access.create_restricted_user(
197 | user_id="analyst_1",
198 | table_permissions={
199 | 'customers': {
200 | 'operations': ['SELECT'],
201 | 'columns': ['id', 'name', 'email', 'country'],
202 | 'row_filter': "country = 'USA'" # Row-level security
203 | }
204 | }
205 | )
206 |
207 | # Query with user context
208 | result = db.query(
209 | "Show me all customer emails",
210 | user_id="analyst_1" # Will only see US customers
211 | )
212 | ```
213 |
214 | ### Cost Tracking
215 |
216 | ```python
217 | # Monitor LLM usage and costs
218 | stats = db.get_usage_stats()
219 | print(f"Total cost today: ${stats['daily_cost']:.2f}")
220 | print(f"Remaining budget: ${stats['remaining_budget']:.2f}")
221 | print(f"Queries today: {stats['request_count']}")
222 |
223 | # Export usage report
224 | report = db.export_usage_report(
225 | start_date='2024-01-01',
226 | end_date='2024-01-31',
227 | format='csv'
228 | )
229 | ```
230 |
231 | ---
232 |
233 | ## Architecture
234 |
235 | CogniDB uses a modular, secure architecture:
236 |
237 | ```
238 | ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
239 | │ User Input │────▶│ Security Layer │────▶│ LLM Manager │
240 | └─────────────────┘ └─────────────────┘ └─────────────────┘
241 | │ │
242 | ▼ ▼
243 | ┌─────────────────┐ ┌─────────────────┐
244 | │ Query Validator│ │ Query Generator │
245 | └─────────────────┘ └─────────────────┘
246 | │ │
247 | ▼ ▼
248 | ┌─────────────────┐ ┌─────────────────┐
249 | │ Database Driver │────▶│ Result Cache │
250 | └─────────────────┘ └─────────────────┘
251 | ```
252 |
253 | ### Key Components
254 | 1. Security Layer: Multi-stage validation and sanitization
255 | 2. LLM Manager: Handles all AI interactions with fallback support
256 | 3. Query Generator: Converts natural language to SQL with schema awareness
257 | 4. Database Drivers: Secure, parameterized database connections
258 | 5. Cache Layer: Reduces costs and improves performance
259 |
260 | ---
261 |
262 | ## Security Best Practices
263 | 1. Never expose credentials: Use environment variables or secrets managers
264 | 2. Enable SSL/TLS: Always use encrypted connections
265 | 3. Restrict permissions: Use read-only database users when possible
266 | 4. Monitor usage: Enable audit logging and review regularly
267 | 5. Update regularly: Keep CogniDB and dependencies up to date
268 |
269 | ---
270 |
271 | ## Testing
272 |
273 | ```bash
274 | # Run all tests
275 | pytest
276 |
277 | # Run with coverage
278 | pytest --cov=cognidb
279 |
280 | # Run security tests only
281 | pytest tests/security/
282 |
283 | # Run integration tests
284 | pytest tests/integration/ --db-host=localhost
285 | ```
286 |
287 | ---
288 |
289 | ## Performance Tips
290 | 1. Use connection pooling: Enabled by default for better performance
291 | 2. Enable caching: Reduces LLM costs and improves response time
292 | 3. Optimize schemas: Add appropriate indexes based on query patterns
293 | 4. Use prepared statements: For frequently executed queries
294 | 5. Monitor query performance: Use the optimization feature regularly
295 |
296 | ---
297 |
298 | ## Contributing
299 |
300 | We welcome contributions! Please see CONTRIBUTING.md for guidelines.
301 |
302 | ### Development Setup
303 |
304 | ```bash
305 | # Clone the repository
306 | git clone https://github.com/boxed-dev/cognidb
307 | cd cognidb
308 |
309 | # Install in development mode
310 | pip install -e .[dev]
311 |
312 | # Run pre-commit hooks
313 | pre-commit install
314 |
315 | # Run tests
316 | pytest
317 | ```
318 |
319 | ---
320 |
321 | ## License
322 |
323 | This project is licensed under the MIT License - see LICENSE for details.
324 |
325 | ---
326 |
327 | ## Acknowledgments
328 | - OpenAI, Anthropic, and the open-source LLM community
329 | - Contributors to SQLParse, psycopg2, and other dependencies
330 | - The CogniDB community for feedback and contributions
331 |
332 | ---
333 |
334 | Built with ❤️ for data democratization
335 |
336 |
337 |
--------------------------------------------------------------------------------
/cognidb/ai/llm_manager.py:
--------------------------------------------------------------------------------
1 | """LLM manager with multi-provider support and cost tracking."""
2 |
3 | import time
4 | from typing import Dict, Any, Optional, List, Union
5 | from abc import ABC, abstractmethod
6 | from dataclasses import dataclass
7 | from ..config.settings import LLMConfig, LLMProvider
8 | from ..core.exceptions import CogniDBError, RateLimitError
9 | from .cost_tracker import CostTracker
10 | from .providers import (
11 | OpenAIProvider,
12 | AnthropicProvider,
13 | AzureOpenAIProvider,
14 | HuggingFaceProvider,
15 | LocalProvider
16 | )
17 |
18 |
19 | @dataclass
20 | class LLMResponse:
21 | """LLM response container."""
22 | content: str
23 | model: str
24 | provider: str
25 | usage: Dict[str, int]
26 | cost: float
27 | latency: float
28 | cached: bool = False
29 |
30 |
31 | class LLMManager:
32 | """
33 | Manages LLM interactions with multiple providers.
34 |
35 | Features:
36 | - Multi-provider support with fallback
37 | - Cost tracking and limits
38 | - Rate limiting
39 | - Response caching
40 | - Token usage tracking
41 | """
42 |
43 | def __init__(self, config: LLMConfig, cache_provider=None):
44 | """
45 | Initialize LLM manager.
46 |
47 | Args:
48 | config: LLM configuration
49 | cache_provider: Optional cache provider for response caching
50 | """
51 | self.config = config
52 | self.cache = cache_provider
53 | self.cost_tracker = CostTracker(max_daily_cost=config.max_cost_per_day)
54 |
55 | # Initialize primary provider
56 | self.primary_provider = self._create_provider(config.provider, config)
57 |
58 | # Initialize fallback providers
59 | self.fallback_providers = []
60 | if config.provider != LLMProvider.OPENAI:
61 | self.fallback_providers.append(
62 | self._create_provider(LLMProvider.OPENAI, config)
63 | )
64 |
65 | # Rate limiting
66 | self._request_times: List[float] = []
67 | self._last_request_time = 0
68 |
69 | def generate(self,
70 | prompt: str,
71 | system_prompt: Optional[str] = None,
72 | max_tokens: Optional[int] = None,
73 | temperature: Optional[float] = None,
74 | use_cache: bool = True) -> LLMResponse:
75 | """
76 | Generate response from LLM.
77 |
78 | Args:
79 | prompt: User prompt
80 | system_prompt: System prompt (overrides config)
81 | max_tokens: Maximum tokens (overrides config)
82 | temperature: Temperature (overrides config)
83 | use_cache: Whether to use cached responses
84 |
85 | Returns:
86 | LLM response
87 |
88 | Raises:
89 | RateLimitError: If rate limit exceeded
90 | CogniDBError: If generation fails
91 | """
92 | # Check rate limits
93 | self._check_rate_limit()
94 |
95 | # Check cost limits
96 | if self.cost_tracker.is_limit_exceeded():
97 | raise CogniDBError("Daily cost limit exceeded")
98 |
99 | # Check cache
100 | if use_cache and self.cache:
101 | cache_key = self._generate_cache_key(prompt, system_prompt)
102 | cached_response = self.cache.get(cache_key)
103 | if cached_response:
104 | cached_response['cached'] = True
105 | return LLMResponse(**cached_response)
106 |
107 | # Prepare parameters
108 | params = {
109 | 'prompt': prompt,
110 | 'system_prompt': system_prompt or self.config.system_prompt,
111 | 'max_tokens': max_tokens or self.config.max_tokens,
112 | 'temperature': temperature or self.config.temperature
113 | }
114 |
115 | # Try primary provider
116 | start_time = time.time()
117 | response = None
118 | last_error = None
119 |
120 | for provider in [self.primary_provider] + self.fallback_providers:
121 | try:
122 | response = provider.generate(**params)
123 | response['provider'] = provider.name
124 | response['latency'] = time.time() - start_time
125 | break
126 | except Exception as e:
127 | last_error = e
128 | continue
129 |
130 | if response is None:
131 | raise CogniDBError(f"All LLM providers failed: {last_error}")
132 |
133 | # Track cost
134 | cost = self._calculate_cost(response['usage'], response['model'])
135 | response['cost'] = cost
136 | self.cost_tracker.track_usage(cost, response['usage'])
137 |
138 | # Cache response
139 | if use_cache and self.cache:
140 | self.cache.set(
141 | cache_key,
142 | response,
143 | ttl=self.config.llm_response_ttl
144 | )
145 |
146 | # Update rate limiting
147 | self._request_times.append(time.time())
148 |
149 | return LLMResponse(**response)
150 |
151 | def generate_with_examples(self,
152 | prompt: str,
153 | examples: List[Dict[str, str]],
154 | **kwargs) -> LLMResponse:
155 | """
156 | Generate response with few-shot examples.
157 |
158 | Args:
159 | prompt: User prompt
160 | examples: List of input/output examples
161 | **kwargs: Additional generation parameters
162 |
163 | Returns:
164 | LLM response
165 | """
166 | # Build prompt with examples
167 | formatted_prompt = self._format_with_examples(prompt, examples)
168 | return self.generate(formatted_prompt, **kwargs)
169 |
170 | def stream_generate(self,
171 | prompt: str,
172 | callback,
173 | **kwargs):
174 | """
175 | Stream generation with callback.
176 |
177 | Args:
178 | prompt: User prompt
179 | callback: Function called with each token
180 | **kwargs: Additional generation parameters
181 | """
182 | if not self.config.enable_streaming:
183 | raise CogniDBError("Streaming not enabled in configuration")
184 |
185 | # Check if provider supports streaming
186 | if not hasattr(self.primary_provider, 'stream_generate'):
187 | raise CogniDBError("Provider does not support streaming")
188 |
189 | # Stream from provider
190 | self.primary_provider.stream_generate(prompt, callback, **kwargs)
191 |
192 | def get_usage_stats(self) -> Dict[str, Any]:
193 | """Get usage statistics."""
194 | return {
195 | 'total_cost': self.cost_tracker.get_total_cost(),
196 | 'daily_cost': self.cost_tracker.get_daily_cost(),
197 | 'token_usage': self.cost_tracker.get_token_usage(),
198 | 'request_count': len(self._request_times),
199 | 'cache_stats': self.cache.get_stats() if self.cache else None
200 | }
201 |
202 | def _create_provider(self, provider_type: LLMProvider, config: LLMConfig):
203 | """Create LLM provider instance."""
204 | if provider_type == LLMProvider.OPENAI:
205 | return OpenAIProvider(config)
206 | elif provider_type == LLMProvider.ANTHROPIC:
207 | return AnthropicProvider(config)
208 | elif provider_type == LLMProvider.AZURE_OPENAI:
209 | return AzureOpenAIProvider(config)
210 | elif provider_type == LLMProvider.HUGGINGFACE:
211 | return HuggingFaceProvider(config)
212 | elif provider_type == LLMProvider.LOCAL:
213 | return LocalProvider(config)
214 | else:
215 | raise ValueError(f"Unknown provider: {provider_type}")
216 |
217 | def _check_rate_limit(self) -> None:
218 | """Check and enforce rate limits."""
219 | current_time = time.time()
220 |
221 | # Clean old request times
222 | minute_ago = current_time - 60
223 | self._request_times = [
224 | t for t in self._request_times if t > minute_ago
225 | ]
226 |
227 | # Check rate limit
228 | if len(self._request_times) >= self.config.max_queries_per_minute:
229 | retry_after = 60 - (current_time - self._request_times[0])
230 | raise RateLimitError(
231 | f"Rate limit exceeded ({self.config.max_queries_per_minute}/min)",
232 | retry_after=int(retry_after)
233 | )
234 |
235 | def _generate_cache_key(self, prompt: str, system_prompt: Optional[str]) -> str:
236 | """Generate cache key for prompt."""
237 | import hashlib
238 |
239 | key_parts = [
240 | self.config.provider.value,
241 | self.config.model_name,
242 | str(self.config.temperature),
243 | str(self.config.max_tokens),
244 | system_prompt or self.config.system_prompt or "",
245 | prompt
246 | ]
247 |
248 | key_string = "|".join(key_parts)
249 | return f"llm:{hashlib.sha256(key_string.encode()).hexdigest()}"
250 |
251 | def _calculate_cost(self, usage: Dict[str, int], model: str) -> float:
252 | """Calculate cost based on token usage."""
253 | # Model pricing (per 1K tokens)
254 | pricing = {
255 | # OpenAI
256 | 'gpt-4': {'input': 0.03, 'output': 0.06},
257 | 'gpt-4-turbo': {'input': 0.01, 'output': 0.03},
258 | 'gpt-3.5-turbo': {'input': 0.0005, 'output': 0.0015},
259 | # Anthropic
260 | 'claude-3-opus': {'input': 0.015, 'output': 0.075},
261 | 'claude-3-sonnet': {'input': 0.003, 'output': 0.015},
262 | 'claude-3-haiku': {'input': 0.00025, 'output': 0.00125},
263 | # Default
264 | 'default': {'input': 0.001, 'output': 0.002}
265 | }
266 |
267 | model_pricing = pricing.get(model, pricing['default'])
268 |
269 | input_cost = (usage.get('prompt_tokens', 0) / 1000) * model_pricing['input']
270 | output_cost = (usage.get('completion_tokens', 0) / 1000) * model_pricing['output']
271 |
272 | return input_cost + output_cost
273 |
274 | def _format_with_examples(self, prompt: str, examples: List[Dict[str, str]]) -> str:
275 | """Format prompt with few-shot examples."""
276 | formatted_examples = []
277 |
278 | for example in examples:
279 | formatted_examples.append(
280 | f"Input: {example['input']}\nOutput: {example['output']}"
281 | )
282 |
283 | examples_text = "\n\n".join(formatted_examples)
284 |
285 | return f"""Here are some examples:
286 |
287 | {examples_text}
288 |
289 | Now, for the following input:
290 | Input: {prompt}
291 | Output:"""
--------------------------------------------------------------------------------
/cognidb/security/validator.py:
--------------------------------------------------------------------------------
1 | """Security validator implementation."""
2 |
3 | import re
4 | from typing import Tuple, Optional, List, Set
5 | from ..core.interfaces import SecurityValidator
6 | from ..core.query_intent import QueryIntent, QueryType
7 | from ..core.exceptions import SecurityError
8 | from .query_parser import SQLQueryParser
9 |
10 |
11 | class QuerySecurityValidator(SecurityValidator):
12 | """
13 | Comprehensive security validator for queries.
14 |
15 | Implements multiple layers of security:
16 | 1. Query intent validation
17 | 2. Native query validation
18 | 3. Identifier sanitization
19 | 4. Value sanitization
20 | """
21 |
22 | # Dangerous SQL keywords that should never appear
23 | FORBIDDEN_KEYWORDS = {
24 | 'DROP', 'DELETE', 'TRUNCATE', 'UPDATE', 'INSERT', 'ALTER',
25 | 'CREATE', 'REPLACE', 'RENAME', 'GRANT', 'REVOKE', 'EXECUTE',
26 | 'EXEC', 'CALL', 'MERGE', 'LOCK', 'UNLOCK'
27 | }
28 |
29 | # Patterns that might indicate SQL injection
30 | SQL_INJECTION_PATTERNS = [
31 | r';\s*--', # Statement termination followed by comment
32 | r';\s*\/\*', # Statement termination followed by comment
33 | r'UNION\s+SELECT', # UNION-based injection
34 | r'OR\s+1\s*=\s*1', # Classic SQL injection
35 | r'OR\s+\'1\'\s*=\s*\'1\'', # Classic SQL injection with quotes
36 | r'WAITFOR\s+DELAY', # Time-based injection
37 | r'BENCHMARK\s*\(', # MySQL time-based injection
38 | r'PG_SLEEP\s*\(', # PostgreSQL time-based injection
39 | r'LOAD_FILE\s*\(', # File system access
40 | r'INTO\s+OUTFILE', # File system write
41 | r'xp_cmdshell', # SQL Server command execution
42 | ]
43 |
44 | # Valid identifier pattern (alphanumeric + underscore)
45 | VALID_IDENTIFIER_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
46 |
47 | # Maximum identifier length
48 | MAX_IDENTIFIER_LENGTH = 64
49 |
50 | def __init__(self,
51 | allowed_operations: Optional[List[str]] = None,
52 | max_query_complexity: int = 10,
53 | allow_subqueries: bool = False):
54 | """
55 | Initialize security validator.
56 |
57 | Args:
58 | allowed_operations: List of allowed query types (default: SELECT only)
59 | max_query_complexity: Maximum allowed query complexity score
60 | allow_subqueries: Whether to allow subqueries
61 | """
62 | self._allowed_operations = allowed_operations or ['SELECT']
63 | self.max_query_complexity = max_query_complexity
64 | self.allow_subqueries = allow_subqueries
65 | self.parser = SQLQueryParser()
66 |
67 | @property
68 | def allowed_operations(self) -> List[str]:
69 | """List of allowed query operations."""
70 | return self._allowed_operations
71 |
72 | def validate_query_intent(self, query_intent: QueryIntent) -> Tuple[bool, Optional[str]]:
73 | """
74 | Validate query intent for security issues.
75 |
76 | Checks:
77 | 1. Query type is allowed
78 | 2. Table/column names are valid
79 | 3. Query complexity is within limits
80 | 4. No forbidden patterns in conditions
81 | """
82 | # Check query type
83 | if query_intent.query_type.name not in self.allowed_operations:
84 | return False, f"Query type {query_intent.query_type.name} is not allowed"
85 |
86 | # Validate table names
87 | for table in query_intent.tables:
88 | if not self._is_valid_identifier(table):
89 | return False, f"Invalid table name: {table}"
90 |
91 | # Validate column names
92 | for column in query_intent.columns:
93 | if column.name != "*" and not self._is_valid_identifier(column.name):
94 | return False, f"Invalid column name: {column.name}"
95 | if column.table and not self._is_valid_identifier(column.table):
96 | return False, f"Invalid table reference in column: {column.table}"
97 |
98 | # Check query complexity
99 | complexity = self._calculate_complexity(query_intent)
100 | if complexity > self.max_query_complexity:
101 | return False, f"Query too complex (score: {complexity}, max: {self.max_query_complexity})"
102 |
103 | # Validate conditions
104 | if query_intent.conditions:
105 | valid, error = self._validate_conditions(query_intent.conditions)
106 | if not valid:
107 | return False, error
108 |
109 | # Validate joins
110 | for join in query_intent.joins:
111 | if not self._is_valid_identifier(join.left_table):
112 | return False, f"Invalid table in join: {join.left_table}"
113 | if not self._is_valid_identifier(join.right_table):
114 | return False, f"Invalid table in join: {join.right_table}"
115 | if not self._is_valid_identifier(join.left_column):
116 | return False, f"Invalid column in join: {join.left_column}"
117 | if not self._is_valid_identifier(join.right_column):
118 | return False, f"Invalid column in join: {join.right_column}"
119 |
120 | return True, None
121 |
122 | def validate_native_query(self, query: str) -> Tuple[bool, Optional[str]]:
123 | """
124 | Validate native SQL query for security issues.
125 |
126 | Performs comprehensive security checks including:
127 | 1. Forbidden keyword detection
128 | 2. SQL injection pattern matching
129 | 3. Query parsing and analysis
130 | """
131 | # Normalize query for analysis
132 | normalized_query = query.upper().strip()
133 |
134 | # Check for forbidden keywords
135 | for keyword in self.FORBIDDEN_KEYWORDS:
136 | if re.search(rf'\b{keyword}\b', normalized_query):
137 | return False, f"Forbidden keyword detected: {keyword}"
138 |
139 | # Check for SQL injection patterns
140 | for pattern in self.SQL_INJECTION_PATTERNS:
141 | if re.search(pattern, normalized_query, re.IGNORECASE):
142 | return False, f"Potential SQL injection pattern detected"
143 |
144 | # Parse and validate query structure
145 | try:
146 | parsed = self.parser.parse(query)
147 | if parsed['type'] not in self.allowed_operations:
148 | return False, f"Query type {parsed['type']} is not allowed"
149 |
150 | # Additional checks based on parsed structure
151 | if not self.allow_subqueries and parsed.get('has_subquery'):
152 | return False, "Subqueries are not allowed"
153 |
154 | except Exception as e:
155 | return False, f"Query parsing failed: {str(e)}"
156 |
157 | return True, None
158 |
159 | def sanitize_identifier(self, identifier: str) -> str:
160 | """
161 | Sanitize a table/column identifier.
162 |
163 | Args:
164 | identifier: The identifier to sanitize
165 |
166 | Returns:
167 | Sanitized identifier
168 |
169 | Raises:
170 | SecurityError: If identifier cannot be sanitized safely
171 | """
172 | # Remove any quotes
173 | identifier = identifier.strip().strip('"\'`[]')
174 |
175 | # Validate
176 | if not self._is_valid_identifier(identifier):
177 | raise SecurityError(f"Invalid identifier: {identifier}")
178 |
179 | return identifier
180 |
181 | def sanitize_value(self, value: any) -> any:
182 | """
183 | Sanitize a parameter value.
184 |
185 | Args:
186 | value: The value to sanitize
187 |
188 | Returns:
189 | Sanitized value
190 | """
191 | if value is None:
192 | return None
193 |
194 | if isinstance(value, str):
195 | # Remove any SQL comment indicators
196 | value = re.sub(r'--.*$', '', value, flags=re.MULTILINE)
197 | value = re.sub(r'/\*.*?\*/', '', value, flags=re.DOTALL)
198 |
199 | # Escape special characters
200 | # Note: Actual escaping should be done by the database driver
201 | # This is just an additional safety layer
202 | value = value.replace('\x00', '') # Remove null bytes
203 |
204 | elif isinstance(value, (list, tuple)):
205 | # Recursively sanitize collections
206 | return type(value)(self.sanitize_value(v) for v in value)
207 |
208 | elif isinstance(value, dict):
209 | # Recursively sanitize dictionaries
210 | return {k: self.sanitize_value(v) for k, v in value.items()}
211 |
212 | return value
213 |
214 | def _is_valid_identifier(self, identifier: str) -> bool:
215 | """Check if an identifier is valid."""
216 | if not identifier or len(identifier) > self.MAX_IDENTIFIER_LENGTH:
217 | return False
218 | return bool(self.VALID_IDENTIFIER_PATTERN.match(identifier))
219 |
220 | def _calculate_complexity(self, query_intent: QueryIntent) -> int:
221 | """
222 | Calculate query complexity score.
223 |
224 | Factors:
225 | - Number of tables
226 | - Number of joins
227 | - Number of conditions
228 | - Aggregations
229 | - Subqueries (if parsed)
230 | """
231 | score = 0
232 |
233 | # Base score for tables
234 | score += len(query_intent.tables)
235 |
236 | # Joins add complexity
237 | score += len(query_intent.joins) * 2
238 |
239 | # Conditions add complexity
240 | if query_intent.conditions:
241 | score += self._count_conditions(query_intent.conditions)
242 |
243 | # Aggregations add complexity
244 | score += len(query_intent.aggregations)
245 |
246 | # Group by adds complexity
247 | if query_intent.group_by:
248 | score += 1
249 |
250 | # Having clause adds complexity
251 | if query_intent.having:
252 | score += 2
253 |
254 | return score
255 |
256 | def _count_conditions(self, condition_group) -> int:
257 | """Recursively count conditions in a group."""
258 | count = 0
259 | for condition in condition_group.conditions:
260 | if hasattr(condition, 'conditions'): # It's a group
261 | count += self._count_conditions(condition)
262 | else:
263 | count += 1
264 | return count
265 |
266 | def _validate_conditions(self, condition_group) -> Tuple[bool, Optional[str]]:
267 | """Validate conditions in a condition group."""
268 | for condition in condition_group.conditions:
269 | if hasattr(condition, 'conditions'): # It's a group
270 | valid, error = self._validate_conditions(condition)
271 | if not valid:
272 | return False, error
273 | else:
274 | # Validate column name
275 | if not self._is_valid_identifier(condition.column.name):
276 | return False, f"Invalid column in condition: {condition.column.name}"
277 |
278 | # Validate value isn't attempting injection
279 | if isinstance(condition.value, str):
280 | if any(keyword in condition.value.upper() for keyword in self.FORBIDDEN_KEYWORDS):
281 | return False, f"Forbidden keyword in condition value"
282 |
283 | return True, None
--------------------------------------------------------------------------------
/cognidb/ai/prompt_builder.py:
--------------------------------------------------------------------------------
1 | """Advanced prompt builder for SQL generation."""
2 |
3 | from typing import Dict, List, Any, Optional
4 | from ..core.query_intent import QueryIntent
5 |
6 |
7 | class PromptBuilder:
8 | """
9 | Builds optimized prompts for SQL generation.
10 |
11 | Features:
12 | - Schema-aware prompts
13 | - Few-shot examples
14 | - Database-specific hints
15 | - Query optimization suggestions
16 | """
17 |
18 | # Database-specific SQL dialects
19 | SQL_DIALECTS = {
20 | 'mysql': {
21 | 'limit': 'LIMIT {limit}',
22 | 'string_concat': "CONCAT({args})",
23 | 'current_date': 'CURDATE()',
24 | 'date_format': "DATE_FORMAT({date}, '{format}')"
25 | },
26 | 'postgresql': {
27 | 'limit': 'LIMIT {limit}',
28 | 'string_concat': "{args}", # Uses ||
29 | 'current_date': 'CURRENT_DATE',
30 | 'date_format': "TO_CHAR({date}, '{format}')"
31 | },
32 | 'sqlite': {
33 | 'limit': 'LIMIT {limit}',
34 | 'string_concat': "{args}", # Uses ||
35 | 'current_date': "DATE('now')",
36 | 'date_format': "STRFTIME('{format}', {date})"
37 | }
38 | }
39 |
40 | def __init__(self, database_type: str = 'postgresql'):
41 | """
42 | Initialize prompt builder.
43 |
44 | Args:
45 | database_type: Type of database (mysql, postgresql, sqlite)
46 | """
47 | self.database_type = database_type
48 | self.dialect = self.SQL_DIALECTS.get(database_type, self.SQL_DIALECTS['postgresql'])
49 |
50 | def build_sql_generation_prompt(self,
51 | natural_language_query: str,
52 | schema: Dict[str, Dict[str, str]],
53 | examples: Optional[List[Dict[str, str]]] = None,
54 | context: Optional[Dict[str, Any]] = None) -> str:
55 | """
56 | Build prompt for SQL generation.
57 |
58 | Args:
59 | natural_language_query: User's natural language query
60 | schema: Database schema
61 | examples: Optional few-shot examples
62 | context: Optional context (user preferences, constraints)
63 |
64 | Returns:
65 | Optimized prompt for LLM
66 | """
67 | # Build schema description
68 | schema_desc = self._build_schema_description(schema)
69 |
70 | # Build examples section
71 | examples_section = ""
72 | if examples:
73 | examples_section = self._build_examples_section(examples)
74 |
75 | # Build context hints
76 | context_hints = self._build_context_hints(context or {})
77 |
78 | # Construct the prompt
79 | prompt = f"""You are an expert SQL query generator for {self.database_type} databases.
80 |
81 | Database Schema:
82 | {schema_desc}
83 |
84 | {context_hints}
85 |
86 | Important Instructions:
87 | 1. Generate ONLY the SQL query, no explanations or markdown
88 | 2. Use proper {self.database_type} syntax and functions
89 | 3. Always use table aliases for clarity
90 | 4. Include appropriate JOINs when querying multiple tables
91 | 5. Use parameterized placeholders (?) for any user-provided values
92 | 6. Ensure the query is optimized for performance
93 | 7. Handle NULL values appropriately
94 | 8. Use appropriate data type conversions when needed
95 |
96 | {examples_section}
97 |
98 | Now generate a SQL query for the following request:
99 | User Query: {natural_language_query}
100 |
101 | SQL Query:"""
102 |
103 | return prompt
104 |
105 | def build_query_explanation_prompt(self,
106 | sql_query: str,
107 | schema: Dict[str, Dict[str, str]]) -> str:
108 | """
109 | Build prompt for explaining SQL query.
110 |
111 | Args:
112 | sql_query: SQL query to explain
113 | schema: Database schema
114 |
115 | Returns:
116 | Prompt for explanation
117 | """
118 | schema_desc = self._build_schema_description(schema)
119 |
120 | return f"""Explain the following SQL query in simple terms:
121 |
122 | Database Schema:
123 | {schema_desc}
124 |
125 | SQL Query:
126 | {sql_query}
127 |
128 | Provide a clear, concise explanation of:
129 | 1. What the query does
130 | 2. Which tables and columns it uses
131 | 3. Any filters or conditions applied
132 | 4. The expected result format
133 |
134 | Explanation:"""
135 |
136 | def build_optimization_prompt(self,
137 | sql_query: str,
138 | schema: Dict[str, Dict[str, str]],
139 | performance_stats: Optional[Dict[str, Any]] = None) -> str:
140 | """
141 | Build prompt for query optimization suggestions.
142 |
143 | Args:
144 | sql_query: SQL query to optimize
145 | schema: Database schema with index information
146 | performance_stats: Optional query performance statistics
147 |
148 | Returns:
149 | Prompt for optimization
150 | """
151 | schema_desc = self._build_schema_description(schema, include_indexes=True)
152 |
153 | perf_section = ""
154 | if performance_stats:
155 | perf_section = f"\nPerformance Stats:\n{self._format_performance_stats(performance_stats)}"
156 |
157 | return f"""Analyze and optimize the following SQL query:
158 |
159 | Database Schema:
160 | {schema_desc}
161 |
162 | Current Query:
163 | {sql_query}
164 | {perf_section}
165 |
166 | Provide optimization suggestions considering:
167 | 1. Index usage
168 | 2. JOIN optimization
169 | 3. Subquery elimination
170 | 4. Filtering efficiency
171 | 5. Result set size reduction
172 |
173 | Optimized Query and Explanation:"""
174 |
175 | def build_intent_to_sql_prompt(self,
176 | query_intent: QueryIntent,
177 | database_type: str) -> str:
178 | """
179 | Build prompt to convert QueryIntent to SQL.
180 |
181 | Args:
182 | query_intent: Parsed query intent
183 | database_type: Target database type
184 |
185 | Returns:
186 | Prompt for SQL generation
187 | """
188 | intent_desc = self._describe_query_intent(query_intent)
189 |
190 | return f"""Convert the following query specification to {database_type} SQL:
191 |
192 | Query Specification:
193 | {intent_desc}
194 |
195 | Generate a properly formatted {database_type} SQL query that implements this specification.
196 | Use appropriate syntax and functions for {database_type}.
197 |
198 | SQL Query:"""
199 |
200 | def _build_schema_description(self,
201 | schema: Dict[str, Dict[str, str]],
202 | include_indexes: bool = False) -> str:
203 | """Build formatted schema description."""
204 | lines = []
205 |
206 | for table_name, columns in schema.items():
207 | lines.append(f"Table: {table_name}")
208 |
209 | for col_name, col_type in columns.items():
210 | lines.append(f" - {col_name}: {col_type}")
211 |
212 | if include_indexes and f"{table_name}_indexes" in schema:
213 | lines.append(" Indexes:")
214 | for index in schema[f"{table_name}_indexes"]:
215 | lines.append(f" - {index}")
216 |
217 | lines.append("")
218 |
219 | return "\n".join(lines)
220 |
221 | def _build_examples_section(self, examples: List[Dict[str, str]]) -> str:
222 | """Build few-shot examples section."""
223 | if not examples:
224 | return ""
225 |
226 | lines = ["Examples of similar queries:\n"]
227 |
228 | for i, example in enumerate(examples, 1):
229 | lines.append(f"Example {i}:")
230 | lines.append(f"User Query: {example['query']}")
231 | lines.append(f"SQL: {example['sql']}")
232 | lines.append("")
233 |
234 | return "\n".join(lines)
235 |
236 | def _build_context_hints(self, context: Dict[str, Any]) -> str:
237 | """Build context-specific hints."""
238 | hints = []
239 |
240 | if context.get('timezone'):
241 | hints.append(f"Timezone: {context['timezone']} (adjust date/time queries accordingly)")
242 |
243 | if context.get('date_format'):
244 | hints.append(f"Preferred date format: {context['date_format']}")
245 |
246 | if context.get('limit_default'):
247 | hints.append(f"Default result limit: {context['limit_default']}")
248 |
249 | if context.get('case_sensitive'):
250 | hints.append("Use case-sensitive string comparisons")
251 |
252 | if context.get('exclude_deleted'):
253 | hints.append("Exclude soft-deleted records (check for deleted_at IS NULL)")
254 |
255 | if hints:
256 | return "Context:\n" + "\n".join(f"- {hint}" for hint in hints) + "\n"
257 |
258 | return ""
259 |
260 | def _describe_query_intent(self, query_intent: QueryIntent) -> str:
261 | """Convert QueryIntent to human-readable description."""
262 | lines = []
263 |
264 | lines.append(f"Query Type: {query_intent.query_type.name}")
265 | lines.append(f"Tables: {', '.join(query_intent.tables)}")
266 |
267 | if query_intent.columns:
268 | cols = [str(col) for col in query_intent.columns]
269 | lines.append(f"Columns: {', '.join(cols)}")
270 |
271 | if query_intent.joins:
272 | lines.append("Joins:")
273 | for join in query_intent.joins:
274 | lines.append(
275 | f" - {join.join_type.value} JOIN {join.right_table} "
276 | f"ON {join.left_table}.{join.left_column} = "
277 | f"{join.right_table}.{join.right_column}"
278 | )
279 |
280 | if query_intent.conditions:
281 | lines.append("Conditions: [Complex condition group]")
282 |
283 | if query_intent.group_by:
284 | cols = [str(col) for col in query_intent.group_by]
285 | lines.append(f"Group By: {', '.join(cols)}")
286 |
287 | if query_intent.order_by:
288 | order_specs = []
289 | for order in query_intent.order_by:
290 | direction = "ASC" if order.ascending else "DESC"
291 | order_specs.append(f"{order.column} {direction}")
292 | lines.append(f"Order By: {', '.join(order_specs)}")
293 |
294 | if query_intent.limit:
295 | lines.append(f"Limit: {query_intent.limit}")
296 | if query_intent.offset:
297 | lines.append(f"Offset: {query_intent.offset}")
298 |
299 | return "\n".join(lines)
300 |
301 | def _format_performance_stats(self, stats: Dict[str, Any]) -> str:
302 | """Format performance statistics."""
303 | lines = []
304 |
305 | if 'execution_time' in stats:
306 | lines.append(f"- Execution Time: {stats['execution_time']}ms")
307 |
308 | if 'rows_examined' in stats:
309 | lines.append(f"- Rows Examined: {stats['rows_examined']}")
310 |
311 | if 'rows_returned' in stats:
312 | lines.append(f"- Rows Returned: {stats['rows_returned']}")
313 |
314 | if 'index_used' in stats:
315 | lines.append(f"- Index Used: {stats['index_used']}")
316 |
317 | return "\n".join(lines)
--------------------------------------------------------------------------------
/cognidb/ai/query_generator.py:
--------------------------------------------------------------------------------
1 | """Query generator using LLM with advanced features."""
2 |
3 | import re
4 | from typing import Dict, Any, List, Optional, Tuple
5 | from ..core.query_intent import QueryIntent, QueryType, Column, Condition, ComparisonOperator
6 | from ..core.exceptions import TranslationError
7 | from ..security.sanitizer import InputSanitizer
8 | from .llm_manager import LLMManager
9 | from .prompt_builder import PromptBuilder
10 |
11 |
12 | class QueryGenerator:
13 | """
14 | Generates SQL queries from natural language using LLM.
15 |
16 | Features:
17 | - Natural language to SQL conversion
18 | - Query intent parsing
19 | - Schema-aware generation
20 | - Query validation and correction
21 | - Caching for repeated queries
22 | """
23 |
24 | def __init__(self,
25 | llm_manager: LLMManager,
26 | database_type: str = 'postgresql'):
27 | """
28 | Initialize query generator.
29 |
30 | Args:
31 | llm_manager: LLM manager instance
32 | database_type: Type of database
33 | """
34 | self.llm_manager = llm_manager
35 | self.database_type = database_type
36 | self.prompt_builder = PromptBuilder(database_type)
37 | self.sanitizer = InputSanitizer()
38 |
39 | def generate_sql(self,
40 | natural_language_query: str,
41 | schema: Dict[str, Dict[str, str]],
42 | examples: Optional[List[Dict[str, str]]] = None,
43 | context: Optional[Dict[str, Any]] = None) -> str:
44 | """
45 | Generate SQL from natural language query.
46 |
47 | Args:
48 | natural_language_query: User's query in natural language
49 | schema: Database schema
50 | examples: Optional few-shot examples
51 | context: Optional context information
52 |
53 | Returns:
54 | Generated SQL query
55 |
56 | Raises:
57 | TranslationError: If SQL generation fails
58 | """
59 | # Sanitize input
60 | sanitized_query = self.sanitizer.sanitize_natural_language(natural_language_query)
61 |
62 | # Build prompt
63 | prompt = self.prompt_builder.build_sql_generation_prompt(
64 | sanitized_query,
65 | schema,
66 | examples,
67 | context
68 | )
69 |
70 | # Generate SQL
71 | try:
72 | response = self.llm_manager.generate(prompt)
73 | sql_query = self._extract_sql(response.content)
74 |
75 | # Validate basic SQL structure
76 | if not self._is_valid_sql(sql_query):
77 | raise TranslationError("Generated SQL is invalid")
78 |
79 | return sql_query
80 |
81 | except Exception as e:
82 | raise TranslationError(f"Failed to generate SQL: {str(e)}")
83 |
84 | def parse_to_intent(self,
85 | natural_language_query: str,
86 | schema: Dict[str, Dict[str, str]]) -> QueryIntent:
87 | """
88 | Parse natural language to QueryIntent.
89 |
90 | Args:
91 | natural_language_query: User's query
92 | schema: Database schema
93 |
94 | Returns:
95 | Parsed QueryIntent
96 | """
97 | # First generate SQL
98 | sql_query = self.generate_sql(natural_language_query, schema)
99 |
100 | # Then parse SQL to QueryIntent
101 | return self._parse_sql_to_intent(sql_query, schema)
102 |
103 | def explain_query(self,
104 | sql_query: str,
105 | schema: Dict[str, Dict[str, str]]) -> str:
106 | """
107 | Generate natural language explanation of SQL query.
108 |
109 | Args:
110 | sql_query: SQL query to explain
111 | schema: Database schema
112 |
113 | Returns:
114 | Natural language explanation
115 | """
116 | prompt = self.prompt_builder.build_query_explanation_prompt(sql_query, schema)
117 |
118 | try:
119 | response = self.llm_manager.generate(prompt, temperature=0.3)
120 | return response.content
121 | except Exception as e:
122 | return f"Could not generate explanation: {str(e)}"
123 |
124 | def optimize_query(self,
125 | sql_query: str,
126 | schema: Dict[str, Dict[str, str]],
127 | performance_stats: Optional[Dict[str, Any]] = None) -> Tuple[str, str]:
128 | """
129 | Generate optimized version of SQL query.
130 |
131 | Args:
132 | sql_query: SQL query to optimize
133 | schema: Database schema with indexes
134 | performance_stats: Optional performance statistics
135 |
136 | Returns:
137 | Tuple of (optimized_query, explanation)
138 | """
139 | prompt = self.prompt_builder.build_optimization_prompt(
140 | sql_query,
141 | schema,
142 | performance_stats
143 | )
144 |
145 | try:
146 | response = self.llm_manager.generate(prompt, temperature=0.2)
147 |
148 | # Extract optimized query and explanation
149 | content = response.content
150 | if "```sql" in content:
151 | # Extract SQL from markdown
152 | sql_match = re.search(r'```sql\n(.*?)\n```', content, re.DOTALL)
153 | if sql_match:
154 | optimized_query = sql_match.group(1).strip()
155 | explanation = content.replace(sql_match.group(0), '').strip()
156 | return optimized_query, explanation
157 |
158 | # Try to split by common patterns
159 | lines = content.split('\n')
160 | sql_lines = []
161 | explanation_lines = []
162 | in_sql = True
163 |
164 | for line in lines:
165 | if line.strip() and not line.startswith(('--', '#', '//')):
166 | if in_sql and any(keyword in line.upper() for keyword in
167 | ['EXPLANATION:', 'CHANGES:', 'OPTIMIZATION:']):
168 | in_sql = False
169 |
170 | if in_sql:
171 | sql_lines.append(line)
172 | else:
173 | explanation_lines.append(line)
174 |
175 | optimized_query = '\n'.join(sql_lines).strip()
176 | explanation = '\n'.join(explanation_lines).strip()
177 |
178 | return optimized_query, explanation
179 |
180 | except Exception as e:
181 | return sql_query, f"Could not optimize: {str(e)}"
182 |
183 | def suggest_queries(self,
184 | partial_query: str,
185 | schema: Dict[str, Dict[str, str]],
186 | num_suggestions: int = 3) -> List[str]:
187 | """
188 | Generate query suggestions based on partial input.
189 |
190 | Args:
191 | partial_query: Partial natural language query
192 | schema: Database schema
193 | num_suggestions: Number of suggestions to generate
194 |
195 | Returns:
196 | List of suggested queries
197 | """
198 | prompt = f"""Based on the database schema and partial query, suggest {num_suggestions} complete queries.
199 |
200 | Database Schema:
201 | {self.prompt_builder._build_schema_description(schema)}
202 |
203 | Partial Query: {partial_query}
204 |
205 | Generate {num_suggestions} relevant query suggestions that complete or expand on the partial query.
206 | Format each suggestion on a new line starting with "- ".
207 |
208 | Suggestions:"""
209 |
210 | try:
211 | response = self.llm_manager.generate(prompt, temperature=0.7)
212 |
213 | # Extract suggestions
214 | suggestions = []
215 | for line in response.content.split('\n'):
216 | if line.strip().startswith('- '):
217 | suggestion = line.strip()[2:].strip()
218 | if suggestion:
219 | suggestions.append(suggestion)
220 |
221 | return suggestions[:num_suggestions]
222 |
223 | except Exception:
224 | return []
225 |
226 | def _extract_sql(self, llm_response: str) -> str:
227 | """Extract SQL query from LLM response."""
228 | # Remove markdown code blocks if present
229 | if "```sql" in llm_response:
230 | match = re.search(r'```sql\n(.*?)\n```', llm_response, re.DOTALL)
231 | if match:
232 | llm_response = match.group(1)
233 | elif "```" in llm_response:
234 | match = re.search(r'```\n(.*?)\n```', llm_response, re.DOTALL)
235 | if match:
236 | llm_response = match.group(1)
237 |
238 | # Clean up the response
239 | sql_query = llm_response.strip()
240 |
241 | # Remove any leading/trailing quotes
242 | if sql_query.startswith('"') and sql_query.endswith('"'):
243 | sql_query = sql_query[1:-1]
244 | elif sql_query.startswith("'") and sql_query.endswith("'"):
245 | sql_query = sql_query[1:-1]
246 |
247 | # Ensure it ends with semicolon
248 | if not sql_query.endswith(';'):
249 | sql_query += ';'
250 |
251 | return sql_query
252 |
253 | def _is_valid_sql(self, sql_query: str) -> bool:
254 | """Basic SQL validation."""
255 | if not sql_query or not sql_query.strip():
256 | return False
257 |
258 | # Check for basic SQL structure
259 | sql_upper = sql_query.upper()
260 | valid_starts = ['SELECT', 'WITH', 'SHOW', 'DESCRIBE', 'EXPLAIN']
261 |
262 | return any(sql_upper.strip().startswith(start) for start in valid_starts)
263 |
264 | def _parse_sql_to_intent(self, sql_query: str, schema: Dict[str, Dict[str, str]]) -> QueryIntent:
265 | """
266 | Parse SQL query to QueryIntent (simplified version).
267 |
268 | This is a basic implementation. In production, you'd want
269 | a full SQL parser.
270 | """
271 | # Extract tables (basic regex approach)
272 | tables = []
273 | from_match = re.search(r'FROM\s+(\w+)', sql_query, re.IGNORECASE)
274 | if from_match:
275 | tables.append(from_match.group(1))
276 |
277 | # Extract columns (basic approach)
278 | columns = []
279 | select_match = re.search(r'SELECT\s+(.*?)\s+FROM', sql_query, re.IGNORECASE | re.DOTALL)
280 | if select_match:
281 | column_str = select_match.group(1)
282 | if column_str.strip() == '*':
283 | columns = [Column('*')]
284 | else:
285 | # Simple split by comma (doesn't handle complex cases)
286 | for col in column_str.split(','):
287 | col = col.strip()
288 | if ' AS ' in col.upper():
289 | parts = re.split(r'\s+AS\s+', col, flags=re.IGNORECASE)
290 | columns.append(Column(parts[0].strip(), alias=parts[1].strip()))
291 | else:
292 | columns.append(Column(col))
293 |
294 | # Create basic QueryIntent
295 | intent = QueryIntent(
296 | query_type=QueryType.SELECT,
297 | tables=tables,
298 | columns=columns,
299 | natural_language_query=sql_query
300 | )
301 |
302 | return intent
--------------------------------------------------------------------------------
/cognidb/ai/cost_tracker.py:
--------------------------------------------------------------------------------
1 | """Cost tracking for LLM usage."""
2 |
3 | import time
4 | from datetime import datetime, timedelta
5 | from typing import Dict, Any, List, Optional
6 | from collections import defaultdict
7 | import json
8 | from pathlib import Path
9 |
10 |
11 | class CostTracker:
12 | """
13 | Tracks LLM usage and costs.
14 |
15 | Features:
16 | - Token usage tracking
17 | - Cost calculation and limits
18 | - Daily/monthly aggregation
19 | - Persistent storage
20 | """
21 |
22 | def __init__(self,
23 | max_daily_cost: float = 100.0,
24 | storage_path: Optional[str] = None):
25 | """
26 | Initialize cost tracker.
27 |
28 | Args:
29 | max_daily_cost: Maximum allowed daily cost
30 | storage_path: Path to store usage data
31 | """
32 | self.max_daily_cost = max_daily_cost
33 | self.storage_path = storage_path or str(
34 | Path.home() / '.cognidb' / 'usage.json'
35 | )
36 |
37 | # Usage data structure
38 | self.usage_data = defaultdict(lambda: {
39 | 'requests': 0,
40 | 'tokens': {'prompt': 0, 'completion': 0, 'total': 0},
41 | 'cost': 0.0,
42 | 'models': defaultdict(int)
43 | })
44 |
45 | # Load existing data
46 | self._load_usage_data()
47 |
48 | def track_usage(self,
49 | cost: float,
50 | token_usage: Dict[str, int],
51 | model: Optional[str] = None) -> None:
52 | """
53 | Track usage for a request.
54 |
55 | Args:
56 | cost: Cost of the request
57 | token_usage: Token usage statistics
58 | model: Model used
59 | """
60 | today = datetime.now().strftime('%Y-%m-%d')
61 |
62 | # Update daily stats
63 | self.usage_data[today]['requests'] += 1
64 | self.usage_data[today]['cost'] += cost
65 | self.usage_data[today]['tokens']['prompt'] += token_usage.get('prompt_tokens', 0)
66 | self.usage_data[today]['tokens']['completion'] += token_usage.get('completion_tokens', 0)
67 | self.usage_data[today]['tokens']['total'] += token_usage.get('total_tokens', 0)
68 |
69 | if model:
70 | self.usage_data[today]['models'][model] += 1
71 |
72 | # Save data
73 | self._save_usage_data()
74 |
75 | def get_daily_cost(self, date: Optional[str] = None) -> float:
76 | """
77 | Get cost for a specific day.
78 |
79 | Args:
80 | date: Date in YYYY-MM-DD format (default: today)
81 |
82 | Returns:
83 | Daily cost
84 | """
85 | if date is None:
86 | date = datetime.now().strftime('%Y-%m-%d')
87 |
88 | return self.usage_data.get(date, {}).get('cost', 0.0)
89 |
90 | def get_monthly_cost(self, year: int, month: int) -> float:
91 | """
92 | Get cost for a specific month.
93 |
94 | Args:
95 | year: Year
96 | month: Month (1-12)
97 |
98 | Returns:
99 | Monthly cost
100 | """
101 | total_cost = 0.0
102 | month_str = f"{year:04d}-{month:02d}"
103 |
104 | for date, data in self.usage_data.items():
105 | if date.startswith(month_str):
106 | total_cost += data.get('cost', 0.0)
107 |
108 | return total_cost
109 |
110 | def get_total_cost(self) -> float:
111 | """Get total cost across all time."""
112 | return sum(data.get('cost', 0.0) for data in self.usage_data.values())
113 |
114 | def get_token_usage(self, date: Optional[str] = None) -> Dict[str, int]:
115 | """
116 | Get token usage for a specific day.
117 |
118 | Args:
119 | date: Date in YYYY-MM-DD format (default: today)
120 |
121 | Returns:
122 | Token usage statistics
123 | """
124 | if date is None:
125 | date = datetime.now().strftime('%Y-%m-%d')
126 |
127 | return self.usage_data.get(date, {}).get('tokens', {
128 | 'prompt': 0,
129 | 'completion': 0,
130 | 'total': 0
131 | })
132 |
133 | def is_limit_exceeded(self, date: Optional[str] = None) -> bool:
134 | """
135 | Check if daily cost limit is exceeded.
136 |
137 | Args:
138 | date: Date to check (default: today)
139 |
140 | Returns:
141 | True if limit exceeded
142 | """
143 | daily_cost = self.get_daily_cost(date)
144 | return daily_cost >= self.max_daily_cost
145 |
146 | def get_remaining_budget(self, date: Optional[str] = None) -> float:
147 | """
148 | Get remaining budget for the day.
149 |
150 | Args:
151 | date: Date to check (default: today)
152 |
153 | Returns:
154 | Remaining budget
155 | """
156 | daily_cost = self.get_daily_cost(date)
157 | return max(0, self.max_daily_cost - daily_cost)
158 |
159 | def get_usage_summary(self, days: int = 7) -> Dict[str, Any]:
160 | """
161 | Get usage summary for recent days.
162 |
163 | Args:
164 | days: Number of days to include
165 |
166 | Returns:
167 | Usage summary
168 | """
169 | end_date = datetime.now()
170 | start_date = end_date - timedelta(days=days-1)
171 |
172 | summary = {
173 | 'period': f"{start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')}",
174 | 'total_cost': 0.0,
175 | 'total_requests': 0,
176 | 'total_tokens': 0,
177 | 'daily_breakdown': [],
178 | 'model_usage': defaultdict(int)
179 | }
180 |
181 | current_date = start_date
182 | while current_date <= end_date:
183 | date_str = current_date.strftime('%Y-%m-%d')
184 | if date_str in self.usage_data:
185 | data = self.usage_data[date_str]
186 | summary['total_cost'] += data['cost']
187 | summary['total_requests'] += data['requests']
188 | summary['total_tokens'] += data['tokens']['total']
189 |
190 | summary['daily_breakdown'].append({
191 | 'date': date_str,
192 | 'cost': data['cost'],
193 | 'requests': data['requests'],
194 | 'tokens': data['tokens']['total']
195 | })
196 |
197 | for model, count in data['models'].items():
198 | summary['model_usage'][model] += count
199 |
200 | current_date += timedelta(days=1)
201 |
202 | summary['model_usage'] = dict(summary['model_usage'])
203 | summary['average_daily_cost'] = summary['total_cost'] / days
204 | summary['average_cost_per_request'] = (
205 | summary['total_cost'] / summary['total_requests']
206 | if summary['total_requests'] > 0 else 0
207 | )
208 |
209 | return summary
210 |
211 | def cleanup_old_data(self, days_to_keep: int = 90) -> None:
212 | """
213 | Remove usage data older than specified days.
214 |
215 | Args:
216 | days_to_keep: Number of days of data to keep
217 | """
218 | cutoff_date = datetime.now() - timedelta(days=days_to_keep)
219 | cutoff_str = cutoff_date.strftime('%Y-%m-%d')
220 |
221 | dates_to_remove = [
222 | date for date in self.usage_data.keys()
223 | if date < cutoff_str
224 | ]
225 |
226 | for date in dates_to_remove:
227 | del self.usage_data[date]
228 |
229 | if dates_to_remove:
230 | self._save_usage_data()
231 |
232 | def export_usage_report(self,
233 | start_date: str,
234 | end_date: str,
235 | format: str = 'json') -> str:
236 | """
237 | Export usage report for date range.
238 |
239 | Args:
240 | start_date: Start date (YYYY-MM-DD)
241 | end_date: End date (YYYY-MM-DD)
242 | format: Export format (json, csv)
243 |
244 | Returns:
245 | Formatted report
246 | """
247 | report_data = []
248 |
249 | current = datetime.strptime(start_date, '%Y-%m-%d')
250 | end = datetime.strptime(end_date, '%Y-%m-%d')
251 |
252 | while current <= end:
253 | date_str = current.strftime('%Y-%m-%d')
254 | if date_str in self.usage_data:
255 | data = self.usage_data[date_str]
256 | report_data.append({
257 | 'date': date_str,
258 | 'requests': data['requests'],
259 | 'cost': data['cost'],
260 | 'prompt_tokens': data['tokens']['prompt'],
261 | 'completion_tokens': data['tokens']['completion'],
262 | 'total_tokens': data['tokens']['total'],
263 | 'models': dict(data['models'])
264 | })
265 | current += timedelta(days=1)
266 |
267 | if format == 'json':
268 | return json.dumps(report_data, indent=2)
269 | elif format == 'csv':
270 | import csv
271 | import io
272 |
273 | output = io.StringIO()
274 | if report_data:
275 | writer = csv.DictWriter(
276 | output,
277 | fieldnames=['date', 'requests', 'cost', 'prompt_tokens',
278 | 'completion_tokens', 'total_tokens']
279 | )
280 | writer.writeheader()
281 | for row in report_data:
282 | # Flatten models field
283 | row_copy = row.copy()
284 | del row_copy['models']
285 | writer.writerow(row_copy)
286 |
287 | return output.getvalue()
288 | else:
289 | raise ValueError(f"Unsupported format: {format}")
290 |
291 | def _load_usage_data(self) -> None:
292 | """Load usage data from storage."""
293 | try:
294 | if Path(self.storage_path).exists():
295 | with open(self.storage_path, 'r') as f:
296 | data = json.load(f)
297 | # Convert to defaultdict structure
298 | for date, usage in data.items():
299 | self.usage_data[date] = usage
300 | if 'models' in usage:
301 | self.usage_data[date]['models'] = defaultdict(
302 | int, usage['models']
303 | )
304 | except Exception:
305 | # If loading fails, start fresh
306 | pass
307 |
308 | def _save_usage_data(self) -> None:
309 | """Save usage data to storage."""
310 | try:
311 | # Ensure directory exists
312 | Path(self.storage_path).parent.mkdir(parents=True, exist_ok=True)
313 |
314 | # Convert defaultdict to regular dict for JSON serialization
315 | data_to_save = {}
316 | for date, usage in self.usage_data.items():
317 | data_to_save[date] = {
318 | 'requests': usage['requests'],
319 | 'tokens': usage['tokens'],
320 | 'cost': usage['cost'],
321 | 'models': dict(usage['models'])
322 | }
323 |
324 | with open(self.storage_path, 'w') as f:
325 | json.dump(data_to_save, f, indent=2)
326 | except Exception:
327 | # Log error but don't fail the request
328 | pass
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | CogniDB - Secure Natural Language Database Interface
3 |
4 | A production-ready natural language to SQL interface with comprehensive
5 | security, multi-database support, and intelligent query generation.
6 | """
7 |
8 | __version__ = "2.0.0"
9 | __author__ = "CogniDB Team"
10 |
11 | import logging
12 | from typing import Dict, Any, Optional, List, Union
13 | from pathlib import Path
14 |
15 | # Core imports
16 | from .cognidb.core.exceptions import CogniDBError
17 | from .cognidb.config import ConfigLoader, Settings, DatabaseType
18 | from .cognidb.security import QuerySecurityValidator, AccessController, InputSanitizer
19 | from .cognidb.ai import LLMManager, QueryGenerator
20 | from .cognidb.drivers import (
21 | MySQLDriver,
22 | PostgreSQLDriver,
23 | MongoDBDriver,
24 | DynamoDBDriver,
25 | SQLiteDriver
26 | )
27 |
28 | # Setup logging
29 | logging.basicConfig(
30 | level=logging.INFO,
31 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
32 | )
33 | logger = logging.getLogger(__name__)
34 |
35 |
36 | class CogniDB:
37 | """
38 | Main CogniDB interface for natural language database queries.
39 |
40 | Features:
41 | - Natural language to SQL conversion
42 | - Multi-database support (MySQL, PostgreSQL, MongoDB, DynamoDB)
43 | - Comprehensive security validation
44 | - Query optimization and caching
45 | - Cost tracking and limits
46 | - Audit logging
47 | """
48 |
49 | def __init__(self,
50 | config_file: Optional[str] = None,
51 | **kwargs):
52 | """
53 | Initialize CogniDB.
54 |
55 | Args:
56 | config_file: Path to configuration file
57 | **kwargs: Override configuration values
58 | """
59 | # Load configuration
60 | self.config_loader = ConfigLoader(config_file)
61 | self.settings = self.config_loader.load()
62 |
63 | # Apply any overrides
64 | self._apply_config_overrides(kwargs)
65 |
66 | # Initialize components
67 | self._init_driver()
68 | self._init_security()
69 | self._init_ai()
70 | self._init_cache()
71 |
72 | # Connect to database
73 | self.driver.connect()
74 |
75 | # Cache schema
76 | self.schema = self.driver.fetch_schema()
77 |
78 | logger.info("CogniDB initialized successfully")
79 |
80 | def query(self,
81 | natural_language_query: str,
82 | user_id: Optional[str] = None,
83 | explain: bool = False) -> Dict[str, Any]:
84 | """
85 | Execute a natural language query.
86 |
87 | Args:
88 | natural_language_query: Query in natural language
89 | user_id: Optional user ID for access control
90 | explain: Whether to include query explanation
91 |
92 | Returns:
93 | Dictionary with results and metadata
94 | """
95 | try:
96 | # Sanitize input
97 | sanitized_query = self.input_sanitizer.sanitize_natural_language(
98 | natural_language_query
99 | )
100 |
101 | # Check access permissions
102 | if self.settings.security.enable_access_control and user_id:
103 | # This would integrate with your access control system
104 | pass
105 |
106 | # Generate SQL from natural language
107 | sql_query = self.query_generator.generate_sql(
108 | sanitized_query,
109 | self.schema,
110 | examples=self.settings.llm.few_shot_examples
111 | )
112 |
113 | # Validate generated SQL
114 | is_valid, error = self.security_validator.validate_native_query(sql_query)
115 | if not is_valid:
116 | raise CogniDBError(f"Security validation failed: {error}")
117 |
118 | # Execute query
119 | results = self.driver.execute_native_query(sql_query)
120 |
121 | # Prepare response
122 | response = {
123 | 'success': True,
124 | 'query': natural_language_query,
125 | 'sql': sql_query,
126 | 'results': results,
127 | 'row_count': len(results),
128 | 'execution_time': None # Would be tracked by driver
129 | }
130 |
131 | # Add explanation if requested
132 | if explain:
133 | response['explanation'] = self.query_generator.explain_query(
134 | sql_query,
135 | self.schema
136 | )
137 |
138 | # Log query for audit
139 | self._audit_log(user_id, natural_language_query, sql_query, True)
140 |
141 | return response
142 |
143 | except Exception as e:
144 | logger.error(f"Query execution failed: {str(e)}")
145 | self._audit_log(user_id, natural_language_query, None, False, str(e))
146 |
147 | return {
148 | 'success': False,
149 | 'query': natural_language_query,
150 | 'error': str(e)
151 | }
152 |
153 | def optimize_query(self, sql_query: str) -> Dict[str, Any]:
154 | """
155 | Get optimization suggestions for a SQL query.
156 |
157 | Args:
158 | sql_query: SQL query to optimize
159 |
160 | Returns:
161 | Dictionary with optimized query and explanation
162 | """
163 | try:
164 | optimized_sql, explanation = self.query_generator.optimize_query(
165 | sql_query,
166 | self.schema
167 | )
168 |
169 | return {
170 | 'success': True,
171 | 'original_query': sql_query,
172 | 'optimized_query': optimized_sql,
173 | 'explanation': explanation
174 | }
175 |
176 | except Exception as e:
177 | logger.error(f"Query optimization failed: {str(e)}")
178 | return {
179 | 'success': False,
180 | 'error': str(e)
181 | }
182 |
183 | def suggest_queries(self, partial_query: str) -> List[str]:
184 | """
185 | Get query suggestions based on partial input.
186 |
187 | Args:
188 | partial_query: Partial natural language query
189 |
190 | Returns:
191 | List of suggested queries
192 | """
193 | try:
194 | return self.query_generator.suggest_queries(
195 | partial_query,
196 | self.schema
197 | )
198 | except Exception as e:
199 | logger.error(f"Failed to generate suggestions: {str(e)}")
200 | return []
201 |
202 | def get_schema(self, table_name: Optional[str] = None) -> Dict[str, Any]:
203 | """
204 | Get database schema information.
205 |
206 | Args:
207 | table_name: Optional specific table name
208 |
209 | Returns:
210 | Schema information
211 | """
212 | if table_name:
213 | return {
214 | table_name: self.schema.get(table_name, {})
215 | }
216 | return self.schema
217 |
218 | def get_usage_stats(self) -> Dict[str, Any]:
219 | """Get usage statistics including costs."""
220 | return self.llm_manager.get_usage_stats()
221 |
222 | def close(self):
223 | """Close all connections and cleanup resources."""
224 | try:
225 | if hasattr(self, 'driver'):
226 | self.driver.disconnect()
227 | logger.info("CogniDB closed successfully")
228 | except Exception as e:
229 | logger.error(f"Error closing CogniDB: {str(e)}")
230 |
231 | def __enter__(self):
232 | """Context manager entry."""
233 | return self
234 |
235 | def __exit__(self, exc_type, exc_val, exc_tb):
236 | """Context manager exit."""
237 | self.close()
238 |
239 | # Private methods
240 |
241 | def _init_driver(self):
242 | """Initialize database driver."""
243 | driver_map = {
244 | DatabaseType.MYSQL: MySQLDriver,
245 | DatabaseType.POSTGRESQL: PostgreSQLDriver,
246 | DatabaseType.MONGODB: MongoDBDriver,
247 | DatabaseType.DYNAMODB: DynamoDBDriver,
248 | DatabaseType.SQLITE: SQLiteDriver
249 | }
250 |
251 | driver_class = driver_map.get(self.settings.database.type)
252 | if not driver_class:
253 | raise CogniDBError(f"Unsupported database type: {self.settings.database.type}")
254 |
255 | # Convert settings to driver config
256 | driver_config = {
257 | 'host': self.settings.database.host,
258 | 'port': self.settings.database.port,
259 | 'database': self.settings.database.database,
260 | 'username': self.settings.database.username,
261 | 'password': self.settings.database.password,
262 | 'ssl_enabled': self.settings.database.ssl_enabled,
263 | 'query_timeout': self.settings.database.query_timeout,
264 | 'max_result_size': self.settings.database.max_result_size,
265 | **self.settings.database.options
266 | }
267 |
268 | self.driver = driver_class(driver_config)
269 |
270 | def _init_security(self):
271 | """Initialize security components."""
272 | self.security_validator = QuerySecurityValidator(
273 | allowed_operations=['SELECT'],
274 | max_query_complexity=self.settings.security.max_query_complexity,
275 | allow_subqueries=self.settings.security.allow_subqueries
276 | )
277 |
278 | self.access_controller = AccessController()
279 | self.input_sanitizer = InputSanitizer()
280 |
281 | def _init_ai(self):
282 | """Initialize AI components."""
283 | # Initialize LLM manager
284 | self.llm_manager = LLMManager(
285 | self.settings.llm,
286 | cache_provider=None # Will be set after cache init
287 | )
288 |
289 | # Initialize query generator
290 | self.query_generator = QueryGenerator(
291 | self.llm_manager,
292 | database_type=self.settings.database.type.value
293 | )
294 |
295 | def _init_cache(self):
296 | """Initialize caching layer."""
297 | # For now, using in-memory cache from LLM manager
298 | # In production, would initialize Redis/Memcached here
299 | pass
300 |
301 | def _apply_config_overrides(self, overrides: Dict[str, Any]):
302 | """Apply configuration overrides."""
303 | for key, value in overrides.items():
304 | if hasattr(self.settings, key):
305 | setattr(self.settings, key, value)
306 |
307 | def _audit_log(self,
308 | user_id: Optional[str],
309 | natural_language_query: str,
310 | sql_query: Optional[str],
311 | success: bool,
312 | error: Optional[str] = None):
313 | """Log query for audit trail."""
314 | if not self.settings.security.enable_audit_logging:
315 | return
316 |
317 | # In production, this would write to a proper audit log
318 | logger.info(
319 | f"AUDIT: user={user_id}, query={natural_language_query[:50]}..., "
320 | f"success={success}, error={error}"
321 | )
322 |
323 |
324 | # Convenience function for quick usage
325 | def create_cognidb(**kwargs) -> CogniDB:
326 | """
327 | Create a CogniDB instance with configuration.
328 |
329 | Args:
330 | **kwargs: Configuration parameters
331 |
332 | Returns:
333 | CogniDB instance
334 | """
335 | return CogniDB(**kwargs)
--------------------------------------------------------------------------------
/cognidb/config/secrets.py:
--------------------------------------------------------------------------------
1 | """Secrets management for sensitive configuration."""
2 |
3 | import os
4 | import json
5 | import base64
6 | from typing import Dict, Any, Optional
7 | from pathlib import Path
8 | from cryptography.fernet import Fernet
9 | from cryptography.hazmat.primitives import hashes
10 | from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
11 | from ..core.exceptions import ConfigurationError
12 |
13 |
14 | class SecretsManager:
15 | """
16 | Secure secrets management.
17 |
18 | Supports:
19 | - Environment variables
20 | - Encrypted file storage
21 | - AWS Secrets Manager
22 | - HashiCorp Vault
23 | - Azure Key Vault
24 | """
25 |
26 | def __init__(self, provider: str = "env", **kwargs):
27 | """
28 | Initialize secrets manager.
29 |
30 | Args:
31 | provider: One of 'env', 'file', 'aws', 'vault', 'azure'
32 | **kwargs: Provider-specific configuration
33 | """
34 | self.provider = provider
35 | self.config = kwargs
36 | self._cache: Dict[str, Any] = {}
37 | self._cipher: Optional[Fernet] = None
38 |
39 | if provider == "file":
40 | self._init_file_provider()
41 | elif provider == "aws":
42 | self._init_aws_provider()
43 | elif provider == "vault":
44 | self._init_vault_provider()
45 | elif provider == "azure":
46 | self._init_azure_provider()
47 |
48 | def get_secret(self, key: str, default: Any = None) -> Any:
49 | """
50 | Retrieve a secret value.
51 |
52 | Args:
53 | key: Secret key
54 | default: Default value if not found
55 |
56 | Returns:
57 | Secret value
58 | """
59 | # Check cache first
60 | if key in self._cache:
61 | return self._cache[key]
62 |
63 | value = None
64 |
65 | if self.provider == "env":
66 | value = os.getenv(key, default)
67 | elif self.provider == "file":
68 | value = self._get_file_secret(key, default)
69 | elif self.provider == "aws":
70 | value = self._get_aws_secret(key, default)
71 | elif self.provider == "vault":
72 | value = self._get_vault_secret(key, default)
73 | elif self.provider == "azure":
74 | value = self._get_azure_secret(key, default)
75 |
76 | # Cache the value
77 | if value is not None:
78 | self._cache[key] = value
79 |
80 | return value
81 |
82 | def set_secret(self, key: str, value: Any) -> None:
83 | """
84 | Store a secret value.
85 |
86 | Args:
87 | key: Secret key
88 | value: Secret value
89 | """
90 | if self.provider == "env":
91 | os.environ[key] = str(value)
92 | elif self.provider == "file":
93 | self._set_file_secret(key, value)
94 | elif self.provider == "aws":
95 | self._set_aws_secret(key, value)
96 | elif self.provider == "vault":
97 | self._set_vault_secret(key, value)
98 | elif self.provider == "azure":
99 | self._set_azure_secret(key, value)
100 |
101 | # Update cache
102 | self._cache[key] = value
103 |
104 | def delete_secret(self, key: str) -> None:
105 | """Delete a secret."""
106 | if self.provider == "env":
107 | os.environ.pop(key, None)
108 | elif self.provider == "file":
109 | self._delete_file_secret(key)
110 | elif self.provider == "aws":
111 | self._delete_aws_secret(key)
112 | elif self.provider == "vault":
113 | self._delete_vault_secret(key)
114 | elif self.provider == "azure":
115 | self._delete_azure_secret(key)
116 |
117 | # Remove from cache
118 | self._cache.pop(key, None)
119 |
120 | def clear_cache(self) -> None:
121 | """Clear the secrets cache."""
122 | self._cache.clear()
123 |
124 | # File-based provider methods
125 |
126 | def _init_file_provider(self) -> None:
127 | """Initialize file-based secrets provider."""
128 | secrets_file = self.config.get('secrets_file',
129 | str(Path.home() / '.cognidb' / 'secrets.enc'))
130 | master_password = self.config.get('master_password')
131 |
132 | if not master_password:
133 | master_password = os.getenv('COGNIDB_MASTER_PASSWORD')
134 | if not master_password:
135 | raise ConfigurationError(
136 | "Master password required for file-based secrets"
137 | )
138 |
139 | # Create directory if needed
140 | Path(secrets_file).parent.mkdir(parents=True, exist_ok=True)
141 |
142 | # Generate encryption key from password
143 | kdf = PBKDF2HMAC(
144 | algorithm=hashes.SHA256(),
145 | length=32,
146 | salt=b'cognidb_salt', # In production, use random salt
147 | iterations=100000,
148 | )
149 | key = base64.urlsafe_b64encode(kdf.derive(master_password.encode()))
150 | self._cipher = Fernet(key)
151 |
152 | self.secrets_file = secrets_file
153 | self._load_secrets_file()
154 |
155 | def _load_secrets_file(self) -> None:
156 | """Load secrets from encrypted file."""
157 | if not Path(self.secrets_file).exists():
158 | self._secrets_data = {}
159 | return
160 |
161 | try:
162 | with open(self.secrets_file, 'rb') as f:
163 | encrypted_data = f.read()
164 |
165 | decrypted_data = self._cipher.decrypt(encrypted_data)
166 | self._secrets_data = json.loads(decrypted_data.decode())
167 | except Exception as e:
168 | raise ConfigurationError(f"Failed to load secrets file: {e}")
169 |
170 | def _save_secrets_file(self) -> None:
171 | """Save secrets to encrypted file."""
172 | try:
173 | data = json.dumps(self._secrets_data).encode()
174 | encrypted_data = self._cipher.encrypt(data)
175 |
176 | with open(self.secrets_file, 'wb') as f:
177 | f.write(encrypted_data)
178 | except Exception as e:
179 | raise ConfigurationError(f"Failed to save secrets file: {e}")
180 |
181 | def _get_file_secret(self, key: str, default: Any) -> Any:
182 | """Get secret from file."""
183 | return self._secrets_data.get(key, default)
184 |
185 | def _set_file_secret(self, key: str, value: Any) -> None:
186 | """Set secret in file."""
187 | self._secrets_data[key] = value
188 | self._save_secrets_file()
189 |
190 | def _delete_file_secret(self, key: str) -> None:
191 | """Delete secret from file."""
192 | self._secrets_data.pop(key, None)
193 | self._save_secrets_file()
194 |
195 | # AWS Secrets Manager provider methods
196 |
197 | def _init_aws_provider(self) -> None:
198 | """Initialize AWS Secrets Manager provider."""
199 | try:
200 | import boto3
201 | self._aws_client = boto3.client(
202 | 'secretsmanager',
203 | region_name=self.config.get('region', 'us-east-1')
204 | )
205 | except ImportError:
206 | raise ConfigurationError(
207 | "boto3 required for AWS Secrets Manager. Install with: pip install boto3"
208 | )
209 |
210 | def _get_aws_secret(self, key: str, default: Any) -> Any:
211 | """Get secret from AWS Secrets Manager."""
212 | try:
213 | response = self._aws_client.get_secret_value(SecretId=key)
214 | if 'SecretString' in response:
215 | secret = response['SecretString']
216 | try:
217 | return json.loads(secret)
218 | except json.JSONDecodeError:
219 | return secret
220 | else:
221 | return base64.b64decode(response['SecretBinary'])
222 | except self._aws_client.exceptions.ResourceNotFoundException:
223 | return default
224 | except Exception as e:
225 | raise ConfigurationError(f"Failed to get AWS secret: {e}")
226 |
227 | def _set_aws_secret(self, key: str, value: Any) -> None:
228 | """Set secret in AWS Secrets Manager."""
229 | try:
230 | secret_string = json.dumps(value) if not isinstance(value, str) else value
231 | try:
232 | self._aws_client.update_secret(
233 | SecretId=key,
234 | SecretString=secret_string
235 | )
236 | except self._aws_client.exceptions.ResourceNotFoundException:
237 | self._aws_client.create_secret(
238 | Name=key,
239 | SecretString=secret_string
240 | )
241 | except Exception as e:
242 | raise ConfigurationError(f"Failed to set AWS secret: {e}")
243 |
244 | def _delete_aws_secret(self, key: str) -> None:
245 | """Delete secret from AWS Secrets Manager."""
246 | try:
247 | self._aws_client.delete_secret(
248 | SecretId=key,
249 | ForceDeleteWithoutRecovery=True
250 | )
251 | except Exception as e:
252 | raise ConfigurationError(f"Failed to delete AWS secret: {e}")
253 |
254 | # HashiCorp Vault provider methods
255 |
256 | def _init_vault_provider(self) -> None:
257 | """Initialize HashiCorp Vault provider."""
258 | try:
259 | import hvac
260 | self._vault_client = hvac.Client(
261 | url=self.config.get('url', 'http://localhost:8200'),
262 | token=self.config.get('token') or os.getenv('VAULT_TOKEN')
263 | )
264 | if not self._vault_client.is_authenticated():
265 | raise ConfigurationError("Vault authentication failed")
266 | except ImportError:
267 | raise ConfigurationError(
268 | "hvac required for HashiCorp Vault. Install with: pip install hvac"
269 | )
270 |
271 | def _get_vault_secret(self, key: str, default: Any) -> Any:
272 | """Get secret from HashiCorp Vault."""
273 | try:
274 | mount_point = self.config.get('mount_point', 'secret')
275 | response = self._vault_client.secrets.kv.v2.read_secret_version(
276 | path=key,
277 | mount_point=mount_point
278 | )
279 | return response['data']['data']
280 | except Exception:
281 | return default
282 |
283 | def _set_vault_secret(self, key: str, value: Any) -> None:
284 | """Set secret in HashiCorp Vault."""
285 | try:
286 | mount_point = self.config.get('mount_point', 'secret')
287 | self._vault_client.secrets.kv.v2.create_or_update_secret(
288 | path=key,
289 | secret=dict(value=value) if not isinstance(value, dict) else value,
290 | mount_point=mount_point
291 | )
292 | except Exception as e:
293 | raise ConfigurationError(f"Failed to set Vault secret: {e}")
294 |
295 | def _delete_vault_secret(self, key: str) -> None:
296 | """Delete secret from HashiCorp Vault."""
297 | try:
298 | mount_point = self.config.get('mount_point', 'secret')
299 | self._vault_client.secrets.kv.v2.delete_metadata_and_all_versions(
300 | path=key,
301 | mount_point=mount_point
302 | )
303 | except Exception as e:
304 | raise ConfigurationError(f"Failed to delete Vault secret: {e}")
305 |
306 | # Azure Key Vault provider methods
307 |
308 | def _init_azure_provider(self) -> None:
309 | """Initialize Azure Key Vault provider."""
310 | try:
311 | from azure.keyvault.secrets import SecretClient
312 | from azure.identity import DefaultAzureCredential
313 |
314 | vault_url = self.config.get('vault_url')
315 | if not vault_url:
316 | raise ConfigurationError("vault_url required for Azure Key Vault")
317 |
318 | credential = DefaultAzureCredential()
319 | self._azure_client = SecretClient(
320 | vault_url=vault_url,
321 | credential=credential
322 | )
323 | except ImportError:
324 | raise ConfigurationError(
325 | "azure-keyvault-secrets required. Install with: "
326 | "pip install azure-keyvault-secrets azure-identity"
327 | )
328 |
329 | def _get_azure_secret(self, key: str, default: Any) -> Any:
330 | """Get secret from Azure Key Vault."""
331 | try:
332 | secret = self._azure_client.get_secret(key)
333 | return secret.value
334 | except Exception:
335 | return default
336 |
337 | def _set_azure_secret(self, key: str, value: Any) -> None:
338 | """Set secret in Azure Key Vault."""
339 | try:
340 | value_str = json.dumps(value) if not isinstance(value, str) else value
341 | self._azure_client.set_secret(key, value_str)
342 | except Exception as e:
343 | raise ConfigurationError(f"Failed to set Azure secret: {e}")
344 |
345 | def _delete_azure_secret(self, key: str) -> None:
346 | """Delete secret from Azure Key Vault."""
347 | try:
348 | poller = self._azure_client.begin_delete_secret(key)
349 | poller.wait()
350 | except Exception as e:
351 | raise ConfigurationError(f"Failed to delete Azure secret: {e}")
--------------------------------------------------------------------------------
/cognidb/config/loader.py:
--------------------------------------------------------------------------------
1 | """Configuration loader with multiple source support."""
2 |
3 | import os
4 | import json
5 | import yaml
6 | from pathlib import Path
7 | from typing import Dict, Any, Optional, List
8 | from .settings import Settings, DatabaseConfig, LLMConfig, CacheConfig, SecurityConfig
9 | from .settings import DatabaseType, LLMProvider, CacheProvider
10 | from .secrets import SecretsManager
11 | from ..core.exceptions import ConfigurationError
12 |
13 |
14 | class ConfigLoader:
15 | """
16 | Load configuration from multiple sources.
17 |
18 | Priority order (highest to lowest):
19 | 1. Environment variables
20 | 2. Config file (JSON/YAML)
21 | 3. Defaults
22 | """
23 |
24 | def __init__(self,
25 | config_file: Optional[str] = None,
26 | secrets_manager: Optional[SecretsManager] = None):
27 | """
28 | Initialize config loader.
29 |
30 | Args:
31 | config_file: Path to configuration file
32 | secrets_manager: Secrets manager instance
33 | """
34 | self.config_file = config_file or self._find_config_file()
35 | self.secrets_manager = secrets_manager or SecretsManager()
36 | self._config_data: Dict[str, Any] = {}
37 |
38 | def load(self) -> Settings:
39 | """
40 | Load configuration from all sources.
41 |
42 | Returns:
43 | Settings object
44 | """
45 | # Load from file if exists
46 | if self.config_file and Path(self.config_file).exists():
47 | self._load_from_file()
48 |
49 | # Override with environment variables
50 | self._load_from_env()
51 |
52 | # Load secrets
53 | self._load_secrets()
54 |
55 | # Create settings object
56 | settings = self._create_settings()
57 |
58 | # Validate
59 | errors = settings.validate()
60 | if errors:
61 | raise ConfigurationError(f"Configuration errors: {', '.join(errors)}")
62 |
63 | return settings
64 |
65 | def _find_config_file(self) -> Optional[str]:
66 | """Find configuration file in standard locations."""
67 | # Check environment variable
68 | if 'COGNIDB_CONFIG' in os.environ:
69 | return os.environ['COGNIDB_CONFIG']
70 |
71 | # Check standard locations
72 | locations = [
73 | 'cognidb.yaml',
74 | 'cognidb.yml',
75 | 'cognidb.json',
76 | '.cognidb.yaml',
77 | '.cognidb.yml',
78 | '.cognidb.json',
79 | str(Path.home() / '.cognidb' / 'config.yaml'),
80 | str(Path.home() / '.cognidb' / 'config.yml'),
81 | str(Path.home() / '.cognidb' / 'config.json'),
82 | '/etc/cognidb/config.yaml',
83 | '/etc/cognidb/config.yml',
84 | '/etc/cognidb/config.json',
85 | ]
86 |
87 | for location in locations:
88 | if Path(location).exists():
89 | return location
90 |
91 | return None
92 |
93 | def _load_from_file(self) -> None:
94 | """Load configuration from file."""
95 | try:
96 | with open(self.config_file, 'r') as f:
97 | if self.config_file.endswith(('.yaml', '.yml')):
98 | self._config_data = yaml.safe_load(f)
99 | else:
100 | self._config_data = json.load(f)
101 | except Exception as e:
102 | raise ConfigurationError(f"Failed to load config file: {e}")
103 |
104 | def _load_from_env(self) -> None:
105 | """Load configuration from environment variables."""
106 | # Database settings
107 | if 'DB_TYPE' in os.environ:
108 | self._set_nested('database.type', os.environ['DB_TYPE'])
109 | if 'DB_HOST' in os.environ:
110 | self._set_nested('database.host', os.environ['DB_HOST'])
111 | if 'DB_PORT' in os.environ:
112 | self._set_nested('database.port', int(os.environ['DB_PORT']))
113 | if 'DB_NAME' in os.environ:
114 | self._set_nested('database.database', os.environ['DB_NAME'])
115 | if 'DB_USER' in os.environ:
116 | self._set_nested('database.username', os.environ['DB_USER'])
117 | if 'DB_PASSWORD' in os.environ:
118 | self._set_nested('database.password', os.environ['DB_PASSWORD'])
119 |
120 | # LLM settings
121 | if 'LLM_PROVIDER' in os.environ:
122 | self._set_nested('llm.provider', os.environ['LLM_PROVIDER'])
123 | if 'LLM_API_KEY' in os.environ:
124 | self._set_nested('llm.api_key', os.environ['LLM_API_KEY'])
125 | if 'LLM_MODEL' in os.environ:
126 | self._set_nested('llm.model_name', os.environ['LLM_MODEL'])
127 |
128 | # Cache settings
129 | if 'CACHE_PROVIDER' in os.environ:
130 | self._set_nested('cache.provider', os.environ['CACHE_PROVIDER'])
131 |
132 | # Security settings
133 | if 'SECURITY_SELECT_ONLY' in os.environ:
134 | self._set_nested('security.allow_only_select',
135 | os.environ['SECURITY_SELECT_ONLY'].lower() == 'true')
136 |
137 | # Application settings
138 | if 'ENVIRONMENT' in os.environ:
139 | self._set_nested('environment', os.environ['ENVIRONMENT'])
140 | if 'DEBUG' in os.environ:
141 | self._set_nested('debug', os.environ['DEBUG'].lower() == 'true')
142 | if 'LOG_LEVEL' in os.environ:
143 | self._set_nested('log_level', os.environ['LOG_LEVEL'])
144 |
145 | def _load_secrets(self) -> None:
146 | """Load secrets from secrets manager."""
147 | # Database password
148 | if not self._get_nested('database.password'):
149 | password = self.secrets_manager.get_secret('DB_PASSWORD')
150 | if password:
151 | self._set_nested('database.password', password)
152 |
153 | # LLM API key
154 | if not self._get_nested('llm.api_key'):
155 | api_key = self.secrets_manager.get_secret('LLM_API_KEY')
156 | if api_key:
157 | self._set_nested('llm.api_key', api_key)
158 |
159 | # Redis password
160 | if not self._get_nested('cache.redis_password'):
161 | redis_password = self.secrets_manager.get_secret('REDIS_PASSWORD')
162 | if redis_password:
163 | self._set_nested('cache.redis_password', redis_password)
164 |
165 | # Encryption key
166 | if not self._get_nested('security.encryption_key'):
167 | encryption_key = self.secrets_manager.get_secret('ENCRYPTION_KEY')
168 | if encryption_key:
169 | self._set_nested('security.encryption_key', encryption_key)
170 |
171 | def _create_settings(self) -> Settings:
172 | """Create Settings object from loaded configuration."""
173 | # Database configuration
174 | db_config = self._get_nested('database', {})
175 | database = DatabaseConfig(
176 | type=DatabaseType(db_config.get('type', 'postgresql')),
177 | host=db_config.get('host', 'localhost'),
178 | port=db_config.get('port', 5432),
179 | database=db_config.get('database', 'cognidb'),
180 | username=db_config.get('username'),
181 | password=db_config.get('password'),
182 | pool_size=db_config.get('pool_size', 5),
183 | max_overflow=db_config.get('max_overflow', 10),
184 | pool_timeout=db_config.get('pool_timeout', 30),
185 | pool_recycle=db_config.get('pool_recycle', 3600),
186 | ssl_enabled=db_config.get('ssl_enabled', False),
187 | ssl_ca_cert=db_config.get('ssl_ca_cert'),
188 | ssl_client_cert=db_config.get('ssl_client_cert'),
189 | ssl_client_key=db_config.get('ssl_client_key'),
190 | query_timeout=db_config.get('query_timeout', 30),
191 | max_result_size=db_config.get('max_result_size', 10000),
192 | options=db_config.get('options', {})
193 | )
194 |
195 | # LLM configuration
196 | llm_config = self._get_nested('llm', {})
197 | llm = LLMConfig(
198 | provider=LLMProvider(llm_config.get('provider', 'openai')),
199 | api_key=llm_config.get('api_key'),
200 | model_name=llm_config.get('model_name', 'gpt-4'),
201 | temperature=llm_config.get('temperature', 0.1),
202 | max_tokens=llm_config.get('max_tokens', 1000),
203 | timeout=llm_config.get('timeout', 30),
204 | max_tokens_per_query=llm_config.get('max_tokens_per_query', 2000),
205 | max_queries_per_minute=llm_config.get('max_queries_per_minute', 60),
206 | max_cost_per_day=llm_config.get('max_cost_per_day', 100.0),
207 | system_prompt=llm_config.get('system_prompt'),
208 | few_shot_examples=llm_config.get('few_shot_examples', []),
209 | azure_endpoint=llm_config.get('azure_endpoint'),
210 | azure_deployment=llm_config.get('azure_deployment'),
211 | huggingface_model_id=llm_config.get('huggingface_model_id'),
212 | local_model_path=llm_config.get('local_model_path'),
213 | enable_function_calling=llm_config.get('enable_function_calling', True),
214 | enable_streaming=llm_config.get('enable_streaming', False),
215 | retry_attempts=llm_config.get('retry_attempts', 3),
216 | retry_delay=llm_config.get('retry_delay', 1.0)
217 | )
218 |
219 | # Cache configuration
220 | cache_config = self._get_nested('cache', {})
221 | cache = CacheConfig(
222 | provider=CacheProvider(cache_config.get('provider', 'in_memory')),
223 | query_result_ttl=cache_config.get('query_result_ttl', 3600),
224 | schema_ttl=cache_config.get('schema_ttl', 86400),
225 | llm_response_ttl=cache_config.get('llm_response_ttl', 7200),
226 | max_cache_size_mb=cache_config.get('max_cache_size_mb', 100),
227 | max_entry_size_mb=cache_config.get('max_entry_size_mb', 10),
228 | eviction_policy=cache_config.get('eviction_policy', 'lru'),
229 | redis_host=cache_config.get('redis_host', 'localhost'),
230 | redis_port=cache_config.get('redis_port', 6379),
231 | redis_password=cache_config.get('redis_password'),
232 | redis_db=cache_config.get('redis_db', 0),
233 | redis_ssl=cache_config.get('redis_ssl', False),
234 | disk_cache_path=cache_config.get('disk_cache_path',
235 | str(Path.home() / '.cognidb' / 'cache')),
236 | enable_compression=cache_config.get('enable_compression', True),
237 | enable_async_writes=cache_config.get('enable_async_writes', True)
238 | )
239 |
240 | # Security configuration
241 | security_config = self._get_nested('security', {})
242 | security = SecurityConfig(
243 | allow_only_select=security_config.get('allow_only_select', True),
244 | max_query_complexity=security_config.get('max_query_complexity', 10),
245 | allow_subqueries=security_config.get('allow_subqueries', False),
246 | allow_unions=security_config.get('allow_unions', False),
247 | enable_rate_limiting=security_config.get('enable_rate_limiting', True),
248 | rate_limit_per_minute=security_config.get('rate_limit_per_minute', 100),
249 | rate_limit_per_hour=security_config.get('rate_limit_per_hour', 1000),
250 | enable_access_control=security_config.get('enable_access_control', True),
251 | default_user_permissions=security_config.get('default_user_permissions', ['SELECT']),
252 | require_authentication=security_config.get('require_authentication', False),
253 | enable_audit_logging=security_config.get('enable_audit_logging', True),
254 | audit_log_path=security_config.get('audit_log_path',
255 | str(Path.home() / '.cognidb' / 'audit.log')),
256 | log_query_results=security_config.get('log_query_results', False),
257 | encrypt_cache=security_config.get('encrypt_cache', True),
258 | encrypt_logs=security_config.get('encrypt_logs', True),
259 | encryption_key=security_config.get('encryption_key'),
260 | allowed_ip_ranges=security_config.get('allowed_ip_ranges', []),
261 | require_ssl=security_config.get('require_ssl', True)
262 | )
263 |
264 | # Create settings
265 | return Settings(
266 | database=database,
267 | llm=llm,
268 | cache=cache,
269 | security=security,
270 | app_name=self._get_nested('app_name', 'CogniDB'),
271 | environment=self._get_nested('environment', 'production'),
272 | debug=self._get_nested('debug', False),
273 | log_level=self._get_nested('log_level', 'INFO'),
274 | data_dir=self._get_nested('data_dir', str(Path.home() / '.cognidb')),
275 | log_dir=self._get_nested('log_dir', str(Path.home() / '.cognidb' / 'logs')),
276 | enable_natural_language=self._get_nested('enable_natural_language', True),
277 | enable_query_explanation=self._get_nested('enable_query_explanation', True),
278 | enable_query_optimization=self._get_nested('enable_query_optimization', True),
279 | enable_auto_indexing=self._get_nested('enable_auto_indexing', False),
280 | enable_metrics=self._get_nested('enable_metrics', True),
281 | metrics_port=self._get_nested('metrics_port', 9090),
282 | enable_tracing=self._get_nested('enable_tracing', True),
283 | tracing_endpoint=self._get_nested('tracing_endpoint')
284 | )
285 |
286 | def _get_nested(self, path: str, default: Any = None) -> Any:
287 | """Get nested configuration value."""
288 | keys = path.split('.')
289 | value = self._config_data
290 |
291 | for key in keys:
292 | if isinstance(value, dict) and key in value:
293 | value = value[key]
294 | else:
295 | return default
296 |
297 | return value
298 |
299 | def _set_nested(self, path: str, value: Any) -> None:
300 | """Set nested configuration value."""
301 | keys = path.split('.')
302 | target = self._config_data
303 |
304 | for key in keys[:-1]:
305 | if key not in target:
306 | target[key] = {}
307 | target = target[key]
308 |
309 | target[keys[-1]] = value
--------------------------------------------------------------------------------
/cognidb/drivers/postgres_driver.py:
--------------------------------------------------------------------------------
1 | """Secure PostgreSQL driver implementation."""
2 |
3 | import time
4 | import logging
5 | from typing import Dict, List, Any, Optional
6 | import psycopg2
7 | from psycopg2 import pool, sql, extras, OperationalError, DatabaseError
8 | from psycopg2.extensions import ISOLATION_LEVEL_READ_COMMITTED
9 | from .base_driver import BaseDriver
10 | from ..core.exceptions import ConnectionError, ExecutionError
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | class PostgreSQLDriver(BaseDriver):
16 | """
17 | PostgreSQL database driver with security enhancements.
18 |
19 | Features:
20 | - Connection pooling with pgbouncer support
21 | - Parameterized queries with proper escaping
22 | - SSL/TLS enforcement
23 | - Statement timeout enforcement
24 | - Prepared statements
25 | - EXPLAIN ANALYZE integration
26 | """
27 |
28 | def __init__(self, config: Dict[str, Any]):
29 | """Initialize PostgreSQL driver."""
30 | super().__init__(config)
31 | self.pool = None
32 | self._prepared_statements = {}
33 |
34 | def connect(self) -> None:
35 | """Establish connection to PostgreSQL database."""
36 | try:
37 | # Prepare connection config
38 | conn_params = {
39 | 'host': self.config['host'],
40 | 'port': self.config.get('port', 5432),
41 | 'database': self.config['database'],
42 | 'user': self.config.get('username'),
43 | 'password': self.config.get('password'),
44 | 'connect_timeout': self.config.get('connection_timeout', 10),
45 | 'application_name': 'CogniDB',
46 | 'options': f"-c statement_timeout={self.config.get('query_timeout', 30)}s"
47 | }
48 |
49 | # SSL configuration
50 | if self.config.get('ssl_enabled', True):
51 | conn_params['sslmode'] = 'require'
52 | if self.config.get('ssl_ca_cert'):
53 | conn_params['sslrootcert'] = self.config['ssl_ca_cert']
54 | if self.config.get('ssl_client_cert'):
55 | conn_params['sslcert'] = self.config['ssl_client_cert']
56 | if self.config.get('ssl_client_key'):
57 | conn_params['sslkey'] = self.config['ssl_client_key']
58 |
59 | # Create connection pool
60 | self.pool = psycopg2.pool.ThreadedConnectionPool(
61 | minconn=1,
62 | maxconn=self.config.get('pool_size', 5),
63 | **conn_params
64 | )
65 |
66 | # Test connection
67 | self.connection = self.pool.getconn()
68 | self.connection.set_isolation_level(ISOLATION_LEVEL_READ_COMMITTED)
69 | self._connection_time = time.time()
70 |
71 | # Set additional session parameters
72 | with self.connection.cursor() as cursor:
73 | cursor.execute("SET TIME ZONE 'UTC'")
74 | cursor.execute("SET lock_timeout = '5s'")
75 | cursor.execute("SET idle_in_transaction_session_timeout = '60s'")
76 |
77 | self.connection.commit()
78 |
79 | logger.info(f"Connected to PostgreSQL database: {self.config['database']}")
80 |
81 | except (OperationalError, DatabaseError) as e:
82 | logger.error(f"PostgreSQL connection failed: {str(e)}")
83 | raise ConnectionError(f"Failed to connect to PostgreSQL: {str(e)}")
84 |
85 | def disconnect(self) -> None:
86 | """Close the database connection."""
87 | if self.connection:
88 | try:
89 | # Clear prepared statements
90 | for stmt_name in self._prepared_statements:
91 | try:
92 | with self.connection.cursor() as cursor:
93 | cursor.execute(f"DEALLOCATE {stmt_name}")
94 | except Exception:
95 | pass
96 |
97 | self._prepared_statements.clear()
98 |
99 | # Return connection to pool
100 | if self.pool:
101 | self.pool.putconn(self.connection)
102 |
103 | logger.info("Disconnected from PostgreSQL database")
104 |
105 | except Exception as e:
106 | logger.error(f"Error closing connection: {str(e)}")
107 | finally:
108 | self.connection = None
109 | self._connection_time = None
110 |
111 | # Close the pool
112 | if self.pool:
113 | self.pool.closeall()
114 | self.pool = None
115 |
116 | def _create_connection(self):
117 | """Get connection from pool."""
118 | if not self.pool:
119 | raise ConnectionError("Connection pool not initialized")
120 | return self.pool.getconn()
121 |
122 | def _close_connection(self):
123 | """Return connection to pool."""
124 | if self.connection and self.pool:
125 | self.pool.putconn(self.connection)
126 |
127 | def _execute_with_timeout(self,
128 | query: str,
129 | params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
130 | """Execute query with timeout and proper parameterization."""
131 | cursor = None
132 |
133 | try:
134 | # Ensure we have a valid connection
135 | if not self.connection or self.connection.closed:
136 | self.connection = self._create_connection()
137 |
138 | # Create cursor with RealDictCursor for dict results
139 | cursor = self.connection.cursor(cursor_factory=extras.RealDictCursor)
140 |
141 | # Execute query with parameters
142 | if params:
143 | # Use psycopg2's parameter substitution
144 | cursor.execute(query, params)
145 | else:
146 | cursor.execute(query)
147 |
148 | # Handle results
149 | if cursor.description:
150 | results = cursor.fetchall()
151 |
152 | # Convert RealDictRow to regular dict
153 | results = [dict(row) for row in results]
154 |
155 | # Apply result size limit
156 | max_results = self.config.get('max_result_size', 10000)
157 | if len(results) > max_results:
158 | logger.warning(f"Result truncated from {len(results)} to {max_results} rows")
159 | results = results[:max_results]
160 |
161 | return results
162 | else:
163 | # For non-SELECT queries
164 | self.connection.commit()
165 | return [{'affected_rows': cursor.rowcount}]
166 |
167 | except (OperationalError, DatabaseError) as e:
168 | if self.connection:
169 | self.connection.rollback()
170 | raise ExecutionError(f"Query execution failed: {str(e)}")
171 | finally:
172 | if cursor:
173 | cursor.close()
174 |
175 | def execute_prepared(self,
176 | name: str,
177 | query: str,
178 | params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]:
179 | """Execute a prepared statement for better performance."""
180 | cursor = None
181 |
182 | try:
183 | cursor = self.connection.cursor(cursor_factory=extras.RealDictCursor)
184 |
185 | # Prepare statement if not already prepared
186 | if name not in self._prepared_statements:
187 | cursor.execute(f"PREPARE {name} AS {query}")
188 | self._prepared_statements[name] = query
189 |
190 | # Execute prepared statement
191 | if params:
192 | execute_query = sql.SQL("EXECUTE {} ({})").format(
193 | sql.Identifier(name),
194 | sql.SQL(', ').join(sql.Placeholder() * len(params))
195 | )
196 | cursor.execute(execute_query, list(params.values()))
197 | else:
198 | cursor.execute(f"EXECUTE {name}")
199 |
200 | # Fetch results
201 | if cursor.description:
202 | return [dict(row) for row in cursor.fetchall()]
203 | else:
204 | self.connection.commit()
205 | return [{'affected_rows': cursor.rowcount}]
206 |
207 | except Exception as e:
208 | if self.connection:
209 | self.connection.rollback()
210 | raise ExecutionError(f"Prepared statement execution failed: {str(e)}")
211 | finally:
212 | if cursor:
213 | cursor.close()
214 |
215 | def explain_query(self, query: str, analyze: bool = False) -> Dict[str, Any]:
216 | """Get query execution plan."""
217 | explain_query = f"EXPLAIN {'ANALYZE' if analyze else ''} {query}"
218 |
219 | try:
220 | results = self._execute_with_timeout(explain_query)
221 | return {
222 | 'plan': results,
223 | 'query': query
224 | }
225 | except Exception as e:
226 | raise ExecutionError(f"Failed to explain query: {str(e)}")
227 |
228 | def _fetch_schema_impl(self) -> Dict[str, Dict[str, str]]:
229 | """Fetch PostgreSQL schema using information_schema."""
230 | query = """
231 | SELECT
232 | t.table_name,
233 | c.column_name,
234 | c.data_type,
235 | c.character_maximum_length,
236 | c.numeric_precision,
237 | c.numeric_scale,
238 | c.is_nullable,
239 | c.column_default,
240 | tc.constraint_type
241 | FROM information_schema.tables t
242 | JOIN information_schema.columns c
243 | ON t.table_schema = c.table_schema
244 | AND t.table_name = c.table_name
245 | LEFT JOIN information_schema.key_column_usage kcu
246 | ON c.table_schema = kcu.table_schema
247 | AND c.table_name = kcu.table_name
248 | AND c.column_name = kcu.column_name
249 | LEFT JOIN information_schema.table_constraints tc
250 | ON kcu.constraint_schema = tc.constraint_schema
251 | AND kcu.constraint_name = tc.constraint_name
252 | WHERE t.table_schema = 'public'
253 | AND t.table_type = 'BASE TABLE'
254 | ORDER BY t.table_name, c.ordinal_position
255 | """
256 |
257 | cursor = None
258 | try:
259 | cursor = self.connection.cursor(cursor_factory=extras.RealDictCursor)
260 | cursor.execute(query)
261 |
262 | schema = {}
263 | for row in cursor:
264 | table_name = row['table_name']
265 | if table_name not in schema:
266 | schema[table_name] = {}
267 |
268 | # Build column type
269 | col_type = row['data_type']
270 | if row['character_maximum_length']:
271 | col_type += f"({row['character_maximum_length']})"
272 | elif row['numeric_precision']:
273 | col_type += f"({row['numeric_precision']}"
274 | if row['numeric_scale']:
275 | col_type += f",{row['numeric_scale']}"
276 | col_type += ")"
277 |
278 | if row['is_nullable'] == 'NO':
279 | col_type += ' NOT NULL'
280 | if row['constraint_type'] == 'PRIMARY KEY':
281 | col_type += ' PRIMARY KEY'
282 | if row['column_default']:
283 | col_type += ' DEFAULT'
284 |
285 | schema[table_name][row['column_name']] = col_type
286 |
287 | # Fetch indexes
288 | self._fetch_indexes(schema, cursor)
289 |
290 | return schema
291 |
292 | finally:
293 | if cursor:
294 | cursor.close()
295 |
296 | def _fetch_indexes(self, schema: Dict[str, Dict[str, str]], cursor):
297 | """Fetch index information."""
298 | query = """
299 | SELECT
300 | tablename,
301 | indexname,
302 | indexdef
303 | FROM pg_indexes
304 | WHERE schemaname = 'public'
305 | AND indexname NOT LIKE '%_pkey'
306 | ORDER BY tablename, indexname
307 | """
308 |
309 | cursor.execute(query)
310 |
311 | for row in cursor:
312 | table_name = row['tablename']
313 | if table_name in schema:
314 | index_key = f"{table_name}_indexes"
315 | if index_key not in schema:
316 | schema[index_key] = []
317 | schema[index_key].append(f"{row['indexname']}")
318 |
319 | def _begin_transaction(self):
320 | """Begin a transaction."""
321 | # PostgreSQL starts transaction automatically
322 | pass
323 |
324 | def _commit_transaction(self):
325 | """Commit a transaction."""
326 | self.connection.commit()
327 |
328 | def _rollback_transaction(self):
329 | """Rollback a transaction."""
330 | self.connection.rollback()
331 |
332 | def _get_driver_info(self) -> Dict[str, Any]:
333 | """Get PostgreSQL-specific information."""
334 | info = {
335 | 'server_version': None,
336 | 'connection_id': None,
337 | 'current_schema': None,
338 | 'encoding': None
339 | }
340 |
341 | if self.connection and not self.connection.closed:
342 | try:
343 | cursor = self.connection.cursor()
344 | cursor.execute("""
345 | SELECT
346 | version(),
347 | pg_backend_pid(),
348 | current_schema(),
349 | pg_encoding_to_char(encoding)
350 | FROM pg_database
351 | WHERE datname = current_database()
352 | """)
353 | result = cursor.fetchone()
354 | info['server_version'] = result[0]
355 | info['connection_id'] = result[1]
356 | info['current_schema'] = result[2]
357 | info['encoding'] = result[3]
358 | cursor.close()
359 | except Exception:
360 | pass
361 |
362 | return info
363 |
364 | @property
365 | def supports_transactions(self) -> bool:
366 | """PostgreSQL supports transactions."""
367 | return True
368 |
369 | @property
370 | def supports_schemas(self) -> bool:
371 | """PostgreSQL supports schemas."""
372 | return True
--------------------------------------------------------------------------------