├── .dockerignore ├── .gitignore ├── CLAUDE.md ├── Dockerfile ├── README.md ├── README_zh.md ├── config └── example.json ├── docker-compose.yml ├── main.py ├── pyproject.toml ├── src ├── __init__.py ├── api │ ├── __init__.py │ ├── handlers.py │ ├── middleware │ │ ├── __init__.py │ │ ├── auth.py │ │ └── timing.py │ └── routes.py ├── common │ ├── __init__.py │ ├── logging.py │ ├── token_cache.py │ └── token_counter.py ├── config │ ├── __init__.py │ ├── settings.py │ └── watcher.py ├── core │ ├── __init__.py │ ├── clients │ │ ├── __init__.py │ │ └── openai_client.py │ └── converters │ │ ├── __init__.py │ │ ├── request_converter.py │ │ ├── response_converter.py │ │ └── stream_converters.py ├── main.py └── models │ ├── __init__.py │ ├── anthropic.py │ ├── errors.py │ └── openai.py ├── tests ├── __init__.py ├── fixtures.py └── integration │ ├── test_end_to_end.py │ ├── test_error_handling.py │ ├── test_health_endpoint.py │ ├── test_messages_endpoint.py │ ├── test_model_mapping_integration.py │ └── test_streaming.py └── uv.lock /.dockerignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | *.egg-info/ 21 | .installed.cfg 22 | *.egg 23 | 24 | # Virtual environments 25 | .env 26 | .venv 27 | env/ 28 | venv/ 29 | ENV/ 30 | env.bak/ 31 | venv.bak/ 32 | 33 | # IDEs and editors 34 | .vscode/ 35 | .idea/ 36 | *.swp 37 | *.swo 38 | *~ 39 | 40 | # OS generated files 41 | .DS_Store 42 | .DS_Store? 43 | ._* 44 | .Spotlight-V100 45 | .Trashes 46 | ehthumbs.db 47 | Thumbs.db 48 | 49 | # Git 50 | .git/ 51 | .gitignore 52 | 53 | # Log files 54 | logs/ 55 | *.log 56 | 57 | # Testing and coverage 58 | .coverage 59 | .pytest_cache/ 60 | .tox/ 61 | .nox/ 62 | htmlcov/ 63 | .coverage.* 64 | coverage.xml 65 | *.cover 66 | *.py,cover 67 | .hypothesis/ 68 | 69 | # Documentation 70 | docs/_build/ 71 | .readthedocs.yml 72 | 73 | # Temporary files 74 | *.tmp 75 | *.temp 76 | .cache/ 77 | 78 | # Configuration files 79 | config/settings.json 80 | 81 | # Docker 82 | Dockerfile* 83 | .dockerignore 84 | docker-compose*.yml -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .nox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | *.py,cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | cover/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | db.sqlite3-journal 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | .pybuilder/ 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | Pipfile.lock 88 | 89 | # poetry 90 | #poetry.lock 91 | 92 | # pdm 93 | .pdm.toml 94 | 95 | # PEP 582 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | envs/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | # pytype static type analyzer 134 | .pytype/ 135 | 136 | # Cython debug symbols 137 | cython_debug/ 138 | 139 | # PyCharm 140 | .idea/ 141 | 142 | # VS Code 143 | .vscode/ 144 | 145 | # macOS 146 | .DS_Store 147 | .AppleDouble 148 | .LSOverride 149 | Icon 150 | ._* 151 | .DocumentRevisions-V100 152 | .fseventsd 153 | .Spotlight-V100 154 | .TemporaryItems 155 | .Trashes 156 | .VolumeIcon.icns 157 | .com.apple.timemachine.donotpresent 158 | .AppleDB 159 | .AppleDesktop 160 | Network Trash Folder 161 | Temporary Items 162 | .apdisk 163 | 164 | # Application specific 165 | .env.local 166 | .env.development.local 167 | .env.test.local 168 | .env.production.local 169 | 170 | # Configuration files with sensitive data 171 | config/local_*.json 172 | config/secrets.json 173 | 174 | # Logs 175 | *.log 176 | 177 | # Data files 178 | data/ 179 | tmp/ 180 | .temp/ 181 | 182 | # Docker 183 | 184 | # 185 | .claude 186 | .cunzhi-memory 187 | .ruff_cache 188 | .serena 189 | example 190 | config/settings.json -------------------------------------------------------------------------------- /CLAUDE.md: -------------------------------------------------------------------------------- 1 | # CLAUDE.md 2 | 3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. 4 | 5 | ## 项目概述 6 | 7 | 这是一个高性能的 RESTful API 代理服务,将 OpenAI API 格式转换为 Anthropic Claude API 兼容格式,允许开发者使用现有的 OpenAI 客户端代码来调用 Anthropic Claude 模型。 8 | 9 | 主要功能: 10 | - ✅ 请求/响应格式转换 11 | - ✅ 流式响应支持 12 | - ✅ 错误处理映射 13 | - ✅ 配置管理 14 | - ✅ 结构化日志 15 | - ✅ 健康检查端点 16 | 17 | ## 项目结构 18 | 19 | ``` 20 | openai-to-claude/ 21 | ├── src/ 22 | │ ├── config/ # 配置管理 23 | │ ├── models/ # Pydantic 数据模型 24 | │ ├── converters/ # 数据格式转换器 25 | │ ├── clients/ # HTTP 客户端 26 | │ ├── api/ # API 端点和中间件 27 | │ └── common/ # 公共工具日志、token计数等) 28 | ├── tests/ # 测试套件 29 | ├── config/ # 配置文件模板 30 | └── pyproject.toml # 项目依赖和配置 31 | ``` 32 | 33 | ## 核心架构 34 | 35 | ### 数据流向 36 | 1. 客户端发送 Anthropic 格式的请求到 `/v1/messages` 端点 37 | 2. `MessagesHandler` 接收请求并使用 `AnthropicToOpenAIConverter` 转换为 OpenAI 格式 38 | 3. `OpenAIServiceClient` 将请求发送到实际的 OpenAI API 服务 39 | 4. 收到 OpenAI 响应后,使用 `OpenAIToAnthropicConverter` 转换回 Anthropic 格式 40 | 5. 返回给客户端 41 | 42 | ### 主要组件 43 | - **src/api/handlers.py**: 核心消息处理逻辑 44 | - **src/core/converters/**: 请求/响应格式转换器 45 | - **src/core/clients/openai_client.py**: OpenAI API 客户端 46 | - **src/models/**: 数据模型定义 (Anthropic 和 OpenAI) 47 | - **src/config/settings.py**: 配置管理 48 | - **src/common/**: 公共工具(日志、token 计数等) 49 | 50 | ## 开发命令 51 | 52 | ### 安装依赖 53 | ```bash 54 | uv sync 55 | ``` 56 | 57 | ### 运行服务 58 | ```bash 59 | uvicorn src.main:app --reload --host 0.0.0.0 --port 8000 60 | ``` 61 | 62 | ### 运行测试 63 | ```bash 64 | # 运行所有测试 65 | pytest 66 | 67 | # 运行特定测试文 68 | pytest tests/unit/test_models.py 69 | 70 | # 运行集成测试 71 | pytest tests/integration 72 | 73 | # 带覆盖率报告 74 | pytest --cov=src --cov-report=html 75 | ``` 76 | 77 | ### 代码质量检查 78 | ```bash 79 | # 运行 linting 80 | ruff check . 81 | 82 | # 自动修复代码风格 83 | ruff check . --fix 84 | 85 | # 代码格式化 86 | black . 87 | 88 | # 类型检查 89 | mypy src 90 | ``` 91 | 92 | ## 配置说明 93 | 94 | ### 环境变量 95 | - `LOG_LEVEL`: 日志级别 (默认: INFO) 96 | - `CONFIG_PATH`: 配置文件路径 (默认: config/settings.json) 97 | 98 | ### 配置文件 (config/settings.json) 99 | ```json 100 | { 101 | "openai": { 102 | "api_key": "your-openai-api-key-here", 103 | "base_url": "https://api.openai.com/v1" 104 | }, 105 | "server": { 106 | "host": "0.0.0.0", 107 | "port": 8000 108 | }, 109 | "api_key": "your-proxy-api-key-here", 110 | "logging": { 111 | "level": "INFO" 112 | }, 113 | "models": { 114 | "default": "claude-3-5-sonnet-20241022", 115 | "small": "claude-3-5-haiku-20241022", 116 | "tool": "claude-3-5-sonnet-20241022", 117 | "think": "claude-3-7-sonnet-20250219", 118 | "longContext": "claude-3-7-sonnet-20250219" 119 | }, 120 | "parameter_overrides": { 121 | "max_tokens": null, 122 | "temperature": null, 123 | "top_p": null, 124 | "top_k": null 125 | } 126 | } 127 | ``` 128 | 129 | ## API 端点 130 | 131 | - `POST /v1/messages` - Anthropic 消息 API(兼容 OpenAI 聊天 API) 132 | - `GET /health` - 健康检查端点 -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # 构建阶段 2 | FROM python:3.11-slim AS builder 3 | 4 | # 设置工作目录 5 | WORKDIR /app 6 | 7 | # 复制 pyproject.toml 文件 8 | COPY pyproject.toml ./ 9 | 10 | # 安装项目依赖 11 | RUN pip install --no-cache-dir uv && uv sync 12 | 13 | # 运行阶段 14 | FROM python:3.11-slim 15 | 16 | # 设置工作目录 17 | WORKDIR /app 18 | 19 | # 安装 uv 工具 20 | RUN pip install --no-cache-dir uv 21 | 22 | # 从构建阶段复制已安装的依赖 23 | COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages 24 | 25 | # 复制项目文件 26 | COPY . . 27 | 28 | # 设置环境变量 29 | ENV PYTHONDONTWRITEBYTECODE=1 \ 30 | PYTHONUNBUFFERED=1 31 | 32 | # 创建日志和配置目录 33 | RUN mkdir -p /app/logs /app/config 34 | 35 | # 暴露端口 36 | EXPOSE 8000 37 | 38 | # 启动应用 39 | CMD ["uv", "run", "main.py"] 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenAI-to-Claude API Proxy Service 2 | 3 | [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/) 4 | [![FastAPI](https://img.shields.io/badge/FastAPI-0.104.1-009688.svg)](https://fastapi.tiangolo.com/) 5 | [![License](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE) 6 | 7 | High-performance proxy service that converts OpenAI API to Anthropic API compatible format. Allows developers to seamlessly call OpenAI models using existing Anthropic client code. 8 | 9 | [中文版本](README_zh.md) 10 | 11 | ## 🌟 Core Features 12 | 13 | - ✅ **Seamless Compatibility**: Call OpenAI models using standard Anthropic clients 14 | - ✅ **Full Functionality**: Supports text, tool calls, streaming responses, and more 15 | - ✅ **Intelligent Routing**: Automatically selects the most suitable OpenAI model based on request content 16 | - ✅ **Hot Reload**: Automatically reloads configuration file changes without restarting the service 17 | - ✅ **Structured Logging**: Detailed request/response logs for debugging and monitoring 18 | - ✅ **Error Mapping**: Comprehensive error handling and mapping mechanisms 19 | 20 | ## 🚀 Quick Start 21 | 22 | ### Requirements 23 | 24 | - Python 3.11+ 25 | - uv (recommended package manager) 26 | 27 | ### Install Dependencies 28 | 29 | ```bash 30 | # Install dependencies using uv (recommended) 31 | uv sync 32 | ``` 33 | 34 | ### Configuration 35 | 36 | 1. Copy the example configuration file: 37 | ```bash 38 | cp config/example.json config/settings.json 39 | ``` 40 | 41 | 2. Edit `config/settings.json`: 42 | ```json 43 | { 44 | "openai": { 45 | "api_key": "your-openai-api-key-here", // Replace with your OpenAI API key 46 | "base_url": "https://api.openai.com/v1" // OpenAI API address 47 | }, 48 | "api_key": "your-proxy-api-key-here", // API key for the proxy service 49 | // Other configurations... 50 | } 51 | ``` 52 | 53 | ### Start the Service 54 | 55 | ```bash 56 | # Development mode 57 | uv run main.py --config config/settings.json 58 | 59 | # Production mode 60 | uv run main.py 61 | ``` 62 | 63 | ### Start with Docker 64 | 65 | ```bash 66 | # Build and start the service 67 | docker-compose up --build 68 | 69 | # Run in background 70 | docker-compose up --build -d 71 | 72 | # Stop the service 73 | docker-compose down 74 | ``` 75 | 76 | The service will start at `http://localhost:8000`. 77 | 78 | ## 🛠️ Usage 79 | 80 | ### Claude Code Usage 81 | 82 | This project can be used with [Claude Code](https://claude.ai/code) for development and testing. To configure Claude Code to work with this proxy service, create a `.claude/settings.json` file with the following configuration: 83 | 84 | ```json 85 | { 86 | "env": { 87 | "ANTHROPIC_API_KEY": "your-api-key", 88 | "ANTHROPIC_BASE_URL": "http://127.0.0.1:8000", 89 | "DISABLE_TELEMETRY": "1", 90 | "CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC": "1" 91 | }, 92 | "apiKeyHelper": "echo 'your-api-key'", 93 | "permissions": { 94 | "allow": [], 95 | "deny": [] 96 | } 97 | } 98 | ``` 99 | 100 | Configuration Notes: 101 | - Replace `ANTHROPIC_API_KEY` with your API key, in `config/settings.json` 102 | - Replace `ANTHROPIC_BASE_URL` with the actual URL where this proxy service is running 103 | - The `apiKeyHelper` with your API key, in `config/settings.json` 104 | 105 | ### Using Anthropic Python Client 106 | 107 | ```python 108 | from anthropic import Anthropic 109 | 110 | # Initialize client pointing to the proxy service 111 | client = Anthropic( 112 | base_url="http://localhost:8000/v1", 113 | api_key="your-proxy-api-key-here" # Use the api_key from the configuration file 114 | ) 115 | 116 | # Send a message request 117 | response = client.messages.create( 118 | model="gpt-4o", 119 | messages=[ 120 | {"role": "user", "content": "Hello, GPT!"} 121 | ], 122 | max_tokens=1024 123 | ) 124 | 125 | print(response.content[0].text) 126 | ``` 127 | 128 | ### Streaming Response 129 | 130 | ```python 131 | # Streaming response 132 | stream = client.messages.create( 133 | model="gpt-4o", 134 | messages=[ 135 | {"role": "user", "content": "Tell me a story about AI"} 136 | ], 137 | max_tokens=1024, 138 | stream=True 139 | ) 140 | 141 | for chunk in stream: 142 | if chunk.type == "content_block_delta": 143 | print(chunk.delta.text, end="", flush=True) 144 | ``` 145 | 146 | ### Tool Calls 147 | 148 | ```python 149 | # Tool calls 150 | tools = [ 151 | { 152 | "name": "get_current_weather", 153 | "description": "Get the current weather for a specified city", 154 | "input_schema": { 155 | "type": "object", 156 | "properties": { 157 | "city": { 158 | "type": "string", 159 | "description": "City name" 160 | } 161 | }, 162 | "required": ["city"] 163 | } 164 | } 165 | ] 166 | 167 | response = client.messages.create( 168 | model="gpt-4o", 169 | messages=[ 170 | {"role": "user", "content": "What's the weather like in Beijing now?"} 171 | ], 172 | tools=tools, 173 | tool_choice={"type": "auto"} 174 | ) 175 | ``` 176 | 177 | ## 📁 Project Structure 178 | 179 | ``` 180 | openai-to-claude/ 181 | ├── src/ 182 | │ ├── api/ # API endpoints and middleware 183 | │ ├── config/ # Configuration management 184 | │ ├── core/ # Core business logic 185 | │ │ ├── clients/ # HTTP clients 186 | │ │ └── converters/ # Data format converters 187 | │ ├── models/ # Pydantic data models 188 | │ └── common/ # Common utilities (logging, token counting, etc.) 189 | ├── config/ # Configuration files 190 | ├── tests/ # Test suite 191 | ├── CLAUDE.md # Claude Code project guide 192 | └── pyproject.toml # Project dependencies and configuration 193 | ``` 194 | 195 | ## 🔧 Configuration 196 | 197 | ### Environment Variables 198 | 199 | - `CONFIG_PATH`: Configuration file path (default: `config/settings.json`) 200 | - `LOG_LEVEL`: Log level (default: `INFO`) 201 | 202 | ### Configuration File (`config/settings.json`) 203 | 204 | ```json 205 | { 206 | "openai": { 207 | "api_key": "your-openai-api-key-here", 208 | "base_url": "https://api.openai.com/v1" 209 | }, 210 | "server": { 211 | "host": "0.0.0.0", 212 | "port": 8000 213 | }, 214 | "api_key": "your-proxy-api-key-here", 215 | "logging": { 216 | "level": "INFO" 217 | }, 218 | "models": { 219 | "default": "Qwen/Qwen3-Coder", 220 | "small": "deepseek-ai/DeepSeek-V3-0324", 221 | "think": "deepseek-ai/DeepSeek-R1-0528", 222 | "long_context": "gemini-2.5-pro", 223 | "web_search": "gemini-2.5-flash" 224 | }, 225 | "parameter_overrides": { 226 | "max_tokens": null, 227 | "temperature": null, 228 | "top_p": null, 229 | "top_k": null 230 | } 231 | } 232 | ``` 233 | 234 | #### Configuration Items 235 | 236 | - **openai**: OpenAI API configuration 237 | - `api_key`: OpenAI API key for accessing OpenAI services 238 | - `base_url`: OpenAI API base URL, default is `https://api.openai.com/v1` 239 | 240 | - **server**: Server configuration 241 | - `host`: Service listening host address, default is `0.0.0.0` (listen on all network interfaces) 242 | - `port`: Service listening port, default is `8000` 243 | 244 | - **api_key**: API key for the proxy service, used to verify requests to the `/v1/messages` endpoint 245 | 246 | - **logging**: Logging configuration 247 | - `level`: Log level, options are `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`, default is `INFO` 248 | 249 | - **models**: Model configuration, defines model selection for different usage scenarios 250 | - `default`: Default general model for general requests 251 | - `small`: Lightweight model for simple tasks 252 | - `think`: Deep thinking model for complex reasoning tasks 253 | - `long_context`: Long context processing model for handling long text 254 | - `web_search`: Web search model for web search, currently supports geimini 255 | 256 | - **parameter_overrides**: Parameter override configuration, allows administrators to set model parameter override values in the configuration file 257 | - `max_tokens`: Maximum token count override, when set, will override the max_tokens parameter in client requests 258 | - `temperature`: Temperature parameter override, controls the randomness of output, range 0.0-2.0 259 | - `top_p`: top_p sampling parameter override, controls the probability threshold of candidate words, range 0.0-1.0 260 | - `top_k`: top_k sampling parameter override, controls the number of candidate words, range >=0 261 | 262 | ## 🧪 Testing 263 | 264 | ```bash 265 | # Run all tests 266 | pytest 267 | 268 | # Run unit tests 269 | pytest tests/unit 270 | 271 | # Run integration tests 272 | pytest tests/integration 273 | 274 | # Generate coverage report 275 | pytest --cov=src --cov-report=html 276 | ``` 277 | 278 | ## 📊 API Endpoints 279 | 280 | - `POST /v1/messages` - Anthropic Messages API 281 | - `GET /health` - Health check endpoint 282 | - `GET /` - Welcome page 283 | 284 | ## 🛡️ Security 285 | 286 | - API key authentication 287 | - Request rate limiting (planned) 288 | - Input validation and sanitization 289 | - Structured logging 290 | 291 | ## 📈 Performance Monitoring 292 | 293 | - Request/response time monitoring 294 | - Memory usage tracking 295 | - Error rate statistics 296 | 297 | ## 🤝 Contributing 298 | 299 | Issues and Pull Requests are welcome! 300 | 301 | ## 📄 License 302 | 303 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 304 | 305 | ## 🙏 Acknowledgements 306 | 307 | - [claude-code-router](https://github.com/musistudio/claude-code-router) - Very good project, many places in this project have referenced this project 308 | - [FastAPI](https://fastapi.tiangolo.com/) - Modern high-performance web framework 309 | - [Anthropic](https://www.anthropic.com/) - Claude AI models 310 | - [OpenAI](https://openai.com/) - OpenAI API specification 311 | 312 | ## 🤖 Claude Code Usage 313 | 314 | This project can be used with [Claude Code](https://claude.ai/code) for development and testing. To configure Claude Code to work with this proxy service, create a `.claude/settings.json` file with the following configuration: 315 | 316 | ### Example Configuration 317 | 318 | ```json 319 | { 320 | "env": { 321 | "ANTHROPIC_API_KEY": "sk-chen0v0...", 322 | "ANTHROPIC_BASE_URL": "http://127.0.0.1:8100", 323 | "DISABLE_TELEMETRY": "1", 324 | "CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC": "1" 325 | }, 326 | "apiKeyHelper": "echo 'sk-chen0v0...'", 327 | "permissions": { 328 | "allow": [], 329 | "deny": [] 330 | } 331 | } 332 | ``` 333 | 334 | ### Configuration Notes 335 | 336 | - Replace `ANTHROPIC_API_KEY` with your actual Anthropic API key 337 | - Replace `ANTHROPIC_BASE_URL` with the actual URL where this proxy service is running 338 | - The `apiKeyHelper` field should also be updated with your actual API key -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | # OpenAI-to-Claude API 代理服务 2 | 3 | [![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/) 4 | [![FastAPI](https://img.shields.io/badge/FastAPI-0.104.1-009688.svg)](https://fastapi.tiangolo.com/) 5 | [![License](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE) 6 | 7 | 将 OpenAI API 转为 Anthropic API 兼容格式的高性能代理服务。允许开发者使用现有的 Anthropic 客户端代码无缝调用 OpenAI 模型。 8 | 9 | ## 🌟 核心特性 10 | 11 | - ✅ **无缝兼容**: 使用标准 Anthropic 客户端调用 OpenAI 模型 12 | - ✅ **完整功能**: 支持文本、工具调用、流式响应等功能 13 | - ✅ **智能路由**: 根据请求内容自动选择最适合的 OpenAI 模型 14 | - ✅ **热重载**: 配置文件修改后自动重载,无需重启服务 15 | - ✅ **结构化日志**: 详细的请求/响应日志,便于调试和监控 16 | - ✅ **错误映射**: 完善的错误处理和映射机制 17 | 18 | ## 🚀 快速开始 19 | 20 | ### 环境要求 21 | 22 | - Python 3.11+ 23 | - uv (推荐的包管理器) 24 | 25 | ### 安装依赖 26 | 27 | ```bash 28 | # 使用 uv 安装依赖(推荐) 29 | uv sync 30 | ``` 31 | 32 | ### 配置 33 | 34 | 1. 复制示例配置文件: 35 | ```bash 36 | cp config/example.json config/settings.json 37 | ``` 38 | 39 | 2. 编辑 `config/settings.json`: 40 | ```json 41 | { 42 | "openai": { 43 | "api_key": "your-openai-api-key-here", // 替换为你的 OpenAI API 密钥 44 | "base_url": "https://api.openai.com/v1" // OpenAI API 地址 45 | }, 46 | "api_key": "your-proxy-api-key-here", // 代理服务的 API 密钥 47 | // 其他配置... 48 | } 49 | ``` 50 | 51 | ### 启动服务 52 | 53 | ```bash 54 | # 开发模式 55 | uv run main.py --config config/settings.json 56 | 57 | # 生产模式 58 | uv run main.py 59 | ``` 60 | 61 | ### 使用 Docker 启动 62 | 63 | ```bash 64 | # 构建并启动服务 65 | docker-compose up --build 66 | 67 | # 后台运行 68 | docker-compose up --build -d 69 | 70 | # 停止服务 71 | docker-compose down 72 | ``` 73 | 74 | 服务将在 `http://localhost:8000` 启动。 75 | 76 | ## 🛠️ 使用方法 77 | 78 | ### Claude Code 使用方法 79 | 80 | 本项目可以与 [Claude Code](https://claude.ai/code) 一起使用进行开发和测试。要配置 Claude Code 以使用此代理服务,请创建一个 `.claude/settings.json` 文件,配置如下: 81 | 82 | ```json 83 | { 84 | "env": { 85 | "ANTHROPIC_API_KEY": "your-api-key", 86 | "ANTHROPIC_BASE_URL": "http://127.0.0.1:8000", 87 | "DISABLE_TELEMETRY": "1", 88 | "CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC": "1" 89 | }, 90 | "apiKeyHelper": "echo 'your-api-key'", 91 | "permissions": { 92 | "allow": [], 93 | "deny": [] 94 | } 95 | } 96 | ``` 97 | 98 | 配置说明: 99 | - 将 `ANTHROPIC_API_KEY` 替换为您配置的 API 密钥,在 `config/settings.json` 中 100 | - 将 `ANTHROPIC_BASE_URL` 替换为此代理服务实际运行的 URL 101 | - `apiKeyHelper` 字段也应更新为您的 API 密钥 102 | 103 | ### 使用 Anthropic Python 客户端 104 | 105 | ```python 106 | from anthropic import Anthropic 107 | 108 | # 初始化客户端,指向代理服务 109 | client = Anthropic( 110 | base_url="http://localhost:8000/v1", 111 | api_key="your-proxy-api-key-here" # 使用配置文件中的 api_key 112 | ) 113 | 114 | # 发送消息请求 115 | response = client.messages.create( 116 | model="gpt-4o", 117 | messages=[ 118 | {"role": "user", "content": "你好,GPT!"} 119 | ], 120 | max_tokens=1024 121 | ) 122 | 123 | print(response.content[0].text) 124 | ``` 125 | 126 | ### 流式响应 127 | 128 | ```python 129 | # 流式响应 130 | stream = client.messages.create( 131 | model="gpt-4o", 132 | messages=[ 133 | {"role": "user", "content": "给我讲一个关于 AI 的故事"} 134 | ], 135 | max_tokens=1024, 136 | stream=True 137 | ) 138 | 139 | for chunk in stream: 140 | if chunk.type == "content_block_delta": 141 | print(chunk.delta.text, end="", flush=True) 142 | ``` 143 | 144 | ### 工具调用 145 | 146 | ```python 147 | # 工具调用 148 | tools = [ 149 | { 150 | "name": "get_current_weather", 151 | "description": "获取指定城市的当前天气", 152 | "input_schema": { 153 | "type": "object", 154 | "properties": { 155 | "city": { 156 | "type": "string", 157 | "description": "城市名称" 158 | } 159 | }, 160 | "required": ["city"] 161 | } 162 | } 163 | ] 164 | 165 | response = client.messages.create( 166 | model="gpt-4o", 167 | messages=[ 168 | {"role": "user", "content": "北京现在的天气怎么样?"} 169 | ], 170 | tools=tools, 171 | tool_choice={"type": "auto"} 172 | ) 173 | ``` 174 | 175 | ## 📁 项目结构 176 | 177 | ``` 178 | openai-to-claude/ 179 | ├── src/ 180 | │ ├── api/ # API 端点和中间件 181 | │ ├── config/ # 配置管理 182 | │ ├── core/ # 核心业务逻辑 183 | │ │ ├── clients/ # HTTP 客户端 184 | │ │ └── converters/ # 数据格式转换器 185 | │ ├── models/ # Pydantic 数据模型 186 | │ └── common/ # 公共工(日志、token计数等) 187 | ├── config/ # 配置文件 188 | ├── tests/ # 测试套件 189 | ├── CLAUDE.md # Claude Code 项目指导 190 | └── pyproject.toml # 项目依赖和配置 191 | ``` 192 | 193 | ## 🤖 Claude Code 使用方法 194 | 195 | 本项目可以与 [Claude Code](https://claude.ai/code) 一起使用进行开发和测试。要配置 Claude Code 以使用此代理服务,请创建一个 `.claude/settings.json` 文件,配置如下: 196 | 197 | ### 示例配置文件 198 | 199 | ```json 200 | { 201 | "env": { 202 | "ANTHROPIC_API_KEY": "sk-chen0v0...", 203 | "ANTHROPIC_BASE_URL": "http://127.0.0.1:8100", 204 | "DISABLE_TELEMETRY": "1", 205 | "CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC": "1" 206 | }, 207 | "apiKeyHelper": "echo 'sk-chen0v0...'", 208 | "permissions": { 209 | "allow": [], 210 | "deny": [] 211 | } 212 | } 213 | ``` 214 | 215 | ### 配置说明 216 | 217 | - 将 `ANTHROPIC_API_KEY` 替换为您的实际 Anthropic API 密钥 218 | - 将 `ANTHROPIC_BASE_URL` 替换为此代理服务实际运行的 URL 219 | - `apiKeyHelper` 字段也应更新为您的实际 API 密钥 220 | 221 | ## 🔧 配置说明 222 | 223 | ### 环境变量 224 | 225 | - `CONFIG_PATH`: 配置文件路径 (默认: `config/settings.json`) 226 | - `LOG_LEVEL`: 日志级别 (默认: `INFO`) 227 | 228 | ### 配置文件 (`config/settings.json`) 229 | 230 | ```json 231 | { 232 | "openai": { 233 | "api_key": "your-openai-api-key-here", 234 | "base_url": "https://api.openai.com/v1" 235 | }, 236 | "server": { 237 | "host": "0.0.0.0", 238 | "port": 8000 239 | }, 240 | "api_key": "your-proxy-api-key-here", 241 | "logging": { 242 | "level": "INFO" 243 | }, 244 | "models": { 245 | "default": "Qwen/Qwen3-Coder", 246 | "small": "deepseek-ai/DeepSeek-V3-0324", 247 | "think": "deepseek-ai/DeepSeek-R1-0528", 248 | "long_context": "gemini-2.5-pro", 249 | "web_search": "gemini-2.5-flash" 250 | }, 251 | "parameter_overrides": { 252 | "max_tokens": null, 253 | "temperature": null, 254 | "top_p": null, 255 | "top_k": null 256 | } 257 | } 258 | ``` 259 | 260 | #### 配置项说明 261 | 262 | - **openai**: OpenAI API 配置 263 | - `api_key`: OpenAI API 密钥,用于访问 OpenAI 服务 264 | - `base_url`: OpenAI API 基础 URL,默认为 `https://api.openai.com/v1` 265 | 266 | - **server**: 服务器配置 267 | - `host`: 服务监听主机地址,默认为 `0.0.0.0`(监听所有网络接口) 268 | - `port`: 服务监听端口,默认为 `8000` 269 | 270 | - **api_key**: 代理服务的 API 密钥,用于验证访问 `/v1/messages` 端点的请求 271 | 272 | - **logging**: 日志配置 273 | - `level`: 日志级别,可选值为 `DEBUG`、`INFO`、`WARNING`、`ERROR`、`CRITICAL`,默认为 `INFO` 274 | 275 | - **models**: 模型配置,定义不同使用场景下的模型选择 276 | - `default`: 默认通用模型,用于一般请求 277 | - `small`: 轻量级模型,用于简单任务 278 | - `think`: 深度思考模型,用于复杂推理任务 279 | - `long_context`: 长上下文处理模型,用于处理长文本 280 | - `web_search`: 网络搜索模型,用于网络搜索,目前支持geimini 281 | 282 | - **parameter_overrides**: 参数覆盖配置,允许管理员在配置文件中设置模型参数的覆盖值 283 | - `max_tokens`: 最大 token 数覆盖,设置后会覆盖客户端请求中的 max_tokens 参数 284 | - `temperature`: 温度参数覆盖,控制输出的随机程度,范围为 0.0-2.0 285 | - `top_p`: top_p 采样参数覆盖,控制候选词汇的概率阈值,范围为 0.0-1.0 286 | - `top_k`: top_k 采样参数覆盖,控制候选词汇的数量,范围为 >=0 287 | 288 | ## 🧪 测试 289 | 290 | ```bash 291 | # 运行所有测试 292 | pytest 293 | 294 | # 运行单元测试 295 | pytest tests/unit 296 | 297 | # 运行集成测试 298 | pytest tests/integration 299 | 300 | # 生成覆盖率报告 301 | pytest --cov=src --cov-report=html 302 | ``` 303 | 304 | ## 📊 API 端点 305 | 306 | - `POST /v1/messages` - Anthropic 消息 API 307 | - `GET /health` - 健康检查端点 308 | - `GET /` - 欢迎页面 309 | 310 | ## 🛡️ 安全性 311 | 312 | - API 密钥验证 313 | - 请求频率限制(计划中) 314 | - 输入验证和清理 315 | - 结构化日志记录 316 | 317 | ## 📈 性能监控 318 | 319 | - 请求/响应时间监控 320 | - 内存使用情况跟踪 321 | - 错误率统计 322 | 323 | ## 🤝 贡献 324 | 325 | 欢迎提交 Issue 和 Pull Request! 326 | 327 | ## 📄 许可证 328 | 329 | 本项目采用 MIT 许可证 - 查看 [LICENSE](LICENSE) 文件了解详情。 330 | 331 | ## 🙏 致谢 332 | 333 | - [claude-code-router](https://github.com/musistudio/claude-code-router) - 很好的项目,本项目很多地方参考了这个项目 334 | - [FastAPI](https://fastapi.tiangolo.com/) - 现代高性能 Web 框架 335 | - [Anthropic](https://www.anthropic.com/) - Claude AI 模型 336 | - [OpenAI](https://openai.com/) - OpenAI API 规范 -------------------------------------------------------------------------------- /config/example.json: -------------------------------------------------------------------------------- 1 | { 2 | "openai": { 3 | "api_key": "your-openai-api-key-here", 4 | "base_url": "https://api.openai.com/v1" 5 | }, 6 | "server": { 7 | "host": "0.0.0.0", 8 | "port": 8000 9 | }, 10 | "api_key": "your-proxy-api-key-here", 11 | "logging": { 12 | "level": "INFO" 13 | }, 14 | "models": { 15 | "default": "Qwen/Qwen3-Coder", 16 | "small": "deepseek-ai/DeepSeek-V3-0324", 17 | "think": "deepseek-ai/DeepSeek-R1-0528", 18 | "long_context": "gemini-2.5-pro", 19 | "web_search": "gemini-2.5-flash" 20 | }, 21 | "parameter_overrides": { 22 | "max_tokens": null, 23 | "temperature": null, 24 | "top_p": null, 25 | "top_k": null 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | proxy: 5 | container_name: openai-to-claude # 指定容器名称 6 | build: . 7 | ports: 8 | - "8000:8000" # 将主机的8000端口映射到容器的8000端口 9 | volumes: 10 | - ./config:/app/config # 挂载配置文件目录 11 | - ./logs:/app/logs # 挂载日志目录 12 | restart: unless-stopped # 容器在停止后会自动重启,除非手动停止 13 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Anthropic-OpenAI Proxy 启动脚本 4 | 5 | 使用 JSON 配置文件中的 host 和 port 启动服务器,而不是命令行参数。 6 | 配置优先级: 7 | 1. 命令行指定的 config 参数 8 | 2. 环境变量 CONFIG_PATH 指定的路径 9 | 3. ./config/settings.json (默认) 10 | 4. ./config/example.json (模板) 11 | """ 12 | 13 | import asyncio 14 | import uvicorn 15 | import os 16 | import sys 17 | import argparse 18 | from pathlib import Path 19 | 20 | # 添加 src 目录到 Python 路径 21 | sys.path.insert(0, str(Path(__file__).parent / "src")) 22 | 23 | from src.config.settings import Config 24 | 25 | 26 | def main(): 27 | """主启动函数""" 28 | try: 29 | parser = argparse.ArgumentParser(description="启动 Anthropic-OpenAI Proxy") 30 | parser.add_argument( 31 | "--config", type=str, help="JSON 配置文件路径 (默认为 config/settings.json)" 32 | ) 33 | parser.add_argument( 34 | "--config-path", 35 | type=str, 36 | default="config/settings.json", 37 | help="配置文件路径,可通过 CONFIG_PATH 环境变量指定", 38 | ) 39 | 40 | args = parser.parse_args() 41 | 42 | # 确保从项目根目录启动 43 | project_root = Path(__file__).parent 44 | os.chdir(project_root) 45 | 46 | # 确定配置文件路径 47 | config_path = args.config or os.getenv("CONFIG_PATH", args.config_path) 48 | 49 | # 同步加载配置 50 | config = Config.from_file_sync(config_path) 51 | 52 | # 获取服务器配置 53 | host, port = config.get_server_config() 54 | 55 | print(f"🚀 启动 OpenAI To Claude Server...") 56 | print(f" 配置文件: {config_path}") 57 | print(f" 监听地址: {host}:{port}") 58 | print() 59 | print("📋 重要端点:") 60 | print(f" 健康检查: http://{host}:{port}/health") 61 | print(f" API文档: http://{host}:{port}/docs") 62 | print(f" OpenAPI: http://{host}:{port}/openapi.json") 63 | print() 64 | 65 | # 启动 Uvicorn 服务器 66 | uvicorn.run( 67 | "src.main:app", 68 | host=host, 69 | port=port, 70 | # reload=True, 71 | timeout_keep_alive=60, 72 | log_level=config.logging.level.lower(), 73 | ) 74 | except Exception as e: 75 | print(f"❌ 启动失败: {e}") 76 | sys.exit(1) 77 | 78 | 79 | if __name__ == "__main__": 80 | main() 81 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "openai-to-claude" 3 | version = "0.1.0" 4 | description = "OpenAI to Anthropic API 兼容代理服务" 5 | readme = "README.md" 6 | requires-python = ">=3.11" 7 | authors = [ 8 | { name = "Dev Team", email = "dev@example.com" } 9 | ] 10 | keywords = ["anthropic", "openai", "api", "proxy", "fastapi"] 11 | classifiers = [ 12 | "Development Status :: 3 - Alpha", 13 | "Intended Audience :: Developers", 14 | "Programming Language :: Python :: 3", 15 | "Programming Language :: Python :: 3.11", 16 | "Programming Language :: Python :: 3.12", 17 | "Operating System :: OS Independent", 18 | "Topic :: Software Development :: Libraries :: Python Modules", 19 | "Topic :: Internet :: WWW/HTTP :: HTTP Servers", 20 | ] 21 | 22 | dependencies = [ 23 | "fastapi>=0.104.1,<0.105.0", 24 | "uvicorn[standard]>=0.24.0,<0.25.0", 25 | "httpx>=0.25.0,<0.26.0", 26 | "pydantic>=2.5.0,<3.0.0", 27 | "pydantic-settings>=2.1.0,<3.0.0", 28 | "python-multipart>=0.0.6", 29 | "uvloop>=0.19.0,<0.20.0", 30 | "orjson>=3.9.0,<4.0.0", 31 | "loguru>=0.7.0,<1.0.0", 32 | "tenacity>=8.2.0,<9.0.0", 33 | "aiofiles>=23.2.0,<24.0.0", 34 | "pytest>=7.4.4", 35 | "pytest-asyncio>=0.23.8", 36 | "tiktoken>=0.7.0,<1.0.0", 37 | "watchdog>=3.0.0,<4.0.0", 38 | ] 39 | 40 | [project.optional-dependencies] 41 | dev = [ 42 | "pytest>=7.4.0,<8.0.0", 43 | "pytest-asyncio>=0.21.0,<1.0.0", 44 | "pytest-cov>=4.1.0,<5.0.0", 45 | "httpx>=0.25.0,<0.26.0", # 用于测试 46 | "mypy>=1.7.0,<2.0.0", 47 | "ruff>=0.2.0,<1.0.0", 48 | "black>=23.11.0,<24.0.0", 49 | ] 50 | 51 | [project.urls] 52 | Homepage = "https://github.com/example/openai-to-claude" 53 | Repository = "https://github.com/example/openai-to-claude" 54 | Issues = "https://github.com/example/openai-to-claude/issues" 55 | 56 | [tool.uv] 57 | package = true 58 | 59 | [tool.hatch.build.targets.wheel] 60 | packages = ["src"] 61 | 62 | [tool.hatch.build.targets.sdist] 63 | packages = ["src"] 64 | 65 | [tool.pytest.ini_options] 66 | testpaths = ["tests"] 67 | pythonpath = ["."] 68 | python_files = ["test_*.py", "*_test.py"] 69 | python_classes = ["Test*"] 70 | python_functions = ["test_*"] 71 | addopts = [ 72 | "--cov=src", 73 | "--cov-report=term-missing", 74 | "--cov-report=html", 75 | "--cov-report=xml", 76 | "-v", 77 | "--strict-markers", 78 | ] 79 | asyncio_mode = "auto" 80 | markers = [ 81 | "integration: mark test as integration test requiring mock server", 82 | "streaming: mark test as streaming response test", 83 | "timeout: mark test as timeout or error handling test", 84 | "slow: mark test as slow running", 85 | "e2e: mark test as end-to-end integration test" 86 | ] 87 | 88 | [tool.ruff] 89 | line-length = 88 90 | target-version = "py311" 91 | 92 | [tool.ruff.lint] 93 | select = [ 94 | "E", # pycodestyle errors 95 | "W", # pycodestyle warnings 96 | "F", # pyflakes 97 | "I", # isort 98 | "B", # flake8-bugbear 99 | "C4", # flake8-comprehensions 100 | "UP", # pyupgrade 101 | ] 102 | ignore = [ 103 | "E501", # line too long, handled by black 104 | "B008", # do not perform function calls in argument defaults 105 | "C901", # too complex 106 | ] 107 | 108 | [tool.ruff.lint.per-file-ignores] 109 | "__init__.py" = ["F401"] 110 | "tests/**/*" = ["B011", "B017"] 111 | 112 | [tool.black] 113 | line-length = 88 114 | target-version = ["py311"] 115 | include = '\.pyi?$' 116 | extend-exclude = ''' 117 | /( 118 | # directories 119 | \.eggs 120 | | \.git 121 | | \.hg 122 | | \.mypy_cache 123 | | \.tox 124 | | \.venv 125 | | build 126 | | dist 127 | )/ 128 | ''' 129 | 130 | [tool.mypy] 131 | python_version = "3.11" 132 | warn_return_any = true 133 | warn_unused_configs = true 134 | disallow_untyped_defs = true 135 | disallow_incomplete_defs = true 136 | check_untyped_defs = true 137 | disallow_untyped_decorators = true 138 | no_implicit_optional = true 139 | warn_redundant_casts = true 140 | warn_unused_ignores = true 141 | warn_no_return = true 142 | warn_unreachable = true 143 | strict_equality = true 144 | 145 | [dependency-groups] 146 | dev = [ 147 | "requests>=2.32.4", 148 | ] 149 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | OpenAI-Claude Code Proxy 3 | 4 | 一个高性能的 Anthropic Claude API 到 OpenAI API 格式的代理服务。 5 | 提供完整的请求/响应转换、流式处理、错误处理和配置管理功能。 6 | 7 | 主要功能: 8 | - Anthropic Claude API 到 OpenAI API 格式的双向转换 9 | - 支持流式和非流式响应 10 | - 完整的工具调用支持 11 | - 请求ID追踪和日志记录 12 | - 配置文件热重载 13 | - 性能监控和错误处理 14 | 15 | 使用示例: 16 | from src import create_app 17 | 18 | app = create_app() 19 | """ 20 | 21 | __version__ = "1.0.0" 22 | __author__ = "OpenAI-Claude Code Proxy Team" 23 | __description__ = "Anthropic Claude API to OpenAI API format proxy service" 24 | 25 | # 导出主要的公共API 26 | from .common import configure_logging, request_logger 27 | from .config import get_config, reload_config 28 | from .main import app 29 | 30 | __all__ = [ 31 | "app", 32 | "get_config", 33 | "reload_config", 34 | "configure_logging", 35 | "request_logger", 36 | "__version__", 37 | "__author__", 38 | "__description__", 39 | ] 40 | -------------------------------------------------------------------------------- /src/api/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | API模块 3 | 4 | 提供FastAPI应用的路由、处理器和中间件。 5 | 6 | 主要功能: 7 | - API路由定义 8 | - 请求处理器 9 | - 中间件集成 10 | - 健康检查端点 11 | 12 | 子模块: 13 | - handlers: API请求处理器 14 | - routes: 路由定义 15 | - middleware: 中间件实现 16 | """ 17 | 18 | # 导入路由和处理器 19 | from .handlers import MessagesHandler, messages_endpoint 20 | from .handlers import router as handlers_router 21 | 22 | # 导入中间件 23 | from .middleware import ( 24 | APIKeyMiddleware, 25 | RequestTimingMiddleware, 26 | setup_middlewares, 27 | ) 28 | from .routes import health_check 29 | from .routes import router as routes_router 30 | 31 | __all__ = [ 32 | # 路由 33 | "routes_router", 34 | "handlers_router", 35 | "health_check", 36 | "health_check_detailed", 37 | # 处理器 38 | "MessagesHandler", 39 | "messages_endpoint", 40 | # 中间件 41 | "APIKeyMiddleware", 42 | "RequestTimingMiddleware", 43 | "setup_middlewares", 44 | ] 45 | -------------------------------------------------------------------------------- /src/api/handlers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Anthropic /v1/messages 端点处理程序 3 | 4 | 实现Anthropic native messages API与OpenAI API的转换和代理 5 | """ 6 | 7 | import asyncio 8 | import json 9 | from collections.abc import AsyncGenerator 10 | 11 | from fastapi import APIRouter, BackgroundTasks, HTTPException, Request 12 | from fastapi.responses import JSONResponse, StreamingResponse 13 | from pydantic import ValidationError 14 | 15 | from src.core.clients.openai_client import OpenAIServiceClient 16 | from src.core.converters.request_converter import ( 17 | AnthropicToOpenAIConverter, 18 | ) 19 | from src.core.converters.response_converter import OpenAIToAnthropicConverter 20 | from src.models.anthropic import ( 21 | AnthropicMessageResponse, 22 | AnthropicRequest, 23 | ) 24 | from src.models.errors import get_error_response 25 | 26 | router = APIRouter(prefix="/v1", tags=["messages"]) 27 | 28 | 29 | class MessagesHandler: 30 | """处理Anthropic /v1/messages 端点请求""" 31 | 32 | def __init__(self, config): 33 | self.request_converter = AnthropicToOpenAIConverter() 34 | self.response_converter = OpenAIToAnthropicConverter() 35 | self.config = config 36 | self._config = None 37 | self.client = OpenAIServiceClient( 38 | api_key=config.openai.api_key, 39 | base_url=config.openai.base_url, 40 | ) 41 | 42 | @classmethod 43 | async def create(cls, config=None): 44 | """异步工厂方法创建 MessagesHandler 实例""" 45 | if config is None: 46 | from src.config.settings import get_config 47 | 48 | config = await get_config() 49 | 50 | instance = cls.__new__(cls) 51 | instance.request_converter = AnthropicToOpenAIConverter() 52 | instance.response_converter = OpenAIToAnthropicConverter() 53 | instance.config = config 54 | instance._config = config 55 | instance.client = OpenAIServiceClient( 56 | api_key=config.openai.api_key, 57 | base_url=config.openai.base_url, 58 | ) 59 | return instance 60 | 61 | async def process_message( 62 | self, request: AnthropicRequest, request_id: str = None 63 | ) -> AnthropicMessageResponse: 64 | """处理非流式消息请求""" 65 | # 获取绑定了请求ID的logger 66 | from src.common.logging import get_logger_with_request_id 67 | 68 | bound_logger = get_logger_with_request_id(request_id) 69 | 70 | try: 71 | bound_logger.debug("处理非流式请求") 72 | # 验证请求 73 | # await validate_anthropic_request(request, request_id) 74 | # 将 Anthropic 请求转换为 OpenAI 格式(异步) 75 | openai_request = await self.request_converter.convert_anthropic_to_openai( 76 | request, request_id 77 | ) 78 | 79 | # 发送到 OpenAI 80 | openai_response = await self.client.send_request( 81 | openai_request, request_id=request_id 82 | ) 83 | bound_logger.debug( 84 | f"OpenAI 响应: {json.dumps(openai_response, ensure_ascii=False)}" 85 | ) 86 | 87 | # 将 OpenAI 响应转回 Anthropic 格式 88 | anthropic_response = await self.response_converter.convert_response( 89 | openai_response, request.model, request_id 90 | ) 91 | # 安全地提取响应文本 92 | response_text = "empty" 93 | if ( 94 | anthropic_response.content 95 | and len(anthropic_response.content) > 0 96 | and hasattr(anthropic_response.content[0], "text") 97 | and anthropic_response.content[0].text 98 | ): 99 | response_text = anthropic_response.content[0].text 100 | bound_logger.info( 101 | f"Anthropic 响应生成完成 - Text: {response_text[:100]}..., Usage: {anthropic_response.usage}" 102 | ) 103 | 104 | return anthropic_response 105 | 106 | except ValidationError as e: 107 | bound_logger.warning(f"Validation error - Errors: {e.errors()}") 108 | error_response = get_error_response( 109 | 422, details={"validation_errors": e.errors(), "request_id": request_id} 110 | ) 111 | raise HTTPException(status_code=422, detail=error_response.model_dump()) 112 | 113 | except json.JSONDecodeError as e: 114 | # 专门处理JSON解析错误,这通常发生在OpenAI响应解析时 115 | bound_logger.exception( 116 | f"JSON解析错误 - Error: {str(e)}, Position: {e.pos if hasattr(e, 'pos') else 'unknown'}" 117 | ) 118 | error_response = get_error_response( 119 | 502, 120 | message="上游服务返回无效JSON格式", 121 | details={"json_error": str(e), "request_id": request_id}, 122 | ) 123 | raise HTTPException(status_code=502, detail=error_response.model_dump()) 124 | except HTTPException as e: 125 | bound_logger.exception( 126 | f"处理非流式消息请求错误 - Type: {type(e).__name__}, Error: {str(e)}" 127 | ) 128 | error_response = get_error_response( 129 | e.status_code, message=str(e.detail), details={"request_id": request_id} 130 | ) 131 | raise HTTPException( 132 | status_code=e.status_code, 133 | detail=error_response.model_dump(exclude_none=True), 134 | ) 135 | 136 | except Exception as e: 137 | bound_logger.exception( 138 | f"处理非流式消息请求错误 - Type: {type(e).__name__}, Error: {str(e)}" 139 | ) 140 | error_response = get_error_response( 141 | 500, message=str(e), details={"request_id": request_id} 142 | ) 143 | raise HTTPException( 144 | status_code=500, detail=error_response.model_dump(exclude_none=True) 145 | ) 146 | 147 | async def process_stream_message( 148 | self, request: AnthropicRequest, request_id: str = None 149 | ) -> AsyncGenerator[str, None]: 150 | """处理流式消息请求,使用新的流式转换器""" 151 | if not request.stream: 152 | raise ValueError("流式响应参数必须为true") 153 | 154 | # 获取绑定了请求ID的logger 155 | from src.common.logging import get_logger_with_request_id 156 | 157 | bound_logger = get_logger_with_request_id(request_id) 158 | 159 | try: 160 | # await validate_anthropic_request(request, request_id) 161 | openai_request = await self.request_converter.convert_anthropic_to_openai( 162 | request, request_id 163 | ) 164 | 165 | # 创建 OpenAI 流式数据源 166 | async def openai_stream_generator(): 167 | bound_logger.info("开始OpenAI流式生成") 168 | chunk_count = 0 169 | async for chunk in self.client.send_streaming_request( 170 | openai_request, request_id=request_id 171 | ): 172 | # 跳过被解析器过滤掉的不完整chunk(通常是tool_calls片段) 173 | if chunk is not None: 174 | chunk_count += 1 175 | # 将 OpenAI 响应对象转换为字符串格式 176 | bound_logger.debug(f"OpenAI event: {chunk}") 177 | yield f"{chunk}\n\n" 178 | bound_logger.debug(f"OpenAI流式生成完成,总共{chunk_count}个chunk") 179 | 180 | # 使用新的流式转换器 181 | bound_logger.info("开始流式转换") 182 | async for ( 183 | anthropic_event 184 | ) in self.response_converter.convert_openai_stream_to_anthropic_stream( 185 | openai_stream_generator(), model=request.model, request_id=request_id 186 | ): 187 | bound_logger.debug(f"Anthropic event: {anthropic_event}") 188 | yield anthropic_event 189 | bound_logger.info("流式转换完成") 190 | 191 | except (ValidationError, ValueError) as e: 192 | error_detail = e.errors() if hasattr(e, "errors") else str(e) 193 | bound_logger.warning(f"流式请求验证失败 - Errors: {error_detail}") 194 | error_response = get_error_response(422, message=str(error_detail)) 195 | # 在错误响应中添加请求ID 196 | error_data = error_response.model_dump() 197 | if request_id: 198 | error_data["request_id"] = request_id 199 | yield f"event: error\ndata: {json.dumps(error_data, ensure_ascii=False)}\n\n" 200 | 201 | except json.JSONDecodeError as e: 202 | # 专门处理流式模式下的JSON解析错误 203 | bound_logger.exception( 204 | f"流式模式JSON解析错误 - Error: {str(e)}, Position: {e.pos if hasattr(e, 'pos') else 'unknown'}" 205 | ) 206 | error_response = get_error_response( 207 | 502, 208 | message="流式响应中发现无效JSON格式", 209 | details={"json_error": str(e), "request_id": request_id}, 210 | ) 211 | error_data = error_response.model_dump() 212 | if request_id: 213 | error_data["request_id"] = request_id 214 | yield f"event: error\ndata: {json.dumps(error_data, ensure_ascii=False)}\n\n" 215 | 216 | except Exception as e: 217 | bound_logger.exception( 218 | f"流式请求处理错误 - Type: {type(e).__name__}, Error: {str(e)}" 219 | ) 220 | error_response = get_error_response(500, message=str(e)) 221 | # 在错误响应中添加请求ID 222 | error_data = error_response.model_dump() 223 | if request_id: 224 | error_data["request_id"] = request_id 225 | yield f"event: error\ndata: {json.dumps(error_data, ensure_ascii=False)}\n\n" 226 | 227 | 228 | @router.post("/messages") 229 | async def messages_endpoint(request: Request, background_tasks: BackgroundTasks): 230 | """ 231 | Anthropic /v1/messages 端点 232 | 233 | 这个端点实现了Anthropic原生messages API的主要功能: 234 | - 接受Anthropic格式的请求 235 | - 转换为OpenAI格式发送到后端 236 | - 返回Anthropic格式的响应 237 | """ 238 | # 从应用状态获取消息处理器(已由main.py在启动时初始化) 239 | handler: MessagesHandler = request.app.state.messages_handler 240 | 241 | # 获取请求ID(由中间件生成,如果启用的话) 242 | from src.common.logging import ( 243 | get_logger_with_request_id, 244 | get_request_id_from_request, 245 | ) 246 | 247 | request_id = get_request_id_from_request(request) 248 | bound_logger = get_logger_with_request_id(request_id) 249 | 250 | # 记录请求 251 | client_ip = request.client.host if request.client else "unknown" 252 | bound_logger.info( 253 | f"收到Anthropic请求 - Method: {request.method}, URL: {str(request.url)}, IP: {client_ip}" 254 | ) 255 | 256 | try: 257 | # 解析请求体 258 | body = await request.json() 259 | # 记录请求 260 | log_body = body.copy() 261 | log_body["tools"] = [] 262 | bound_logger.debug( 263 | f"Anthropic请求体 - Model: {body.get('model', 'unknown')}, Messages: {len(body.get('messages', []))}, Stream: {body.get('stream', False)}\n{json.dumps(log_body, ensure_ascii=False, indent=2)}" 264 | ) 265 | 266 | anthropic_request = AnthropicRequest(**body) 267 | 268 | # 记录清理后的请求信息(移除敏感信息) 269 | # safe_body = sanitize_for_logging(body) 270 | # logger.debug("请求已清理", request_body=safe_body) 271 | 272 | # 根据请求类型处理响应 273 | if anthropic_request.stream: 274 | # 流式响应 - 优化配置确保真正的流式效果 275 | async def stream_wrapper(): 276 | """包装器确保流式响应的立即传输""" 277 | try: 278 | async for chunk in handler.process_stream_message( 279 | anthropic_request, request_id=request_id 280 | ): 281 | # 立即传输每个chunk,不缓冲 282 | # chunk已经是完整的SSE格式字符串,编码后返回 283 | yield chunk.encode("utf-8") 284 | # 强制刷新缓冲区(在某些环境中有效) 285 | await asyncio.sleep(0) 286 | except Exception as e: 287 | # 如果流式处理出错,记录完整错误并发送错误事件 288 | bound_logger.exception(f"流式处理出错 - Error: {str(e)}") 289 | error_data = {"error": str(e)} 290 | if request_id: 291 | error_data["request_id"] = request_id 292 | error_event = f"event: error\ndata: {json.dumps(error_data)}\n\n" 293 | yield error_event.encode("utf-8") 294 | 295 | return StreamingResponse( 296 | stream_wrapper(), 297 | media_type="text/event-stream; charset=utf-8", 298 | headers={ 299 | "Cache-Control": "no-cache, no-store, must-revalidate", 300 | "Pragma": "no-cache", 301 | "Expires": "0", 302 | "Connection": "keep-alive", 303 | "X-Accel-Buffering": "no", # 禁用nginx缓冲 304 | "X-Content-Type-Options": "nosniff", 305 | "Transfer-Encoding": "chunked", 306 | "Access-Control-Allow-Origin": "*", # CORS支持 307 | "Access-Control-Allow-Headers": "*", 308 | "Access-Control-Allow-Methods": "*", 309 | "X-Proxy-Buffering": "no", # 禁用代理缓冲 310 | "Buffering": "no", # 禁用缓冲 311 | }, 312 | ) 313 | else: 314 | # 非流式响应 315 | response = await handler.process_message( 316 | anthropic_request, request_id=request_id 317 | ) 318 | json_response = JSONResponse(content=response.model_dump(exclude_none=True)) 319 | if request_id: 320 | json_response.headers["X-Request-ID"] = request_id 321 | return json_response 322 | 323 | except ValidationError as e: 324 | bound_logger.warning(f"请求验证失败 - Errors: {e.errors()}") 325 | error_response = get_error_response( 326 | 422, details={"validation_errors": e.errors()} 327 | ) 328 | error_detail = error_response.model_dump() 329 | error_detail["request_id"] = request_id 330 | raise HTTPException(status_code=422, detail=error_detail) 331 | 332 | except json.JSONDecodeError as e: 333 | bound_logger.warning(f"请求中的JSON格式错误 - Error: {str(e)}") 334 | error_response = get_error_response(400, message="无效的JSON格式") 335 | error_detail = error_response.model_dump() 336 | error_detail["request_id"] = request_id 337 | raise HTTPException(status_code=400, detail=error_detail) 338 | 339 | except Exception as e: 340 | # 检查是否为HTTPException,避免重复记录已处理的错误 341 | if isinstance(e, HTTPException): 342 | # HTTPException已经在内层处理过,直接重新抛出 343 | raise e 344 | 345 | bound_logger.exception( 346 | f"在messages端点发生意外错误 - Type: {type(e).__name__}, Error: {str(e)}" 347 | ) 348 | error_response = get_error_response(500, message=str(e)) 349 | error_detail = error_response.model_dump() 350 | error_detail["request_id"] = request_id 351 | raise HTTPException(status_code=500, detail=error_detail) 352 | -------------------------------------------------------------------------------- /src/api/middleware/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 中间件模块 3 | 4 | 提供FastAPI应用的各种中间件实现。 5 | 6 | 主要功能: 7 | - API密钥认证中间件 8 | - 请求计时中间件 9 | - 请求ID追踪 10 | - 中间件配置和设置 11 | 12 | 使用示例: 13 | from src.api.middleware import setup_middlewares 14 | 15 | setup_middlewares(app) 16 | """ 17 | 18 | # 导入中间件实现 19 | from .auth import APIKeyMiddleware 20 | from .timing import RequestTimingMiddleware, setup_middlewares 21 | 22 | __all__ = [ 23 | "APIKeyMiddleware", 24 | "RequestTimingMiddleware", 25 | "setup_middlewares", 26 | ] 27 | -------------------------------------------------------------------------------- /src/api/middleware/auth.py: -------------------------------------------------------------------------------- 1 | from fastapi import Request 2 | from starlette.middleware.base import BaseHTTPMiddleware 3 | 4 | from src.models.errors import get_error_response 5 | 6 | 7 | class APIKeyMiddleware(BaseHTTPMiddleware): 8 | """中间件:验证API密钥""" 9 | 10 | def __init__(self, app, api_key: str): 11 | super().__init__(app) 12 | self.api_key = api_key 13 | 14 | async def dispatch(self, request: Request, call_next): 15 | # 检查是否为需要密钥验证的路径 16 | if not self._requires_auth(request.url.path): 17 | # 跳过认证,直接处理请求 18 | response = await call_next(request) 19 | return response 20 | 21 | # 从请求头中获取API密钥 22 | token = request.headers.get("x-api-key") 23 | 24 | if token != self.api_key: 25 | error_response = get_error_response(401, message="API密钥无效") 26 | 27 | # 直接返回401响应,而不是抛出异常 28 | from fastapi.responses import JSONResponse 29 | 30 | return JSONResponse(status_code=401, content=error_response.dict()) 31 | 32 | response = await call_next(request) 33 | return response 34 | 35 | def _requires_auth(self, path: str) -> bool: 36 | """检查路径是否需要API密钥验证""" 37 | # 只有 /v1/messages 需要密钥验证 38 | auth_required_paths = ["/v1/messages"] 39 | return path in auth_required_paths 40 | -------------------------------------------------------------------------------- /src/api/middleware/timing.py: -------------------------------------------------------------------------------- 1 | """请求ID中间件""" 2 | 3 | import time 4 | from collections.abc import Callable 5 | 6 | from fastapi import FastAPI, Request 7 | from starlette.middleware.base import BaseHTTPMiddleware 8 | from starlette.responses import Response 9 | from starlette.types import ASGIApp 10 | 11 | # RequestIDMiddleware 已移除 - 如需request_id功能可重新添加 12 | 13 | 14 | class RequestTimingMiddleware(BaseHTTPMiddleware): 15 | """记录请求处理时间的中间件""" 16 | 17 | def __init__(self, app: ASGIApp) -> None: 18 | super().__init__(app) 19 | 20 | async def dispatch(self, request: Request, call_next: Callable) -> Response: 21 | start_time = time.time() 22 | 23 | # 延迟导入 24 | from src.common.logging import ( 25 | generate_request_id, 26 | get_logger_with_request_id, 27 | get_request_id_header_name, 28 | ) 29 | 30 | # 生成请求ID并添加到请求状态中(默认启用) 31 | request_id = await generate_request_id() 32 | request.state.request_id = request_id 33 | 34 | # 获取绑定了请求ID的logger 35 | bound_logger = get_logger_with_request_id(request_id) 36 | 37 | try: 38 | response = await call_next(request) 39 | 40 | response_time = time.time() - start_time 41 | response_time_ms = round(response_time * 1000, 2) 42 | 43 | # 使用绑定了请求ID的logger记录响应 44 | bound_logger.info( 45 | f"请求完成 - Status: {response.status_code}, Time: {response_time_ms}ms" 46 | ) 47 | 48 | response.headers["X-Process-Time"] = f"{response_time:.3f}s" 49 | header_name = await get_request_id_header_name() 50 | response.headers[header_name] = request_id 51 | 52 | return response 53 | 54 | except Exception as exc: 55 | response_time = time.time() - start_time 56 | error_content = ( 57 | f'{{"error":"Internal Server Error","request_id":"{request_id}"}}' 58 | ) 59 | 60 | response = Response( 61 | content=error_content, 62 | status_code=500, 63 | media_type="application/json", 64 | ) 65 | header_name = await get_request_id_header_name() 66 | response.headers[header_name] = request_id 67 | 68 | # 使用绑定了请求ID的logger记录错误 69 | # 安全构造错误日志 - 避免格式字符串问题 70 | safe_url = str(request.url) if hasattr(request, "url") else "unknown" 71 | safe_method = request.method if hasattr(request, "method") else "unknown" 72 | 73 | bound_logger.error( 74 | "请求处理错误", 75 | error_type=type(exc).__name__, 76 | error_message=str(exc), 77 | url=safe_url, 78 | method=safe_method, 79 | exc_info=True, 80 | ) 81 | 82 | return response 83 | 84 | 85 | def setup_middlewares(app: FastAPI) -> None: 86 | """设置所有中间件""" 87 | # 只保留请求计时中间件 88 | app.add_middleware(RequestTimingMiddleware) 89 | -------------------------------------------------------------------------------- /src/api/routes.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | from typing import Any 3 | 4 | from fastapi import APIRouter, Depends 5 | 6 | from src.config.settings import Config 7 | from src.core.clients.openai_client import OpenAIServiceClient 8 | 9 | router = APIRouter() 10 | 11 | 12 | async def get_openai_client() -> OpenAIServiceClient: 13 | """获取OpenAI客户端实例""" 14 | config = await Config.from_file() 15 | return OpenAIServiceClient( 16 | api_key=config.openai.api_key, 17 | base_url=config.openai.base_url, 18 | ) 19 | 20 | 21 | @router.get("/health", tags=["health"]) 22 | async def health_check( 23 | client: OpenAIServiceClient = Depends(get_openai_client), 24 | ) -> dict[str, Any]: 25 | """健康检查端点 - 验证OpenAI连通性""" 26 | 27 | health_status = { 28 | "status": "healthy", 29 | "service": "openai-to-claude", 30 | "timestamp": datetime.now().isoformat(), 31 | "checks": {}, 32 | } 33 | 34 | try: 35 | # 检查OpenAI服务可用性 36 | openai_health = await client.health_check() 37 | health_status["checks"]["openai"] = openai_health 38 | 39 | # 如果任何一个检查失败,状态设为降级 40 | if not all(openai_health.values()): 41 | health_status["status"] = "degraded" 42 | 43 | except Exception as e: 44 | # 如果无法创建客户端或者检查抛出异常 45 | health_status["status"] = "unhealthy" 46 | health_status["checks"]["openai"] = { 47 | "openai_service": False, 48 | "api_accessible": False, 49 | "error": str(e), 50 | } 51 | 52 | return health_status 53 | -------------------------------------------------------------------------------- /src/common/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 通用工具模块 3 | 4 | 提供项目中共享的工具和实用功能。 5 | 6 | 主要功能: 7 | - 日志配置和管理 8 | - 请求ID生成和追踪 9 | - Token计数功能 10 | - 异常处理工具 11 | 12 | 使用示例: 13 | from src.common import configure_logging, request_logger 14 | 15 | configure_logging() 16 | logger = request_logger.get_logger() 17 | """ 18 | 19 | # 导入日志相关功能 20 | from .logging import ( 21 | RequestLogger, 22 | configure_logging, 23 | generate_request_id, 24 | get_logger_with_request_id, 25 | get_request_id_from_request, 26 | get_request_id_header_name, 27 | log_exception, 28 | request_logger, 29 | should_enable_request_id, 30 | ) 31 | 32 | # 导入Token计数功能 33 | from .token_counter import TokenCounter, token_counter 34 | 35 | __all__ = [ 36 | # 日志功能 37 | "configure_logging", 38 | "RequestLogger", 39 | "request_logger", 40 | "generate_request_id", 41 | "should_enable_request_id", 42 | "get_request_id_header_name", 43 | "get_request_id_from_request", 44 | "log_exception", 45 | "get_logger_with_request_id", 46 | # Token计数 47 | "TokenCounter", 48 | "token_counter", 49 | ] 50 | -------------------------------------------------------------------------------- /src/common/logging.py: -------------------------------------------------------------------------------- 1 | """Loguru日志配置""" 2 | 3 | import sys 4 | import uuid 5 | from pathlib import Path 6 | 7 | from loguru import logger 8 | 9 | 10 | def configure_logging(log_config) -> None: 11 | """配置Loguru日志系统 12 | 13 | Args: 14 | log_config: 日志配置对象 15 | """ 16 | # 移除默认的handler 17 | logger.remove() 18 | 19 | # 使用相对路径而不是绝对路径 20 | log_path = Path("logs/app.log") 21 | 22 | # 确保日志目录存在,并设置正确的权限 23 | log_path.parent.mkdir(parents=True, exist_ok=True) 24 | # 设置目录权限为755,确保当前用户可写 25 | log_path.parent.chmod(0o755) 26 | 27 | # 如果日志文件已存在,确保其可写 28 | if log_path.exists(): 29 | log_path.chmod(0o644) 30 | 31 | # 控制台日志格式(包含请求ID) 32 | console_format = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {extra[request_id]} | {name}:{line} - {message}" 33 | 34 | # 配置控制台日志 35 | logger.add( 36 | sys.stdout, 37 | format=console_format, 38 | level=log_config.level, 39 | colorize=True, 40 | filter=lambda record: record["extra"].setdefault("request_id", "---"), 41 | ) 42 | 43 | # 配置文件日志(包含截取的异常堆栈) 44 | logger.add( 45 | str(log_path), 46 | level=log_config.level, 47 | rotation="10 MB", 48 | retention="1 day", 49 | encoding="utf-8", 50 | enqueue=True, # 异步写入 51 | filter=lambda record: record["extra"].setdefault("request_id", "---"), 52 | ) 53 | 54 | # 配置全局异常处理 55 | def exception_handler(exc_type, exc_value, exc_traceback): 56 | """全局异常处理器""" 57 | if issubclass(exc_type, KeyboardInterrupt): 58 | # 允许KeyboardInterrupt正常退出 59 | sys.__excepthook__(exc_type, exc_value, exc_traceback) 60 | return 61 | 62 | logger.opt(exception=(exc_type, exc_value, exc_traceback)).critical( 63 | "未捕获的异常" 64 | ) 65 | 66 | # 设置全局异常处理器 67 | sys.excepthook = exception_handler 68 | 69 | 70 | class RequestLogger: 71 | """请求日志处理器""" 72 | 73 | async def log_response( 74 | self, status_code: int, response_time: float, request_id: str = None 75 | ): 76 | """记录响应结束""" 77 | bound_logger = get_logger_with_request_id(request_id) 78 | 79 | response_time_ms = round(response_time * 1000, 2) 80 | bound_logger.info( 81 | f"请求完成 - Status: {status_code}, Time: {response_time_ms}ms" 82 | ) 83 | 84 | async def log_error( 85 | self, error: Exception, context: dict = None, request_id: str = None 86 | ): 87 | """记录错误情况""" 88 | bound_logger = get_logger_with_request_id(request_id) 89 | 90 | error_type = type(error).__name__ 91 | error_message = str(error) 92 | context_str = f", Context: {context}" if context else "" 93 | 94 | # 使用loguru的exception方法记录完整的堆栈跟踪 95 | bound_logger.exception( 96 | f"请求处理错误 - Type: {error_type}, Message: {error_message}{context_str}" 97 | ) 98 | 99 | 100 | # 全局logger实例 101 | request_logger = RequestLogger() 102 | 103 | 104 | async def generate_request_id() -> str: 105 | """生成唯一的请求ID 106 | 107 | Returns: 108 | str: 格式为 req_xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx 的请求ID 109 | """ 110 | return f"req_{uuid.uuid4().hex}" 111 | 112 | 113 | async def should_enable_request_id() -> bool: 114 | """检查是否应该启用请求ID(始终启用) 115 | 116 | Returns: 117 | bool: 始终返回True,请求ID功能默认启用 118 | """ 119 | return True 120 | 121 | 122 | async def get_request_id_header_name() -> str: 123 | """获取请求ID响应头名称 124 | 125 | Returns: 126 | str: 固定返回 "X-Request-ID" 127 | """ 128 | return "X-Request-ID" 129 | 130 | 131 | def get_request_id_from_request(request) -> str | None: 132 | """从请求对象中安全地获取请求ID 133 | 134 | Args: 135 | request: FastAPI Request对象 136 | 137 | Returns: 138 | str | None: 请求ID,如果不存在则返回None 139 | """ 140 | try: 141 | return getattr(request.state, "request_id", None) 142 | except AttributeError: 143 | return None 144 | 145 | 146 | async def log_exception(message: str = "发生异常", **kwargs): 147 | """记录异常的便捷函数 148 | 149 | 使用示例: 150 | try: 151 | # 一些可能出错的代码 152 | pass 153 | except Exception as e: 154 | log_exception("处理请求时发生错误", request_id="123", user_id="456") 155 | 156 | Args: 157 | message: 异常描述信息 158 | **kwargs: 额外的上下文信息 159 | """ 160 | kwargs_str = ", ".join([f"{k}: {v}" for k, v in kwargs.items()]) if kwargs else "" 161 | full_message = f"{message} - {kwargs_str}" if kwargs_str else message 162 | logger.exception(full_message) 163 | 164 | 165 | def get_logger_with_request_id(request_id: str = None): 166 | """获取绑定了请求ID的日志器实例 167 | 168 | Args: 169 | request_id: 请求ID,如果为None则使用默认值 170 | 171 | Returns: 172 | 绑定了请求ID的logger实例 173 | """ 174 | if request_id: 175 | return logger.bind(request_id=request_id) 176 | else: 177 | return logger.bind(request_id="---") 178 | -------------------------------------------------------------------------------- /src/common/token_cache.py: -------------------------------------------------------------------------------- 1 | """ 2 | 简单的请求token缓存模块 3 | 4 | 基于KISS原则,使用全局字典实现请求ID与token数量的临时缓存。 5 | 主要用于在OpenAI响应缺失usage信息时提供fallback。 6 | """ 7 | 8 | from typing import Dict, Optional 9 | 10 | # 全局缓存字典 - 遵循KISS原则 11 | _cache: Dict[str, int] = {} 12 | 13 | 14 | def cache_tokens(request_id: str, tokens: int) -> None: 15 | """ 16 | 缓存请求的token数量 17 | 18 | Args: 19 | request_id: 请求ID 20 | tokens: token数量 21 | """ 22 | if request_id and tokens > 0: 23 | _cache[request_id] = tokens 24 | 25 | 26 | def get_cached_tokens(request_id: str, delete=False) -> Optional[int]: 27 | """ 28 | 获取缓存的token数量并清理缓存 29 | 30 | Args: 31 | request_id: 请求ID 32 | 33 | Returns: 34 | 缓存的token数量,如果不存在则返回None 35 | 36 | Note: 37 | 使用pop()方法,获取后自动删除缓存,防止内存泄漏 38 | """ 39 | if not request_id: 40 | return None 41 | if delete: 42 | return _cache.pop(request_id, None) 43 | else: 44 | return _cache.get(request_id, None) 45 | 46 | 47 | def get_cache_size() -> int: 48 | """ 49 | 获取当前缓存大小,用于调试和监控 50 | 51 | Returns: 52 | 缓存中的条目数量 53 | """ 54 | return len(_cache) 55 | 56 | 57 | def clear_cache() -> None: 58 | """ 59 | 清空所有缓存,用于测试或重置 60 | """ 61 | _cache.clear() 62 | -------------------------------------------------------------------------------- /src/common/token_counter.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any 3 | 4 | import tiktoken 5 | 6 | 7 | class TokenCounter: 8 | """Token计数器,基于Node.js实现完整功能复现""" 9 | 10 | def __init__(self): 11 | self.encoder = tiktoken.get_encoding("o200k_base") 12 | 13 | def _extract_text_content(self, obj, field_name: str) -> str: 14 | """统一提取文本内容的方法,简化代码重复""" 15 | if hasattr(obj, field_name): 16 | return str(getattr(obj, field_name, "")) 17 | elif isinstance(obj, dict): 18 | return str(obj.get(field_name, "")) 19 | return "" 20 | 21 | def _process_content_part(self, content_part) -> list[str]: 22 | """处理消息内容部分,返回文本列表""" 23 | texts = [] 24 | part_type = ( 25 | content_part.type 26 | if hasattr(content_part, "type") 27 | else content_part.get("type") if isinstance(content_part, dict) else None 28 | ) 29 | 30 | if part_type == "text": 31 | text = self._extract_text_content(content_part, "text") 32 | if text: 33 | texts.append(text) 34 | elif part_type == "tool_use": 35 | input_data = ( 36 | getattr(content_part, "input", {}) 37 | if hasattr(content_part, "input") 38 | else content_part.get("input", {}) 39 | ) 40 | if input_data: 41 | texts.append(json.dumps(input_data, ensure_ascii=False)) 42 | 43 | return texts 44 | 45 | async def count_tokens( 46 | self, 47 | messages: list[Any] = None, 48 | system: Any = None, 49 | tools: list[Any] = None, 50 | ) -> int: 51 | """计算完整请求的token总数 52 | 53 | Args: 54 | messages: 消息列表 55 | system: 系统提示 56 | tools: 工具定义 57 | 58 | Returns: 59 | int: 总计token数量 60 | """ 61 | # 收集所有文本内容到单个列表,遵循KISS原则 62 | text_parts = [] 63 | 64 | # 处理消息内容 65 | if messages: 66 | for message in messages: 67 | # 获取消息内容 68 | content = ( 69 | message.content 70 | if hasattr(message, "content") 71 | else message.get("content", "") if isinstance(message, dict) else "" 72 | ) 73 | 74 | if isinstance(content, str): 75 | text_parts.append(content) 76 | elif isinstance(content, list): 77 | for content_part in content: 78 | text_parts.extend(self._process_content_part(content_part)) 79 | 80 | # 处理系统提示 81 | if system: 82 | if isinstance(system, str): 83 | text_parts.append(system) 84 | elif isinstance(system, list): 85 | for item in system: 86 | item_type = ( 87 | item.type 88 | if hasattr(item, "type") 89 | else item.get("type") if isinstance(item, dict) else None 90 | ) 91 | if item_type == "text": 92 | text_content = self._extract_text_content(item, "text") 93 | if text_content: 94 | text_parts.append(text_content) 95 | 96 | # 处理工具定义 97 | if tools: 98 | for tool in tools: 99 | # 统一获取name和description 100 | name = self._extract_text_content(tool, "name") 101 | description = self._extract_text_content(tool, "description") 102 | 103 | if name: 104 | text_parts.append(name) 105 | if description: 106 | text_parts.append(description) 107 | 108 | # 处理schema 109 | schema = ( 110 | getattr(tool, "input_schema", None) 111 | if hasattr(tool, "input_schema") 112 | else tool.get("input_schema") if isinstance(tool, dict) else None 113 | ) 114 | if schema: 115 | text_parts.append(json.dumps(schema, ensure_ascii=False)) 116 | 117 | # 一次性拼接所有文本并计算token(KISS原则) 118 | combined_text = "".join(text_parts) 119 | return len(self.encoder.encode(combined_text)) 120 | 121 | def count_response_tokens(self, content_blocks: list) -> int: 122 | """计算响应内容的token数量 123 | 124 | Args: 125 | content_blocks: Anthropic格式的内容块列表 126 | 127 | Returns: 128 | int: 响应内容的token数量 129 | """ 130 | # 收集所有文本内容到单个列表,遵循KISS原则 131 | text_parts = [] 132 | 133 | # 处理内容块 134 | if content_blocks: 135 | for block in content_blocks: 136 | # 处理文本内容 137 | if hasattr(block, "text") and block.text: 138 | text_parts.append(str(block.text)) 139 | elif isinstance(block, dict) and block.get("text"): 140 | text_parts.append(str(block["text"])) 141 | 142 | # 处理思考内容 143 | if hasattr(block, "thinking") and block.thinking: 144 | text_parts.append(str(block.thinking)) 145 | elif isinstance(block, dict) and block.get("thinking"): 146 | text_parts.append(str(block["thinking"])) 147 | 148 | # 处理工具调用内容 149 | if hasattr(block, "input") and block.input: 150 | text_parts.append(json.dumps(block.input, ensure_ascii=False)) 151 | elif isinstance(block, dict) and block.get("input"): 152 | text_parts.append(json.dumps(block["input"], ensure_ascii=False)) 153 | 154 | # 处理工具名称 155 | if hasattr(block, "name") and block.name: 156 | text_parts.append(str(block.name)) 157 | elif isinstance(block, dict) and block.get("name"): 158 | text_parts.append(str(block["name"])) 159 | 160 | # 一次性拼接所有文本并计算token(KISS原则) 161 | combined_text = "".join(text_parts) 162 | return len(self.encoder.encode(combined_text)) 163 | 164 | 165 | # 全局实例 166 | token_counter = TokenCounter() 167 | -------------------------------------------------------------------------------- /src/config/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 配置管理模块 3 | 4 | 提供应用程序配置的加载、验证和热重载功能。 5 | 6 | 主要功能: 7 | - 配置文件加载和验证 8 | - 配置热重载监听 9 | - 配置模型定义 10 | - 配置实例管理 11 | 12 | 使用示例: 13 | from src.config import get_config, reload_config 14 | 15 | config = get_config() 16 | print(config.server.host) 17 | """ 18 | 19 | # 导入配置相关功能 20 | from .settings import ( 21 | Config, 22 | LoggingConfig, 23 | ModelConfig, 24 | OpenAIConfig, 25 | ParameterOverridesConfig, 26 | ServerConfig, 27 | get_config, 28 | get_config_file_path, 29 | reload_config, 30 | ) 31 | from .watcher import ConfigFileHandler, ConfigWatcher 32 | 33 | __all__ = [ 34 | # 配置管理函数 35 | "get_config", 36 | "reload_config", 37 | "get_config_file_path", 38 | # 配置模型 39 | "Config", 40 | "OpenAIConfig", 41 | "ServerConfig", 42 | "LoggingConfig", 43 | "ModelConfig", 44 | "ParameterOverridesConfig", 45 | # 配置监听器 46 | "ConfigWatcher", 47 | "ConfigFileHandler", 48 | ] 49 | -------------------------------------------------------------------------------- /src/config/settings.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | 5 | import aiofiles 6 | from loguru import logger 7 | from pydantic import BaseModel, Field, field_validator 8 | 9 | # 全局配置缓存 10 | _config_instance = None 11 | 12 | 13 | async def get_config() -> "Config": 14 | """ 15 | 获取全局配置对象(带缓存的单例模式) 16 | """ 17 | global _config_instance 18 | if _config_instance is None: 19 | try: 20 | _config_instance = await Config.from_file() 21 | except Exception: 22 | # 如果配置文件读取失败,创建默认配置 23 | _config_instance = Config( 24 | openai={ 25 | "api_key": "your-openai-api-key-here", 26 | "base_url": "https://api.openai.com/v1", 27 | } 28 | ) 29 | return _config_instance 30 | 31 | 32 | async def reload_config(config_path: str | None = None) -> "Config": 33 | """重新加载全局配置对象 34 | 35 | Args: 36 | config_path: 配置文件路径,如果为None则使用默认路径 37 | 38 | Returns: 39 | Config: 重新加载的配置实例 40 | 41 | Raises: 42 | Exception: 配置加载失败时保持原配置不变 43 | """ 44 | global _config_instance 45 | 46 | try: 47 | # 尝试加载新配置 48 | new_config = await Config.from_file(config_path) 49 | _config_instance = new_config 50 | logger.info(f"配置重载成功: {new_config.model_dump_json()}") 51 | return _config_instance 52 | except Exception as e: 53 | logger.error(f"配置重载失败,保持原配置: {e}") 54 | if _config_instance is None: 55 | # 如果没有原配置,则创建默认配置 56 | _config_instance = Config( 57 | openai={ 58 | "api_key": "your-openai-api-key-here", 59 | "base_url": "https://api.openai.com/v1", 60 | } 61 | ) 62 | return _config_instance 63 | 64 | 65 | def get_config_file_path() -> str: 66 | """获取当前使用的配置文件路径 67 | 68 | Returns: 69 | str: 配置文件路径 70 | """ 71 | import os 72 | 73 | return os.getenv("CONFIG_PATH", "config/settings.json") 74 | 75 | 76 | class OpenAIConfig(BaseModel): 77 | """OpenAI API 配置""" 78 | 79 | api_key: str = Field(..., description="OpenAI API密钥") 80 | base_url: str = Field("https://api.openai.com/v1", description="OpenAI API基础URL") 81 | 82 | 83 | class ServerConfig(BaseModel): 84 | """服务器配置""" 85 | 86 | host: str = Field("0.0.0.0", description="服务监听主机") 87 | port: int = Field(8000, gt=0, lt=65536, description="服务监听端口") 88 | 89 | 90 | class LoggingConfig(BaseModel): 91 | """日志配置""" 92 | 93 | level: str = Field( 94 | "INFO", description="日志级别 (DEBUG, INFO, WARNING, ERROR, CRITICAL)" 95 | ) 96 | 97 | def __init__(self, **data): 98 | """初始化时支持环境变量覆盖""" 99 | # 环境变量覆盖 100 | if "LOG_LEVEL" in os.environ: 101 | data["level"] = os.environ["LOG_LEVEL"] 102 | 103 | super().__init__(**data) 104 | 105 | @field_validator("level") 106 | @classmethod 107 | def validate_log_level(cls, v: str) -> str: 108 | """验证日志级别""" 109 | valid_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} 110 | if v.upper() not in valid_levels: 111 | raise ValueError(f"日志级别必须是以下之一: {', '.join(valid_levels)}") 112 | return v.upper() 113 | 114 | 115 | class ModelConfig(BaseModel): 116 | """模型配置类 117 | 118 | 定义不同使用场景下的模型选择 119 | """ 120 | 121 | default: str = Field( 122 | description="默认通用模型", default="claude-3-5-sonnet-20241022" 123 | ) 124 | small: str = Field( 125 | description="轻量级模型,用于简单任务", default="claude-3-5-haiku-20241022" 126 | ) 127 | tool: str = Field( 128 | description="工具使用专用模型", default="claude-3-5-sonnet-20241022" 129 | ) 130 | think: str = Field( 131 | description="深度思考模型,用于复杂推理任务", 132 | default="claude-3-7-sonnet-20250219", 133 | ) 134 | long_context: str = Field( 135 | description="长上下文处理模型", default="claude-3-7-sonnet-20250219" 136 | ) 137 | web_search: str = Field(description="网络搜索模型", default="gemini-2.5-flash") 138 | 139 | 140 | class ParameterOverridesConfig(BaseModel): 141 | """参数覆盖配置类 142 | 143 | 允许管理员在配置文件中设置模型参数的覆盖值。 144 | 当设置了这些参数时,会覆盖客户端请求中的相应参数。 145 | """ 146 | 147 | max_tokens: int | None = Field( 148 | None, 149 | gt=0, 150 | description="最大token数覆盖,设置后会覆盖客户端请求中的max_tokens参数", 151 | ) 152 | temperature: float | None = Field( 153 | None, ge=0.0, le=2.0, description="温度参数覆盖,控制输出的随机程度" 154 | ) 155 | top_p: float | None = Field( 156 | None, ge=0.0, le=1.0, description="top_p采样参数覆盖,控制候选词汇的概率阈值" 157 | ) 158 | top_k: int | None = Field( 159 | None, ge=0, description="top_k采样参数覆盖,控制候选词汇的数量" 160 | ) 161 | 162 | 163 | class Config(BaseModel): 164 | """应用配置根类 165 | 166 | 使用 JSON 配置文件加载配置。 167 | 配置文件优先级: 168 | 1. 命令行指定的配置路径 169 | 2. 环境变量 CONFIG_PATH 指定的路径 170 | 3. ./config/settings.json (默认) 171 | 4. ./config/example.json (示例配置) 172 | 5. 默认值 173 | """ 174 | 175 | # 各模块配置 176 | openai: OpenAIConfig 177 | server: ServerConfig = ServerConfig() 178 | api_key: str = Field(..., description="/v1/messages接口的API密钥") 179 | logging: LoggingConfig = LoggingConfig() 180 | models: ModelConfig = ModelConfig() 181 | parameter_overrides: ParameterOverridesConfig = ParameterOverridesConfig() 182 | 183 | @classmethod 184 | async def from_file(cls, config_path: str | None = None) -> "Config": 185 | """ 186 | 从 JSON 配置文件加载配置 187 | Args: 188 | config_path: JSON配置文件路径,如果为None则使用默认路径 189 | 190 | Returns: 191 | Config: 配置实例 192 | 193 | Raises: 194 | FileNotFoundError: 配置文件不存在 195 | json.JSONDecodeError: JSON格式错误 196 | ValidationError: 配置数据验证错误 197 | """ 198 | import os 199 | 200 | if config_path is None: 201 | # 优先使用环境变量指定的路径 202 | config_path = os.getenv("CONFIG_PATH", "config/settings.json") 203 | 204 | config_file = Path(config_path) 205 | 206 | if config_file.exists(): 207 | try: 208 | async with aiofiles.open(config_file, encoding="utf-8") as f: 209 | config_data = await f.read() 210 | config_data = json.loads(config_data) 211 | except json.JSONDecodeError as e: 212 | print(f"❌ 配置文件格式错误: {e}") 213 | raise 214 | else: 215 | print(f"⚠️ 配置文件 {config_file.absolute()} 不存在") 216 | print("📦 使用 config/example.json 作为模板") 217 | 218 | # 尝试使用 example 配置 219 | example_file = Path("config/example.json") 220 | if example_file.exists(): 221 | try: 222 | async with aiofiles.open(example_file, encoding="utf-8") as f: 223 | config_data = await f.read() 224 | config_data = json.loads(config_data) 225 | # 创建 settings.json 作为实际配置文件 226 | async with aiofiles.open(config_file, "w", encoding="utf-8") as f: 227 | await f.write( 228 | json.dumps(config_data, indent=2, ensure_ascii=False) 229 | ) 230 | print(f"✅ 已从模板创建 {config_file}") 231 | 232 | except (json.JSONDecodeError, OSError) as e: 233 | print(f"❌ 无法创建配置文件: {e}") 234 | config_data = {} 235 | else: 236 | config_data = {} 237 | 238 | # 验证必填的 openai 配置 239 | if "openai" not in config_data: 240 | config_data["openai"] = { 241 | "api_key": "your-openai-api-key-here", 242 | "base_url": "https://api.openai.com/v1", 243 | } 244 | 245 | # 确保api_key存在(这是一个必填项) 246 | if "api_key" not in config_data: 247 | config_data["api_key"] = "your-proxy-api-key-here" 248 | 249 | return cls(**config_data) 250 | 251 | @classmethod 252 | def from_file_sync(cls, config_path: str | None = None) -> "Config": 253 | """ 254 | 从 JSON 配置文件加载配置 255 | Args: 256 | config_path: JSON配置文件路径,如果为None则使用默认路径 257 | 258 | Returns: 259 | Config: 配置实例 260 | 261 | Raises: 262 | FileNotFoundError: 配置文件不存在 263 | json.JSONDecodeError: JSON格式错误 264 | ValidationError: 配置数据验证错误 265 | """ 266 | import os 267 | 268 | if config_path is None: 269 | # 优先使用环境变量指定的路径 270 | config_path = os.getenv("CONFIG_PATH", "config/settings.json") 271 | 272 | config_file = Path(config_path) 273 | 274 | if config_file.exists(): 275 | try: 276 | with open(config_file, encoding="utf-8") as f: 277 | config_data = json.load(f) 278 | except json.JSONDecodeError as e: 279 | print(f"❌ 配置文件格式错误: {e}") 280 | raise 281 | else: 282 | print(f"⚠️ 配置文件 {config_file.absolute()} 不存在") 283 | print("📦 使用 config/example.json 作为模板") 284 | 285 | # 尝试使用 example 配置 286 | example_file = Path("config/example.json") 287 | if example_file.exists(): 288 | try: 289 | with open(example_file, encoding="utf-8") as f: 290 | config_data = json.load(f) 291 | # 创建 settings.json 作为实际配置文件 292 | with open(config_file, "w", encoding="utf-8") as f: 293 | f.write(json.dumps(config_data, indent=2, ensure_ascii=False)) 294 | print(f"✅ 已从模板创建 {config_file}") 295 | except (json.JSONDecodeError, OSError) as e: 296 | print(f"❌ 无法创建配置文件: {e}") 297 | config_data = {} 298 | else: 299 | config_data = {} 300 | 301 | # 验证必填的 openai 配置 302 | if "openai" not in config_data: 303 | config_data["openai"] = { 304 | "api_key": "your-openai-api-key-here", 305 | "base_url": "https://api.openai.com/v1", 306 | } 307 | 308 | # 确保api_key存在(这是一个必填项) 309 | if "api_key" not in config_data: 310 | config_data["api_key"] = "your-proxy-api-key-here" 311 | 312 | return cls(**config_data) 313 | 314 | def get_server_config(self) -> tuple[str, int]: 315 | """获取服务器配置 (host, port) 316 | 317 | Returns: 318 | tuple[str, int]: (host, port) 319 | """ 320 | return self.server.host, self.server.port 321 | -------------------------------------------------------------------------------- /src/config/watcher.py: -------------------------------------------------------------------------------- 1 | """配置文件监听和热重载模块 2 | 3 | 监听配置文件的变化,当配置文件被修改时自动重新加载配置。 4 | 使用 watchdog 库监听文件系统事件。 5 | """ 6 | 7 | import asyncio 8 | import json 9 | import os 10 | import threading 11 | from collections.abc import Callable 12 | from concurrent.futures import ThreadPoolExecutor 13 | from pathlib import Path 14 | from typing import Any 15 | 16 | from loguru import logger 17 | from watchdog.events import FileSystemEventHandler 18 | from watchdog.observers import Observer 19 | 20 | 21 | class ConfigFileHandler(FileSystemEventHandler): 22 | """配置文件变化事件处理器""" 23 | 24 | def __init__(self, config_path: Path, callback: Callable[[], None]): 25 | """ 26 | 初始化配置文件处理器 27 | 28 | Args: 29 | config_path: 要监听的配置文件路径 30 | callback: 配置文件变化时的回调函数 31 | """ 32 | self.config_path = config_path.resolve() 33 | self.callback = callback 34 | self._last_modified = 0 35 | 36 | def on_modified(self, event) -> None: 37 | """处理文件修改事件""" 38 | if event.is_directory: 39 | return 40 | 41 | # 检查是否是我们监听的配置文件 42 | event_path = Path(event.src_path).resolve() 43 | if event_path != self.config_path: 44 | return 45 | 46 | # 防止重复触发 47 | try: 48 | current_modified = event_path.stat().st_mtime 49 | if current_modified == self._last_modified: 50 | return 51 | self._last_modified = current_modified 52 | except OSError: 53 | return 54 | 55 | logger.info(f"配置文件已修改: {self.config_path}") 56 | 57 | # 延迟一点执行,确保文件写入完成 58 | threading.Timer(0.1, self._execute_callback).start() 59 | 60 | def _execute_callback(self) -> None: 61 | """执行回调函数""" 62 | try: 63 | self.callback() 64 | except Exception as e: 65 | logger.error(f"配置重载回调执行失败: {e}") 66 | 67 | 68 | class ConfigWatcher: 69 | """配置文件监听器 70 | 71 | 监听指定的配置文件,当文件发生变化时触发重新加载。 72 | """ 73 | 74 | def __init__(self, config_path: str | None = None): 75 | """ 76 | 初始化配置监听器 77 | 78 | Args: 79 | config_path: 配置文件路径,如果为None则使用默认路径 80 | """ 81 | if config_path is None: 82 | config_path = os.getenv("CONFIG_PATH", "config/settings.json") 83 | 84 | self.config_path = Path(config_path).resolve() 85 | self.observer: Observer | None = None 86 | self.handler: ConfigFileHandler | None = None 87 | self._reload_callbacks: list[Callable[[], None]] = [] 88 | self._async_reload_callbacks: list[Callable[[], Any]] = [] 89 | self._executor: ThreadPoolExecutor | None = None 90 | 91 | def add_reload_callback(self, callback: Callable[[], Any]) -> None: 92 | """ 93 | 添加异步配置重载回调函数 94 | 95 | Args: 96 | callback: 异步回调函数 97 | """ 98 | self._async_reload_callbacks.append(callback) 99 | 100 | async def start_watching(self) -> None: 101 | """开始监听配置文件变化""" 102 | if self.observer is not None: 103 | logger.warning("配置监听器已在运行") 104 | return 105 | 106 | if not self.config_path.exists(): 107 | logger.warning(f"配置文件不存在,跳过监listen: {self.config_path}") 108 | return 109 | 110 | # 创建线程池执行器用于异步回调 111 | if self._executor is None: 112 | self._executor = ThreadPoolExecutor( 113 | max_workers=1, thread_name_prefix="config-watcher" 114 | ) 115 | 116 | # 创建事件处理器 117 | self.handler = ConfigFileHandler(self.config_path, self._on_config_changed) 118 | 119 | # 创建观察者并开始监听 120 | self.observer = Observer() 121 | watch_dir = self.config_path.parent 122 | self.observer.schedule(self.handler, str(watch_dir), recursive=False) 123 | self.observer.start() 124 | 125 | logger.info(f"开始监听配置文件: {self.config_path}") 126 | 127 | def stop_watching(self) -> None: 128 | """停止监听配置文件变化""" 129 | if self.observer is None: 130 | return 131 | 132 | logger.info("停止配置文件监听") 133 | self.observer.stop() 134 | self.observer.join() 135 | self.observer = None 136 | self.handler = None 137 | 138 | # 关闭线程池 139 | if self._executor is not None: 140 | self._executor.shutdown(wait=True) 141 | self._executor = None 142 | 143 | def _on_config_changed(self) -> None: 144 | """配置文件变化时的处理逻辑""" 145 | logger.info("检测到配置文件变化,开始重新加载...") 146 | 147 | # 在线程池中执行异步验证和回调 148 | if self._executor is not None: 149 | self._executor.submit(self._handle_config_change) 150 | else: 151 | logger.error("线程池执行器未初始化,跳过配置重载") 152 | 153 | async def _process_config_change(self) -> None: 154 | """处理配置变化的异步逻辑""" 155 | # 验证配置文件格式 156 | if not await self._validate_config_file(): 157 | logger.error("配置文件格式无效,跳过重载") 158 | return 159 | 160 | # 执行同步回调 161 | for callback in self._reload_callbacks: 162 | try: 163 | callback() 164 | logger.debug(f"同步配置重载回调执行成功: {callback.__name__}") 165 | except Exception as e: 166 | logger.error(f"同步配置重载回调执行失败 {callback.__name__}: {e}") 167 | 168 | # 执行异步回调 169 | await self._execute_async_callbacks() 170 | 171 | logger.info("配置重载完成") 172 | 173 | def _handle_config_change(self) -> None: 174 | """在线程池中处理配置变化""" 175 | try: 176 | # 在新线程中创建事件循环 177 | loop = asyncio.new_event_loop() 178 | asyncio.set_event_loop(loop) 179 | try: 180 | # 运行异步验证和回调 181 | loop.run_until_complete(self._process_config_change()) 182 | finally: 183 | loop.close() 184 | except Exception as e: 185 | logger.error(f"配置变化处理失败: {e}") 186 | 187 | async def _execute_async_callbacks(self) -> None: 188 | """执行异步回调函数""" 189 | for callback in self._async_reload_callbacks: 190 | try: 191 | if asyncio.iscoroutinefunction(callback): 192 | await callback() 193 | else: 194 | callback() 195 | except Exception as e: 196 | logger.error(f"异步配置重载回调执行失败 {callback.__name__}: {e}") 197 | 198 | async def _validate_config_file(self) -> bool: 199 | """验证配置文件格式是否正确""" 200 | try: 201 | import aiofiles 202 | 203 | async with aiofiles.open(self.config_path, encoding="utf-8") as f: 204 | content = await f.read() 205 | json.loads(content) 206 | return True 207 | except (json.JSONDecodeError, OSError) as e: 208 | logger.error(f"配置文件验证失败: {e}") 209 | return False 210 | 211 | def __enter__(self): 212 | """上下文管理器入口""" 213 | self.start_watching() 214 | return self 215 | 216 | def __exit__(self, exc_type, exc_val, exc_tb): 217 | """上下文管理器退出""" 218 | # 忽略异常信息,直接停止监听 219 | self.stop_watching() 220 | -------------------------------------------------------------------------------- /src/core/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 核心功能模块 3 | 4 | 提供代理服务的核心功能,包括: 5 | - OpenAI 客户端封装 6 | - 请求/响应格式转换器 7 | - 数据模型定义 8 | 9 | 子模块: 10 | - clients: OpenAI API 客户端实现 11 | - converters: Anthropic ↔ OpenAI 格式转换器 12 | """ 13 | 14 | # 导入核心客户端 15 | from .clients import OpenAIServiceClient 16 | 17 | # 导入转换器 18 | from .converters import ( 19 | AnthropicToOpenAIConverter, 20 | OpenAIToAnthropicConverter, 21 | ) 22 | 23 | __all__ = [ 24 | # 客户端 25 | "OpenAIServiceClient", 26 | # 转换器 27 | "AnthropicToOpenAIConverter", 28 | "OpenAIToAnthropicConverter", 29 | ] 30 | -------------------------------------------------------------------------------- /src/core/clients/__init__.py: -------------------------------------------------------------------------------- 1 | from .openai_client import OpenAIServiceClient 2 | 3 | __all__ = ["OpenAIServiceClient"] 4 | -------------------------------------------------------------------------------- /src/core/clients/openai_client.py: -------------------------------------------------------------------------------- 1 | """OpenAI API client for making asynchronous requests to OpenAI endpoints.""" 2 | 3 | import json 4 | from collections.abc import AsyncGenerator 5 | from typing import Any 6 | 7 | import httpx 8 | from loguru import logger 9 | 10 | from src.models.errors import StandardErrorResponse, get_error_response 11 | from src.models.openai import OpenAIRequest, OpenAIStreamResponse 12 | 13 | 14 | class OpenAIClientError(Exception): 15 | """Base exception for OpenAI client errors.""" 16 | 17 | def __init__(self, error_response: StandardErrorResponse): 18 | self.error_response = error_response 19 | super().__init__(str(error_response)) 20 | 21 | 22 | class OpenAIServiceClient: 23 | """Async OpenAI API client with connection pooling and retry logic.""" 24 | 25 | def __init__( 26 | self, 27 | api_key: str, 28 | base_url: str = "https://api.openai.com/v1", 29 | timeout: float = 60.0, 30 | ): 31 | """Initialize OpenAI client with connection pool. 32 | 33 | Args: 34 | api_key: OpenAI API密钥 35 | base_url: OpenAI API基础URL 36 | timeout: 请求超时时间(秒) 37 | """ 38 | self.api_key = api_key 39 | self.base_url = base_url.rstrip("/") 40 | self.timeout = timeout 41 | 42 | self.client = httpx.AsyncClient( 43 | headers={ 44 | "Authorization": f"Bearer {api_key}", 45 | "Content-Type": "application/json", 46 | "Connection": "keep-alive", 47 | }, 48 | # 确保自动解压缩响应 49 | follow_redirects=True, 50 | timeout=timeout, 51 | ) 52 | 53 | async def __aenter__(self): 54 | """Async context manager entry.""" 55 | return self 56 | 57 | async def __aexit__(self, exc_type, exc_val, exc_tb): 58 | """Async context manager exit.""" 59 | await self.aclose() 60 | 61 | async def aclose(self): 62 | """Close the HTTP client.""" 63 | await self.client.aclose() 64 | 65 | async def send_request( 66 | self, 67 | request: OpenAIRequest, 68 | endpoint: str = "/chat/completions", 69 | request_id: str = None, 70 | ) -> dict[str, Any]: 71 | """Send synchronous request to OpenAI API. 72 | 73 | Args: 74 | request: OpenAI request object 75 | endpoint: API endpoint path 76 | request_id: 请求ID用于日志追踪 77 | 78 | Returns: 79 | OpenAI API响应 80 | 81 | Raises: 82 | OpenAIClientError: 当API返回错误时 83 | """ 84 | # 获取绑定了请求ID的logger 85 | from src.common.logging import get_logger_with_request_id 86 | 87 | bound_logger = get_logger_with_request_id(request_id) 88 | 89 | url = f"{self.base_url}{endpoint}" 90 | request_data = request.model_dump(exclude_none=True) 91 | 92 | # 记录请求详情 93 | bound_logger.info( 94 | f"发送OpenAI请求 - URL: {url}, Model: {request_data.get('model', 'unknown')}, Messages: {len(request_data.get('messages', []))}" 95 | ) 96 | 97 | try: 98 | response = await self.client.post( 99 | url, 100 | json=request_data, 101 | ) 102 | response.raise_for_status() 103 | # 记录响应状态 104 | content_type = response.headers.get("content-type", "unknown") 105 | bound_logger.info( 106 | f"收到OpenAI响应 - Status: {response.status_code}, Content-Type: {content_type}, Size: {len(response.content)} bytes" 107 | ) 108 | 109 | # 使用 response.text 让 httpx 自动处理编码和解压缩 110 | try: 111 | text = response.text 112 | result = json.loads(text) 113 | 114 | # 记录响应内容(如果启用详细日志) 115 | bound_logger.debug( 116 | f"OpenAI响应内容 - ID: {result.get('id', 'unknown')}, Model: {result.get('model', 'unknown')}, Usage: {result.get('usage', {})}" 117 | ) 118 | 119 | except json.JSONDecodeError as e: 120 | # 记录详细的JSON解析错误信息 121 | response_preview = ( 122 | response.text[:500] if response.text else "Empty response" 123 | ) 124 | content_type = response.headers.get("content-type", "unknown") 125 | bound_logger.exception( 126 | f"OpenAI JSON解析失败 - Status: {response.status_code}, Content-Type: {content_type}, " 127 | f"Error: {str(e)}, Response Preview: {response_preview}" 128 | ) 129 | # 抛出包含更多上下文信息的异常 130 | raise json.JSONDecodeError( 131 | f"Failed to parse OpenAI response (Status: {response.status_code}): {str(e)}", 132 | response.text, 133 | e.pos, 134 | ) 135 | return result 136 | except httpx.HTTPStatusError as e: 137 | # 安全读取响应内容(非流式模式) 138 | response_body = "" 139 | try: 140 | response_body = e.response.text 141 | except httpx.ResponseNotRead: 142 | # 如果响应未被读取,直接获取错误信息 143 | response_body = str(e) 144 | 145 | bound_logger.error( 146 | f"OpenAI API返回错误 - Endpoint: {endpoint}, Status: {e.response.status_code}, Response: {response_body[:200]}" 147 | ) 148 | 149 | raise OpenAIClientError( 150 | get_error_response( 151 | status_code=e.response.status_code, 152 | message=response_body, 153 | details={"type": "http_error"}, 154 | ) 155 | ) 156 | 157 | except httpx.TimeoutException as e: 158 | bound_logger.error( 159 | f"OpenAI API request timeout - Endpoint: {endpoint}, Timeout: {self.timeout}s" 160 | ) 161 | raise OpenAIClientError( 162 | get_error_response( 163 | status_code=504, 164 | message=str(e), 165 | details={"type": "timeout_error", "original_error": str(e)}, 166 | ) 167 | ) 168 | 169 | except httpx.ConnectError as e: 170 | bound_logger.error( 171 | f"OpenAI API connection error - Endpoint: {endpoint}, Error: {str(e)}" 172 | ) 173 | raise OpenAIClientError( 174 | get_error_response( 175 | status_code=502, 176 | message=str(e), 177 | details={"type": "connection_error", "original_error": str(e)}, 178 | ) 179 | ) 180 | 181 | async def send_streaming_request( 182 | self, 183 | request: OpenAIRequest, 184 | endpoint: str = "/chat/completions", 185 | request_id: str = None, 186 | ) -> AsyncGenerator[str, None]: 187 | """Send streaming request to OpenAI API. 188 | 189 | Args: 190 | request: OpenAI request object 191 | endpoint: API endpoint path 192 | request_id: 请求ID用于日志追踪 193 | 194 | Yields: 195 | 原始的Server-Sent Events数据行 196 | 197 | Raises: 198 | OpenAIClientError: 当API返回错误时 199 | """ 200 | # 获取绑定了请求ID的logger 201 | from src.common.logging import get_logger_with_request_id 202 | 203 | bound_logger = get_logger_with_request_id(request_id) 204 | 205 | url = f"{self.base_url}{endpoint}" 206 | 207 | # Ensure streaming is enabled 208 | request_dict = request.model_dump(exclude_none=True) 209 | request_dict["stream"] = True 210 | 211 | # 记录流式请求详情 212 | bound_logger.info( 213 | f"发送OpenAI流式请求 - URL: {url}, Model: {request_dict.get('model', 'unknown')}, Messages: {len(request_dict.get('messages', []))}, Stream: True" 214 | ) 215 | 216 | try: 217 | async with self.client.stream( 218 | "POST", 219 | url, 220 | json=request_dict, 221 | ) as response: 222 | response.raise_for_status() 223 | 224 | # 记录流式响应开始 225 | content_type = response.headers.get("content-type", "unknown") 226 | bound_logger.info( 227 | f"开始接收OpenAI流式响应 - Status: {response.status_code}, Content-Type: {content_type}" 228 | ) 229 | 230 | buffer = "" 231 | 232 | async for chunk_bytes in response.aiter_bytes(chunk_size=1024): 233 | chunk_text = chunk_bytes.decode("utf-8", errors="ignore") 234 | buffer += chunk_text 235 | 236 | # 处理完整的行 237 | while "\n" in buffer: 238 | line, buffer = buffer.split("\n", 1) 239 | line = line.strip() 240 | 241 | # 直接转发非空行 242 | if line: 243 | # logger.debug(f"Forwarding line: {line}") 244 | yield line 245 | # 检查是否结束 246 | if line == "data: [DONE]": 247 | return 248 | 249 | # 处理最后可能剩余的数据 250 | if buffer.strip(): 251 | yield buffer.strip() 252 | 253 | except httpx.HTTPStatusError as e: 254 | error_body = "" 255 | try: 256 | # 尝试读取完整的错误响应体 257 | error_body = await e.response.aread() 258 | error_body = error_body.decode("utf-8", errors="ignore") 259 | except Exception as read_error: 260 | error_body = f"无法读取错误响应: {str(read_error)}" 261 | 262 | # 记录完整错误信息,但在日志中截断过长内容 263 | error_summary = ( 264 | error_body[:500] + "..." if len(error_body) > 500 else error_body 265 | ) 266 | bound_logger.error( 267 | f"OpenAI API 错误 - Status: {e.response.status_code}, URL: {url}" 268 | ) 269 | bound_logger.error(f"Error Response: {error_summary}") 270 | raise OpenAIClientError( 271 | get_error_response( 272 | status_code=e.response.status_code, 273 | message=f"HTTP {e.response.status_code} error", 274 | details={"type": "http_error"}, 275 | ) 276 | ) 277 | 278 | except httpx.TimeoutException as e: 279 | bound_logger.error(f"OpenAI API 超时 - Error: {str(e)}") 280 | raise OpenAIClientError( 281 | get_error_response( 282 | status_code=504, 283 | message="Request timeout", 284 | details={"type": "timeout_error"}, 285 | ) 286 | ) 287 | 288 | except httpx.ConnectError as e: 289 | bound_logger.error(f"OpenAI API 连接错误 - Error: {str(e)}") 290 | raise OpenAIClientError( 291 | get_error_response( 292 | status_code=502, 293 | message="Connection error", 294 | details={"type": "connection_error"}, 295 | ) 296 | ) 297 | 298 | async def _parse_streaming_chunk( 299 | self, chunk_data: str, tool_calls_state: dict 300 | ) -> OpenAIStreamResponse | None: 301 | """解析流式响应chunk,优雅处理不完整的JSON数据。 302 | 303 | Args: 304 | chunk_data: JSON字符串的响应块 305 | tool_calls_state: 预留参数(未使用) 306 | 307 | Returns: 308 | 解析后的响应对象,如果数据不完整则返回None 309 | """ 310 | import json 311 | 312 | try: 313 | # 尝试解析JSON数据 314 | raw_data = json.loads(chunk_data) 315 | 316 | result = OpenAIStreamResponse.model_validate(raw_data) 317 | return result 318 | 319 | except json.JSONDecodeError as e: 320 | # JSON解析失败,通常是因为数据被分割,静默跳过 321 | logger.debug( 322 | f"Skipping incomplete JSON chunk - Error: {str(e)}, Data: {chunk_data[:100]}" 323 | ) 324 | return None 325 | except Exception as e: 326 | # Pydantic验证失败,可能是tool_calls的增量数据不完整 327 | logger.debug( 328 | f"Skipping chunk due to validation error - Error: {str(e)}, Data: {chunk_data[:100]}" 329 | ) 330 | return None 331 | 332 | async def health_check(self) -> dict[str, bool]: 333 | """Check OpenAI API availability. 334 | 335 | Returns: 336 | 健康检查结果 337 | """ 338 | try: 339 | url = f"{self.base_url}/models" 340 | response = await self.client.get(url) 341 | 342 | return { 343 | "openai_service": response.status_code == 200, 344 | "api_accessible": True, 345 | "last_check": True, 346 | } 347 | 348 | except Exception as e: 349 | logger.exception(f"OpenAI health check failed - Error: {str(e)}") 350 | return { 351 | "openai_service": False, 352 | "api_accessible": False, 353 | "last_check": True, 354 | } 355 | -------------------------------------------------------------------------------- /src/core/converters/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 转换器模块 3 | 4 | 提供Anthropic和OpenAI API格式之间的数据转换功能。 5 | """ 6 | 7 | from .request_converter import AnthropicToOpenAIConverter 8 | from .response_converter import OpenAIToAnthropicConverter 9 | 10 | __all__ = ["AnthropicToOpenAIConverter", "OpenAIToAnthropicConverter"] 11 | -------------------------------------------------------------------------------- /src/core/converters/response_converter.py: -------------------------------------------------------------------------------- 1 | """ 2 | OpenAI-to-Anthropic 响应转换器 3 | 4 | 实现将OpenAI格式的响应转换为Anthropic格式的功能 5 | """ 6 | 7 | import json 8 | import time 9 | import traceback 10 | from collections.abc import AsyncIterator 11 | from typing import Any 12 | 13 | from src.common.token_cache import get_cached_tokens 14 | 15 | from .stream_converters import ( 16 | StreamState, 17 | _log_stream_completion_details, 18 | format_event, 19 | process_finish_event, 20 | process_regular_content, 21 | process_thinking_content, 22 | process_tool_calls, 23 | safe_json_parse, 24 | ) 25 | 26 | 27 | from src.models.anthropic import ( 28 | AnthropicContentBlock, 29 | AnthropicContentTypes, 30 | AnthropicMessageResponse, 31 | AnthropicMessageTypes, 32 | AnthropicRoles, 33 | AnthropicStreamEventTypes, 34 | AnthropicStreamMessage, 35 | AnthropicStreamMessageStartMessage, 36 | AnthropicUsage, 37 | ) 38 | from src.models.openai import ( 39 | OpenAIChoice, 40 | OpenAIMessage, 41 | ) 42 | 43 | 44 | class OpenAIToAnthropicConverter: 45 | """OpenAI响应到Anthropic格式的转换器""" 46 | 47 | @staticmethod 48 | async def convert_response( 49 | openai_response: dict[str, Any], 50 | original_model: str = None, 51 | request_id: str = None, 52 | ) -> AnthropicMessageResponse: 53 | """ 54 | 将OpenAI非流式响应转换为Anthropic格式 55 | 56 | Args: 57 | openai_response: OpenAI响应字典 58 | original_model: 原始请求的Anthropic模型 59 | request_id: 请求ID,用于获取缓存的token数量 60 | 61 | Returns: 62 | AnthropicMessageResponse: 转换后的Anthropic格式响应 63 | """ 64 | choices = openai_response.get("choices") 65 | if not choices: 66 | raise ValueError("OpenAI响应没有有效的choices") 67 | 68 | # 使用第一个choice作为主要响应 69 | first_choice_data = choices[0] 70 | message_data = first_choice_data.get("message", {}) 71 | choice = OpenAIChoice( 72 | message=( 73 | OpenAIMessage( 74 | role=message_data.get("role"), 75 | content=message_data.get("content", ""), 76 | tool_calls=message_data.get("tool_calls"), 77 | ) 78 | if message_data 79 | else None 80 | ), 81 | finish_reason=first_choice_data.get("finish_reason"), 82 | index=first_choice_data.get("index", 0), 83 | ) 84 | 85 | # 提取内容块 86 | content_blocks = ( 87 | OpenAIToAnthropicConverter._extract_content_blocks_with_reasoning( 88 | choice, first_choice_data 89 | ) 90 | ) 91 | 92 | # 转换使用统计 93 | usage_data = openai_response.get("usage", {}) 94 | usage = OpenAIToAnthropicConverter._convert_usage( 95 | usage_data, request_id, content_blocks 96 | ) 97 | 98 | # 确定模型ID 99 | model = original_model if original_model else openai_response.get("model") 100 | 101 | mapping = { 102 | "stop": "end_turn", 103 | "length": "max_tokens", 104 | "content_filter": "content_filter", 105 | "tool_calls": "tool_use", 106 | "function_call": "tool_use", 107 | } 108 | 109 | # 映射完成原因 110 | stop_reason = mapping.get(choice.finish_reason, "end_turn") 111 | 112 | return AnthropicMessageResponse( 113 | id=openai_response.get("id", ""), 114 | type=AnthropicMessageTypes.MESSAGE, 115 | role=AnthropicRoles.ASSISTANT, 116 | content=content_blocks, 117 | model=model, 118 | stop_reason=stop_reason, 119 | stop_sequence=None, 120 | usage=usage, 121 | ) 122 | 123 | @staticmethod 124 | def _extract_content_blocks_with_reasoning( 125 | choice, first_choice_data 126 | ) -> list[AnthropicContentBlock]: 127 | """ 128 | 从OpenAI choice中提取内容块,包括推理内容 129 | 130 | Args: 131 | choice: OpenAI选择对象 132 | first_choice_data: OpenAI choice的原始数据 133 | 134 | Returns: 135 | List[AnthropicContentBlock]: 内容块列表 136 | """ 137 | if not choice.message: 138 | return [] 139 | 140 | content_blocks = [] 141 | message_data = first_choice_data.get("message", {}) 142 | 143 | # 处理推理内容 - 作为独立的thinking类型内容块 144 | reasoning_content = message_data.get("reasoning_content") 145 | if ( 146 | reasoning_content 147 | and isinstance(reasoning_content, str) 148 | and reasoning_content.strip() 149 | ): 150 | content_blocks.append( 151 | AnthropicContentBlock( 152 | type=AnthropicContentTypes.THINKING, 153 | thinking=reasoning_content.strip(), 154 | signature=f"{int(time.time()*1000)}", 155 | ) 156 | ) 157 | 158 | # 处理普通内容 - 作为独立的text类型内容块 159 | content_str = message_data.get("content", "") 160 | if content_str and isinstance(content_str, str) and content_str.strip(): 161 | # 检查content中是否包含标签 162 | if "" in content_str and "" in content_str: 163 | # 分离思考内容和普通内容 164 | import re 165 | 166 | think_pattern = r"(.*?)" 167 | think_matches = re.findall(think_pattern, content_str, re.DOTALL) 168 | 169 | # 如果还没有添加thinking块且找到了思考内容,添加thinking块 170 | if think_matches and not any( 171 | block.type == AnthropicContentTypes.THINKING 172 | for block in content_blocks 173 | ): 174 | thinking_content = think_matches[0].strip() 175 | if thinking_content: 176 | content_blocks.append( 177 | AnthropicContentBlock( 178 | type=AnthropicContentTypes.THINKING, 179 | thinking=thinking_content, 180 | signature=f"{int(time.time()*1000)}", 181 | ) 182 | ) 183 | 184 | # 移除标签,保留普通内容 185 | clean_content = re.sub( 186 | think_pattern, "", content_str, flags=re.DOTALL 187 | ).strip() 188 | if clean_content: 189 | content_blocks.append( 190 | AnthropicContentBlock( 191 | type=AnthropicContentTypes.TEXT, text=clean_content 192 | ) 193 | ) 194 | else: 195 | # 没有思考标签,直接作为普通内容 196 | content_blocks.append( 197 | AnthropicContentBlock( 198 | type=AnthropicContentTypes.TEXT, text=content_str.strip() 199 | ) 200 | ) 201 | 202 | # 处理工具调用 203 | if choice.message.tool_calls: 204 | from src.models.openai import OpenAIToolCall 205 | 206 | for tool_call_data in choice.message.tool_calls: 207 | tool_call = OpenAIToolCall.model_validate(tool_call_data) 208 | if hasattr(tool_call, "function") and tool_call.function: 209 | content_blocks.append( 210 | AnthropicContentBlock( 211 | type=AnthropicContentTypes.TOOL_USE, 212 | id=tool_call.id, 213 | name=tool_call.function.name, 214 | input=( 215 | safe_json_parse(tool_call.function.arguments) 216 | if tool_call.function.arguments 217 | else {} 218 | ), 219 | ) 220 | ) 221 | 222 | # 如果没有任何内容,返回空的text块 223 | if not content_blocks: 224 | content_blocks = [ 225 | AnthropicContentBlock(type=AnthropicContentTypes.TEXT, text="") 226 | ] 227 | 228 | return content_blocks 229 | 230 | @staticmethod 231 | def _convert_usage( 232 | usage_data: dict[str, Any], request_id: str = None, content_blocks: list = None 233 | ) -> AnthropicUsage: 234 | """ 235 | 将OpenAI使用统计转换为Anthropic格式,支持缓存fallback 236 | 237 | Args: 238 | usage_data: OpenAI使用统计数据 239 | request_id: 请求ID,用于获取缓存的token数量 240 | content_blocks: 内容块列表,用于计算输出token数量 241 | 242 | Returns: 243 | AnthropicUsage: Anthropic格式的使用统计 244 | """ 245 | prompt_tokens = usage_data.get("prompt_tokens", 0) if usage_data else 0 246 | completion_tokens = usage_data.get("completion_tokens", 0) if usage_data else 0 247 | 248 | # 如果OpenAI没有返回prompt_tokens,使用缓存的值 249 | if not prompt_tokens and request_id: 250 | from src.common.token_cache import get_cached_tokens 251 | 252 | cached_tokens = get_cached_tokens(request_id) 253 | if cached_tokens: 254 | prompt_tokens = cached_tokens 255 | 256 | # 如果OpenAI没有返回completion_tokens,使用我们的计算方法 257 | if not completion_tokens and content_blocks: 258 | from src.common.token_counter import token_counter 259 | 260 | # 使用同步版本,保持简单性(KISS原则) 261 | completion_tokens = token_counter.count_response_tokens(content_blocks) 262 | 263 | return AnthropicUsage( 264 | input_tokens=prompt_tokens, 265 | output_tokens=completion_tokens, 266 | ) 267 | 268 | @staticmethod 269 | async def convert_openai_stream_to_anthropic_stream( 270 | openai_stream: AsyncIterator[str], 271 | model: str = "unknown", 272 | request_id: str = None, 273 | ) -> AsyncIterator[str]: 274 | """将 OpenAI 流式响应转换为 Anthropic 流式响应格式 275 | 276 | Args: 277 | openai_stream: OpenAI 流式数据源 278 | model: 模型名称 279 | request_id: 请求ID用于日志追踪 280 | 281 | Yields: 282 | str: Anthropic 格式的流式事件字符串 283 | """ 284 | # 获取绑定了请求ID的logger 285 | from src.common.logging import get_logger_with_request_id 286 | 287 | bound_logger = get_logger_with_request_id(request_id) 288 | 289 | state = StreamState() 290 | 291 | try: 292 | async for chunk in openai_stream: 293 | if state.has_finished: 294 | break 295 | 296 | state.buffer += chunk 297 | lines = state.buffer.split("\n") 298 | state.buffer = lines.pop() if lines else "" 299 | 300 | for line in lines: 301 | if state.has_finished: 302 | break 303 | 304 | if not line.startswith("data: "): 305 | continue 306 | 307 | data = line[6:] 308 | if data == "[DONE]": 309 | continue 310 | try: 311 | chunk_data = json.loads(data) 312 | state.total_chunks += 1 313 | # 处理错误 314 | if "error" in chunk_data: 315 | error_event = { 316 | "type": "error", 317 | "message": { 318 | "type": "api_error", 319 | "message": json.dumps(chunk_data["error"]), 320 | }, 321 | } 322 | yield format_event("error", error_event) 323 | continue 324 | 325 | # 发送 message_start 事件 326 | if not state.has_started and not state.has_finished: 327 | # 获取input token缓存 328 | input_tokens = 0 329 | cached_tokens = get_cached_tokens(request_id) 330 | if cached_tokens: 331 | input_tokens = cached_tokens 332 | state.has_started = True 333 | message_start = AnthropicStreamMessage( 334 | message=AnthropicStreamMessageStartMessage( 335 | id=state.message_id, 336 | model=model, 337 | usage=AnthropicUsage(input_tokens=input_tokens), 338 | ), 339 | ) 340 | yield format_event( 341 | AnthropicStreamEventTypes.MESSAGE_START, 342 | message_start.model_dump(exclude=["delta", "usage"]), 343 | ) 344 | 345 | choices = chunk_data.get("choices", []) 346 | if not choices: 347 | continue 348 | 349 | choice = choices[0] 350 | delta = choice.get("delta", None) 351 | if delta is None: 352 | continue 353 | 354 | content = delta.get("content", None) 355 | reasoning_content = delta.get("reasoning_content", None) 356 | tool_calls = delta.get("tool_calls", None) 357 | 358 | # 检查是否有任何内容需要处理 359 | has_content = content is not None and content != "" 360 | has_reasoning = ( 361 | reasoning_content is not None and reasoning_content != "" 362 | ) 363 | has_tool_calls = tool_calls is not None and len(tool_calls) > 0 364 | 365 | if not has_content and not has_reasoning and not has_tool_calls: 366 | if choice.get("finish_reason") is None: 367 | continue 368 | 369 | # 处理思考内容 370 | events = process_thinking_content(delta, state) 371 | if events: 372 | for event in events: 373 | yield event 374 | 375 | # 处理普通文本内容 376 | events = process_regular_content(delta, state) 377 | if events: 378 | for event in events: 379 | yield event 380 | continue 381 | 382 | # 处理工具调用 383 | if has_tool_calls: 384 | events = process_tool_calls(delta, state) 385 | if events: 386 | for event in events: 387 | yield event 388 | continue 389 | 390 | # 处理完成事件 391 | finish_reason = choice.get("finish_reason") 392 | if finish_reason: 393 | finish_events = process_finish_event( 394 | chunk_data, state, request_id 395 | ) 396 | for event in finish_events: 397 | yield event 398 | 399 | # 在所有事件生成完成后记录详细日志(遵循KISS原则) 400 | _log_stream_completion_details( 401 | state, 402 | request_id, 403 | model, 404 | ) 405 | 406 | except json.JSONDecodeError as parse_error: 407 | bound_logger.error( 408 | f"Parse error - Error: {str(parse_error.args[0])}, Data: {data[:100]}", 409 | exc_info=True, 410 | ) 411 | except Exception as e: 412 | bound_logger.error( 413 | f"Unexpected error processing chunk - Error: {str(e)}", 414 | exc_info=True, 415 | ) 416 | traceback.print_exc() 417 | 418 | except Exception as error: 419 | bound_logger.error( 420 | f"Stream conversion error - Error: {str(error)}", exc_info=True 421 | ) 422 | error_event = { 423 | "type": "error", 424 | "message": {"type": "api_error", "message": str(error)}, 425 | } 426 | yield format_event("error", error_event) 427 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | from contextlib import asynccontextmanager 2 | 3 | from fastapi import FastAPI 4 | from fastapi.middleware.cors import CORSMiddleware 5 | from loguru import logger 6 | 7 | from src.api.handlers import router as messages_router 8 | from src.api.middleware.auth import APIKeyMiddleware 9 | from src.api.middleware.timing import setup_middlewares 10 | from src.api.routes import router as health_router 11 | 12 | # 启动时同步加载配置(模块级别,应用启动时执行) 13 | from .common.logging import get_request_id_from_request 14 | from src.config.settings import Config 15 | 16 | config = Config.from_file_sync() 17 | 18 | 19 | @asynccontextmanager 20 | async def lifespan(app: FastAPI): 21 | """应用生命周期管理""" 22 | from src.api.handlers import MessagesHandler 23 | from src.common.logging import configure_logging 24 | from src.config.settings import get_config_file_path, reload_config 25 | from src.config.watcher import ConfigWatcher 26 | 27 | host, port = config.get_server_config() 28 | 29 | # 配置Loguru日志 30 | configure_logging(config.logging) 31 | 32 | # 创建配置化的消息处理器并缓存到应用中 33 | app.state.messages_handler = await MessagesHandler.create(config) 34 | 35 | # 配置重载回调函数 36 | async def on_config_reload(): 37 | """配置重载时的回调函数""" 38 | try: 39 | # 重新加载配置 40 | new_config = await reload_config() 41 | 42 | # 重新配置日志 43 | configure_logging(new_config.logging) 44 | 45 | # 重新创建消息处理器 46 | app.state.messages_handler = await MessagesHandler.create(new_config) 47 | 48 | logger.info("配置热重载完成,服务已更新") 49 | except Exception as e: 50 | logger.error(f"配置热重载失败: {e}") 51 | 52 | # 启动配置文件监听器 53 | config_watcher = ConfigWatcher(get_config_file_path()) 54 | config_watcher.add_reload_callback(on_config_reload) 55 | await config_watcher.start_watching() 56 | 57 | # 缓存监听器到应用状态 58 | app.state.config_watcher = config_watcher 59 | 60 | logger.info( 61 | f"启动 OpenAI To Claude 服务器 - Host: {host}, Port: {port}, LogLevel: {config.logging.level}" 62 | ) 63 | logger.info(f"配置文件监听已启用: {get_config_file_path()}") 64 | 65 | yield 66 | 67 | # 关闭时的清理工作 68 | logger.info("正在停止配置文件监听...") 69 | if hasattr(app.state, "config_watcher"): 70 | app.state.config_watcher.stop_watching() 71 | logger.info("服务器已停止") 72 | 73 | 74 | app = FastAPI( 75 | title="OpenAI To Claude Server", 76 | version="0.1.0", 77 | description="A server to convert OpenAI API calls to Claude format.", 78 | lifespan=lifespan, 79 | ) 80 | 81 | # 设置CORS中间件 82 | app.add_middleware( 83 | CORSMiddleware, 84 | allow_origins=["*"], # 允许所有来源,生产环境建议指定具体域名 85 | allow_credentials=True, 86 | allow_methods=["*"], # 允许所有HTTP方法 87 | allow_headers=["*"], # 允许所有请求头 88 | ) 89 | 90 | # 设置其他中间件 91 | setup_middlewares(app) 92 | app.add_middleware(APIKeyMiddleware, api_key=config.api_key) 93 | 94 | app.include_router(health_router) 95 | app.include_router(messages_router) 96 | 97 | 98 | @app.get("/") 99 | async def root(): 100 | return {"message": "Welcome to the OpenAI To Claude Server"} 101 | 102 | 103 | from fastapi.responses import JSONResponse 104 | from src.models.errors import get_error_response 105 | 106 | 107 | # 全局异常处理程序 108 | @app.exception_handler(Exception) 109 | async def global_exception_handler(request, exc): 110 | """全局异常处理,防止Internal Server Error直接返回给客户端""" 111 | from src.common.logging import ( 112 | get_logger_with_request_id, 113 | get_request_id_from_request, 114 | ) 115 | 116 | request_id = get_request_id_from_request(request) 117 | bound_logger = get_logger_with_request_id(request_id) 118 | 119 | bound_logger.exception( 120 | "捕获未处理的服务器异常", 121 | error_type=type(exc).__name__, 122 | error_message=str(exc), 123 | request_path=str(request.url), 124 | request_method=request.method, 125 | ) 126 | 127 | # 使用标准错误响应格式返回客户端 128 | error_response = get_error_response(500, message="服务器内部错误,请稍后重试") 129 | return JSONResponse(status_code=500, content=error_response.model_dump()) 130 | 131 | 132 | # Pydantic验证错误处理 133 | @app.exception_handler(422) 134 | async def validation_exception_handler(request, exc): 135 | """处理Pydantic验证错误""" 136 | from src.common.logging import get_logger_with_request_id 137 | 138 | request_id = get_request_id_from_request(request) 139 | bound_logger = get_logger_with_request_id(request_id) 140 | 141 | bound_logger.warning("请求验证失败") 142 | return JSONResponse(status_code=422, content=exc.detail) 143 | 144 | 145 | # 404错误处理 146 | @app.exception_handler(404) 147 | async def not_found_handler(request, exc): 148 | """处理404错误""" 149 | error_response = get_error_response(404, message="请求的资源不存在") 150 | 151 | return JSONResponse(status_code=404, content=error_response.model_dump()) 152 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # 数据模型模块 - 定义所有API相关的数据模型 2 | 3 | # Anthropic模型 4 | from .anthropic import ( 5 | AnthropicContentBlock, 6 | AnthropicContentTypes, 7 | AnthropicErrorDetail, 8 | AnthropicErrorResponse, 9 | AnthropicMessage, 10 | # 数据模型 11 | AnthropicMessageContent, 12 | AnthropicMessageResponse, 13 | AnthropicMessageTypes, 14 | AnthropicPing, 15 | AnthropicRequest, 16 | AnthropicRoles, 17 | # 流式响应相关 18 | AnthropicStreamContentBlock, 19 | # 常量类 20 | AnthropicStreamEventTypes, 21 | AnthropicSystemMessage, 22 | AnthropicTextContent, 23 | AnthropicToolDefinition, 24 | AnthropicToolUse, 25 | AnthropicUsage, 26 | ContentBlock, 27 | Delta, 28 | ) 29 | 30 | # 错误模型 31 | from .errors import ( 32 | ERROR_CODE_MAPPING, 33 | BadRequestError, 34 | ErrorDetail, 35 | ExternalServiceError, 36 | NotFoundError, 37 | RateLimitError, 38 | ServerError, 39 | ServiceUnavailableError, 40 | StandardErrorResponse, 41 | TimeoutError, 42 | UnauthorizedError, 43 | ValidationError, 44 | ValidationErrorItem, 45 | ValidationErrorResponse, 46 | get_error_response, 47 | ) 48 | 49 | # OpenAI模型 50 | from .openai import ( 51 | OpenAIChoice, 52 | OpenAIChoiceDelta, 53 | OpenAICompletionUsage, 54 | OpenAIErrorDetail, 55 | OpenAIErrorResponse, 56 | OpenAIImageUrl, 57 | OpenAIMessage, 58 | OpenAIMessageContent, 59 | OpenAIRequest, 60 | OpenAIResponse, 61 | OpenAIStreamResponse, 62 | OpenAITool, 63 | OpenAIToolCall, 64 | OpenAIToolCallFunction, 65 | OpenAIToolFunction, 66 | OpenAIUsage, 67 | ) 68 | 69 | __all__ = [ 70 | # Anthropic 常量类 71 | "AnthropicStreamEventTypes", 72 | "AnthropicContentTypes", 73 | "AnthropicMessageTypes", 74 | "AnthropicRoles", 75 | # Anthropic 数据模型 76 | "AnthropicMessageContent", 77 | "AnthropicMessage", 78 | "AnthropicSystemMessage", 79 | "AnthropicToolDefinition", 80 | "AnthropicRequest", 81 | "AnthropicToolUse", 82 | "AnthropicTextContent", 83 | "AnthropicContentBlock", 84 | "AnthropicUsage", 85 | "AnthropicMessageResponse", 86 | "AnthropicErrorDetail", 87 | "AnthropicErrorResponse", 88 | "AnthropicPing", 89 | "AnthropicStreamContentBlock", 90 | "Delta", 91 | "ContentBlock", 92 | # OpenAI 93 | "OpenAIMessageContent", 94 | "OpenAIMessage", 95 | "OpenAIImageUrl", 96 | "OpenAIToolCallFunction", 97 | "OpenAIToolCall", 98 | "OpenAIToolFunction", 99 | "OpenAITool", 100 | "OpenAIRequest", 101 | "OpenAIChoiceDelta", 102 | "OpenAIChoice", 103 | "OpenAIUsage", 104 | "OpenAICompletionUsage", 105 | "OpenAIResponse", 106 | "OpenAIStreamResponse", 107 | "OpenAIErrorDetail", 108 | "OpenAIErrorResponse", 109 | # 错误处理 110 | "ErrorDetail", 111 | "StandardErrorResponse", 112 | "ValidationErrorItem", 113 | "ValidationError", 114 | "ValidationErrorResponse", 115 | "UnauthorizedError", 116 | "RateLimitError", 117 | "ServerError", 118 | "TimeoutError", 119 | "NotFoundError", 120 | "BadRequestError", 121 | "ServiceUnavailableError", 122 | "ExternalServiceError", 123 | "ERROR_CODE_MAPPING", 124 | "get_error_response", 125 | ] 126 | -------------------------------------------------------------------------------- /src/models/anthropic.py: -------------------------------------------------------------------------------- 1 | """Anthropic API 数据模型定义""" 2 | 3 | from typing import Any, Literal, Optional, Union 4 | 5 | from pydantic import BaseModel, Field 6 | 7 | 8 | class AnthropicStreamEventTypes: 9 | """Anthropic流式响应事件类型常量""" 10 | 11 | # 消息相关事件 12 | MESSAGE_START = "message_start" 13 | MESSAGE_DELTA = "message_delta" 14 | MESSAGE_STOP = "message_stop" 15 | 16 | # 内容块相关事件 17 | CONTENT_BLOCK_START = "content_block_start" 18 | CONTENT_BLOCK_DELTA = "content_block_delta" 19 | CONTENT_BLOCK_STOP = "content_block_stop" 20 | 21 | # 其他事件 22 | PING = "ping" 23 | 24 | 25 | class AnthropicContentTypes: 26 | """Anthropic内容类型常量""" 27 | 28 | # 基础内容类型 29 | TEXT = "text" 30 | IMAGE = "image" 31 | TOOL_USE = "tool_use" 32 | TOOL_RESULT = "tool_result" 33 | THINKING = "thinking" 34 | 35 | # 增量类型 36 | TEXT_DELTA = "text_delta" 37 | INPUT_JSON_DELTA = "input_json_delta" 38 | THINKING_DELTA = "thinking_delta" 39 | SIGNATURE_DELTA = "signature_delta" 40 | 41 | 42 | class AnthropicMessageTypes: 43 | """Anthropic消息类型常量""" 44 | 45 | MESSAGE = "message" 46 | ERROR = "error" 47 | 48 | 49 | class AnthropicRoles: 50 | """Anthropic角色常量""" 51 | 52 | USER = "user" 53 | ASSISTANT = "assistant" 54 | 55 | 56 | class AnthropicMessageContent(BaseModel): 57 | """Anthropic消息内容项""" 58 | 59 | type: Literal["text", "thinking", "image", "tool_use", "tool_result"] = Field( 60 | description="内容类型" 61 | ) 62 | text: str | None = Field(None, description="文本内容(当type为text时)") 63 | source: dict[str, Any] | None = Field(None, description="当type为image时的源信息") 64 | id: str | None = Field(None, description="工具调用ID(当type为tool_use时)") 65 | name: str | None = Field(None, description="工具名称(当type为tool_use时)") 66 | input: dict[str, Any] | None = Field( 67 | None, description="工具输入参数(当type为tool_use时)" 68 | ) 69 | tool_use_id: str | None = Field( 70 | None, description="工具使用ID(当type为tool_result时)" 71 | ) 72 | content: str | list[dict[str, Any]] | None = Field( 73 | None, description="工具结果内容(当type为tool_result时)" 74 | ) 75 | is_error: bool | None = Field( 76 | None, description="工具调用是否为错误结果(当type为tool_result时)" 77 | ) 78 | 79 | 80 | class AnthropicMessage(BaseModel): 81 | """Anthropic消息格式""" 82 | 83 | role: Literal["user", "assistant"] = Field(description="消息角色") 84 | content: str | list[AnthropicMessageContent] = Field(description="消息内容") 85 | 86 | 87 | class AnthropicSystemMessage(BaseModel): 88 | """Anthropic系统消息""" 89 | 90 | type: Literal["text"] = Field( 91 | default=AnthropicContentTypes.TEXT, description="系统消息类型,固定为text" 92 | ) 93 | text: str = Field(description="系统消息文本内容") 94 | 95 | 96 | class AnthropicToolDefinition(BaseModel): 97 | """Anthropic工具定义""" 98 | 99 | name: str = Field(description="工具名称") 100 | description: Optional[str] = Field(None, description="工具描述") 101 | input_schema: Optional[dict[str, Any]] = Field( 102 | None, description="JSON Schema格式的输入参数定义" 103 | ) 104 | type: Optional[str] = Field(None, description="工具名称") 105 | max_uses: Optional[int] = Field(None, description="工具名称") 106 | 107 | 108 | class AnthropicRequest(BaseModel): 109 | """Anthropic API请求模型""" 110 | 111 | model: str = Field(description="使用的模型ID,如claude-3-5-sonnet-20241022") 112 | messages: list[AnthropicMessage] = Field(description="对话消息列表") 113 | max_tokens: int = Field(description="最大输出token数量") 114 | system: str | list[AnthropicSystemMessage] | None = Field( 115 | None, description="系统提示信息" 116 | ) 117 | tools: list[AnthropicToolDefinition] | None = Field( 118 | None, description="可用工具定义" 119 | ) 120 | tool_choice: str | dict[str, Any] | None = Field(None, description="工具选择配置") 121 | metadata: dict[str, Any] | None = Field(None, description="可选元数据") 122 | stop_sequences: list[str] | None = Field(None, description="停止序列") 123 | stream: bool | None = Field(False, description="是否使用流式响应") 124 | temperature: float | None = Field(None, ge=0.0, le=1.0, description="采样温度") 125 | top_p: float | None = Field(None, ge=0.0, le=1.0, description="top-p采样参数") 126 | top_k: int | None = Field(None, ge=1, le=1000, description="top-k采样参数") 127 | thinking: bool | dict[str, Any] | None = Field( 128 | None, description="是否启用推理模型模式或配置对象" 129 | ) 130 | 131 | 132 | class AnthropicToolUse(BaseModel): 133 | """工具使用响应""" 134 | 135 | id: str = Field(description="工具调用ID") 136 | type: str = Field(default=AnthropicContentTypes.TOOL_USE, description="响应类型") 137 | name: str = Field(description="工具名称") 138 | input: dict[str, Any] = Field(description="工具输入参数") 139 | 140 | 141 | class AnthropicTextContent(BaseModel): 142 | """文本内容响应""" 143 | 144 | type: str = Field(default=AnthropicContentTypes.TEXT, description="内容类型") 145 | text: str = Field(description="文本内容") 146 | 147 | 148 | class AnthropicContentBlock(BaseModel): 149 | """Anthropic内容块""" 150 | 151 | type: Literal["text", "tool_use", "thinking"] 152 | text: str | None = Field(None, description="文本内容,当type为text时") 153 | id: str | None = Field(None, description="工具调用ID,当type为tool_use时") 154 | name: str | None = Field(None, description="工具名称,当type为tool_use时") 155 | input: dict[str, Any] | None = Field( 156 | None, description="工具输入,当type为tool_use时" 157 | ) 158 | thinking: str | None = Field(None, description="思考内容,当type为thinking时") 159 | signature: str | None = Field(None, description="思考内容签名,当type为thinking时") 160 | 161 | 162 | class AnthropicUsage(BaseModel): 163 | """Anthropic使用统计""" 164 | 165 | input_tokens: int = Field(1, description="输入token数量") 166 | output_tokens: int | None = Field(1, description="输出token数量") 167 | cache_creation_input_tokens: int | None = Field( 168 | 0, description="缓存创建输入token数量" 169 | ) 170 | cache_read_input_tokens: int | None = Field(0, description="缓存读取输入token数量") 171 | service_tier: str | None = Field("standard", description="服务层级") 172 | 173 | 174 | class MessageDelta(BaseModel): 175 | """消息增量""" 176 | 177 | stop_reason: str | None = Field(None, description="停止原因") 178 | stop_sequence: str | None = Field(None, description="停止序列") 179 | 180 | 181 | class AnthropicStreamMessageStartMessage(BaseModel): 182 | """Anthropic流式消息开始事件中的消息详情""" 183 | 184 | id: str = Field(description="消息ID") 185 | type: str = Field(default=AnthropicMessageTypes.MESSAGE, description="消息类型") 186 | role: Literal["assistant"] = Field( 187 | default=AnthropicRoles.ASSISTANT, description="消息角色" 188 | ) 189 | model: str = Field(description="使用的模型ID") 190 | content: list[Any] = Field([], description="内容块,通常为空") 191 | stop_reason: str | None = Field(None, description="停止原因") 192 | stop_sequence: str | None = Field(None, description="停止序列") 193 | usage: AnthropicUsage = Field(description="使用统计") 194 | 195 | 196 | class AnthropicStreamMessage(BaseModel): 197 | """流式消息开始事件""" 198 | 199 | type: str = Field( 200 | default=AnthropicStreamEventTypes.MESSAGE_START, description="事件类型" 201 | ) 202 | message: AnthropicStreamMessageStartMessage = Field(None, description="消息详情") 203 | delta: MessageDelta = Field(None, description="消息增量") 204 | usage: AnthropicUsage = Field(None, description="使用统计") 205 | 206 | 207 | class Delta(BaseModel): 208 | """文本增量""" 209 | 210 | type: str = Field(default=AnthropicContentTypes.TEXT_DELTA, description="增量类型") 211 | text: str | None = Field(None, description="文本增量内容") 212 | thinking: str | None = Field(None, description="思考内容") 213 | signature: str | None = Field(None, description="签名内容") 214 | partial_json: str | None = Field(None, description="部分JSON字符串") 215 | 216 | 217 | class InputJsonDelta(BaseModel): 218 | """输入JSON增量""" 219 | 220 | type: str = Field( 221 | default=AnthropicContentTypes.INPUT_JSON_DELTA, description="增量类型" 222 | ) 223 | 224 | 225 | class AnthropicUsageDelta(BaseModel): 226 | """使用统计增量""" 227 | 228 | output_tokens: int = Field(description="输出token数量增量") 229 | 230 | 231 | class AnthropicMessageResponse(BaseModel): 232 | """Anthropic消息响应""" 233 | 234 | id: str = Field(description="响应唯一ID") 235 | type: str = Field(default=AnthropicMessageTypes.MESSAGE, description="响应类型") 236 | role: Literal["assistant"] = Field( 237 | default=AnthropicRoles.ASSISTANT, description="消息角色" 238 | ) 239 | content: list[AnthropicContentBlock] = Field(description="消息内容块") 240 | model: str = Field(description="使用的模型ID") 241 | stop_reason: str | None = Field(None, description="停止原因") 242 | stop_sequence: str | None = Field(None, description="停止序列") 243 | usage: AnthropicUsage = Field(description="使用统计") 244 | 245 | 246 | class AnthropicErrorDetail(BaseModel): 247 | """Anthropic错误详情""" 248 | 249 | type: str = Field(description="错误类型") 250 | message: str = Field(description="错误消息") 251 | 252 | 253 | class AnthropicErrorResponse(BaseModel): 254 | """Anthropic错误响应模型""" 255 | 256 | type: str = Field(default=AnthropicMessageTypes.ERROR, description="响应类型") 257 | error: AnthropicErrorDetail = Field(description="错误详情") 258 | 259 | 260 | class ContentBlock(BaseModel): 261 | """内容块""" 262 | 263 | type: str = Field(default=AnthropicContentTypes.TEXT, description="内容块类型") 264 | text: str | None = Field(None, description="文本内容") 265 | thinking: str | None = Field(None, description="思考内容") 266 | signature: str | None = Field(None, description="签名内容") 267 | # tool_use相关字段 268 | id: str | None = Field(None, description="工具调用ID,当type为tool_use时") 269 | name: str | None = Field(None, description="工具名称,当type为tool_use时") 270 | input: dict[str, Any] | None = Field( 271 | None, description="工具输入,当type为tool_use时" 272 | ) 273 | 274 | 275 | class AnthropicStreamContentBlockStart(BaseModel): 276 | """流式内容块开始""" 277 | 278 | type: str = Field( 279 | default=AnthropicStreamEventTypes.CONTENT_BLOCK_START, description="事件类型" 280 | ) 281 | index: int = Field(default=0, description="内容块索引") 282 | content_block: ContentBlock = Field( 283 | default_factory=lambda: ContentBlock(type=AnthropicContentTypes.TEXT, text="") 284 | ) 285 | 286 | 287 | class AnthropicStreamContentBlock(BaseModel): 288 | """流式内容块增量""" 289 | 290 | type: str = Field( 291 | default=AnthropicStreamEventTypes.CONTENT_BLOCK_DELTA, description="事件类型" 292 | ) 293 | index: int = Field(0, description="内容块索引") 294 | delta: Delta = Field(None, description="增量内容") 295 | content_block: ContentBlock = Field(None, description="内容块") 296 | usage: Delta | None = Field(None, description="使用统计") 297 | 298 | 299 | class AnthropicStreamContentBlockStop(BaseModel): 300 | """流式内容块结束""" 301 | 302 | type: str = Field( 303 | default=AnthropicStreamEventTypes.CONTENT_BLOCK_STOP, description="事件类型" 304 | ) 305 | index: int = Field(0, description="内容块索引") 306 | 307 | 308 | class AnthropicPing(BaseModel): 309 | """流式ping消息""" 310 | 311 | type: str = Field(default=AnthropicStreamEventTypes.PING, description="事件类型") 312 | 313 | 314 | AnthropicStreamResponse = Union[ 315 | AnthropicStreamMessage, 316 | AnthropicStreamContentBlockStart, 317 | AnthropicStreamContentBlock, 318 | AnthropicStreamContentBlockStop, 319 | AnthropicPing, 320 | ] 321 | -------------------------------------------------------------------------------- /src/models/errors.py: -------------------------------------------------------------------------------- 1 | """标准化错误响应模型""" 2 | 3 | from typing import Any 4 | 5 | from pydantic import BaseModel, Field 6 | 7 | 8 | class ErrorDetail(BaseModel): 9 | """错误详细信息""" 10 | 11 | code: str = Field(description="错误代码") 12 | message: str = Field(description="错误消息") 13 | param: str | None = Field(None, description="相关参数名称") 14 | type: str | None = Field(None, description="错误类型") 15 | details: dict[str, Any] | None = Field(None, description="额外错误详情") 16 | request_id: str | None = Field(None, description="请求ID用于追踪") 17 | 18 | 19 | class StandardErrorResponse(BaseModel): 20 | """标准化错误响应模型""" 21 | 22 | type: str = Field("error", description="响应类型") 23 | error: ErrorDetail = Field(description="错误详情") 24 | 25 | 26 | class ValidationErrorItem(BaseModel): 27 | """验证错误项""" 28 | 29 | loc: list[str] = Field(description="错误位置字段路径") 30 | msg: str = Field(description="错误消息") 31 | type: str = Field(description="错误类型") 32 | 33 | 34 | class ValidationError(BaseModel): 35 | """验证错误详情""" 36 | 37 | code: str = Field("validation_error", description="错误代码") 38 | message: str = Field("请求参数验证失败", description="错误消息") 39 | details: list[ValidationErrorItem] = Field(..., description="验证错误列表") 40 | 41 | 42 | class ValidationErrorResponse(BaseModel): 43 | """验证错误响应模型""" 44 | 45 | type: str = Field("validation_error", description="响应类型") 46 | error: ValidationError = Field(description="验证错误详情") 47 | 48 | 49 | class UnauthorizedError(BaseModel): 50 | """未授权错误""" 51 | 52 | code: str = Field("unauthorized", description="错误代码") 53 | message: str = Field("无效的API密钥或未经授权的访问", description="错误消息") 54 | type: str = Field("authentication_error", description="错误类型") 55 | 56 | 57 | class RateLimitError(BaseModel): 58 | """限流错误""" 59 | 60 | code: str = Field("rate_limit_exceeded", description="错误代码") 61 | message: str = Field("请求频率超出限制,请稍后重试", description="错误消息") 62 | type: str = Field("rate_limit_error", description="错误类型") 63 | retry_after: int | None = Field(None, description="重试等待时间(秒)") 64 | 65 | 66 | class ServerError(BaseModel): 67 | """服务器内部错误""" 68 | 69 | code: str = Field("internal_server_error", description="错误代码") 70 | message: str = Field("服务器内部错误,请稍后重试", description="错误消息") 71 | type: str = Field("server_error", description="错误类型") 72 | request_id: str | None = Field(None, description="请求ID用于排查问题") 73 | 74 | 75 | class TimeoutError(BaseModel): 76 | """超时错误""" 77 | 78 | code: str = Field("timeout", description="错误代码") 79 | message: str = Field("请求超时,请稍后重试", description="错误消息") 80 | type: str = Field("timeout_error", description="错误类型") 81 | timeout: int | None = Field(None, description="超时时间(秒)") 82 | 83 | 84 | class NotFoundError(BaseModel): 85 | """资源未找到错误""" 86 | 87 | code: str = Field("not_found", description="错误代码") 88 | message: str = Field("请求的资源不存在", description="错误消息") 89 | type: str = Field("not_found_error", description="错误类型") 90 | 91 | 92 | class BadRequestError(BaseModel): 93 | """错误请求""" 94 | 95 | code: str = Field("bad_request", description="错误代码") 96 | message: str = Field("请求格式错误或参数无效", description="错误消息") 97 | type: str = Field("invalid_request_error", description="错误类型") 98 | 99 | 100 | class ServiceUnavailableError(BaseModel): 101 | """服务不可用""" 102 | 103 | code: str = Field("service_unavailable", description="错误代码") 104 | message: str = Field("服务暂时不可用,请稍后重试", description="错误消息") 105 | type: str = Field("server_error", description="错误类型") 106 | retry_after: int | None = Field(None, description="建议重试时间(秒)") 107 | 108 | 109 | class ExternalServiceError(BaseModel): 110 | """外部服务错误""" 111 | 112 | code: str = Field("external_service_error", description="错误代码") 113 | message: str = Field("外部服务错误,请稍后重试", description="错误消息") 114 | type: str = Field("api_error", description="错误类型") 115 | service: str | None = Field(None, description="出错的外部服务名称") 116 | original_error: dict[str, Any] | None = Field( 117 | None, description="原始错误信息(生产环境可能会省略)" 118 | ) 119 | 120 | 121 | # 错误代码映射表 122 | ERROR_CODE_MAPPING = { 123 | 400: BadRequestError, 124 | 401: UnauthorizedError, 125 | 404: NotFoundError, 126 | 422: ValidationError, 127 | 429: RateLimitError, 128 | 500: ServerError, 129 | 502: ExternalServiceError, 130 | 503: ServiceUnavailableError, 131 | 504: TimeoutError, 132 | } 133 | 134 | 135 | def format_compact_traceback(error: Exception, max_lines: int = 10) -> str: 136 | """格式化紧凑的错误堆栈信息,只保留项目相关部分""" 137 | import traceback 138 | 139 | error_traceback = "".join( 140 | traceback.format_exception(type(error), error, error.__traceback__) 141 | ) 142 | lines = error_traceback.split("\n") 143 | 144 | # 过滤出项目相关的行 145 | filtered_lines = [] 146 | for line in lines: 147 | if ( 148 | "openai-claude-code-proxy/src" in line 149 | or line.strip().startswith('File "/Users') 150 | or any(keyword in line for keyword in ["Error:", "Exception:", " "]) 151 | or line.strip() == "" 152 | ): 153 | filtered_lines.append(line) 154 | 155 | # 只保留最后max_lines行 156 | return "\n".join(filtered_lines[-max_lines:]) if filtered_lines else str(error) 157 | 158 | 159 | def get_error_response( 160 | status_code: int, 161 | message: str | None = None, 162 | details: dict[str, Any] | None = None, 163 | ) -> StandardErrorResponse: 164 | """根据HTTP状态码获取对应的错误响应模型,返回Pydantic模型实例""" 165 | error_class = ERROR_CODE_MAPPING.get(status_code, ServerError) 166 | 167 | # 创建错误详情 168 | if error_class == ValidationError: 169 | # 验证错误需要特殊处理 170 | if details and "validation_errors" in details: 171 | validation_items = [ 172 | { 173 | "loc": error.get("loc", []), 174 | "msg": error.get("msg", ""), 175 | "type": error.get("type", "value_error"), 176 | } 177 | for error in details["validation_errors"] 178 | ] 179 | error_detail_data = { 180 | "code": "validation_error", 181 | "message": message or "请求参数验证失败", 182 | "details": validation_items, 183 | } 184 | else: 185 | error_detail_data = { 186 | "code": "validation_error", 187 | "message": message or "请求参数验证失败", 188 | "details": [], 189 | } 190 | else: 191 | # 其他错误类型 192 | error_detail_data = { 193 | "code": error_class.model_fields["code"].default, 194 | "message": message or error_class.model_fields["message"].default, 195 | } 196 | 197 | if details: 198 | # Add optional fields if present in details 199 | for key in [ 200 | "param", 201 | "type", 202 | "retry_after", 203 | "request_id", 204 | "service", 205 | "original_error", 206 | ]: 207 | if key in details: 208 | error_detail_data[key] = details[key] 209 | # 将其他details作为details字段 210 | other_details = { 211 | k: v 212 | for k, v in details.items() 213 | if k 214 | not in [ 215 | "param", 216 | "type", 217 | "retry_after", 218 | "request_id", 219 | "service", 220 | "original_error", 221 | ] 222 | } 223 | if other_details: 224 | error_detail_data["details"] = other_details 225 | 226 | # 创建并返回StandardErrorResponse实例 227 | return StandardErrorResponse(error=ErrorDetail(**error_detail_data)) 228 | -------------------------------------------------------------------------------- /src/models/openai.py: -------------------------------------------------------------------------------- 1 | """OpenAI API 数据模型定义""" 2 | 3 | from typing import Any, Literal 4 | 5 | from pydantic import BaseModel, Field 6 | 7 | 8 | class OpenAIMessageContent(BaseModel): 9 | """OpenAI消息内容项""" 10 | 11 | type: Literal["text", "image_url"] = Field(description="内容类型") 12 | text: str | None = Field(None, description="文本内容") 13 | image_url: dict[str, str] | None = Field(None, description="图像URL配置") 14 | 15 | 16 | class OpenAIMessage(BaseModel): 17 | """OpenAI消息格式""" 18 | 19 | role: Literal["system", "user", "assistant", "tool"] = Field(description="消息角色") 20 | content: str | list[OpenAIMessageContent] | None = Field(description="消息内容") 21 | name: str | None = Field(None, description="消息作者名称") 22 | tool_calls: list[dict[str, Any]] | None = Field( 23 | None, description="工具调用信息(当role为assistant时)" 24 | ) 25 | tool_call_id: str | None = Field( 26 | None, description="工具调用ID(当role为tool时)" 27 | ) 28 | refusal: str | None = Field(None, description="拒绝服务的详细信息") 29 | reasoning_content: str | None = Field(None, description="推理内容") 30 | annotations: list[dict[str, Any]] | None = Field( 31 | None, description="模型生成的标注" 32 | ) 33 | 34 | 35 | class OpenAIImageUrl(BaseModel): 36 | """OpenAI图像URL""" 37 | 38 | url: str = Field(description="图像URL") 39 | detail: Literal["auto", "low", "high"] | None = Field( 40 | "auto", description="图像细节级别" 41 | ) 42 | 43 | 44 | class OpenAIToolCallFunction(BaseModel): 45 | """工具调用函数""" 46 | 47 | name: str | None = Field(None, description="函数名称") # 改为可选,因为流式响应中可能缺失 48 | arguments: str | None = Field(None, description="JSON格式的函数参数") # 改为可选,支持增量传输 49 | 50 | 51 | class OpenAIToolCall(BaseModel): 52 | """工具调用""" 53 | 54 | id: str = Field(description="工具调用ID") 55 | type: Literal["function"] = Field("function", description="调用类型") 56 | function: OpenAIToolCallFunction = Field(description="函数详情") 57 | 58 | 59 | class OpenAIDeltaToolCall(BaseModel): 60 | """工具调用增量""" 61 | 62 | index: int | None = Field(None, description="工具调用索引") # 改为可选,因为流式响应中可能缺失 63 | id: str | None = Field(None, description="工具调用ID") 64 | type: Literal["function"] | None = Field(None, description="调用类型") 65 | function: OpenAIToolCallFunction | None = Field(None, description="函数详情增量") 66 | 67 | 68 | class OpenAIToolFunction(BaseModel): 69 | """OpenAI工具函数定义""" 70 | 71 | name: str = Field(description="函数名称") 72 | description: str | None = Field(None, description="函数描述") 73 | parameters: dict[str, Any] | None = Field( 74 | None, description="JSON Schema格式的函数参数" 75 | ) 76 | 77 | 78 | class OpenAITool(BaseModel): 79 | """OpenAI工具定义""" 80 | 81 | type: Literal["function"] = Field("function", description="工具类型") 82 | function: OpenAIToolFunction = Field(description="函数定义") 83 | 84 | 85 | class OpenAIRequest(BaseModel): 86 | """OpenAI API请求模型""" 87 | 88 | model: str = Field(description="使用的模型ID,如gpt-4o或claude-3-5-sonnet") 89 | messages: list[OpenAIMessage] = Field(description="对话消息列表") 90 | max_tokens: int | None = Field(None, description="最大输出token数量") 91 | max_completion_tokens: int | None = Field( 92 | None, description="最大完成token数量(新格式)" 93 | ) 94 | temperature: float | None = Field(None, ge=0.0, le=2.0, description="采样温度") 95 | top_p: float | None = Field(None, ge=0.0, le=1.0, description="top-p采样参数") 96 | top_k: int | None = Field(None, ge=0, description="top-k采样参数") 97 | stream: bool | None = Field(False, description="是否使用流式响应") 98 | stream_options: dict[str, Any] | None = Field(None, description="流式响应选项") 99 | stop: str | list[str] | None = Field(None, description="停止序列") 100 | frequency_penalty: float | None = Field(None, description="频率惩罚") 101 | presence_penalty: float | None = Field(None, description="存在惩罚") 102 | logprobs: bool | None = Field(False, description="是否返回log概率") 103 | top_logprobs: int | None = Field( 104 | None, description="返回的top log probability数量" 105 | ) 106 | logit_bias: dict[str, int] | None = Field(None, description="logit偏差") 107 | n: int | None = Field(None, ge=1, le=128, description="生成的消息数量") 108 | seed: int | None = Field(None, description="随机种子") 109 | response_format: dict[str, Any] | None = Field(None, description="响应格式配置") 110 | tools: list[OpenAITool] | None = Field(None, description="可用工具定义") 111 | tool_choice: str | dict[str, Any] | None = Field( 112 | None, description="工具选择配置" 113 | ) 114 | parallel_tool_calls: bool | None = Field( 115 | None, description="是否允许并行工具调用" 116 | ) 117 | user: str | None = Field(None, description="用户信息") 118 | think: bool | None = Field(None, description="是否启用推理模型模式") 119 | 120 | 121 | class OpenAIChoiceDelta(BaseModel): 122 | """流式响应增量内容""" 123 | 124 | role: str | None = Field(None, description="消息角色") 125 | content: str | None = Field(None, description="内容增量") 126 | reasoning_content: str | None = Field(None, description="推理内容增量") 127 | tool_calls: list[OpenAIDeltaToolCall] | None = Field( 128 | None, description="工具调用增量" 129 | ) 130 | 131 | 132 | class OpenAIChoice(BaseModel): 133 | """OpenAI响应选项""" 134 | 135 | index: int = Field(description="选项索引") 136 | message: OpenAIMessage | None = Field(None, description="完整消息响应") 137 | delta: OpenAIChoiceDelta | None = Field(None, description="流式增量内容") 138 | finish_reason: str | None = Field( 139 | None, 140 | description="完成原因: stop, length, content_filter, tool_calls, function_call", 141 | ) 142 | logprobs: dict[str, Any] | None = Field(None, description="log概率信息") 143 | 144 | 145 | class OpenAIUsage(BaseModel): 146 | """OpenAI使用统计""" 147 | 148 | prompt_tokens: int = Field(description="提示token数量") 149 | completion_tokens: int = Field(description="完成token数量") 150 | total_tokens: int = Field(description="总token数量") 151 | completion_tokens_details: dict[str, Any] | None = Field( 152 | None, description="完成token的详细信息" 153 | ) 154 | prompt_tokens_details: dict[str, Any] | None = Field( 155 | None, description="提示token的详细信息" 156 | ) 157 | 158 | 159 | class OpenAICompletionUsage(BaseModel): 160 | """OpenAI完成使用统计""" 161 | 162 | completion_tokens: int = Field(description="完成token数量") 163 | prompt_tokens: int = Field(description="提示token数量") 164 | total_tokens: int = Field(description="总token数量") 165 | 166 | 167 | class OpenAIResponse(BaseModel): 168 | """OpenAI API响应模型""" 169 | 170 | id: str = Field(description="响应唯一ID") 171 | object: Literal["chat.completion"] = Field(description="对象类型") 172 | created: int = Field(description="创建时间戳") 173 | model: str = Field(description="使用的模型ID") 174 | choices: list[OpenAIChoice] = Field(description="响应选项列表") 175 | usage: OpenAIUsage = Field(description="使用统计") 176 | system_fingerprint: str | None = Field(None, description="系统指纹") 177 | 178 | 179 | class OpenAIStreamResponse(BaseModel): 180 | """OpenAI流式响应模型""" 181 | 182 | id: str = Field(description="响应唯一ID") 183 | object: Literal["chat.completion.chunk"] = Field(description="对象类型") 184 | created: int = Field(description="创建时间戳") 185 | model: str = Field(description="使用的模型ID") 186 | choices: list[OpenAIChoice] = Field(description="响应选项列表") 187 | usage: OpenAIUsage | None = Field( 188 | None, description="使用统计(仅在流式响应最后一块出现)" 189 | ) 190 | system_fingerprint: str | None = Field(None, description="系统指纹") 191 | 192 | 193 | class OpenAIErrorDetail(BaseModel): 194 | """OpenAI错误详情""" 195 | 196 | message: str = Field(description="错误消息") 197 | type: str = Field(description="错误类型") 198 | param: str | None = Field(None, description="相关参数") 199 | code: str | None = Field(None, description="错误代码") 200 | 201 | 202 | class OpenAIErrorResponse(BaseModel): 203 | """OpenAI错误响应模型""" 204 | 205 | error: OpenAIErrorDetail = Field(description="错误详情") 206 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 测试模块 3 | 4 | 包含项目的单元测试和集成测试。 5 | 6 | 测试结构: 7 | - unit/: 单元测试 8 | - integration/: 集成测试 9 | - fixtures/: 测试夹具和工具 10 | 11 | 测试覆盖: 12 | - API端点测试 13 | - 数据转换测试 14 | - 错误处理测试 15 | - 流式响应测试 16 | - 配置管理测试 17 | """ 18 | 19 | # 导入测试夹具 20 | from .fixtures import * 21 | 22 | __all__ = [ 23 | # 测试夹具将通过 fixtures 模块的 __all__ 自动导出 24 | ] -------------------------------------------------------------------------------- /tests/fixtures.py: -------------------------------------------------------------------------------- 1 | """Mock OpenAI server for end-to-end testing.""" 2 | 3 | from typing import Dict, Any, List, Optional, AsyncGenerator 4 | from fastapi import FastAPI, HTTPException, Header 5 | from pydantic import BaseModel, Field 6 | import time 7 | import asyncio 8 | from contextlib import asynccontextmanager 9 | 10 | 11 | class MockMessage(BaseModel): 12 | role: str 13 | content: str 14 | 15 | 16 | class MockChoice(BaseModel): 17 | index: int = 0 18 | message: MockMessage 19 | finish_reason: str = "stop" 20 | 21 | 22 | class MockDelta(BaseModel): 23 | content: Optional[str] = None 24 | role: Optional[str] = None 25 | 26 | 27 | class MockStreamChoice(BaseModel): 28 | index: int = 0 29 | delta: MockDelta 30 | finish_reason: Optional[str] = None 31 | 32 | 33 | class MockUsage(BaseModel): 34 | prompt_tokens: int 35 | completion_tokens: int 36 | total_tokens: int 37 | 38 | 39 | class MockCompletionResponse(BaseModel): 40 | id: str = Field(default="mock-completion-id") 41 | object: str = "chat.completion" 42 | created: int = Field(default_factory=lambda: int(time.time())) 43 | model: str 44 | choices: List[MockChoice] 45 | usage: MockUsage 46 | 47 | 48 | class MockStreamResponse(BaseModel): 49 | id: str = Field(default="mock-completion-id") 50 | object: str = "chat.completion.chunk" 51 | created: int = Field(default_factory=lambda: int(time.time())) 52 | model: str 53 | choices: List[MockStreamChoice] 54 | 55 | 56 | class MockChatCompletionRequest(BaseModel): 57 | model: str 58 | messages: List[MockMessage] 59 | max_tokens: Optional[int] = None 60 | temperature: Optional[float] = None 61 | stream: bool = False 62 | 63 | 64 | mock_responses: Dict[str, Dict[str, Any]] = {} 65 | error_trigger: str = "" 66 | delay_ms: int = 0 67 | 68 | 69 | @asynccontextmanager 70 | async def lifespan(app: FastAPI): 71 | """Manage mock server lifecycle.""" 72 | print("Mock OpenAI server starting...") 73 | yield 74 | print("Mock OpenAI server shutting down...") 75 | 76 | 77 | app = FastAPI(lifespan=lifespan) 78 | 79 | 80 | @app.middleware("http") 81 | async def add_delay(request, call_next): 82 | """Add configurable delay to all responses.""" 83 | if delay_ms > 0: 84 | await asyncio.sleep(delay_ms / 1000) 85 | response = await call_next(request) 86 | return response 87 | 88 | 89 | @app.post("/v1/chat/completions") 90 | async def mock_chat_completions( 91 | request: MockChatCompletionRequest, 92 | authorization: str = Header(...), 93 | ): 94 | """Mock the OpenAI chat completions endpoint.""" 95 | 96 | if not authorization.startswith("Bearer mock-key"): 97 | raise HTTPException(status_code=401, detail="Invalid API key") 98 | 99 | if error_trigger == "invalid_model": 100 | raise HTTPException(status_code=400, detail="Invalid model") 101 | 102 | if error_trigger == "rate_limit": 103 | raise HTTPException(status_code=429, detail="Rate limit exceeded") 104 | 105 | if error_trigger == "server_error": 106 | raise HTTPException(status_code=500, detail="Internal server error") 107 | 108 | model = request.model 109 | if model == "invalid-model": 110 | raise HTTPException(status_code=404, detail="Model not found") 111 | 112 | last_message = request.messages[-1] 113 | 114 | if request.stream: 115 | return mock_stream_response(model, last_message.content or "") 116 | 117 | return mock_non_stream_response(model, last_message.content or "") 118 | 119 | 120 | def mock_non_stream_response(model: str, content: str) -> MockCompletionResponse: 121 | """Generate a non-streaming chat completion response.""" 122 | prompt_tokens = len(content.split()) 123 | completion_tokens = 10 # Fixed for testing 124 | 125 | return MockCompletionResponse( 126 | model=model, 127 | choices=[ 128 | MockChoice( 129 | message=MockMessage( 130 | role="assistant", 131 | content=f"Mock response for: {content}" 132 | ), 133 | finish_reason="stop" 134 | ) 135 | ], 136 | usage=MockUsage( 137 | prompt_tokens=prompt_tokens, 138 | completion_tokens=completion_tokens, 139 | total_tokens=prompt_tokens + completion_tokens 140 | ) 141 | ) 142 | 143 | 144 | async def mock_stream_response(model: str, content: str): 145 | """Generate a streaming chat completion response.""" 146 | words = ["This", "is", "a", "mock", "response", "for:", content[:10], "..."] 147 | 148 | # First chunk with role 149 | yield MockStreamResponse( 150 | model=model, 151 | choices=[ 152 | MockStreamChoice( 153 | delta=MockDelta(role="assistant", content=""), 154 | finish_reason=None 155 | ) 156 | ] 157 | ).model_dump_json() + "\n" 158 | 159 | # Stream content chunks 160 | for word in words: 161 | chunk = MockStreamResponse( 162 | model=model, 163 | choices=[ 164 | MockStreamChoice( 165 | delta=MockDelta(content=f"{word}"), 166 | finish_reason=None 167 | ) 168 | ] 169 | ) 170 | yield chunk.model_dump_json() + "\n" 171 | await asyncio.sleep(0.05) # Simulate streaming delay 172 | 173 | # Final chunk 174 | chunk = MockStreamResponse( 175 | model=model, 176 | choices=[ 177 | MockStreamChoice( 178 | delta=MockDelta(content=""), 179 | finish_reason="stop" 180 | ) 181 | ] 182 | ) 183 | yield chunk.model_dump_json() + "\n" 184 | 185 | 186 | @app.get("/health") 187 | async def health_check(): 188 | """Health check endpoint.""" 189 | return {"status": "healthy"} 190 | 191 | 192 | @app.get("/mock/config") 193 | async def get_config(): 194 | """Get current mock configuration.""" 195 | return { 196 | "error_trigger": error_trigger, 197 | "delay_ms": delay_ms, 198 | "responses": list(mock_responses.keys()) 199 | } 200 | 201 | 202 | @app.post("/mock/configure") 203 | async def configure_mock(config: Dict[str, Any]): 204 | """Configure mock server behavior.""" 205 | global error_trigger, delay_ms, mock_responses 206 | 207 | if "error_trigger" in config: 208 | error_trigger = config["error_trigger"] 209 | if "delay_ms" in config: 210 | delay_ms = config["delay_ms"] 211 | if "responses" in config: 212 | mock_responses.update(config["responses"]) 213 | 214 | return {"message": "Mock server configured"} 215 | 216 | 217 | @app.delete("/mock/reset") 218 | async def reset_mock(): 219 | """Reset mock server to default state.""" 220 | global error_trigger, delay_ms, mock_responses 221 | error_trigger = "" 222 | delay_ms = 0 223 | mock_responses.clear() 224 | return {"message": "Mock server reset"} -------------------------------------------------------------------------------- /tests/integration/test_end_to_end.py: -------------------------------------------------------------------------------- 1 | """End-to-end integration tests for the OpenAI-to-Claude proxy.""" 2 | 3 | import asyncio 4 | import json 5 | import os 6 | import requests 7 | import pytest 8 | from typing import Dict, Any, List 9 | import httpx 10 | from fastapi.testclient import TestClient 11 | from unittest.mock import patch, AsyncMock 12 | 13 | from src.main import app 14 | from tests.fixtures.mock_openai_server import app as mock_app 15 | 16 | 17 | class TestEndToEndIntegration: 18 | """End-to-end integration tests covering the full proxy flow.""" 19 | 20 | @pytest.fixture(autouse=True) 21 | def setup_client(self): 22 | """Set up test environment and create a test client with auth headers.""" 23 | api_key = "mock-key" 24 | os.environ["API_KEY"] = api_key 25 | os.environ["OPENAI_API_KEY"] = "mock-openai-key" 26 | os.environ["OPENAI_BASE_URL"] = "http://localhost:8001" 27 | 28 | # Reset mock server configuration 29 | try: 30 | requests.delete("http://localhost:8001/mock/reset") 31 | except requests.ConnectionError: 32 | # Mock server might not be running, which is fine for some tests 33 | pass 34 | 35 | self.client = TestClient(app) 36 | self.client.headers = {"Authorization": f"Bearer {api_key}"} 37 | 38 | yield 39 | 40 | # Clean up environment 41 | if "API_KEY" in os.environ: 42 | del os.environ["API_KEY"] 43 | if "OPENAI_API_KEY" in os.environ: 44 | del os.environ["OPENAI_API_KEY"] 45 | if "OPENAI_BASE_URL" in os.environ: 46 | del os.environ["OPENAI_BASE_URL"] 47 | 48 | def test_basic_chat_completion(self): 49 | """Test complete flow from Anthropic request to OpenAI proxy.""" 50 | client = self.client 51 | 52 | payload = { 53 | "model": "claude-3-sonnet-20240229", 54 | "max_tokens": 1024, 55 | "messages": [ 56 | {"role": "user", "content": "Hello, world!"} 57 | ] 58 | } 59 | 60 | response = client.post("/v1/messages", json=payload) 61 | 62 | assert response.status_code == 200 63 | data = response.json() 64 | 65 | # Verify Anthropic response structure 66 | assert "id" in data 67 | assert "type" in data 68 | assert "role" in data 69 | assert "content" in data 70 | assert "model" in data 71 | assert "stop_reason" in data 72 | assert "usage" in data 73 | 74 | # Verify content appeared in mock response 75 | assert "Hello, world!" in data["content"][0]["text"] 76 | 77 | def test_chat_completion_with_system_message(self): 78 | """Test chat completion with system message included.""" 79 | client = self.client 80 | 81 | payload = { 82 | "model": "claude-3-sonnet-20240229", 83 | "max_tokens": 1024, 84 | "messages": [ 85 | {"role": "system", "content": "You are a helpful assistant."}, 86 | {"role": "user", "content": "Tell me a joke"} 87 | ] 88 | } 89 | 90 | response = client.post("/v1/messages", json=payload) 91 | 92 | assert response.status_code == 200 93 | data = response.json() 94 | 95 | assert data["role"] == "assistant" 96 | assert "Tell me a joke" in data["content"][0]["text"] 97 | 98 | def test_chat_completion_with_temperature(self): 99 | """Test parameter conversion including temperature.""" 100 | client = self.client 101 | 102 | payload = { 103 | "model": "claude-3-sonnet-20240229", 104 | "max_tokens": 500, 105 | "temperature": 0.7, 106 | "messages": [ 107 | {"role": "user", "content": "Give me a creative response"} 108 | ] 109 | } 110 | 111 | response = client.post("/v1/messages", json=payload) 112 | 113 | assert response.status_code == 200 114 | data = response.json() 115 | 116 | assert response.status_code == 200 117 | assert "creative response" in data["content"][0]["text"] 118 | 119 | def test_invalid_model_error(self): 120 | """Test error handling for invalid model.""" 121 | client = self.client 122 | 123 | # Configure mock to return error 124 | requests.post( 125 | "http://localhost:8001/mock/configure", 126 | json={"error_trigger": "invalid_model"} 127 | ) 128 | 129 | payload = { 130 | "model": "invalid-model-name", 131 | "max_tokens": 1024, 132 | "messages": [{"role": "user", "content": "test"}] 133 | } 134 | 135 | response = client.post("/v1/messages", json=payload) 136 | 137 | assert response.status_code == 400 138 | data = response.json() 139 | 140 | assert "type" in data 141 | assert "error" in data 142 | assert data["error"]["type"] == "invalid_request_error" 143 | 144 | def test_rate_limit_error(self): 145 | """Test rate limit error response.""" 146 | client = self.client 147 | 148 | # Configure mock to return rate limit error 149 | requests.post( 150 | "http://localhost:8001/mock/configure", 151 | json={"error_trigger": "rate_limit"} 152 | ) 153 | 154 | payload = { 155 | "model": "claude-3-sonnet-20240229", 156 | "max_tokens": 1024, 157 | "messages": [{"role": "user", "content": "test"}] 158 | } 159 | 160 | response = client.post("/v1/messages", json=payload) 161 | 162 | assert response.status_code == 429 163 | data = response.json() 164 | 165 | assert data["error"]["type"] == "rate_limit_error" 166 | 167 | def test_server_error(self): 168 | """Test server error response.""" 169 | client = self.client 170 | 171 | # Configure mock to return server error 172 | requests.post( 173 | "http://localhost:8001/mock/configure", 174 | json={"error_trigger": "server_error"} 175 | ) 176 | 177 | payload = { 178 | "model": "claude-3-sonnet-20240229", 179 | "max_tokens": 1024, 180 | "messages": [{"role": "user", "content": "test"}] 181 | } 182 | 183 | response = client.post("/v1/messages", json=payload) 184 | 185 | assert response.status_code == 500 186 | data = response.json() 187 | 188 | assert data["error"]["type"].startswith("api") 189 | 190 | def test_invalid_api_key_error(self): 191 | """Test unauthorized error with invalid API key.""" 192 | os.environ["OPENAI_API_KEY"] = "invalid-key" 193 | 194 | client = self.client 195 | 196 | payload = { 197 | "model": "claude-3-sonnet-20240229", 198 | "max_tokens": 1024, 199 | "messages": [{"role": "user", "content": "test"}] 200 | } 201 | 202 | response = client.post("/v1/messages", json=payload) 203 | 204 | assert response.status_code == 401 205 | data = response.json() 206 | 207 | assert data["error"]["type"] == "authentication_error" 208 | 209 | def test_request_validation_error(self): 210 | """Test request validation error handling.""" 211 | client = self.client 212 | 213 | # Invalid payload structure 214 | payload = { 215 | "model": "claude-3-sonnet-20240229", 216 | # Missing required 'messages' 217 | "max_tokens": 1024 218 | } 219 | 220 | response = client.post("/v1/messages", json=payload) 221 | 222 | assert response.status_code == 422 223 | data = response.json() 224 | 225 | assert "messages" in data["error"]["type"] or "messages" in str(data) 226 | 227 | def test_empty_message_content(self): 228 | """Test handling of empty message content.""" 229 | client = self.client 230 | 231 | payload = { 232 | "model": "claude-3-sonnet-20240229", 233 | "max_tokens": 1024, 234 | "messages": [{"role": "user", "content": ""}] 235 | } 236 | 237 | response = client.post("/v1/messages", json=payload) 238 | 239 | assert response.status_code == 200 240 | data = response.json() 241 | 242 | assert "content" in data 243 | assert len(data["content"]) > 0 244 | 245 | def test_concurrent_requests(self): 246 | """Test handling multiple concurrent requests.""" 247 | client = self.client 248 | 249 | def make_request(content: str): 250 | payload = { 251 | "model": "claude-3-sonnet-20240229", 252 | "max_tokens": 100, 253 | "messages": [{"role": "user", "content": content}] 254 | } 255 | return client.post("/v1/messages", json=payload) 256 | 257 | # Make 5 concurrent requests 258 | import concurrent.futures 259 | 260 | with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: 261 | futures = [ 262 | executor.submit(make_request, f"Request {i}") 263 | for i in range(5) 264 | ] 265 | 266 | results = [future.result() for future in concurrent.futures.as_completed(futures)] 267 | 268 | # All should succeed 269 | for response in results: 270 | assert response.status_code == 200 271 | data = response.json() 272 | assert "content" in data 273 | 274 | def test_response_timing(self): 275 | """Test response timing is under expected threshold.""" 276 | import time 277 | 278 | client = self.client 279 | 280 | # Configure mock for 100ms delay 281 | requests.post( 282 | "http://localhost:8001/mock/configure", 283 | json={"delay_ms": 100} 284 | ) 285 | 286 | payload = { 287 | "model": "claude-3-sonnet-20240229", 288 | "max_tokens": 1024, 289 | "messages": [{"role": "user", "content": "Why is the sky blue?"}] 290 | } 291 | 292 | start_time = time.time() 293 | response = client.post("/v1/messages", json=payload) 294 | end_time = time.time() 295 | 296 | assert response.status_code == 200 297 | duration = (end_time - start_time) * 1000 # Convert to milliseconds 298 | 299 | # Should be under 50ms plus mock delay (adjusted for testing environment) 300 | assert duration < 200 # Conservative threshold including mock overhead 301 | 302 | @pytest.mark.skip 303 | def test_health_check_endpoint(self): 304 | """Test health check endpoint connectivity.""" 305 | client = self.client 306 | 307 | response = client.get("/health") 308 | 309 | assert response.status_code == 200 310 | data = response.json() 311 | assert data["status"] == "running" 312 | 313 | 314 | if __name__ == "__main__": 315 | import uvicorn 316 | uvicorn.run(mock_app, host="0.0.0.0", port=8001) -------------------------------------------------------------------------------- /tests/integration/test_error_handling.py: -------------------------------------------------------------------------------- 1 | """Comprehensive error handling tests for the proxy.""" 2 | 3 | import pytest 4 | import asyncio 5 | from typing import Dict, Any 6 | import requests 7 | from fastapi.testclient import TestClient 8 | from unittest.mock import patch, MagicMock, AsyncMock 9 | import os 10 | 11 | from src.main import app 12 | from src.core.clients.openai_client import OpenAIServiceClient 13 | 14 | 15 | class TestErrorHandlingIntegration: 16 | """Integration tests for various error paths and edge cases.""" 17 | 18 | @pytest.fixture(autouse=True) 19 | def setup(self): 20 | """Set up test environment.""" 21 | os.environ["OPENAI_API_KEY"] = "mock-key" 22 | os.environ["OPENAI_BASE_URL"] = "http://localhost:8001" 23 | os.environ["REQUEST_TIMEOUT"] = "2" # Short timeout for testing 24 | 25 | # Reset mock server 26 | requests.delete("http://localhost:8001/mock/reset") 27 | 28 | yield 29 | 30 | # Clean up environment 31 | if "OPENAI_API_KEY" in os.environ: 32 | del os.environ["OPENAI_API_KEY"] 33 | if "OPENAI_BASE_URL" in os.environ: 34 | del os.environ["OPENAI_BASE_URL"] 35 | if "REQUEST_TIMEOUT" in os.environ: 36 | del os.environ["REQUEST_TIMEOUT"] 37 | 38 | def test_network_timeout(self): 39 | """Test network timeout handling.""" 40 | # Configure mock server with long delay 41 | requests.post( 42 | "http://localhost:8001/mock/configure", 43 | json={"delay_ms": 3000} # 3s delay, longer than 2s timeout 44 | ) 45 | 46 | client = TestClient(app) 47 | 48 | payload = { 49 | "model": "claude-3-sonnet-20240229", 50 | "max_tokens": 100, 51 | "messages": [{"role": "user", "content": "Timeout test"}] 52 | } 53 | 54 | response = client.post("/v1/messages", json=payload) 55 | 56 | assert response.status_code == 504 # Gateway timeout 57 | data = response.json() 58 | 59 | assert data["error"]["type"] == "timeout_error" 60 | assert "timeout" in data["error"]["message"].lower() 61 | 62 | def test_connection_refused(self): 63 | """Test connection refused error handling.""" 64 | # Configure invalid base URL 65 | os.environ["OPENAI_BASE_URL"] = "http://localhost:9999" # Invalid port 66 | 67 | client = TestClient(app) 68 | 69 | payload = { 70 | "model": "claude-3-sonnet-20240229", 71 | "max_tokens": 100, 72 | "messages": [{"role": "user", "content": "Connection test"}] 73 | } 74 | 75 | response = client.post("/v1/messages", json=payload) 76 | 77 | assert response.status_code == 503 # Service unavailable 78 | data = response.json() 79 | 80 | assert "connection" in data["error"]["type"] or "service_unavailable" in data["error"]["type"] 81 | 82 | def test_malformed_openai_response(self): 83 | """Test handling of malformed JSON from OpenAI.""" 84 | from unittest.mock import patch 85 | import httpx 86 | 87 | client = TestClient(app) 88 | 89 | # Mock httpx to return invalid JSON 90 | with patch('httpx.AsyncClient.post') as mock_post: 91 | mock_response = MagicMock() 92 | mock_response.status_code = 200 93 | mock_response.json.side_effect = ValueError("Invalid JSON") 94 | mock_response.text = "invalid json response" 95 | mock_post.return_value.__aenter__.return_value = mock_response 96 | 97 | payload = { 98 | "model": "claude-3-sonnet-20240229", 99 | "max_tokens": 100, 100 | "messages": [{"role": "user", "content": "Malformed response test"}] 101 | } 102 | 103 | response = client.post("/v1/messages", json=payload) 104 | 105 | assert response.status_code == 502 # Bad gateway 106 | data = response.json() 107 | 108 | assert data["error"]["type"] == "api_error" 109 | 110 | def test_openai_500_error(self): 111 | """Test handling of 500 errors from OpenAI.""" 112 | client = TestClient(app) 113 | 114 | requests.post( 115 | "http://localhost:8001/mock/configure", 116 | json={"error_trigger": "server_error"} 117 | ) 118 | 119 | payload = { 120 | "model": "claude-3-sonnet-20240229", 121 | "max_tokens": 100, 122 | "messages": [{"role": "user", "content": "Server error test"}] 123 | } 124 | 125 | response = client.post("/v1/messages", json=payload) 126 | 127 | assert response.status_code == 500 128 | data = response.json() 129 | 130 | assert data["error"]["type"] == "api_error" 131 | assert "server" in data["error"]["message"].lower() 132 | 133 | def test_openai_429_rate_limit(self): 134 | """Test handling of rate limit errors from OpenAI.""" 135 | client = TestClient(app) 136 | 137 | requests.post( 138 | "http://localhost:8001/mock/configure", 139 | json={"error_trigger": "rate_limit"} 140 | ) 141 | 142 | payload = { 143 | "model": "claude-3-sonnet-20240229", 144 | "max_tokens": 100, 145 | "messages": [{"role": "user", "content": "Rate limit test"}] 146 | } 147 | 148 | response = client.post("/v1/messages", json=payload) 149 | 150 | assert response.status_code == 429 151 | data = response.json() 152 | 153 | assert data["error"]["type"] == "rate_limit_error" 154 | 155 | def test_openai_401_unauthorized(self): 156 | """Test handling of unauthorized errors from OpenAI.""" 157 | os.environ["OPENAI_API_KEY"] = "invalid-mock-key" 158 | 159 | client = TestClient(app) 160 | 161 | payload = { 162 | "model": "claude-3-sonnet-20240229", 163 | "max_tokens": 100, 164 | "messages": [{"role": "user", "content": "Unauthorized test"}] 165 | } 166 | 167 | response = client.post("/v1/messages", json=payload) 168 | 169 | assert response.status_code == 401 170 | data = response.json() 171 | 172 | assert data["error"]["type"] == "authentication_error" 173 | 174 | def test_invalid_json_payload(self): 175 | """Test handling of invalid JSON payload.""" 176 | client = TestClient(app) 177 | 178 | # Send raw text instead of JSON 179 | response = client.post( 180 | "/v1/messages", 181 | data="invalid json payload", 182 | headers={"content-type": "application/json"} 183 | ) 184 | 185 | assert response.status_code == 422 186 | data = response.json() 187 | 188 | assert "json" in str(data).lower() or "invalid" in str(data).lower() 189 | 190 | def test_missing_required_fields(self): 191 | """Test validation errors for missing required fields.""" 192 | client = TestClient(app) 193 | 194 | test_cases = [ 195 | # Missing messages 196 | {"model": "claude-3-sonnet-20240229", "max_tokens": 100}, 197 | # Empty messages 198 | {"model": "claude-3-sonnet-20240229", "max_tokens": 100, "messages": []}, 199 | # Missing model 200 | {"max_tokens": 100, "messages": [{"role": "user", "content": "test"}]}, 201 | # Invalid message structure 202 | { 203 | "model": "claude-3-sonnet-20240229", 204 | "max_tokens": 100, 205 | "messages": [{"invalid_role": "test", "content": "test"}] 206 | } 207 | ] 208 | 209 | for payload in test_cases: 210 | response = client.post("/v1/messages", json=payload) 211 | assert response.status_code == 422, f"Expected validation error for {payload}" 212 | 213 | def test_invalid_temperature_range(self): 214 | """Test validation of temperature parameter.""" 215 | client = TestClient(app) 216 | 217 | test_cases = [ 218 | -0.1, # Below 0.0 219 | 2.1, # Above 2.0 220 | "abc", # Non-numeric 221 | None # Null (optional should be fine) 222 | ] 223 | 224 | for temp in test_cases: 225 | payload = { 226 | "model": "claude-3-sonnet-20240229", 227 | "max_tokens": 100, 228 | "temperature": temp, 229 | "messages": [{"role": "user", "content": "test"}] 230 | } 231 | 232 | if temp is None or temp == "abc": 233 | # These might pass validation, skip 234 | continue 235 | 236 | response = client.post("/v1/messages", json=payload) 237 | 238 | if temp < 0.0 or temp > 2.0: 239 | assert response.status_code == 422 240 | 241 | def test_invalid_max_tokens(self): 242 | """Test validation of max_tokens parameter.""" 243 | client = TestClient(app) 244 | 245 | test_cases = [ 246 | -1, # Negative 247 | 0, # Zero 248 | 1_000_000, # Too large 249 | "abc", # Non-integer 250 | None # Should be required 251 | ] 252 | 253 | for max_tokens in test_cases: 254 | payload = { 255 | "model": "claude-3-sonnet-20240229", 256 | "max_tokens": max_tokens, 257 | "messages": [{"role": "user", "content": "test"}] 258 | } 259 | 260 | if max_tokens is None or max_tokens == "abc" or max_tokens <= 0: 261 | response = client.post("/v1/messages", json=payload) 262 | if max_tokens is None or max_tokens <= 0: 263 | assert response.status_code == 422 264 | 265 | def test_large_payload_handling(self): 266 | """Test handling of very large payloads.""" 267 | client = TestClient(app) 268 | 269 | # Create large message content 270 | large_content = "x" * 100_000 # 100KB of content 271 | 272 | payload = { 273 | "model": "claude-3-sonnet-20240229", 274 | "max_tokens": 100, 275 | "messages": [ 276 | {"role": "user", "content": large_content}, 277 | {"role": "assistant", "content": "That's a large payload"}, 278 | {"role": "user", "content": "Yes it is"} 279 | ] 280 | } 281 | 282 | response = client.post("/v1/messages", json=payload) 283 | 284 | # Should handle large content gracefully (might succeed or validation error) 285 | assert response.status_code in [200, 422] 286 | 287 | def test_special_characters_in_content(self): 288 | """Test handling of special characters in message content.""" 289 | client = TestClient(app) 290 | 291 | special_contents = [ 292 | "Unicode: 你好世界 🌍", 293 | "Emojis: 🚀✈️🚁🛸", 294 | "Newlines: Hello\nWorld\nTest", 295 | "Quotes: \"Hello\" and 'World'", 296 | "HTML:
test
", 297 | "JSON: {'key': 'value'}", 298 | "Code: `print('hello')`", 299 | "Long text: " + "word " * 100 300 | ] 301 | 302 | for content in special_contents: 303 | payload = { 304 | "model": "claude-3-sonnet-20240229", 305 | "max_tokens": 100, 306 | "messages": [{"role": "user", "content": content}] 307 | } 308 | 309 | response = client.post("/v1/messages", json=payload) 310 | assert response.status_code == 200, f"Failed for content: {content[:50]}..." 311 | 312 | def test_streaming_timeout(self): 313 | """Test streaming timeout handling.""" 314 | # Configure mock server with long delay for streaming 315 | requests.post( 316 | "http://localhost:8001/mock/configure", 317 | json={"delay_ms": 3000} 318 | ) 319 | 320 | client = TestClient(app) 321 | 322 | payload = { 323 | "model": "claude-3-sonnet-20240229", 324 | "max_tokens": 100, 325 | "stream": True, 326 | "messages": [{"role": "user", "content": "Streaming timeout test"}] 327 | } 328 | 329 | with client.stream("POST", "/v1/messages", json=payload) as response: 330 | # Streaming timeout might not be caught until we read 331 | assert response.status_code == 200 # Headers sent immediately 332 | 333 | # Collect chunks 334 | chunks = [] 335 | for line in response.iter_lines(): 336 | if line and line.startswith("data: "): 337 | data_str = line[6:] 338 | if data_str.strip() == "[DONE]": 339 | break 340 | 341 | try: 342 | data = json.loads(data_str) 343 | chunks.append(data) 344 | except Exception: 345 | # Might get timeout error in streaming chunks 346 | pass 347 | 348 | def test_concurrent_error_handling(self): 349 | """Test error handling with concurrent requests.""" 350 | import concurrent.futures 351 | import time 352 | 353 | client = TestClient(app) 354 | 355 | def make_request(request_num): 356 | if request_num % 2 == 0: 357 | # Valid request 358 | payload = { 359 | "model": "claude-3-sonnet-20240229", 360 | "max_tokens": 100, 361 | "messages": [{"role": "user", "content": f"Valid request {request_num}"}] 362 | } 363 | else: 364 | # Invalid request (missing messages) 365 | payload = { 366 | "model": "claude-3-sonnet-20240229", 367 | "max_tokens": 100 368 | } 369 | 370 | return client.post("/v1/messages", json=payload) 371 | 372 | # Make concurrent requests with mix of valid and invalid 373 | with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: 374 | futures = [executor.submit(make_request, i) for i in range(10)] 375 | results = [future.result() for future in concurrent.futures.as_completed(futures)] 376 | 377 | # Count success vs error responses 378 | success_count = sum(1 for r in results if r.status_code == 200) 379 | error_count = sum(1 for r in results if r.status_code != 200) 380 | 381 | # Should have both success and error responses 382 | assert success_count == 5 # Even-numbered requests 383 | assert error_count == 5 # Odd-numbered requests -------------------------------------------------------------------------------- /tests/integration/test_health_endpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Health check endpoint integration tests. 3 | """ 4 | import pytest 5 | from httpx import AsyncClient 6 | from fastapi import FastAPI 7 | from unittest.mock import AsyncMock, patch 8 | 9 | from src.main import app 10 | 11 | 12 | class TestHealthEndpoint: 13 | """Test cases for the health check endpoint.""" 14 | 15 | @pytest.fixture 16 | def test_app(self) -> FastAPI: 17 | """Create test FastAPI app.""" 18 | return app 19 | 20 | @pytest.mark.asyncio 21 | async def test_health_check_success(self, test_app): 22 | """Test successful health check with OpenAI service available.""" 23 | 24 | mock_health_result = { 25 | "openai_service": True, 26 | "api_accessible": True, 27 | "last_check": True 28 | } 29 | 30 | with patch( 31 | "src.clients.openai_client.OpenAIServiceClient.health_check", 32 | new_callable=AsyncMock 33 | ) as mock_health_check: 34 | mock_health_check.return_value = mock_health_result 35 | 36 | async with AsyncClient(app=test_app, base_url="http://test") as ac: 37 | response = await ac.get("/health") 38 | 39 | assert response.status_code == 200 40 | data = response.json() 41 | 42 | assert data["status"] == "healthy" 43 | assert data["service"] == "openai-to-claude" 44 | assert "timestamp" in data 45 | assert data["checks"]["openai"] == mock_health_result 46 | 47 | @pytest.mark.asyncio 48 | async def test_health_check_degraded(self, test_app): 49 | """Test health check when OpenAI service is degraded.""" 50 | 51 | mock_health_result = { 52 | "openai_service": False, 53 | "api_accessible": True, 54 | "last_check": True 55 | } 56 | 57 | with patch( 58 | "src.clients.openai_client.OpenAIServiceClient.health_check", 59 | new_callable=AsyncMock 60 | ) as mock_health_check: 61 | mock_health_check.return_value = mock_health_result 62 | 63 | async with AsyncClient(app=test_app, base_url="http://test") as ac: 64 | response = await ac.get("/health") 65 | 66 | assert response.status_code == 200 67 | data = response.json() 68 | 69 | assert data["status"] == "degraded" 70 | assert data["service"] == "openai-to-claude" 71 | assert data["checks"]["openai"] == mock_health_result 72 | 73 | @pytest.mark.asyncio 74 | async def test_health_check_unhealthy(self, test_app): 75 | """Test health check when OpenAI service is unavailable.""" 76 | 77 | with patch( 78 | "src.clients.openai_client.OpenAIServiceClient.health_check", 79 | new_callable=AsyncMock 80 | ) as mock_health_check: 81 | mock_health_check.side_effect = Exception("Connection failed") 82 | 83 | async with AsyncClient(app=test_app, base_url="http://test") as ac: 84 | response = await ac.get("/health") 85 | 86 | assert response.status_code == 200 87 | data = response.json() 88 | 89 | assert data["status"] == "unhealthy" 90 | assert data["service"] == "openai-to-claude" 91 | assert data["checks"]["openai"]["openai_service"] is False 92 | assert data["checks"]["openai"]["api_accessible"] is False 93 | assert "error" in data["checks"]["openai"] 94 | 95 | @pytest.mark.asyncio 96 | async def test_health_check_response_structure(self, test_app): 97 | """Test health check response structure and required fields.""" 98 | 99 | mock_health_result = { 100 | "openai_service": True, 101 | "api_accessible": True, 102 | "last_check": True 103 | } 104 | 105 | with patch( 106 | "src.clients.openai_client.OpenAIServiceClient.health_check", 107 | new_callable=AsyncMock 108 | ) as mock_health_check: 109 | mock_health_check.return_value = mock_health_result 110 | 111 | async with AsyncClient(app=test_app, base_url="http://test") as ac: 112 | response = await ac.get("/health") 113 | 114 | assert response.status_code == 200 115 | assert response.headers["content-type"] == "application/json" 116 | 117 | data = response.json() 118 | 119 | # Check required top-level fields 120 | assert "status" in data 121 | assert "service" in data 122 | assert "timestamp" in data 123 | assert "checks" in data 124 | 125 | # Check status values 126 | assert data["status"] in ["healthy", "degraded", "unhealthy"] 127 | assert data["service"] == "openai-to-claude" 128 | 129 | # Check timestamp format 130 | from datetime import datetime 131 | datetime.fromisoformat(data["timestamp"].replace('Z', '+00:00')) 132 | 133 | # Check checks structure 134 | assert isinstance(data["checks"], dict) 135 | assert "openai" in data["checks"] -------------------------------------------------------------------------------- /tests/integration/test_messages_endpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | _integration tests for /v1/messages endpoint 3 | 4 | ٛKՌ�t*�B-͔A �pnlb�OpenAI API� 5 | """ 6 | 7 | import pytest 8 | from fastapi.testclient import TestClient 9 | from unittest.mock import AsyncMock, patch, MagicMock 10 | import json 11 | 12 | from src.main import app 13 | from src.models.anthropic import AnthropicRequest 14 | from src.models.openai import OpenAIRequest, OpenAIResponse 15 | 16 | 17 | class TestMessagesEndpoint: 18 | """K� /v1/messages ﹟�""" 19 | 20 | @pytest.fixture 21 | def client(self): 22 | """�Kբ7�""" 23 | return TestClient(app) 24 | 25 | @pytest.fixture 26 | def valid_anthropic_request(self): 27 | """ H�Anthropic�B:�""" 28 | return { 29 | "model": "claude-3-5-sonnet-20241022", 30 | "max_tokens": 100, 31 | "messages": [ 32 | { 33 | "role": "user", 34 | "content": "Hello, how are you?" 35 | } 36 | ] 37 | } 38 | 39 | @pytest.fixture 40 | def mock_openai_response(self): 41 | """Mock�OpenAI͔""" 42 | return { 43 | "id": "chatcmpl-123", 44 | "object": "chat.completion", 45 | "created": 1700000000, 46 | "model": "gpt-3.5-turbo", 47 | "choices": [ 48 | { 49 | "index": 0, 50 | "message": { 51 | "role": "assistant", 52 | "content": "I'm doing well! How can I help you today?" 53 | }, 54 | "finish_reason": "stop" 55 | } 56 | ], 57 | "usage": { 58 | "prompt_tokens": 10, 59 | "completion_tokens": 15, 60 | "total_tokens": 25 61 | } 62 | } 63 | 64 | def test_endpoint_exists(self, client): 65 | """���X(""" 66 | response = client.post("/v1/messages", json={"model": "test"}) 67 | assert response.status_code != 404 68 | 69 | def test_invalid_request_format(self, client): 70 | """K��H�B<""" 71 | response = client.post("/v1/messages", json={"invalid": "format"}) 72 | assert response.status_code == 422 73 | 74 | def test_missing_required_fields(self, client): 75 | """K�:�kW�""" 76 | response = client.post("/v1/messages", json={"model": "test"}) 77 | assert response.status_code == 422 78 | 79 | def test_empty_messages(self, client): 80 | """K�z�oh""" 81 | request = { 82 | "model": "claude-3-5-sonnet-20241022", 83 | "messages": [], 84 | "max_tokens": 100 85 | } 86 | response = client.post("/v1/messages", json=request) 87 | assert response.status_code == 422 88 | 89 | @patch('src.handlers.messages.MessagesHandler.process_message') 90 | @pytest.mark.asyncio 91 | async def test_successful_non_streaming_request( 92 | self, mock_process, client, valid_anthropic_request, mock_openai_response 93 | ): 94 | """K���^A�B""" 95 | from src.models.anthropic import AnthropicMessageResponse 96 | 97 | # ���Anthropic͔ 98 | expected_response = AnthropicMessageResponse( 99 | id="msg_123", 100 | type="message", 101 | role="assistant", 102 | content=[{"type": "text", "text": "I'm doing well! How can I help you today?"}], 103 | model="claude-3-5-sonnet-20241022", 104 | stop_reason="end_turn", 105 | usage={"input_tokens": 10, "output_tokens": 15} 106 | ) 107 | 108 | mock_process.return_value = expected_response 109 | 110 | response = client.post("/v1/messages", json=valid_anthropic_request) 111 | 112 | assert response.status_code == 200 113 | response_data = response.json() 114 | 115 | assert response_data["type"] == "message" 116 | assert response_data["role"] == "assistant" 117 | assert len(response_data["content"]) > 0 118 | assert response_data["content"][0]["type"] == "text" 119 | assert response_data["model"] == "claude-3-5-sonnet-20241022" 120 | assert "usage" in response_data 121 | 122 | @patch('src.handlers.messages.MessagesHandler.process_stream_message') 123 | @pytest.mark.asyncio 124 | async def test_successful_streaming_request( 125 | self, mock_stream, client, valid_anthropic_request, mock_openai_response 126 | ): 127 | """K���A�B""" 128 | valid_anthropic_request["stream"] = True 129 | 130 | # !�A͔ 131 | async def mock_generator(): 132 | yield 'event: content_block_start\ndata: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}}\n\n' 133 | yield 'event: content_block_delta\ndata: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}}\n\n' 134 | yield 'event: message_stop\ndata: {"type": "message_stop"}\n\n' 135 | 136 | mock_stream.return_value = mock_generator() 137 | 138 | response = client.post("/v1/messages", json=valid_anthropic_request) 139 | 140 | assert response.status_code == 200 141 | assert response.headers["content-type"] == "text/event-stream; charset=utf-8" 142 | 143 | lines = response.content.decode('utf-8').split('\\n') 144 | assert any('content_block_start' in line for line in lines) 145 | 146 | def test_system_message_support(self, client): 147 | """K��߈o/""" 148 | request = { 149 | "model": "claude-3-5-sonnet-20241022", 150 | "messages": [{"role": "user", "content": "Hello"}], 151 | "system": [{ 152 | "type": "text", 153 | "text": "You are a helpful assistant" 154 | }], 155 | "max_tokens": 100 156 | } 157 | 158 | response = client.post("/v1/messages", json=request) 159 | assert response.status_code != 400 160 | 161 | def test_tool_definition_support(self, client): 162 | """K��w�I/""" 163 | request = { 164 | "model": "claude-3-5-sonnet-20241022", 165 | "messages": [{"role": "user", "content": "Hello"}], 166 | "tools": [ 167 | { 168 | "name": "get_weather", 169 | "description": "Get weather information", 170 | "input_schema": { 171 | "type": "object", 172 | "properties": { 173 | "location": {"type": "string"} 174 | }, 175 | "required": ["location"] 176 | } 177 | } 178 | ], 179 | "max_tokens": 100 180 | } 181 | 182 | response = client.post("/v1/messages", json=request) 183 | assert response.status_code != 400 184 | 185 | def test_temperature_parameter(self, client): 186 | """K�temperature�p""" 187 | request = { 188 | "model": "claude-3-5-sonnet-20241022", 189 | "messages": [{"role": "user", "content": "Hello"}], 190 | "temperature": 0.5, 191 | "max_tokens": 100 192 | } 193 | 194 | response = client.post("/v1/messages", json=request) 195 | # ��nj� �E��1converter 196 | assert response.status_code != 422 197 | 198 | def test_max_tokens_parameter(self, client): 199 | """K�max_tokens�p""" 200 | request = { 201 | "model": "claude-3-5-sonnet-20241022", 202 | "messages": [{"role": "user", "content": "Hello"}], 203 | "max_tokens": 1000 204 | } 205 | 206 | response = client.post("/v1/messages", json=request) 207 | assert response.status_code != 422 208 | 209 | def test_top_p_parameter(self, client): 210 | """K�top_p�p""" 211 | request = { 212 | "model": "claude-3-5-sonnet-20241022", 213 | "messages": [{"role": "user", "content": "Hello"}], 214 | "top_p": 0.9, 215 | "max_tokens": 100 216 | } 217 | 218 | response = client.post("/v1/messages", json=request) 219 | assert response.status_code != 422 220 | 221 | def test_stop_sequences_parameter(self, client): 222 | """K�stop_sequences�p""" 223 | request = { 224 | "model": "claude-3-5-sonnet-20241022", 225 | "messages": [{"role": "user", "content": "Hello"}], 226 | "stop_sequences": ["\n\n"], 227 | "max_tokens": 100 228 | } 229 | 230 | response = client.post("/v1/messages", json=request) 231 | assert response.status_code != 422 232 | 233 | def test_metadata_parameter(self, client): 234 | """K�metadata�p""" 235 | request = { 236 | "model": "claude-3-5-sonnet-20241022", 237 | "messages": [{"role": "user", "content": "Hello"}], 238 | "metadata": {"trace_id": "123456"}, 239 | "max_tokens": 100 240 | } 241 | 242 | response = client.post("/v1/messages", json=request) 243 | assert response.status_code != 422 244 | 245 | def test_image_content_support(self, client): 246 | """K��υ�/""" 247 | request = { 248 | "model": "claude-3-5-sonnet-20241022", 249 | "messages": [ 250 | { 251 | "role": "user", 252 | "content": [ 253 | {"type": "text", "text": "What's in this image?"}, 254 | { 255 | "type": "image", 256 | "source": { 257 | "type": "base64", 258 | "media_type": "image/jpeg", 259 | "data": "base64_data_here" 260 | } 261 | } 262 | ] 263 | } 264 | ], 265 | "max_tokens": 100 266 | } 267 | 268 | response = client.post("/v1/messages", json=request) 269 | assert response.status_code != 422 270 | 271 | @patch.dict('src.models.config.Config', {'openai': {'api_key': 'test-key'}}) 272 | def test_environment_variables(self, client): 273 | """Kկ���Mn""" 274 | # ��Mn��� 275 | response = client.post("/v1/messages", json={ 276 | "model": "claude-3-5-sonnet-20241022", 277 | "messages": [{"role": "user", "content": "Hello"}], 278 | "max_tokens": 100 279 | }) 280 | 281 | # Config import 282 | assert response.status_code != 500 283 | 284 | def test_error_response_format(self, client): 285 | """K��͔<&Anthropic�""" 286 | response = client.post("/v1/messages", json={"invalid": "format"}) 287 | 288 | assert response.status_code == 422 289 | error_data = response.json() 290 | 291 | # ���<+�W� 292 | assert "type" in error_data 293 | assert "error" in error_data 294 | assert "type" in error_data["error"] 295 | assert "message" in error_data["error"] 296 | 297 | def test_all_required_fields_present(self, client): 298 | """K��kW��t'""" 299 | required_fields = [ 300 | "model", 301 | "messages", 302 | "max_tokens" 303 | ] 304 | 305 | for field in required_fields: 306 | base_request = { 307 | "model": "claude-3-5-sonnet-20241022", 308 | "messages": [{"role": "user", "content": "Hello"}], 309 | "max_tokens": 100 310 | } 311 | 312 | # �d�kW� 313 | del base_request[field] 314 | 315 | response = client.post("/v1/messages", json=base_request) 316 | assert response.status_code == 422 -------------------------------------------------------------------------------- /tests/integration/test_model_mapping_integration.py: -------------------------------------------------------------------------------- 1 | """ 2 | 集成测试:验证模型映射的完整功能 3 | """ 4 | import json 5 | from src.models.anthropic import AnthropicRequest, AnthropicMessage 6 | from src.models.openai import OpenAIRequest 7 | from src.core.converters.request_converter import AnthropicToOpenAIConverter 8 | 9 | 10 | class TestModelMappingIntegration: 11 | """集成测试:验证从Anthropic到OpenAI的完整转换""" 12 | 13 | def test_convert_with_thinking_model(self): 14 | """使用thinking模型的完整转换测试""" 15 | anthropic_request = AnthropicRequest( 16 | model="claude-3-5-sonnet-20241022", 17 | messages=[AnthropicMessage(role="user", content="hello")], 18 | max_tokens=100, 19 | thinking=True 20 | ) 21 | 22 | openai_request = AnthropicToOpenAIConverter.convert_anthropic_to_openai(anthropic_request) 23 | 24 | assert isinstance(openai_request, OpenAIRequest) 25 | assert openai_request.model == "claude-3-7-sonnet-thinking" 26 | assert len(openai_request.messages) == 1 27 | assert openai_request.messages[0].content == "hello" 28 | assert openai_request.max_tokens == 100 29 | 30 | def test_convert_with_sonnet_model(self): 31 | """使用sonnet模型的完整转换测试""" 32 | anthropic_request = AnthropicRequest( 33 | model="claude-3-5-sonnet-20241022", 34 | messages=[AnthropicMessage(role="user", content="hello")], 35 | max_tokens=100, 36 | thinking=None 37 | ) 38 | 39 | openai_request = AnthropicToOpenAIConverter.convert_anthropic_to_openai(anthropic_request) 40 | 41 | assert isinstance(openai_request, OpenAIRequest) 42 | assert openai_request.model == "claude-3-5-sonnet" 43 | 44 | def test_convert_with_haiku_model(self): 45 | """使用haiku模型的完整转换测试""" 46 | anthropic_request = AnthropicRequest( 47 | model="claude-3-5-haiku", 48 | messages=[AnthropicMessage(role="user", content="hello")], 49 | max_tokens=100, 50 | thinking=None 51 | ) 52 | 53 | openai_request = AnthropicToOpenAIConverter.convert_anthropic_to_openai(anthropic_request) 54 | 55 | assert isinstance(openai_request, OpenAIRequest) 56 | assert openai_request.model == "claude-3-5-haiku" 57 | 58 | def test_convert_with_system_support(self): 59 | """带系统提示的完整转换测试""" 60 | anthropic_request = AnthropicRequest( 61 | model="claude-3-5-sonnet", 62 | messages=[AnthropicMessage(role="user", content="hello")], 63 | max_tokens=100, 64 | system="你是一个有用的助手", 65 | thinking=True 66 | ) 67 | 68 | openai_request = AnthropicToOpenAIConverter.convert_anthropic_to_openai(anthropic_request) 69 | 70 | assert openai_request.model == "claude-3-7-sonnet-thinking" 71 | assert openai_request.messages[0].role == "system" 72 | assert openai_request.messages[0].content == "你是一个有用的助手" 73 | 74 | def test_convert_with_tools(self): 75 | """带工具定义的完整转换测试""" 76 | tools = [ 77 | { 78 | "name": "get_weather", 79 | "description": "获取天气信息", 80 | "input_schema": { 81 | "type": "object", 82 | "properties": { 83 | "location": {"type": "string"}, 84 | "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]} 85 | }, 86 | "required": ["location"] 87 | } 88 | } 89 | ] 90 | 91 | anthropic_request = AnthropicRequest( 92 | model="claude-3-5-haiku", 93 | messages=[AnthropicMessage(role="user", content="天气如何")], 94 | max_tokens=100, 95 | tools=tools, 96 | thinking=None 97 | ) 98 | 99 | openai_request = AnthropicToOpenAIConverter.convert_anthropic_to_openai(anthropic_request) 100 | 101 | assert openai_request.model == "claude-3-5-haiku" 102 | assert openai_request.tools is not None 103 | assert len(openai_request.tools) == 1 104 | assert openai_request.tools[0].function.name == "get_weather" 105 | 106 | def test_model_passthrough(self): 107 | """测试带有逗号的模型名不被转换""" 108 | anthropic_request = AnthropicRequest( 109 | model="claude-custom,variant", 110 | messages=[AnthropicMessage(role="user", content="hello")], 111 | max_tokens=100 112 | ) 113 | 114 | # thinking字段存在,但因为模型名有逗号,应该保留原样 115 | openai_request = AnthropicToOpenAIConverter.convert_anthropic_to_openai(anthropic_request) 116 | 117 | assert openai_request.model == "claude-custom,variant" 118 | 119 | def test_serialization_roundtrip(self): 120 | """测试序列化和反序列化""" 121 | original_request = { 122 | "model": "claude-3-5-sonnet", 123 | "messages": [{"role": "user", "content": "hello"}], 124 | "max_tokens": 100, 125 | "thinking": True 126 | } 127 | 128 | anthropic_request = AnthropicRequest(**original_request) 129 | openai_request = AnthropicToOpenAIConverter.convert_anthropic_to_openai(anthropic_request) 130 | 131 | # 确保可以序列化为JSON 132 | openai_dict = openai_request.dict() 133 | assert openai_dict["model"] == "claude-3-7-sonnet-thinking" 134 | assert openai_dict["messages"][0]["role"] == "user" 135 | assert openai_dict["messages"][0]["content"] == "hello" -------------------------------------------------------------------------------- /tests/integration/test_streaming.py: -------------------------------------------------------------------------------- 1 | """Streaming response tests for end-to-end testing.""" 2 | 3 | import pytest 4 | import requests 5 | from fastapi.testclient import TestClient 6 | import os 7 | 8 | from src.main import app 9 | 10 | 11 | class TestStreamingIntegration: 12 | """Tests for streaming response handling in the proxy.""" 13 | 14 | @pytest.fixture(autouse=True) 15 | def setup(self): 16 | """Set up test environment.""" 17 | os.environ["OPENAI_API_KEY"] = "mock-key" 18 | os.environ["OPENAI_BASE_URL"] = "http://localhost:8001" 19 | 20 | # Reset mock server 21 | requests.delete("http://localhost:8001/mock/reset") 22 | 23 | yield 24 | 25 | # Clean up environment 26 | if "OPENAI_API_KEY" in os.environ: 27 | del os.environ["OPENAI_API_KEY"] 28 | if "OPENAI_BASE_URL" in os.environ: 29 | del os.environ["OPENAI_BASE_URL"] 30 | 31 | def test_streaming_chat_completion(self): 32 | """Test complete streaming flow from Anthropic to OpenAI proxy.""" 33 | client = TestClient(app) 34 | 35 | payload = { 36 | "model": "claude-3-sonnet-20240229", 37 | "max_tokens": 1024, 38 | "stream": True, 39 | "messages": [ 40 | {"role": "user", "content": "Write a short poem about Python"} 41 | ] 42 | } 43 | 44 | with client.stream("POST", "/v1/messages", json=payload) as response: 45 | assert response.status_code == 200 46 | assert response.headers["content-type"] == "text/plain; charset=utf-8" 47 | 48 | chunks = [] 49 | for line in response.iter_lines(): 50 | if line and line.startswith("data: "): 51 | data_str = line[6:] # Remove "data: " prefix 52 | if data_str.strip() == "[DONE]": 53 | break 54 | 55 | try: 56 | data = json.loads(data_str) 57 | chunks.append(data) 58 | except json.JSONDecodeError: 59 | continue 60 | 61 | # Verify we received some chunks 62 | assert len(chunks) > 0 63 | 64 | # Check structure of first chunk 65 | first_chunk = chunks[0] 66 | assert "type" in first_chunk 67 | assert first_chunk["type"] == "content_block_start" 68 | assert "content_block" in first_chunk 69 | 70 | # Check structure of content chunks 71 | content_chunks = [c for c in chunks if c.get("type") == "content_block_delta"] 72 | assert len(content_chunks) > 0 73 | 74 | # Check final chunk 75 | final_chunks = [c for c in chunks if c.get("type") == "message_stop"] 76 | assert len(final_chunks) == 1 77 | 78 | def test_streaming_with_system_message(self): 79 | """Test streaming with system message included.""" 80 | client = TestClient(app) 81 | 82 | payload = { 83 | "model": "claude-3-sonnet-20240229", 84 | "max_tokens": 512, 85 | "stream": True, 86 | "temperature": 0.8, 87 | "messages": [ 88 | {"role": "system", "content": "You are a Python programming assistant."}, 89 | {"role": "user", "content": "Explain decorators"} 90 | ] 91 | } 92 | 93 | with client.stream("POST", "/v1/messages", json=payload) as response: 94 | assert response.status_code == 200 95 | 96 | chunks = [] 97 | for line in response.iter_lines(): 98 | if line and line.startswith("data: "): 99 | data_str = line[6:] 100 | if data_str.strip() == "[DONE]": 101 | break 102 | 103 | try: 104 | data = json.loads(data_str) 105 | chunks.append(data) 106 | except json.JSONDecodeError: 107 | continue 108 | 109 | # Should still receive chunks even with system message 110 | assert len(chunks) > 2 111 | 112 | # Verify we got the message start 113 | start_chunks = [c for c in chunks if c.get("type") == "message_start"] 114 | assert len(start_chunks) == 1 115 | 116 | # Verify we got content blocks 117 | content_start = [c for c in chunks if c.get("type") == "content_block_start"] 118 | assert len(content_start) == 1 119 | 120 | def test_streaming_empty_content_handling(self): 121 | """Test streaming with empty or minimal content.""" 122 | client = TestClient(app) 123 | 124 | payload = { 125 | "model": "claude-3-sonnet-20240229", 126 | "max_tokens": 10, 127 | "stream": True, 128 | "messages": [ 129 | {"role": "user", "content": ""} 130 | ] 131 | } 132 | 133 | with client.stream("POST", "/v1/messages", json=payload) as response: 134 | assert response.status_code == 200 135 | 136 | chunks = [] 137 | for line in response.iter_lines(): 138 | if line and line.startswith("data: "): 139 | data_str = line[6:] 140 | if data_str.strip() == "[DONE]": 141 | break 142 | 143 | try: 144 | data = json.loads(data_str) 145 | chunks.append(data) 146 | except json.JSONDecodeError: 147 | continue 148 | 149 | # Should still receive proper streaming structure 150 | assert len(chunks) >= 3 # message_start, content_block_start, content_block_delta, message_stop 151 | 152 | def test_streaming_usage_information(self): 153 | """Test that usage information is included in streaming response.""" 154 | client = TestClient(app) 155 | 156 | payload = { 157 | "model": "claude-3-sonnet-20240229", 158 | "max_tokens": 100, 159 | "stream": True, 160 | "messages": [ 161 | {"role": "user", "content": "Count to 5"} 162 | ] 163 | } 164 | 165 | with client.stream("POST", "/v1/messages", json=payload) as response: 166 | assert response.status_code == 200 167 | 168 | # Collect all chunks to ensure complete flow 169 | chunks = [] 170 | for line in response.iter_lines(): 171 | if line and line.startswith("data: "): 172 | data_str = line[6:] 173 | if data_str.strip() == "[DONE]": 174 | break 175 | 176 | try: 177 | data = json.loads(data_str) 178 | chunks.append(data) 179 | except json.JSONDecodeError: 180 | continue 181 | 182 | # Look for message_delta with usage info 183 | usage_chunks = [c for c in chunks if c.get("type") == "message_delta"] 184 | assert len(usage_chunks) >= 1 185 | 186 | final_chunk = usage_chunks[0] 187 | assert "usage" in final_chunk 188 | usage = final_chunk["usage"] 189 | assert "input_tokens" in usage 190 | assert "output_tokens" in usage 191 | 192 | def test_streaming_response_timing(self): 193 | """Test streaming response timing performance.""" 194 | import time 195 | 196 | client = TestClient(app) 197 | 198 | payload = { 199 | "model": "claude-3-sonnet-20240229", 200 | "max_tokens": 50, 201 | "stream": True, 202 | "messages": [ 203 | {"role": "user", "content": "Say 'hello' three times"} 204 | ] 205 | } 206 | 207 | start_time = time.time() 208 | 209 | with client.stream("POST", "/v1/messages", json=payload) as response: 210 | assert response.status_code == 200 211 | 212 | chunks = [] 213 | for line in response.iter_lines(): 214 | if line and line.startswith("data: "): 215 | data_str = line[6:] 216 | if data_str.strip() == "[DONE]": 217 | break 218 | 219 | try: 220 | data = json.loads(data_str) 221 | chunks.append(data) 222 | except json.JSONDecodeError: 223 | continue 224 | 225 | end_time = time.time() 226 | duration = (end_time - start_time) * 1000 227 | 228 | # Ensure reasonable performance for streaming 229 | assert duration < 1000 # Should complete within 1 second for test 230 | assert len(chunks) > 3 # Should have multiple streaming chunks 231 | 232 | def test_streaming_with_stream_false_comparison(self): 233 | """Test that streaming works differently from non-streaming.""" 234 | client = TestClient(app) 235 | 236 | request_data = { 237 | "model": "claude-3-sonnet-20240229", 238 | "max_tokens": 50, 239 | "messages": [ 240 | {"role": "user", "content": "Reply with a single word: Python"} 241 | ] 242 | } 243 | 244 | # Non-streaming version 245 | non_stream_payload = request_data.copy() 246 | non_stream_payload["stream"] = False 247 | sync_response = client.post("/v1/messages", json=non_stream_payload) 248 | assert sync_response.status_code == 200 249 | 250 | sync_data = sync_response.json() 251 | assert "content" in sync_data 252 | 253 | # Streaming version 254 | stream_payload = request_data.copy() 255 | stream_payload["stream"] = True 256 | 257 | with client.stream("POST", "/v1/messages", json=stream_payload) as response: 258 | assert response.status_code == 200 259 | 260 | # Should be server-sent events format 261 | assert response.headers["content-type"] == "text/plain; charset=utf-8" 262 | 263 | # Collect streaming content 264 | collected_content = "" 265 | chunks = [] 266 | 267 | for line in response.iter_lines(): 268 | if line and line.startswith("data: "): 269 | data_str = line[6:] 270 | if data_str.strip() == "[DONE]": 271 | break 272 | 273 | try: 274 | data = json.loads(data_str) 275 | chunks.append(data) 276 | 277 | # Extract text content from delta chunks 278 | if data.get("type") == "content_block_delta": 279 | delta = data.get("delta", {}) 280 | if "text" in delta: 281 | collected_content += delta["text"] 282 | except json.JSONDecodeError: 283 | continue 284 | 285 | # Both should return similar content, streaming breaks it into chunks 286 | assert len(collected_content) > 0 287 | assert "Python" in collected_content or "python" in collected_content.lower() --------------------------------------------------------------------------------