├── src
├── __init__.py
├── environment.py
├── pqc_utils.py
├── config_utils.py
└── mcp_client.py
├── tests
├── __init__.py
├── test_pqc_utils.py
├── test_config_utils.py
└── test_mcp_client.py
├── .env
├── scripts
├── __init__.py
├── install.sh
└── mock_mcp_server.py
├── requirements.txt
├── .dockerignore
├── docker-compose.yml
├── .pre-commit-config.yaml
├── Dockerfile
├── config.yaml
├── pyproject.toml
├── qu3-32x32.svg
├── DEVELOPMENT.md
├── TROUBLESHOOTING.md
└── README.md
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.env:
--------------------------------------------------------------------------------
1 | QU3_ENV=development
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | fastapi
2 | uvicorn[standard]
3 | typer[all]
4 | requests
5 | liboqs-python@git+https://github.com/open-quantum-safe/liboqs-python@main
6 | cryptography
7 | PyYAML
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | # Git
2 | .git
3 | .gitignore
4 |
5 | # Python
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 | *.so
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # Virtual environments
28 | venv/
29 | env/
30 | ENV/
31 |
32 | # IDE
33 | .vscode/
34 | .idea/
35 | *.swp
36 | *.swo
37 | *~
38 |
39 | # OS
40 | .DS_Store
41 | Thumbs.db
42 |
43 | # Logs
44 | *.log
45 |
46 | # Keys (security)
47 | *.sec
48 | *.key
49 | *.pem
50 |
51 | # Temporary files
52 | *.tmp
53 | *.temp
54 |
55 | # Documentation build
56 | docs/_build/
57 |
58 | # Coverage reports
59 | htmlcov/
60 | .coverage
61 | .coverage.*
62 |
63 | # pytest
64 | .pytest_cache/
65 |
66 | # mypy
67 | .mypy_cache/
68 |
69 | # Docker
70 | Dockerfile
71 | docker-compose.yml
72 | .dockerignore
73 |
74 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3.8'
2 |
3 | services:
4 | qu3-client:
5 | build: .
6 | container_name: qu3-client
7 | environment:
8 | - QU3_ENV=development
9 | - PYTHONPATH=/app
10 | volumes:
11 | - ./config.yaml:/app/config.yaml:ro
12 | - qu3-keys:/home/qu3user/.qu3/keys
13 | networks:
14 | - qu3-network
15 | depends_on:
16 | - qu3-server
17 | command: ["python3", "-m", "src.main", "generate-keys"]
18 |
19 | qu3-server:
20 | build: .
21 | container_name: qu3-server
22 | environment:
23 | - QU3_ENV=development
24 | - PYTHONPATH=/app
25 | volumes:
26 | - ./config.yaml:/app/config.yaml:ro
27 | - qu3-keys:/home/qu3user/.qu3/keys
28 | ports:
29 | - "8000:8000"
30 | networks:
31 | - qu3-network
32 | command: ["python3", "-m", "scripts.mock_mcp_server"]
33 | healthcheck:
34 | test: ["CMD", "curl", "-f", "http://localhost:8000/"]
35 | interval: 30s
36 | timeout: 10s
37 | retries: 3
38 | start_period: 40s
39 |
40 | volumes:
41 | qu3-keys:
42 | driver: local
43 |
44 | networks:
45 | qu3-network:
46 | driver: bridge
47 |
48 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.4.0
4 | hooks:
5 | - id: trailing-whitespace
6 | - id: end-of-file-fixer
7 | - id: check-yaml
8 | - id: check-added-large-files
9 | - id: check-case-conflict
10 | - id: check-merge-conflict
11 | - id: check-json
12 | - id: check-toml
13 | - id: debug-statements
14 | - id: detect-private-key
15 |
16 | - repo: https://github.com/psf/black
17 | rev: 23.3.0
18 | hooks:
19 | - id: black
20 | language_version: python3
21 |
22 | - repo: https://github.com/pycqa/isort
23 | rev: 5.12.0
24 | hooks:
25 | - id: isort
26 | args: ["--profile", "black"]
27 |
28 | - repo: https://github.com/pycqa/flake8
29 | rev: 6.0.0
30 | hooks:
31 | - id: flake8
32 | args: [--max-line-length=100, --extend-ignore=E203,W503]
33 |
34 | - repo: https://github.com/pre-commit/mirrors-mypy
35 | rev: v1.3.0
36 | hooks:
37 | - id: mypy
38 | additional_dependencies: [types-PyYAML, types-requests]
39 | args: [--ignore-missing-imports]
40 |
41 | - repo: https://github.com/PyCQA/bandit
42 | rev: 1.7.5
43 | hooks:
44 | - id: bandit
45 | args: ["-c", "pyproject.toml"]
46 | additional_dependencies: ["bandit[toml]"]
47 |
48 | - repo: https://github.com/pycqa/pydocstyle
49 | rev: 6.3.0
50 | hooks:
51 | - id: pydocstyle
52 | args: [--convention=google]
53 |
54 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Dockerfile for qu3-app with liboqs support
2 | FROM ubuntu:latest
3 |
4 | # Install system dependencies
5 | RUN apt-get update && apt-get install -y \
6 | build-essential \
7 | git \
8 | cmake \
9 | libssl-dev \
10 | python3 \
11 | python3-venv \
12 | pip \
13 | curl \
14 | && rm -rf /var/lib/apt/lists/*
15 |
16 | # Get and install liboqs
17 | RUN git clone --depth 1 --branch main https://github.com/open-quantum-safe/liboqs
18 | RUN cmake -S liboqs -B liboqs/build -DBUILD_SHARED_LIBS=ON && \
19 | cmake --build liboqs/build --parallel 4 && \
20 | cmake --build liboqs/build --target install
21 |
22 | # Create non-root user for security
23 | RUN useradd -m qu3user && \
24 | mkdir -p /home/qu3user/.qu3/keys && \
25 | chown -R qu3user:qu3user /home/qu3user/.qu3 && \
26 | chmod 700 /home/qu3user/.qu3/keys
27 |
28 | # Switch to non-root user
29 | USER qu3user
30 | WORKDIR /home/qu3user
31 |
32 | # Create Python virtual environment
33 | RUN python3 -m venv venv
34 |
35 | # Set up application directory
36 | WORKDIR /app
37 | COPY requirements.txt .
38 |
39 | # Set environment variables for library paths
40 | ENV LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
41 | ENV PATH="/home/qu3user/venv/bin:$PATH"
42 |
43 | # Install Python dependencies
44 | RUN . /home/qu3user/venv/bin/activate && \
45 | pip install --no-cache-dir --upgrade pip && \
46 | pip install --no-cache-dir -r requirements.txt
47 |
48 | # Copy application code
49 | COPY . .
50 | USER root
51 | RUN chown -R qu3user:qu3user /app
52 | USER qu3user
53 |
54 | # Set environment variables
55 | ENV PYTHONPATH=/app
56 | ENV QU3_ENV=production
57 | ENV PYTHONDONTWRITEBYTECODE=1
58 | ENV PYTHONUNBUFFERED=1
59 |
60 | # Health check
61 | HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
62 | CMD python3 -c "from src.environment import validate_environment; validate_environment()" || exit 1
63 |
64 | # Default command
65 | CMD ["python3", "-m", "src.main", "--help"]
66 |
67 |
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for qu3-app client
2 |
3 | # Default URL for the MCP server
4 | server_url: "http://127.0.0.1:8000"
5 |
6 | # Directory to store client and server keys
7 | # Use "~/.qu3/keys" for user-specific storage, or a relative path like "./keys"
8 | key_directory: "~/.qu3/keys"
9 |
10 | # Cryptographic algorithms (should match server configuration)
11 | algorithms:
12 | kem: "Kyber768"
13 | sig: "SPHINCS+-SHA2-128f-simple"
14 |
15 | # Connection timeouts (in seconds)
16 | timeouts:
17 | connect: 30
18 | read: 30
19 | handshake: 15
20 |
21 | # Retry configuration
22 | retry:
23 | max_attempts: 3
24 | backoff_factor: 0.3
25 | delay: 1.0
26 |
27 | # Logging configuration
28 | logging:
29 | level: "INFO" # DEBUG, INFO, WARNING, ERROR, CRITICAL
30 | file: null # Optional log file path, e.g., "/var/log/qu3_client.log"
31 | format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
32 |
33 | # Security settings
34 | security:
35 | # Validate server certificates (set to false only for development)
36 | verify_ssl: true
37 | # Maximum allowed response size (in bytes)
38 | max_response_size: 10485760 # 10MB
39 | # Session key rotation interval (in seconds, 0 to disable)
40 | session_rotation_interval: 3600 # 1 hour
41 |
42 | # Performance settings
43 | performance:
44 | # HTTP connection pool settings
45 | pool_connections: 10
46 | pool_maxsize: 20
47 | # Enable request/response compression
48 | compression: true
49 |
50 | # Development settings (only used when QU3_ENV=development)
51 | development:
52 | # Enable debug mode
53 | debug: false
54 | # Mock server settings
55 | mock_server:
56 | host: "127.0.0.1"
57 | port: 8000
58 | auto_start: false
59 | # Skip certain validations for development
60 | skip_validations: []
61 |
62 | # Environment-specific overrides
63 | environments:
64 | production:
65 | logging:
66 | level: "WARNING"
67 | security:
68 | verify_ssl: true
69 | development:
70 | debug: false
71 |
72 | development:
73 | logging:
74 | level: "DEBUG"
75 | security:
76 | verify_ssl: false
77 | development:
78 | debug: true
79 |
80 | testing:
81 | logging:
82 | level: "ERROR"
83 | security:
84 | max_response_size: 1048576 # 1MB for testing
85 | development:
86 | skip_validations: ["ssl", "response_size"]
87 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "qu3-app"
7 | version = "1.0.0"
8 | description = "Quantum-Safe MCP Client Application"
9 | readme = "README.md"
10 | requires-python = ">=3.8"
11 | license = {text = "MIT"}
12 | authors = [
13 | {name = "QU3 Team", email = "joseph@qu3.ai"}
14 | ]
15 | keywords = ["quantum-safe", "post-quantum", "cryptography", "mcp", "client"]
16 | classifiers = [
17 | "Development Status :: 4 - Beta",
18 | "Intended Audience :: Developers",
19 | "License :: OSI Approved :: MIT License",
20 | "Operating System :: OS Independent",
21 | "Programming Language :: Python :: 3",
22 | "Programming Language :: Python :: 3.8",
23 | "Programming Language :: Python :: 3.9",
24 | "Programming Language :: Python :: 3.10",
25 | "Programming Language :: Python :: 3.11",
26 | "Programming Language :: Python :: 3.12",
27 | "Topic :: Security :: Cryptography",
28 | "Topic :: Software Development :: Libraries :: Python Modules",
29 | ]
30 | dependencies = [
31 | "fastapi>=0.68.0",
32 | "uvicorn[standard]>=0.15.0",
33 | "typer[all]>=0.7.0",
34 | "requests>=2.20.0",
35 | "liboqs-python @ git+https://github.com/open-quantum-safe/liboqs-python@main",
36 | "cryptography>=3.0.0",
37 | "PyYAML>=5.4.0",
38 | ]
39 |
40 | [project.optional-dependencies]
41 | dev = [
42 | "pytest>=6.0.0",
43 | "pytest-cov>=2.10.0",
44 | "pytest-asyncio>=0.18.0",
45 | "black>=22.0.0",
46 | "isort>=5.10.0",
47 | "flake8>=4.0.0",
48 | "mypy>=0.950",
49 | "pre-commit>=2.15.0",
50 | ]
51 | docs = [
52 | "sphinx>=4.0.0",
53 | "sphinx-rtd-theme>=1.0.0",
54 | "myst-parser>=0.17.0",
55 | ]
56 |
57 | [project.urls]
58 | Homepage = "https://github.com/qu3ai/qu3-app"
59 | Documentation = "https://github.com/qu3ai/qu3-app#readme"
60 | Repository = "https://github.com/qu3ai/qu3-app"
61 | "Bug Tracker" = "https://github.com/qu3ai/qu3-app/issues"
62 |
63 | [project.scripts]
64 | qu3 = "src.main:app"
65 |
66 | [tool.setuptools.packages.find]
67 | where = ["."]
68 | include = ["src*"]
69 |
70 | [tool.black]
71 | line-length = 100
72 | target-version = ['py38']
73 | include = '\.pyi?$'
74 | extend-exclude = '''
75 | /(
76 | # directories
77 | \.eggs
78 | | \.git
79 | | \.hg
80 | | \.mypy_cache
81 | | \.tox
82 | | \.venv
83 | | build
84 | | dist
85 | )/
86 | '''
87 |
88 | [tool.isort]
89 | profile = "black"
90 | line_length = 100
91 | multi_line_output = 3
92 | include_trailing_comma = true
93 | force_grid_wrap = 0
94 | use_parentheses = true
95 | ensure_newline_before_comments = true
96 |
97 | [tool.mypy]
98 | python_version = "3.8"
99 | warn_return_any = true
100 | warn_unused_configs = true
101 | disallow_untyped_defs = true
102 | disallow_incomplete_defs = true
103 | check_untyped_defs = true
104 | disallow_untyped_decorators = true
105 | no_implicit_optional = true
106 | warn_redundant_casts = true
107 | warn_unused_ignores = true
108 | warn_no_return = true
109 | warn_unreachable = true
110 | strict_equality = true
111 |
112 | [[tool.mypy.overrides]]
113 | module = [
114 | "oqs.*",
115 | "uvicorn.*",
116 | ]
117 | ignore_missing_imports = true
118 |
119 | [tool.pytest.ini_options]
120 | minversion = "6.0"
121 | addopts = "-ra -q --strict-markers --strict-config"
122 | testpaths = ["tests"]
123 | python_files = ["test_*.py", "*_test.py"]
124 | python_classes = ["Test*"]
125 | python_functions = ["test_*"]
126 | markers = [
127 | "slow: marks tests as slow (deselect with '-m \"not slow\"')",
128 | "integration: marks tests as integration tests",
129 | "unit: marks tests as unit tests",
130 | ]
131 |
132 | [tool.coverage.run]
133 | source = ["src"]
134 | omit = [
135 | "*/tests/*",
136 | "*/test_*",
137 | "*/__pycache__/*",
138 | ]
139 |
140 | [tool.coverage.report]
141 | exclude_lines = [
142 | "pragma: no cover",
143 | "def __repr__",
144 | "if self.debug:",
145 | "if settings.DEBUG",
146 | "raise AssertionError",
147 | "raise NotImplementedError",
148 | "if 0:",
149 | "if __name__ == .__main__.:",
150 | "class .*\\bProtocol\\):",
151 | "@(abc\\.)?abstractmethod",
152 | ]
153 |
154 |
--------------------------------------------------------------------------------
/qu3-32x32.svg:
--------------------------------------------------------------------------------
1 |
25 |
--------------------------------------------------------------------------------
/scripts/install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # QU3-App Installation Script
4 | # This script helps with the complex installation process of qu3-app
5 |
6 | set -e
7 |
8 | echo "🚀 QU3-App Installation Script"
9 | echo "=============================="
10 |
11 | # Colors for output
12 | RED='\033[0;31m'
13 | GREEN='\033[0;32m'
14 | YELLOW='\033[1;33m'
15 | BLUE='\033[0;34m'
16 | NC='\033[0m' # No Color
17 |
18 | # Function to print colored output
19 | print_status() {
20 | echo -e "${GREEN}✅ $1${NC}"
21 | }
22 |
23 | print_warning() {
24 | echo -e "${YELLOW}⚠️ $1${NC}"
25 | }
26 |
27 | print_error() {
28 | echo -e "${RED}❌ $1${NC}"
29 | }
30 |
31 | print_info() {
32 | echo -e "${BLUE}ℹ️ $1${NC}"
33 | }
34 |
35 | # Check if running in Docker
36 | if [ -f /.dockerenv ]; then
37 | print_info "Running in Docker environment"
38 | IN_DOCKER=true
39 | else
40 | IN_DOCKER=false
41 | fi
42 |
43 | # Check Python version
44 | print_info "Checking Python version..."
45 | PYTHON_VERSION=$(python3 --version 2>&1 | cut -d' ' -f2)
46 | PYTHON_MAJOR=$(echo $PYTHON_VERSION | cut -d'.' -f1)
47 | PYTHON_MINOR=$(echo $PYTHON_VERSION | cut -d'.' -f2)
48 |
49 | if [ "$PYTHON_MAJOR" -lt 3 ] || ([ "$PYTHON_MAJOR" -eq 3 ] && [ "$PYTHON_MINOR" -lt 8 ]); then
50 | print_error "Python 3.8 or higher is required. Found: $PYTHON_VERSION"
51 | exit 1
52 | fi
53 |
54 | print_status "Python version $PYTHON_VERSION is compatible"
55 |
56 | # Check for required system dependencies
57 | print_info "Checking system dependencies..."
58 |
59 | check_command() {
60 | if command -v $1 >/dev/null 2>&1; then
61 | print_status "$1 is available"
62 | return 0
63 | else
64 | print_warning "$1 is not available"
65 | return 1
66 | fi
67 | }
68 |
69 | MISSING_DEPS=()
70 |
71 | if ! check_command cmake; then
72 | MISSING_DEPS+=("cmake")
73 | fi
74 |
75 | if ! check_command gcc; then
76 | MISSING_DEPS+=("build-essential")
77 | fi
78 |
79 | if ! check_command git; then
80 | MISSING_DEPS+=("git")
81 | fi
82 |
83 | # Install missing dependencies if we can
84 | if [ ${#MISSING_DEPS[@]} -gt 0 ]; then
85 | print_warning "Missing dependencies: ${MISSING_DEPS[*]}"
86 |
87 | if [ "$IN_DOCKER" = true ] || [ "$EUID" -eq 0 ]; then
88 | print_info "Attempting to install missing dependencies..."
89 |
90 | # Detect package manager
91 | if command -v apt-get >/dev/null 2>&1; then
92 | apt-get update
93 | for dep in "${MISSING_DEPS[@]}"; do
94 | apt-get install -y $dep
95 | done
96 | elif command -v yum >/dev/null 2>&1; then
97 | for dep in "${MISSING_DEPS[@]}"; do
98 | yum install -y $dep
99 | done
100 | elif command -v brew >/dev/null 2>&1; then
101 | for dep in "${MISSING_DEPS[@]}"; do
102 | brew install $dep
103 | done
104 | else
105 | print_error "Cannot automatically install dependencies. Please install: ${MISSING_DEPS[*]}"
106 | exit 1
107 | fi
108 |
109 | print_status "Dependencies installed successfully"
110 | else
111 | print_error "Please install the following dependencies manually: ${MISSING_DEPS[*]}"
112 | print_info "On Ubuntu/Debian: sudo apt-get install ${MISSING_DEPS[*]}"
113 | print_info "On CentOS/RHEL: sudo yum install ${MISSING_DEPS[*]}"
114 | print_info "On macOS: brew install ${MISSING_DEPS[*]}"
115 | exit 1
116 | fi
117 | fi
118 |
119 | # Check if virtual environment exists
120 | if [ ! -d "venv" ]; then
121 | print_info "Creating Python virtual environment..."
122 | python3 -m venv venv
123 | print_status "Virtual environment created"
124 | else
125 | print_status "Virtual environment already exists"
126 | fi
127 |
128 | # Activate virtual environment
129 | print_info "Activating virtual environment..."
130 | source venv/bin/activate
131 |
132 | # Upgrade pip
133 | print_info "Upgrading pip..."
134 | pip install --upgrade pip
135 |
136 | # Install requirements
137 | print_info "Installing Python dependencies..."
138 | print_warning "This may take several minutes due to liboqs-python compilation..."
139 |
140 | # Install requirements with timeout and retry logic
141 | MAX_RETRIES=3
142 | RETRY_COUNT=0
143 |
144 | while [ $RETRY_COUNT -lt $MAX_RETRIES ]; do
145 | if pip install -r requirements.txt; then
146 | print_status "Python dependencies installed successfully"
147 | break
148 | else
149 | RETRY_COUNT=$((RETRY_COUNT + 1))
150 | if [ $RETRY_COUNT -lt $MAX_RETRIES ]; then
151 | print_warning "Installation failed, retrying ($RETRY_COUNT/$MAX_RETRIES)..."
152 | sleep 5
153 | else
154 | print_error "Failed to install dependencies after $MAX_RETRIES attempts"
155 | print_info "You may need to install liboqs manually. See: https://github.com/open-quantum-safe/liboqs"
156 | exit 1
157 | fi
158 | fi
159 | done
160 |
161 | # Install development dependencies if requested
162 | if [ "$1" = "--dev" ]; then
163 | print_info "Installing development dependencies..."
164 | pip install pytest pytest-cov black isort flake8 mypy pre-commit
165 | print_status "Development dependencies installed"
166 | fi
167 |
168 | # Validate installation
169 | print_info "Validating installation..."
170 |
171 | if python -c "import oqs; print('liboqs-python:', oqs.__version__)" 2>/dev/null; then
172 | print_status "liboqs-python is working correctly"
173 | else
174 | print_error "liboqs-python installation validation failed"
175 | exit 1
176 | fi
177 |
178 | if python -c "import src.main" 2>/dev/null; then
179 | print_status "QU3-App modules can be imported"
180 | else
181 | print_error "QU3-App module import failed"
182 | exit 1
183 | fi
184 |
185 | # Create default configuration if it doesn't exist
186 | if [ ! -f "config.yaml" ]; then
187 | print_info "Creating default configuration..."
188 | cat > config.yaml << EOF
189 | # QU3-App Configuration
190 | key_directory: "~/.qu3/keys/"
191 | server_url: "http://127.0.0.1:8000"
192 |
193 | # Logging configuration
194 | logging:
195 | level: "INFO"
196 | file: null # Set to a file path to enable file logging
197 |
198 | # Network configuration
199 | network:
200 | timeout: 30
201 | verify_ssl: true
202 | EOF
203 | print_status "Default configuration created"
204 | fi
205 |
206 | # Run environment validation
207 | print_info "Running environment validation..."
208 | if python -m src.environment 2>/dev/null; then
209 | print_status "Environment validation passed"
210 | else
211 | print_warning "Environment validation had issues (this may be normal)"
212 | fi
213 |
214 | echo ""
215 | print_status "🎉 QU3-App installation completed successfully!"
216 | echo ""
217 | print_info "Next steps:"
218 | echo " 1. Activate the virtual environment: source venv/bin/activate"
219 | echo " 2. Generate client keys: python -m src.main generate-keys"
220 | echo " 3. Start the mock server: python -m scripts.mock_mcp_server"
221 | echo " 4. Test the installation: python -m src.main validate-config"
222 | echo ""
223 | print_info "For help: python -m src.main --help"
224 |
225 |
--------------------------------------------------------------------------------
/DEVELOPMENT.md:
--------------------------------------------------------------------------------
1 | # QU3-App Development Guide
2 |
3 | This guide provides comprehensive information for developers working on the qu3-app project.
4 |
5 | ## Table of Contents
6 |
7 | - [Development Setup](#development-setup)
8 | - [Architecture Overview](#architecture-overview)
9 | - [Testing](#testing)
10 | - [Docker Development](#docker-development)
11 | - [Security Considerations](#security-considerations)
12 | - [Troubleshooting](#troubleshooting)
13 |
14 | ## Development Setup
15 |
16 | ### Prerequisites
17 |
18 | - Python 3.8 or higher
19 | - Git
20 | - Docker (optional, for containerized development)
21 | - CMake and build tools (for liboqs compilation)
22 |
23 | ### Local Development Setup
24 |
25 | 1. **Clone the repository:**
26 | ```bash
27 | git clone https://github.com/qu3ai/qu3-app.git
28 | cd qu3-app
29 | ```
30 |
31 | 2. **Create and activate virtual environment:**
32 | ```bash
33 | python3 -m venv venv
34 | source venv/bin/activate # On Windows: venv\Scripts\activate
35 | ```
36 |
37 | 3. **Install dependencies:**
38 | ```bash
39 | pip install -r requirements.txt
40 | pip install -e ".[dev]" # Install development dependencies
41 | ```
42 |
43 | 4. **Install pre-commit hooks:**
44 | ```bash
45 | pre-commit install
46 | ```
47 |
48 | 5. **Run environment validation:**
49 | ```bash
50 | python -m src.environment
51 | ```
52 |
53 | ### Docker Development Setup
54 |
55 | For easier dependency management, especially with liboqs:
56 |
57 | 1. **Build the Docker image:**
58 | ```bash
59 | docker build -t qu3-app .
60 | ```
61 |
62 | 2. **Run with Docker Compose:**
63 | ```bash
64 | docker-compose up --build
65 | ```
66 |
67 | ## Architecture Overview
68 |
69 | ### Core Components
70 |
71 | ```
72 | src/
73 | ├── main.py # CLI interface (Typer)
74 | ├── mcp_client.py # MCP client implementation
75 | ├── pqc_utils.py # Post-quantum cryptography utilities
76 | ├── config_utils.py # Configuration management
77 | └── environment.py # Environment validation and setup
78 | ```
79 |
80 | ### Key Design Patterns
81 |
82 | - **Dependency Injection**: Configuration and keys are injected into components
83 | - **Error Handling**: Comprehensive error handling with custom exception hierarchy
84 | - **Retry Logic**: Exponential backoff for network operations
85 | - **Security by Default**: Secure defaults with explicit opt-out for development
86 |
87 | ### Data Flow
88 |
89 | 1. **Initialization**: Load configuration and validate environment
90 | 2. **Key Management**: Generate or load PQC key pairs
91 | 3. **Connection**: Establish secure session via KEM handshake
92 | 4. **Request/Response**: Sign, encrypt, send, decrypt, verify
93 | 5. **Cleanup**: Secure cleanup of sensitive data
94 |
95 |
96 | ## Testing
97 |
98 | ### Test Structure
99 |
100 | ```
101 | tests/
102 | ├── test_pqc_utils.py # PQC utilities tests
103 | ├── test_config_utils.py # Configuration tests
104 | ├── test_mcp_client.py # Client integration tests
105 | └── conftest.py # Pytest configuration
106 | ```
107 |
108 |
109 | ## Docker Development
110 |
111 | ### Building Images
112 |
113 | ```bash
114 | # Build development image
115 | docker build -t qu3-app:dev .
116 |
117 | # Build with specific target
118 | docker build --target builder -t qu3-app:builder .
119 | ```
120 |
121 | ### Development Workflow
122 |
123 | ```bash
124 | # Start development environment
125 | docker-compose up -d
126 |
127 | # Access container shell
128 | docker-compose exec qu3-client bash
129 |
130 | # View logs
131 | docker-compose logs -f qu3-client
132 | ```
133 |
134 | ### Volume Mounts
135 |
136 | For development, mount source code:
137 |
138 | ```yaml
139 | volumes:
140 | - ./src:/app/src:ro
141 | - ./tests:/app/tests:ro
142 | - ./config.yaml:/app/config.yaml:ro
143 | ```
144 |
145 | ## Security Considerations
146 |
147 | ### Key Management
148 |
149 | - Keys are stored with restrictive permissions (600 for private, 644 for public)
150 | - Key directories have 700 permissions
151 | - No keys should ever be committed to version control
152 |
153 | ### Secure Coding Practices
154 |
155 | 1. **Input Validation**: Validate all inputs at boundaries
156 | 2. **Error Handling**: Don't leak sensitive information in error messages
157 | 3. **Logging**: Be careful not to log sensitive data
158 | 4. **Dependencies**: Keep dependencies updated and scan for vulnerabilities
159 |
160 |
161 | ### Monitoring
162 |
163 | - Use structured logging for observability
164 | - Monitor key metrics: response times, error rates, resource usage
165 | - Set up alerts for critical failures
166 |
167 | ## Troubleshooting
168 |
169 | ### Common Issues
170 |
171 | #### liboqs Installation Fails
172 |
173 | **Problem**: CMake or compilation errors during liboqs installation
174 |
175 | **Solutions**:
176 | 1. Use Docker for consistent environment
177 | 2. Install system dependencies: `apt-get install build-essential cmake`
178 | 3. Use pre-built wheels if available
179 |
180 | #### Key Generation Errors
181 |
182 | **Problem**: PQC key generation fails
183 |
184 | **Solutions**:
185 | 1. Check liboqs installation: `python -c "import oqs; print(oqs.get_enabled_kem_mechanisms())"`
186 | 2. Verify algorithm names match exactly
187 | 3. Check file permissions on key directory
188 |
189 | #### Connection Timeouts
190 |
191 | **Problem**: Client cannot connect to server
192 |
193 | **Solutions**:
194 | 1. Check server is running: `curl http://localhost:8000/`
195 | 2. Verify network connectivity
196 | 3. Check firewall settings
197 | 4. Increase timeout values in config
198 |
199 | #### SSL/TLS Issues
200 |
201 | **Problem**: Certificate verification errors
202 |
203 | **Solutions**:
204 | 1. For development: set `verify_ssl: false` in config
205 | 2. For production: ensure valid certificates
206 | 3. Check system certificate store
207 |
208 | ### Debug Mode
209 |
210 | Enable debug logging for detailed troubleshooting:
211 |
212 | ```yaml
213 | # config.yaml
214 | logging:
215 | level: "DEBUG"
216 | file: "debug.log"
217 | ```
218 |
219 | ### Environment Diagnostics
220 |
221 | Run comprehensive environment check:
222 |
223 | ```bash
224 | python -m src.environment
225 | ```
226 |
227 | This will output:
228 | - System information
229 | - Python version and packages
230 | - liboqs status and available algorithms
231 | - Configuration validation
232 | - Security recommendations
233 |
234 | ## Contributing
235 |
236 | ### Development Workflow
237 |
238 | 1. Create feature branch: `git checkout -b feature/your-feature`
239 | 2. Make changes with tests
240 | 3. Run quality checks: `pre-commit run --all-files`
241 | 4. Run tests: `pytest`
242 | 5. Commit changes: `git commit -m "feat: add your feature"`
243 | 6. Push and create pull request
244 |
245 | ### Commit Message Format
246 |
247 | Use conventional commits:
248 |
249 | - `feat:` New features
250 | - `fix:` Bug fixes
251 | - `docs:` Documentation changes
252 | - `style:` Code style changes
253 | - `refactor:` Code refactoring
254 | - `test:` Test additions or changes
255 | - `chore:` Maintenance tasks
256 |
257 | ### Code Review Guidelines
258 |
259 | - All code must be reviewed before merging
260 | - Tests must pass and coverage should not decrease
261 | - Documentation should be updated for new features
262 | - Security implications should be considered
263 |
264 |
--------------------------------------------------------------------------------
/tests/test_pqc_utils.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import os
3 | import sys
4 | import base64
5 |
6 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
7 | if project_root not in sys.path:
8 | sys.path.insert(0, project_root)
9 |
10 | from src import pqc_utils
11 |
12 | class TestPqcUtils(unittest.TestCase):
13 |
14 | @classmethod
15 | def setUpClass(cls):
16 | """Generate keys once for all tests in this class."""
17 | cls.kem_algo = pqc_utils.ALGORITHMS["kem"]
18 | cls.sig_algo = pqc_utils.ALGORITHMS["sig"]
19 |
20 | print(f"\nGenerating test keys (KEM: {cls.kem_algo}, SIG: {cls.sig_algo})...")
21 | cls.client_kem_pk, cls.client_kem_sk = pqc_utils.generate_key_pair(cls.kem_algo)
22 | cls.server_kem_pk, cls.server_kem_sk = pqc_utils.generate_key_pair(cls.kem_algo)
23 | cls.client_sign_pk, cls.client_sign_sk = pqc_utils.generate_key_pair(cls.sig_algo)
24 | cls.server_sign_pk, cls.server_sign_sk = pqc_utils.generate_key_pair(cls.sig_algo)
25 | print("Test keys generated.")
26 |
27 | def test_01_generate_key_pair(self):
28 | """Test basic key generation for KEM and SIG algorithms."""
29 | self.assertIsNotNone(self.client_kem_pk)
30 | self.assertIsNotNone(self.client_kem_sk)
31 | self.assertIsNotNone(self.client_sign_pk)
32 | self.assertIsNotNone(self.client_sign_sk)
33 | self.assertIsInstance(self.client_kem_pk, bytes)
34 | self.assertIsInstance(self.client_kem_sk, bytes)
35 | self.assertIsInstance(self.client_sign_pk, bytes)
36 | self.assertIsInstance(self.client_sign_sk, bytes)
37 | self.assertGreater(len(self.client_kem_pk), 0)
38 | self.assertGreater(len(self.client_kem_sk), 0)
39 | self.assertGreater(len(self.client_sign_pk), 0)
40 | self.assertGreater(len(self.client_sign_sk), 0)
41 |
42 |
43 | with self.assertRaises(pqc_utils.PQCKeyGenerationError):
44 | pqc_utils.generate_key_pair("NonExistentAlgo")
45 |
46 | def test_02_kem_encap_decap(self):
47 | """Test KEM encapsulation and decapsulation round trip."""
48 |
49 | ciphertext1, shared_secret1_client = pqc_utils.kem_encapsulate(self.kem_algo, self.server_kem_pk)
50 | self.assertIsNotNone(ciphertext1)
51 | self.assertIsNotNone(shared_secret1_client)
52 | self.assertIsInstance(ciphertext1, bytes)
53 | self.assertIsInstance(shared_secret1_client, bytes)
54 | self.assertGreater(len(ciphertext1), 0)
55 | self.assertGreater(len(shared_secret1_client), 0)
56 |
57 |
58 | shared_secret1_server = pqc_utils.kem_decapsulate(self.kem_algo, ciphertext1, self.server_kem_sk)
59 | self.assertIsNotNone(shared_secret1_server)
60 | self.assertIsInstance(shared_secret1_server, bytes)
61 |
62 |
63 | self.assertEqual(shared_secret1_client, shared_secret1_server)
64 |
65 |
66 | ciphertext2, shared_secret2_server = pqc_utils.kem_encapsulate(self.kem_algo, self.client_kem_pk)
67 | shared_secret2_client = pqc_utils.kem_decapsulate(self.kem_algo, ciphertext2, self.client_kem_sk)
68 | self.assertEqual(shared_secret2_server, shared_secret2_client)
69 |
70 |
71 | # with self.assertRaises(pqc_utils.POCKEMError):
72 | # pqc_utils.kem_decapsulate(self.kem_algo, ciphertext1, self.client_kem_sk)
73 | # Commenting out the above assertion: For some KEMs, decapsulating with a wrong key might not raise an error
74 | # but produce a different shared secret. The critical part is that Alice and Bob derive the *same* secret
75 | # from a legitimate exchange, and that tampered ciphertexts fail.
76 |
77 | tampered_ciphertext = ciphertext1[:-1] + bytes([(ciphertext1[-1] + 1) % 256])
78 | with self.assertRaises(pqc_utils.POCKEMError):
79 | pqc_utils.kem_decapsulate(self.kem_algo, tampered_ciphertext, self.server_kem_sk)
80 |
81 | def test_03_sign_verify(self):
82 | """Test SPHINCS+ signing and verification round trip."""
83 | message = b"This is a test message for SPHINCS+ signing."
84 |
85 |
86 | signature = pqc_utils.sign_message(message, self.client_sign_sk, self.sig_algo)
87 | self.assertIsNotNone(signature)
88 | self.assertIsInstance(signature, bytes)
89 | self.assertGreater(len(signature), 0)
90 |
91 |
92 | self.assertTrue(pqc_utils.verify_signature(message, signature, self.client_sign_pk, self.sig_algo))
93 |
94 |
95 | self.assertFalse(pqc_utils.verify_signature(message, signature, self.server_sign_pk, self.sig_algo))
96 |
97 |
98 | tampered_message = b"This is a different message."
99 | self.assertFalse(pqc_utils.verify_signature(tampered_message, signature, self.client_sign_pk, self.sig_algo))
100 |
101 |
102 | tampered_signature = signature[:-1] + bytes([(signature[-1] + 1) % 256])
103 |
104 | self.assertFalse(pqc_utils.verify_signature(message, tampered_signature, self.client_sign_pk, self.sig_algo))
105 |
106 | def test_04_derive_aes_key(self):
107 | """Test HKDF key derivation."""
108 |
109 | _, shared_secret = pqc_utils.kem_encapsulate(self.kem_algo, self.server_kem_pk)
110 |
111 | aes_key1 = pqc_utils.derive_aes_key(shared_secret)
112 | self.assertIsInstance(aes_key1, bytes)
113 |
114 | self.assertEqual(len(aes_key1), 32)
115 |
116 |
117 | aes_key2 = pqc_utils.derive_aes_key(shared_secret)
118 | self.assertEqual(aes_key1, aes_key2)
119 |
120 |
121 | _, shared_secret_other = pqc_utils.kem_encapsulate(self.kem_algo, self.client_kem_pk)
122 |
123 | if shared_secret != shared_secret_other:
124 | aes_key_other = pqc_utils.derive_aes_key(shared_secret_other)
125 | self.assertNotEqual(aes_key1, aes_key_other)
126 |
127 | def test_05_aes_gcm_encrypt_decrypt(self):
128 | """Test AES-GCM encryption and decryption round trip."""
129 |
130 | _, shared_secret = pqc_utils.kem_encapsulate(self.kem_algo, self.server_kem_pk)
131 | aes_key = pqc_utils.derive_aes_key(shared_secret)
132 |
133 | plaintext = b"This is the secret data to be encrypted."
134 |
135 |
136 | nonce, ciphertext = pqc_utils.encrypt_aes_gcm(aes_key, plaintext)
137 | self.assertIsInstance(nonce, bytes)
138 | self.assertIsInstance(ciphertext, bytes)
139 |
140 | self.assertEqual(len(nonce), 12)
141 | # Ciphertext includes the tag, so it's longer than plaintext
142 | self.assertGreater(len(ciphertext), len(plaintext))
143 |
144 |
145 | decrypted_plaintext = pqc_utils.decrypt_aes_gcm(aes_key, nonce, ciphertext)
146 | self.assertEqual(plaintext, decrypted_plaintext)
147 |
148 |
149 | _, shared_secret_wrong = pqc_utils.kem_encapsulate(self.kem_algo, self.client_kem_pk)
150 | aes_key_wrong = pqc_utils.derive_aes_key(shared_secret_wrong)
151 | if aes_key != aes_key_wrong:
152 | with self.assertRaises(pqc_utils.PQCDecryptionError):
153 | pqc_utils.decrypt_aes_gcm(aes_key_wrong, nonce, ciphertext)
154 |
155 |
156 | nonce_wrong = os.urandom(12)
157 | with self.assertRaises(pqc_utils.PQCDecryptionError):
158 | pqc_utils.decrypt_aes_gcm(aes_key, nonce_wrong, ciphertext)
159 |
160 |
161 | tampered_ciphertext = ciphertext[:-1] + bytes([(ciphertext[-1] + 1) % 256])
162 | with self.assertRaises(pqc_utils.PQCDecryptionError):
163 | pqc_utils.decrypt_aes_gcm(aes_key, nonce, tampered_ciphertext)
164 |
165 | if __name__ == '__main__':
166 | unittest.main()
--------------------------------------------------------------------------------
/TROUBLESHOOTING.md:
--------------------------------------------------------------------------------
1 | # QU3-App Troubleshooting Guide
2 |
3 | This guide helps you resolve common issues when working with QU3-App.
4 |
5 | ## 🚨 Common Issues
6 |
7 | ### Installation Problems
8 |
9 | #### liboqs-python Installation Fails
10 |
11 | **Symptoms:**
12 | - CMake errors during installation
13 | - Compilation failures
14 | - "No module named 'oqs'" errors
15 |
16 | **Solutions:**
17 |
18 | 1. **Use the automated installer:**
19 | ```bash
20 | ./scripts/install.sh
21 | ```
22 |
23 | 2. **Manual dependency installation:**
24 | ```bash
25 | # Ubuntu/Debian
26 | sudo apt-get update
27 | sudo apt-get install build-essential cmake git
28 |
29 | # CentOS/RHEL
30 | sudo yum install gcc gcc-c++ cmake git
31 |
32 | # macOS
33 | brew install cmake
34 | ```
35 |
36 | 3. **Docker alternative:**
37 | ```bash
38 | docker-compose up --build
39 | ```
40 |
41 | 4. **Pre-built wheels (if available):**
42 | ```bash
43 | pip install --find-links https://github.com/open-quantum-safe/liboqs-python/releases liboqs-python
44 | ```
45 |
46 | #### Python Version Issues
47 |
48 | **Symptoms:**
49 | - "Python 3.8 or higher is required"
50 | - Syntax errors in modern Python code
51 |
52 | **Solutions:**
53 | 1. **Check Python version:**
54 | ```bash
55 | python3 --version
56 | ```
57 |
58 | 2. **Install newer Python:**
59 | ```bash
60 | # Ubuntu/Debian
61 | sudo apt-get install python3.9 python3.9-venv
62 |
63 | # macOS
64 | brew install python@3.9
65 | ```
66 |
67 | 3. **Use pyenv for version management:**
68 | ```bash
69 | pyenv install 3.9.0
70 | pyenv local 3.9.0
71 | ```
72 |
73 | ### Configuration Issues
74 |
75 | #### Server URL Not Configured
76 |
77 | **Symptoms:**
78 | - "Server URL not configured" errors
79 | - Connection failures
80 |
81 | **Solutions:**
82 | 1. **Check config.yaml:**
83 | ```yaml
84 | server_url: "http://127.0.0.1:8000"
85 | ```
86 |
87 | 2. **Use CLI override:**
88 | ```bash
89 | python -m src.main run-inference model_name '{}' --server-url http://localhost:8000
90 | ```
91 |
92 | 3. **Validate configuration:**
93 | ```bash
94 | python -m src.main validate-config
95 | ```
96 |
97 | #### Key Directory Issues
98 |
99 | **Symptoms:**
100 | - Permission denied errors
101 | - Keys not found
102 |
103 | **Solutions:**
104 | 1. **Check directory permissions:**
105 | ```bash
106 | ls -la ~/.qu3/keys/
107 | ```
108 |
109 | 2. **Fix permissions:**
110 | ```bash
111 | chmod 700 ~/.qu3/keys/
112 | chmod 600 ~/.qu3/keys/*.sec
113 | chmod 644 ~/.qu3/keys/*.pub
114 | ```
115 |
116 | 3. **Regenerate keys:**
117 | ```bash
118 | python -m src.main generate-keys --force
119 | ```
120 |
121 | ### Connection Problems
122 |
123 | #### Cannot Connect to Server
124 |
125 | **Symptoms:**
126 | - Connection timeout errors
127 | - "Server is not reachable" messages
128 |
129 | **Solutions:**
130 | 1. **Test basic connectivity:**
131 | ```bash
132 | python -m src.main test-connection
133 | ```
134 |
135 | 2. **Check if server is running:**
136 | ```bash
137 | curl http://127.0.0.1:8000/
138 | ```
139 |
140 | 3. **Start the mock server:**
141 | ```bash
142 | python -m scripts.mock_mcp_server
143 | ```
144 |
145 | 4. **Check firewall settings:**
146 | ```bash
147 | # Linux
148 | sudo ufw status
149 |
150 | # macOS
151 | sudo pfctl -sr
152 | ```
153 |
154 | #### SSL/TLS Certificate Issues
155 |
156 | **Symptoms:**
157 | - Certificate verification errors
158 | - SSL handshake failures
159 |
160 | **Solutions:**
161 | 1. **For development, disable SSL verification:**
162 | ```yaml
163 | # config.yaml
164 | network:
165 | verify_ssl: false
166 | ```
167 |
168 | 2. **Update certificates:**
169 | ```bash
170 | # Ubuntu/Debian
171 | sudo apt-get update && sudo apt-get install ca-certificates
172 |
173 | # macOS
174 | brew install ca-certificates
175 | ```
176 |
177 | ### Runtime Errors
178 |
179 | #### Key Generation Failures
180 |
181 | **Symptoms:**
182 | - "Failed to generate keys" errors
183 | - PQC algorithm not found
184 |
185 | **Solutions:**
186 | 1. **Verify liboqs installation:**
187 | ```bash
188 | python -c "import oqs; print(oqs.get_enabled_kem_mechanisms())"
189 | ```
190 |
191 | 2. **Check available algorithms:**
192 | ```bash
193 | python -c "import oqs; print('KEM:', oqs.get_enabled_kem_mechanisms()); print('SIG:', oqs.get_enabled_sig_mechanisms())"
194 | ```
195 |
196 | 3. **Reinstall liboqs-python:**
197 | ```bash
198 | pip uninstall liboqs-python
199 | pip install git+https://github.com/open-quantum-safe/liboqs-python@main
200 | ```
201 |
202 | #### Memory Issues
203 |
204 | **Symptoms:**
205 | - Out of memory errors during key generation
206 | - Slow performance
207 |
208 | **Solutions:**
209 | 1. **Increase available memory:**
210 | - Close other applications
211 | - Use a machine with more RAM
212 |
213 | 2. **Use Docker with memory limits:**
214 | ```bash
215 | docker run --memory=2g qu3-app
216 | ```
217 |
218 | 3. **Monitor memory usage:**
219 | ```bash
220 | python -m src.main benchmark --iterations 5
221 | ```
222 |
223 | ### Performance Issues
224 |
225 | #### Slow Key Generation
226 |
227 | **Symptoms:**
228 | - Key generation takes very long
229 | - Timeouts during operations
230 |
231 | **Solutions:**
232 | 1. **Benchmark performance:**
233 | ```bash
234 | python -m src.main benchmark
235 | ```
236 |
237 | 2. **Check system resources:**
238 | ```bash
239 | top
240 | htop
241 | ```
242 |
243 | 3. **Use faster algorithms (if acceptable for your use case):**
244 | - Consider different parameter sets
245 | - Check algorithm documentation
246 |
247 | #### Network Timeouts
248 |
249 | **Symptoms:**
250 | - Request timeout errors
251 | - Slow server responses
252 |
253 | **Solutions:**
254 | 1. **Increase timeout values:**
255 | ```yaml
256 | # config.yaml
257 | network:
258 | timeout: 60
259 | ```
260 |
261 | 2. **Check network latency:**
262 | ```bash
263 | ping your-server-url
264 | ```
265 |
266 | 3. **Use local server for testing:**
267 | ```bash
268 | python -m scripts.mock_mcp_server
269 | ```
270 |
271 | ## 🔧 Debugging Tools
272 |
273 | ### Environment Validation
274 | ```bash
275 | python -m src.environment
276 | ```
277 |
278 | ### Configuration Check
279 | ```bash
280 | python -m src.main validate-config
281 | ```
282 |
283 | ### Key Inspection
284 | ```bash
285 | python -m src.main inspect-keys
286 | ```
287 |
288 | ### Connection Testing
289 | ```bash
290 | python -m src.main test-connection
291 | ```
292 |
293 | ### Performance Benchmarking
294 | ```bash
295 | python -m src.main benchmark --iterations 10
296 | ```
297 |
298 | ### Verbose Logging
299 | ```yaml
300 | # config.yaml
301 | logging:
302 | level: "DEBUG"
303 | file: "debug.log"
304 | ```
305 |
306 | ## 🆘 Getting Help
307 |
308 | ### Self-Help Resources
309 | 1. **Check this troubleshooting guide**
310 | 2. **Review the main README.md**
311 | 3. **Check DEVELOPMENT.md for detailed setup**
312 | 4. **Run diagnostic commands**
313 |
314 | ### Community Support
315 | 1. **GitHub Issues**: Report bugs and request features
316 | 2. **Discussions**: Ask questions and share experiences
317 | 3. **Documentation**: Contribute improvements
318 |
319 | ### Professional Support
320 | For enterprise users requiring dedicated support:
321 | - **Email**: joseph@qu3.ai
322 | - **Priority Support**: Available for enterprise customers
323 | - **Custom Integration**: Professional services available
324 |
325 | ## 🔍 Diagnostic Information
326 |
327 | When reporting issues, please include:
328 |
329 | 1. **System Information:**
330 | ```bash
331 | python -m src.environment
332 | ```
333 |
334 | 2. **Configuration:**
335 | ```bash
336 | python -m src.main validate-config
337 | ```
338 |
339 | 3. **Error Messages:**
340 | - Full error output
341 | - Stack traces
342 | - Log files (if enabled)
343 |
344 | 4. **Steps to Reproduce:**
345 | - Exact commands used
346 | - Expected vs actual behavior
347 | - Environment details
348 |
349 | ## 📝 Known Issues
350 |
351 | ### Current Limitations
352 | - **Installation Time**: liboqs-python compilation can take 10-30 minutes
353 | - **Memory Usage**: Key generation requires significant memory
354 | - **Platform Support**: Limited Windows support (use WSL)
355 |
356 | ### Workarounds
357 | - **Use Docker**: Avoids compilation issues
358 | - **Pre-built Images**: Available for common platforms
359 | - **Cloud Development**: Use cloud instances for development
360 |
361 | ### Future Improvements
362 | - **Binary Distributions**: Pre-compiled packages
363 | - **Optimized Algorithms**: Faster implementations
364 | - **Better Windows Support**: Native Windows builds
365 |
366 | ---
367 |
368 | **Still having issues?**
369 |
370 | 1. Check our [GitHub Issues](https://github.com/qu3ai/qu3-app/issues)
371 | 2. Create a new issue with diagnostic information
372 | 3. Contact support at joseph@qu3.ai
373 |
374 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # QU3 - Quantum-Safe MCP Client
2 |
3 | This project provides a client application (`qu3-app`) for secure interaction with Quantum-Safe Multi-Compute Provider (MCP) environments. It leverages post-quantum cryptography (PQC) standards for establishing secure communication channels, ensuring client authenticity, and verifying server attestations.
4 |
5 | This client is designed to work with MCP servers that support the QU3 interaction protocols. For development and testing, a compatible mock server implementation is included in `scripts/mock_mcp_server.py`.
6 |
7 | ## Architecture & Flow
8 |
9 | ### Secure Communication Flow
10 |
11 | The following diagram illustrates the end-to-end secure communication pattern implemented between the QU3 Client and the MCP Server:
12 |
13 | ```mermaid
14 | sequenceDiagram
15 | participant Client as QU3 Client (CLI)
16 | participant Server as MCP Server
17 |
18 | Note over Client,Server: Initial Setup (First Run / Keys Missing)
19 | Client->>+Server: GET /keys (Fetch Server Public Keys)
20 | Server-->>-Client: Server KEM PK, Server Sign PK (Base64)
21 | Client->>Client: Store Server Public Keys Locally
22 |
23 | Note over Client,Server: Establish Secure Session
24 | Client->>+Server: POST /kem-handshake/initiate { Client KEM PK, Client Sign PK }
25 | Server->>Server: Encapsulate Shared Secret (using Client KEM PK)
26 | Server->>Server: Derive AES-256 Session Key & Store Session Details
27 | Server-->>-Client: { KEM Ciphertext }
28 | Client->>Client: Decapsulate Shared Secret & Derive AES-256 Session Key
29 | Client->>Client: Store AES Session Key
30 |
31 | Note over Client,Server: Secured Inference Request
32 | Client->>Client: Prepare Secure Request (Sign Payload, Encrypt with AES Key)
33 | Client->>+Server: POST /inference { Client KEM PK, nonce, ciphertext }
34 | Server->>Server: Process Secure Request (Lookup Session, Validate Timestamp, Decrypt, Verify Signature)
35 | Server->>Server: Execute Model(input_data) -> output_data
36 | Server->>Server: Prepare Secure Response (Prepare Attestation, Sign, Encrypt with AES Key)
37 | Server-->>-Client: { resp_nonce, resp_ciphertext }
38 | Client->>Client: Process Secure Response (Decrypt, Verify Attestation)
39 | Client->>Client: Process Result
40 |
41 | Note over Client,Server: Secured Policy Update
42 | Client->>Client: Prepare Secure Policy (Read, Sign, Encrypt with AES Key)
43 | Client->>+Server: POST /policy-update { Client KEM PK, policy_nonce, policy_ciphertext, policy_sig }
44 | Server->>Server: Process Secure Policy (Lookup Session, Validate Timestamp, Decrypt, Verify Signature)
45 | Server->>Server: Process Policy (Mock)
46 | Server->>Server: Prepare Secure Status (Sign Status, Encrypt with AES Key)
47 | Server-->>-Client: { status_nonce, status_ciphertext, status_sig }
48 | Client->>Client: Process Secure Status (Decrypt, Verify Signature)
49 | Client->>Client: Display Status
50 | ```
51 |
52 | ### Client Component Interaction
53 |
54 | This diagram shows how the main Python modules within the client application interact:
55 |
56 | ```mermaid
57 | graph TD
58 | A("User CLI (Typer)") --> B("src_main");
59 | B -- Initiates --> C("src_mcp_client (MCPClient)");
60 | B -- Uses --> D("src_config_utils");
61 | C -- Uses --> E("src_pqc_utils");
62 | C -- Uses --> G("requests Session");
63 | D -- Uses --> H("PyYAML");
64 | D -- Manages --> I("Key Files");
65 | D -- Uses --> G;
66 | E -- Uses --> J("liboqs_python");
67 | E -- Uses --> K("cryptography");
68 |
69 | subgraph Cryptography
70 | J
71 | K
72 | end
73 |
74 | subgraph Networking
75 | G
76 | end
77 |
78 | subgraph "Configuration & Keys"
79 | H
80 | I
81 | end
82 | ```
83 |
84 | ### Key Management Overview
85 |
86 | Keys are crucial for the security protocols. Here's how they are managed:
87 |
88 | ```mermaid
89 | graph LR
90 | subgraph Client Side
91 | A["CLI: generate-keys"] --> B{"src_main.py"};
92 | B --> C["src_pqc_utils.py"]:::pqc --> D{"Generate KEM/Sign Pairs"};
93 | D --> E["src_config_utils.py"]:::config --> F["Save Client Keys (.pub, .sec)"];
94 |
95 | G["CLI: run-inference/etc."] --> B;
96 | B --> H{"Initialize Client"};
97 | H --> E --> I{"Load Client Keys?"};
98 | I -- Found --> J["Use Keys"];
99 | I -- Not Found --> D;
100 |
101 | H --> E --> K{"Load Server Keys?"};
102 | K -- Found --> J;
103 | K -- Not Found --> L{"Fetch Server Keys?"};
104 | L -- Calls --> E --> M["GET /keys"]:::net;
105 | M -- Response --> E --> N["Save Server Keys (.pub)"];
106 | N --> J;
107 | L -- Fetch Fail --> O["(Error - Cannot Proceed)"];
108 | end
109 |
110 | subgraph "Server Side (Mock)"
111 | P["Server Startup"] --> Q["scripts_mock_mcp_server.py"];
112 | Q --> R["src_config_utils.py"]:::config --> S{"Load/Generate Server Keys"};
113 | S --> T["Save Server Keys (.pub, .sec)"];
114 | Q --> U["Register /keys Endpoint"];
115 | U -- Request for /keys --> V{"Return Server Public Keys"};
116 | end
117 |
118 | subgraph "Filesystem (Key Dir)"
119 | F
120 | T
121 | N
122 | end
123 |
124 | classDef pqc fill:#f9d,stroke:#333,stroke-width:2px;
125 | classDef config fill:#cfc,stroke:#333,stroke-width:2px;
126 | classDef net fill:#cdf,stroke:#333,stroke-width:2px;
127 | ```
128 |
129 | ## Core Components
130 |
131 | * **`src/main.py`**: Command-line interface (CLI) built with Typer. Handles user commands, orchestrates client operations, and displays results. Includes commands for key generation, inference, agent workflows, and policy updates.
132 | * **`src/mcp_client.py`**: The main client class (`MCPClient`) responsible for:
133 | * Managing PQC keys.
134 | * Defining request (`MCPRequest`) and response (`MCPResponse`) data structures.
135 | * Establishing secure sessions via KEM handshake (`connect`).
136 | * Sending signed and encrypted requests (`send_request`).
137 | * Processing and verifying encrypted/signed responses.
138 | * Handling disconnection (`disconnect`).
139 | * **`src/pqc_utils.py`**: Utility functions for PQC operations (Kyber KEM, SPHINCS+ signing) using `liboqs-python`, AES-GCM encryption/decryption using `cryptography`, and HKDF key derivation.
140 | * **`src/config_utils.py`**: Handles loading configuration from `config.yaml`, loading/saving keys from/to files, and fetching server public keys from the `/keys` endpoint.
141 | * **`scripts/mock_mcp_server.py`**: A FastAPI development/test server that simulates an MCP environment. Implements the server-side logic for KEM handshake, request decryption/verification, basic model execution, attestation signing, response encryption, policy updates, and key distribution.
142 | * **`config.yaml`**: Configuration file for storing settings like the default key directory (`key_directory`) and server URL (`server_url`).
143 | * **`tests/`**: Directory containing unit tests (`unittest`) for core components (`pqc_utils`, `config_utils`, `mcp_client`).
144 |
145 | ## Features
146 |
147 | * **PQC Algorithms**: Uses NIST PQC finalists:
148 | * KEM: `Kyber-768`
149 | * Signature: `SPHINCS+-SHA2-128f-simple`
150 | * **PQC Key Management**: Generates and loads Kyber and SPHINCS+ key pairs.
151 | * **Secure Session Establishment**: Uses Kyber KEM over a network handshake (`/kem-handshake/initiate`) to establish a shared secret.
152 | * **Key Derivation**: Derives a 32-byte AES-256 key from the KEM shared secret using HKDF-SHA256.
153 | * **Encrypted Communication**: Encrypts request/response payloads (after KEM handshake) using AES-256-GCM with the derived session key.
154 | * **Client Authentication**: Client signs requests using SPHINCS+; server verifies.
155 | * **Server Attestation**: Server signs responses (attestation data) using SPHINCS+; client verifies.
156 | * **Configuration**: Loads key directory and server URL from `config.yaml`.
157 | * **Automated Server Key Fetching**: Client automatically fetches server public keys from the `/keys` endpoint if not found locally.
158 | * **CLI Commands**:
159 | * `generate-keys`: Creates client key pairs.
160 | * `run-inference`: Sends a single, secured inference request.
161 | * `run-agent`: Executes a sequential workflow (`modelA->modelB`), passing outputs as inputs (wraps non-dict output), with step-by-step reporting and robust failure handling.
162 | * `update-policy`: Sends an encrypted and signed policy file to the server.
163 | * **Mock Server**: Includes endpoints (`/`, `/keys`, `/kem-handshake/initiate`, `/inference`, `/policy-update`) implementing the corresponding server-side PQC and communication logic for testing. Provides example models `model_caps` and `model_reverse`.
164 | * **Unit Tests**: Includes unit tests covering core cryptographic utilities, configuration management, and client communication logic (with network mocking).
165 |
166 | ## Setup
167 |
168 | 1. **Clone the repository:**
169 | ```bash
170 | git clone
171 | cd qu3-app
172 | ```
173 | 2. **Create a virtual environment:**
174 | ```bash
175 | python3 -m venv venv
176 | source venv/bin/activate
177 | ```
178 | 3. **Install dependencies:**
179 | ```bash
180 | pip install -r requirements.txt
181 | ```
182 | *(Note: `liboqs-python` might require system dependencies like a C compiler and the `liboqs` C library. Refer to its documentation if installation fails.)*
183 | 4. **(Optional) Configure `config.yaml`:** Modify `key_directory` or `server_url` if needed. The default key directory is `~/.qu3/keys/`.
184 |
185 | ### Quick Start
186 | ```bash
187 | # Easy installation
188 | ./scripts/install.sh
189 | # Validate setup
190 | python -m src.main validate-config
191 | # Test connection
192 | python -m src.main test-connection
193 | # Run benchmarks
194 | python -m src.main benchmark
195 | ```
196 |
197 | ## Running the Mock Server
198 |
199 | In one terminal, run:
200 |
201 | ```bash
202 | # Ensure virtual environment is active
203 | source venv/bin/activate
204 |
205 | python -m scripts.mock_mcp_server
206 | ```
207 |
208 | The server will start (usually on `http://127.0.0.1:8000`) and automatically generate its own key pairs in the configured key directory if they don't exist.
209 |
210 | ## Running the Client CLI
211 |
212 | In another terminal (with the virtual environment activated):
213 |
214 | 1. **Generate Client Keys:** (Only needs to be done once unless `--force` is used)
215 | ```bash
216 | python -m src.main generate-keys
217 | ```
218 |
219 | 2. **Fetch Server Keys (Automatic):** The client will attempt to fetch keys from the server's `/keys` endpoint during initialization (`run-inference`, `run-agent`, `update-policy`) if `server_kem.pub` and `server_sign.pub` are not found in the key directory specified in `config.yaml`. Ensure the mock server is running before executing client commands that require connection.
220 |
221 | 3. **Run Single Inference:**
222 | ```bash
223 | # Example using mock server's model_caps
224 | python -m src.main run-inference model_caps '{"text": "process this data"}'
225 |
226 | # Example specifying server URL
227 | python -m src.main run-inference model_reverse '{"text": "backward"}' --server-url http://127.0.0.1:8000
228 | ```
229 |
230 | 4. **Run Agent Workflow:**
231 | ```bash
232 | # Example chaining two mock models
233 | python -m src.main run-agent "model_caps -> model_reverse" '{"text": "flow start"}'
234 | ```
235 |
236 | 5. **Update Policy:**
237 | Create a policy file (e.g., `my_policy.txt`) with some content.
238 | ```bash
239 | echo "Allow model_caps access." > my_policy.txt
240 | python -m src.main update-policy --policy-file my_policy.txt
241 | ```
242 |
243 | ## Running Tests
244 |
245 | To run the unit tests:
246 |
247 | ```bash
248 | # Ensure virtual environment is active
249 | source venv/bin/activate
250 |
251 | python -m unittest discover tests
252 | ```
253 |
254 | Or run specific test files:
255 |
256 | ```bash
257 | python -m unittest tests.test_pqc_utils
258 | python -m unittest tests.test_config_utils
259 | python -m unittest tests.test_mcp_client
260 | ```
261 |
262 | ## 🆕 Latest Updates
263 |
264 | The newest release includes significant improvements to developer experience:
265 |
266 | ### New CLI Commands
267 | - `validate-config`: Comprehensive environment and configuration validation
268 | - `inspect-keys`: Detailed key inspection and status reporting
269 | - `test-connection`: Server connectivity testing without operations
270 | - `benchmark`: Performance measurement for quantum-safe operations
271 |
272 | ### Enhanced Installation
273 | - Automated installation script (`scripts/install.sh`) handles complex dependencies
274 | - Docker development environment improvements
275 | - Better error handling and diagnostics
276 |
277 | ### Documentation & Examples
278 | - Comprehensive troubleshooting guide (`TROUBLESHOOTING.md`)
279 |
--------------------------------------------------------------------------------
/src/environment.py:
--------------------------------------------------------------------------------
1 | """
2 | Environment and configuration validation utilities for qu3-app.
3 |
4 | This module provides enhanced environment detection, configuration validation,
5 | and runtime environment setup for the quantum-safe MCP client.
6 | """
7 |
8 | import os
9 | import sys
10 | import platform
11 | from pathlib import Path
12 | from typing import Dict, Any, Optional, List, Tuple
13 | import logging
14 | import subprocess
15 | import importlib.util
16 |
17 | log = logging.getLogger(__name__)
18 |
19 | class EnvironmentError(Exception):
20 | """Raised when environment validation fails."""
21 | pass
22 |
23 | class DependencyError(EnvironmentError):
24 | """Raised when required dependencies are missing or incompatible."""
25 | pass
26 |
27 | def get_system_info() -> Dict[str, Any]:
28 | """Get comprehensive system information for debugging and compatibility checks.
29 |
30 | Returns:
31 | Dictionary containing system information
32 | """
33 | return {
34 | 'platform': platform.platform(),
35 | 'system': platform.system(),
36 | 'release': platform.release(),
37 | 'version': platform.version(),
38 | 'machine': platform.machine(),
39 | 'processor': platform.processor(),
40 | 'python_version': platform.python_version(),
41 | 'python_implementation': platform.python_implementation(),
42 | 'python_executable': sys.executable,
43 | 'working_directory': str(Path.cwd()),
44 | 'user': os.getenv('USER', os.getenv('USERNAME', 'unknown')),
45 | 'home': str(Path.home()),
46 | 'path_separator': os.pathsep,
47 | 'environment_variables': {
48 | key: value for key, value in os.environ.items()
49 | if key.startswith(('QU3_', 'LIBOQS_', 'OQS_'))
50 | }
51 | }
52 |
53 | def check_python_version(min_version: Tuple[int, int] = (3, 8)) -> bool:
54 | """Check if Python version meets minimum requirements.
55 |
56 | Args:
57 | min_version: Minimum required Python version as (major, minor)
58 |
59 | Returns:
60 | True if version is sufficient
61 |
62 | Raises:
63 | DependencyError: If Python version is too old
64 | """
65 | current_version = sys.version_info[:2]
66 | if current_version < min_version:
67 | raise DependencyError(
68 | f"Python {min_version[0]}.{min_version[1]}+ required, "
69 | f"but running {current_version[0]}.{current_version[1]}"
70 | )
71 |
72 | log.debug(f"Python version check passed: {current_version} >= {min_version}")
73 | return True
74 |
75 | def check_required_packages() -> Dict[str, Any]:
76 | """Check availability and versions of required packages.
77 |
78 | Returns:
79 | Dictionary with package information and status
80 |
81 | Raises:
82 | DependencyError: If critical packages are missing
83 | """
84 | required_packages = {
85 | 'oqs': {'critical': True, 'min_version': None},
86 | 'cryptography': {'critical': True, 'min_version': '3.0'},
87 | 'requests': {'critical': True, 'min_version': '2.20'},
88 | 'yaml': {'critical': True, 'min_version': None, 'import_name': 'yaml'},
89 | 'typer': {'critical': True, 'min_version': '0.7'},
90 | 'fastapi': {'critical': False, 'min_version': '0.68'},
91 | 'uvicorn': {'critical': False, 'min_version': '0.15'},
92 | }
93 |
94 | package_status = {}
95 | missing_critical = []
96 |
97 | for package_name, info in required_packages.items():
98 | import_name = info.get('import_name', package_name)
99 | try:
100 | spec = importlib.util.find_spec(import_name)
101 | if spec is None:
102 | package_status[package_name] = {
103 | 'available': False,
104 | 'version': None,
105 | 'critical': info['critical']
106 | }
107 | if info['critical']:
108 | missing_critical.append(package_name)
109 | else:
110 | # Try to get version
111 | try:
112 | module = importlib.import_module(import_name)
113 | version = getattr(module, '__version__', 'unknown')
114 | except Exception:
115 | version = 'unknown'
116 |
117 | package_status[package_name] = {
118 | 'available': True,
119 | 'version': version,
120 | 'critical': info['critical']
121 | }
122 |
123 | except Exception as e:
124 | log.warning(f"Error checking package {package_name}: {e}")
125 | package_status[package_name] = {
126 | 'available': False,
127 | 'version': None,
128 | 'critical': info['critical'],
129 | 'error': str(e)
130 | }
131 | if info['critical']:
132 | missing_critical.append(package_name)
133 |
134 | if missing_critical:
135 | raise DependencyError(
136 | f"Critical packages missing: {', '.join(missing_critical)}. "
137 | f"Run 'pip install -r requirements.txt' to install dependencies."
138 | )
139 |
140 | return package_status
141 |
142 | def check_liboqs_installation() -> Dict[str, Any]:
143 | """Check liboqs installation and available algorithms.
144 |
145 | Returns:
146 | Dictionary with liboqs status and available algorithms
147 | """
148 | try:
149 | import oqs
150 |
151 | # Get available algorithms
152 | kem_algorithms = oqs.get_enabled_kem_mechanisms()
153 | sig_algorithms = oqs.get_enabled_sig_mechanisms()
154 |
155 | # Check for required algorithms
156 | required_kem = "Kyber768"
157 | required_sig = "SPHINCS+-SHA2-128f-simple"
158 |
159 | status = {
160 | 'available': True,
161 | 'version': getattr(oqs, '__version__', 'unknown'),
162 | 'kem_algorithms': kem_algorithms,
163 | 'sig_algorithms': sig_algorithms,
164 | 'required_kem_available': required_kem in kem_algorithms,
165 | 'required_sig_available': required_sig in sig_algorithms,
166 | }
167 |
168 | if not status['required_kem_available']:
169 | log.warning(f"Required KEM algorithm {required_kem} not available")
170 | if not status['required_sig_available']:
171 | log.warning(f"Required signature algorithm {required_sig} not available")
172 |
173 | return status
174 |
175 | except ImportError as e:
176 | log.error(f"liboqs not available: {e}")
177 | return {
178 | 'available': False,
179 | 'error': str(e),
180 | 'suggestion': 'Install liboqs-python: pip install liboqs-python'
181 | }
182 | except Exception as e:
183 | log.error(f"Error checking liboqs: {e}")
184 | return {
185 | 'available': False,
186 | 'error': str(e)
187 | }
188 |
189 | def validate_file_permissions(file_path: Path, expected_mode: int) -> bool:
190 | """Validate file permissions for security.
191 |
192 | Args:
193 | file_path: Path to file to check
194 | expected_mode: Expected permission mode (e.g., 0o600)
195 |
196 | Returns:
197 | True if permissions are correct
198 | """
199 | try:
200 | if not file_path.exists():
201 | return False
202 |
203 | actual_mode = file_path.stat().st_mode & 0o777
204 | if actual_mode != expected_mode:
205 | log.warning(
206 | f"File {file_path} has permissions {oct(actual_mode)}, "
207 | f"expected {oct(expected_mode)}"
208 | )
209 | return False
210 | return True
211 | except Exception as e:
212 | log.error(f"Error checking permissions for {file_path}: {e}")
213 | return False
214 |
215 | def setup_secure_environment() -> None:
216 | """Set up secure environment variables and settings."""
217 | # Disable Python bytecode generation for security
218 | os.environ['PYTHONDONTWRITEBYTECODE'] = '1'
219 |
220 | # Set secure umask for file creation
221 | try:
222 | os.umask(0o077) # Only owner can read/write new files
223 | log.debug("Set secure umask (077)")
224 | except Exception as e:
225 | log.warning(f"Failed to set secure umask: {e}")
226 |
227 | # Clear potentially sensitive environment variables
228 | sensitive_vars = ['HISTFILE', 'LESSHISTFILE']
229 | for var in sensitive_vars:
230 | if var in os.environ:
231 | del os.environ[var]
232 | log.debug(f"Cleared sensitive environment variable: {var}")
233 |
234 | def get_runtime_diagnostics() -> Dict[str, Any]:
235 | """Get comprehensive runtime diagnostics for troubleshooting.
236 |
237 | Returns:
238 | Dictionary with diagnostic information
239 | """
240 | diagnostics = {
241 | 'timestamp': str(Path.cwd()),
242 | 'system_info': get_system_info(),
243 | }
244 |
245 | try:
246 | diagnostics['python_version_check'] = check_python_version()
247 | except Exception as e:
248 | diagnostics['python_version_check'] = {'error': str(e)}
249 |
250 | try:
251 | diagnostics['package_status'] = check_required_packages()
252 | except Exception as e:
253 | diagnostics['package_status'] = {'error': str(e)}
254 |
255 | try:
256 | diagnostics['liboqs_status'] = check_liboqs_installation()
257 | except Exception as e:
258 | diagnostics['liboqs_status'] = {'error': str(e)}
259 |
260 | return diagnostics
261 |
262 | def validate_environment() -> bool:
263 | """Perform comprehensive environment validation.
264 |
265 | Returns:
266 | True if environment is valid for running qu3-app
267 |
268 | Raises:
269 | EnvironmentError: If environment validation fails
270 | """
271 | log.info("Validating runtime environment...")
272 |
273 | try:
274 | # Check Python version
275 | check_python_version()
276 |
277 | # Check required packages
278 | package_status = check_required_packages()
279 | log.debug(f"Package status: {package_status}")
280 |
281 | # Check liboqs specifically
282 | liboqs_status = check_liboqs_installation()
283 | if not liboqs_status.get('available', False):
284 | raise EnvironmentError(
285 | f"liboqs not available: {liboqs_status.get('error', 'unknown error')}"
286 | )
287 |
288 | if not liboqs_status.get('required_kem_available', False):
289 | raise EnvironmentError("Required KEM algorithm not available in liboqs")
290 |
291 | if not liboqs_status.get('required_sig_available', False):
292 | raise EnvironmentError("Required signature algorithm not available in liboqs")
293 |
294 | # Set up secure environment
295 | setup_secure_environment()
296 |
297 | log.info("Environment validation passed")
298 | return True
299 |
300 | except (DependencyError, EnvironmentError):
301 | raise
302 | except Exception as e:
303 | raise EnvironmentError(f"Environment validation failed: {e}")
304 |
305 | def get_config_recommendations() -> List[str]:
306 | """Get configuration recommendations based on current environment.
307 |
308 | Returns:
309 | List of configuration recommendations
310 | """
311 | recommendations = []
312 |
313 | # Check if running in development vs production
314 | if os.getenv('QU3_ENV') != 'production':
315 | recommendations.append("Set QU3_ENV=production for production deployments")
316 |
317 | # Check key directory security
318 | try:
319 | from .config_utils import get_key_dir
320 | key_dir = get_key_dir()
321 | if key_dir.exists():
322 | stat_info = key_dir.stat()
323 | if stat_info.st_mode & 0o077:
324 | recommendations.append(
325 | f"Key directory {key_dir} has overly permissive permissions. "
326 | "Run: chmod 700 ~/.qu3/keys"
327 | )
328 | except Exception:
329 | pass
330 |
331 | # Check for HTTPS usage
332 | try:
333 | from .config_utils import get_server_url
334 | server_url = get_server_url()
335 | if server_url.startswith('http://') and 'localhost' not in server_url and '127.0.0.1' not in server_url:
336 | recommendations.append(
337 | "Consider using HTTPS for server communication in production"
338 | )
339 | except Exception:
340 | pass
341 |
342 | return recommendations
343 |
344 | if __name__ == "__main__":
345 | # CLI for environment diagnostics
346 | import json
347 |
348 | try:
349 | validate_environment()
350 | print("✅ Environment validation passed")
351 |
352 | diagnostics = get_runtime_diagnostics()
353 | print("\n📊 Runtime Diagnostics:")
354 | print(json.dumps(diagnostics, indent=2, default=str))
355 |
356 | recommendations = get_config_recommendations()
357 | if recommendations:
358 | print("\n💡 Recommendations:")
359 | for rec in recommendations:
360 | print(f" • {rec}")
361 |
362 | except Exception as e:
363 | print(f"❌ Environment validation failed: {e}")
364 | sys.exit(1)
365 |
366 |
--------------------------------------------------------------------------------
/src/pqc_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple, Any, Optional
2 | import os
3 | import oqs
4 | import logging
5 | import time
6 | from cryptography.hazmat.primitives.ciphers.aead import AESGCM
7 | from cryptography.hazmat.primitives import hashes
8 | from cryptography.hazmat.primitives.kdf.hkdf import HKDF
9 | from cryptography.exceptions import InvalidTag
10 |
11 | class PQCError(Exception):
12 | """Base class for PQC related errors."""
13 | pass
14 |
15 | class PQCKeyGenerationError(PQCError):
16 | """Error during PQC key generation."""
17 | pass
18 |
19 | class PQCSignatureError(PQCError):
20 | """Error related to PQC signing or verification."""
21 | pass
22 |
23 | class PQCKEMError(PQCError):
24 | """Error related to PQC KEM operations."""
25 | pass
26 |
27 | class PQCEncryptionError(PQCError):
28 | """Error during AES encryption."""
29 | pass
30 |
31 | class PQCDecryptionError(PQCError):
32 | """Error during AES decryption (e.g., InvalidTag)."""
33 | pass
34 |
35 | class PQCRetryableError(PQCError):
36 | """Error that can be retried."""
37 | pass
38 |
39 | ALGORITHMS = {
40 | "kem": "Kyber768",
41 | "sig": "SPHINCS+-SHA2-128f-simple"
42 | }
43 |
44 | AES_NONCE_BYTES = 12
45 | AES_KEY_BYTES = 32
46 |
47 | # Retry configuration
48 | MAX_RETRIES = 3
49 | RETRY_DELAY = 0.1 # seconds
50 |
51 | log = logging.getLogger(__name__)
52 |
53 | def _validate_algorithm_support():
54 | """Validate that required algorithms are supported by the current liboqs build."""
55 | for algo_type, algo_name in ALGORITHMS.items():
56 | if algo_type == "kem" and not oqs.is_kem_enabled(algo_name):
57 | raise ImportError(f"Required KEM algorithm '{algo_name}' is not enabled in this liboqs build.")
58 | elif algo_type == "sig" and not oqs.is_sig_enabled(algo_name):
59 | raise ImportError(f"Required Signature algorithm '{algo_name}' is not enabled in this liboqs build.")
60 | log.debug(f"PQC Algorithm Confirmed: {algo_type} = {algo_name}")
61 |
62 | _validate_algorithm_support()
63 |
64 | def _retry_on_failure(func, max_retries: int = MAX_RETRIES, delay: float = RETRY_DELAY):
65 | """Retry decorator for PQC operations that may fail transiently."""
66 | def wrapper(*args, **kwargs):
67 | last_exception = None
68 | for attempt in range(max_retries + 1):
69 | try:
70 | return func(*args, **kwargs)
71 | except (oqs.MechanismNotSupportedError, oqs.MechanismNotEnabledError) as e:
72 | # These are not retryable
73 | raise PQCError(f"PQC mechanism not supported: {e}")
74 | except Exception as e:
75 | last_exception = e
76 | if attempt < max_retries:
77 | log.warning(f"Attempt {attempt + 1} failed for {func.__name__}: {e}. Retrying in {delay}s...")
78 | time.sleep(delay)
79 | delay *= 1.5
80 | else:
81 | log.error(f"All {max_retries + 1} attempts failed for {func.__name__}")
82 |
83 | raise PQCRetryableError(f"Operation failed after {max_retries + 1} attempts: {last_exception}")
84 | return wrapper
85 |
86 | def generate_key_pair(algo_name: str) -> Tuple[bytes, bytes]:
87 | """Generates a public and private key pair for the specified PQC algorithm.
88 |
89 | Args:
90 | algo_name: The name of the PQC algorithm (must be supported by liboqs)
91 |
92 | Returns:
93 | Tuple of (public_key_bytes, secret_key_bytes)
94 |
95 | Raises:
96 | PQCKeyGenerationError: If key generation fails
97 | ValueError: If algorithm is not supported
98 | """
99 | if not isinstance(algo_name, str) or not algo_name.strip():
100 | raise ValueError("Algorithm name must be a non-empty string")
101 |
102 | @_retry_on_failure
103 | def _generate():
104 | if oqs.is_kem_enabled(algo_name):
105 | with oqs.KeyEncapsulation(algo_name) as kem:
106 | public_key = kem.generate_keypair()
107 | secret_key = kem.export_secret_key()
108 | log.info(f"Generated {algo_name} KEM key pair.")
109 | return public_key, secret_key
110 | elif oqs.is_sig_enabled(algo_name):
111 | with oqs.Signature(algo_name) as sig:
112 | public_key = sig.generate_keypair()
113 | secret_key = sig.export_secret_key()
114 | log.info(f"Generated {algo_name} Signature key pair.")
115 | return public_key, secret_key
116 | else:
117 | raise ValueError(f"Unsupported or unknown PQC algorithm: {algo_name}")
118 |
119 | try:
120 | return _generate()
121 | except (PQCError, ValueError):
122 | raise
123 | except Exception as e:
124 | raise PQCKeyGenerationError(f"Unexpected error during {algo_name} key generation: {e}")
125 |
126 | def sign_message(message: bytes, secret_key: bytes, sig_algo: str) -> bytes:
127 | """Signs a message using the provided private key and signature algorithm.
128 |
129 | Args:
130 | message: The message bytes to sign
131 | secret_key: The private key bytes for signing
132 | sig_algo: The signature algorithm name
133 |
134 | Returns:
135 | The signature bytes
136 |
137 | Raises:
138 | PQCSignatureError: If signing fails
139 | ValueError: If inputs are invalid
140 | """
141 | if not isinstance(message, bytes):
142 | raise ValueError("Message must be bytes")
143 | if not isinstance(secret_key, bytes) or len(secret_key) == 0:
144 | raise ValueError("Secret key must be non-empty bytes")
145 | if not isinstance(sig_algo, str) or not sig_algo.strip():
146 | raise ValueError("Signature algorithm must be a non-empty string")
147 |
148 | if not oqs.is_sig_enabled(sig_algo):
149 | raise PQCSignatureError(f"Signature algorithm '{sig_algo}' is not enabled or supported.")
150 |
151 | @_retry_on_failure
152 | def _sign():
153 | with oqs.Signature(sig_algo, secret_key) as sig:
154 | signature = sig.sign(message)
155 | log.debug(f"Message signed using {sig_algo}.")
156 | return signature
157 |
158 | try:
159 | return _sign()
160 | except PQCError:
161 | raise
162 | except Exception as e:
163 | raise PQCSignatureError(f"Unexpected error during message signing with {sig_algo}: {e}")
164 |
165 | def verify_signature(message: bytes, signature: bytes, public_key: bytes, sig_algo: str) -> bool:
166 | """Verifies a signature against a message using the public key and signature algorithm.
167 |
168 | Args:
169 | message: The original message bytes
170 | signature: The signature bytes to verify
171 | public_key: The public key bytes for verification
172 | sig_algo: The signature algorithm name
173 |
174 | Returns:
175 | True if signature is valid, False otherwise
176 | """
177 | if not isinstance(message, bytes):
178 | log.warning("Message must be bytes for signature verification")
179 | return False
180 | if not isinstance(signature, bytes) or len(signature) == 0:
181 | log.warning("Signature must be non-empty bytes for verification")
182 | return False
183 | if not isinstance(public_key, bytes) or len(public_key) == 0:
184 | log.warning("Public key must be non-empty bytes for verification")
185 | return False
186 | if not isinstance(sig_algo, str) or not sig_algo.strip():
187 | log.warning("Signature algorithm must be a non-empty string")
188 | return False
189 |
190 | if not oqs.is_sig_enabled(sig_algo):
191 | log.warning(f"Attempt to verify with unsupported/disabled sig algo: {sig_algo}")
192 | return False
193 |
194 | try:
195 | with oqs.Signature(sig_algo) as sig:
196 | is_valid = sig.verify(message, signature, public_key)
197 | log.debug(f"Signature verification result using {sig_algo}: {is_valid}")
198 | return is_valid
199 | except (oqs.MechanismNotSupportedError, oqs.MechanismNotEnabledError) as e:
200 | log.warning(f"Signature verification failed for {sig_algo} (OQS Error): {e}")
201 | return False
202 | except Exception as e:
203 | log.warning(f"Unexpected error during signature verification for {sig_algo}: {e}")
204 | return False
205 |
206 | def kem_encapsulate(kem_algo: str, public_key: bytes) -> Tuple[bytes, bytes]:
207 | """Performs KEM encapsulation using the recipient's public key.
208 |
209 | Args:
210 | kem_algo: The KEM algorithm name
211 | public_key: The recipient's public key bytes
212 |
213 | Returns:
214 | Tuple of (ciphertext_bytes, shared_secret_bytes)
215 |
216 | Raises:
217 | PQCKEMError: If KEM operation fails
218 | ValueError: If inputs are invalid
219 | """
220 | if not isinstance(kem_algo, str) or not kem_algo.strip():
221 | raise ValueError("KEM algorithm must be a non-empty string")
222 | if not isinstance(public_key, bytes) or len(public_key) == 0:
223 | raise ValueError("Public key must be non-empty bytes")
224 |
225 | if not oqs.is_kem_enabled(kem_algo):
226 | raise PQCKEMError(f"KEM algorithm '{kem_algo}' is not enabled or supported by current build flags.")
227 |
228 | @_retry_on_failure
229 | def _encapsulate():
230 | with oqs.KeyEncapsulation(kem_algo) as kem:
231 | ciphertext, shared_secret = kem.encap_secret(public_key)
232 | log.debug(f"Performed KEM encapsulation using {kem_algo}.")
233 | return ciphertext, shared_secret
234 |
235 | try:
236 | return _encapsulate()
237 | except PQCError:
238 | raise
239 | except Exception as e:
240 | raise PQCKEMError(f"Unexpected error during KEM encapsulation with {kem_algo}: {e}")
241 |
242 | def kem_decapsulate(kem_algo: str, ciphertext: bytes, secret_key: bytes) -> bytes:
243 | """Performs KEM decapsulation using the recipient's private key.
244 |
245 | Args:
246 | kem_algo: The KEM algorithm name
247 | ciphertext: The KEM ciphertext bytes
248 | secret_key: The recipient's private key bytes
249 |
250 | Returns:
251 | The derived shared secret bytes
252 |
253 | Raises:
254 | PQCKEMError: If KEM operation fails
255 | ValueError: If inputs are invalid
256 | """
257 | if not isinstance(kem_algo, str) or not kem_algo.strip():
258 | raise ValueError("KEM algorithm must be a non-empty string")
259 | if not isinstance(ciphertext, bytes) or len(ciphertext) == 0:
260 | raise ValueError("Ciphertext must be non-empty bytes")
261 | if not isinstance(secret_key, bytes) or len(secret_key) == 0:
262 | raise ValueError("Secret key must be non-empty bytes")
263 |
264 | if not oqs.is_kem_enabled(kem_algo):
265 | raise PQCKEMError(f"KEM algorithm '{kem_algo}' is not enabled or supported by current build flags.")
266 |
267 | @_retry_on_failure
268 | def _decapsulate():
269 | with oqs.KeyEncapsulation(kem_algo, secret_key) as kem:
270 | shared_secret = kem.decap_secret(ciphertext)
271 | log.debug(f"Performed KEM decapsulation using {kem_algo}.")
272 | return shared_secret
273 |
274 | try:
275 | return _decapsulate()
276 | except PQCError:
277 | raise
278 | except Exception as e:
279 | raise PQCKEMError(f"Unexpected error during KEM decapsulation with {kem_algo}: {e}")
280 |
281 | def derive_aes_key(kem_shared_secret: bytes) -> bytes:
282 | """Derives a fixed-size AES key from the KEM shared secret using HKDF.
283 |
284 | Args:
285 | kem_shared_secret: The shared secret bytes from KEM operation
286 |
287 | Returns:
288 | The derived AES key bytes
289 |
290 | Raises:
291 | ValueError: If shared secret is invalid
292 | PQCError: If key derivation fails
293 | """
294 | if not isinstance(kem_shared_secret, bytes) or len(kem_shared_secret) == 0:
295 | raise ValueError("KEM shared secret must be non-empty bytes")
296 |
297 | try:
298 | hkdf = HKDF(
299 | algorithm=hashes.SHA256(),
300 | length=AES_KEY_BYTES,
301 | salt=None,
302 | info=b'qu3-aes-gcm-key',
303 | )
304 | derived_key = hkdf.derive(kem_shared_secret)
305 | log.debug(f"Derived AES-{AES_KEY_BYTES*8} key using HKDF.")
306 | return derived_key
307 | except Exception as e:
308 | raise PQCError(f"Failed to derive AES key from shared secret: {e}")
309 |
310 | def encrypt_aes_gcm(key: bytes, plaintext: bytes) -> Tuple[bytes, bytes]:
311 | """Encrypts plaintext using AES-GCM with the given key.
312 |
313 | Args:
314 | key: The AES key (must be exactly AES_KEY_BYTES bytes)
315 | plaintext: The data to encrypt
316 |
317 | Returns:
318 | Tuple of (nonce_bytes, ciphertext_bytes)
319 |
320 | Raises:
321 | PQCEncryptionError: If encryption fails
322 | ValueError: If inputs are invalid
323 | """
324 | if not isinstance(key, bytes) or len(key) != AES_KEY_BYTES:
325 | raise ValueError(f"AES key must be exactly {AES_KEY_BYTES} bytes, got {len(key) if isinstance(key, bytes) else 'non-bytes'}. Use derive_aes_key.")
326 | if not isinstance(plaintext, bytes):
327 | raise ValueError("Plaintext must be bytes")
328 |
329 | try:
330 | aesgcm = AESGCM(key)
331 | nonce = os.urandom(AES_NONCE_BYTES)
332 | ciphertext = aesgcm.encrypt(nonce, plaintext, None)
333 | log.debug(f"AES-GCM Encryption complete. Nonce: {nonce.hex()[:16]}..., Ciphertext length: {len(ciphertext)}")
334 | return nonce, ciphertext
335 | except Exception as e:
336 | log.error(f"AES-GCM Encryption failed: {e}")
337 | raise PQCEncryptionError(f"AES-GCM encryption failed: {e}")
338 |
339 | def decrypt_aes_gcm(key: bytes, nonce: bytes, ciphertext: bytes) -> bytes:
340 | """Decrypts AES-GCM ciphertext using the given key and nonce.
341 |
342 | Args:
343 | key: The AES key (must be exactly AES_KEY_BYTES bytes)
344 | nonce: The nonce (IV) used during encryption
345 | ciphertext: The encrypted data
346 |
347 | Returns:
348 | The original plaintext bytes
349 |
350 | Raises:
351 | PQCDecryptionError: If decryption fails
352 | ValueError: If inputs are invalid
353 | """
354 | if not isinstance(key, bytes) or len(key) != AES_KEY_BYTES:
355 | raise ValueError(f"AES key must be exactly {AES_KEY_BYTES} bytes, got {len(key) if isinstance(key, bytes) else 'non-bytes'}. Use derive_aes_key.")
356 | if not isinstance(nonce, bytes) or len(nonce) != AES_NONCE_BYTES:
357 | raise ValueError(f"Invalid nonce: expected {AES_NONCE_BYTES} bytes, got {len(nonce) if isinstance(nonce, bytes) else 'non-bytes'}")
358 | if not isinstance(ciphertext, bytes):
359 | raise ValueError("Ciphertext must be bytes")
360 |
361 | try:
362 | aesgcm = AESGCM(key)
363 | plaintext = aesgcm.decrypt(nonce, ciphertext, None)
364 | log.debug("AES-GCM Decryption successful.")
365 | return plaintext
366 | except InvalidTag as e:
367 | log.error(f"AES-GCM Decryption failed due to InvalidTag: {e}")
368 | raise PQCDecryptionError(f"AES-GCM decryption failed (InvalidTag): {e}")
369 | except Exception as e:
370 | log.error(f"AES-GCM Decryption failed with an unexpected error: {e}")
371 | raise PQCDecryptionError(f"AES-GCM decryption failed unexpectedly: {e}")
372 |
--------------------------------------------------------------------------------
/tests/test_config_utils.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import os
3 | import sys
4 | import yaml
5 | from pathlib import Path
6 | import tempfile
7 | import shutil
8 | import base64
9 | import json
10 | import requests
11 |
12 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
13 | if project_root not in sys.path:
14 | sys.path.insert(0, project_root)
15 |
16 | from unittest.mock import patch
17 |
18 | MOCK_CONFIG_FILENAME_BASENAME = "temp_test_config.yaml"
19 |
20 | from src import config_utils
21 | from src import pqc_utils
22 |
23 | class TestConfigUtils(unittest.TestCase):
24 |
25 | @classmethod
26 | def setUpClass(cls):
27 | """Set up a temporary directory for test file operations."""
28 | cls.temp_dir = tempfile.mkdtemp(prefix="qu3_test_config_")
29 | cls.temp_dir_path = Path(cls.temp_dir)
30 | cls.MOCK_CONFIG_ABSOLUTE_PATH = cls.temp_dir_path / MOCK_CONFIG_FILENAME_BASENAME
31 |
32 | cls.kem_algo = pqc_utils.ALGORITHMS["kem"]
33 | cls.sig_algo = pqc_utils.ALGORITHMS["sig"]
34 | cls.dummy_kem_pk, cls.dummy_kem_sk = pqc_utils.generate_key_pair(cls.kem_algo)
35 | cls.dummy_sig_pk, cls.dummy_sig_sk = pqc_utils.generate_key_pair(cls.sig_algo)
36 |
37 | @classmethod
38 | def tearDownClass(cls):
39 | """Clean up the temporary directory."""
40 | shutil.rmtree(cls.temp_dir)
41 | # MOCK_CONFIG_ABSOLUTE_PATH is removed by individual tests or setUp
42 |
43 | def setUp(self):
44 | """Ensure clean state before each test (clear cache, remove mock file)."""
45 | config_utils._config_cache = None
46 | if self.MOCK_CONFIG_ABSOLUTE_PATH.exists():
47 | os.remove(self.MOCK_CONFIG_ABSOLUTE_PATH)
48 |
49 | def tearDown(self):
50 | # Ensure the mock config file is cleaned up after each test, if not already by setUp of next
51 | if self.MOCK_CONFIG_ABSOLUTE_PATH.exists():
52 | os.remove(self.MOCK_CONFIG_ABSOLUTE_PATH)
53 | config_utils._config_cache = None # Clear cache again
54 |
55 | def test_01_load_config_defaults(self):
56 | """Test loading config when file doesn't exist (uses defaults)."""
57 | with patch('src.config_utils.CONFIG_FILE_PATH', self.MOCK_CONFIG_ABSOLUTE_PATH):
58 | print(f"\n[Test 01] Testing with MOCK_CONFIG_ABSOLUTE_PATH: {self.MOCK_CONFIG_ABSOLUTE_PATH}")
59 |
60 | if self.MOCK_CONFIG_ABSOLUTE_PATH.exists():
61 | os.remove(self.MOCK_CONFIG_ABSOLUTE_PATH)
62 | self.assertFalse(self.MOCK_CONFIG_ABSOLUTE_PATH.exists())
63 | print(f"[Test 01] Exists before load_config: {self.MOCK_CONFIG_ABSOLUTE_PATH.exists()}")
64 | config_utils._config_cache = None
65 | config = config_utils.load_config()
66 | print(f"[Test 01] Raw output of load_config: {config}")
67 | self.assertEqual(config, {})
68 |
69 | # These should now use the defaults because load_config returned empty
70 | self.assertEqual(config_utils.get_key_dir(), Path(config_utils.DEFAULT_KEY_DIR_STR).expanduser().resolve())
71 | self.assertEqual(config_utils.get_server_url(), config_utils.DEFAULT_SERVER_URL)
72 |
73 | def test_02_load_config_custom(self):
74 | """Test loading config from a custom YAML file."""
75 | custom_config = {
76 | 'key_directory': str(self.temp_dir_path / "custom_keys"),
77 | 'server_url': 'http://custom.example.com:9000',
78 | 'some_other_setting': 123
79 | }
80 | with patch('src.config_utils.CONFIG_FILE_PATH', self.MOCK_CONFIG_ABSOLUTE_PATH):
81 | print(f"\n[Test 02] Testing with MOCK_CONFIG_ABSOLUTE_PATH: {self.MOCK_CONFIG_ABSOLUTE_PATH}")
82 | with open(self.MOCK_CONFIG_ABSOLUTE_PATH, 'w') as f:
83 | yaml.dump(custom_config, f)
84 |
85 | print(f"[Test 02] Exists before load_config: {self.MOCK_CONFIG_ABSOLUTE_PATH.exists()}")
86 | config_utils._config_cache = None
87 | config = config_utils.load_config()
88 | print(f"[Test 02] Raw output of load_config: {config}")
89 | self.assertEqual(config, custom_config)
90 |
91 | # These should use the custom values
92 | self.assertEqual(config_utils.get_key_dir(), self.temp_dir_path / "custom_keys")
93 | self.assertEqual(config_utils.get_server_url(), 'http://custom.example.com:9000')
94 | self.assertEqual(config_utils.get_config_value('some_other_setting'), 123)
95 |
96 | # Test cache modification (though direct cache mod is not typical usage)
97 | if config_utils._config_cache:
98 | config_utils._config_cache['server_url'] = 'cached_value'
99 | self.assertEqual(config_utils.get_server_url(), 'cached_value')
100 |
101 | def test_03_load_config_invalid_yaml(self):
102 | """Test loading config with invalid YAML content."""
103 | with patch('src.config_utils.CONFIG_FILE_PATH', self.MOCK_CONFIG_ABSOLUTE_PATH):
104 | print(f"\n[Test 03] Testing with MOCK_CONFIG_ABSOLUTE_PATH: {self.MOCK_CONFIG_ABSOLUTE_PATH}")
105 | with open(self.MOCK_CONFIG_ABSOLUTE_PATH, 'w') as f:
106 | f.write("key: value: nested_invalid")
107 |
108 | print(f"[Test 03] Exists before load_config: {self.MOCK_CONFIG_ABSOLUTE_PATH.exists()}")
109 | config_utils._config_cache = None
110 | config = config_utils.load_config()
111 | print(f"[Test 03] Raw output of load_config: {config}")
112 | self.assertEqual(config, {})
113 |
114 | # These should use the defaults
115 | self.assertEqual(config_utils.get_key_dir(), Path(config_utils.DEFAULT_KEY_DIR_STR).expanduser().resolve())
116 | self.assertEqual(config_utils.get_server_url(), config_utils.DEFAULT_SERVER_URL)
117 |
118 | def test_04_save_load_key_pair(self):
119 | """Test saving and loading a key pair."""
120 | pub_path = self.temp_dir_path / "test_kem.pub"
121 | sec_path = self.temp_dir_path / "test_kem.sec"
122 |
123 | # Ensure files don't exist initially
124 | self.assertFalse(pub_path.exists())
125 | self.assertFalse(sec_path.exists())
126 |
127 |
128 | config_utils.save_key_pair_to_files(self.dummy_kem_pk, self.dummy_kem_sk, pub_path, sec_path)
129 | self.assertTrue(pub_path.exists())
130 | self.assertTrue(sec_path.exists())
131 |
132 | self.assertEqual(sec_path.stat().st_mode & 0o777, 0o600)
133 | self.assertEqual(pub_path.stat().st_mode & 0o777, 0o644)
134 |
135 |
136 | loaded_pk, loaded_sk = config_utils.load_key_pair_from_files(pub_path, sec_path)
137 | self.assertEqual(loaded_pk, self.dummy_kem_pk)
138 | self.assertEqual(loaded_sk, self.dummy_kem_sk)
139 |
140 | def test_05_load_public_key(self):
141 | """Test loading just a public key."""
142 | pub_path = self.temp_dir_path / "test_sig.pub"
143 | sec_path = self.temp_dir_path / "test_sig.sec"
144 | config_utils.save_key_pair_to_files(self.dummy_sig_pk, self.dummy_sig_sk, pub_path, sec_path)
145 |
146 | loaded_pk = config_utils.load_public_key_from_file(pub_path)
147 | self.assertEqual(loaded_pk, self.dummy_sig_pk)
148 |
149 | def test_06_load_key_pair_not_found(self):
150 | """Test loading non-existent key pair raises FileNotFoundError."""
151 | pub_path = self.temp_dir_path / "non_existent.pub"
152 | sec_path = self.temp_dir_path / "non_existent.sec"
153 | with self.assertRaises(FileNotFoundError):
154 | config_utils.load_key_pair_from_files(pub_path, sec_path)
155 |
156 | def test_07_load_public_key_not_found(self):
157 | """Test loading non-existent public key raises FileNotFoundError."""
158 | pub_path = self.temp_dir_path / "non_existent.pub"
159 | with self.assertRaises(FileNotFoundError):
160 | config_utils.load_public_key_from_file(pub_path)
161 |
162 |
163 |
164 | @patch('src.config_utils.requests.get')
165 | def test_08_fetch_keys_success(self, mock_get):
166 | """Test successfully fetching and saving server keys."""
167 |
168 | mock_response = unittest.mock.Mock()
169 | mock_response.status_code = 200
170 | mock_response.json.return_value = {
171 | 'server_kem_public_key_b64': base64.b64encode(self.dummy_kem_pk).decode('utf-8'),
172 | 'server_sign_public_key_b64': base64.b64encode(self.dummy_sig_pk).decode('utf-8')
173 | }
174 | mock_get.return_value = mock_response
175 |
176 | server_url = "http://mockserver.test:8000"
177 | kem_file = "server_kem.pub"
178 | sign_file = "server_sign.pub"
179 | kem_path = self.temp_dir_path / kem_file
180 | sign_path = self.temp_dir_path / sign_file
181 |
182 | self.assertFalse(kem_path.exists())
183 | self.assertFalse(sign_path.exists())
184 |
185 |
186 | result = config_utils.fetch_and_save_server_keys(server_url, self.temp_dir_path, kem_file, sign_file)
187 |
188 |
189 | self.assertTrue(result)
190 | mock_get.assert_called_once_with(f"{server_url}/keys", timeout=10)
191 | mock_response.raise_for_status.assert_called_once()
192 | self.assertTrue(kem_path.exists())
193 | self.assertTrue(sign_path.exists())
194 | self.assertEqual(kem_path.read_bytes(), self.dummy_kem_pk)
195 | self.assertEqual(sign_path.read_bytes(), self.dummy_sig_pk)
196 | self.assertEqual(kem_path.stat().st_mode & 0o777, 0o644)
197 | self.assertEqual(sign_path.stat().st_mode & 0o777, 0o644)
198 |
199 | @patch('src.config_utils.requests.get')
200 | def test_09_fetch_keys_network_error(self, mock_get):
201 | """Test fetching keys with a network error."""
202 | mock_get.side_effect = requests.exceptions.ConnectionError("Test connection error")
203 | server_url = "http://unreachable.test:8000"
204 | result = config_utils.fetch_and_save_server_keys(server_url, self.temp_dir_path)
205 | self.assertFalse(result)
206 | mock_get.assert_called_once_with(f"{server_url}/keys", timeout=10)
207 |
208 | @patch('src.config_utils.requests.get')
209 | def test_10_fetch_keys_http_error(self, mock_get):
210 | """Test fetching keys with an HTTP error response (e.g., 404)."""
211 | mock_response = unittest.mock.Mock()
212 | mock_response.status_code = 404
213 | mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Not Found")
214 | mock_get.return_value = mock_response
215 | server_url = "http://mockserver.test:8000"
216 | result = config_utils.fetch_and_save_server_keys(server_url, self.temp_dir_path)
217 | self.assertFalse(result)
218 | mock_get.assert_called_once_with(f"{server_url}/keys", timeout=10)
219 | mock_response.raise_for_status.assert_called_once()
220 |
221 | @patch('src.config_utils.requests.get')
222 | def test_11_fetch_keys_bad_json(self, mock_get):
223 | """Test fetching keys with invalid JSON in the response."""
224 | mock_response = unittest.mock.Mock()
225 | mock_response.status_code = 200
226 | mock_response.json.side_effect = json.JSONDecodeError("Bad JSON", "", 0)
227 | mock_get.return_value = mock_response
228 | server_url = "http://mockserver.test:8000"
229 | result = config_utils.fetch_and_save_server_keys(server_url, self.temp_dir_path)
230 | self.assertFalse(result)
231 | mock_get.assert_called_once_with(f"{server_url}/keys", timeout=10)
232 | mock_response.raise_for_status.assert_called_once()
233 | mock_response.json.assert_called_once()
234 |
235 | @patch('src.config_utils.requests.get')
236 | def test_12_fetch_keys_missing_data(self, mock_get):
237 | """Test fetching keys with missing key data in the JSON response."""
238 | mock_response = unittest.mock.Mock()
239 | mock_response.status_code = 200
240 | mock_response.json.return_value = {'server_kem_public_key_b64': 'abc='}
241 | mock_get.return_value = mock_response
242 | server_url = "http://mockserver.test:8000"
243 | result = config_utils.fetch_and_save_server_keys(server_url, self.temp_dir_path)
244 | self.assertFalse(result)
245 |
246 | @patch('src.config_utils.requests.get')
247 | @patch('builtins.open', new_callable=unittest.mock.mock_open)
248 | def test_13_fetch_keys_save_error(self, mock_open, mock_get):
249 | """Test fetching keys with an error during file saving."""
250 | mock_response = unittest.mock.Mock()
251 | mock_response.status_code = 200
252 | mock_response.json.return_value = {
253 | 'server_kem_public_key_b64': base64.b64encode(self.dummy_kem_pk).decode('utf-8'),
254 | 'server_sign_public_key_b64': base64.b64encode(self.dummy_sig_pk).decode('utf-8')
255 | }
256 | mock_get.return_value = mock_response
257 |
258 | mock_open.side_effect = IOError("Permission denied")
259 |
260 | server_url = "http://mockserver.test:8000"
261 | result = config_utils.fetch_and_save_server_keys(server_url, self.temp_dir_path)
262 |
263 | self.assertFalse(result)
264 | mock_get.assert_called_once()
265 |
266 | self.assertTrue(mock_open.called)
267 |
268 | def test_14_get_logging_config(self):
269 | """Test retrieving logging configuration from config.yaml."""
270 |
271 | scenarios = [
272 | ("fully_configured",
273 | {'logging': {'level': 'DEBUG', 'file': 'test.log'}},
274 | {'level': 'DEBUG', 'file': 'test.log'}),
275 | ("missing_level",
276 | {'logging': {'file': 'app.log'}},
277 | {'level': 'INFO', 'file': 'app.log'}),
278 | ("missing_file",
279 | {'logging': {'level': 'WARNING'}},
280 | {'level': 'WARNING', 'file': None}),
281 | ("invalid_level_str",
282 | {'logging': {'level': 'INVALID', 'file': 'err.log'}},
283 | {'level': 'INFO', 'file': 'err.log'}),
284 | ("invalid_level_type",
285 | {'logging': {'level': 123, 'file': 'err.log'}},
286 | {'level': 'INFO', 'file': 'err.log'}),
287 | ("no_logging_section",
288 | {'other_config': 'value'},
289 | {'level': 'INFO', 'file': None}),
290 | ("logging_not_dict",
291 | {'logging': 'not_a_dictionary'},
292 | {'level': 'INFO', 'file': None}),
293 | ("empty_config_file",
294 | {},
295 | {'level': 'INFO', 'file': None}),
296 | ("no_config_file",
297 | None,
298 | {'level': 'INFO', 'file': None}),
299 | ]
300 |
301 | for name, config_content, expected_log_config in scenarios:
302 | with self.subTest(name=name):
303 | if self.MOCK_CONFIG_ABSOLUTE_PATH.exists():
304 | os.remove(self.MOCK_CONFIG_ABSOLUTE_PATH)
305 | config_utils._config_cache = None
306 |
307 | if config_content is not None:
308 | with open(self.MOCK_CONFIG_ABSOLUTE_PATH, 'w') as f:
309 | yaml.dump(config_content, f)
310 |
311 | with patch('src.config_utils.CONFIG_FILE_PATH', self.MOCK_CONFIG_ABSOLUTE_PATH):
312 | config_utils._config_cache = None
313 | actual_log_config = config_utils.get_logging_config()
314 | self.assertEqual(actual_log_config, expected_log_config)
315 |
316 | if self.MOCK_CONFIG_ABSOLUTE_PATH.exists():
317 | os.remove(self.MOCK_CONFIG_ABSOLUTE_PATH)
318 | config_utils._config_cache = None
319 |
320 |
321 | if __name__ == '__main__':
322 | unittest.main()
--------------------------------------------------------------------------------
/src/config_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import stat
3 | from pathlib import Path
4 | from typing import Dict, Optional, Tuple, Any
5 | import logging
6 | import yaml
7 | import requests
8 | import json
9 | import base64
10 | from urllib.parse import urljoin
11 | import time
12 |
13 | CONFIG_FILE_PATH = Path("config.yaml")
14 | _config_cache: Optional[Dict[str, Any]] = None
15 |
16 | DEFAULT_KEY_DIR_STR = "~/.qu3/keys"
17 | DEFAULT_SERVER_URL = "http://127.0.0.1:8000"
18 |
19 | # Security and reliability constants
20 | MAX_CONFIG_FILE_SIZE = 1024 * 1024 # 1MB max config file size
21 | REQUEST_TIMEOUT = 30 # seconds
22 | MAX_RETRIES = 3
23 | RETRY_DELAY = 1.0 # seconds
24 |
25 | log = logging.getLogger(__name__)
26 |
27 | def _validate_config_security(config_path: Path) -> bool:
28 | """Validate configuration file security permissions and size."""
29 | try:
30 | stat_info = config_path.stat()
31 |
32 | # Check file size
33 | if stat_info.st_size > MAX_CONFIG_FILE_SIZE:
34 | log.warning(f"Configuration file {config_path} is unusually large ({stat_info.st_size} bytes)")
35 | return False
36 |
37 | # Check permissions (should not be world-writable)
38 | if stat_info.st_mode & stat.S_IWOTH:
39 | log.warning(f"Configuration file {config_path} is world-writable (security risk)")
40 | return False
41 |
42 | return True
43 | except OSError as e:
44 | log.error(f"Failed to check security of config file {config_path}: {e}")
45 | return False
46 |
47 | def load_config() -> Dict[str, Any]:
48 | """Loads configuration from config.yaml, caches, and returns it.
49 |
50 | Returns:
51 | Dictionary containing configuration values
52 |
53 | Raises:
54 | ValueError: If configuration is invalid
55 | OSError: If file operations fail
56 | """
57 | global _config_cache
58 | if _config_cache is not None:
59 | return _config_cache
60 |
61 | config: Dict[str, Any] = {}
62 | if CONFIG_FILE_PATH.exists():
63 | # Validate security before loading
64 | if not _validate_config_security(CONFIG_FILE_PATH):
65 | log.warning(f"Security validation failed for {CONFIG_FILE_PATH}, using defaults")
66 | _config_cache = {}
67 | return _config_cache
68 |
69 | try:
70 | with open(CONFIG_FILE_PATH, 'r', encoding='utf-8') as f:
71 | loaded_yaml = yaml.safe_load(f)
72 | if isinstance(loaded_yaml, dict):
73 | config = loaded_yaml
74 | log.info(f"Loaded configuration from {CONFIG_FILE_PATH}")
75 | else:
76 | log.warning(f"Configuration file {CONFIG_FILE_PATH} does not contain a valid YAML dictionary.")
77 | except yaml.YAMLError as e:
78 | log.error(f"Error parsing configuration file {CONFIG_FILE_PATH}: {e}")
79 | raise ValueError(f"Invalid YAML in configuration file: {e}")
80 | except UnicodeDecodeError as e:
81 | log.error(f"Encoding error reading configuration file {CONFIG_FILE_PATH}: {e}")
82 | raise ValueError(f"Configuration file encoding error: {e}")
83 | except Exception as e:
84 | log.error(f"Error reading configuration file {CONFIG_FILE_PATH}: {e}")
85 | raise OSError(f"Failed to read configuration file: {e}")
86 | else:
87 | log.info(f"Configuration file {CONFIG_FILE_PATH} not found. Using default values.")
88 |
89 | # Validate configuration structure
90 | if config and not isinstance(config, dict):
91 | raise ValueError("Configuration must be a dictionary")
92 |
93 | _config_cache = config if config else {}
94 | return _config_cache
95 |
96 | def get_config_value(key: str, default: Any = None) -> Any:
97 | """Retrieves a value from the loaded configuration, falling back to a default.
98 |
99 | Args:
100 | key: Configuration key to retrieve
101 | default: Default value if key is not found
102 |
103 | Returns:
104 | Configuration value or default
105 | """
106 | if not isinstance(key, str) or not key.strip():
107 | log.warning("Configuration key must be a non-empty string")
108 | return default
109 |
110 | config = load_config()
111 | return config.get(key, default)
112 |
113 | def get_key_dir() -> Path:
114 | """Returns the configured directory path for storing PQC keys.
115 |
116 | Returns:
117 | Path object for the key directory
118 |
119 | Raises:
120 | ValueError: If key directory path is invalid
121 | """
122 | key_dir_str = get_config_value('key_directory', DEFAULT_KEY_DIR_STR)
123 |
124 | if not isinstance(key_dir_str, str) or not key_dir_str.strip():
125 | log.warning("Invalid key_directory in config, using default")
126 | key_dir_str = DEFAULT_KEY_DIR_STR
127 |
128 | try:
129 | key_dir = Path(key_dir_str).expanduser().resolve()
130 |
131 | # Security check: ensure key directory is not in a world-writable location
132 | if key_dir.exists():
133 | stat_info = key_dir.stat()
134 | if stat_info.st_mode & stat.S_IWOTH:
135 | log.warning(f"Key directory {key_dir} is world-writable (security risk)")
136 |
137 | return key_dir
138 | except Exception as e:
139 | log.error(f"Failed to resolve key directory path '{key_dir_str}': {e}")
140 | raise ValueError(f"Invalid key directory path: {e}")
141 |
142 | def ensure_key_dir_exists(key_dir_path: Optional[Path] = None) -> Path:
143 | """Ensures the key directory exists, creating it if necessary.
144 |
145 | Args:
146 | key_dir_path: Optional specific path to ensure. If None, uses configured key_dir.
147 |
148 | Returns:
149 | The path to the key directory.
150 |
151 | Raises:
152 | OSError: If directory creation fails
153 | PermissionError: If insufficient permissions
154 | """
155 | target_path = key_dir_path if key_dir_path is not None else get_key_dir()
156 |
157 | try:
158 | target_path.mkdir(parents=True, exist_ok=True)
159 |
160 | # Set secure permissions on the key directory (owner only)
161 | try:
162 | os.chmod(target_path, 0o700)
163 | log.debug(f"Set secure permissions (700) on key directory: {target_path}")
164 | except OSError as e:
165 | log.warning(f"Failed to set secure permissions on {target_path}: {e}")
166 |
167 | log.info(f"Ensured key directory exists: {target_path}")
168 | except PermissionError as e:
169 | log.error(f"Permission denied creating key directory {target_path}: {e}")
170 | raise
171 | except OSError as e:
172 | log.error(f"Failed to create key directory {target_path}: {e}")
173 | raise
174 |
175 | return target_path
176 |
177 | def get_server_url() -> str:
178 | """Returns the configured default MCP server URL.
179 |
180 | Returns:
181 | Server URL string
182 | """
183 | url = get_config_value('server_url', DEFAULT_SERVER_URL)
184 |
185 | if not isinstance(url, str) or not url.strip():
186 | log.warning("Invalid server_url in config, using default")
187 | return DEFAULT_SERVER_URL
188 |
189 | # Basic URL validation
190 | if not (url.startswith('http://') or url.startswith('https://')):
191 | log.warning(f"Server URL '{url}' does not start with http:// or https://")
192 |
193 | return url.strip()
194 |
195 | def save_key_pair_to_files(public_key: bytes, secret_key: bytes, pub_path: Path, sec_path: Path):
196 | """Saves a public/private key pair to the specified file paths with secure permissions.
197 |
198 | Args:
199 | public_key: Public key bytes
200 | secret_key: Secret key bytes
201 | pub_path: Path for public key file
202 | sec_path: Path for secret key file
203 |
204 | Raises:
205 | ValueError: If inputs are invalid
206 | OSError: If file operations fail
207 | """
208 | if not isinstance(public_key, bytes) or len(public_key) == 0:
209 | raise ValueError("Public key must be non-empty bytes")
210 | if not isinstance(secret_key, bytes) or len(secret_key) == 0:
211 | raise ValueError("Secret key must be non-empty bytes")
212 | if not isinstance(pub_path, Path) or not isinstance(sec_path, Path):
213 | raise ValueError("Paths must be Path objects")
214 |
215 | try:
216 | pub_path.parent.mkdir(parents=True, exist_ok=True)
217 |
218 | # Save public key with readable permissions
219 | with open(pub_path, 'wb') as f_pub:
220 | f_pub.write(public_key)
221 | os.chmod(pub_path, 0o644)
222 |
223 | # Save secret key with restricted permissions
224 | with open(sec_path, 'wb') as f_sec:
225 | f_sec.write(secret_key)
226 | os.chmod(sec_path, 0o600)
227 |
228 | log.info(f"Saved key pair: Public='{pub_path.name}', Secret='{sec_path.name}' with secure permissions.")
229 | except OSError as e:
230 | log.exception(f"Failed to save key pair ('{pub_path.name}', '{sec_path.name}'): {e}")
231 | raise
232 |
233 | def load_key_pair_from_files(pub_path: Path, sec_path: Path) -> Tuple[bytes, bytes]:
234 | """Loads a public/private key pair from the specified file paths.
235 |
236 | Args:
237 | pub_path: Path to public key file
238 | sec_path: Path to secret key file
239 |
240 | Returns:
241 | Tuple of (public_key_bytes, secret_key_bytes)
242 |
243 | Raises:
244 | FileNotFoundError: If key files don't exist
245 | ValueError: If key files are invalid
246 | OSError: If file operations fail
247 | """
248 | if not isinstance(pub_path, Path) or not isinstance(sec_path, Path):
249 | raise ValueError("Paths must be Path objects")
250 |
251 | try:
252 | # Check file permissions for security
253 | if sec_path.exists():
254 | stat_info = sec_path.stat()
255 | if stat_info.st_mode & (stat.S_IRGRP | stat.S_IROTH):
256 | log.warning(f"Secret key file {sec_path} has overly permissive read permissions")
257 |
258 | with open(pub_path, 'rb') as f_pub:
259 | public_key = f_pub.read()
260 | with open(sec_path, 'rb') as f_sec:
261 | secret_key = f_sec.read()
262 |
263 | if len(public_key) == 0:
264 | raise ValueError(f"Public key file {pub_path} is empty")
265 | if len(secret_key) == 0:
266 | raise ValueError(f"Secret key file {sec_path} is empty")
267 |
268 | log.debug(f"Loaded key pair: Public='{pub_path.name}', Secret='{sec_path.name}'")
269 | return public_key, secret_key
270 | except FileNotFoundError:
271 | log.debug(f"Key pair files not found: '{pub_path.name}', '{sec_path.name}'")
272 | raise
273 | except OSError as e:
274 | log.exception(f"Failed to load key pair ('{pub_path.name}', '{sec_path.name}'): {e}")
275 | raise
276 |
277 | def load_public_key_from_file(pub_path: Path) -> bytes:
278 | """Loads a public key from the specified file path.
279 |
280 | Args:
281 | pub_path: Path to public key file
282 |
283 | Returns:
284 | Public key bytes
285 |
286 | Raises:
287 | FileNotFoundError: If key file doesn't exist
288 | ValueError: If key file is invalid
289 | OSError: If file operations fail
290 | """
291 | if not isinstance(pub_path, Path):
292 | raise ValueError("Path must be a Path object")
293 |
294 | try:
295 | with open(pub_path, 'rb') as f_pub:
296 | public_key = f_pub.read()
297 |
298 | if len(public_key) == 0:
299 | raise ValueError(f"Public key file {pub_path} is empty")
300 |
301 | log.debug(f"Loaded public key from '{pub_path.name}'")
302 | return public_key
303 | except FileNotFoundError:
304 | log.debug(f"Public key file not found: '{pub_path.name}'")
305 | raise
306 | except OSError as e:
307 | log.exception(f"Failed to load public key from '{pub_path.name}': {e}")
308 | raise
309 |
310 | def fetch_and_save_server_keys(
311 | server_url: str,
312 | key_dir: Path,
313 | kem_pub_filename: str = "server_kem.pub",
314 | sign_pub_filename: str = "server_sign.pub"
315 | ) -> bool:
316 | """Fetches server public keys from the /keys endpoint and saves them with retry logic.
317 |
318 | Args:
319 | server_url: The base URL of the MCP server.
320 | key_dir: The directory to save the keys into.
321 | kem_pub_filename: The filename for the server KEM public key.
322 | sign_pub_filename: The filename for the server signing public key.
323 |
324 | Returns:
325 | True if keys were fetched and saved successfully, False otherwise.
326 | """
327 | if not isinstance(server_url, str) or not server_url.strip():
328 | log.error("Server URL must be a non-empty string")
329 | return False
330 | if not isinstance(key_dir, Path):
331 | log.error("Key directory must be a Path object")
332 | return False
333 |
334 | keys_endpoint = urljoin(server_url.rstrip('/') + '/', "keys")
335 | log.info(f"Attempting to fetch server public keys from {keys_endpoint}...")
336 |
337 | for attempt in range(MAX_RETRIES):
338 | try:
339 | response = requests.get(
340 | keys_endpoint,
341 | timeout=REQUEST_TIMEOUT,
342 | headers={'User-Agent': 'qu3-client/1.0'}
343 | )
344 | response.raise_for_status()
345 |
346 | # Validate response content type
347 | content_type = response.headers.get('content-type', '')
348 | if 'application/json' not in content_type:
349 | log.warning(f"Unexpected content type from server: {content_type}")
350 |
351 | keys_data = response.json()
352 |
353 | server_kem_pk_b64 = keys_data.get('server_kem_public_key_b64')
354 | server_sign_pk_b64 = keys_data.get('server_sign_public_key_b64')
355 |
356 | if not server_kem_pk_b64 or not server_sign_pk_b64:
357 | log.error("Server response from /keys is missing required key fields.")
358 | return False
359 |
360 | # Validate base64 encoding
361 | try:
362 | server_kem_pk = base64.b64decode(server_kem_pk_b64)
363 | server_sign_pk = base64.b64decode(server_sign_pk_b64)
364 | except Exception as e:
365 | log.error(f"Failed to decode base64 keys from server: {e}")
366 | return False
367 |
368 | # Validate key sizes (basic sanity check)
369 | if len(server_kem_pk) < 32 or len(server_sign_pk) < 32:
370 | log.error("Server keys appear to be too small (possible corruption)")
371 | return False
372 |
373 | kem_path = key_dir / kem_pub_filename
374 | sign_path = key_dir / sign_pub_filename
375 |
376 | # Ensure directory exists
377 | key_dir.mkdir(parents=True, exist_ok=True)
378 |
379 | # Save keys with secure permissions
380 | with open(kem_path, 'wb') as f:
381 | f.write(server_kem_pk)
382 | os.chmod(kem_path, 0o644)
383 | log.info(f"Saved server KEM public key to {kem_path} with permissions set.")
384 |
385 | with open(sign_path, 'wb') as f:
386 | f.write(server_sign_pk)
387 | os.chmod(sign_path, 0o644)
388 | log.info(f"Saved server signing public key to {sign_path} with permissions set.")
389 |
390 | return True
391 |
392 | except requests.exceptions.Timeout as e:
393 | log.warning(f"Timeout fetching server keys (attempt {attempt + 1}/{MAX_RETRIES}): {e}")
394 | except requests.exceptions.ConnectionError as e:
395 | log.warning(f"Connection error fetching server keys (attempt {attempt + 1}/{MAX_RETRIES}): {e}")
396 | except requests.exceptions.HTTPError as e:
397 | log.error(f"HTTP error fetching server keys: {e}")
398 | return False # Don't retry on HTTP errors
399 | except requests.exceptions.RequestException as e:
400 | log.warning(f"Network error fetching server keys (attempt {attempt + 1}/{MAX_RETRIES}): {e}")
401 | except (json.JSONDecodeError, ValueError, TypeError) as e:
402 | log.error(f"Error decoding/parsing server keys response: {e}")
403 | return False # Don't retry on parsing errors
404 | except IOError as e:
405 | log.error(f"Error saving server keys to {key_dir}: {e}")
406 | return False # Don't retry on I/O errors
407 | except Exception as e:
408 | log.exception(f"Unexpected error fetching or saving server keys:")
409 | return False
410 |
411 | if attempt < MAX_RETRIES - 1:
412 | log.info(f"Retrying in {RETRY_DELAY} seconds...")
413 | time.sleep(RETRY_DELAY)
414 |
415 | log.error(f"Failed to fetch server keys after {MAX_RETRIES} attempts")
416 | return False
417 |
418 | def get_logging_config() -> Dict[str, Any]:
419 | """
420 | Retrieves logging configuration (level and file) from the global config.
421 | Defaults to INFO level and no file if not specified or invalid.
422 |
423 | Returns:
424 | Dictionary with 'level' and 'file' keys
425 | """
426 | config = load_config()
427 | logging_config = config.get('logging', {})
428 |
429 | if not isinstance(logging_config, dict):
430 | log.warning("Logging configuration is not a dictionary. Using default logging settings.")
431 | logging_config = {}
432 |
433 | level_str = logging_config.get('level', "INFO")
434 | valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
435 |
436 | if not isinstance(level_str, str) or level_str.upper() not in valid_levels:
437 | log.warning(f"Invalid logging level '{level_str}'. Defaulting to INFO.")
438 | level_str = "INFO"
439 | else:
440 | level_str = level_str.upper()
441 |
442 | log_file = logging_config.get('file')
443 | if log_file is not None:
444 | if not isinstance(log_file, str) or not log_file.strip():
445 | log.warning(f"Invalid logging file path '{log_file}'. Defaulting to no file.")
446 | log_file = None
447 | else:
448 | # Validate log file path
449 | try:
450 | log_path = Path(log_file).expanduser()
451 | log_path.parent.mkdir(parents=True, exist_ok=True)
452 | log_file = str(log_path)
453 | except Exception as e:
454 | log.warning(f"Invalid log file path '{log_file}': {e}. Defaulting to no file.")
455 | log_file = None
456 |
457 | return {'level': level_str, 'file': log_file}
458 |
--------------------------------------------------------------------------------
/tests/test_mcp_client.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import os
3 | import sys
4 | import base64
5 | import json
6 | from pathlib import Path
7 | from unittest.mock import patch, MagicMock
8 |
9 | import requests
10 |
11 |
12 | project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
13 | if project_root not in sys.path:
14 | sys.path.insert(0, project_root)
15 |
16 | from src import pqc_utils
17 | from src.mcp_client import MCPClient, MCPRequest, MCPResponse
18 |
19 |
20 | class TestMCPClient(unittest.TestCase):
21 |
22 | @classmethod
23 | def setUpClass(cls):
24 | """Generate keys once for all client tests."""
25 | cls.kem_algo = pqc_utils.ALGORITHMS["kem"]
26 | cls.sig_algo = pqc_utils.ALGORITHMS["sig"]
27 | cls.client_kem_pk, cls.client_kem_sk = pqc_utils.generate_key_pair(cls.kem_algo)
28 | cls.client_sign_pk, cls.client_sign_sk = pqc_utils.generate_key_pair(cls.sig_algo)
29 | cls.server_kem_pk, cls.server_kem_sk = pqc_utils.generate_key_pair(cls.kem_algo)
30 | cls.server_sign_pk, cls.server_sign_sk = pqc_utils.generate_key_pair(cls.sig_algo)
31 | cls.server_url = "http://mock-mcp-server.test:8000"
32 |
33 | def setUp(self):
34 | """Create a new client instance for each test."""
35 | self.client = MCPClient(
36 | server_url=self.server_url,
37 | client_kem_key_pair=(self.client_kem_pk, self.client_kem_sk),
38 | client_sign_key_pair=(self.client_sign_pk, self.client_sign_sk),
39 | server_kem_public_key=self.server_kem_pk,
40 | server_sign_public_key=self.server_sign_pk,
41 | )
42 |
43 |
44 | def tearDown(self):
45 |
46 | if self.client and self.client._is_connected:
47 | self.client.disconnect()
48 |
49 | def test_01_init_success(self):
50 | """Test successful client initialization."""
51 | self.assertIsNotNone(self.client)
52 | self.assertFalse(self.client._is_connected)
53 | self.assertIsNone(self.client._session_key)
54 | # Verify that the keys passed during initialization are stored, using their public attribute names
55 | self.assertEqual(self.client.client_kem_pk_bytes, self.client_kem_pk)
56 | self.assertEqual(self.client.client_kem_sk_bytes, self.client_kem_sk)
57 | self.assertEqual(self.client.client_sign_pk_bytes, self.client_sign_pk)
58 | self.assertEqual(self.client.client_sign_sk_bytes, self.client_sign_sk)
59 | self.assertEqual(self.client.server_kem_pk_bytes, self.server_kem_pk)
60 | self.assertEqual(self.client.server_sign_pk_bytes, self.server_sign_pk)
61 |
62 | def test_02_init_missing_keys(self):
63 | """Test client initialization failure with missing keys."""
64 | with self.assertRaises(ValueError):
65 | MCPClient(
66 | server_url=self.server_url,
67 | client_kem_key_pair=(self.client_kem_pk, None),
68 | client_sign_key_pair=(self.client_sign_pk, self.client_sign_sk),
69 | server_kem_public_key=self.server_kem_pk,
70 | server_sign_public_key=self.server_sign_pk,
71 | )
72 | with self.assertRaises(ValueError):
73 | MCPClient(
74 | server_url=self.server_url,
75 | client_kem_key_pair=(self.client_kem_pk, self.client_kem_sk),
76 | client_sign_key_pair=(self.client_sign_pk, self.client_sign_sk),
77 | server_kem_public_key=None,
78 | server_sign_public_key=self.server_sign_pk,
79 | )
80 |
81 |
82 | @patch('src.mcp_client.requests.Session.post')
83 | def test_03_connect_success(self, mock_post):
84 | """Test successful connection and KEM handshake."""
85 |
86 | # Server performs encapsulation with client's PK
87 | mock_ciphertext, mock_shared_secret_server = pqc_utils.kem_encapsulate(self.kem_algo, self.client_kem_pk)
88 | mock_response = MagicMock()
89 | mock_response.status_code = 200
90 | mock_response.json.return_value = {
91 | "kemCiphertextB64": base64.b64encode(mock_ciphertext).decode('utf-8')
92 | }
93 | mock_post.return_value = mock_response
94 |
95 |
96 | result = self.client.connect(self.server_url)
97 |
98 | self.assertTrue(result)
99 | self.assertTrue(self.client._is_connected)
100 | self.assertEqual(self.client._connected_server_url, self.server_url)
101 | self.assertIsNotNone(self.client._session_key)
102 |
103 |
104 | expected_aes_key = pqc_utils.derive_aes_key(mock_shared_secret_server)
105 | self.assertEqual(self.client._session_key, expected_aes_key)
106 |
107 |
108 | expected_endpoint = f"{self.server_url}/kem-handshake/initiate"
109 | expected_payload = {
110 | "clientKemPublicKeyB64": base64.b64encode(self.client_kem_pk).decode('utf-8')
111 | }
112 | # mock_post.assert_called_once_with(expected_endpoint, json=expected_payload, timeout=15)
113 | # More robust check for the call:
114 | mock_post.assert_called_once()
115 | args, kwargs = mock_post.call_args
116 | self.assertEqual(args[0], expected_endpoint)
117 | self.assertIn("json", kwargs)
118 | self.assertEqual(kwargs["json"]["client_kem_pub_key_b64"], expected_payload["clientKemPublicKeyB64"])
119 | self.assertEqual(kwargs["timeout"], 15)
120 | mock_response.raise_for_status.assert_called_once()
121 |
122 | @patch('src.mcp_client.requests.Session.post')
123 | def test_04_connect_server_error(self, mock_post):
124 | """Test connection failure due to server HTTP error."""
125 | mock_response = MagicMock()
126 | mock_response.status_code = 500
127 | mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Server Error")
128 | mock_post.return_value = mock_response
129 |
130 | result = self.client.connect(self.server_url)
131 |
132 | self.assertFalse(result)
133 | self.assertFalse(self.client._is_connected)
134 | self.assertIsNone(self.client._session_key)
135 | mock_post.assert_called_once()
136 | mock_response.raise_for_status.assert_called_once()
137 |
138 | @patch('src.mcp_client.requests.Session.post')
139 | def test_05_connect_bad_response(self, mock_post):
140 | """Test connection failure due to missing ciphertext in server response."""
141 | mock_response = MagicMock()
142 | mock_response.status_code = 200
143 | mock_response.json.return_value = {"wrong_field": "abc"}
144 | mock_post.return_value = mock_response
145 |
146 | result = self.client.connect(self.server_url)
147 |
148 | self.assertFalse(result)
149 | self.assertFalse(self.client._is_connected)
150 | self.assertIsNone(self.client._session_key)
151 | mock_post.assert_called_once()
152 | mock_response.raise_for_status.assert_called_once()
153 | mock_response.json.assert_called_once()
154 |
155 |
156 |
157 |
158 | @patch('src.mcp_client.requests.Session.post')
159 | @patch('src.mcp_client.requests.Session.close')
160 | def test_06_disconnect(self, mock_close, mock_post):
161 | """Test disconnection resets state."""
162 |
163 | mock_ciphertext, _ = pqc_utils.kem_encapsulate(self.kem_algo, self.client_kem_pk)
164 | mock_conn_response = MagicMock()
165 | mock_conn_response.status_code = 200
166 | mock_conn_response.json.return_value = {
167 | "kemCiphertextB64": base64.b64encode(mock_ciphertext).decode('utf-8')
168 | }
169 | mock_post.return_value = mock_conn_response
170 | self.client.connect(self.server_url)
171 | self.assertTrue(self.client._is_connected)
172 | self.assertIsNotNone(self.client._session_key)
173 |
174 |
175 | self.client.disconnect()
176 |
177 |
178 | self.assertFalse(self.client._is_connected)
179 | self.assertIsNone(self.client._session_key)
180 | self.assertIsNone(self.client._connected_server_url)
181 | mock_close.assert_called_once()
182 |
183 |
184 | mock_close.reset_mock()
185 | self.client.disconnect()
186 | mock_close.assert_not_called()
187 |
188 |
189 | def _mock_connect(self, mock_post):
190 | """Helper to mock a successful connection for subsequent tests."""
191 | mock_ciphertext, shared_secret = pqc_utils.kem_encapsulate(self.kem_algo, self.client_kem_pk)
192 | mock_response = MagicMock()
193 | mock_response.status_code = 200
194 | mock_response.json.return_value = {
195 | "kemCiphertextB64": base64.b64encode(mock_ciphertext).decode('utf-8')
196 | }
197 | mock_post.return_value = mock_response
198 | self.assertTrue(self.client.connect(self.server_url), "Mock connection setup failed")
199 |
200 | return pqc_utils.derive_aes_key(shared_secret)
201 |
202 |
203 | @patch('src.mcp_client.requests.Session.post')
204 | def test_07_send_request_success(self, mock_post):
205 | """Test sending a request successfully with valid server attestation."""
206 | session_aes_key = self._mock_connect(mock_post)
207 | mock_post.reset_mock()
208 |
209 | input_data = {"text": "test input"}
210 | request = MCPRequest(
211 | target_server_url=self.server_url,
212 | model_id="test_model",
213 | input_data=input_data
214 | )
215 |
216 | attestation_data = {"serverVersion": "test-1.0", "modelId": "test_model", "status": "success"}
217 | attestation_bytes = json.dumps(attestation_data, sort_keys=True, separators=(',', ':')).encode('utf-8')
218 | server_attestation_sig = pqc_utils.sign_message(attestation_bytes, self.server_sign_sk, pqc_utils.ALGORITHMS["sig"])
219 |
220 | response_cleartext_dict = {
221 | "status": "success",
222 | "output_data": {"result": "mock success"},
223 | "error_message": None,
224 | "attestation_data": attestation_data,
225 | "attestation_signature_b64": base64.b64encode(server_attestation_sig).decode('utf-8'),
226 | "audit_hash": None
227 | }
228 | response_cleartext_bytes = json.dumps(response_cleartext_dict).encode('utf-8')
229 | resp_nonce, resp_ciphertext = pqc_utils.encrypt_aes_gcm(session_aes_key, response_cleartext_bytes)
230 |
231 | mock_inference_response = MagicMock()
232 | mock_inference_response.status_code = 200
233 | mock_inference_response.json.return_value = {
234 | "nonceB64": base64.b64encode(resp_nonce).decode('utf-8'),
235 | "encryptedPayloadB64": base64.b64encode(resp_ciphertext).decode('utf-8')
236 | }
237 | mock_post.return_value = mock_inference_response
238 |
239 | response = self.client.send_request(request)
240 |
241 | self.assertIsNotNone(response)
242 | self.assertEqual(response.status, "success")
243 | self.assertEqual(response.output_data, {"result": "mock success"})
244 | self.assertIsNone(response.error_message)
245 | self.assertEqual(response.attestation_data, attestation_data)
246 | self.assertEqual(response.attestation_signature, server_attestation_sig)
247 |
248 |
249 | expected_endpoint = f"{self.server_url}/inference"
250 | mock_post.assert_called_once()
251 | call_args, call_kwargs = mock_post.call_args
252 | self.assertEqual(call_args[0], expected_endpoint)
253 | self.assertIn('json', call_kwargs)
254 | sent_body = call_kwargs['json']
255 |
256 | self.assertEqual(sent_body['clientKemPublicKeyB64'], base64.b64encode(self.client_kem_pk).decode('utf-8'))
257 | self.assertIn('nonceB64', sent_body)
258 | self.assertIn('encryptedPayloadB64', sent_body)
259 |
260 |
261 | sent_nonce = base64.b64decode(sent_body['nonceB64'])
262 | sent_ciphertext = base64.b64decode(sent_body['encryptedPayloadB64'])
263 | decrypted_sent_payload_bytes = pqc_utils.decrypt_aes_gcm(session_aes_key, sent_nonce, sent_ciphertext)
264 | decrypted_sent_payload_dict = json.loads(decrypted_sent_payload_bytes.decode('utf-8'))
265 |
266 | self.assertEqual(decrypted_sent_payload_dict['target_server_url'], self.server_url)
267 | self.assertEqual(decrypted_sent_payload_dict['model_id'], "test_model")
268 | self.assertEqual(decrypted_sent_payload_dict['input_data'], input_data)
269 | self.assertIn('pqc_signature_b64', decrypted_sent_payload_dict)
270 |
271 |
272 | sent_sig_bytes = base64.b64decode(decrypted_sent_payload_dict['pqc_signature_b64'])
273 | payload_to_verify = {
274 | "target_server_url": self.server_url,
275 | "model_id": "test_model",
276 | "input_data": input_data,
277 | }
278 | payload_to_verify_bytes = json.dumps(payload_to_verify, sort_keys=True, separators=(',', ':')).encode('utf-8')
279 | self.assertTrue(pqc_utils.verify_signature(payload_to_verify_bytes, sent_sig_bytes, self.client_sign_pk, pqc_utils.ALGORITHMS["sig"]))
280 |
281 | @patch('src.mcp_client.requests.Session.post')
282 | def test_08_send_request_not_connected(self, mock_post):
283 | """Test send_request fails if client is not connected."""
284 |
285 | request = MCPRequest(self.server_url, "test", {})
286 | response = self.client.send_request(request)
287 | self.assertIsNotNone(response)
288 | self.assertEqual(response.status, 'error')
289 | self.assertIn("Client not connected or session key missing", response.error_message)
290 | mock_post.assert_not_called()
291 |
292 | @patch('src.mcp_client.requests.Session.post')
293 | def test_09_send_request_invalid_attestation(self, mock_post):
294 | """Test send_request with invalid server attestation signature."""
295 | session_aes_key = self._mock_connect(mock_post)
296 | mock_post.reset_mock()
297 |
298 | request = MCPRequest(self.server_url, "test_model", {"text": "test input"})
299 |
300 |
301 | attestation_data = {"serverVersion": "test-1.0", "modelId": "test_model", "status": "success"}
302 |
303 | wrong_sk, _ = pqc_utils.generate_key_pair(self.sig_algo)
304 | attestation_bytes = json.dumps(attestation_data, sort_keys=True, separators=(',', ':')).encode('utf-8')
305 | bad_server_attestation_sig = pqc_utils.sign_message(attestation_bytes, wrong_sk, pqc_utils.ALGORITHMS["sig"])
306 |
307 | response_cleartext_dict = {
308 | "status": "success",
309 | "output_data": {"result": "mock success"},
310 | "error_message": None,
311 | "attestation_data": attestation_data,
312 | "attestation_signature_b64": base64.b64encode(bad_server_attestation_sig).decode('utf-8'),
313 | "audit_hash": None
314 | }
315 | response_cleartext_bytes = json.dumps(response_cleartext_dict).encode('utf-8')
316 | resp_nonce, resp_ciphertext = pqc_utils.encrypt_aes_gcm(session_aes_key, response_cleartext_bytes)
317 |
318 | mock_inference_response = MagicMock()
319 | mock_inference_response.status_code = 200
320 | mock_inference_response.json.return_value = {
321 | "nonceB64": base64.b64encode(resp_nonce).decode('utf-8'),
322 | "encryptedPayloadB64": base64.b64encode(resp_ciphertext).decode('utf-8')
323 | }
324 | mock_post.return_value = mock_inference_response
325 |
326 |
327 | response = self.client.send_request(request)
328 |
329 |
330 | self.assertIsNotNone(response)
331 | self.assertEqual(response.status, "error")
332 | self.assertIn("ATTENTION: Attestation verification FAILED", response.error_message)
333 |
334 | self.assertEqual(response.output_data, {"result": "mock success"})
335 | self.assertEqual(response.attestation_data, attestation_data)
336 | self.assertEqual(response.attestation_signature, bad_server_attestation_sig)
337 |
338 | @patch('src.mcp_client.requests.Session.post')
339 | def test_10_send_request_server_error_response(self, mock_post):
340 | """Test send_request handling server sending an error status."""
341 | session_aes_key = self._mock_connect(mock_post)
342 | mock_post.reset_mock()
343 |
344 | request = MCPRequest(self.server_url, "bad_model", {"text": "test input"})
345 |
346 |
347 |
348 | attestation_data = {"serverVersion": "test-1.0", "modelId": "bad_model", "status": "error"}
349 | attestation_bytes = json.dumps(attestation_data, sort_keys=True, separators=(',', ':')).encode('utf-8')
350 | server_attestation_sig = pqc_utils.sign_message(attestation_bytes, self.server_sign_sk, pqc_utils.ALGORITHMS["sig"])
351 |
352 | response_cleartext_dict = {
353 | "status": "error",
354 | "output_data": None,
355 | "error_message": "Model not found",
356 | "attestation_data": attestation_data,
357 | "attestation_signature_b64": base64.b64encode(server_attestation_sig).decode('utf-8'),
358 | "audit_hash": None
359 | }
360 | response_cleartext_bytes = json.dumps(response_cleartext_dict).encode('utf-8')
361 | resp_nonce, resp_ciphertext = pqc_utils.encrypt_aes_gcm(session_aes_key, response_cleartext_bytes)
362 |
363 | mock_inference_response = MagicMock()
364 | mock_inference_response.status_code = 200
365 | mock_inference_response.json.return_value = {
366 | "nonceB64": base64.b64encode(resp_nonce).decode('utf-8'),
367 | "encryptedPayloadB64": base64.b64encode(resp_ciphertext).decode('utf-8')
368 | }
369 | mock_post.return_value = mock_inference_response
370 |
371 |
372 | response = self.client.send_request(request)
373 |
374 | # Assertions: Client should reflect the server's error status, but verify attestation
375 | self.assertIsNotNone(response)
376 | self.assertEqual(response.status, "error")
377 | self.assertEqual(response.error_message, "Model not found")
378 | self.assertIsNone(response.output_data)
379 | self.assertEqual(response.attestation_data, attestation_data)
380 | self.assertEqual(response.attestation_signature, server_attestation_sig)
381 |
382 | @patch('src.mcp_client.requests.Session.post')
383 | def test_11_send_request_decryption_error(self, mock_post):
384 | """Test send_request handling failure to decrypt server response."""
385 | session_aes_key = self._mock_connect(mock_post)
386 | mock_post.reset_mock()
387 |
388 | request = MCPRequest(self.server_url, "test_model", {"text": "test input"})
389 |
390 | response_cleartext_dict = {"status": "success", "output_data": "abc"}
391 | response_cleartext_bytes = json.dumps(response_cleartext_dict).encode('utf-8')
392 | resp_nonce, resp_ciphertext = pqc_utils.encrypt_aes_gcm(session_aes_key, response_cleartext_bytes)
393 |
394 | tampered_ciphertext = resp_ciphertext[:-1] + bytes([(resp_ciphertext[-1] + 1) % 256])
395 |
396 | mock_inference_response = MagicMock()
397 | mock_inference_response.status_code = 200
398 | mock_inference_response.json.return_value = {
399 | "nonceB64": base64.b64encode(resp_nonce).decode('utf-8'),
400 | "encryptedPayloadB64": base64.b64encode(tampered_ciphertext).decode('utf-8')
401 | }
402 | mock_post.return_value = mock_inference_response
403 |
404 |
405 | response = self.client.send_request(request)
406 |
407 | self.assertIsNotNone(response)
408 | self.assertEqual(response.status, "error")
409 | self.assertIn("AES-GCM decryption failed (InvalidTag)", response.error_message) # Corrected assertion
410 | self.assertIsNone(response.output_data)
411 | self.assertIsNone(response.attestation_data)
412 |
413 |
414 | if __name__ == '__main__':
415 | unittest.main()
--------------------------------------------------------------------------------
/src/mcp_client.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Any, Dict, Optional, Tuple
3 | import base64
4 | import os
5 | import requests
6 | import json
7 | import logging
8 | import time
9 | from pathlib import Path
10 | from cryptography.exceptions import InvalidTag
11 | from urllib3.util.retry import Retry
12 | from requests.adapters import HTTPAdapter
13 |
14 | from .pqc_utils import (
15 | ALGORITHMS,
16 | sign_message,
17 | verify_signature,
18 | kem_decapsulate,
19 | encrypt_aes_gcm,
20 | decrypt_aes_gcm,
21 | derive_aes_key,
22 | PQCError,
23 | PQCKEMError,
24 | PQCSignatureError,
25 | PQCEncryptionError,
26 | PQCDecryptionError
27 | )
28 |
29 | # Connection and retry configuration
30 | DEFAULT_TIMEOUT = 30
31 | MAX_RETRIES = 3
32 | BACKOFF_FACTOR = 0.3
33 | RETRY_STATUS_CODES = [500, 502, 503, 504]
34 |
35 | @dataclass
36 | class MCPRequest:
37 | """Represents a request to an MCP server."""
38 | target_server_url: str
39 | model_id: str
40 | input_data: Dict[str, Any]
41 | policy_id: Optional[str] = None
42 | pqc_signature: Optional[bytes] = None
43 | client_signing_pubkey: Optional[bytes] = None
44 |
45 | def __post_init__(self):
46 | """Validate request data after initialization."""
47 | if not isinstance(self.target_server_url, str) or not self.target_server_url.strip():
48 | raise ValueError("target_server_url must be a non-empty string")
49 | if not isinstance(self.model_id, str) or not self.model_id.strip():
50 | raise ValueError("model_id must be a non-empty string")
51 |
52 |
53 | def to_dict(self) -> Dict[str, Any]:
54 | """Convert request to dictionary format."""
55 | data = {
56 | "modelId": self.model_id,
57 | "inputData": self.input_data,
58 | "policyId": self.policy_id,
59 | "pqcSignature": base64.b64encode(self.pqc_signature).decode('utf-8') if self.pqc_signature else None,
60 | "clientSigningPubkey": base64.b64encode(self.client_signing_pubkey).decode('utf-8') if self.client_signing_pubkey else None
61 | }
62 |
63 | return {k: v for k, v in data.items() if v is not None}
64 |
65 |
66 | @dataclass
67 | class MCPResponse:
68 | """Represents a response from an MCP server."""
69 | status: str
70 | output_data: Optional[Any] = None
71 | error_message: Optional[str] = None
72 | attestation_data: Optional[Dict[str, Any]] = None
73 | attestation_signature: Optional[bytes] = None
74 | audit_hash: Optional[str] = None
75 |
76 | def __post_init__(self):
77 | """Validate response data after initialization."""
78 | if not isinstance(self.status, str) or not self.status.strip():
79 | raise ValueError("status must be a non-empty string")
80 |
81 | @classmethod
82 | def from_dict(cls, data: Dict[str, Any]) -> 'MCPResponse':
83 | """Create MCPResponse from dictionary data."""
84 | if not isinstance(data, dict):
85 | raise ValueError("Response data must be a dictionary")
86 |
87 | raw_signature_b64 = data.get('attestation_signature_b64')
88 | logger = logging.getLogger(__name__)
89 | logger.debug(f"MCPResponse.from_dict: Raw attestation_signature_b64 from dict: {'PRESENT and NON-EMPTY' if raw_signature_b64 else ('PRESENT BUT EMPTY/NONE' if raw_signature_b64 == '' or raw_signature_b64 is None else 'MISSING')}")
90 |
91 | if raw_signature_b64 and len(raw_signature_b64) > 20:
92 | logger.debug(f"MCPResponse.from_dict: att_sig_b64 (first 20 chars): {raw_signature_b64[:20]}...")
93 |
94 | try:
95 | attestation_signature = base64.b64decode(raw_signature_b64) if raw_signature_b64 else None
96 | except Exception as e:
97 | logger.warning(f"Failed to decode attestation signature: {e}")
98 | attestation_signature = None
99 |
100 | return cls(
101 | status=data.get('status', 'error'),
102 | output_data=data.get('output_data'),
103 | error_message=data.get('error_message'),
104 | attestation_data=data.get('attestation_data'),
105 | attestation_signature=attestation_signature,
106 | audit_hash=data.get('audit_hash')
107 | )
108 |
109 |
110 | class MCPClient:
111 | """Client for secure interaction with Quantum-Safe MCP Servers.
112 |
113 | Handles PQC key management (Kyber for KEM, SPHINCS+ for signatures),
114 | request signing, response attestation verification, and session management
115 | with enhanced error handling and connection pooling.
116 | """
117 |
118 | def __init__(
119 | self,
120 | server_url: str,
121 | client_kem_key_pair: Tuple[bytes, bytes],
122 | client_sign_key_pair: Tuple[bytes, bytes],
123 | server_kem_public_key: bytes,
124 | server_sign_public_key: bytes,
125 | timeout: int = DEFAULT_TIMEOUT,
126 | max_retries: int = MAX_RETRIES
127 | ):
128 | """Initializes the MCP Client with required configuration and keys.
129 |
130 | Args:
131 | server_url: The base URL of the target MCP server.
132 | client_kem_key_pair: Tuple of (public_key_bytes, secret_key_bytes) for client KEM.
133 | client_sign_key_pair: Tuple of (public_key_bytes, secret_key_bytes) for client signing.
134 | server_kem_public_key: Public key bytes for server KEM.
135 | server_sign_public_key: Public key bytes for server signing.
136 | timeout: Request timeout in seconds.
137 | max_retries: Maximum number of retry attempts.
138 |
139 | Raises:
140 | ValueError: If any required key is missing or invalid.
141 | """
142 | self.log = logging.getLogger(__name__)
143 |
144 | # Validate inputs
145 | self._validate_initialization_params(
146 | server_url, client_kem_key_pair, client_sign_key_pair,
147 | server_kem_public_key, server_sign_public_key
148 | )
149 |
150 | self.client_kem_pk_bytes, self.client_kem_sk_bytes = client_kem_key_pair
151 | self.client_sign_pk_bytes, self.client_sign_sk_bytes = client_sign_key_pair
152 | self.server_kem_pk_bytes = server_kem_public_key
153 | self.server_sign_pk_bytes = server_sign_public_key
154 |
155 | self._session_key: Optional[bytes] = None
156 | self._is_connected = False
157 | self._connected_server_url: Optional[str] = None
158 | self._timeout = timeout
159 | self._max_retries = max_retries
160 |
161 | self._http_session = self._create_http_session()
162 |
163 | self.log.info("MCPClient initialized successfully with provided keys.")
164 |
165 | def _validate_initialization_params(
166 | self,
167 | server_url: str,
168 | client_kem_key_pair: Tuple[bytes, bytes],
169 | client_sign_key_pair: Tuple[bytes, bytes],
170 | server_kem_public_key: bytes,
171 | server_sign_public_key: bytes
172 | ):
173 | """Validate initialization parameters."""
174 | if not isinstance(server_url, str) or not server_url.strip():
175 | raise ValueError("server_url must be a non-empty string")
176 |
177 | if not isinstance(client_kem_key_pair, tuple) or len(client_kem_key_pair) != 2:
178 | raise ValueError("client_kem_key_pair must be a tuple of (public_key, secret_key)")
179 |
180 | if not isinstance(client_sign_key_pair, tuple) or len(client_sign_key_pair) != 2:
181 | raise ValueError("client_sign_key_pair must be a tuple of (public_key, secret_key)")
182 |
183 | for key_name, key_bytes in [
184 | ("client_kem_public_key", client_kem_key_pair[0]),
185 | ("client_kem_secret_key", client_kem_key_pair[1]),
186 | ("client_sign_public_key", client_sign_key_pair[0]),
187 | ("client_sign_secret_key", client_sign_key_pair[1]),
188 | ("server_kem_public_key", server_kem_public_key),
189 | ("server_sign_public_key", server_sign_public_key),
190 | ]:
191 | if not isinstance(key_bytes, bytes) or len(key_bytes) == 0:
192 | raise ValueError(f"{key_name} must be non-empty bytes")
193 |
194 | def _create_http_session(self) -> requests.Session:
195 | """Create HTTP session with retry strategy and connection pooling."""
196 | session = requests.Session()
197 |
198 | retry_strategy = Retry(
199 | total=self._max_retries,
200 | status_forcelist=RETRY_STATUS_CODES,
201 | backoff_factor=BACKOFF_FACTOR,
202 | allowed_methods=["GET", "POST"]
203 | )
204 |
205 | adapter = HTTPAdapter(
206 | max_retries=retry_strategy,
207 | pool_connections=10,
208 | pool_maxsize=20
209 | )
210 |
211 | session.mount("http://", adapter)
212 | session.mount("https://", adapter)
213 |
214 | session.headers.update({
215 | 'Content-Type': 'application/json',
216 | 'User-Agent': 'qu3-client/1.0',
217 | 'Accept': 'application/json'
218 | })
219 |
220 | return session
221 |
222 | def connect(self, server_url: str) -> bool:
223 | """Establishes a secure session with the target server via KEM handshake.
224 |
225 | Args:
226 | server_url: The base URL of the MCP server.
227 |
228 | Returns:
229 | bool: True if the connection and KEM handshake succeed, False otherwise.
230 | """
231 | if not isinstance(server_url, str) or not server_url.strip():
232 | self.log.error("Server URL must be a non-empty string")
233 | return False
234 |
235 | if self._is_connected:
236 | if server_url == self._connected_server_url:
237 | self.log.info(f"Already connected to {self._connected_server_url}. Skipping connection attempt.")
238 | return True
239 | else:
240 | self.log.warning(f"Client already connected to {self._connected_server_url}, disconnecting before connecting to {server_url}.")
241 | self.disconnect()
242 |
243 | if not self.client_kem_pk_bytes or not self.client_kem_sk_bytes or not self.client_sign_pk_bytes:
244 | self.log.error("Connection failed: Client KEM or Signing public keys are not loaded.")
245 | return False
246 |
247 | self.log.info(f"Initiating connection and KEM ({ALGORITHMS['kem']}) handshake with {server_url}...")
248 |
249 | handshake_endpoint = f"{server_url.rstrip('/')}/kem-handshake/initiate"
250 |
251 | for attempt in range(self._max_retries + 1):
252 | try:
253 | client_kem_pk_b64 = base64.b64encode(self.client_kem_pk_bytes).decode('utf-8')
254 | client_sign_pk_b64 = base64.b64encode(self.client_sign_pk_bytes).decode('utf-8')
255 |
256 | request_payload = {
257 | "client_kem_pub_key_b64": client_kem_pk_b64,
258 | "client_sign_pub_key_b64": client_sign_pk_b64
259 | }
260 | self.log.debug(f"Sending client KEM & Signing public keys to {handshake_endpoint}...")
261 |
262 | http_response = self._http_session.post(
263 | handshake_endpoint,
264 | json=request_payload,
265 | timeout=self._timeout
266 | )
267 | http_response.raise_for_status()
268 |
269 | response_data = http_response.json()
270 | ciphertext_b64 = response_data.get("kemCiphertextB64")
271 | if not ciphertext_b64:
272 | self.log.error("Handshake failed: Server response missing KEM ciphertext.")
273 | return False
274 |
275 | ciphertext = base64.b64decode(ciphertext_b64)
276 | self.log.debug(f"Received KEM ciphertext (length: {len(ciphertext)} bytes) from server.")
277 |
278 | kem_shared_secret = kem_decapsulate(ALGORITHMS['kem'], ciphertext, self.client_kem_sk_bytes)
279 |
280 | self._session_key = derive_aes_key(kem_shared_secret)
281 | self.log.info(f"Successfully derived AES session key (length: {len(self._session_key)} bytes). Secure session established.")
282 |
283 | self._connected_server_url = server_url
284 | self._is_connected = True
285 | self.log.info(f"Connection to {self._connected_server_url} successful.")
286 | return True
287 |
288 | except requests.exceptions.Timeout as e:
289 | self.log.warning(f"Timeout during handshake attempt {attempt + 1}/{self._max_retries + 1}: {e}")
290 | except requests.exceptions.ConnectionError as e:
291 | self.log.warning(f"Connection error during handshake attempt {attempt + 1}/{self._max_retries + 1}: {e}")
292 | except requests.exceptions.HTTPError as e:
293 | self.log.error(f"HTTP error during KEM handshake: {e}")
294 | return False # Don't retry on HTTP errors
295 | except requests.exceptions.RequestException as e:
296 | self.log.warning(f"Network error during handshake attempt {attempt + 1}/{self._max_retries + 1}: {e}")
297 | except (json.JSONDecodeError, TypeError, ValueError, base64.binascii.Error) as e:
298 | self.log.error(f"Failed to decode or parse server handshake response: {e}")
299 | return False # Don't retry on parsing errors
300 | except PQCKEMError as e:
301 | self.log.error(f"PQC KEM operation failed during handshake: {e}")
302 | return False # Don't retry on PQC errors
303 | except Exception as e:
304 | self.log.exception(f"Unexpected error during KEM handshake attempt {attempt + 1}/{self._max_retries + 1}: {e}")
305 |
306 | if attempt < self._max_retries:
307 | delay = BACKOFF_FACTOR * (2 ** attempt)
308 | self.log.info(f"Retrying handshake in {delay:.1f} seconds...")
309 | time.sleep(delay)
310 |
311 | self.log.error(f"Failed to establish connection after {self._max_retries + 1} attempts")
312 | return False
313 |
314 | def send_request(self, request: MCPRequest) -> Optional[MCPResponse]:
315 | """Sends a signed and encrypted request to the connected MCP server.
316 |
317 | Handles PQC signing of the request payload, AES-GCM encryption,
318 | network communication, decryption of the response, and verification
319 | of the server's attestation signature.
320 | """
321 | if not isinstance(request, MCPRequest):
322 | self.log.error("Request must be an MCPRequest instance")
323 | return MCPResponse(status='error', error_message="Invalid request type")
324 |
325 | if not self._session_key:
326 | self.log.error("Cannot send request: No active session key. Call connect() first.")
327 | return MCPResponse(status='error', error_message="Client not connected or session key missing.")
328 | if not self._connected_server_url:
329 | self.log.error("Cannot send request: Not connected to any server.")
330 | return MCPResponse(status='error', error_message="Client not connected.")
331 |
332 | self.log.info(f"Preparing encrypted & signed request for model '{request.model_id}' to {self._connected_server_url}...")
333 |
334 | payload_to_sign = {
335 | "target_server_url": request.target_server_url,
336 | "model_id": request.model_id,
337 | "input_data": request.input_data,
338 | }
339 |
340 | payload_bytes_to_sign = json.dumps(payload_to_sign, sort_keys=True, separators=(',', ':')).encode('utf-8')
341 | self.log.debug(f"Payload to sign: {payload_bytes_to_sign.decode()}")
342 |
343 | try:
344 | self.log.debug(f"Signing payload with {ALGORITHMS['sig']}...")
345 | signature_bytes = sign_message(payload_bytes_to_sign, self.client_sign_sk_bytes, ALGORITHMS["sig"])
346 | self.log.debug(f"Signature generated ({len(signature_bytes)} bytes).")
347 | except PQCSignatureError as e:
348 | self.log.exception(f"PQC signing failed: {e}")
349 | return MCPResponse(status='error', error_message=f"Client-side PQC signing failed: {e}")
350 | except Exception as e:
351 | self.log.exception(f"Unexpected error during signing: {e}")
352 | return MCPResponse(status='error', error_message=f"Unexpected client-side signing error: {e}")
353 |
354 | payload_with_sig = payload_to_sign.copy()
355 | payload_with_sig["pqc_signature_b64"] = base64.b64encode(signature_bytes).decode('utf-8')
356 | self.log.debug("Added signature to payload.")
357 |
358 | try:
359 | final_payload_bytes = json.dumps(payload_with_sig, sort_keys=True, separators=(',', ':')).encode('utf-8')
360 | self.log.debug(f"Encrypting final payload ({len(final_payload_bytes)} bytes) with AES-GCM...")
361 | nonce_bytes, ciphertext_bytes = encrypt_aes_gcm(self._session_key, final_payload_bytes)
362 | self.log.debug(f"Encryption successful. Nonce: {nonce_bytes.hex()[:16]}..., Ciphertext: {ciphertext_bytes.hex()[:16]}...")
363 | except PQCEncryptionError as e:
364 | self.log.exception(f"AES-GCM encryption failed: {e}")
365 | return MCPResponse(status='error', error_message=f"Client-side encryption failed: {e}")
366 | except (ValueError, TypeError) as e:
367 | self.log.exception(f"AES-GCM encryption failed: {e}")
368 | return MCPResponse(status='error', error_message=f"Client-side encryption failed: {e}")
369 | except Exception as e:
370 | self.log.exception(f"Unexpected error during encryption: {e}")
371 | return MCPResponse(status='error', error_message=f"Unexpected client-side encryption error: {e}")
372 |
373 | request_body = {
374 | "clientKemPublicKeyB64": base64.b64encode(self.client_kem_pk_bytes).decode('utf-8'),
375 | "nonceB64": base64.b64encode(nonce_bytes).decode('utf-8'),
376 | "encryptedPayloadB64": base64.b64encode(ciphertext_bytes).decode('utf-8'),
377 | }
378 |
379 | inference_endpoint = f"{self._connected_server_url.rstrip('/')}/inference"
380 | self.log.info(f"Sending encrypted POST request to {inference_endpoint}...")
381 |
382 | for attempt in range(self._max_retries + 1):
383 | try:
384 | http_response = self._http_session.post(
385 | inference_endpoint,
386 | json=request_body,
387 | timeout=self._timeout
388 | )
389 | http_response.raise_for_status()
390 | self.log.info(f"Received HTTP response: {http_response.status_code}")
391 |
392 | encrypted_response_data = http_response.json()
393 | nonce_b64 = encrypted_response_data.get("nonceB64")
394 | encrypted_payload_b64 = encrypted_response_data.get("encryptedPayloadB64")
395 |
396 | if not nonce_b64 or not encrypted_payload_b64:
397 | self.log.error("Server response missing nonce or encrypted payload.")
398 | return MCPResponse(status='error', error_message="Invalid encrypted response format from server")
399 |
400 | self.log.debug("Decrypting server response...")
401 | nonce = base64.b64decode(nonce_b64)
402 | encrypted_payload = base64.b64decode(encrypted_payload_b64)
403 |
404 | decrypted_payload_bytes = decrypt_aes_gcm(self._session_key, nonce, encrypted_payload)
405 | response_dict = json.loads(decrypted_payload_bytes.decode('utf-8'))
406 | self.log.debug(f"Received and decrypted response payload: {json.dumps(response_dict)[:200]}...")
407 |
408 | server_response = MCPResponse.from_dict(response_dict)
409 | break
410 |
411 | except requests.exceptions.Timeout as e:
412 | self.log.warning(f"Timeout sending request attempt {attempt + 1}/{self._max_retries + 1}: {e}")
413 | if attempt == self._max_retries:
414 | return MCPResponse(status='error', error_message=f"Request timeout after {self._max_retries + 1} attempts")
415 | except requests.exceptions.ConnectionError as e:
416 | self.log.warning(f"Connection error sending request attempt {attempt + 1}/{self._max_retries + 1}: {e}")
417 | if attempt == self._max_retries:
418 | return MCPResponse(status='error', error_message=f"Connection error after {self._max_retries + 1} attempts")
419 | except requests.exceptions.HTTPError as e:
420 | self.log.error(f"HTTP error sending request to {inference_endpoint}: {e}")
421 | return MCPResponse(status='error', error_message=f"HTTP error: {e}")
422 | except requests.exceptions.RequestException as e:
423 | self.log.warning(f"Network error sending request attempt {attempt + 1}/{self._max_retries + 1}: {e}")
424 | if attempt == self._max_retries:
425 | return MCPResponse(status='error', error_message=f"Network error after {self._max_retries + 1} attempts: {e}")
426 | except (json.JSONDecodeError, base64.binascii.Error, ValueError, TypeError) as e:
427 | self.log.error(f"Error decoding/parsing server response: {e}")
428 | return MCPResponse(status='error', error_message=f"Failed to process server response: {e}")
429 | except PQCDecryptionError as e:
430 | self.log.error(f"Failed to decrypt server response: {e}")
431 | return MCPResponse(status='error', error_message=f"Server response decryption failed: {e}")
432 | except Exception as e:
433 | self.log.exception(f"Unexpected error handling server response attempt {attempt + 1}/{self._max_retries + 1}:")
434 | if attempt == self._max_retries:
435 | return MCPResponse(status='error', error_message=f"Client-side response handling error: {e}")
436 |
437 | if attempt < self._max_retries:
438 | delay = BACKOFF_FACTOR * (2 ** attempt)
439 | self.log.info(f"Retrying request in {delay:.1f} seconds...")
440 | time.sleep(delay)
441 |
442 | # Verify server attestation
443 | self.log.info("Verifying server response attestation signature...")
444 | attestation_ok, verification_error_msg = self._verify_attestation(server_response)
445 | if not attestation_ok:
446 | error_prefix = "ATTENTION: Attestation verification FAILED"
447 | if verification_error_msg:
448 | error_prefix += f" ({verification_error_msg})"
449 |
450 | server_response.status = 'error'
451 | server_response.error_message = f"{error_prefix}{'; ' + server_response.error_message if server_response.error_message else ''}"
452 | self.log.warning(server_response.error_message)
453 | return server_response
454 | else:
455 | self.log.info("Server attestation signature VERIFIED.")
456 |
457 | return server_response
458 |
459 | def encrypt_payload(self, payload_bytes: bytes) -> Tuple[bytes, bytes]:
460 | """Encrypts the given payload using the current session key.
461 |
462 | Args:
463 | payload_bytes: The bytes to encrypt.
464 |
465 | Returns:
466 | Tuple[bytes, bytes]: The nonce and the ciphertext.
467 |
468 | Raises:
469 | RuntimeError: If no active session key is available (client not connected).
470 | PQCError: If AES-GCM encryption fails.
471 | """
472 | if not isinstance(payload_bytes, bytes):
473 | raise ValueError("Payload must be bytes")
474 |
475 | if not self._session_key:
476 | self.log.error("Cannot encrypt payload: No active session key. Call connect() first.")
477 | raise RuntimeError("Cannot encrypt payload: No active session key. Client must be connected.")
478 |
479 | self.log.debug(f"Encrypting payload of length {len(payload_bytes)} bytes with AES-GCM using session key...")
480 | try:
481 | nonce, ciphertext = encrypt_aes_gcm(self._session_key, payload_bytes)
482 | self.log.debug(f"Payload encryption successful. Nonce: {nonce.hex()[:16]}..., Ciphertext: {ciphertext.hex()[:16]}...")
483 | return nonce, ciphertext
484 | except PQCEncryptionError as e:
485 | self.log.exception(f"AES-GCM encryption failed during encrypt_payload: {e}")
486 | raise PQCEncryptionError(f"Payload encryption failed: {e}")
487 | except Exception as e:
488 | self.log.exception(f"Unexpected error during encrypt_payload: {e}")
489 | raise RuntimeError(f"Payload encryption failed with an unexpected error: {e}")
490 |
491 | def is_connected(self) -> bool:
492 | """Check if client is currently connected to a server.
493 |
494 | Returns:
495 | True if connected, False otherwise
496 | """
497 | return self._is_connected and self._session_key is not None
498 |
499 | def get_connected_server(self) -> Optional[str]:
500 | """Get the URL of the currently connected server.
501 |
502 | Returns:
503 | Server URL if connected, None otherwise
504 | """
505 | return self._connected_server_url if self._is_connected else None
506 |
507 | def disconnect(self):
508 | """Disconnect from the current server and clean up resources."""
509 | if not self._is_connected:
510 | return
511 |
512 | connected_url = self._connected_server_url
513 | self.log.info(f"Disconnecting from {connected_url}...")
514 |
515 | self._session_key = None
516 | self._is_connected = False
517 | self._connected_server_url = None
518 |
519 | if hasattr(self, '_http_session'):
520 | self._http_session.close()
521 | # Recreate session for potential future connections
522 | self._http_session = self._create_http_session()
523 |
524 | self.log.info(f"Disconnected from {connected_url}.")
525 |
526 | def _verify_attestation(self, response: MCPResponse) -> Tuple[bool, Optional[str]]:
527 | """Verifies the server's PQC attestation signature.
528 |
529 | Returns:
530 | Tuple[bool, Optional[str]]: (True if valid, error message if verification fails or data missing).
531 | """
532 | if not self.server_sign_pk_bytes:
533 | msg = "Cannot verify attestation: Server signing public key is not loaded."
534 | self.log.error(msg)
535 | return False, msg
536 |
537 | self.log.debug(f"_verify_attestation: response.attestation_data is {'PRESENT' if response.attestation_data else 'MISSING/EMPTY'}")
538 | sig_val = response.attestation_signature
539 | self.log.debug(f"_verify_attestation: response.attestation_signature is {'PRESENT and NON-EMPTY bytes' if sig_val and isinstance(sig_val, bytes) and len(sig_val) > 0 else ('NONE or EMPTY bytes' if sig_val == b'' or sig_val is None else 'UNEXPECTED TYPE')}")
540 | if sig_val and isinstance(sig_val, bytes) and len(sig_val) > 0:
541 | self.log.debug(f"_verify_attestation: att_sig (first 20 hex): {sig_val.hex()[:20]}...")
542 |
543 | if not response.attestation_data or not response.attestation_signature:
544 | msg = "Cannot verify attestation: Attestation data or signature missing."
545 | self.log.warning(msg)
546 | return False, msg
547 |
548 | try:
549 | attestation_string = json.dumps(response.attestation_data, sort_keys=True, separators=(',', ':'))
550 | attestation_bytes = attestation_string.encode('utf-8')
551 | except Exception as e:
552 | msg = f"Error serializing attestation data: {e}"
553 | self.log.exception(msg)
554 | return False, msg
555 |
556 | self.log.debug(f"Verifying server attestation signature ({ALGORITHMS['sig']}) against known server public key...")
557 | try:
558 | is_valid = verify_signature(
559 | attestation_bytes,
560 | response.attestation_signature,
561 | self.server_sign_pk_bytes,
562 | ALGORITHMS['sig']
563 | )
564 | if not is_valid:
565 | return False, "Signature mismatch"
566 | return True, None
567 | except PQCSignatureError as e:
568 | msg = f"Attestation signature verification failed (PQC Error): {e}"
569 | self.log.warning(msg)
570 | return False, msg
571 | except Exception as e:
572 | msg = f"Unexpected error during attestation signature verification: {e}"
573 | self.log.exception(msg)
574 | return False, msg
575 |
576 | def __del__(self):
577 | """Cleanup when object is destroyed."""
578 | try:
579 | self.disconnect()
580 | except Exception:
581 | pass # Ignore errors during cleanup
582 |
--------------------------------------------------------------------------------
/scripts/mock_mcp_server.py:
--------------------------------------------------------------------------------
1 | import uvicorn
2 | from fastapi import FastAPI, HTTPException, Body
3 | from pydantic import BaseModel, Field
4 | import base64
5 | import os
6 | from pathlib import Path
7 | import logging
8 | import oqs
9 | import json
10 | from typing import Dict, Any, Optional
11 | from datetime import datetime, timedelta, timezone
12 | from cryptography.exceptions import InvalidTag, InvalidSignature
13 |
14 |
15 | import sys
16 | sys.path.insert(0, Path(__file__).parent.parent.resolve().__str__())
17 |
18 | from src.mcp_client import MCPRequest as MCPRequestSchema, MCPResponse as MCPResponseSchema
19 | from src.pqc_utils import (
20 | generate_key_pair,
21 | sign_message,
22 | verify_signature,
23 | ALGORITHMS,
24 | kem_encapsulate,
25 | kem_decapsulate,
26 | encrypt_aes_gcm,
27 | decrypt_aes_gcm,
28 | derive_aes_key,
29 | )
30 | from src.config_utils import (
31 | save_key_pair_to_files,
32 | load_key_pair_from_files,
33 | load_public_key_from_file,
34 | get_key_dir,
35 | )
36 |
37 | app = FastAPI(title="MCP Server (Development/Test Instance)")
38 |
39 | log = logging.getLogger(__name__)
40 |
41 | KEY_DIR = get_key_dir()
42 | SERVER_KEM_ALGO = ALGORITHMS["kem"]
43 | SERVER_SIGN_ALGO = ALGORITHMS["sig"]
44 |
45 | server_kem_key_pair_files = (KEY_DIR / "server_kem.pub", KEY_DIR / "server_kem.sec")
46 | server_sign_key_pair_files = (KEY_DIR / "server_sign.pub", KEY_DIR / "server_sign.sec")
47 | client_sign_pub_file = KEY_DIR / "client_sign.pub"
48 |
49 | server_kem_pk, server_kem_sk = None, None
50 | server_sign_pk, server_sign_sk = None, None
51 | client_sign_pk = None
52 |
53 |
54 | SESSION_KEYS: Dict[bytes, Dict[str, Any]] = {}
55 | SESSION_TIMEOUT = timedelta(hours=1)
56 |
57 | def initialize_server_keys():
58 | global server_kem_pk, server_kem_sk, server_sign_pk, server_sign_sk, client_sign_pk
59 | print("Initializing Server Keys (Development Instance)...")
60 | KEY_DIR.mkdir(parents=True, exist_ok=True)
61 |
62 |
63 | try:
64 | server_kem_pk, server_kem_sk = load_key_pair_from_files(*server_kem_key_pair_files)
65 | print(f"Server KEM keys loaded from {KEY_DIR}")
66 | except FileNotFoundError:
67 | print("Server KEM keys not found. Generating (for development instance)...")
68 | server_kem_pk, server_kem_sk = generate_key_pair(SERVER_KEM_ALGO)
69 | save_key_pair_to_files(server_kem_pk, server_kem_sk, *server_kem_key_pair_files)
70 | print(f"Server KEM keys generated and saved to {KEY_DIR}")
71 | except Exception as e:
72 | print(f"CRITICAL: Error initializing server KEM keys: {e}")
73 | server_kem_pk, server_kem_sk = None, None
74 |
75 |
76 |
77 | try:
78 | server_sign_pk, server_sign_sk = load_key_pair_from_files(*server_sign_key_pair_files)
79 | print(f"Server Signing keys loaded from {KEY_DIR}")
80 | except FileNotFoundError:
81 | print("Server Signing keys not found. Generating (for development instance)...")
82 | server_sign_pk, server_sign_sk = generate_key_pair(SERVER_SIGN_ALGO)
83 | save_key_pair_to_files(server_sign_pk, server_sign_sk, *server_sign_key_pair_files)
84 | print(f"Server Signing keys generated and saved to {KEY_DIR}")
85 | except Exception as e:
86 | print(f"CRITICAL: Error initializing server signing keys: {e}")
87 | server_sign_pk, server_sign_sk = None, None
88 |
89 |
90 | try:
91 | client_sign_pk = load_public_key_from_file(client_sign_pub_file)
92 | print(f"Client Signing public key loaded from {client_sign_pub_file}")
93 | except FileNotFoundError:
94 | print(f"WARNING: Client signing public key ({client_sign_pub_file.name}) not found.")
95 | print("Server will proceed but cannot verify client signatures.")
96 | except Exception as e:
97 | print(f"Error loading client signing public key: {e}")
98 |
99 |
100 | print("Server Key initialization complete.")
101 |
102 |
103 | @app.on_event("startup")
104 | def startup_event():
105 | initialize_server_keys()
106 |
107 |
108 |
109 | class EncryptedRequest(BaseModel):
110 | client_kem_public_key_b64: str = Field(..., alias="clientKemPublicKeyB64")
111 | nonce_b64: str = Field(..., alias="nonceB64")
112 | encrypted_payload_b64: str = Field(..., alias="encryptedPayloadB64")
113 |
114 | class EncryptedResponse(BaseModel):
115 | nonce_b64: str = Field(..., alias="nonceB64")
116 | encrypted_payload_b64: str = Field(..., alias="encryptedPayloadB64")
117 |
118 | @app.get("/")
119 | def read_root():
120 | """Basic health check endpoint."""
121 | return {"message": "MCP Development Server is running."}
122 |
123 |
124 | class KEMHandshakeRequest(BaseModel):
125 | client_kem_pub_key_b64: str
126 | client_sign_pub_key_b64: str
127 |
128 | class KEMHandshakeResponse(BaseModel):
129 | kem_ciphertext_b64: str = Field(..., alias="kemCiphertextB64")
130 |
131 |
132 | @app.post("/kem-handshake/initiate", response_model=KEMHandshakeResponse, response_model_by_alias=True)
133 | def handle_kem_handshake(request: KEMHandshakeRequest):
134 | """Performs KEM encapsulation using the provided client KEM public key and stores the session key."""
135 | log.info("Received KEM handshake initiation request.")
136 |
137 | # Removed server_kem_pk check here as it's not directly used for encapsulation against client PK
138 |
139 | try:
140 | client_kem_pk_bytes = base64.b64decode(request.client_kem_pub_key_b64)
141 | client_sign_pk_bytes = base64.b64decode(request.client_sign_pub_key_b64)
142 | log.debug(f"Received client KEM public key (length: {len(client_kem_pk_bytes)} bytes).")
143 | log.debug(f"Received client Signing public key (length: {len(client_sign_pk_bytes)} bytes).")
144 |
145 | ciphertext, shared_secret = kem_encapsulate(SERVER_KEM_ALGO, client_kem_pk_bytes)
146 |
147 |
148 | session_key = derive_aes_key(shared_secret)
149 |
150 | SESSION_KEYS[client_kem_pk_bytes] = {
151 | "session_key": session_key,
152 | "client_sign_pk": client_sign_pk_bytes,
153 | "timestamp": datetime.now(timezone.utc)
154 | }
155 | log.info(f"KEM encapsulation successful. Stored session info for client {client_kem_pk_bytes.hex()[:16]}...")
156 |
157 | ciphertext_b64 = base64.b64encode(ciphertext).decode('utf-8')
158 |
159 | return {"kemCiphertextB64": ciphertext_b64}
160 |
161 | except (base64.binascii.Error, ValueError) as e:
162 | log.error(f"Failed to decode client KEM public key: {e}")
163 | raise HTTPException(status_code=400, detail="Invalid base64 encoding for client KEM public key.")
164 | except oqs.OpenSSLError as e:
165 | log.exception(f"PQC KEM encapsulation failed: {e}")
166 | raise HTTPException(status_code=500, detail=f"Server-side KEM encapsulation failed: {e}")
167 | except Exception as e:
168 | log.exception(f"Unexpected error during KEM handshake: {e}")
169 | raise HTTPException(status_code=500, detail=f"Unexpected server error during KEM handshake.")
170 |
171 |
172 |
173 | @app.post("/inference", response_model=EncryptedResponse)
174 | def run_inference_secure(request: EncryptedRequest):
175 | """Handles encrypted and signed inference requests."""
176 | client_id_b64 = request.client_kem_public_key_b64
177 | log.info(f"Received /inference request from client ID: {client_id_b64[:10]}...")
178 |
179 |
180 | try:
181 | client_kem_pk_bytes = base64.b64decode(client_id_b64)
182 | except (base64.binascii.Error, ValueError):
183 | raise HTTPException(status_code=400, detail="Invalid client KEM public key format.")
184 |
185 | session_info = SESSION_KEYS.get(client_kem_pk_bytes)
186 | if not session_info:
187 | log.warning(f"No session info found for client ID: {client_id_b64[:10]}... Handshake required?")
188 | raise HTTPException(status_code=401, detail="No active session key. Perform KEM handshake.")
189 |
190 |
191 | now = datetime.now(timezone.utc)
192 | if now - session_info['timestamp'] > SESSION_TIMEOUT:
193 | log.warning(f"Session expired for client ID: {client_id_b64[:10]}...")
194 |
195 |
196 | raise HTTPException(status_code=401, detail="Session expired. Perform KEM handshake again.")
197 |
198 | session_key = session_info["session_key"]
199 | client_sign_pk_bytes = session_info["client_sign_pk"]
200 |
201 |
202 | try:
203 | nonce = base64.b64decode(request.nonce_b64)
204 | encrypted_payload = base64.b64decode(request.encrypted_payload_b64)
205 |
206 |
207 | decrypted_payload_bytes = decrypt_aes_gcm(session_key, nonce, encrypted_payload)
208 | request_payload_dict = json.loads(decrypted_payload_bytes.decode('utf-8'))
209 | log.debug(f"Successfully decrypted request payload: {request_payload_dict}")
210 |
211 | except (base64.binascii.Error, json.JSONDecodeError, ValueError) as e:
212 | log.warning(f"Failed to decode/decrypt payload or parse JSON: {e}")
213 | raise HTTPException(status_code=400, detail=f"Invalid encrypted request format: {e}")
214 | except Exception as e:
215 | log.error(f"Decryption failed for client {client_id_b64[:10]}...: {e}")
216 | raise HTTPException(status_code=400, detail=f"Payload decryption failed: {e}")
217 |
218 |
219 | try:
220 |
221 | signature_b64 = request_payload_dict.get("pqc_signature_b64")
222 | if not signature_b64:
223 | raise ValueError("Missing 'pqc_signature_b64' in decrypted payload.")
224 |
225 | signature_bytes = base64.b64decode(signature_b64)
226 |
227 |
228 |
229 | data_to_verify_dict = {
230 | "target_server_url": request_payload_dict.get("target_server_url"),
231 | "model_id": request_payload_dict.get("model_id"),
232 | "input_data": request_payload_dict.get("input_data"),
233 |
234 | }
235 |
236 | log.debug(f"Verifying signature against target_server_url: {data_to_verify_dict['target_server_url']}")
237 |
238 |
239 | message_bytes = json.dumps(data_to_verify_dict, sort_keys=True, separators=(',', ':')).encode('utf-8')
240 |
241 |
242 | verify_signature(message_bytes, signature_bytes, client_sign_pk_bytes, SERVER_SIGN_ALGO)
243 | log.info(f"Client signature verified successfully for client {client_id_b64[:10]}...")
244 |
245 | except (base64.binascii.Error, ValueError) as e:
246 | log.warning(f"Error decoding signature or missing fields for verification: {e}")
247 | raise HTTPException(status_code=400, detail=f"Invalid signature data format: {e}")
248 | except InvalidSignature:
249 | log.warning(f"Client signature verification FAILED for client {client_id_b64[:10]}...")
250 | raise HTTPException(status_code=403, detail="Invalid client signature")
251 | except Exception as e:
252 | log.exception(f"Unexpected error during signature verification: {e}")
253 | raise HTTPException(status_code=500, detail="Signature verification failed")
254 |
255 | model_id = request_payload_dict.get("model_id")
256 | raw_input_data = request_payload_dict.get("input_data", {})
257 | log.info(f"Processing inference for model '{model_id}'...")
258 |
259 | output_data: Any = None
260 | error_message: Optional[str] = None
261 | status = "success"
262 |
263 | processed_input_data = {}
264 | if isinstance(raw_input_data, str):
265 | try:
266 | parsed_data = json.loads(raw_input_data)
267 | if isinstance(parsed_data, dict):
268 | processed_input_data = parsed_data
269 | else:
270 | status = "error"
271 | error_message = "Invalid input_data: content of JSON string is not an object."
272 | log.warning(f"input_data string parsed, but not to a dict: {raw_input_data}")
273 | except json.JSONDecodeError as e:
274 | status = "error"
275 | error_message = f"Invalid input_data: failed to parse JSON string. ({e})"
276 | log.warning(f"Failed to parse input_data string '{raw_input_data}': {e}")
277 | elif isinstance(raw_input_data, dict):
278 | processed_input_data = raw_input_data
279 | else:
280 | status = "error"
281 | error_message = "Invalid input_data type: expected JSON object or JSON string."
282 | log.warning(f"input_data is of unexpected type: {type(raw_input_data)}. Value: {raw_input_data}")
283 |
284 | if status == "success":
285 | if model_id == "model_caps":
286 | text = processed_input_data.get("text")
287 | if isinstance(text, str):
288 | output_data = {"capitalized_text": text.upper()}
289 | else:
290 | status = "error"
291 | error_message = ("Invalid input for model_caps: 'text' field must be a string "
292 | "and present in input_data.")
293 | log.warning(f"model_caps: 'text' is not a string or missing. input_data: {processed_input_data}")
294 | elif model_id == "model_reverse":
295 | text = processed_input_data.get("text")
296 | if isinstance(text, str):
297 | output_data = {"reversed_text": text[::-1]}
298 | else:
299 | status = "error"
300 | error_message = ("Invalid input for model_reverse: 'text' field must be a string "
301 | "and present in input_data.")
302 | log.warning(f"model_reverse: 'text' is not a string or missing. input_data: {processed_input_data}")
303 |
304 | # Text Analysis Models
305 | elif model_id == "sentiment_analysis":
306 | text = processed_input_data.get("text")
307 | if isinstance(text, str) and text.strip():
308 | try:
309 | from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
310 | analyzer = SentimentIntensityAnalyzer()
311 | scores = analyzer.polarity_scores(text)
312 |
313 | # Determine overall sentiment based on compound score
314 | compound = scores['compound']
315 | if compound >= 0.05:
316 | sentiment = "positive"
317 | elif compound <= -0.05:
318 | sentiment = "negative"
319 | else:
320 | sentiment = "neutral"
321 |
322 | # Calculate confidence based on compound score magnitude
323 | confidence = abs(compound)
324 | if confidence < 0.05:
325 | confidence = 0.1 # Low confidence for neutral
326 |
327 | output_data = {
328 | "sentiment": sentiment,
329 | "confidence": round(confidence, 3),
330 | "compound_score": round(compound, 3),
331 | "positive_score": round(scores['pos'], 3),
332 | "negative_score": round(scores['neg'], 3),
333 | "neutral_score": round(scores['neu'], 3),
334 | "text_length": len(text),
335 | "analysis_method": "VADER"
336 | }
337 | except ImportError:
338 | # Fallback to basic analysis if VADER not available
339 | output_data = {
340 | "sentiment": "neutral",
341 | "confidence": 0.1,
342 | "error": "VADER sentiment analysis not available",
343 | "analysis_method": "fallback"
344 | }
345 | else:
346 | status = "error"
347 | error_message = "Invalid input for sentiment_analysis: 'text' field must be a non-empty string."
348 |
349 | elif model_id == "keyword_extraction":
350 | text = processed_input_data.get("text")
351 | if isinstance(text, str) and text.strip():
352 | import re
353 | from collections import Counter
354 | import math
355 |
356 | # Remove punctuation and convert to lowercase
357 | words = re.findall(r'\b[a-zA-Z]{2,}\b', text.lower())
358 |
359 | # Comprehensive stop words list
360 | stop_words = {
361 | "the", "and", "for", "are", "but", "not", "you", "all", "can", "had", "her", "was", "one", "our", "out", "day", "get", "has", "him", "his", "how", "man", "new", "now", "old", "see", "two", "way", "who", "boy", "did", "its", "let", "put", "say", "she", "too", "use", "may", "come", "could", "each", "even", "find", "from", "have", "here", "into", "just", "like", "look", "make", "more", "most", "only", "over", "such", "take", "than", "that", "them", "well", "were", "what", "when", "will", "with", "would", "your", "this", "they", "been", "their", "said", "each", "which", "there", "time", "very", "after", "first", "never", "these", "think", "where", "being", "every", "great", "might", "shall", "still", "those", "under", "while"
362 | }
363 |
364 | # Filter out stop words and short words
365 | filtered_words = [word for word in words if word not in stop_words and len(word) >= 3]
366 |
367 | if not filtered_words:
368 | output_data = {
369 | "keywords": [],
370 | "total_words": len(words),
371 | "unique_words": 0,
372 | "filtered_words": 0,
373 | "analysis_method": "frequency_based"
374 | }
375 | else:
376 | # Count word frequency
377 | word_counts = Counter(filtered_words)
378 | total_filtered = len(filtered_words)
379 |
380 | # Calculate TF-IDF-like scores (simplified)
381 | keywords_with_scores = []
382 | for word, count in word_counts.items():
383 | tf = count / total_filtered # Term frequency
384 | # Simple IDF approximation based on word length and frequency
385 | idf = math.log(total_filtered / count) + (len(word) / 10)
386 | score = tf * idf
387 | keywords_with_scores.append({
388 | "word": word,
389 | "frequency": count,
390 | "tf_score": round(tf, 4),
391 | "relevance_score": round(score, 4)
392 | })
393 |
394 | # Sort by relevance score and get top 10
395 | keywords_with_scores.sort(key=lambda x: x["relevance_score"], reverse=True)
396 | top_keywords = keywords_with_scores[:10]
397 |
398 | output_data = {
399 | "keywords": top_keywords,
400 | "total_words": len(words),
401 | "unique_words": len(set(words)),
402 | "filtered_words": len(filtered_words),
403 | "unique_filtered": len(set(filtered_words)),
404 | "analysis_method": "tf_idf_based"
405 | }
406 | else:
407 | status = "error"
408 | error_message = "Invalid input for keyword_extraction: 'text' field must be a non-empty string."
409 |
410 | # Data Processing Models
411 | elif model_id == "json_formatter":
412 | data = processed_input_data.get("data")
413 | if data is not None:
414 | try:
415 | # If data is a string, try to parse it as JSON
416 | if isinstance(data, str):
417 | import json
418 | parsed_data = json.loads(data)
419 | else:
420 | parsed_data = data
421 |
422 | # Format with proper indentation
423 | formatted_json = json.dumps(parsed_data, indent=2, sort_keys=True)
424 | output_data = {
425 | "formatted_json": formatted_json,
426 | "is_valid": True,
427 | "size_bytes": len(formatted_json)
428 | }
429 | except (json.JSONDecodeError, TypeError) as e:
430 | output_data = {
431 | "formatted_json": None,
432 | "is_valid": False,
433 | "error": str(e)
434 | }
435 | else:
436 | status = "error"
437 | error_message = "Invalid input for json_formatter: 'data' field is required."
438 |
439 | elif model_id == "csv_analyzer":
440 | csv_data = processed_input_data.get("csv_data")
441 | if isinstance(csv_data, str):
442 | try:
443 | import csv
444 | from io import StringIO
445 |
446 | # Parse CSV data
447 | csv_reader = csv.reader(StringIO(csv_data))
448 | rows = list(csv_reader)
449 |
450 | if not rows:
451 | output_data = {"error": "Empty CSV data"}
452 | else:
453 | headers = rows[0] if rows else []
454 | data_rows = rows[1:] if len(rows) > 1 else []
455 |
456 | # Basic analysis
457 | analysis = {
458 | "total_rows": len(data_rows),
459 | "total_columns": len(headers),
460 | "headers": headers,
461 | "sample_data": data_rows[:3] if data_rows else [],
462 | "column_stats": {}
463 | }
464 |
465 | # Analyze each column
466 | for i, header in enumerate(headers):
467 | column_data = [row[i] if i < len(row) else "" for row in data_rows]
468 | non_empty = [val for val in column_data if val.strip()]
469 |
470 | analysis["column_stats"][header] = {
471 | "non_empty_count": len(non_empty),
472 | "empty_count": len(column_data) - len(non_empty),
473 | "sample_values": non_empty[:3]
474 | }
475 |
476 | output_data = analysis
477 | except Exception as e:
478 | output_data = {"error": f"CSV parsing error: {str(e)}"}
479 | else:
480 | status = "error"
481 | error_message = "Invalid input for csv_analyzer: 'csv_data' field must be a string."
482 |
483 | # Code Utility Models
484 | elif model_id == "code_formatter":
485 | code = processed_input_data.get("code")
486 | language = processed_input_data.get("language", "python")
487 |
488 | if isinstance(code, str):
489 | # Simple code formatting (basic indentation)
490 | lines = code.split('\n')
491 | formatted_lines = []
492 | indent_level = 0
493 |
494 | for line in lines:
495 | stripped = line.strip()
496 | if not stripped:
497 | formatted_lines.append("")
498 | continue
499 |
500 | # Decrease indent for closing brackets/keywords
501 | if any(stripped.startswith(keyword) for keyword in ['end', '}', ')', ']', 'else:', 'elif', 'except:', 'finally:']):
502 | indent_level = max(0, indent_level - 1)
503 |
504 | # Add indentation
505 | formatted_lines.append(" " * indent_level + stripped)
506 |
507 | # Increase indent for opening brackets/keywords
508 | if any(stripped.endswith(char) for char in ['{', '(', '[']) or any(stripped.endswith(keyword) for keyword in [':', 'then', 'do']):
509 | indent_level += 1
510 |
511 | output_data = {
512 | "formatted_code": '\n'.join(formatted_lines),
513 | "language": language,
514 | "original_lines": len(lines),
515 | "formatted_lines": len(formatted_lines)
516 | }
517 | else:
518 | status = "error"
519 | error_message = "Invalid input for code_formatter: 'code' field must be a string."
520 |
521 | elif model_id == "code_validator":
522 | code = processed_input_data.get("code")
523 | language = processed_input_data.get("language", "python")
524 |
525 | if isinstance(code, str):
526 | issues = []
527 |
528 | # Basic validation checks
529 | if language.lower() == "python":
530 | # Check for common Python issues
531 | lines = code.split('\n')
532 | for i, line in enumerate(lines, 1):
533 | if line.strip():
534 | # Check for mixed tabs and spaces
535 | if '\t' in line and ' ' in line:
536 | issues.append(f"Line {i}: Mixed tabs and spaces")
537 |
538 | # Check for missing colons
539 | stripped = line.strip()
540 | if any(stripped.startswith(keyword) for keyword in ['if ', 'for ', 'while ', 'def ', 'class ', 'try', 'except', 'else', 'elif']):
541 | if not stripped.endswith(':'):
542 | issues.append(f"Line {i}: Missing colon after {stripped.split()[0]}")
543 |
544 | # Try to compile/parse if possible
545 | syntax_valid = True
546 | syntax_error = None
547 |
548 | if language.lower() == "python":
549 | try:
550 | compile(code, '', 'exec')
551 | except SyntaxError as e:
552 | syntax_valid = False
553 | syntax_error = str(e)
554 | issues.append(f"Syntax Error: {syntax_error}")
555 |
556 | output_data = {
557 | "is_valid": syntax_valid and len(issues) == 0,
558 | "syntax_valid": syntax_valid,
559 | "issues": issues,
560 | "language": language,
561 | "lines_checked": len(code.split('\n'))
562 | }
563 | else:
564 | status = "error"
565 | error_message = "Invalid input for code_validator: 'code' field must be a string."
566 |
567 | # Mathematical Models
568 | elif model_id == "math_calculator":
569 | expression = processed_input_data.get("expression")
570 | if isinstance(expression, str):
571 | try:
572 | # Safe evaluation of mathematical expressions
573 | import re
574 | import math
575 |
576 | # Only allow safe mathematical operations
577 | allowed_chars = set('0123456789+-*/().^ ')
578 | allowed_functions = ['sin', 'cos', 'tan', 'log', 'sqrt', 'abs', 'pow']
579 |
580 | # Basic safety check
581 | if all(c in allowed_chars or c.isalpha() for c in expression):
582 | # Replace ^ with **
583 | safe_expr = expression.replace('^', '**')
584 |
585 | # Create safe namespace
586 | safe_dict = {
587 | "__builtins__": {},
588 | "sin": math.sin, "cos": math.cos, "tan": math.tan,
589 | "log": math.log, "sqrt": math.sqrt, "abs": abs,
590 | "pow": pow, "pi": math.pi, "e": math.e
591 | }
592 |
593 | result = eval(safe_expr, safe_dict)
594 | output_data = {
595 | "result": result,
596 | "expression": expression,
597 | "is_valid": True
598 | }
599 | else:
600 | output_data = {
601 | "result": None,
602 | "expression": expression,
603 | "is_valid": False,
604 | "error": "Expression contains invalid characters"
605 | }
606 | except Exception as e:
607 | output_data = {
608 | "result": None,
609 | "expression": expression,
610 | "is_valid": False,
611 | "error": str(e)
612 | }
613 | else:
614 | status = "error"
615 | error_message = "Invalid input for math_calculator: 'expression' field must be a string."
616 |
617 | elif model_id == "statistics_analyzer":
618 | numbers = processed_input_data.get("numbers")
619 | if isinstance(numbers, list) and all(isinstance(x, (int, float)) for x in numbers):
620 | if len(numbers) > 0:
621 | import statistics
622 |
623 | try:
624 | output_data = {
625 | "count": len(numbers),
626 | "sum": sum(numbers),
627 | "mean": statistics.mean(numbers),
628 | "median": statistics.median(numbers),
629 | "min": min(numbers),
630 | "max": max(numbers),
631 | "range": max(numbers) - min(numbers),
632 | "std_dev": statistics.stdev(numbers) if len(numbers) > 1 else 0,
633 | "variance": statistics.variance(numbers) if len(numbers) > 1 else 0
634 | }
635 | except Exception as e:
636 | output_data = {"error": f"Statistics calculation error: {str(e)}"}
637 | else:
638 | output_data = {"error": "Empty numbers list"}
639 | else:
640 | status = "error"
641 | error_message = "Invalid input for statistics_analyzer: 'numbers' field must be a list of numbers."
642 |
643 | else:
644 | status = "error"
645 | error_message = f"Unknown model ID: '{model_id}'"
646 |
647 | log.info(f"Inference result - Status: {status}, Output Keys: {list(output_data.keys()) if isinstance(output_data, dict) else type(output_data)}")
648 |
649 |
650 | attestation_data = {
651 | "serverVersion": "mock-0.1.0",
652 | "modelId": model_id,
653 | "status": status,
654 | "inputHash": base64.b64encode(json.dumps(raw_input_data, sort_keys=True).encode()).decode(),
655 | "outputHash": base64.b64encode(json.dumps(output_data, sort_keys=True).encode()).decode(),
656 | "timestamp": datetime.now(timezone.utc).isoformat()
657 | }
658 |
659 |
660 | log.debug(f"Attempting to sign attestation. Server sign SK is {'set' if server_sign_sk else 'None'}. Algorithm: {SERVER_SIGN_ALGO}")
661 | try:
662 | attestation_string = json.dumps(attestation_data, sort_keys=True, separators=(',', ':')).encode('utf-8')
663 | attestation_signature_bytes = sign_message(attestation_string, server_sign_sk, SERVER_SIGN_ALGO)
664 |
665 | log.debug(f"Raw attestation_signature_bytes (len: {len(attestation_signature_bytes) if attestation_signature_bytes else 'None'}): {attestation_signature_bytes.hex()[:32] if attestation_signature_bytes else 'None'}...")
666 |
667 | if not attestation_signature_bytes:
668 | log.error("sign_message returned None or empty bytes, but no exception was raised.")
669 | attestation_signature_bytes = None
670 | error_message = error_message or "Failed to generate server attestation signature (sign_message returned empty)."
671 | status = "error"
672 | else:
673 | log.debug("Server attestation data signed successfully.")
674 |
675 | except Exception as e:
676 | log.exception(f"Failed to sign server attestation: {e}")
677 | attestation_signature_bytes = None
678 | error_message = error_message or "Failed to generate server attestation signature (exception during sign_message)."
679 | status = "error"
680 |
681 |
682 | response_payload_cleartext = {
683 | "status": status,
684 | "output_data": output_data,
685 | "error_message": error_message,
686 | "attestation_data": attestation_data,
687 | "attestation_signature_b64": base64.b64encode(attestation_signature_bytes).decode('utf-8') if attestation_signature_bytes else None,
688 | "audit_hash": None
689 | }
690 |
691 |
692 | try:
693 | response_payload_json = json.dumps(response_payload_cleartext).encode('utf-8')
694 |
695 | resp_nonce, resp_ciphertext = encrypt_aes_gcm(session_key, response_payload_json)
696 | log.debug("Response payload encrypted successfully.")
697 |
698 | return {
699 | "nonceB64": base64.b64encode(resp_nonce).decode('utf-8'),
700 | "encryptedPayloadB64": base64.b64encode(resp_ciphertext).decode('utf-8')
701 | }
702 | except Exception as e:
703 | log.exception(f"Failed to encrypt response payload: {e}")
704 | raise HTTPException(status_code=500, detail="Failed to encrypt server response")
705 |
706 |
707 | class PolicyUpdateRequest(BaseModel):
708 | client_kem_pub_key_b64: str
709 | nonce_b64: str
710 | ciphertext_b64: str
711 | signature_b64: str
712 |
713 | @app.post("/policy-update")
714 | async def policy_update(request: PolicyUpdateRequest):
715 | log.info("Received request for /policy-update")
716 | try:
717 |
718 | client_kem_pk_bytes = base64.b64decode(request.client_kem_pub_key_b64)
719 | session_info = SESSION_KEYS.get(client_kem_pk_bytes)
720 | if not session_info:
721 | log.error("Session info not found for client KEM PK.")
722 | raise HTTPException(status_code=401, detail="Session not established or expired.")
723 |
724 |
725 | now = datetime.now(timezone.utc)
726 | if now - session_info['timestamp'] > SESSION_TIMEOUT:
727 | log.warning(f"Session expired for client KEM PK: {client_kem_pk_bytes.hex()[:16]}...")
728 |
729 | raise HTTPException(status_code=401, detail="Session expired. Perform KEM handshake again.")
730 |
731 | session_key = session_info["session_key"]
732 | client_sign_pk_bytes = session_info["client_sign_pk"]
733 | log.debug(f"Retrieved session info for client KEM PK (first 10 bytes): {client_kem_pk_bytes[:10].hex()}...")
734 |
735 |
736 | nonce_bytes = base64.b64decode(request.nonce_b64)
737 | ciphertext_bytes = base64.b64decode(request.ciphertext_b64)
738 | signature_bytes = base64.b64decode(request.signature_b64)
739 | log.debug("Decoded nonce, ciphertext, and signature from request.")
740 |
741 |
742 | try:
743 | decrypted_policy_bytes = decrypt_aes_gcm(session_key, nonce_bytes, ciphertext_bytes)
744 | log.debug("Policy content decrypted successfully.")
745 | except InvalidTag:
746 | log.error("Decryption failed: Invalid AES-GCM tag.")
747 | raise HTTPException(status_code=400, detail="Policy decryption failed (invalid tag).")
748 | except Exception as e:
749 | log.error(f"Decryption failed with unexpected error: {e}")
750 | raise HTTPException(status_code=500, detail=f"Policy decryption failed: {e}")
751 |
752 | try:
753 |
754 | verify_signature(decrypted_policy_bytes, signature_bytes, client_sign_pk_bytes, SERVER_SIGN_ALGO)
755 | log.info("Client signature VERIFIED successfully.")
756 | except InvalidSignature:
757 | log.error("Client signature verification FAILED.")
758 | raise HTTPException(status_code=403, detail="Invalid client signature on policy.")
759 | except Exception as e:
760 | log.error(f"Signature verification failed with unexpected error: {e}")
761 | raise HTTPException(status_code=500, detail=f"Policy signature verification failed: {e}")
762 |
763 |
764 | policy_content = decrypted_policy_bytes.decode('utf-8')
765 | log.info(f"Successfully received and verified policy update:\n--- POLICY START ---\n{policy_content}\n--- POLICY END ---")
766 | print(f"\n--- Received Policy Update ---\n{policy_content}\n-----------------------------")
767 |
768 |
769 | response_status = {"status": "Policy update received and verified successfully."}
770 | response_payload_bytes = json.dumps(response_status).encode('utf-8')
771 |
772 |
773 | server_signature = sign_message(response_payload_bytes, server_sign_sk, SERVER_SIGN_ALGO)
774 | log.debug("Server response signed.")
775 |
776 |
777 | resp_nonce_bytes, resp_ciphertext_bytes = encrypt_aes_gcm(session_key, response_payload_bytes)
778 | log.debug("Server response encrypted.")
779 |
780 |
781 | return {
782 | "nonce_b64": base64.b64encode(resp_nonce_bytes).decode('utf-8'),
783 | "ciphertext_b64": base64.b64encode(resp_ciphertext_bytes).decode('utf-8'),
784 | "signature_b64": base64.b64encode(server_signature).decode('utf-8')
785 | }
786 |
787 |
788 | except HTTPException as http_exc:
789 |
790 | raise http_exc
791 | except (ValueError, TypeError, base64.binascii.Error) as e:
792 |
793 | log.error(f"Error decoding request data: {e}")
794 | raise HTTPException(status_code=400, detail=f"Invalid request format or base64 encoding: {e}")
795 | except Exception as e:
796 | log.exception("Unexpected error processing /policy-update:")
797 | raise HTTPException(status_code=500, detail="Internal server error during policy update.")
798 |
799 |
800 | class ServerKeysResponse(BaseModel):
801 | server_kem_public_key_b64: str
802 | server_sign_public_key_b64: str
803 |
804 | @app.get("/keys", response_model=ServerKeysResponse)
805 | def get_server_public_keys():
806 | """Returns the server's public KEM and signing keys."""
807 | log.info("Request received for /keys endpoint.")
808 | if not server_kem_pk or not server_sign_pk:
809 | # This should not happen if startup initialization worked
810 | log.error("Server keys are not initialized. Cannot serve public keys.")
811 | raise HTTPException(status_code=503, detail="Server keys unavailable.")
812 |
813 | try:
814 | kem_pk_b64 = base64.b64encode(server_kem_pk).decode('utf-8')
815 | sign_pk_b64 = base64.b64encode(server_sign_pk).decode('utf-8')
816 | log.info("Sending server public keys.")
817 | return ServerKeysResponse(
818 | server_kem_public_key_b64=kem_pk_b64,
819 | server_sign_public_key_b64=sign_pk_b64
820 | )
821 | except Exception as e:
822 | log.exception("Error encoding server public keys:")
823 | raise HTTPException(status_code=500, detail="Error preparing server keys.")
824 |
825 |
826 | class ModelInfo(BaseModel):
827 | id: str
828 | name: str
829 | description: str
830 | category: str
831 | input_schema: Dict[str, Any]
832 | output_schema: Dict[str, Any]
833 | example_input: Dict[str, Any]
834 | example_output: Dict[str, Any]
835 |
836 | class ModelsResponse(BaseModel):
837 | models: list[ModelInfo]
838 | total_count: int
839 |
840 | @app.get("/models", response_model=ModelsResponse)
841 | def get_available_models():
842 | """Returns information about all available models."""
843 | log.info("Request received for /models endpoint.")
844 |
845 | models = [
846 | ModelInfo(
847 | id="model_caps",
848 | name="Text Capitalizer",
849 | description="Converts input text to uppercase",
850 | category="text_processing",
851 | input_schema={"text": "string (required)"},
852 | output_schema={"capitalized_text": "string"},
853 | example_input={"text": "hello world"},
854 | example_output={"capitalized_text": "HELLO WORLD"}
855 | ),
856 | ModelInfo(
857 | id="model_reverse",
858 | name="Text Reverser",
859 | description="Reverses the input text character by character",
860 | category="text_processing",
861 | input_schema={"text": "string (required)"},
862 | output_schema={"reversed_text": "string"},
863 | example_input={"text": "hello world"},
864 | example_output={"reversed_text": "dlrow olleh"}
865 | ),
866 | ModelInfo(
867 | id="sentiment_analysis",
868 | name="Sentiment Analyzer",
869 | description="Analyzes the sentiment of input text (positive, negative, neutral)",
870 | category="text_analysis",
871 | input_schema={"text": "string (required)"},
872 | output_schema={
873 | "sentiment": "string (positive|negative|neutral)",
874 | "confidence": "float (0.0-1.0)",
875 | "positive_indicators": "integer",
876 | "negative_indicators": "integer"
877 | },
878 | example_input={"text": "I love this amazing product!"},
879 | example_output={
880 | "sentiment": "positive",
881 | "confidence": 0.7,
882 | "positive_indicators": 2,
883 | "negative_indicators": 0
884 | }
885 | ),
886 | ModelInfo(
887 | id="keyword_extraction",
888 | name="Keyword Extractor",
889 | description="Extracts important keywords from text with frequency analysis",
890 | category="text_analysis",
891 | input_schema={"text": "string (required)"},
892 | output_schema={
893 | "keywords": "array of {word: string, frequency: integer}",
894 | "total_words": "integer",
895 | "unique_words": "integer"
896 | },
897 | example_input={"text": "Python programming is great for data analysis and machine learning"},
898 | example_output={
899 | "keywords": [{"word": "python", "frequency": 1}, {"word": "programming", "frequency": 1}],
900 | "total_words": 10,
901 | "unique_words": 9
902 | }
903 | ),
904 | ModelInfo(
905 | id="json_formatter",
906 | name="JSON Formatter",
907 | description="Formats and validates JSON data with proper indentation",
908 | category="data_processing",
909 | input_schema={"data": "string or object (required)"},
910 | output_schema={
911 | "formatted_json": "string",
912 | "is_valid": "boolean",
913 | "size_bytes": "integer"
914 | },
915 | example_input={"data": '{"name":"John","age":30}'},
916 | example_output={
917 | "formatted_json": "{\n \"age\": 30,\n \"name\": \"John\"\n}",
918 | "is_valid": True,
919 | "size_bytes": 32
920 | }
921 | ),
922 | ModelInfo(
923 | id="csv_analyzer",
924 | name="CSV Analyzer",
925 | description="Analyzes CSV data structure and provides statistics",
926 | category="data_processing",
927 | input_schema={"csv_data": "string (required)"},
928 | output_schema={
929 | "total_rows": "integer",
930 | "total_columns": "integer",
931 | "headers": "array of strings",
932 | "sample_data": "array of arrays",
933 | "column_stats": "object"
934 | },
935 | example_input={"csv_data": "name,age,city\nJohn,30,NYC\nJane,25,LA"},
936 | example_output={
937 | "total_rows": 2,
938 | "total_columns": 3,
939 | "headers": ["name", "age", "city"],
940 | "sample_data": [["John", "30", "NYC"]],
941 | "column_stats": {"name": {"non_empty_count": 2, "empty_count": 0}}
942 | }
943 | ),
944 | ModelInfo(
945 | id="code_formatter",
946 | name="Code Formatter",
947 | description="Formats code with proper indentation and structure",
948 | category="code_utilities",
949 | input_schema={"code": "string (required)", "language": "string (optional, default: python)"},
950 | output_schema={
951 | "formatted_code": "string",
952 | "language": "string",
953 | "original_lines": "integer",
954 | "formatted_lines": "integer"
955 | },
956 | example_input={"code": "def hello():\nprint('world')", "language": "python"},
957 | example_output={
958 | "formatted_code": "def hello():\n print('world')",
959 | "language": "python",
960 | "original_lines": 2,
961 | "formatted_lines": 2
962 | }
963 | ),
964 | ModelInfo(
965 | id="code_validator",
966 | name="Code Validator",
967 | description="Validates code syntax and checks for common issues",
968 | category="code_utilities",
969 | input_schema={"code": "string (required)", "language": "string (optional, default: python)"},
970 | output_schema={
971 | "is_valid": "boolean",
972 | "syntax_valid": "boolean",
973 | "issues": "array of strings",
974 | "language": "string",
975 | "lines_checked": "integer"
976 | },
977 | example_input={"code": "def hello()\nprint('world')", "language": "python"},
978 | example_output={
979 | "is_valid": False,
980 | "syntax_valid": False,
981 | "issues": ["Syntax Error: invalid syntax"],
982 | "language": "python",
983 | "lines_checked": 2
984 | }
985 | ),
986 | ModelInfo(
987 | id="math_calculator",
988 | name="Math Calculator",
989 | description="Evaluates mathematical expressions safely",
990 | category="mathematics",
991 | input_schema={"expression": "string (required)"},
992 | output_schema={
993 | "result": "number or null",
994 | "expression": "string",
995 | "is_valid": "boolean",
996 | "error": "string (optional)"
997 | },
998 | example_input={"expression": "2 + 3 * 4"},
999 | example_output={
1000 | "result": 14,
1001 | "expression": "2 + 3 * 4",
1002 | "is_valid": True
1003 | }
1004 | ),
1005 | ModelInfo(
1006 | id="statistics_analyzer",
1007 | name="Statistics Analyzer",
1008 | description="Calculates statistical measures for numerical data",
1009 | category="mathematics",
1010 | input_schema={"numbers": "array of numbers (required)"},
1011 | output_schema={
1012 | "count": "integer",
1013 | "sum": "number",
1014 | "mean": "number",
1015 | "median": "number",
1016 | "min": "number",
1017 | "max": "number",
1018 | "range": "number",
1019 | "std_dev": "number",
1020 | "variance": "number"
1021 | },
1022 | example_input={"numbers": [1, 2, 3, 4, 5]},
1023 | example_output={
1024 | "count": 5,
1025 | "sum": 15,
1026 | "mean": 3.0,
1027 | "median": 3,
1028 | "min": 1,
1029 | "max": 5,
1030 | "range": 4,
1031 | "std_dev": 1.58,
1032 | "variance": 2.5
1033 | }
1034 | )
1035 | ]
1036 |
1037 | return ModelsResponse(models=models, total_count=len(models))
1038 |
1039 |
1040 | if __name__ == "__main__":
1041 | logging.basicConfig(level=logging.DEBUG)
1042 | print("Starting MCP Development Server...")
1043 | uvicorn.run(app, host="127.0.0.1", port=8000)
1044 |
--------------------------------------------------------------------------------