├── src ├── __init__.py ├── api │ ├── __init__.py │ ├── models.py │ └── main.py ├── core │ ├── __init__.py │ ├── env_example.py │ ├── config.py │ ├── llm_factory.py │ └── cli.py ├── policy │ ├── __init__.py │ └── loader.py ├── rag │ ├── __init__.py │ └── engine.py └── __main__.py ├── tests ├── __init__.py ├── deployment-with-violations.yaml ├── conftest.py ├── deployment-compliant.yaml ├── test_index_optimization.py └── test_api_integration.py ├── .gitignore ├── pytest.ini ├── start_server.py ├── setup.py ├── .env.example ├── pyproject.toml ├── run_tests.sh ├── docs └── demo.md ├── run_cli.sh ├── data └── company_policy.txt ├── README.md ├── .github └── workflows │ └── ci.yml └── LICENSE /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/policy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/rag/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | 7 | # Distribution / packaging 8 | build/ 9 | dist/ 10 | *.egg-info/ 11 | *.egg 12 | 13 | # Virtual environments 14 | venv/ 15 | .venv/ 16 | env/ 17 | .env 18 | .env.* 19 | 20 | # LlamaIndex 21 | .llamaindex/ 22 | storage/ 23 | 24 | # FastAPI 25 | .uvicorn/ 26 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | testpaths = tests 3 | python_files = test_*.py 4 | python_classes = Test* 5 | python_functions = test_* 6 | addopts = 7 | -v 8 | --tb=short 9 | --strict-markers 10 | --disable-warnings 11 | --color=yes 12 | markers = 13 | unit: Unit tests 14 | integration: Integration tests 15 | api: API tests 16 | slow: Slow tests that may take longer to run 17 | filterwarnings = 18 | ignore::DeprecationWarning 19 | ignore::PendingDeprecationWarning -------------------------------------------------------------------------------- /start_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Start the AIPA API server.""" 3 | 4 | import os 5 | import sys 6 | import logging 7 | 8 | # Add src to path 9 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) 10 | 11 | from src.api.main import start 12 | 13 | if __name__ == "__main__": 14 | print("Starting AIPA API server...") 15 | print("Server will be available at: http://localhost:8000") 16 | print("API docs at: http://localhost:8000/docs") 17 | print("Press Ctrl+C to stop") 18 | start() 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="aipa", 5 | version="0.1.0", 6 | packages=find_packages(), 7 | install_requires=[ 8 | "llama-index-core>=0.10.0", 9 | "llama-index-readers-file>=0.1.0", 10 | "llama-stack-client>=0.1.0", 11 | "fastapi>=0.103.1", 12 | "uvicorn>=0.23.2", 13 | "python-dotenv>=1.0.0", 14 | "langchain>=0.0.267", 15 | "pydantic>=2.3.0", 16 | ], 17 | entry_points={ 18 | "console_scripts": [ 19 | "aipa=src.__main__:main", 20 | ], 21 | }, 22 | ) 23 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # LLM Provider Configuration 2 | # Choose one: llamastack, anthropic, openai 3 | LLM_PROVIDER=anthropic 4 | LLM_TEMPERATURE=0.1 5 | LLM_MAX_TOKENS=1024 6 | 7 | # LlamaStack Configuration (if using llamastack provider) 8 | LLAMASTACK_API_URL=http://localhost:8000 9 | LLAMASTACK_MODEL=llama2 10 | 11 | # Anthropic Configuration (if using anthropic provider) 12 | ANTHROPIC_API_KEY=your_anthropic_api_key_here 13 | ANTHROPIC_MODEL=claude-3-haiku-20240307 14 | 15 | # OpenAI Configuration (if using openai provider) 16 | OPENAI_API_KEY=your_openai_api_key_here 17 | OPENAI_MODEL=gpt-3.5-turbo 18 | 19 | # RAG Configuration 20 | RAG_CHUNK_SIZE=512 21 | RAG_CHUNK_OVERLAP=50 22 | RAG_TOP_K=3 23 | 24 | # Policy Configuration 25 | POLICY_DIR=data 26 | 27 | # Debug 28 | DEBUG=true 29 | -------------------------------------------------------------------------------- /src/core/env_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample environment variables for configuration. 3 | 4 | To use, create a .env file in the project root with the following variables: 5 | 6 | # Llama Stack Configuration 7 | LLAMA_API_URL=http://localhost:11434 # URL for the llama-stack API 8 | LLAMA_MODEL=llama2 # Model name to use 9 | LLAMA_PROVIDER=ollama # Provider (ollama, vllm, etc.) 10 | LLAMA_TEMPERATURE=0.1 # Temperature for generation 11 | LLAMA_MAX_TOKENS=1024 # Maximum tokens to generate 12 | 13 | # RAG Configuration 14 | RAG_CHUNK_SIZE=512 # Size of chunks for indexing 15 | RAG_CHUNK_OVERLAP=50 # Overlap between chunks 16 | RAG_TOP_K=3 # Number of chunks to retrieve 17 | 18 | # Policy Configuration 19 | POLICY_DIR=data # Directory containing policy documents 20 | 21 | # Debug mode 22 | DEBUG=false # Enable debug mode 23 | """ 24 | -------------------------------------------------------------------------------- /src/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | from src.api.main import start as start_api 5 | from src.core.cli import cli as start_cli 6 | 7 | 8 | def main(): 9 | """Main entry point for the application.""" 10 | parser = argparse.ArgumentParser( 11 | description="AI Policy Advisor - A RAG-based policy engine using llama-stack" 12 | ) 13 | subparsers = parser.add_subparsers(dest="command", help="Command to run") 14 | 15 | # API server command 16 | api_parser = subparsers.add_parser("api", help="Start the API server") 17 | 18 | # CLI command - pass through to Click 19 | cli_parser = subparsers.add_parser("cli", help="Run CLI commands") 20 | 21 | args, remaining = parser.parse_known_args() 22 | 23 | if args.command == "api": 24 | start_api() 25 | elif args.command == "cli": 26 | # Pass remaining args to Click CLI 27 | sys.argv = ["cli"] + remaining 28 | start_cli() 29 | else: 30 | parser.print_help() 31 | 32 | 33 | if __name__ == "__main__": 34 | main() 35 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "aipa" 7 | version = "0.1.0" 8 | description = "AI Policy Advisor - A RAG-based policy engine using llama-stack" 9 | authors = [ 10 | {name = "AI Policy Advisor Team"} 11 | ] 12 | requires-python = ">=3.10" 13 | dependencies = [ 14 | "llama-index-core>=0.10.0", 15 | "llama-index-readers-file>=0.1.0", 16 | "llama-index-llms-anthropic>=0.1.0", 17 | "llama-index-llms-openai>=0.1.0", 18 | "llama-index-embeddings-huggingface>=0.2.0", 19 | "llama-stack-client>=0.1.0", 20 | "fastapi>=0.103.1", 21 | "uvicorn>=0.23.2", 22 | "python-dotenv>=1.0.0", 23 | "langchain>=0.0.267", 24 | "pydantic>=2.3.0", 25 | "click>=8.0.0", 26 | "pyyaml>=6.0.0", 27 | "httpx>=0.24.0", 28 | ] 29 | 30 | [project.optional-dependencies] 31 | dev = [ 32 | "pytest>=7.0.0", 33 | "pytest-asyncio>=0.21.0", 34 | "pytest-mock>=3.11.0", 35 | "pytest-cov>=4.1.0", 36 | "httpx>=0.24.0", 37 | "black>=23.7.0", 38 | "isort>=5.12.0", 39 | ] 40 | -------------------------------------------------------------------------------- /tests/deployment-with-violations.yaml: -------------------------------------------------------------------------------- 1 | # deployment-with-violations.yaml 2 | apiVersion: apps/v1 3 | kind: Deployment 4 | metadata: 5 | name: nginx-deployment 6 | labels: 7 | app: nginx 8 | spec: 9 | replicas: 1 # Violation: Less than minimum 2 replicas 10 | selector: 11 | matchLabels: 12 | app: nginx 13 | template: 14 | metadata: 15 | labels: 16 | app: nginx 17 | spec: 18 | containers: 19 | - name: nginx 20 | image: nginx:latest # Violation: Using latest tag 21 | resources: 22 | requests: # Violation: Below minimum requirements 23 | cpu: "100m" 24 | memory: "128Mi" 25 | limits: # Violation: Missing resource limits 26 | cpu: "200m" 27 | memory: "256Mi" 28 | securityContext: # Violation: Missing security context 29 | runAsUser: 0 # Violation: Running as root 30 | ports: 31 | - containerPort: 80 32 | livenessProbe: # Violation: Missing health checks 33 | httpGet: 34 | path: / 35 | port: 80 36 | initialDelaySeconds: 30 37 | periodSeconds: 10 38 | -------------------------------------------------------------------------------- /src/policy/loader.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Dict, Union 3 | 4 | from src.core.config import config 5 | 6 | 7 | class PolicyLoader: 8 | """Load policy documents from the policy directory.""" 9 | 10 | def __init__(self, policy_dir: Path = None): 11 | """Initialize the policy loader. 12 | 13 | Args: 14 | policy_dir: Directory containing policy documents. Defaults to config value. 15 | """ 16 | self.policy_dir = policy_dir or config.policy.policy_dir 17 | 18 | def load_policies(self) -> List[Dict[str, Union[str, Path]]]: 19 | """Load all policy documents from the policy directory. 20 | 21 | Returns: 22 | List of dictionaries containing policy document metadata and content. 23 | """ 24 | policy_docs = [] 25 | 26 | # Load all text files in the policy directory 27 | for policy_file in self.policy_dir.glob("*.txt"): 28 | with open(policy_file, "r") as f: 29 | content = f.read() 30 | 31 | policy_docs.append({ 32 | "source": str(policy_file), 33 | "filename": policy_file.name, 34 | "content": content 35 | }) 36 | 37 | return policy_docs 38 | 39 | def get_policy_count(self) -> int: 40 | """Get the number of policy documents. 41 | 42 | Returns: 43 | Number of policy documents. 44 | """ 45 | return len(list(self.policy_dir.glob("*.txt"))) 46 | 47 | 48 | # Create a singleton policy loader instance 49 | policy_loader = PolicyLoader() 50 | -------------------------------------------------------------------------------- /run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | echo "🧪 Running AIPA Test Suite" 5 | echo "==========================" 6 | 7 | # Check if virtual environment is activated 8 | if [[ "$VIRTUAL_ENV" == "" ]]; then 9 | echo "⚠️ Warning: No virtual environment detected. Activating venv..." 10 | if [[ -f "venv/bin/activate" ]]; then 11 | source venv/bin/activate 12 | else 13 | echo "❌ No venv found. Please create one with: python -m venv venv && source venv/bin/activate" 14 | exit 1 15 | fi 16 | fi 17 | 18 | # Install dependencies if needed 19 | echo "📦 Installing dependencies..." 20 | pip install -e ".[dev]" > /dev/null 2>&1 21 | 22 | # Run linting 23 | echo "" 24 | echo "🔍 Running code quality checks..." 25 | echo " - Black formatting check..." 26 | black --check --diff src tests || { 27 | echo "❌ Black formatting issues found. Run: black src tests" 28 | exit 1 29 | } 30 | 31 | echo " - Import sorting check..." 32 | isort --check-only --diff src tests || { 33 | echo "❌ Import sorting issues found. Run: isort src tests" 34 | exit 1 35 | } 36 | 37 | # Run tests 38 | echo "" 39 | echo "🧪 Running tests..." 40 | 41 | echo " - Unit tests (index optimization)..." 42 | pytest tests/test_index_optimization.py -v -q 43 | 44 | echo " - Integration tests (API)..." 45 | pytest tests/test_api_integration.py -v -q 46 | 47 | echo " - All tests with coverage..." 48 | pytest --cov=src --cov-report=term-missing --cov-report=html -q 49 | 50 | echo "" 51 | echo "✅ All tests passed!" 52 | echo "" 53 | echo "📊 Coverage report generated in htmlcov/index.html" 54 | echo "🎉 Test suite completed successfully!" -------------------------------------------------------------------------------- /src/api/models.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Optional 2 | from pydantic import BaseModel, Field 3 | 4 | 5 | class QueryRequest(BaseModel): 6 | """Request model for policy queries.""" 7 | query: str = Field(..., description="The query about policy to check") 8 | 9 | 10 | class ManifestValidationRequest(BaseModel): 11 | """Request model for Kubernetes manifest validation.""" 12 | manifest: str = Field(..., description="The Kubernetes manifest YAML content") 13 | 14 | 15 | class PolicyViolationResponse(BaseModel): 16 | """Response model for policy violations.""" 17 | rule: str = Field(..., description="The policy rule that was violated") 18 | manifest_path: str = Field(..., description="Path in the manifest where violation occurred") 19 | violation: str = Field(..., description="Description of the violation") 20 | severity: str = Field(default="error", description="Severity of the violation") 21 | 22 | 23 | class ManifestValidationResponse(BaseModel): 24 | """Response model for manifest validation.""" 25 | violations: List[PolicyViolationResponse] = Field( 26 | default_factory=list, description="List of policy violations found" 27 | ) 28 | compliant: bool = Field(..., description="Whether the manifest is compliant") 29 | metadata: Dict[str, Any] = Field( 30 | default_factory=dict, description="Metadata about the validation" 31 | ) 32 | 33 | 34 | class SourceInfo(BaseModel): 35 | """Information about a source document.""" 36 | source: str = Field(..., description="The source file path") 37 | text: str = Field(..., description="The relevant text from the source") 38 | 39 | 40 | class QueryResponse(BaseModel): 41 | """Response model for policy queries.""" 42 | answer: str = Field(..., description="The policy decision or answer") 43 | sources: List[SourceInfo] = Field( 44 | default_factory=list, description="Source documents used for the answer" 45 | ) 46 | metadata: Dict[str, Any] = Field( 47 | default_factory=dict, description="Metadata about the query and response" 48 | ) 49 | -------------------------------------------------------------------------------- /src/core/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Optional, Literal 4 | 5 | from pydantic import BaseModel 6 | from dotenv import load_dotenv 7 | 8 | # Load environment variables 9 | load_dotenv() 10 | 11 | # Base directory 12 | BASE_DIR = Path(__file__).parent.parent.parent 13 | 14 | LLMProvider = Literal["llamastack", "anthropic", "openai"] 15 | 16 | 17 | class LLMConfig(BaseModel): 18 | """Configuration for LLM providers.""" 19 | provider: LLMProvider = os.getenv("LLM_PROVIDER", "llamastack") 20 | temperature: float = float(os.getenv("LLM_TEMPERATURE", "0.1")) 21 | max_tokens: int = int(os.getenv("LLM_MAX_TOKENS", "1024")) 22 | 23 | 24 | class LlamaStackConfig(BaseModel): 25 | """Configuration for Llama Stack.""" 26 | api_url: str = os.getenv("LLAMASTACK_API_URL", "http://localhost:8000") 27 | model_name: str = os.getenv("LLAMASTACK_MODEL", "llama2") 28 | 29 | 30 | class AnthropicConfig(BaseModel): 31 | """Configuration for Anthropic.""" 32 | api_key: str = os.getenv("ANTHROPIC_API_KEY", "") 33 | model_name: str = os.getenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307") 34 | 35 | 36 | class OpenAIConfig(BaseModel): 37 | """Configuration for OpenAI.""" 38 | api_key: str = os.getenv("OPENAI_API_KEY", "") 39 | model_name: str = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo") 40 | 41 | 42 | class RagConfig(BaseModel): 43 | """Configuration for RAG.""" 44 | chunk_size: int = int(os.getenv("RAG_CHUNK_SIZE", "512")) 45 | chunk_overlap: int = int(os.getenv("RAG_CHUNK_OVERLAP", "50")) 46 | similarity_top_k: int = int(os.getenv("RAG_TOP_K", "3")) 47 | 48 | 49 | class PolicyConfig(BaseModel): 50 | """Configuration for Policy documents.""" 51 | policy_dir: Path = Path(os.getenv("POLICY_DIR", str(BASE_DIR / "data"))) 52 | 53 | 54 | class Config(BaseModel): 55 | """Main configuration class.""" 56 | llm: LLMConfig = LLMConfig() 57 | llamastack: LlamaStackConfig = LlamaStackConfig() 58 | anthropic: AnthropicConfig = AnthropicConfig() 59 | openai: OpenAIConfig = OpenAIConfig() 60 | rag: RagConfig = RagConfig() 61 | policy: PolicyConfig = PolicyConfig() 62 | debug: bool = os.getenv("DEBUG", "False").lower() in ("true", "1", "t") 63 | 64 | 65 | # Create a singleton config instance 66 | config = Config() 67 | -------------------------------------------------------------------------------- /docs/demo.md: -------------------------------------------------------------------------------- 1 | # Demo 2 | 3 | AIPA 4 | 5 | RAG based policy engine using natural language. 6 | 7 | Ground the LLM using Retreival Augmented Generation (RAG) by creating an index 8 | of the policy document by: 9 | 10 | 1. Embedding: first breaks down your documents into smaller pieces called 11 | Nodes. 12 | 2. Vector Embeddings: It then creates vector embeddings (numerical 13 | representations of text meaning) for each node using an LLM API. 14 | 3. Storing in a Vector Store: These vector embeddings, along with the 15 | corresponding nodes, are stored in a chosen vector store/vector database. 16 | 17 | LlamaIndex leverages various vector stores/vector databases as its underlying 18 | storage backend. By default, LlamaIndex uses a simple, in-memory for quick 19 | experimentation e.g. SimpleVectorStore, Faiss, Hnslib, etc. This can be easily 20 | persisted to disk. 21 | 22 | ## Setup 23 | 24 | ### System 1 25 | 26 | #### Window 1 27 | 28 | ```bash 29 | ollama run llama3.2:3b-instruct-fp16 --keepalive 60m 30 | ``` 31 | 32 | #### Window 2 33 | 34 | ```bash 35 | export LLAMA_STACK_MODEL="meta-llama/Llama-3.2-3B-Instruct" 36 | export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" 37 | export LLAMA_STACK_PORT=8321 38 | export LLAMA_STACK_SERVER=http://localhost:$LLAMA_STACK_PORT 39 | export LLAMA_STACK_ENDPOINT=$LLAMA_STACK_SERVER 40 | podman run -it -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT -v ~/.llama:/root/.llama:Z --network=host llamastack/distribution-ollama --port $LLAMA_STACK_PORT --env INFERENCE_MODEL=$LLAMA_STACK_MODEL --env OLLAMA_URL=http://localhos 41 | t:11434 42 | ``` 43 | 44 | ### System 2 45 | 46 | ## Demo 47 | 48 | ```bash 49 | vi .env.example 50 | vi data/company_policy.txt 51 | python src/core/cli.py -p llamastack ask 'What software can I install?' 52 | python src/core/cli.py -p llamastack ask "Can I use my work laptop for personal use? If so, how much?" 53 | python src/core/cli.py -p llamastack ask "What OCI registries are approved?" 54 | ``` 55 | 56 | ```bash 57 | python -m src api 58 | python src/core/cli.py --use-api ask "What OCI registries are approved?" 59 | vi tests/deployment-with-violations.yaml 60 | python src/core/cli.py --use-api validate-manifest tests/deployment-with-violations.yaml 61 | vi tests/deployment-compliant.yaml 62 | python src/core/cli.py --use-api validate-manifest tests/deployment-compliant.yaml 63 | ``` 64 | 65 | -------------------------------------------------------------------------------- /run_cli.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # AIPA CLI Runner 4 | # This script sets up the environment and runs the CLI 5 | 6 | # Default environment variables 7 | export LLM_PROVIDER=${LLM_PROVIDER:-anthropic} 8 | export LLM_TEMPERATURE=${LLM_TEMPERATURE:-0.1} 9 | export LLM_MAX_TOKENS=${LLM_MAX_TOKENS:-1024} 10 | export LLAMASTACK_API_URL=${LLAMASTACK_API_URL:-http://localhost:8000} 11 | export LLAMASTACK_MODEL=${LLAMASTACK_MODEL:-llama2} 12 | export ANTHROPIC_MODEL=${ANTHROPIC_MODEL:-claude-3-haiku-20240307} 13 | export OPENAI_MODEL=${OPENAI_MODEL:-gpt-3.5-turbo} 14 | export RAG_CHUNK_SIZE=${RAG_CHUNK_SIZE:-512} 15 | export RAG_CHUNK_OVERLAP=${RAG_CHUNK_OVERLAP:-50} 16 | export RAG_TOP_K=${RAG_TOP_K:-3} 17 | export POLICY_DIR=${POLICY_DIR:-data} 18 | export DEBUG=${DEBUG:-false} 19 | 20 | # Check if API key is set for the chosen provider (local mode only) 21 | if [ "$1" != "--use-api" ] && [ "$2" != "--use-api" ]; then 22 | if [ "$LLM_PROVIDER" = "anthropic" ] && [ -z "$ANTHROPIC_API_KEY" ]; then 23 | echo "Warning: ANTHROPIC_API_KEY is not set. Please set it to use Anthropic Claude." 24 | echo "export ANTHROPIC_API_KEY='your_api_key_here'" 25 | echo "" 26 | echo "Or use API mode: $0 --use-api [command]" 27 | fi 28 | 29 | if [ "$LLM_PROVIDER" = "openai" ] && [ -z "$OPENAI_API_KEY" ]; then 30 | echo "Warning: OPENAI_API_KEY is not set. Please set it to use OpenAI GPT." 31 | echo "export OPENAI_API_KEY='your_api_key_here'" 32 | echo "" 33 | echo "Or use API mode: $0 --use-api [command]" 34 | fi 35 | fi 36 | 37 | # Show usage examples if no arguments provided 38 | if [ $# -eq 0 ]; then 39 | echo "AIPA CLI - AI Policy Advisor" 40 | echo "" 41 | echo "Usage examples:" 42 | echo "" 43 | echo "LOCAL MODE (processes policies locally):" 44 | echo " $0 ask 'Can I use Docker Hub images?'" 45 | echo " $0 validate-manifest tests/deployment-compliant.yaml" 46 | echo " $0 --provider anthropic --model claude-3-5-sonnet-20241022 ask 'Security policies?'" 47 | echo "" 48 | echo "API MODE (queries running server, more efficient):" 49 | echo " $0 --use-api ask 'Can I use Docker Hub images?'" 50 | echo " $0 --use-api validate-manifest tests/deployment-compliant.yaml" 51 | echo " $0 --use-api --api-url http://localhost:8001 ask 'Security policies?'" 52 | echo "" 53 | echo "Other commands:" 54 | echo " $0 providers # Show available LLM providers" 55 | echo "" 56 | echo "Note: API mode requires a running server (python -m src.api.main)" 57 | exit 1 58 | fi 59 | 60 | # Run the CLI with all arguments passed through 61 | python src/core/cli.py "$@" 62 | -------------------------------------------------------------------------------- /data/company_policy.txt: -------------------------------------------------------------------------------- 1 | # Company Technology Policy 2 | 3 | ## Software Installation 4 | 1. Employees may install approved software from the company's software portal. 5 | 2. Installation of non-approved software requires IT department approval. 6 | 3. Open source software must be reviewed by the security team before installation. 7 | 4. All software must comply with licensing requirements. 8 | 9 | ## Data Protection 10 | 1. Company data must not be shared on public platforms. 11 | 2. Customer information is confidential and must be encrypted when stored. 12 | 3. Use company-provided cloud storage for sensitive documents. 13 | 4. Regular backups of work-related files are required. 14 | 15 | ## Device Usage 16 | 1. Work laptops may be used for reasonable personal activities. 17 | 2. Device passwords must be changed every 90 days. 18 | 3. Company devices must not be shared with non-employees. 19 | 4. Lost or stolen devices must be reported immediately. 20 | 21 | ## Remote Work 22 | 1. VPN must be used when accessing company resources remotely. 23 | 2. Work from public networks requires additional security measures. 24 | 3. Home networks used for work must have WPA2 or better security. 25 | 4. Remote workers must ensure their workspace complies with confidentiality requirements. 26 | 27 | # Kubernetes Policy Rules 28 | 29 | ## Container Image Registry Policy 30 | - Only container images from approved registries are allowed 31 | - Approved registries are: 32 | * quay.io 33 | * registry.access.redhat.com 34 | - All container images must be fully qualified (include registry, repository, and tag) 35 | - Container images must use specific tags, not 'latest' 36 | - Container images must be scanned for vulnerabilities before deployment 37 | 38 | ## Resource Limits Policy 39 | - All containers must have resource requests and limits defined 40 | - CPU requests must be at least 100m 41 | - Memory requests must be at least 128Mi 42 | - CPU limits must not exceed 2 cores 43 | - Memory limits must not exceed 2Gi 44 | - Resource requests must not exceed resource limits 45 | 46 | ## Security Policy 47 | - All pods must have security contexts defined 48 | - Containers must not run as root 49 | - Containers must not have privileged access 50 | - All pods must have network policies defined 51 | - All pods must have pod security policies applied 52 | - All secrets must be stored in Kubernetes secrets, not in environment variables 53 | 54 | ## Deployment Policy 55 | - All deployments must have at least 2 replicas for high availability 56 | - All deployments must have a rolling update strategy 57 | - All deployments must have resource limits defined 58 | - All deployments must have health checks (readiness and liveness probes) 59 | - All deployments must have proper labels and annotations 60 | 61 | ## Storage Policy 62 | - All persistent volumes must be backed up 63 | - All persistent volumes must have proper access modes 64 | - All persistent volumes must have proper storage class 65 | - All persistent volumes must have proper size limits 66 | 67 | ## Network Policy 68 | - All services must have proper type (ClusterIP, NodePort, or LoadBalancer) 69 | - All services must have proper ports defined 70 | - All services must have proper selectors 71 | - All services must have proper labels and annotations 72 | 73 | ## Monitoring Policy 74 | - All pods must have proper logging configuration 75 | - All pods must have proper metrics exposed 76 | - All pods must have proper alerting rules 77 | - All pods must have proper dashboards 78 | 79 | ## Compliance Policy 80 | - All resources must have proper labels for cost allocation 81 | - All resources must have proper labels for environment 82 | - All resources must have proper labels for team ownership 83 | - All resources must have proper labels for compliance requirements 84 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Pytest configuration and shared fixtures.""" 3 | 4 | import pytest 5 | import os 6 | import sys 7 | from unittest.mock import patch, MagicMock 8 | 9 | # Add src to path for all tests 10 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) 11 | 12 | 13 | @pytest.fixture(autouse=True) 14 | def mock_llm_setup(): 15 | """Auto-use fixture to mock LLM setup for all tests to avoid external dependencies.""" 16 | with patch('src.core.llm_factory.llm_factory.setup_global_llm') as mock_setup: 17 | mock_llm = MagicMock() 18 | mock_setup.return_value = mock_llm 19 | 20 | # Mock Settings to avoid import issues 21 | with patch('src.rag.engine.Settings') as mock_settings: 22 | yield mock_llm 23 | 24 | 25 | @pytest.fixture(autouse=True) 26 | def mock_embedding_setup(): 27 | """Auto-use fixture to mock HuggingFace embeddings to avoid downloading models.""" 28 | with patch('src.rag.engine.HuggingFaceEmbedding') as mock_embedding: 29 | mock_embed_instance = MagicMock() 30 | mock_embedding.return_value = mock_embed_instance 31 | yield mock_embed_instance 32 | 33 | 34 | @pytest.fixture(autouse=True) 35 | def mock_policy_loader(): 36 | """Auto-use fixture to mock policy loader to avoid file system dependencies.""" 37 | with patch('src.rag.engine.policy_loader') as mock_loader: 38 | mock_loader.load_policies.return_value = [ 39 | { 40 | "content": "Test policy content for CI testing", 41 | "source": "test_policy.txt" 42 | } 43 | ] 44 | yield mock_loader 45 | 46 | 47 | @pytest.fixture 48 | def sample_k8s_manifest(): 49 | """Fixture providing a sample Kubernetes manifest for testing.""" 50 | return """ 51 | apiVersion: apps/v1 52 | kind: Deployment 53 | metadata: 54 | name: test-deployment 55 | namespace: default 56 | spec: 57 | replicas: 2 58 | selector: 59 | matchLabels: 60 | app: test-app 61 | template: 62 | metadata: 63 | labels: 64 | app: test-app 65 | spec: 66 | containers: 67 | - name: test-container 68 | image: nginx:1.20 69 | ports: 70 | - containerPort: 80 71 | resources: 72 | requests: 73 | memory: "64Mi" 74 | cpu: "250m" 75 | limits: 76 | memory: "128Mi" 77 | cpu: "500m" 78 | """ 79 | 80 | 81 | @pytest.fixture 82 | def sample_non_compliant_manifest(): 83 | """Fixture providing a non-compliant Kubernetes manifest for testing.""" 84 | return """ 85 | apiVersion: apps/v1 86 | kind: Deployment 87 | metadata: 88 | name: bad-deployment 89 | spec: 90 | replicas: 1 91 | selector: 92 | matchLabels: 93 | app: bad-app 94 | template: 95 | metadata: 96 | labels: 97 | app: bad-app 98 | spec: 99 | containers: 100 | - name: bad-container 101 | image: nginx:latest 102 | securityContext: 103 | runAsUser: 0 104 | """ 105 | 106 | 107 | @pytest.fixture 108 | def sample_policy_violations(): 109 | """Fixture providing sample policy violations for testing.""" 110 | from src.rag.engine import PolicyViolation 111 | 112 | return [ 113 | PolicyViolation( 114 | rule="Container images must use specific tags", 115 | manifest_path="spec.template.spec.containers[0].image", 116 | violation="Using 'latest' tag is not allowed", 117 | severity="error" 118 | ), 119 | PolicyViolation( 120 | rule="Containers should not run as root", 121 | manifest_path="spec.template.spec.containers[0].securityContext.runAsUser", 122 | violation="Container is configured to run as root (UID 0)", 123 | severity="warning" 124 | ) 125 | ] -------------------------------------------------------------------------------- /tests/deployment-compliant.yaml: -------------------------------------------------------------------------------- 1 | # deployment-compliant.yaml 2 | apiVersion: apps/v1 3 | kind: Deployment 4 | metadata: 5 | name: nginx-deployment 6 | labels: 7 | app: nginx 8 | environment: production 9 | team: platform 10 | cost-center: platform-team 11 | annotations: 12 | monitoring.coreos.com/enabled: "true" 13 | logging.kubernetes.io/enabled: "true" 14 | policy.kubernetes.io/pod-security-policy: "restricted" 15 | spec: 16 | replicas: 3 # Compliant: More than minimum 2 replicas 17 | strategy: 18 | type: RollingUpdate 19 | rollingUpdate: 20 | maxSurge: 1 21 | maxUnavailable: 0 22 | selector: 23 | matchLabels: 24 | app: nginx 25 | template: 26 | metadata: 27 | labels: 28 | app: nginx 29 | environment: production 30 | team: platform 31 | cost-center: platform-team 32 | spec: 33 | containers: 34 | - name: nginx 35 | image: quay.io/nginx/nginx:1.25.3 # Compliant: Using approved registry (quay.io) with specific version 36 | resources: 37 | requests: # Compliant: Meets minimum requirements 38 | cpu: "500m" # Above minimum 100m 39 | memory: "512Mi" # Above minimum 128Mi 40 | limits: # Compliant: Has resource limits within policy 41 | cpu: "1000m" # Under maximum 2 cores 42 | memory: "1Gi" # Under maximum 2Gi 43 | securityContext: # Compliant: Has security context 44 | runAsUser: 1000 45 | runAsNonRoot: true # Compliant: Not running as root 46 | allowPrivilegeEscalation: false # Compliant: No privileged access 47 | capabilities: 48 | drop: 49 | - ALL 50 | ports: 51 | - containerPort: 80 52 | name: http 53 | - containerPort: 9113 54 | name: metrics # Compliant: Metrics port for monitoring 55 | livenessProbe: # Compliant: Has health checks 56 | httpGet: 57 | path: / 58 | port: 80 59 | initialDelaySeconds: 30 60 | periodSeconds: 10 61 | timeoutSeconds: 5 62 | failureThreshold: 3 63 | readinessProbe: # Compliant: Has readiness probe 64 | httpGet: 65 | path: / 66 | port: 80 67 | initialDelaySeconds: 5 68 | periodSeconds: 10 69 | timeoutSeconds: 5 70 | failureThreshold: 3 71 | --- 72 | # NetworkPolicy for the deployment 73 | apiVersion: networking.k8s.io/v1 74 | kind: NetworkPolicy 75 | metadata: 76 | name: nginx-network-policy 77 | labels: 78 | app: nginx 79 | environment: production 80 | team: platform 81 | cost-center: platform-team 82 | spec: 83 | podSelector: 84 | matchLabels: 85 | app: nginx 86 | policyTypes: 87 | - Ingress 88 | - Egress 89 | ingress: 90 | - from: 91 | - podSelector: 92 | matchLabels: 93 | role: frontend 94 | ports: 95 | - protocol: TCP 96 | port: 80 97 | egress: 98 | - to: [] 99 | ports: 100 | - protocol: TCP 101 | port: 53 # DNS 102 | - protocol: UDP 103 | port: 53 # DNS 104 | --- 105 | # Service for the deployment 106 | apiVersion: v1 107 | kind: Service 108 | metadata: 109 | name: nginx-service 110 | labels: 111 | app: nginx 112 | environment: production 113 | team: platform 114 | cost-center: platform-team 115 | annotations: 116 | monitoring.coreos.com/scrape: "true" 117 | monitoring.coreos.com/port: "9113" 118 | monitoring.coreos.com/path: "/metrics" 119 | spec: 120 | type: ClusterIP # Compliant: Proper service type 121 | selector: 122 | app: nginx 123 | ports: 124 | - name: http 125 | port: 80 126 | targetPort: http 127 | protocol: TCP 128 | - name: metrics 129 | port: 9113 130 | targetPort: metrics 131 | protocol: TCP 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AIPA (AI Policy Advisor) 2 | 3 | A simple AI-based policy engine that uses natural language policies to make decisions. 4 | 5 | ## Features 6 | 7 | - Uses RAG (Retrieval Augmented Generation) to provide context to LLMs 8 | - Supports multiple LLM providers: LlamaStack, Anthropic Claude, OpenAI GPT 9 | - Policy enforcement based on natural language policy documents 10 | - Kubernetes manifest validation against company policies 11 | - CLI and API interfaces 12 | - Minimal viable implementation for extensibility 13 | 14 | ## Quick Start 15 | 16 | 1. Install dependencies: 17 | ```bash 18 | pip install -e . 19 | ``` 20 | 21 | 2. Configure your LLM provider by copying the example config: 22 | ```bash 23 | cp .env.example .env 24 | # Edit .env with your API keys and preferences 25 | ``` 26 | 27 | 3. Place your policy documents in the `data/` directory 28 | 29 | 4. Check available providers: 30 | ```bash 31 | python src/core/cli.py providers 32 | ``` 33 | 34 | 5. Ask policy questions: 35 | ```bash 36 | python src/core/cli.py ask "Can I install software on my work laptop?" 37 | ``` 38 | 39 | 6. Validate Kubernetes manifests: 40 | ```bash 41 | python src/core/cli.py validate-manifest deployment.yaml 42 | ``` 43 | 44 | ## LLM Provider Configuration 45 | 46 | ### Anthropic Claude (Recommended) 47 | ```env 48 | LLM_PROVIDER=anthropic 49 | ANTHROPIC_API_KEY=your_api_key_here 50 | ANTHROPIC_MODEL=claude-3-haiku-20240307 51 | ``` 52 | 53 | ### OpenAI GPT 54 | ```env 55 | LLM_PROVIDER=openai 56 | OPENAI_API_KEY=your_api_key_here 57 | OPENAI_MODEL=gpt-3.5-turbo 58 | ``` 59 | 60 | ### LlamaStack (Local) 61 | ```env 62 | LLM_PROVIDER=llamastack 63 | LLAMASTACK_API_URL=http://localhost:8000 64 | LLAMASTACK_MODEL=llama2 65 | ``` 66 | 67 | ## CLI Usage 68 | 69 | The CLI supports two modes of operation: 70 | 71 | ### Local Mode (Default) 72 | Processes policies locally, builds RAG index on each run: 73 | 74 | ```bash 75 | # Ask policy questions 76 | python src/core/cli.py ask "What are the password requirements?" 77 | 78 | # Use specific provider 79 | python src/core/cli.py --provider anthropic ask "Security policy question" 80 | 81 | # Validate Kubernetes manifest 82 | python src/core/cli.py validate-manifest tests/deployment-with-violations.yaml 83 | 84 | # Check provider status 85 | python src/core/cli.py providers 86 | ``` 87 | 88 | ### API Mode (Efficient) 89 | Queries a running API server, more efficient as it reuses pre-built RAG indices: 90 | 91 | ```bash 92 | # Start the API server (in another terminal) 93 | python start_server.py 94 | 95 | # Use API mode for queries 96 | python src/core/cli.py --use-api ask "What are the password requirements?" 97 | 98 | # Validate manifests via API 99 | python src/core/cli.py --use-api validate-manifest tests/deployment-compliant.yaml 100 | 101 | # Use custom API URL 102 | python src/core/cli.py --use-api --api-url http://localhost:8001 ask "Security policy?" 103 | ``` 104 | 105 | ### Convenience Script 106 | Use the `run_cli.sh` script for easier usage: 107 | 108 | ```bash 109 | # Show help and examples 110 | ./run_cli.sh 111 | 112 | # Local mode 113 | ./run_cli.sh --provider anthropic ask "Can I use Docker Hub images?" 114 | 115 | # API mode (more efficient) 116 | ./run_cli.sh --use-api ask "Can I use Docker Hub images?" 117 | ``` 118 | 119 | ## API Usage 120 | 121 | ### Start the Server 122 | ```bash 123 | python start_server.py 124 | # or 125 | python -m src.api.main 126 | ``` 127 | 128 | ### Query Policies 129 | ```bash 130 | curl -X POST http://localhost:8000/query \ 131 | -H "Content-Type: application/json" \ 132 | -d '{"query": "Can I install software on my work laptop?"}' 133 | ``` 134 | 135 | ### Validate Kubernetes Manifests 136 | ```bash 137 | curl -X POST http://localhost:8000/validate-manifest \ 138 | -H "Content-Type: application/json" \ 139 | -d '{"manifest": "apiVersion: apps/v1\nkind: Deployment\n..."}' 140 | ``` 141 | 142 | ### API Documentation 143 | Visit http://localhost:8000/docs for interactive API documentation. 144 | 145 | ## Architecture 146 | 147 | - `src/core/` - Core functionality and config 148 | - `src/policy/` - Policy loading and management 149 | - `src/rag/` - Retrieval Augmented Generation engine 150 | - `src/api/` - API interfaces for querying the policy engine 151 | -------------------------------------------------------------------------------- /src/api/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from contextlib import asynccontextmanager 3 | from fastapi import FastAPI, HTTPException 4 | import uvicorn 5 | 6 | from src.api.models import ( 7 | QueryRequest, QueryResponse, 8 | ManifestValidationRequest, ManifestValidationResponse, PolicyViolationResponse 9 | ) 10 | from src.rag.engine import rag_engine, K8sPolicyEnforcer 11 | from src.core.config import config 12 | 13 | # Configure logging 14 | logging.basicConfig( 15 | level=logging.DEBUG if config.debug else logging.INFO, 16 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 17 | ) 18 | logger = logging.getLogger(__name__) 19 | 20 | # Global variable to hold the K8s policy enforcer 21 | k8s_enforcer = None 22 | 23 | 24 | @asynccontextmanager 25 | async def lifespan(app: FastAPI): 26 | """Lifespan event handler for FastAPI app startup and shutdown.""" 27 | global k8s_enforcer 28 | 29 | # Startup: Build the RAG index once 30 | logger.info("Building RAG index at startup...") 31 | rag_engine.build_index() 32 | 33 | # Initialize K8s policy enforcer 34 | k8s_enforcer = K8sPolicyEnforcer(rag_engine) 35 | logger.info("API server startup complete.") 36 | 37 | yield 38 | 39 | # Shutdown: cleanup if needed 40 | logger.info("API server shutting down.") 41 | 42 | 43 | # Create FastAPI app with lifespan handler 44 | app = FastAPI( 45 | title="AI Policy Advisor", 46 | description="A RAG-based policy engine using llama-stack", 47 | version="0.1.0", 48 | lifespan=lifespan, 49 | ) 50 | 51 | 52 | @app.post("/query", response_model=QueryResponse) 53 | async def query_policy(request: QueryRequest) -> QueryResponse: 54 | """Query the policy engine with a natural language question.""" 55 | try: 56 | logger.info(f"Received query: {request.query}") 57 | result = rag_engine.query(request.query) 58 | logger.debug(f"Query result: {result}") 59 | return result 60 | except Exception as e: 61 | logger.error(f"Error processing query: {e}", exc_info=True) 62 | raise HTTPException( 63 | status_code=500, detail=f"Error processing query: {str(e)}" 64 | ) 65 | 66 | 67 | @app.post("/validate-manifest", response_model=ManifestValidationResponse) 68 | async def validate_manifest(request: ManifestValidationRequest) -> ManifestValidationResponse: 69 | """Validate a Kubernetes manifest against policies.""" 70 | try: 71 | logger.info("Received manifest validation request") 72 | 73 | # Ensure k8s_enforcer is initialized 74 | if k8s_enforcer is None: 75 | raise HTTPException( 76 | status_code=503, detail="Service not ready: policy enforcer not initialized" 77 | ) 78 | 79 | violations = k8s_enforcer.enforce_policy(request.manifest) 80 | 81 | # Convert PolicyViolation objects to response models 82 | violation_responses = [ 83 | PolicyViolationResponse( 84 | rule=v.rule, 85 | manifest_path=v.manifest_path, 86 | violation=v.violation, 87 | severity=v.severity 88 | ) 89 | for v in violations 90 | ] 91 | 92 | return ManifestValidationResponse( 93 | violations=violation_responses, 94 | compliant=len(violations) == 0, 95 | metadata={ 96 | "violation_count": len(violations), 97 | "error_count": len([v for v in violations if v.severity == "error"]), 98 | "warning_count": len([v for v in violations if v.severity == "warning"]), 99 | } 100 | ) 101 | except Exception as e: 102 | logger.error(f"Error validating manifest: {e}", exc_info=True) 103 | raise HTTPException( 104 | status_code=500, detail=f"Error validating manifest: {str(e)}" 105 | ) 106 | 107 | 108 | @app.get("/health") 109 | async def health_check(): 110 | """Health check endpoint.""" 111 | return {"status": "healthy"} 112 | 113 | 114 | def start(): 115 | """Start the API server.""" 116 | uvicorn.run( 117 | "src.api.main:app", 118 | host="0.0.0.0", 119 | port=8000, 120 | reload=config.debug, 121 | ) 122 | 123 | 124 | if __name__ == "__main__": 125 | start() 126 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ main, develop ] 6 | pull_request: 7 | branches: [ main, develop ] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.10", "3.11", "3.12"] 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | 24 | - name: Cache pip dependencies 25 | uses: actions/cache@v3 26 | with: 27 | path: ~/.cache/pip 28 | key: ${{ runner.os }}-pip-${{ hashFiles('**/pyproject.toml') }} 29 | restore-keys: | 30 | ${{ runner.os }}-pip- 31 | 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | pip install -e ".[dev]" 36 | 37 | - name: Lint with black 38 | run: | 39 | black --check --diff src tests 40 | 41 | - name: Sort imports with isort 42 | run: | 43 | isort --check-only --diff src tests 44 | 45 | - name: Run unit tests 46 | run: | 47 | pytest tests/test_index_optimization.py -v -m "not slow" 48 | 49 | - name: Run integration tests 50 | run: | 51 | pytest tests/test_api_integration.py -v 52 | 53 | - name: Run all tests with coverage 54 | run: | 55 | pytest --cov=src --cov-report=xml --cov-report=term-missing 56 | 57 | - name: Upload coverage to Codecov 58 | uses: codecov/codecov-action@v3 59 | with: 60 | file: ./coverage.xml 61 | flags: unittests 62 | name: codecov-umbrella 63 | fail_ci_if_error: false 64 | 65 | test-existing-manifests: 66 | runs-on: ubuntu-latest 67 | needs: test 68 | 69 | steps: 70 | - uses: actions/checkout@v4 71 | 72 | - name: Set up Python 3.11 73 | uses: actions/setup-python@v4 74 | with: 75 | python-version: "3.11" 76 | 77 | - name: Install dependencies 78 | run: | 79 | python -m pip install --upgrade pip 80 | pip install -e ".[dev]" 81 | 82 | - name: Test with existing test manifests 83 | run: | 84 | python -c " 85 | import sys 86 | sys.path.insert(0, 'src') 87 | from src.rag.engine import RagEngine, K8sPolicyEnforcer 88 | from unittest.mock import patch, MagicMock 89 | 90 | # Mock external dependencies 91 | with patch('src.rag.engine.policy_loader') as mock_loader: 92 | mock_loader.load_policies.return_value = [ 93 | {'content': 'Test policy', 'source': 'test.txt'} 94 | ] 95 | 96 | with patch('src.rag.engine.HuggingFaceEmbedding'): 97 | with patch('src.core.llm_factory.llm_factory.setup_global_llm'): 98 | with patch('src.rag.engine.Settings'): 99 | with patch('src.rag.engine.VectorStoreIndex') as mock_index_class: 100 | mock_index = MagicMock() 101 | mock_query_engine = MagicMock() 102 | mock_index.as_query_engine.return_value = mock_query_engine 103 | mock_query_engine.query.return_value = 'No violations found.' 104 | mock_index_class.from_documents.return_value = mock_index 105 | 106 | # Test the optimization 107 | engine = RagEngine() 108 | engine.build_index() 109 | enforcer = K8sPolicyEnforcer(engine) 110 | 111 | # Read and test existing manifests 112 | import os 113 | for manifest_file in ['tests/deployment-compliant.yaml', 'tests/deployment-with-violations.yaml']: 114 | if os.path.exists(manifest_file): 115 | with open(manifest_file, 'r') as f: 116 | manifest = f.read() 117 | violations = enforcer.enforce_policy(manifest) 118 | print(f'Tested {manifest_file}: {len(violations)} violations') 119 | 120 | print('✅ All existing manifest tests passed!') 121 | " -------------------------------------------------------------------------------- /src/core/llm_factory.py: -------------------------------------------------------------------------------- 1 | """LLM Factory for creating different LLM providers.""" 2 | 3 | import logging 4 | from typing import Optional 5 | 6 | from llama_index.core.llms import LLM 7 | from llama_index.core.settings import Settings 8 | 9 | from src.core.config import config, LLMProvider 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class LLMFactory: 15 | """Factory for creating LLM instances based on provider configuration.""" 16 | 17 | @staticmethod 18 | def create_llm(provider: Optional[LLMProvider] = None) -> LLM: 19 | """Create an LLM instance based on the provider. 20 | 21 | Args: 22 | provider: LLM provider to use. If None, uses config.llm.provider 23 | 24 | Returns: 25 | LLM instance 26 | 27 | Raises: 28 | ValueError: If provider is not supported or configuration is invalid 29 | """ 30 | provider = provider or config.llm.provider 31 | 32 | logger.info(f"Creating LLM with provider: {provider}") 33 | 34 | if provider == "anthropic": 35 | return LLMFactory._create_anthropic_llm() 36 | elif provider == "openai": 37 | return LLMFactory._create_openai_llm() 38 | elif provider == "llamastack": 39 | return LLMFactory._create_llamastack_llm() 40 | else: 41 | raise ValueError(f"Unsupported LLM provider: {provider}") 42 | 43 | @staticmethod 44 | def _create_anthropic_llm() -> LLM: 45 | """Create Anthropic LLM instance.""" 46 | try: 47 | from llama_index.llms.anthropic import Anthropic 48 | 49 | if not config.anthropic.api_key: 50 | raise ValueError("ANTHROPIC_API_KEY environment variable is required for Anthropic provider") 51 | 52 | llm = Anthropic( 53 | model=config.anthropic.model_name, 54 | api_key=config.anthropic.api_key, 55 | temperature=config.llm.temperature, 56 | max_tokens=config.llm.max_tokens, 57 | ) 58 | 59 | logger.info(f"Created Anthropic LLM with model: {config.anthropic.model_name}") 60 | return llm 61 | 62 | except ImportError as e: 63 | logger.error(f"Failed to import Anthropic LLM: {e}") 64 | raise ValueError("llama-index-llms-anthropic package is required for Anthropic provider") 65 | 66 | @staticmethod 67 | def _create_openai_llm() -> LLM: 68 | """Create OpenAI LLM instance.""" 69 | try: 70 | from llama_index.llms.openai import OpenAI 71 | 72 | if not config.openai.api_key: 73 | raise ValueError("OPENAI_API_KEY environment variable is required for OpenAI provider") 74 | 75 | llm = OpenAI( 76 | model=config.openai.model_name, 77 | api_key=config.openai.api_key, 78 | temperature=config.llm.temperature, 79 | max_tokens=config.llm.max_tokens, 80 | ) 81 | 82 | logger.info(f"Created OpenAI LLM with model: {config.openai.model_name}") 83 | return llm 84 | 85 | except ImportError as e: 86 | logger.error(f"Failed to import OpenAI LLM: {e}") 87 | raise ValueError("llama-index-llms-openai package is required for OpenAI provider") 88 | 89 | @staticmethod 90 | def _create_llamastack_llm() -> LLM: 91 | """Create LlamaStack LLM instance.""" 92 | try: 93 | from llama_stack_client import LlamaStackClient 94 | from src.rag.engine import LlamaStackLLM 95 | 96 | # Initialize LlamaStackClient with timeout configuration 97 | client = LlamaStackClient( 98 | base_url=config.llamastack.api_url, 99 | timeout=30.0, # 30 second timeout 100 | ) 101 | 102 | llm = LlamaStackLLM( 103 | client=client, 104 | model_id=config.llamastack.model_name 105 | ) 106 | 107 | logger.info(f"Created LlamaStack LLM with model: {config.llamastack.model_name}") 108 | return llm 109 | 110 | except ImportError as e: 111 | logger.error(f"Failed to import LlamaStack client: {e}") 112 | raise ValueError("llama-stack-client package is required for LlamaStack provider") 113 | 114 | @staticmethod 115 | def setup_global_llm(provider: Optional[LLMProvider] = None) -> LLM: 116 | """Set up global LLM settings. 117 | 118 | Args: 119 | provider: LLM provider to use. If None, uses config.llm.provider 120 | 121 | Returns: 122 | LLM instance that was set globally 123 | """ 124 | llm = LLMFactory.create_llm(provider) 125 | 126 | # Configure global settings 127 | Settings.llm = llm 128 | Settings.chunk_size = config.rag.chunk_size 129 | Settings.chunk_overlap = config.rag.chunk_overlap 130 | 131 | return llm 132 | 133 | 134 | # Create a singleton factory instance 135 | llm_factory = LLMFactory() 136 | -------------------------------------------------------------------------------- /tests/test_index_optimization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Tests for index building optimization.""" 3 | 4 | import pytest 5 | import unittest.mock as mock 6 | from unittest.mock import patch, MagicMock 7 | import sys 8 | import os 9 | 10 | # Add src to path for imports 11 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) 12 | 13 | from src.rag.engine import RagEngine, K8sPolicyEnforcer 14 | 15 | 16 | class TestIndexOptimization: 17 | """Test suite for index building optimization.""" 18 | 19 | def test_rag_engine_index_building(self): 20 | """Test that RagEngine builds index correctly.""" 21 | engine = RagEngine() 22 | 23 | # Initially no index 24 | assert engine.index is None 25 | 26 | # Build index 27 | engine.build_index() 28 | 29 | # Index should now exist 30 | assert engine.index is not None 31 | 32 | def test_k8s_policy_enforcer_uses_rag_engine_index(self): 33 | """Test that K8sPolicyEnforcer uses the RAG engine's index instead of building its own.""" 34 | engine = RagEngine() 35 | engine.build_index() 36 | 37 | # Create policy enforcer 38 | enforcer = K8sPolicyEnforcer(engine) 39 | 40 | # Should reference the same RAG engine 41 | assert enforcer.rag_engine is engine 42 | 43 | # Should not have its own policy_index attribute 44 | assert not hasattr(enforcer, 'policy_index') 45 | 46 | def test_api_server_startup_sequence(self): 47 | """Test that API server builds index once at startup.""" 48 | # Mock the rag_engine singleton to track build_index calls 49 | with patch('src.rag.engine.rag_engine') as mock_rag_engine: 50 | mock_rag_engine.index = None 51 | mock_rag_engine.build_index = MagicMock() 52 | 53 | # Import the API main module (simulates startup) 54 | with patch('src.api.main.rag_engine', mock_rag_engine): 55 | import src.api.main 56 | 57 | # Verify build_index was called during module import 58 | mock_rag_engine.build_index.assert_called_once() 59 | 60 | def test_cli_backwards_compatibility(self): 61 | """Test that CLI mode still works with on-demand index building.""" 62 | # Create new RAG engine instance (simulates CLI usage) 63 | engine = RagEngine() 64 | 65 | # Create policy enforcer 66 | enforcer = K8sPolicyEnforcer(engine) 67 | 68 | # Mock the build_index method to track calls 69 | with patch.object(engine, 'build_index', wraps=engine.build_index) as mock_build: 70 | with patch.object(engine, 'index', None): 71 | # Test manifest 72 | test_manifest = """ 73 | apiVersion: apps/v1 74 | kind: Deployment 75 | metadata: 76 | name: test 77 | spec: 78 | replicas: 1 79 | template: 80 | spec: 81 | containers: 82 | - name: test 83 | image: nginx:latest 84 | """ 85 | 86 | # Mock the index to avoid actual LLM calls 87 | mock_index = MagicMock() 88 | mock_query_engine = MagicMock() 89 | mock_index.as_query_engine.return_value = mock_query_engine 90 | mock_query_engine.query.return_value = "No policy violations found." 91 | 92 | # Set up the mock to return our mock index when build_index is called 93 | def mock_build_side_effect(): 94 | engine.index = mock_index 95 | 96 | mock_build.side_effect = mock_build_side_effect 97 | 98 | # Call enforce_policy (should trigger index building) 99 | violations = enforcer.enforce_policy(test_manifest) 100 | 101 | # Verify build_index was called 102 | mock_build.assert_called_once() 103 | 104 | # Should return empty violations for "No policy violations found" 105 | assert isinstance(violations, list) 106 | 107 | def test_policy_enforcement_reuses_existing_index(self): 108 | """Test that policy enforcement reuses existing index without rebuilding.""" 109 | engine = RagEngine() 110 | 111 | # Mock the index and build_index 112 | mock_index = MagicMock() 113 | mock_query_engine = MagicMock() 114 | mock_index.as_query_engine.return_value = mock_query_engine 115 | mock_query_engine.query.return_value = "No policy violations found." 116 | 117 | # Set the index (simulates already built) 118 | engine.index = mock_index 119 | 120 | enforcer = K8sPolicyEnforcer(engine) 121 | 122 | # Mock build_index to track calls 123 | with patch.object(engine, 'build_index') as mock_build: 124 | test_manifest = """ 125 | apiVersion: apps/v1 126 | kind: Deployment 127 | metadata: 128 | name: test 129 | spec: 130 | replicas: 1 131 | template: 132 | spec: 133 | containers: 134 | - name: test 135 | image: nginx:latest 136 | """ 137 | 138 | # Call enforce_policy 139 | violations = enforcer.enforce_policy(test_manifest) 140 | 141 | # build_index should NOT be called since index already exists 142 | mock_build.assert_not_called() 143 | 144 | # Should still work correctly 145 | assert isinstance(violations, list) 146 | 147 | def test_multiple_policy_enforcements_single_index(self): 148 | """Test that multiple policy enforcements use the same index.""" 149 | engine = RagEngine() 150 | 151 | # Mock the index 152 | mock_index = MagicMock() 153 | mock_query_engine = MagicMock() 154 | mock_index.as_query_engine.return_value = mock_query_engine 155 | mock_query_engine.query.return_value = "No policy violations found." 156 | 157 | # Build index once 158 | with patch.object(engine, 'build_index') as mock_build: 159 | def mock_build_side_effect(): 160 | engine.index = mock_index 161 | mock_build.side_effect = mock_build_side_effect 162 | 163 | enforcer = K8sPolicyEnforcer(engine) 164 | 165 | test_manifest = """ 166 | apiVersion: apps/v1 167 | kind: Deployment 168 | metadata: 169 | name: test 170 | spec: 171 | replicas: 1 172 | template: 173 | spec: 174 | containers: 175 | - name: test 176 | image: nginx:latest 177 | """ 178 | 179 | # First enforcement - should build index 180 | violations1 = enforcer.enforce_policy(test_manifest) 181 | assert mock_build.call_count == 1 182 | 183 | # Second enforcement - should reuse existing index 184 | violations2 = enforcer.enforce_policy(test_manifest) 185 | assert mock_build.call_count == 1 # Still only called once 186 | 187 | # Both should work 188 | assert isinstance(violations1, list) 189 | assert isinstance(violations2, list) 190 | 191 | def test_manifest_parsing_integration(self): 192 | """Test that manifest parsing works correctly with the optimized index.""" 193 | engine = RagEngine() 194 | enforcer = K8sPolicyEnforcer(engine) 195 | 196 | # Test valid YAML manifest 197 | valid_manifest = """ 198 | apiVersion: apps/v1 199 | kind: Deployment 200 | metadata: 201 | name: test-app 202 | spec: 203 | replicas: 2 204 | selector: 205 | matchLabels: 206 | app: test-app 207 | template: 208 | metadata: 209 | labels: 210 | app: test-app 211 | spec: 212 | containers: 213 | - name: test-container 214 | image: nginx:latest 215 | """ 216 | 217 | # Should parse without errors 218 | parsed = enforcer._parse_manifest(valid_manifest) 219 | assert len(parsed) == 1 220 | assert parsed[0]['kind'] == 'Deployment' 221 | assert parsed[0]['metadata']['name'] == 'test-app' 222 | 223 | def test_invalid_manifest_handling(self): 224 | """Test that invalid manifests are handled correctly.""" 225 | engine = RagEngine() 226 | enforcer = K8sPolicyEnforcer(engine) 227 | 228 | # Test invalid YAML 229 | invalid_manifest = """ 230 | apiVersion: apps/v1 231 | kind: Deployment 232 | metadata: 233 | name: test-app 234 | spec: 235 | replicas: 2 236 | invalid_yaml: [unclosed 237 | """ 238 | 239 | # Should raise ValueError for invalid YAML 240 | with pytest.raises(ValueError, match="Invalid YAML manifest"): 241 | enforcer._parse_manifest(invalid_manifest) 242 | 243 | def test_empty_manifest_handling(self): 244 | """Test that empty manifests are handled correctly.""" 245 | engine = RagEngine() 246 | enforcer = K8sPolicyEnforcer(engine) 247 | 248 | # Test empty manifest 249 | empty_manifest = "" 250 | 251 | # Should raise ValueError for empty manifest 252 | with pytest.raises(ValueError, match="No valid documents found"): 253 | enforcer._parse_manifest(empty_manifest) 254 | 255 | 256 | if __name__ == '__main__': 257 | pytest.main([__file__]) -------------------------------------------------------------------------------- /tests/test_api_integration.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Integration tests for API endpoints with index optimization.""" 3 | 4 | import pytest 5 | from fastapi.testclient import TestClient 6 | from unittest.mock import patch, MagicMock 7 | import sys 8 | import os 9 | 10 | # Add src to path for imports 11 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) 12 | 13 | 14 | class TestAPIIntegration: 15 | """Test suite for API integration with index optimization.""" 16 | 17 | def test_startup_builds_index(self): 18 | """Test that API startup builds the index.""" 19 | mock_rag_engine = MagicMock() 20 | mock_rag_engine.build_index = MagicMock() 21 | 22 | with patch('src.rag.engine.rag_engine', mock_rag_engine): 23 | with patch('src.api.main.rag_engine', mock_rag_engine): 24 | # Clear module cache 25 | modules_to_clear = [m for m in sys.modules.keys() if m.startswith('src.api.main')] 26 | for module in modules_to_clear: 27 | if module in sys.modules: 28 | del sys.modules[module] 29 | 30 | # Import the app (simulates startup) 31 | from src.api.main import app 32 | 33 | # Verify build_index was called at least once during startup 34 | assert mock_rag_engine.build_index.call_count >= 1 35 | 36 | def test_health_endpoint_simple(self): 37 | """Test the /health endpoint works.""" 38 | with patch('src.rag.engine.rag_engine') as mock_rag_engine: 39 | with patch('src.api.main.rag_engine', mock_rag_engine): 40 | # Clear module cache 41 | modules_to_clear = [m for m in sys.modules.keys() if m.startswith('src.api.main')] 42 | for module in modules_to_clear: 43 | if module in sys.modules: 44 | del sys.modules[module] 45 | 46 | from src.api.main import app 47 | client = TestClient(app) 48 | 49 | response = client.get("/health") 50 | assert response.status_code == 200 51 | data = response.json() 52 | assert data["status"] == "healthy" 53 | 54 | def test_query_endpoint_integration(self): 55 | """Test query endpoint with actual integration.""" 56 | mock_rag_engine = MagicMock() 57 | mock_rag_engine.build_index = MagicMock() 58 | mock_rag_engine.query.return_value = { 59 | "answer": "Test policy answer", 60 | "sources": [{"source": "test.txt", "text": "Test content"}], 61 | "metadata": {"provider": "test", "model": "test-model"} 62 | } 63 | 64 | with patch('src.rag.engine.rag_engine', mock_rag_engine): 65 | with patch('src.api.main.rag_engine', mock_rag_engine): 66 | # Clear module cache 67 | modules_to_clear = [m for m in sys.modules.keys() if m.startswith('src.api.main')] 68 | for module in modules_to_clear: 69 | if module in sys.modules: 70 | del sys.modules[module] 71 | 72 | from src.api.main import app 73 | client = TestClient(app) 74 | 75 | response = client.post("/query", json={"query": "What is the password policy?"}) 76 | 77 | assert response.status_code == 200 78 | data = response.json() 79 | assert "answer" in data 80 | assert "sources" in data 81 | assert "metadata" in data 82 | 83 | # Verify the RAG engine query was called 84 | mock_rag_engine.query.assert_called_with("What is the password policy?") 85 | 86 | def test_validate_manifest_endpoint_structure(self): 87 | """Test that the validate manifest endpoint has the correct structure.""" 88 | with patch('src.rag.engine.rag_engine') as mock_rag_engine: 89 | with patch('src.api.main.rag_engine', mock_rag_engine): 90 | # Clear module cache 91 | modules_to_clear = [m for m in sys.modules.keys() if m.startswith('src.api.main')] 92 | for module in modules_to_clear: 93 | if module in sys.modules: 94 | del sys.modules[module] 95 | 96 | from src.api.main import app 97 | client = TestClient(app) 98 | 99 | test_manifest = """ 100 | apiVersion: apps/v1 101 | kind: Deployment 102 | metadata: 103 | name: test-app 104 | spec: 105 | replicas: 1 106 | template: 107 | spec: 108 | containers: 109 | - name: test 110 | image: nginx:latest 111 | """ 112 | 113 | response = client.post("/validate-manifest", json={"manifest": test_manifest}) 114 | 115 | assert response.status_code == 200 116 | data = response.json() 117 | 118 | # Check response structure 119 | assert "violations" in data 120 | assert "compliant" in data 121 | assert "metadata" in data 122 | assert isinstance(data["violations"], list) 123 | assert isinstance(data["compliant"], bool) 124 | assert isinstance(data["metadata"], dict) 125 | assert "violation_count" in data["metadata"] 126 | assert "error_count" in data["metadata"] 127 | assert "warning_count" in data["metadata"] 128 | 129 | def test_query_endpoint_handles_errors(self): 130 | """Test query endpoint error handling.""" 131 | mock_rag_engine = MagicMock() 132 | mock_rag_engine.build_index = MagicMock() 133 | mock_rag_engine.query.side_effect = Exception("Test error") 134 | 135 | with patch('src.rag.engine.rag_engine', mock_rag_engine): 136 | with patch('src.api.main.rag_engine', mock_rag_engine): 137 | # Clear module cache 138 | modules_to_clear = [m for m in sys.modules.keys() if m.startswith('src.api.main')] 139 | for module in modules_to_clear: 140 | if module in sys.modules: 141 | del sys.modules[module] 142 | 143 | from src.api.main import app 144 | client = TestClient(app) 145 | 146 | response = client.post("/query", json={"query": "test query"}) 147 | 148 | assert response.status_code == 500 149 | assert "Error processing query" in response.json()["detail"] 150 | 151 | def test_multiple_query_requests_reuse_index(self): 152 | """Test that multiple query requests reuse the same index without rebuilding.""" 153 | mock_rag_engine = MagicMock() 154 | mock_rag_engine.build_index = MagicMock() 155 | mock_rag_engine.query.return_value = { 156 | "answer": "Test answer", 157 | "sources": [], 158 | "metadata": {} 159 | } 160 | 161 | with patch('src.rag.engine.rag_engine', mock_rag_engine): 162 | with patch('src.api.main.rag_engine', mock_rag_engine): 163 | # Clear module cache 164 | modules_to_clear = [m for m in sys.modules.keys() if m.startswith('src.api.main')] 165 | for module in modules_to_clear: 166 | if module in sys.modules: 167 | del sys.modules[module] 168 | 169 | from src.api.main import app 170 | client = TestClient(app) 171 | 172 | # Reset call count after startup 173 | mock_rag_engine.build_index.reset_mock() 174 | 175 | # Make multiple requests 176 | client.post("/query", json={"query": "First query"}) 177 | client.post("/query", json={"query": "Second query"}) 178 | client.post("/query", json={"query": "Third query"}) 179 | 180 | # build_index should not be called again after startup 181 | mock_rag_engine.build_index.assert_not_called() 182 | 183 | # But the engine query method should be called 184 | assert mock_rag_engine.query.call_count == 3 185 | 186 | def test_api_endpoints_exist(self): 187 | """Test that all expected API endpoints exist and return valid responses.""" 188 | with patch('src.rag.engine.rag_engine') as mock_rag_engine: 189 | mock_rag_engine.query.return_value = { 190 | "answer": "Test", 191 | "sources": [], 192 | "metadata": {} 193 | } 194 | 195 | with patch('src.api.main.rag_engine', mock_rag_engine): 196 | # Clear module cache 197 | modules_to_clear = [m for m in sys.modules.keys() if m.startswith('src.api.main')] 198 | for module in modules_to_clear: 199 | if module in sys.modules: 200 | del sys.modules[module] 201 | 202 | from src.api.main import app 203 | client = TestClient(app) 204 | 205 | # Test health endpoint 206 | response = client.get("/health") 207 | assert response.status_code == 200 208 | 209 | # Test query endpoint 210 | response = client.post("/query", json={"query": "test"}) 211 | assert response.status_code == 200 212 | 213 | # Test validate manifest endpoint 214 | response = client.post("/validate-manifest", json={"manifest": "apiVersion: v1\nkind: Pod"}) 215 | assert response.status_code == 200 216 | 217 | def test_invalid_requests_handled(self): 218 | """Test that invalid requests are handled properly.""" 219 | with patch('src.rag.engine.rag_engine') as mock_rag_engine: 220 | with patch('src.api.main.rag_engine', mock_rag_engine): 221 | # Clear module cache 222 | modules_to_clear = [m for m in sys.modules.keys() if m.startswith('src.api.main')] 223 | for module in modules_to_clear: 224 | if module in sys.modules: 225 | del sys.modules[module] 226 | 227 | from src.api.main import app 228 | client = TestClient(app) 229 | 230 | # Test query endpoint with missing data 231 | response = client.post("/query", json={}) 232 | assert response.status_code == 422 # Validation error 233 | 234 | # Test validate manifest endpoint with missing data 235 | response = client.post("/validate-manifest", json={}) 236 | assert response.status_code == 422 # Validation error 237 | 238 | 239 | if __name__ == '__main__': 240 | pytest.main([__file__]) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/core/cli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Command-line interface for the policy engine.""" 3 | 4 | import logging 5 | import click 6 | import yaml 7 | import httpx 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | from src.rag.engine import RagEngine, K8sPolicyEnforcer 12 | from src.core.config import config 13 | 14 | # Configure logging 15 | logging.basicConfig(level=logging.INFO) 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def _override_config(provider: str, model: Optional[str] = None): 20 | """Override configuration with CLI parameters.""" 21 | import os 22 | 23 | # Override provider 24 | os.environ['LLM_PROVIDER'] = provider 25 | config.llm.provider = provider 26 | 27 | # Override model if specified 28 | if model: 29 | if provider == 'anthropic': 30 | os.environ['ANTHROPIC_MODEL'] = model 31 | config.anthropic.model_name = model 32 | elif provider == 'openai': 33 | os.environ['OPENAI_MODEL'] = model 34 | config.openai.model_name = model 35 | elif provider == 'llamastack': 36 | os.environ['LLAMASTACK_MODEL'] = model 37 | config.llamastack.model_name = model 38 | 39 | @click.group() 40 | @click.option('--provider', '-p', 41 | type=click.Choice(['llamastack', 'anthropic', 'openai']), 42 | help='LLM provider to use (overrides config, local mode only)') 43 | @click.option('--model', '-m', help='Model name to use (overrides config, local mode only)') 44 | @click.option('--api-url', default='http://localhost:8000', 45 | help='API server URL for remote mode (default: http://localhost:8000)') 46 | @click.option('--use-api', is_flag=True, 47 | help='Use API server instead of local processing (more efficient, avoids re-indexing)') 48 | @click.pass_context 49 | def cli(ctx, provider, model, api_url, use_api): 50 | """Policy engine CLI. 51 | 52 | Two modes of operation: 53 | 54 | 1. LOCAL MODE (default): Processes policies locally, builds RAG index on each run 55 | 56 | 2. API MODE (--use-api): Queries a running API server, more efficient as it 57 | reuses pre-built RAG indices and avoids re-parsing policy documents 58 | """ 59 | # Store options in context for subcommands 60 | ctx.ensure_object(dict) 61 | ctx.obj['provider'] = provider 62 | ctx.obj['model'] = model 63 | ctx.obj['api_url'] = api_url 64 | ctx.obj['use_api'] = use_api 65 | 66 | @cli.command() 67 | @click.argument('query') 68 | @click.pass_context 69 | def ask(ctx, query: str): 70 | """Ask a question about company policies.""" 71 | use_api = ctx.obj.get('use_api', False) 72 | api_url = ctx.obj.get('api_url', 'http://localhost:8000') 73 | 74 | if use_api: 75 | # Use API server 76 | try: 77 | with httpx.Client() as client: 78 | response = client.post( 79 | f"{api_url}/query", 80 | json={"query": query}, 81 | timeout=30.0 82 | ) 83 | response.raise_for_status() 84 | result = response.json() 85 | 86 | click.echo("\nAnswer:") 87 | click.echo(result["answer"]) 88 | 89 | if result.get("metadata"): 90 | click.echo(f"\nMode: API Server ({api_url})") 91 | 92 | if result.get("sources"): 93 | click.echo("\nSources:") 94 | for source in result["sources"]: 95 | click.echo(f"- {source['source']}") 96 | 97 | except httpx.RequestError as e: 98 | logger.error(f"Error connecting to API server: {e}") 99 | raise click.ClickException(f"Failed to connect to API server at {api_url}: {e}") 100 | except httpx.HTTPStatusError as e: 101 | logger.error(f"API server error: {e}") 102 | raise click.ClickException(f"API server error: {e.response.status_code} - {e.response.text}") 103 | except Exception as e: 104 | logger.error(f"Error processing query via API: {e}") 105 | raise click.ClickException(f"Failed to process query via API: {e}") 106 | else: 107 | # Use local processing 108 | provider = ctx.obj.get('provider') 109 | model = ctx.obj.get('model') 110 | 111 | if provider: 112 | _override_config(provider, model) 113 | 114 | try: 115 | engine = RagEngine() 116 | result = engine.query(query) 117 | 118 | click.echo("\nAnswer:") 119 | click.echo(result["answer"]) 120 | 121 | click.echo(f"\nUsed: {result['metadata']['provider']} - {result['metadata']['model']}") 122 | 123 | if result["sources"]: 124 | click.echo("\nSources:") 125 | for source in result["sources"]: 126 | click.echo(f"- {source['source']}") 127 | 128 | except Exception as e: 129 | logger.error(f"Error processing query: {e}") 130 | click.echo(f"Error: {e}", err=True) 131 | raise click.ClickException(f"Failed to process query: {e}") 132 | 133 | @cli.command() 134 | @click.argument('manifest_path', type=click.Path(exists=True)) 135 | @click.option('--output', '-o', type=click.Path(), help='Output file for violations') 136 | @click.pass_context 137 | def validate_manifest(ctx, manifest_path: str, output: Optional[str]): 138 | """Validate a Kubernetes manifest against company policies.""" 139 | use_api = ctx.obj.get('use_api', False) 140 | api_url = ctx.obj.get('api_url', 'http://localhost:8000') 141 | 142 | # Read manifest file 143 | try: 144 | with open(manifest_path, 'r') as f: 145 | manifest = f.read() 146 | except Exception as e: 147 | logger.error(f"Failed to read manifest file: {e}") 148 | raise click.ClickException(f"Failed to read manifest file: {e}") 149 | 150 | if use_api: 151 | # Use API server 152 | try: 153 | with httpx.Client() as client: 154 | response = client.post( 155 | f"{api_url}/validate-manifest", 156 | json={"manifest": manifest}, 157 | timeout=60.0 # Longer timeout for manifest validation 158 | ) 159 | response.raise_for_status() 160 | result = response.json() 161 | 162 | violations = result.get("violations", []) 163 | 164 | if violations: 165 | click.echo(f"\nFound {len(violations)} policy violations in {manifest_path}:") 166 | for violation in violations: 167 | click.echo(f"\nRule: {violation['rule']}") 168 | click.echo(f"Violation: {violation['violation']}") 169 | click.echo(f"Severity: {violation['severity']}") 170 | 171 | # Write violations to output file if specified 172 | if output: 173 | try: 174 | with open(output, 'w') as f: 175 | yaml.dump(violations, f) 176 | click.echo(f"\nViolations written to {output}") 177 | except Exception as e: 178 | logger.error(f"Failed to write violations to output file: {e}") 179 | raise click.ClickException(f"Failed to write violations to output file: {e}") 180 | 181 | click.echo(f"\nMode: API Server ({api_url})") 182 | # Exit with error code if there are violations 183 | raise click.ClickException("Policy violations found") 184 | else: 185 | click.echo(f"\nNo policy violations found in {manifest_path}") 186 | click.echo(f"Mode: API Server ({api_url})") 187 | 188 | except httpx.RequestError as e: 189 | logger.error(f"Error connecting to API server: {e}") 190 | raise click.ClickException(f"Failed to connect to API server at {api_url}: {e}") 191 | except httpx.HTTPStatusError as e: 192 | logger.error(f"API server error: {e}") 193 | raise click.ClickException(f"API server error: {e.response.status_code} - {e.response.text}") 194 | except Exception as e: 195 | logger.error(f"Error validating manifest via API: {e}") 196 | raise click.ClickException(f"Failed to validate manifest via API: {e}") 197 | else: 198 | # Use local processing 199 | provider = ctx.obj.get('provider') 200 | model = ctx.obj.get('model') 201 | 202 | if provider: 203 | _override_config(provider, model) 204 | 205 | # Initialize RAG engine and policy enforcer 206 | engine = RagEngine() 207 | policy_enforcer = K8sPolicyEnforcer(engine) 208 | 209 | # Validate manifest 210 | try: 211 | violations = policy_enforcer.enforce_policy(manifest) 212 | except Exception as e: 213 | logger.error(f"Failed to validate manifest: {e}") 214 | raise click.ClickException(f"Failed to validate manifest: {e}") 215 | 216 | # Output results 217 | if violations: 218 | click.echo(f"\nFound {len(violations)} policy violations in {manifest_path}:") 219 | for violation in violations: 220 | click.echo(f"\nRule: {violation.rule}") 221 | click.echo(f"Violation: {violation.violation}") 222 | click.echo(f"Severity: {violation.severity}") 223 | 224 | # Write violations to output file if specified 225 | if output: 226 | try: 227 | with open(output, 'w') as f: 228 | yaml.dump([v.dict() for v in violations], f) 229 | click.echo(f"\nViolations written to {output}") 230 | except Exception as e: 231 | logger.error(f"Failed to write violations to output file: {e}") 232 | raise click.ClickException(f"Failed to write violations to output file: {e}") 233 | 234 | # Exit with error code if there are violations 235 | raise click.ClickException("Policy violations found") 236 | else: 237 | click.echo(f"\nNo policy violations found in {manifest_path}") 238 | 239 | @cli.command() 240 | def providers(): 241 | """List available LLM providers and their configuration status.""" 242 | click.echo("Available LLM Providers:") 243 | click.echo("=" * 50) 244 | 245 | # Check LlamaStack 246 | click.echo("\n🦙 LlamaStack:") 247 | click.echo(f" URL: {config.llamastack.api_url}") 248 | click.echo(f" Model: {config.llamastack.model_name}") 249 | try: 250 | from llama_stack_client import LlamaStackClient 251 | client = LlamaStackClient(base_url=config.llamastack.api_url, timeout=5.0) 252 | # Try a simple health check - this might fail but that's ok 253 | click.echo(" Status: ✅ Client available") 254 | except Exception as e: 255 | click.echo(f" Status: ❌ {str(e)[:50]}...") 256 | 257 | # Check Anthropic 258 | click.echo("\n🤖 Anthropic:") 259 | click.echo(f" Model: {config.anthropic.model_name}") 260 | if config.anthropic.api_key: 261 | click.echo(" API Key: ✅ Set") 262 | try: 263 | from llama_index.llms.anthropic import Anthropic 264 | click.echo(" Status: ✅ Available") 265 | except ImportError: 266 | click.echo(" Status: ❌ Package not installed (llama-index-llms-anthropic)") 267 | else: 268 | click.echo(" API Key: ❌ Not set (ANTHROPIC_API_KEY)") 269 | click.echo(" Status: ❌ Not configured") 270 | 271 | # Check OpenAI 272 | click.echo("\n🧠 OpenAI:") 273 | click.echo(f" Model: {config.openai.model_name}") 274 | if config.openai.api_key: 275 | click.echo(" API Key: ✅ Set") 276 | try: 277 | from llama_index.llms.openai import OpenAI 278 | click.echo(" Status: ✅ Available") 279 | except ImportError: 280 | click.echo(" Status: ❌ Package not installed (llama-index-llms-openai)") 281 | else: 282 | click.echo(" API Key: ❌ Not set (OPENAI_API_KEY)") 283 | click.echo(" Status: ❌ Not configured") 284 | 285 | click.echo(f"\nCurrent Provider: {config.llm.provider}") 286 | click.echo("=" * 50) 287 | 288 | def main(): 289 | """Main entry point for the CLI.""" 290 | cli() 291 | 292 | if __name__ == '__main__': 293 | main() 294 | -------------------------------------------------------------------------------- /src/rag/engine.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Optional, AsyncGenerator, Generator, Union 2 | import logging 3 | import yaml 4 | 5 | from llama_index.core import Document, VectorStoreIndex 6 | from llama_index.core.node_parser import SimpleNodeParser 7 | from llama_index.core.settings import Settings 8 | from llama_index.core.llms import LLM, ChatMessage, ChatResponse, CompletionResponse, LLMMetadata 9 | from llama_index.embeddings.huggingface import HuggingFaceEmbedding 10 | from llama_stack_client import LlamaStackClient 11 | from llama_stack_client.types import Model 12 | from pydantic import Field, BaseModel 13 | 14 | from src.core.config import config 15 | from src.policy.loader import policy_loader 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | class PolicyViolation(BaseModel): 20 | """Model for policy violations.""" 21 | rule: str 22 | manifest_path: str 23 | violation: str 24 | severity: str = "error" 25 | 26 | class K8sPolicyEnforcer: 27 | """Enforces Kubernetes manifest policies using natural language rules.""" 28 | 29 | def __init__(self, rag_engine: 'RagEngine'): 30 | """Initialize the K8s policy enforcer. 31 | 32 | Args: 33 | rag_engine: RAG engine instance to use for policy queries 34 | """ 35 | self.rag_engine = rag_engine 36 | 37 | def _parse_manifest(self, manifest: str) -> List[Dict[str, Any]]: 38 | """Parse a Kubernetes manifest (potentially containing multiple documents). 39 | 40 | Args: 41 | manifest: YAML manifest string 42 | 43 | Returns: 44 | List of parsed manifest dictionaries 45 | """ 46 | try: 47 | # Handle multiple documents separated by --- 48 | documents = list(yaml.safe_load_all(manifest)) 49 | # Filter out None/empty documents 50 | documents = [doc for doc in documents if doc is not None] 51 | if not documents: 52 | raise ValueError("No valid documents found in manifest") 53 | return documents 54 | except yaml.YAMLError as e: 55 | logger.error(f"Failed to parse manifest: {e}") 56 | raise ValueError(f"Invalid YAML manifest: {e}") 57 | 58 | def _format_manifest_for_prompt(self, manifests: List[Dict[str, Any]]) -> str: 59 | """Format manifests for inclusion in prompt. 60 | 61 | Args: 62 | manifests: List of parsed manifest dictionaries 63 | 64 | Returns: 65 | Formatted manifest string 66 | """ 67 | formatted_docs = [] 68 | for i, manifest in enumerate(manifests): 69 | formatted_docs.append(f"# Document {i+1}: {manifest.get('kind', 'Unknown')} - {manifest.get('metadata', {}).get('name', 'Unnamed')}") 70 | formatted_docs.append(yaml.dump(manifest, default_flow_style=False)) 71 | formatted_docs.append("---") 72 | return "\n".join(formatted_docs) 73 | 74 | def enforce_policy(self, manifest: Union[str, Dict[str, Any], List[Dict[str, Any]]]) -> List[PolicyViolation]: 75 | """Enforce policy on a Kubernetes manifest. 76 | 77 | Args: 78 | manifest: Kubernetes manifest as YAML string, parsed dictionary, or list of dictionaries 79 | 80 | Returns: 81 | List of policy violations found 82 | """ 83 | # Ensure the RAG engine's index is built (for CLI compatibility) 84 | if not self.rag_engine.index: 85 | self.rag_engine.build_index() 86 | 87 | if isinstance(manifest, str): 88 | manifests = self._parse_manifest(manifest) 89 | elif isinstance(manifest, dict): 90 | manifests = [manifest] 91 | else: 92 | manifests = manifest 93 | 94 | formatted_manifest = self._format_manifest_for_prompt(manifests) 95 | 96 | # Create query engine using the RAG engine's index 97 | query_engine = self.rag_engine.index.as_query_engine( 98 | similarity_top_k=config.rag.similarity_top_k, 99 | ) 100 | 101 | # Construct prompt for policy enforcement 102 | prompt = f"""Analyze this Kubernetes manifest against our company policies: 103 | 104 | {formatted_manifest} 105 | 106 | Please check if this manifest violates any of our policies. For each violation found, provide: 107 | 1. The specific policy rule that was violated 108 | 2. The exact part of the manifest that violates the rule 109 | 3. The severity of the violation (error or warning) 110 | 111 | If no violations are found, respond with "No policy violations found." 112 | 113 | Format your response as a list of violations, one per line, with each violation containing: 114 | - Rule: [policy rule] 115 | - Violation: [description of violation] 116 | - Severity: [error/warning] 117 | """ 118 | 119 | # Execute query 120 | response = query_engine.query(prompt) 121 | 122 | # Parse violations from response 123 | violations = [] 124 | if "No policy violations found" not in str(response): 125 | # Parse the response to extract violations 126 | # This is a simple implementation - you might want to make it more robust 127 | lines = str(response).split('\n') 128 | current_violation = {} 129 | 130 | for line in lines: 131 | if line.startswith('- Rule:'): 132 | if current_violation: 133 | violations.append(PolicyViolation(**current_violation)) 134 | current_violation = {'rule': line[7:].strip()} 135 | elif line.startswith('- Violation:'): 136 | current_violation['violation'] = line[12:].strip() 137 | elif line.startswith('- Severity:'): 138 | current_violation['severity'] = line[11:].strip() 139 | current_violation['manifest_path'] = 'root' # You might want to make this more specific 140 | 141 | if current_violation: 142 | violations.append(PolicyViolation(**current_violation)) 143 | 144 | return violations 145 | 146 | 147 | class LlamaStackLLM(LLM): 148 | """Wrapper for LlamaStack LLM to work with LlamaIndex.""" 149 | 150 | client: LlamaStackClient = Field(description="LlamaStackClient instance") 151 | model_id: str = Field(description="Model ID to use") 152 | 153 | def __init__(self, client: LlamaStackClient, model_id: str): 154 | """Initialize the LlamaStack LLM wrapper. 155 | 156 | Args: 157 | client: LlamaStackClient instance 158 | model_id: Model ID to use 159 | """ 160 | super().__init__(client=client, model_id=model_id) 161 | 162 | @property 163 | def metadata(self) -> LLMMetadata: 164 | """Get LLM metadata.""" 165 | return LLMMetadata( 166 | model_name=self.model_id, 167 | is_chat_model=True, 168 | is_function_calling_model=False, 169 | context_window=4096, # Default value, adjust based on your model 170 | num_output=2048, # Default value, adjust based on your model 171 | ) 172 | 173 | def complete(self, prompt: str, **kwargs) -> CompletionResponse: 174 | """Complete the prompt using LlamaStack. 175 | 176 | Args: 177 | prompt: The prompt to complete 178 | **kwargs: Additional arguments 179 | 180 | Returns: 181 | The completed text 182 | """ 183 | try: 184 | response = self.client.inference.chat_completion( 185 | model_id=self.model_id, 186 | messages=[ 187 | {"role": "system", "content": "You are a helpful AI assistant."}, 188 | {"role": "user", "content": prompt} 189 | ] 190 | ) 191 | # Handle the response format from LlamaStack 192 | if hasattr(response, 'message'): 193 | content = response.message.content 194 | elif hasattr(response, 'content'): 195 | content = response.content 196 | else: 197 | raise ValueError("Unexpected response format from LlamaStack") 198 | 199 | return CompletionResponse(text=content) 200 | 201 | except Exception as e: 202 | logger.error(f"Error in LlamaStack completion: {e}") 203 | # Return a fallback response 204 | return CompletionResponse(text=f"Error: Unable to complete request due to {type(e).__name__}: {str(e)}") 205 | 206 | def stream_complete(self, prompt: str, **kwargs) -> Generator[CompletionResponse, None, None]: 207 | """Stream complete the prompt using LlamaStack. 208 | 209 | Args: 210 | prompt: The prompt to complete 211 | **kwargs: Additional arguments 212 | 213 | Yields: 214 | The completed text chunks 215 | """ 216 | response = self.client.inference.chat_completion( 217 | model_id=self.model_id, 218 | messages=[ 219 | {"role": "system", "content": "You are a helpful AI assistant."}, 220 | {"role": "user", "content": prompt} 221 | ], 222 | stream=True 223 | ) 224 | for chunk in response: 225 | if hasattr(chunk, 'delta') and hasattr(chunk.delta, 'content') and chunk.delta.content: 226 | yield CompletionResponse(text=chunk.delta.content) 227 | elif hasattr(chunk, 'content') and chunk.content: 228 | yield CompletionResponse(text=chunk.content) 229 | 230 | def chat(self, messages: List[ChatMessage], **kwargs) -> ChatResponse: 231 | """Chat with the model using LlamaStack. 232 | 233 | Args: 234 | messages: List of chat messages 235 | **kwargs: Additional arguments 236 | 237 | Returns: 238 | The chat response 239 | """ 240 | try: 241 | response = self.client.inference.chat_completion( 242 | model_id=self.model_id, 243 | messages=[{"role": msg.role.value, "content": msg.content} for msg in messages] 244 | ) 245 | except Exception as e: 246 | logger.error(f"Error in LlamaStack chat completion: {e}") 247 | # Return a fallback response 248 | return ChatResponse(message=ChatMessage(role="assistant", content=f"Error: Unable to complete request due to {type(e).__name__}: {str(e)}")) 249 | 250 | # Debug logging 251 | logger.debug(f"LlamaStack response type: {type(response)}") 252 | logger.debug(f"LlamaStack response attributes: {dir(response)}") 253 | 254 | # Handle the response format from LlamaStack 255 | try: 256 | if hasattr(response, 'completion_message'): 257 | content = response.completion_message.content 258 | elif hasattr(response, 'message'): 259 | content = response.message.content 260 | elif hasattr(response, 'content'): 261 | content = response.content 262 | elif hasattr(response, 'text'): 263 | content = response.text 264 | elif hasattr(response, 'response'): 265 | content = response.response 266 | elif isinstance(response, str): 267 | content = response 268 | elif isinstance(response, dict): 269 | if 'completion_message' in response: 270 | content = response['completion_message'].get('content', '') 271 | elif 'message' in response: 272 | content = response['message'].get('content', '') 273 | elif 'content' in response: 274 | content = response['content'] 275 | elif 'text' in response: 276 | content = response['text'] 277 | elif 'response' in response: 278 | content = response['response'] 279 | else: 280 | raise ValueError(f"Unexpected response dict format: {response}") 281 | else: 282 | raise ValueError(f"Unexpected response type: {type(response)}") 283 | 284 | if not content: 285 | raise ValueError("Empty response content") 286 | 287 | return ChatResponse(message=ChatMessage(role="assistant", content=content)) 288 | 289 | except Exception as e: 290 | logger.error(f"Error processing LlamaStack response: {str(e)}") 291 | logger.error(f"Response object: {response}") 292 | raise ValueError(f"Failed to process LlamaStack response: {str(e)}") 293 | 294 | def stream_chat(self, messages: List[ChatMessage], **kwargs) -> Generator[ChatResponse, None, None]: 295 | """Stream chat with the model using LlamaStack. 296 | 297 | Args: 298 | messages: List of chat messages 299 | **kwargs: Additional arguments 300 | 301 | Yields: 302 | The chat response chunks 303 | """ 304 | response = self.client.inference.chat_completion( 305 | model_id=self.model_id, 306 | messages=[{"role": msg.role.value, "content": msg.content} for msg in messages], 307 | stream=True 308 | ) 309 | 310 | for chunk in response: 311 | try: 312 | content = None 313 | if hasattr(chunk, 'completion_message'): 314 | content = chunk.completion_message.content 315 | elif hasattr(chunk, 'delta') and hasattr(chunk.delta, 'content'): 316 | content = chunk.delta.content 317 | elif hasattr(chunk, 'content'): 318 | content = chunk.content 319 | elif hasattr(chunk, 'text'): 320 | content = chunk.text 321 | elif hasattr(chunk, 'response'): 322 | content = chunk.response 323 | elif isinstance(chunk, str): 324 | content = chunk 325 | elif isinstance(chunk, dict): 326 | if 'completion_message' in chunk: 327 | content = chunk['completion_message'].get('content', '') 328 | elif 'delta' in chunk and 'content' in chunk['delta']: 329 | content = chunk['delta']['content'] 330 | elif 'content' in chunk: 331 | content = chunk['content'] 332 | elif 'text' in chunk: 333 | content = chunk['text'] 334 | elif 'response' in chunk: 335 | content = chunk['response'] 336 | 337 | if content: 338 | yield ChatResponse(message=ChatMessage(role="assistant", content=content)) 339 | 340 | except Exception as e: 341 | logger.error(f"Error processing LlamaStack stream chunk: {str(e)}") 342 | logger.error(f"Chunk object: {chunk}") 343 | continue 344 | 345 | async def acomplete(self, prompt: str, **kwargs) -> CompletionResponse: 346 | """Async complete the prompt using LlamaStack. 347 | 348 | Note: LlamaStack's client doesn't have native async support, so this is a wrapper 349 | around the synchronous complete method. For true async support, we would need to 350 | implement this using an async HTTP client. 351 | 352 | Args: 353 | prompt: The prompt to complete 354 | **kwargs: Additional arguments 355 | 356 | Returns: 357 | The completed text 358 | """ 359 | return self.complete(prompt, **kwargs) 360 | 361 | async def astream_complete(self, prompt: str, **kwargs) -> AsyncGenerator[CompletionResponse, None]: 362 | """Async stream complete the prompt using LlamaStack. 363 | 364 | Note: LlamaStack's client doesn't have native async support, so this is a wrapper 365 | around the synchronous stream_complete method. For true async support, we would need to 366 | implement this using an async HTTP client. 367 | 368 | Args: 369 | prompt: The prompt to complete 370 | **kwargs: Additional arguments 371 | 372 | Yields: 373 | The completed text chunks 374 | """ 375 | for response in self.stream_complete(prompt, **kwargs): 376 | yield response 377 | 378 | async def achat(self, messages: List[ChatMessage], **kwargs) -> ChatResponse: 379 | """Async chat with the model using LlamaStack. 380 | 381 | Note: LlamaStack's client doesn't have native async support, so this is a wrapper 382 | around the synchronous chat method. For true async support, we would need to 383 | implement this using an async HTTP client. 384 | 385 | Args: 386 | messages: List of chat messages 387 | **kwargs: Additional arguments 388 | 389 | Returns: 390 | The chat response 391 | """ 392 | return self.chat(messages, **kwargs) 393 | 394 | async def astream_chat(self, messages: List[ChatMessage], **kwargs) -> AsyncGenerator[ChatResponse, None]: 395 | """Async stream chat with the model using LlamaStack. 396 | 397 | Note: LlamaStack's client doesn't have native async support, so this is a wrapper 398 | around the synchronous stream_chat method. For true async support, we would need to 399 | implement this using an async HTTP client. 400 | 401 | Args: 402 | messages: List of chat messages 403 | **kwargs: Additional arguments 404 | 405 | Yields: 406 | The chat response chunks 407 | """ 408 | for response in self.stream_chat(messages, **kwargs): 409 | yield response 410 | 411 | 412 | class RagEngine: 413 | """Retrieval-Augmented Generation engine for policy queries.""" 414 | 415 | def __init__(self, llm: Optional[LLM] = None): 416 | """Initialize the RAG engine. 417 | 418 | Args: 419 | llm: LLM instance to use. If None, will use llama-stack-client. 420 | """ 421 | self.llm = llm 422 | self.index = None 423 | self._setup_llm() 424 | self._setup_embeddings() 425 | 426 | def _setup_llm(self): 427 | """Set up the LLM using the configured provider.""" 428 | try: 429 | from src.core.llm_factory import llm_factory 430 | 431 | # Use the provided LLM or create one using the factory 432 | if not self.llm: 433 | self.llm = llm_factory.setup_global_llm() 434 | else: 435 | # Configure global settings with the provided LLM 436 | Settings.llm = self.llm 437 | Settings.chunk_size = config.rag.chunk_size 438 | Settings.chunk_overlap = config.rag.chunk_overlap 439 | 440 | except Exception as e: 441 | logger.error(f"Failed to setup LLM: {e}") 442 | raise 443 | 444 | def _setup_embeddings(self): 445 | """Set up the embedding model.""" 446 | try: 447 | # Initialize HuggingFace embedding model 448 | embed_model = HuggingFaceEmbedding( 449 | model_name="BAAI/bge-small-en-v1.5" 450 | ) 451 | 452 | # Configure global settings 453 | Settings.embed_model = embed_model 454 | 455 | except ImportError as e: 456 | logger.error(f"Failed to import HuggingFace embedding model: {e}") 457 | raise 458 | 459 | def build_index(self): 460 | """Build the index from policy documents.""" 461 | # Load policy documents 462 | policy_docs = policy_loader.load_policies() 463 | 464 | if not policy_docs: 465 | logger.warning("No policy documents found.") 466 | return 467 | 468 | # Convert to LlamaIndex documents 469 | documents = [ 470 | Document(text=doc["content"], metadata={"source": doc["source"]}) 471 | for doc in policy_docs 472 | ] 473 | 474 | # Create node parser 475 | node_parser = SimpleNodeParser.from_defaults( 476 | chunk_size=config.rag.chunk_size, 477 | chunk_overlap=config.rag.chunk_overlap, 478 | ) 479 | 480 | # Build the index 481 | self.index = VectorStoreIndex.from_documents( 482 | documents, 483 | node_parser=node_parser, 484 | ) 485 | 486 | logger.info(f"Built index with {len(documents)} policy documents.") 487 | 488 | def query(self, query_text: str) -> Dict[str, Any]: 489 | """Query the policy engine. 490 | 491 | Args: 492 | query_text: The query string. 493 | 494 | Returns: 495 | A dictionary containing the response, relevant documents, and metadata. 496 | """ 497 | if not self.index: 498 | self.build_index() 499 | 500 | # Create query engine 501 | query_engine = self.index.as_query_engine( 502 | similarity_top_k=config.rag.similarity_top_k, 503 | ) 504 | 505 | # Execute query 506 | response = query_engine.query(query_text) 507 | 508 | # Debug logging for source nodes 509 | logger.debug("Source nodes from response:") 510 | for node in getattr(response, "source_nodes", []): 511 | logger.debug(f"Node ID: {node.node.node_id}") 512 | logger.debug(f"Source: {node.node.metadata.get('source', 'unknown')}") 513 | logger.debug(f"Text: {node.node.get_text()[:100]}...") # First 100 chars 514 | 515 | # Format result and deduplicate sources by source path only 516 | seen_sources = set() 517 | unique_sources = [] 518 | 519 | for node in getattr(response, "source_nodes", []): 520 | source_path = node.node.metadata.get("source", "unknown") 521 | source_text = node.node.get_text() 522 | 523 | logger.debug(f"Processing source: {source_path}") 524 | logger.debug(f"Already seen: {source_path in seen_sources}") 525 | 526 | if source_path not in seen_sources: 527 | seen_sources.add(source_path) 528 | unique_sources.append({ 529 | "source": source_path, 530 | "text": source_text 531 | }) 532 | logger.debug("Added to unique sources") 533 | else: 534 | logger.debug("Skipped duplicate source") 535 | 536 | result = { 537 | "answer": str(response), 538 | "sources": unique_sources, 539 | "metadata": { 540 | "provider": config.llm.provider, 541 | "model": self._get_model_name(), 542 | } 543 | } 544 | 545 | return result 546 | 547 | def _get_model_name(self) -> str: 548 | """Get the model name based on the current provider.""" 549 | provider = config.llm.provider 550 | if provider == "anthropic": 551 | return config.anthropic.model_name 552 | elif provider == "openai": 553 | return config.openai.model_name 554 | elif provider == "llamastack": 555 | return config.llamastack.model_name 556 | else: 557 | return "unknown" 558 | 559 | 560 | # Create a singleton RAG engine instance 561 | rag_engine = RagEngine() 562 | --------------------------------------------------------------------------------