├── .dockerignore
├── .gitignore
├── CLAUDE.md
├── Dockerfile
├── README.md
├── README_zh.md
├── config
└── example.json
├── docker-compose.yml
├── main.py
├── pyproject.toml
├── src
├── __init__.py
├── api
│ ├── __init__.py
│ ├── handlers.py
│ ├── middleware
│ │ ├── __init__.py
│ │ ├── auth.py
│ │ └── timing.py
│ └── routes.py
├── common
│ ├── __init__.py
│ ├── logging.py
│ ├── token_cache.py
│ └── token_counter.py
├── config
│ ├── __init__.py
│ ├── settings.py
│ └── watcher.py
├── core
│ ├── __init__.py
│ ├── clients
│ │ ├── __init__.py
│ │ └── openai_client.py
│ └── converters
│ │ ├── __init__.py
│ │ ├── request_converter.py
│ │ ├── response_converter.py
│ │ └── stream_converters.py
├── main.py
└── models
│ ├── __init__.py
│ ├── anthropic.py
│ ├── errors.py
│ └── openai.py
├── tests
├── __init__.py
├── fixtures.py
└── integration
│ ├── test_end_to_end.py
│ ├── test_error_handling.py
│ ├── test_health_endpoint.py
│ ├── test_messages_endpoint.py
│ ├── test_model_mapping_integration.py
│ └── test_streaming.py
└── uv.lock
/.dockerignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Distribution / packaging
7 | .Python
8 | build/
9 | develop-eggs/
10 | dist/
11 | downloads/
12 | eggs/
13 | .eggs/
14 | lib/
15 | lib64/
16 | parts/
17 | sdist/
18 | var/
19 | wheels/
20 | *.egg-info/
21 | .installed.cfg
22 | *.egg
23 |
24 | # Virtual environments
25 | .env
26 | .venv
27 | env/
28 | venv/
29 | ENV/
30 | env.bak/
31 | venv.bak/
32 |
33 | # IDEs and editors
34 | .vscode/
35 | .idea/
36 | *.swp
37 | *.swo
38 | *~
39 |
40 | # OS generated files
41 | .DS_Store
42 | .DS_Store?
43 | ._*
44 | .Spotlight-V100
45 | .Trashes
46 | ehthumbs.db
47 | Thumbs.db
48 |
49 | # Git
50 | .git/
51 | .gitignore
52 |
53 | # Log files
54 | logs/
55 | *.log
56 |
57 | # Testing and coverage
58 | .coverage
59 | .pytest_cache/
60 | .tox/
61 | .nox/
62 | htmlcov/
63 | .coverage.*
64 | coverage.xml
65 | *.cover
66 | *.py,cover
67 | .hypothesis/
68 |
69 | # Documentation
70 | docs/_build/
71 | .readthedocs.yml
72 |
73 | # Temporary files
74 | *.tmp
75 | *.temp
76 | .cache/
77 |
78 | # Configuration files
79 | config/settings.json
80 |
81 | # Docker
82 | Dockerfile*
83 | .dockerignore
84 | docker-compose*.yml
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .nox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | *.py,cover
48 | .hypothesis/
49 | .pytest_cache/
50 | cover/
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 | db.sqlite3
60 | db.sqlite3-journal
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | .pybuilder/
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # IPython
80 | profile_default/
81 | ipython_config.py
82 |
83 | # pyenv
84 | .python-version
85 |
86 | # pipenv
87 | Pipfile.lock
88 |
89 | # poetry
90 | #poetry.lock
91 |
92 | # pdm
93 | .pdm.toml
94 |
95 | # PEP 582
96 | __pypackages__/
97 |
98 | # Celery stuff
99 | celerybeat-schedule
100 | celerybeat.pid
101 |
102 | # SageMath parsed files
103 | *.sage.py
104 |
105 | # Environments
106 | .env
107 | .venv
108 | env/
109 | venv/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 | envs/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
133 | # pytype static type analyzer
134 | .pytype/
135 |
136 | # Cython debug symbols
137 | cython_debug/
138 |
139 | # PyCharm
140 | .idea/
141 |
142 | # VS Code
143 | .vscode/
144 |
145 | # macOS
146 | .DS_Store
147 | .AppleDouble
148 | .LSOverride
149 | Icon
150 | ._*
151 | .DocumentRevisions-V100
152 | .fseventsd
153 | .Spotlight-V100
154 | .TemporaryItems
155 | .Trashes
156 | .VolumeIcon.icns
157 | .com.apple.timemachine.donotpresent
158 | .AppleDB
159 | .AppleDesktop
160 | Network Trash Folder
161 | Temporary Items
162 | .apdisk
163 |
164 | # Application specific
165 | .env.local
166 | .env.development.local
167 | .env.test.local
168 | .env.production.local
169 |
170 | # Configuration files with sensitive data
171 | config/local_*.json
172 | config/secrets.json
173 |
174 | # Logs
175 | *.log
176 |
177 | # Data files
178 | data/
179 | tmp/
180 | .temp/
181 |
182 | # Docker
183 |
184 | #
185 | .claude
186 | .cunzhi-memory
187 | .ruff_cache
188 | .serena
189 | example
190 | config/settings.json
--------------------------------------------------------------------------------
/CLAUDE.md:
--------------------------------------------------------------------------------
1 | # CLAUDE.md
2 |
3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4 |
5 | ## 项目概述
6 |
7 | 这是一个高性能的 RESTful API 代理服务,将 OpenAI API 格式转换为 Anthropic Claude API 兼容格式,允许开发者使用现有的 OpenAI 客户端代码来调用 Anthropic Claude 模型。
8 |
9 | 主要功能:
10 | - ✅ 请求/响应格式转换
11 | - ✅ 流式响应支持
12 | - ✅ 错误处理映射
13 | - ✅ 配置管理
14 | - ✅ 结构化日志
15 | - ✅ 健康检查端点
16 |
17 | ## 项目结构
18 |
19 | ```
20 | openai-to-claude/
21 | ├── src/
22 | │ ├── config/ # 配置管理
23 | │ ├── models/ # Pydantic 数据模型
24 | │ ├── converters/ # 数据格式转换器
25 | │ ├── clients/ # HTTP 客户端
26 | │ ├── api/ # API 端点和中间件
27 | │ └── common/ # 公共工具日志、token计数等)
28 | ├── tests/ # 测试套件
29 | ├── config/ # 配置文件模板
30 | └── pyproject.toml # 项目依赖和配置
31 | ```
32 |
33 | ## 核心架构
34 |
35 | ### 数据流向
36 | 1. 客户端发送 Anthropic 格式的请求到 `/v1/messages` 端点
37 | 2. `MessagesHandler` 接收请求并使用 `AnthropicToOpenAIConverter` 转换为 OpenAI 格式
38 | 3. `OpenAIServiceClient` 将请求发送到实际的 OpenAI API 服务
39 | 4. 收到 OpenAI 响应后,使用 `OpenAIToAnthropicConverter` 转换回 Anthropic 格式
40 | 5. 返回给客户端
41 |
42 | ### 主要组件
43 | - **src/api/handlers.py**: 核心消息处理逻辑
44 | - **src/core/converters/**: 请求/响应格式转换器
45 | - **src/core/clients/openai_client.py**: OpenAI API 客户端
46 | - **src/models/**: 数据模型定义 (Anthropic 和 OpenAI)
47 | - **src/config/settings.py**: 配置管理
48 | - **src/common/**: 公共工具(日志、token 计数等)
49 |
50 | ## 开发命令
51 |
52 | ### 安装依赖
53 | ```bash
54 | uv sync
55 | ```
56 |
57 | ### 运行服务
58 | ```bash
59 | uvicorn src.main:app --reload --host 0.0.0.0 --port 8000
60 | ```
61 |
62 | ### 运行测试
63 | ```bash
64 | # 运行所有测试
65 | pytest
66 |
67 | # 运行特定测试文
68 | pytest tests/unit/test_models.py
69 |
70 | # 运行集成测试
71 | pytest tests/integration
72 |
73 | # 带覆盖率报告
74 | pytest --cov=src --cov-report=html
75 | ```
76 |
77 | ### 代码质量检查
78 | ```bash
79 | # 运行 linting
80 | ruff check .
81 |
82 | # 自动修复代码风格
83 | ruff check . --fix
84 |
85 | # 代码格式化
86 | black .
87 |
88 | # 类型检查
89 | mypy src
90 | ```
91 |
92 | ## 配置说明
93 |
94 | ### 环境变量
95 | - `LOG_LEVEL`: 日志级别 (默认: INFO)
96 | - `CONFIG_PATH`: 配置文件路径 (默认: config/settings.json)
97 |
98 | ### 配置文件 (config/settings.json)
99 | ```json
100 | {
101 | "openai": {
102 | "api_key": "your-openai-api-key-here",
103 | "base_url": "https://api.openai.com/v1"
104 | },
105 | "server": {
106 | "host": "0.0.0.0",
107 | "port": 8000
108 | },
109 | "api_key": "your-proxy-api-key-here",
110 | "logging": {
111 | "level": "INFO"
112 | },
113 | "models": {
114 | "default": "claude-3-5-sonnet-20241022",
115 | "small": "claude-3-5-haiku-20241022",
116 | "tool": "claude-3-5-sonnet-20241022",
117 | "think": "claude-3-7-sonnet-20250219",
118 | "longContext": "claude-3-7-sonnet-20250219"
119 | },
120 | "parameter_overrides": {
121 | "max_tokens": null,
122 | "temperature": null,
123 | "top_p": null,
124 | "top_k": null
125 | }
126 | }
127 | ```
128 |
129 | ## API 端点
130 |
131 | - `POST /v1/messages` - Anthropic 消息 API(兼容 OpenAI 聊天 API)
132 | - `GET /health` - 健康检查端点
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # 构建阶段
2 | FROM python:3.11-slim AS builder
3 |
4 | # 设置工作目录
5 | WORKDIR /app
6 |
7 | # 复制 pyproject.toml 文件
8 | COPY pyproject.toml ./
9 |
10 | # 安装项目依赖
11 | RUN pip install --no-cache-dir uv && uv sync
12 |
13 | # 运行阶段
14 | FROM python:3.11-slim
15 |
16 | # 设置工作目录
17 | WORKDIR /app
18 |
19 | # 安装 uv 工具
20 | RUN pip install --no-cache-dir uv
21 |
22 | # 从构建阶段复制已安装的依赖
23 | COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
24 |
25 | # 复制项目文件
26 | COPY . .
27 |
28 | # 设置环境变量
29 | ENV PYTHONDONTWRITEBYTECODE=1 \
30 | PYTHONUNBUFFERED=1
31 |
32 | # 创建日志和配置目录
33 | RUN mkdir -p /app/logs /app/config
34 |
35 | # 暴露端口
36 | EXPOSE 8000
37 |
38 | # 启动应用
39 | CMD ["uv", "run", "main.py"]
40 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OpenAI-to-Claude API Proxy Service
2 |
3 | [](https://www.python.org/downloads/)
4 | [](https://fastapi.tiangolo.com/)
5 | [](LICENSE)
6 |
7 | High-performance proxy service that converts OpenAI API to Anthropic API compatible format. Allows developers to seamlessly call OpenAI models using existing Anthropic client code.
8 |
9 | [中文版本](README_zh.md)
10 |
11 | ## 🌟 Core Features
12 |
13 | - ✅ **Seamless Compatibility**: Call OpenAI models using standard Anthropic clients
14 | - ✅ **Full Functionality**: Supports text, tool calls, streaming responses, and more
15 | - ✅ **Intelligent Routing**: Automatically selects the most suitable OpenAI model based on request content
16 | - ✅ **Hot Reload**: Automatically reloads configuration file changes without restarting the service
17 | - ✅ **Structured Logging**: Detailed request/response logs for debugging and monitoring
18 | - ✅ **Error Mapping**: Comprehensive error handling and mapping mechanisms
19 |
20 | ## 🚀 Quick Start
21 |
22 | ### Requirements
23 |
24 | - Python 3.11+
25 | - uv (recommended package manager)
26 |
27 | ### Install Dependencies
28 |
29 | ```bash
30 | # Install dependencies using uv (recommended)
31 | uv sync
32 | ```
33 |
34 | ### Configuration
35 |
36 | 1. Copy the example configuration file:
37 | ```bash
38 | cp config/example.json config/settings.json
39 | ```
40 |
41 | 2. Edit `config/settings.json`:
42 | ```json
43 | {
44 | "openai": {
45 | "api_key": "your-openai-api-key-here", // Replace with your OpenAI API key
46 | "base_url": "https://api.openai.com/v1" // OpenAI API address
47 | },
48 | "api_key": "your-proxy-api-key-here", // API key for the proxy service
49 | // Other configurations...
50 | }
51 | ```
52 |
53 | ### Start the Service
54 |
55 | ```bash
56 | # Development mode
57 | uv run main.py --config config/settings.json
58 |
59 | # Production mode
60 | uv run main.py
61 | ```
62 |
63 | ### Start with Docker
64 |
65 | ```bash
66 | # Build and start the service
67 | docker-compose up --build
68 |
69 | # Run in background
70 | docker-compose up --build -d
71 |
72 | # Stop the service
73 | docker-compose down
74 | ```
75 |
76 | The service will start at `http://localhost:8000`.
77 |
78 | ## 🛠️ Usage
79 |
80 | ### Claude Code Usage
81 |
82 | This project can be used with [Claude Code](https://claude.ai/code) for development and testing. To configure Claude Code to work with this proxy service, create a `.claude/settings.json` file with the following configuration:
83 |
84 | ```json
85 | {
86 | "env": {
87 | "ANTHROPIC_API_KEY": "your-api-key",
88 | "ANTHROPIC_BASE_URL": "http://127.0.0.1:8000",
89 | "DISABLE_TELEMETRY": "1",
90 | "CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC": "1"
91 | },
92 | "apiKeyHelper": "echo 'your-api-key'",
93 | "permissions": {
94 | "allow": [],
95 | "deny": []
96 | }
97 | }
98 | ```
99 |
100 | Configuration Notes:
101 | - Replace `ANTHROPIC_API_KEY` with your API key, in `config/settings.json`
102 | - Replace `ANTHROPIC_BASE_URL` with the actual URL where this proxy service is running
103 | - The `apiKeyHelper` with your API key, in `config/settings.json`
104 |
105 | ### Using Anthropic Python Client
106 |
107 | ```python
108 | from anthropic import Anthropic
109 |
110 | # Initialize client pointing to the proxy service
111 | client = Anthropic(
112 | base_url="http://localhost:8000/v1",
113 | api_key="your-proxy-api-key-here" # Use the api_key from the configuration file
114 | )
115 |
116 | # Send a message request
117 | response = client.messages.create(
118 | model="gpt-4o",
119 | messages=[
120 | {"role": "user", "content": "Hello, GPT!"}
121 | ],
122 | max_tokens=1024
123 | )
124 |
125 | print(response.content[0].text)
126 | ```
127 |
128 | ### Streaming Response
129 |
130 | ```python
131 | # Streaming response
132 | stream = client.messages.create(
133 | model="gpt-4o",
134 | messages=[
135 | {"role": "user", "content": "Tell me a story about AI"}
136 | ],
137 | max_tokens=1024,
138 | stream=True
139 | )
140 |
141 | for chunk in stream:
142 | if chunk.type == "content_block_delta":
143 | print(chunk.delta.text, end="", flush=True)
144 | ```
145 |
146 | ### Tool Calls
147 |
148 | ```python
149 | # Tool calls
150 | tools = [
151 | {
152 | "name": "get_current_weather",
153 | "description": "Get the current weather for a specified city",
154 | "input_schema": {
155 | "type": "object",
156 | "properties": {
157 | "city": {
158 | "type": "string",
159 | "description": "City name"
160 | }
161 | },
162 | "required": ["city"]
163 | }
164 | }
165 | ]
166 |
167 | response = client.messages.create(
168 | model="gpt-4o",
169 | messages=[
170 | {"role": "user", "content": "What's the weather like in Beijing now?"}
171 | ],
172 | tools=tools,
173 | tool_choice={"type": "auto"}
174 | )
175 | ```
176 |
177 | ## 📁 Project Structure
178 |
179 | ```
180 | openai-to-claude/
181 | ├── src/
182 | │ ├── api/ # API endpoints and middleware
183 | │ ├── config/ # Configuration management
184 | │ ├── core/ # Core business logic
185 | │ │ ├── clients/ # HTTP clients
186 | │ │ └── converters/ # Data format converters
187 | │ ├── models/ # Pydantic data models
188 | │ └── common/ # Common utilities (logging, token counting, etc.)
189 | ├── config/ # Configuration files
190 | ├── tests/ # Test suite
191 | ├── CLAUDE.md # Claude Code project guide
192 | └── pyproject.toml # Project dependencies and configuration
193 | ```
194 |
195 | ## 🔧 Configuration
196 |
197 | ### Environment Variables
198 |
199 | - `CONFIG_PATH`: Configuration file path (default: `config/settings.json`)
200 | - `LOG_LEVEL`: Log level (default: `INFO`)
201 |
202 | ### Configuration File (`config/settings.json`)
203 |
204 | ```json
205 | {
206 | "openai": {
207 | "api_key": "your-openai-api-key-here",
208 | "base_url": "https://api.openai.com/v1"
209 | },
210 | "server": {
211 | "host": "0.0.0.0",
212 | "port": 8000
213 | },
214 | "api_key": "your-proxy-api-key-here",
215 | "logging": {
216 | "level": "INFO"
217 | },
218 | "models": {
219 | "default": "Qwen/Qwen3-Coder",
220 | "small": "deepseek-ai/DeepSeek-V3-0324",
221 | "think": "deepseek-ai/DeepSeek-R1-0528",
222 | "long_context": "gemini-2.5-pro",
223 | "web_search": "gemini-2.5-flash"
224 | },
225 | "parameter_overrides": {
226 | "max_tokens": null,
227 | "temperature": null,
228 | "top_p": null,
229 | "top_k": null
230 | }
231 | }
232 | ```
233 |
234 | #### Configuration Items
235 |
236 | - **openai**: OpenAI API configuration
237 | - `api_key`: OpenAI API key for accessing OpenAI services
238 | - `base_url`: OpenAI API base URL, default is `https://api.openai.com/v1`
239 |
240 | - **server**: Server configuration
241 | - `host`: Service listening host address, default is `0.0.0.0` (listen on all network interfaces)
242 | - `port`: Service listening port, default is `8000`
243 |
244 | - **api_key**: API key for the proxy service, used to verify requests to the `/v1/messages` endpoint
245 |
246 | - **logging**: Logging configuration
247 | - `level`: Log level, options are `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL`, default is `INFO`
248 |
249 | - **models**: Model configuration, defines model selection for different usage scenarios
250 | - `default`: Default general model for general requests
251 | - `small`: Lightweight model for simple tasks
252 | - `think`: Deep thinking model for complex reasoning tasks
253 | - `long_context`: Long context processing model for handling long text
254 | - `web_search`: Web search model for web search, currently supports geimini
255 |
256 | - **parameter_overrides**: Parameter override configuration, allows administrators to set model parameter override values in the configuration file
257 | - `max_tokens`: Maximum token count override, when set, will override the max_tokens parameter in client requests
258 | - `temperature`: Temperature parameter override, controls the randomness of output, range 0.0-2.0
259 | - `top_p`: top_p sampling parameter override, controls the probability threshold of candidate words, range 0.0-1.0
260 | - `top_k`: top_k sampling parameter override, controls the number of candidate words, range >=0
261 |
262 | ## 🧪 Testing
263 |
264 | ```bash
265 | # Run all tests
266 | pytest
267 |
268 | # Run unit tests
269 | pytest tests/unit
270 |
271 | # Run integration tests
272 | pytest tests/integration
273 |
274 | # Generate coverage report
275 | pytest --cov=src --cov-report=html
276 | ```
277 |
278 | ## 📊 API Endpoints
279 |
280 | - `POST /v1/messages` - Anthropic Messages API
281 | - `GET /health` - Health check endpoint
282 | - `GET /` - Welcome page
283 |
284 | ## 🛡️ Security
285 |
286 | - API key authentication
287 | - Request rate limiting (planned)
288 | - Input validation and sanitization
289 | - Structured logging
290 |
291 | ## 📈 Performance Monitoring
292 |
293 | - Request/response time monitoring
294 | - Memory usage tracking
295 | - Error rate statistics
296 |
297 | ## 🤝 Contributing
298 |
299 | Issues and Pull Requests are welcome!
300 |
301 | ## 📄 License
302 |
303 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
304 |
305 | ## 🙏 Acknowledgements
306 |
307 | - [claude-code-router](https://github.com/musistudio/claude-code-router) - Very good project, many places in this project have referenced this project
308 | - [FastAPI](https://fastapi.tiangolo.com/) - Modern high-performance web framework
309 | - [Anthropic](https://www.anthropic.com/) - Claude AI models
310 | - [OpenAI](https://openai.com/) - OpenAI API specification
311 |
312 | ## 🤖 Claude Code Usage
313 |
314 | This project can be used with [Claude Code](https://claude.ai/code) for development and testing. To configure Claude Code to work with this proxy service, create a `.claude/settings.json` file with the following configuration:
315 |
316 | ### Example Configuration
317 |
318 | ```json
319 | {
320 | "env": {
321 | "ANTHROPIC_API_KEY": "sk-chen0v0...",
322 | "ANTHROPIC_BASE_URL": "http://127.0.0.1:8100",
323 | "DISABLE_TELEMETRY": "1",
324 | "CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC": "1"
325 | },
326 | "apiKeyHelper": "echo 'sk-chen0v0...'",
327 | "permissions": {
328 | "allow": [],
329 | "deny": []
330 | }
331 | }
332 | ```
333 |
334 | ### Configuration Notes
335 |
336 | - Replace `ANTHROPIC_API_KEY` with your actual Anthropic API key
337 | - Replace `ANTHROPIC_BASE_URL` with the actual URL where this proxy service is running
338 | - The `apiKeyHelper` field should also be updated with your actual API key
--------------------------------------------------------------------------------
/README_zh.md:
--------------------------------------------------------------------------------
1 | # OpenAI-to-Claude API 代理服务
2 |
3 | [](https://www.python.org/downloads/)
4 | [](https://fastapi.tiangolo.com/)
5 | [](LICENSE)
6 |
7 | 将 OpenAI API 转为 Anthropic API 兼容格式的高性能代理服务。允许开发者使用现有的 Anthropic 客户端代码无缝调用 OpenAI 模型。
8 |
9 | ## 🌟 核心特性
10 |
11 | - ✅ **无缝兼容**: 使用标准 Anthropic 客户端调用 OpenAI 模型
12 | - ✅ **完整功能**: 支持文本、工具调用、流式响应等功能
13 | - ✅ **智能路由**: 根据请求内容自动选择最适合的 OpenAI 模型
14 | - ✅ **热重载**: 配置文件修改后自动重载,无需重启服务
15 | - ✅ **结构化日志**: 详细的请求/响应日志,便于调试和监控
16 | - ✅ **错误映射**: 完善的错误处理和映射机制
17 |
18 | ## 🚀 快速开始
19 |
20 | ### 环境要求
21 |
22 | - Python 3.11+
23 | - uv (推荐的包管理器)
24 |
25 | ### 安装依赖
26 |
27 | ```bash
28 | # 使用 uv 安装依赖(推荐)
29 | uv sync
30 | ```
31 |
32 | ### 配置
33 |
34 | 1. 复制示例配置文件:
35 | ```bash
36 | cp config/example.json config/settings.json
37 | ```
38 |
39 | 2. 编辑 `config/settings.json`:
40 | ```json
41 | {
42 | "openai": {
43 | "api_key": "your-openai-api-key-here", // 替换为你的 OpenAI API 密钥
44 | "base_url": "https://api.openai.com/v1" // OpenAI API 地址
45 | },
46 | "api_key": "your-proxy-api-key-here", // 代理服务的 API 密钥
47 | // 其他配置...
48 | }
49 | ```
50 |
51 | ### 启动服务
52 |
53 | ```bash
54 | # 开发模式
55 | uv run main.py --config config/settings.json
56 |
57 | # 生产模式
58 | uv run main.py
59 | ```
60 |
61 | ### 使用 Docker 启动
62 |
63 | ```bash
64 | # 构建并启动服务
65 | docker-compose up --build
66 |
67 | # 后台运行
68 | docker-compose up --build -d
69 |
70 | # 停止服务
71 | docker-compose down
72 | ```
73 |
74 | 服务将在 `http://localhost:8000` 启动。
75 |
76 | ## 🛠️ 使用方法
77 |
78 | ### Claude Code 使用方法
79 |
80 | 本项目可以与 [Claude Code](https://claude.ai/code) 一起使用进行开发和测试。要配置 Claude Code 以使用此代理服务,请创建一个 `.claude/settings.json` 文件,配置如下:
81 |
82 | ```json
83 | {
84 | "env": {
85 | "ANTHROPIC_API_KEY": "your-api-key",
86 | "ANTHROPIC_BASE_URL": "http://127.0.0.1:8000",
87 | "DISABLE_TELEMETRY": "1",
88 | "CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC": "1"
89 | },
90 | "apiKeyHelper": "echo 'your-api-key'",
91 | "permissions": {
92 | "allow": [],
93 | "deny": []
94 | }
95 | }
96 | ```
97 |
98 | 配置说明:
99 | - 将 `ANTHROPIC_API_KEY` 替换为您配置的 API 密钥,在 `config/settings.json` 中
100 | - 将 `ANTHROPIC_BASE_URL` 替换为此代理服务实际运行的 URL
101 | - `apiKeyHelper` 字段也应更新为您的 API 密钥
102 |
103 | ### 使用 Anthropic Python 客户端
104 |
105 | ```python
106 | from anthropic import Anthropic
107 |
108 | # 初始化客户端,指向代理服务
109 | client = Anthropic(
110 | base_url="http://localhost:8000/v1",
111 | api_key="your-proxy-api-key-here" # 使用配置文件中的 api_key
112 | )
113 |
114 | # 发送消息请求
115 | response = client.messages.create(
116 | model="gpt-4o",
117 | messages=[
118 | {"role": "user", "content": "你好,GPT!"}
119 | ],
120 | max_tokens=1024
121 | )
122 |
123 | print(response.content[0].text)
124 | ```
125 |
126 | ### 流式响应
127 |
128 | ```python
129 | # 流式响应
130 | stream = client.messages.create(
131 | model="gpt-4o",
132 | messages=[
133 | {"role": "user", "content": "给我讲一个关于 AI 的故事"}
134 | ],
135 | max_tokens=1024,
136 | stream=True
137 | )
138 |
139 | for chunk in stream:
140 | if chunk.type == "content_block_delta":
141 | print(chunk.delta.text, end="", flush=True)
142 | ```
143 |
144 | ### 工具调用
145 |
146 | ```python
147 | # 工具调用
148 | tools = [
149 | {
150 | "name": "get_current_weather",
151 | "description": "获取指定城市的当前天气",
152 | "input_schema": {
153 | "type": "object",
154 | "properties": {
155 | "city": {
156 | "type": "string",
157 | "description": "城市名称"
158 | }
159 | },
160 | "required": ["city"]
161 | }
162 | }
163 | ]
164 |
165 | response = client.messages.create(
166 | model="gpt-4o",
167 | messages=[
168 | {"role": "user", "content": "北京现在的天气怎么样?"}
169 | ],
170 | tools=tools,
171 | tool_choice={"type": "auto"}
172 | )
173 | ```
174 |
175 | ## 📁 项目结构
176 |
177 | ```
178 | openai-to-claude/
179 | ├── src/
180 | │ ├── api/ # API 端点和中间件
181 | │ ├── config/ # 配置管理
182 | │ ├── core/ # 核心业务逻辑
183 | │ │ ├── clients/ # HTTP 客户端
184 | │ │ └── converters/ # 数据格式转换器
185 | │ ├── models/ # Pydantic 数据模型
186 | │ └── common/ # 公共工(日志、token计数等)
187 | ├── config/ # 配置文件
188 | ├── tests/ # 测试套件
189 | ├── CLAUDE.md # Claude Code 项目指导
190 | └── pyproject.toml # 项目依赖和配置
191 | ```
192 |
193 | ## 🤖 Claude Code 使用方法
194 |
195 | 本项目可以与 [Claude Code](https://claude.ai/code) 一起使用进行开发和测试。要配置 Claude Code 以使用此代理服务,请创建一个 `.claude/settings.json` 文件,配置如下:
196 |
197 | ### 示例配置文件
198 |
199 | ```json
200 | {
201 | "env": {
202 | "ANTHROPIC_API_KEY": "sk-chen0v0...",
203 | "ANTHROPIC_BASE_URL": "http://127.0.0.1:8100",
204 | "DISABLE_TELEMETRY": "1",
205 | "CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC": "1"
206 | },
207 | "apiKeyHelper": "echo 'sk-chen0v0...'",
208 | "permissions": {
209 | "allow": [],
210 | "deny": []
211 | }
212 | }
213 | ```
214 |
215 | ### 配置说明
216 |
217 | - 将 `ANTHROPIC_API_KEY` 替换为您的实际 Anthropic API 密钥
218 | - 将 `ANTHROPIC_BASE_URL` 替换为此代理服务实际运行的 URL
219 | - `apiKeyHelper` 字段也应更新为您的实际 API 密钥
220 |
221 | ## 🔧 配置说明
222 |
223 | ### 环境变量
224 |
225 | - `CONFIG_PATH`: 配置文件路径 (默认: `config/settings.json`)
226 | - `LOG_LEVEL`: 日志级别 (默认: `INFO`)
227 |
228 | ### 配置文件 (`config/settings.json`)
229 |
230 | ```json
231 | {
232 | "openai": {
233 | "api_key": "your-openai-api-key-here",
234 | "base_url": "https://api.openai.com/v1"
235 | },
236 | "server": {
237 | "host": "0.0.0.0",
238 | "port": 8000
239 | },
240 | "api_key": "your-proxy-api-key-here",
241 | "logging": {
242 | "level": "INFO"
243 | },
244 | "models": {
245 | "default": "Qwen/Qwen3-Coder",
246 | "small": "deepseek-ai/DeepSeek-V3-0324",
247 | "think": "deepseek-ai/DeepSeek-R1-0528",
248 | "long_context": "gemini-2.5-pro",
249 | "web_search": "gemini-2.5-flash"
250 | },
251 | "parameter_overrides": {
252 | "max_tokens": null,
253 | "temperature": null,
254 | "top_p": null,
255 | "top_k": null
256 | }
257 | }
258 | ```
259 |
260 | #### 配置项说明
261 |
262 | - **openai**: OpenAI API 配置
263 | - `api_key`: OpenAI API 密钥,用于访问 OpenAI 服务
264 | - `base_url`: OpenAI API 基础 URL,默认为 `https://api.openai.com/v1`
265 |
266 | - **server**: 服务器配置
267 | - `host`: 服务监听主机地址,默认为 `0.0.0.0`(监听所有网络接口)
268 | - `port`: 服务监听端口,默认为 `8000`
269 |
270 | - **api_key**: 代理服务的 API 密钥,用于验证访问 `/v1/messages` 端点的请求
271 |
272 | - **logging**: 日志配置
273 | - `level`: 日志级别,可选值为 `DEBUG`、`INFO`、`WARNING`、`ERROR`、`CRITICAL`,默认为 `INFO`
274 |
275 | - **models**: 模型配置,定义不同使用场景下的模型选择
276 | - `default`: 默认通用模型,用于一般请求
277 | - `small`: 轻量级模型,用于简单任务
278 | - `think`: 深度思考模型,用于复杂推理任务
279 | - `long_context`: 长上下文处理模型,用于处理长文本
280 | - `web_search`: 网络搜索模型,用于网络搜索,目前支持geimini
281 |
282 | - **parameter_overrides**: 参数覆盖配置,允许管理员在配置文件中设置模型参数的覆盖值
283 | - `max_tokens`: 最大 token 数覆盖,设置后会覆盖客户端请求中的 max_tokens 参数
284 | - `temperature`: 温度参数覆盖,控制输出的随机程度,范围为 0.0-2.0
285 | - `top_p`: top_p 采样参数覆盖,控制候选词汇的概率阈值,范围为 0.0-1.0
286 | - `top_k`: top_k 采样参数覆盖,控制候选词汇的数量,范围为 >=0
287 |
288 | ## 🧪 测试
289 |
290 | ```bash
291 | # 运行所有测试
292 | pytest
293 |
294 | # 运行单元测试
295 | pytest tests/unit
296 |
297 | # 运行集成测试
298 | pytest tests/integration
299 |
300 | # 生成覆盖率报告
301 | pytest --cov=src --cov-report=html
302 | ```
303 |
304 | ## 📊 API 端点
305 |
306 | - `POST /v1/messages` - Anthropic 消息 API
307 | - `GET /health` - 健康检查端点
308 | - `GET /` - 欢迎页面
309 |
310 | ## 🛡️ 安全性
311 |
312 | - API 密钥验证
313 | - 请求频率限制(计划中)
314 | - 输入验证和清理
315 | - 结构化日志记录
316 |
317 | ## 📈 性能监控
318 |
319 | - 请求/响应时间监控
320 | - 内存使用情况跟踪
321 | - 错误率统计
322 |
323 | ## 🤝 贡献
324 |
325 | 欢迎提交 Issue 和 Pull Request!
326 |
327 | ## 📄 许可证
328 |
329 | 本项目采用 MIT 许可证 - 查看 [LICENSE](LICENSE) 文件了解详情。
330 |
331 | ## 🙏 致谢
332 |
333 | - [claude-code-router](https://github.com/musistudio/claude-code-router) - 很好的项目,本项目很多地方参考了这个项目
334 | - [FastAPI](https://fastapi.tiangolo.com/) - 现代高性能 Web 框架
335 | - [Anthropic](https://www.anthropic.com/) - Claude AI 模型
336 | - [OpenAI](https://openai.com/) - OpenAI API 规范
--------------------------------------------------------------------------------
/config/example.json:
--------------------------------------------------------------------------------
1 | {
2 | "openai": {
3 | "api_key": "your-openai-api-key-here",
4 | "base_url": "https://api.openai.com/v1"
5 | },
6 | "server": {
7 | "host": "0.0.0.0",
8 | "port": 8000
9 | },
10 | "api_key": "your-proxy-api-key-here",
11 | "logging": {
12 | "level": "INFO"
13 | },
14 | "models": {
15 | "default": "Qwen/Qwen3-Coder",
16 | "small": "deepseek-ai/DeepSeek-V3-0324",
17 | "think": "deepseek-ai/DeepSeek-R1-0528",
18 | "long_context": "gemini-2.5-pro",
19 | "web_search": "gemini-2.5-flash"
20 | },
21 | "parameter_overrides": {
22 | "max_tokens": null,
23 | "temperature": null,
24 | "top_p": null,
25 | "top_k": null
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3.8'
2 |
3 | services:
4 | proxy:
5 | container_name: openai-to-claude # 指定容器名称
6 | build: .
7 | ports:
8 | - "8000:8000" # 将主机的8000端口映射到容器的8000端口
9 | volumes:
10 | - ./config:/app/config # 挂载配置文件目录
11 | - ./logs:/app/logs # 挂载日志目录
12 | restart: unless-stopped # 容器在停止后会自动重启,除非手动停止
13 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Anthropic-OpenAI Proxy 启动脚本
4 |
5 | 使用 JSON 配置文件中的 host 和 port 启动服务器,而不是命令行参数。
6 | 配置优先级:
7 | 1. 命令行指定的 config 参数
8 | 2. 环境变量 CONFIG_PATH 指定的路径
9 | 3. ./config/settings.json (默认)
10 | 4. ./config/example.json (模板)
11 | """
12 |
13 | import asyncio
14 | import uvicorn
15 | import os
16 | import sys
17 | import argparse
18 | from pathlib import Path
19 |
20 | # 添加 src 目录到 Python 路径
21 | sys.path.insert(0, str(Path(__file__).parent / "src"))
22 |
23 | from src.config.settings import Config
24 |
25 |
26 | def main():
27 | """主启动函数"""
28 | try:
29 | parser = argparse.ArgumentParser(description="启动 Anthropic-OpenAI Proxy")
30 | parser.add_argument(
31 | "--config", type=str, help="JSON 配置文件路径 (默认为 config/settings.json)"
32 | )
33 | parser.add_argument(
34 | "--config-path",
35 | type=str,
36 | default="config/settings.json",
37 | help="配置文件路径,可通过 CONFIG_PATH 环境变量指定",
38 | )
39 |
40 | args = parser.parse_args()
41 |
42 | # 确保从项目根目录启动
43 | project_root = Path(__file__).parent
44 | os.chdir(project_root)
45 |
46 | # 确定配置文件路径
47 | config_path = args.config or os.getenv("CONFIG_PATH", args.config_path)
48 |
49 | # 同步加载配置
50 | config = Config.from_file_sync(config_path)
51 |
52 | # 获取服务器配置
53 | host, port = config.get_server_config()
54 |
55 | print(f"🚀 启动 OpenAI To Claude Server...")
56 | print(f" 配置文件: {config_path}")
57 | print(f" 监听地址: {host}:{port}")
58 | print()
59 | print("📋 重要端点:")
60 | print(f" 健康检查: http://{host}:{port}/health")
61 | print(f" API文档: http://{host}:{port}/docs")
62 | print(f" OpenAPI: http://{host}:{port}/openapi.json")
63 | print()
64 |
65 | # 启动 Uvicorn 服务器
66 | uvicorn.run(
67 | "src.main:app",
68 | host=host,
69 | port=port,
70 | # reload=True,
71 | timeout_keep_alive=60,
72 | log_level=config.logging.level.lower(),
73 | )
74 | except Exception as e:
75 | print(f"❌ 启动失败: {e}")
76 | sys.exit(1)
77 |
78 |
79 | if __name__ == "__main__":
80 | main()
81 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "openai-to-claude"
3 | version = "0.1.0"
4 | description = "OpenAI to Anthropic API 兼容代理服务"
5 | readme = "README.md"
6 | requires-python = ">=3.11"
7 | authors = [
8 | { name = "Dev Team", email = "dev@example.com" }
9 | ]
10 | keywords = ["anthropic", "openai", "api", "proxy", "fastapi"]
11 | classifiers = [
12 | "Development Status :: 3 - Alpha",
13 | "Intended Audience :: Developers",
14 | "Programming Language :: Python :: 3",
15 | "Programming Language :: Python :: 3.11",
16 | "Programming Language :: Python :: 3.12",
17 | "Operating System :: OS Independent",
18 | "Topic :: Software Development :: Libraries :: Python Modules",
19 | "Topic :: Internet :: WWW/HTTP :: HTTP Servers",
20 | ]
21 |
22 | dependencies = [
23 | "fastapi>=0.104.1,<0.105.0",
24 | "uvicorn[standard]>=0.24.0,<0.25.0",
25 | "httpx>=0.25.0,<0.26.0",
26 | "pydantic>=2.5.0,<3.0.0",
27 | "pydantic-settings>=2.1.0,<3.0.0",
28 | "python-multipart>=0.0.6",
29 | "uvloop>=0.19.0,<0.20.0",
30 | "orjson>=3.9.0,<4.0.0",
31 | "loguru>=0.7.0,<1.0.0",
32 | "tenacity>=8.2.0,<9.0.0",
33 | "aiofiles>=23.2.0,<24.0.0",
34 | "pytest>=7.4.4",
35 | "pytest-asyncio>=0.23.8",
36 | "tiktoken>=0.7.0,<1.0.0",
37 | "watchdog>=3.0.0,<4.0.0",
38 | ]
39 |
40 | [project.optional-dependencies]
41 | dev = [
42 | "pytest>=7.4.0,<8.0.0",
43 | "pytest-asyncio>=0.21.0,<1.0.0",
44 | "pytest-cov>=4.1.0,<5.0.0",
45 | "httpx>=0.25.0,<0.26.0", # 用于测试
46 | "mypy>=1.7.0,<2.0.0",
47 | "ruff>=0.2.0,<1.0.0",
48 | "black>=23.11.0,<24.0.0",
49 | ]
50 |
51 | [project.urls]
52 | Homepage = "https://github.com/example/openai-to-claude"
53 | Repository = "https://github.com/example/openai-to-claude"
54 | Issues = "https://github.com/example/openai-to-claude/issues"
55 |
56 | [tool.uv]
57 | package = true
58 |
59 | [tool.hatch.build.targets.wheel]
60 | packages = ["src"]
61 |
62 | [tool.hatch.build.targets.sdist]
63 | packages = ["src"]
64 |
65 | [tool.pytest.ini_options]
66 | testpaths = ["tests"]
67 | pythonpath = ["."]
68 | python_files = ["test_*.py", "*_test.py"]
69 | python_classes = ["Test*"]
70 | python_functions = ["test_*"]
71 | addopts = [
72 | "--cov=src",
73 | "--cov-report=term-missing",
74 | "--cov-report=html",
75 | "--cov-report=xml",
76 | "-v",
77 | "--strict-markers",
78 | ]
79 | asyncio_mode = "auto"
80 | markers = [
81 | "integration: mark test as integration test requiring mock server",
82 | "streaming: mark test as streaming response test",
83 | "timeout: mark test as timeout or error handling test",
84 | "slow: mark test as slow running",
85 | "e2e: mark test as end-to-end integration test"
86 | ]
87 |
88 | [tool.ruff]
89 | line-length = 88
90 | target-version = "py311"
91 |
92 | [tool.ruff.lint]
93 | select = [
94 | "E", # pycodestyle errors
95 | "W", # pycodestyle warnings
96 | "F", # pyflakes
97 | "I", # isort
98 | "B", # flake8-bugbear
99 | "C4", # flake8-comprehensions
100 | "UP", # pyupgrade
101 | ]
102 | ignore = [
103 | "E501", # line too long, handled by black
104 | "B008", # do not perform function calls in argument defaults
105 | "C901", # too complex
106 | ]
107 |
108 | [tool.ruff.lint.per-file-ignores]
109 | "__init__.py" = ["F401"]
110 | "tests/**/*" = ["B011", "B017"]
111 |
112 | [tool.black]
113 | line-length = 88
114 | target-version = ["py311"]
115 | include = '\.pyi?$'
116 | extend-exclude = '''
117 | /(
118 | # directories
119 | \.eggs
120 | | \.git
121 | | \.hg
122 | | \.mypy_cache
123 | | \.tox
124 | | \.venv
125 | | build
126 | | dist
127 | )/
128 | '''
129 |
130 | [tool.mypy]
131 | python_version = "3.11"
132 | warn_return_any = true
133 | warn_unused_configs = true
134 | disallow_untyped_defs = true
135 | disallow_incomplete_defs = true
136 | check_untyped_defs = true
137 | disallow_untyped_decorators = true
138 | no_implicit_optional = true
139 | warn_redundant_casts = true
140 | warn_unused_ignores = true
141 | warn_no_return = true
142 | warn_unreachable = true
143 | strict_equality = true
144 |
145 | [dependency-groups]
146 | dev = [
147 | "requests>=2.32.4",
148 | ]
149 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | OpenAI-Claude Code Proxy
3 |
4 | 一个高性能的 Anthropic Claude API 到 OpenAI API 格式的代理服务。
5 | 提供完整的请求/响应转换、流式处理、错误处理和配置管理功能。
6 |
7 | 主要功能:
8 | - Anthropic Claude API 到 OpenAI API 格式的双向转换
9 | - 支持流式和非流式响应
10 | - 完整的工具调用支持
11 | - 请求ID追踪和日志记录
12 | - 配置文件热重载
13 | - 性能监控和错误处理
14 |
15 | 使用示例:
16 | from src import create_app
17 |
18 | app = create_app()
19 | """
20 |
21 | __version__ = "1.0.0"
22 | __author__ = "OpenAI-Claude Code Proxy Team"
23 | __description__ = "Anthropic Claude API to OpenAI API format proxy service"
24 |
25 | # 导出主要的公共API
26 | from .common import configure_logging, request_logger
27 | from .config import get_config, reload_config
28 | from .main import app
29 |
30 | __all__ = [
31 | "app",
32 | "get_config",
33 | "reload_config",
34 | "configure_logging",
35 | "request_logger",
36 | "__version__",
37 | "__author__",
38 | "__description__",
39 | ]
40 |
--------------------------------------------------------------------------------
/src/api/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | API模块
3 |
4 | 提供FastAPI应用的路由、处理器和中间件。
5 |
6 | 主要功能:
7 | - API路由定义
8 | - 请求处理器
9 | - 中间件集成
10 | - 健康检查端点
11 |
12 | 子模块:
13 | - handlers: API请求处理器
14 | - routes: 路由定义
15 | - middleware: 中间件实现
16 | """
17 |
18 | # 导入路由和处理器
19 | from .handlers import MessagesHandler, messages_endpoint
20 | from .handlers import router as handlers_router
21 |
22 | # 导入中间件
23 | from .middleware import (
24 | APIKeyMiddleware,
25 | RequestTimingMiddleware,
26 | setup_middlewares,
27 | )
28 | from .routes import health_check
29 | from .routes import router as routes_router
30 |
31 | __all__ = [
32 | # 路由
33 | "routes_router",
34 | "handlers_router",
35 | "health_check",
36 | "health_check_detailed",
37 | # 处理器
38 | "MessagesHandler",
39 | "messages_endpoint",
40 | # 中间件
41 | "APIKeyMiddleware",
42 | "RequestTimingMiddleware",
43 | "setup_middlewares",
44 | ]
45 |
--------------------------------------------------------------------------------
/src/api/handlers.py:
--------------------------------------------------------------------------------
1 | """
2 | Anthropic /v1/messages 端点处理程序
3 |
4 | 实现Anthropic native messages API与OpenAI API的转换和代理
5 | """
6 |
7 | import asyncio
8 | import json
9 | from collections.abc import AsyncGenerator
10 |
11 | from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
12 | from fastapi.responses import JSONResponse, StreamingResponse
13 | from pydantic import ValidationError
14 |
15 | from src.core.clients.openai_client import OpenAIServiceClient
16 | from src.core.converters.request_converter import (
17 | AnthropicToOpenAIConverter,
18 | )
19 | from src.core.converters.response_converter import OpenAIToAnthropicConverter
20 | from src.models.anthropic import (
21 | AnthropicMessageResponse,
22 | AnthropicRequest,
23 | )
24 | from src.models.errors import get_error_response
25 |
26 | router = APIRouter(prefix="/v1", tags=["messages"])
27 |
28 |
29 | class MessagesHandler:
30 | """处理Anthropic /v1/messages 端点请求"""
31 |
32 | def __init__(self, config):
33 | self.request_converter = AnthropicToOpenAIConverter()
34 | self.response_converter = OpenAIToAnthropicConverter()
35 | self.config = config
36 | self._config = None
37 | self.client = OpenAIServiceClient(
38 | api_key=config.openai.api_key,
39 | base_url=config.openai.base_url,
40 | )
41 |
42 | @classmethod
43 | async def create(cls, config=None):
44 | """异步工厂方法创建 MessagesHandler 实例"""
45 | if config is None:
46 | from src.config.settings import get_config
47 |
48 | config = await get_config()
49 |
50 | instance = cls.__new__(cls)
51 | instance.request_converter = AnthropicToOpenAIConverter()
52 | instance.response_converter = OpenAIToAnthropicConverter()
53 | instance.config = config
54 | instance._config = config
55 | instance.client = OpenAIServiceClient(
56 | api_key=config.openai.api_key,
57 | base_url=config.openai.base_url,
58 | )
59 | return instance
60 |
61 | async def process_message(
62 | self, request: AnthropicRequest, request_id: str = None
63 | ) -> AnthropicMessageResponse:
64 | """处理非流式消息请求"""
65 | # 获取绑定了请求ID的logger
66 | from src.common.logging import get_logger_with_request_id
67 |
68 | bound_logger = get_logger_with_request_id(request_id)
69 |
70 | try:
71 | bound_logger.debug("处理非流式请求")
72 | # 验证请求
73 | # await validate_anthropic_request(request, request_id)
74 | # 将 Anthropic 请求转换为 OpenAI 格式(异步)
75 | openai_request = await self.request_converter.convert_anthropic_to_openai(
76 | request, request_id
77 | )
78 |
79 | # 发送到 OpenAI
80 | openai_response = await self.client.send_request(
81 | openai_request, request_id=request_id
82 | )
83 | bound_logger.debug(
84 | f"OpenAI 响应: {json.dumps(openai_response, ensure_ascii=False)}"
85 | )
86 |
87 | # 将 OpenAI 响应转回 Anthropic 格式
88 | anthropic_response = await self.response_converter.convert_response(
89 | openai_response, request.model, request_id
90 | )
91 | # 安全地提取响应文本
92 | response_text = "empty"
93 | if (
94 | anthropic_response.content
95 | and len(anthropic_response.content) > 0
96 | and hasattr(anthropic_response.content[0], "text")
97 | and anthropic_response.content[0].text
98 | ):
99 | response_text = anthropic_response.content[0].text
100 | bound_logger.info(
101 | f"Anthropic 响应生成完成 - Text: {response_text[:100]}..., Usage: {anthropic_response.usage}"
102 | )
103 |
104 | return anthropic_response
105 |
106 | except ValidationError as e:
107 | bound_logger.warning(f"Validation error - Errors: {e.errors()}")
108 | error_response = get_error_response(
109 | 422, details={"validation_errors": e.errors(), "request_id": request_id}
110 | )
111 | raise HTTPException(status_code=422, detail=error_response.model_dump())
112 |
113 | except json.JSONDecodeError as e:
114 | # 专门处理JSON解析错误,这通常发生在OpenAI响应解析时
115 | bound_logger.exception(
116 | f"JSON解析错误 - Error: {str(e)}, Position: {e.pos if hasattr(e, 'pos') else 'unknown'}"
117 | )
118 | error_response = get_error_response(
119 | 502,
120 | message="上游服务返回无效JSON格式",
121 | details={"json_error": str(e), "request_id": request_id},
122 | )
123 | raise HTTPException(status_code=502, detail=error_response.model_dump())
124 | except HTTPException as e:
125 | bound_logger.exception(
126 | f"处理非流式消息请求错误 - Type: {type(e).__name__}, Error: {str(e)}"
127 | )
128 | error_response = get_error_response(
129 | e.status_code, message=str(e.detail), details={"request_id": request_id}
130 | )
131 | raise HTTPException(
132 | status_code=e.status_code,
133 | detail=error_response.model_dump(exclude_none=True),
134 | )
135 |
136 | except Exception as e:
137 | bound_logger.exception(
138 | f"处理非流式消息请求错误 - Type: {type(e).__name__}, Error: {str(e)}"
139 | )
140 | error_response = get_error_response(
141 | 500, message=str(e), details={"request_id": request_id}
142 | )
143 | raise HTTPException(
144 | status_code=500, detail=error_response.model_dump(exclude_none=True)
145 | )
146 |
147 | async def process_stream_message(
148 | self, request: AnthropicRequest, request_id: str = None
149 | ) -> AsyncGenerator[str, None]:
150 | """处理流式消息请求,使用新的流式转换器"""
151 | if not request.stream:
152 | raise ValueError("流式响应参数必须为true")
153 |
154 | # 获取绑定了请求ID的logger
155 | from src.common.logging import get_logger_with_request_id
156 |
157 | bound_logger = get_logger_with_request_id(request_id)
158 |
159 | try:
160 | # await validate_anthropic_request(request, request_id)
161 | openai_request = await self.request_converter.convert_anthropic_to_openai(
162 | request, request_id
163 | )
164 |
165 | # 创建 OpenAI 流式数据源
166 | async def openai_stream_generator():
167 | bound_logger.info("开始OpenAI流式生成")
168 | chunk_count = 0
169 | async for chunk in self.client.send_streaming_request(
170 | openai_request, request_id=request_id
171 | ):
172 | # 跳过被解析器过滤掉的不完整chunk(通常是tool_calls片段)
173 | if chunk is not None:
174 | chunk_count += 1
175 | # 将 OpenAI 响应对象转换为字符串格式
176 | bound_logger.debug(f"OpenAI event: {chunk}")
177 | yield f"{chunk}\n\n"
178 | bound_logger.debug(f"OpenAI流式生成完成,总共{chunk_count}个chunk")
179 |
180 | # 使用新的流式转换器
181 | bound_logger.info("开始流式转换")
182 | async for (
183 | anthropic_event
184 | ) in self.response_converter.convert_openai_stream_to_anthropic_stream(
185 | openai_stream_generator(), model=request.model, request_id=request_id
186 | ):
187 | bound_logger.debug(f"Anthropic event: {anthropic_event}")
188 | yield anthropic_event
189 | bound_logger.info("流式转换完成")
190 |
191 | except (ValidationError, ValueError) as e:
192 | error_detail = e.errors() if hasattr(e, "errors") else str(e)
193 | bound_logger.warning(f"流式请求验证失败 - Errors: {error_detail}")
194 | error_response = get_error_response(422, message=str(error_detail))
195 | # 在错误响应中添加请求ID
196 | error_data = error_response.model_dump()
197 | if request_id:
198 | error_data["request_id"] = request_id
199 | yield f"event: error\ndata: {json.dumps(error_data, ensure_ascii=False)}\n\n"
200 |
201 | except json.JSONDecodeError as e:
202 | # 专门处理流式模式下的JSON解析错误
203 | bound_logger.exception(
204 | f"流式模式JSON解析错误 - Error: {str(e)}, Position: {e.pos if hasattr(e, 'pos') else 'unknown'}"
205 | )
206 | error_response = get_error_response(
207 | 502,
208 | message="流式响应中发现无效JSON格式",
209 | details={"json_error": str(e), "request_id": request_id},
210 | )
211 | error_data = error_response.model_dump()
212 | if request_id:
213 | error_data["request_id"] = request_id
214 | yield f"event: error\ndata: {json.dumps(error_data, ensure_ascii=False)}\n\n"
215 |
216 | except Exception as e:
217 | bound_logger.exception(
218 | f"流式请求处理错误 - Type: {type(e).__name__}, Error: {str(e)}"
219 | )
220 | error_response = get_error_response(500, message=str(e))
221 | # 在错误响应中添加请求ID
222 | error_data = error_response.model_dump()
223 | if request_id:
224 | error_data["request_id"] = request_id
225 | yield f"event: error\ndata: {json.dumps(error_data, ensure_ascii=False)}\n\n"
226 |
227 |
228 | @router.post("/messages")
229 | async def messages_endpoint(request: Request, background_tasks: BackgroundTasks):
230 | """
231 | Anthropic /v1/messages 端点
232 |
233 | 这个端点实现了Anthropic原生messages API的主要功能:
234 | - 接受Anthropic格式的请求
235 | - 转换为OpenAI格式发送到后端
236 | - 返回Anthropic格式的响应
237 | """
238 | # 从应用状态获取消息处理器(已由main.py在启动时初始化)
239 | handler: MessagesHandler = request.app.state.messages_handler
240 |
241 | # 获取请求ID(由中间件生成,如果启用的话)
242 | from src.common.logging import (
243 | get_logger_with_request_id,
244 | get_request_id_from_request,
245 | )
246 |
247 | request_id = get_request_id_from_request(request)
248 | bound_logger = get_logger_with_request_id(request_id)
249 |
250 | # 记录请求
251 | client_ip = request.client.host if request.client else "unknown"
252 | bound_logger.info(
253 | f"收到Anthropic请求 - Method: {request.method}, URL: {str(request.url)}, IP: {client_ip}"
254 | )
255 |
256 | try:
257 | # 解析请求体
258 | body = await request.json()
259 | # 记录请求
260 | log_body = body.copy()
261 | log_body["tools"] = []
262 | bound_logger.debug(
263 | f"Anthropic请求体 - Model: {body.get('model', 'unknown')}, Messages: {len(body.get('messages', []))}, Stream: {body.get('stream', False)}\n{json.dumps(log_body, ensure_ascii=False, indent=2)}"
264 | )
265 |
266 | anthropic_request = AnthropicRequest(**body)
267 |
268 | # 记录清理后的请求信息(移除敏感信息)
269 | # safe_body = sanitize_for_logging(body)
270 | # logger.debug("请求已清理", request_body=safe_body)
271 |
272 | # 根据请求类型处理响应
273 | if anthropic_request.stream:
274 | # 流式响应 - 优化配置确保真正的流式效果
275 | async def stream_wrapper():
276 | """包装器确保流式响应的立即传输"""
277 | try:
278 | async for chunk in handler.process_stream_message(
279 | anthropic_request, request_id=request_id
280 | ):
281 | # 立即传输每个chunk,不缓冲
282 | # chunk已经是完整的SSE格式字符串,编码后返回
283 | yield chunk.encode("utf-8")
284 | # 强制刷新缓冲区(在某些环境中有效)
285 | await asyncio.sleep(0)
286 | except Exception as e:
287 | # 如果流式处理出错,记录完整错误并发送错误事件
288 | bound_logger.exception(f"流式处理出错 - Error: {str(e)}")
289 | error_data = {"error": str(e)}
290 | if request_id:
291 | error_data["request_id"] = request_id
292 | error_event = f"event: error\ndata: {json.dumps(error_data)}\n\n"
293 | yield error_event.encode("utf-8")
294 |
295 | return StreamingResponse(
296 | stream_wrapper(),
297 | media_type="text/event-stream; charset=utf-8",
298 | headers={
299 | "Cache-Control": "no-cache, no-store, must-revalidate",
300 | "Pragma": "no-cache",
301 | "Expires": "0",
302 | "Connection": "keep-alive",
303 | "X-Accel-Buffering": "no", # 禁用nginx缓冲
304 | "X-Content-Type-Options": "nosniff",
305 | "Transfer-Encoding": "chunked",
306 | "Access-Control-Allow-Origin": "*", # CORS支持
307 | "Access-Control-Allow-Headers": "*",
308 | "Access-Control-Allow-Methods": "*",
309 | "X-Proxy-Buffering": "no", # 禁用代理缓冲
310 | "Buffering": "no", # 禁用缓冲
311 | },
312 | )
313 | else:
314 | # 非流式响应
315 | response = await handler.process_message(
316 | anthropic_request, request_id=request_id
317 | )
318 | json_response = JSONResponse(content=response.model_dump(exclude_none=True))
319 | if request_id:
320 | json_response.headers["X-Request-ID"] = request_id
321 | return json_response
322 |
323 | except ValidationError as e:
324 | bound_logger.warning(f"请求验证失败 - Errors: {e.errors()}")
325 | error_response = get_error_response(
326 | 422, details={"validation_errors": e.errors()}
327 | )
328 | error_detail = error_response.model_dump()
329 | error_detail["request_id"] = request_id
330 | raise HTTPException(status_code=422, detail=error_detail)
331 |
332 | except json.JSONDecodeError as e:
333 | bound_logger.warning(f"请求中的JSON格式错误 - Error: {str(e)}")
334 | error_response = get_error_response(400, message="无效的JSON格式")
335 | error_detail = error_response.model_dump()
336 | error_detail["request_id"] = request_id
337 | raise HTTPException(status_code=400, detail=error_detail)
338 |
339 | except Exception as e:
340 | # 检查是否为HTTPException,避免重复记录已处理的错误
341 | if isinstance(e, HTTPException):
342 | # HTTPException已经在内层处理过,直接重新抛出
343 | raise e
344 |
345 | bound_logger.exception(
346 | f"在messages端点发生意外错误 - Type: {type(e).__name__}, Error: {str(e)}"
347 | )
348 | error_response = get_error_response(500, message=str(e))
349 | error_detail = error_response.model_dump()
350 | error_detail["request_id"] = request_id
351 | raise HTTPException(status_code=500, detail=error_detail)
352 |
--------------------------------------------------------------------------------
/src/api/middleware/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 中间件模块
3 |
4 | 提供FastAPI应用的各种中间件实现。
5 |
6 | 主要功能:
7 | - API密钥认证中间件
8 | - 请求计时中间件
9 | - 请求ID追踪
10 | - 中间件配置和设置
11 |
12 | 使用示例:
13 | from src.api.middleware import setup_middlewares
14 |
15 | setup_middlewares(app)
16 | """
17 |
18 | # 导入中间件实现
19 | from .auth import APIKeyMiddleware
20 | from .timing import RequestTimingMiddleware, setup_middlewares
21 |
22 | __all__ = [
23 | "APIKeyMiddleware",
24 | "RequestTimingMiddleware",
25 | "setup_middlewares",
26 | ]
27 |
--------------------------------------------------------------------------------
/src/api/middleware/auth.py:
--------------------------------------------------------------------------------
1 | from fastapi import Request
2 | from starlette.middleware.base import BaseHTTPMiddleware
3 |
4 | from src.models.errors import get_error_response
5 |
6 |
7 | class APIKeyMiddleware(BaseHTTPMiddleware):
8 | """中间件:验证API密钥"""
9 |
10 | def __init__(self, app, api_key: str):
11 | super().__init__(app)
12 | self.api_key = api_key
13 |
14 | async def dispatch(self, request: Request, call_next):
15 | # 检查是否为需要密钥验证的路径
16 | if not self._requires_auth(request.url.path):
17 | # 跳过认证,直接处理请求
18 | response = await call_next(request)
19 | return response
20 |
21 | # 从请求头中获取API密钥
22 | token = request.headers.get("x-api-key")
23 |
24 | if token != self.api_key:
25 | error_response = get_error_response(401, message="API密钥无效")
26 |
27 | # 直接返回401响应,而不是抛出异常
28 | from fastapi.responses import JSONResponse
29 |
30 | return JSONResponse(status_code=401, content=error_response.dict())
31 |
32 | response = await call_next(request)
33 | return response
34 |
35 | def _requires_auth(self, path: str) -> bool:
36 | """检查路径是否需要API密钥验证"""
37 | # 只有 /v1/messages 需要密钥验证
38 | auth_required_paths = ["/v1/messages"]
39 | return path in auth_required_paths
40 |
--------------------------------------------------------------------------------
/src/api/middleware/timing.py:
--------------------------------------------------------------------------------
1 | """请求ID中间件"""
2 |
3 | import time
4 | from collections.abc import Callable
5 |
6 | from fastapi import FastAPI, Request
7 | from starlette.middleware.base import BaseHTTPMiddleware
8 | from starlette.responses import Response
9 | from starlette.types import ASGIApp
10 |
11 | # RequestIDMiddleware 已移除 - 如需request_id功能可重新添加
12 |
13 |
14 | class RequestTimingMiddleware(BaseHTTPMiddleware):
15 | """记录请求处理时间的中间件"""
16 |
17 | def __init__(self, app: ASGIApp) -> None:
18 | super().__init__(app)
19 |
20 | async def dispatch(self, request: Request, call_next: Callable) -> Response:
21 | start_time = time.time()
22 |
23 | # 延迟导入
24 | from src.common.logging import (
25 | generate_request_id,
26 | get_logger_with_request_id,
27 | get_request_id_header_name,
28 | )
29 |
30 | # 生成请求ID并添加到请求状态中(默认启用)
31 | request_id = await generate_request_id()
32 | request.state.request_id = request_id
33 |
34 | # 获取绑定了请求ID的logger
35 | bound_logger = get_logger_with_request_id(request_id)
36 |
37 | try:
38 | response = await call_next(request)
39 |
40 | response_time = time.time() - start_time
41 | response_time_ms = round(response_time * 1000, 2)
42 |
43 | # 使用绑定了请求ID的logger记录响应
44 | bound_logger.info(
45 | f"请求完成 - Status: {response.status_code}, Time: {response_time_ms}ms"
46 | )
47 |
48 | response.headers["X-Process-Time"] = f"{response_time:.3f}s"
49 | header_name = await get_request_id_header_name()
50 | response.headers[header_name] = request_id
51 |
52 | return response
53 |
54 | except Exception as exc:
55 | response_time = time.time() - start_time
56 | error_content = (
57 | f'{{"error":"Internal Server Error","request_id":"{request_id}"}}'
58 | )
59 |
60 | response = Response(
61 | content=error_content,
62 | status_code=500,
63 | media_type="application/json",
64 | )
65 | header_name = await get_request_id_header_name()
66 | response.headers[header_name] = request_id
67 |
68 | # 使用绑定了请求ID的logger记录错误
69 | # 安全构造错误日志 - 避免格式字符串问题
70 | safe_url = str(request.url) if hasattr(request, "url") else "unknown"
71 | safe_method = request.method if hasattr(request, "method") else "unknown"
72 |
73 | bound_logger.error(
74 | "请求处理错误",
75 | error_type=type(exc).__name__,
76 | error_message=str(exc),
77 | url=safe_url,
78 | method=safe_method,
79 | exc_info=True,
80 | )
81 |
82 | return response
83 |
84 |
85 | def setup_middlewares(app: FastAPI) -> None:
86 | """设置所有中间件"""
87 | # 只保留请求计时中间件
88 | app.add_middleware(RequestTimingMiddleware)
89 |
--------------------------------------------------------------------------------
/src/api/routes.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, timezone
2 | from typing import Any
3 |
4 | from fastapi import APIRouter, Depends
5 |
6 | from src.config.settings import Config
7 | from src.core.clients.openai_client import OpenAIServiceClient
8 |
9 | router = APIRouter()
10 |
11 |
12 | async def get_openai_client() -> OpenAIServiceClient:
13 | """获取OpenAI客户端实例"""
14 | config = await Config.from_file()
15 | return OpenAIServiceClient(
16 | api_key=config.openai.api_key,
17 | base_url=config.openai.base_url,
18 | )
19 |
20 |
21 | @router.get("/health", tags=["health"])
22 | async def health_check(
23 | client: OpenAIServiceClient = Depends(get_openai_client),
24 | ) -> dict[str, Any]:
25 | """健康检查端点 - 验证OpenAI连通性"""
26 |
27 | health_status = {
28 | "status": "healthy",
29 | "service": "openai-to-claude",
30 | "timestamp": datetime.now().isoformat(),
31 | "checks": {},
32 | }
33 |
34 | try:
35 | # 检查OpenAI服务可用性
36 | openai_health = await client.health_check()
37 | health_status["checks"]["openai"] = openai_health
38 |
39 | # 如果任何一个检查失败,状态设为降级
40 | if not all(openai_health.values()):
41 | health_status["status"] = "degraded"
42 |
43 | except Exception as e:
44 | # 如果无法创建客户端或者检查抛出异常
45 | health_status["status"] = "unhealthy"
46 | health_status["checks"]["openai"] = {
47 | "openai_service": False,
48 | "api_accessible": False,
49 | "error": str(e),
50 | }
51 |
52 | return health_status
53 |
--------------------------------------------------------------------------------
/src/common/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 通用工具模块
3 |
4 | 提供项目中共享的工具和实用功能。
5 |
6 | 主要功能:
7 | - 日志配置和管理
8 | - 请求ID生成和追踪
9 | - Token计数功能
10 | - 异常处理工具
11 |
12 | 使用示例:
13 | from src.common import configure_logging, request_logger
14 |
15 | configure_logging()
16 | logger = request_logger.get_logger()
17 | """
18 |
19 | # 导入日志相关功能
20 | from .logging import (
21 | RequestLogger,
22 | configure_logging,
23 | generate_request_id,
24 | get_logger_with_request_id,
25 | get_request_id_from_request,
26 | get_request_id_header_name,
27 | log_exception,
28 | request_logger,
29 | should_enable_request_id,
30 | )
31 |
32 | # 导入Token计数功能
33 | from .token_counter import TokenCounter, token_counter
34 |
35 | __all__ = [
36 | # 日志功能
37 | "configure_logging",
38 | "RequestLogger",
39 | "request_logger",
40 | "generate_request_id",
41 | "should_enable_request_id",
42 | "get_request_id_header_name",
43 | "get_request_id_from_request",
44 | "log_exception",
45 | "get_logger_with_request_id",
46 | # Token计数
47 | "TokenCounter",
48 | "token_counter",
49 | ]
50 |
--------------------------------------------------------------------------------
/src/common/logging.py:
--------------------------------------------------------------------------------
1 | """Loguru日志配置"""
2 |
3 | import sys
4 | import uuid
5 | from pathlib import Path
6 |
7 | from loguru import logger
8 |
9 |
10 | def configure_logging(log_config) -> None:
11 | """配置Loguru日志系统
12 |
13 | Args:
14 | log_config: 日志配置对象
15 | """
16 | # 移除默认的handler
17 | logger.remove()
18 |
19 | # 使用相对路径而不是绝对路径
20 | log_path = Path("logs/app.log")
21 |
22 | # 确保日志目录存在,并设置正确的权限
23 | log_path.parent.mkdir(parents=True, exist_ok=True)
24 | # 设置目录权限为755,确保当前用户可写
25 | log_path.parent.chmod(0o755)
26 |
27 | # 如果日志文件已存在,确保其可写
28 | if log_path.exists():
29 | log_path.chmod(0o644)
30 |
31 | # 控制台日志格式(包含请求ID)
32 | console_format = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {extra[request_id]} | {name}:{line} - {message}"
33 |
34 | # 配置控制台日志
35 | logger.add(
36 | sys.stdout,
37 | format=console_format,
38 | level=log_config.level,
39 | colorize=True,
40 | filter=lambda record: record["extra"].setdefault("request_id", "---"),
41 | )
42 |
43 | # 配置文件日志(包含截取的异常堆栈)
44 | logger.add(
45 | str(log_path),
46 | level=log_config.level,
47 | rotation="10 MB",
48 | retention="1 day",
49 | encoding="utf-8",
50 | enqueue=True, # 异步写入
51 | filter=lambda record: record["extra"].setdefault("request_id", "---"),
52 | )
53 |
54 | # 配置全局异常处理
55 | def exception_handler(exc_type, exc_value, exc_traceback):
56 | """全局异常处理器"""
57 | if issubclass(exc_type, KeyboardInterrupt):
58 | # 允许KeyboardInterrupt正常退出
59 | sys.__excepthook__(exc_type, exc_value, exc_traceback)
60 | return
61 |
62 | logger.opt(exception=(exc_type, exc_value, exc_traceback)).critical(
63 | "未捕获的异常"
64 | )
65 |
66 | # 设置全局异常处理器
67 | sys.excepthook = exception_handler
68 |
69 |
70 | class RequestLogger:
71 | """请求日志处理器"""
72 |
73 | async def log_response(
74 | self, status_code: int, response_time: float, request_id: str = None
75 | ):
76 | """记录响应结束"""
77 | bound_logger = get_logger_with_request_id(request_id)
78 |
79 | response_time_ms = round(response_time * 1000, 2)
80 | bound_logger.info(
81 | f"请求完成 - Status: {status_code}, Time: {response_time_ms}ms"
82 | )
83 |
84 | async def log_error(
85 | self, error: Exception, context: dict = None, request_id: str = None
86 | ):
87 | """记录错误情况"""
88 | bound_logger = get_logger_with_request_id(request_id)
89 |
90 | error_type = type(error).__name__
91 | error_message = str(error)
92 | context_str = f", Context: {context}" if context else ""
93 |
94 | # 使用loguru的exception方法记录完整的堆栈跟踪
95 | bound_logger.exception(
96 | f"请求处理错误 - Type: {error_type}, Message: {error_message}{context_str}"
97 | )
98 |
99 |
100 | # 全局logger实例
101 | request_logger = RequestLogger()
102 |
103 |
104 | async def generate_request_id() -> str:
105 | """生成唯一的请求ID
106 |
107 | Returns:
108 | str: 格式为 req_xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx 的请求ID
109 | """
110 | return f"req_{uuid.uuid4().hex}"
111 |
112 |
113 | async def should_enable_request_id() -> bool:
114 | """检查是否应该启用请求ID(始终启用)
115 |
116 | Returns:
117 | bool: 始终返回True,请求ID功能默认启用
118 | """
119 | return True
120 |
121 |
122 | async def get_request_id_header_name() -> str:
123 | """获取请求ID响应头名称
124 |
125 | Returns:
126 | str: 固定返回 "X-Request-ID"
127 | """
128 | return "X-Request-ID"
129 |
130 |
131 | def get_request_id_from_request(request) -> str | None:
132 | """从请求对象中安全地获取请求ID
133 |
134 | Args:
135 | request: FastAPI Request对象
136 |
137 | Returns:
138 | str | None: 请求ID,如果不存在则返回None
139 | """
140 | try:
141 | return getattr(request.state, "request_id", None)
142 | except AttributeError:
143 | return None
144 |
145 |
146 | async def log_exception(message: str = "发生异常", **kwargs):
147 | """记录异常的便捷函数
148 |
149 | 使用示例:
150 | try:
151 | # 一些可能出错的代码
152 | pass
153 | except Exception as e:
154 | log_exception("处理请求时发生错误", request_id="123", user_id="456")
155 |
156 | Args:
157 | message: 异常描述信息
158 | **kwargs: 额外的上下文信息
159 | """
160 | kwargs_str = ", ".join([f"{k}: {v}" for k, v in kwargs.items()]) if kwargs else ""
161 | full_message = f"{message} - {kwargs_str}" if kwargs_str else message
162 | logger.exception(full_message)
163 |
164 |
165 | def get_logger_with_request_id(request_id: str = None):
166 | """获取绑定了请求ID的日志器实例
167 |
168 | Args:
169 | request_id: 请求ID,如果为None则使用默认值
170 |
171 | Returns:
172 | 绑定了请求ID的logger实例
173 | """
174 | if request_id:
175 | return logger.bind(request_id=request_id)
176 | else:
177 | return logger.bind(request_id="---")
178 |
--------------------------------------------------------------------------------
/src/common/token_cache.py:
--------------------------------------------------------------------------------
1 | """
2 | 简单的请求token缓存模块
3 |
4 | 基于KISS原则,使用全局字典实现请求ID与token数量的临时缓存。
5 | 主要用于在OpenAI响应缺失usage信息时提供fallback。
6 | """
7 |
8 | from typing import Dict, Optional
9 |
10 | # 全局缓存字典 - 遵循KISS原则
11 | _cache: Dict[str, int] = {}
12 |
13 |
14 | def cache_tokens(request_id: str, tokens: int) -> None:
15 | """
16 | 缓存请求的token数量
17 |
18 | Args:
19 | request_id: 请求ID
20 | tokens: token数量
21 | """
22 | if request_id and tokens > 0:
23 | _cache[request_id] = tokens
24 |
25 |
26 | def get_cached_tokens(request_id: str, delete=False) -> Optional[int]:
27 | """
28 | 获取缓存的token数量并清理缓存
29 |
30 | Args:
31 | request_id: 请求ID
32 |
33 | Returns:
34 | 缓存的token数量,如果不存在则返回None
35 |
36 | Note:
37 | 使用pop()方法,获取后自动删除缓存,防止内存泄漏
38 | """
39 | if not request_id:
40 | return None
41 | if delete:
42 | return _cache.pop(request_id, None)
43 | else:
44 | return _cache.get(request_id, None)
45 |
46 |
47 | def get_cache_size() -> int:
48 | """
49 | 获取当前缓存大小,用于调试和监控
50 |
51 | Returns:
52 | 缓存中的条目数量
53 | """
54 | return len(_cache)
55 |
56 |
57 | def clear_cache() -> None:
58 | """
59 | 清空所有缓存,用于测试或重置
60 | """
61 | _cache.clear()
62 |
--------------------------------------------------------------------------------
/src/common/token_counter.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Any
3 |
4 | import tiktoken
5 |
6 |
7 | class TokenCounter:
8 | """Token计数器,基于Node.js实现完整功能复现"""
9 |
10 | def __init__(self):
11 | self.encoder = tiktoken.get_encoding("o200k_base")
12 |
13 | def _extract_text_content(self, obj, field_name: str) -> str:
14 | """统一提取文本内容的方法,简化代码重复"""
15 | if hasattr(obj, field_name):
16 | return str(getattr(obj, field_name, ""))
17 | elif isinstance(obj, dict):
18 | return str(obj.get(field_name, ""))
19 | return ""
20 |
21 | def _process_content_part(self, content_part) -> list[str]:
22 | """处理消息内容部分,返回文本列表"""
23 | texts = []
24 | part_type = (
25 | content_part.type
26 | if hasattr(content_part, "type")
27 | else content_part.get("type") if isinstance(content_part, dict) else None
28 | )
29 |
30 | if part_type == "text":
31 | text = self._extract_text_content(content_part, "text")
32 | if text:
33 | texts.append(text)
34 | elif part_type == "tool_use":
35 | input_data = (
36 | getattr(content_part, "input", {})
37 | if hasattr(content_part, "input")
38 | else content_part.get("input", {})
39 | )
40 | if input_data:
41 | texts.append(json.dumps(input_data, ensure_ascii=False))
42 |
43 | return texts
44 |
45 | async def count_tokens(
46 | self,
47 | messages: list[Any] = None,
48 | system: Any = None,
49 | tools: list[Any] = None,
50 | ) -> int:
51 | """计算完整请求的token总数
52 |
53 | Args:
54 | messages: 消息列表
55 | system: 系统提示
56 | tools: 工具定义
57 |
58 | Returns:
59 | int: 总计token数量
60 | """
61 | # 收集所有文本内容到单个列表,遵循KISS原则
62 | text_parts = []
63 |
64 | # 处理消息内容
65 | if messages:
66 | for message in messages:
67 | # 获取消息内容
68 | content = (
69 | message.content
70 | if hasattr(message, "content")
71 | else message.get("content", "") if isinstance(message, dict) else ""
72 | )
73 |
74 | if isinstance(content, str):
75 | text_parts.append(content)
76 | elif isinstance(content, list):
77 | for content_part in content:
78 | text_parts.extend(self._process_content_part(content_part))
79 |
80 | # 处理系统提示
81 | if system:
82 | if isinstance(system, str):
83 | text_parts.append(system)
84 | elif isinstance(system, list):
85 | for item in system:
86 | item_type = (
87 | item.type
88 | if hasattr(item, "type")
89 | else item.get("type") if isinstance(item, dict) else None
90 | )
91 | if item_type == "text":
92 | text_content = self._extract_text_content(item, "text")
93 | if text_content:
94 | text_parts.append(text_content)
95 |
96 | # 处理工具定义
97 | if tools:
98 | for tool in tools:
99 | # 统一获取name和description
100 | name = self._extract_text_content(tool, "name")
101 | description = self._extract_text_content(tool, "description")
102 |
103 | if name:
104 | text_parts.append(name)
105 | if description:
106 | text_parts.append(description)
107 |
108 | # 处理schema
109 | schema = (
110 | getattr(tool, "input_schema", None)
111 | if hasattr(tool, "input_schema")
112 | else tool.get("input_schema") if isinstance(tool, dict) else None
113 | )
114 | if schema:
115 | text_parts.append(json.dumps(schema, ensure_ascii=False))
116 |
117 | # 一次性拼接所有文本并计算token(KISS原则)
118 | combined_text = "".join(text_parts)
119 | return len(self.encoder.encode(combined_text))
120 |
121 | def count_response_tokens(self, content_blocks: list) -> int:
122 | """计算响应内容的token数量
123 |
124 | Args:
125 | content_blocks: Anthropic格式的内容块列表
126 |
127 | Returns:
128 | int: 响应内容的token数量
129 | """
130 | # 收集所有文本内容到单个列表,遵循KISS原则
131 | text_parts = []
132 |
133 | # 处理内容块
134 | if content_blocks:
135 | for block in content_blocks:
136 | # 处理文本内容
137 | if hasattr(block, "text") and block.text:
138 | text_parts.append(str(block.text))
139 | elif isinstance(block, dict) and block.get("text"):
140 | text_parts.append(str(block["text"]))
141 |
142 | # 处理思考内容
143 | if hasattr(block, "thinking") and block.thinking:
144 | text_parts.append(str(block.thinking))
145 | elif isinstance(block, dict) and block.get("thinking"):
146 | text_parts.append(str(block["thinking"]))
147 |
148 | # 处理工具调用内容
149 | if hasattr(block, "input") and block.input:
150 | text_parts.append(json.dumps(block.input, ensure_ascii=False))
151 | elif isinstance(block, dict) and block.get("input"):
152 | text_parts.append(json.dumps(block["input"], ensure_ascii=False))
153 |
154 | # 处理工具名称
155 | if hasattr(block, "name") and block.name:
156 | text_parts.append(str(block.name))
157 | elif isinstance(block, dict) and block.get("name"):
158 | text_parts.append(str(block["name"]))
159 |
160 | # 一次性拼接所有文本并计算token(KISS原则)
161 | combined_text = "".join(text_parts)
162 | return len(self.encoder.encode(combined_text))
163 |
164 |
165 | # 全局实例
166 | token_counter = TokenCounter()
167 |
--------------------------------------------------------------------------------
/src/config/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 配置管理模块
3 |
4 | 提供应用程序配置的加载、验证和热重载功能。
5 |
6 | 主要功能:
7 | - 配置文件加载和验证
8 | - 配置热重载监听
9 | - 配置模型定义
10 | - 配置实例管理
11 |
12 | 使用示例:
13 | from src.config import get_config, reload_config
14 |
15 | config = get_config()
16 | print(config.server.host)
17 | """
18 |
19 | # 导入配置相关功能
20 | from .settings import (
21 | Config,
22 | LoggingConfig,
23 | ModelConfig,
24 | OpenAIConfig,
25 | ParameterOverridesConfig,
26 | ServerConfig,
27 | get_config,
28 | get_config_file_path,
29 | reload_config,
30 | )
31 | from .watcher import ConfigFileHandler, ConfigWatcher
32 |
33 | __all__ = [
34 | # 配置管理函数
35 | "get_config",
36 | "reload_config",
37 | "get_config_file_path",
38 | # 配置模型
39 | "Config",
40 | "OpenAIConfig",
41 | "ServerConfig",
42 | "LoggingConfig",
43 | "ModelConfig",
44 | "ParameterOverridesConfig",
45 | # 配置监听器
46 | "ConfigWatcher",
47 | "ConfigFileHandler",
48 | ]
49 |
--------------------------------------------------------------------------------
/src/config/settings.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from pathlib import Path
4 |
5 | import aiofiles
6 | from loguru import logger
7 | from pydantic import BaseModel, Field, field_validator
8 |
9 | # 全局配置缓存
10 | _config_instance = None
11 |
12 |
13 | async def get_config() -> "Config":
14 | """
15 | 获取全局配置对象(带缓存的单例模式)
16 | """
17 | global _config_instance
18 | if _config_instance is None:
19 | try:
20 | _config_instance = await Config.from_file()
21 | except Exception:
22 | # 如果配置文件读取失败,创建默认配置
23 | _config_instance = Config(
24 | openai={
25 | "api_key": "your-openai-api-key-here",
26 | "base_url": "https://api.openai.com/v1",
27 | }
28 | )
29 | return _config_instance
30 |
31 |
32 | async def reload_config(config_path: str | None = None) -> "Config":
33 | """重新加载全局配置对象
34 |
35 | Args:
36 | config_path: 配置文件路径,如果为None则使用默认路径
37 |
38 | Returns:
39 | Config: 重新加载的配置实例
40 |
41 | Raises:
42 | Exception: 配置加载失败时保持原配置不变
43 | """
44 | global _config_instance
45 |
46 | try:
47 | # 尝试加载新配置
48 | new_config = await Config.from_file(config_path)
49 | _config_instance = new_config
50 | logger.info(f"配置重载成功: {new_config.model_dump_json()}")
51 | return _config_instance
52 | except Exception as e:
53 | logger.error(f"配置重载失败,保持原配置: {e}")
54 | if _config_instance is None:
55 | # 如果没有原配置,则创建默认配置
56 | _config_instance = Config(
57 | openai={
58 | "api_key": "your-openai-api-key-here",
59 | "base_url": "https://api.openai.com/v1",
60 | }
61 | )
62 | return _config_instance
63 |
64 |
65 | def get_config_file_path() -> str:
66 | """获取当前使用的配置文件路径
67 |
68 | Returns:
69 | str: 配置文件路径
70 | """
71 | import os
72 |
73 | return os.getenv("CONFIG_PATH", "config/settings.json")
74 |
75 |
76 | class OpenAIConfig(BaseModel):
77 | """OpenAI API 配置"""
78 |
79 | api_key: str = Field(..., description="OpenAI API密钥")
80 | base_url: str = Field("https://api.openai.com/v1", description="OpenAI API基础URL")
81 |
82 |
83 | class ServerConfig(BaseModel):
84 | """服务器配置"""
85 |
86 | host: str = Field("0.0.0.0", description="服务监听主机")
87 | port: int = Field(8000, gt=0, lt=65536, description="服务监听端口")
88 |
89 |
90 | class LoggingConfig(BaseModel):
91 | """日志配置"""
92 |
93 | level: str = Field(
94 | "INFO", description="日志级别 (DEBUG, INFO, WARNING, ERROR, CRITICAL)"
95 | )
96 |
97 | def __init__(self, **data):
98 | """初始化时支持环境变量覆盖"""
99 | # 环境变量覆盖
100 | if "LOG_LEVEL" in os.environ:
101 | data["level"] = os.environ["LOG_LEVEL"]
102 |
103 | super().__init__(**data)
104 |
105 | @field_validator("level")
106 | @classmethod
107 | def validate_log_level(cls, v: str) -> str:
108 | """验证日志级别"""
109 | valid_levels = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
110 | if v.upper() not in valid_levels:
111 | raise ValueError(f"日志级别必须是以下之一: {', '.join(valid_levels)}")
112 | return v.upper()
113 |
114 |
115 | class ModelConfig(BaseModel):
116 | """模型配置类
117 |
118 | 定义不同使用场景下的模型选择
119 | """
120 |
121 | default: str = Field(
122 | description="默认通用模型", default="claude-3-5-sonnet-20241022"
123 | )
124 | small: str = Field(
125 | description="轻量级模型,用于简单任务", default="claude-3-5-haiku-20241022"
126 | )
127 | tool: str = Field(
128 | description="工具使用专用模型", default="claude-3-5-sonnet-20241022"
129 | )
130 | think: str = Field(
131 | description="深度思考模型,用于复杂推理任务",
132 | default="claude-3-7-sonnet-20250219",
133 | )
134 | long_context: str = Field(
135 | description="长上下文处理模型", default="claude-3-7-sonnet-20250219"
136 | )
137 | web_search: str = Field(description="网络搜索模型", default="gemini-2.5-flash")
138 |
139 |
140 | class ParameterOverridesConfig(BaseModel):
141 | """参数覆盖配置类
142 |
143 | 允许管理员在配置文件中设置模型参数的覆盖值。
144 | 当设置了这些参数时,会覆盖客户端请求中的相应参数。
145 | """
146 |
147 | max_tokens: int | None = Field(
148 | None,
149 | gt=0,
150 | description="最大token数覆盖,设置后会覆盖客户端请求中的max_tokens参数",
151 | )
152 | temperature: float | None = Field(
153 | None, ge=0.0, le=2.0, description="温度参数覆盖,控制输出的随机程度"
154 | )
155 | top_p: float | None = Field(
156 | None, ge=0.0, le=1.0, description="top_p采样参数覆盖,控制候选词汇的概率阈值"
157 | )
158 | top_k: int | None = Field(
159 | None, ge=0, description="top_k采样参数覆盖,控制候选词汇的数量"
160 | )
161 |
162 |
163 | class Config(BaseModel):
164 | """应用配置根类
165 |
166 | 使用 JSON 配置文件加载配置。
167 | 配置文件优先级:
168 | 1. 命令行指定的配置路径
169 | 2. 环境变量 CONFIG_PATH 指定的路径
170 | 3. ./config/settings.json (默认)
171 | 4. ./config/example.json (示例配置)
172 | 5. 默认值
173 | """
174 |
175 | # 各模块配置
176 | openai: OpenAIConfig
177 | server: ServerConfig = ServerConfig()
178 | api_key: str = Field(..., description="/v1/messages接口的API密钥")
179 | logging: LoggingConfig = LoggingConfig()
180 | models: ModelConfig = ModelConfig()
181 | parameter_overrides: ParameterOverridesConfig = ParameterOverridesConfig()
182 |
183 | @classmethod
184 | async def from_file(cls, config_path: str | None = None) -> "Config":
185 | """
186 | 从 JSON 配置文件加载配置
187 | Args:
188 | config_path: JSON配置文件路径,如果为None则使用默认路径
189 |
190 | Returns:
191 | Config: 配置实例
192 |
193 | Raises:
194 | FileNotFoundError: 配置文件不存在
195 | json.JSONDecodeError: JSON格式错误
196 | ValidationError: 配置数据验证错误
197 | """
198 | import os
199 |
200 | if config_path is None:
201 | # 优先使用环境变量指定的路径
202 | config_path = os.getenv("CONFIG_PATH", "config/settings.json")
203 |
204 | config_file = Path(config_path)
205 |
206 | if config_file.exists():
207 | try:
208 | async with aiofiles.open(config_file, encoding="utf-8") as f:
209 | config_data = await f.read()
210 | config_data = json.loads(config_data)
211 | except json.JSONDecodeError as e:
212 | print(f"❌ 配置文件格式错误: {e}")
213 | raise
214 | else:
215 | print(f"⚠️ 配置文件 {config_file.absolute()} 不存在")
216 | print("📦 使用 config/example.json 作为模板")
217 |
218 | # 尝试使用 example 配置
219 | example_file = Path("config/example.json")
220 | if example_file.exists():
221 | try:
222 | async with aiofiles.open(example_file, encoding="utf-8") as f:
223 | config_data = await f.read()
224 | config_data = json.loads(config_data)
225 | # 创建 settings.json 作为实际配置文件
226 | async with aiofiles.open(config_file, "w", encoding="utf-8") as f:
227 | await f.write(
228 | json.dumps(config_data, indent=2, ensure_ascii=False)
229 | )
230 | print(f"✅ 已从模板创建 {config_file}")
231 |
232 | except (json.JSONDecodeError, OSError) as e:
233 | print(f"❌ 无法创建配置文件: {e}")
234 | config_data = {}
235 | else:
236 | config_data = {}
237 |
238 | # 验证必填的 openai 配置
239 | if "openai" not in config_data:
240 | config_data["openai"] = {
241 | "api_key": "your-openai-api-key-here",
242 | "base_url": "https://api.openai.com/v1",
243 | }
244 |
245 | # 确保api_key存在(这是一个必填项)
246 | if "api_key" not in config_data:
247 | config_data["api_key"] = "your-proxy-api-key-here"
248 |
249 | return cls(**config_data)
250 |
251 | @classmethod
252 | def from_file_sync(cls, config_path: str | None = None) -> "Config":
253 | """
254 | 从 JSON 配置文件加载配置
255 | Args:
256 | config_path: JSON配置文件路径,如果为None则使用默认路径
257 |
258 | Returns:
259 | Config: 配置实例
260 |
261 | Raises:
262 | FileNotFoundError: 配置文件不存在
263 | json.JSONDecodeError: JSON格式错误
264 | ValidationError: 配置数据验证错误
265 | """
266 | import os
267 |
268 | if config_path is None:
269 | # 优先使用环境变量指定的路径
270 | config_path = os.getenv("CONFIG_PATH", "config/settings.json")
271 |
272 | config_file = Path(config_path)
273 |
274 | if config_file.exists():
275 | try:
276 | with open(config_file, encoding="utf-8") as f:
277 | config_data = json.load(f)
278 | except json.JSONDecodeError as e:
279 | print(f"❌ 配置文件格式错误: {e}")
280 | raise
281 | else:
282 | print(f"⚠️ 配置文件 {config_file.absolute()} 不存在")
283 | print("📦 使用 config/example.json 作为模板")
284 |
285 | # 尝试使用 example 配置
286 | example_file = Path("config/example.json")
287 | if example_file.exists():
288 | try:
289 | with open(example_file, encoding="utf-8") as f:
290 | config_data = json.load(f)
291 | # 创建 settings.json 作为实际配置文件
292 | with open(config_file, "w", encoding="utf-8") as f:
293 | f.write(json.dumps(config_data, indent=2, ensure_ascii=False))
294 | print(f"✅ 已从模板创建 {config_file}")
295 | except (json.JSONDecodeError, OSError) as e:
296 | print(f"❌ 无法创建配置文件: {e}")
297 | config_data = {}
298 | else:
299 | config_data = {}
300 |
301 | # 验证必填的 openai 配置
302 | if "openai" not in config_data:
303 | config_data["openai"] = {
304 | "api_key": "your-openai-api-key-here",
305 | "base_url": "https://api.openai.com/v1",
306 | }
307 |
308 | # 确保api_key存在(这是一个必填项)
309 | if "api_key" not in config_data:
310 | config_data["api_key"] = "your-proxy-api-key-here"
311 |
312 | return cls(**config_data)
313 |
314 | def get_server_config(self) -> tuple[str, int]:
315 | """获取服务器配置 (host, port)
316 |
317 | Returns:
318 | tuple[str, int]: (host, port)
319 | """
320 | return self.server.host, self.server.port
321 |
--------------------------------------------------------------------------------
/src/config/watcher.py:
--------------------------------------------------------------------------------
1 | """配置文件监听和热重载模块
2 |
3 | 监听配置文件的变化,当配置文件被修改时自动重新加载配置。
4 | 使用 watchdog 库监听文件系统事件。
5 | """
6 |
7 | import asyncio
8 | import json
9 | import os
10 | import threading
11 | from collections.abc import Callable
12 | from concurrent.futures import ThreadPoolExecutor
13 | from pathlib import Path
14 | from typing import Any
15 |
16 | from loguru import logger
17 | from watchdog.events import FileSystemEventHandler
18 | from watchdog.observers import Observer
19 |
20 |
21 | class ConfigFileHandler(FileSystemEventHandler):
22 | """配置文件变化事件处理器"""
23 |
24 | def __init__(self, config_path: Path, callback: Callable[[], None]):
25 | """
26 | 初始化配置文件处理器
27 |
28 | Args:
29 | config_path: 要监听的配置文件路径
30 | callback: 配置文件变化时的回调函数
31 | """
32 | self.config_path = config_path.resolve()
33 | self.callback = callback
34 | self._last_modified = 0
35 |
36 | def on_modified(self, event) -> None:
37 | """处理文件修改事件"""
38 | if event.is_directory:
39 | return
40 |
41 | # 检查是否是我们监听的配置文件
42 | event_path = Path(event.src_path).resolve()
43 | if event_path != self.config_path:
44 | return
45 |
46 | # 防止重复触发
47 | try:
48 | current_modified = event_path.stat().st_mtime
49 | if current_modified == self._last_modified:
50 | return
51 | self._last_modified = current_modified
52 | except OSError:
53 | return
54 |
55 | logger.info(f"配置文件已修改: {self.config_path}")
56 |
57 | # 延迟一点执行,确保文件写入完成
58 | threading.Timer(0.1, self._execute_callback).start()
59 |
60 | def _execute_callback(self) -> None:
61 | """执行回调函数"""
62 | try:
63 | self.callback()
64 | except Exception as e:
65 | logger.error(f"配置重载回调执行失败: {e}")
66 |
67 |
68 | class ConfigWatcher:
69 | """配置文件监听器
70 |
71 | 监听指定的配置文件,当文件发生变化时触发重新加载。
72 | """
73 |
74 | def __init__(self, config_path: str | None = None):
75 | """
76 | 初始化配置监听器
77 |
78 | Args:
79 | config_path: 配置文件路径,如果为None则使用默认路径
80 | """
81 | if config_path is None:
82 | config_path = os.getenv("CONFIG_PATH", "config/settings.json")
83 |
84 | self.config_path = Path(config_path).resolve()
85 | self.observer: Observer | None = None
86 | self.handler: ConfigFileHandler | None = None
87 | self._reload_callbacks: list[Callable[[], None]] = []
88 | self._async_reload_callbacks: list[Callable[[], Any]] = []
89 | self._executor: ThreadPoolExecutor | None = None
90 |
91 | def add_reload_callback(self, callback: Callable[[], Any]) -> None:
92 | """
93 | 添加异步配置重载回调函数
94 |
95 | Args:
96 | callback: 异步回调函数
97 | """
98 | self._async_reload_callbacks.append(callback)
99 |
100 | async def start_watching(self) -> None:
101 | """开始监听配置文件变化"""
102 | if self.observer is not None:
103 | logger.warning("配置监听器已在运行")
104 | return
105 |
106 | if not self.config_path.exists():
107 | logger.warning(f"配置文件不存在,跳过监listen: {self.config_path}")
108 | return
109 |
110 | # 创建线程池执行器用于异步回调
111 | if self._executor is None:
112 | self._executor = ThreadPoolExecutor(
113 | max_workers=1, thread_name_prefix="config-watcher"
114 | )
115 |
116 | # 创建事件处理器
117 | self.handler = ConfigFileHandler(self.config_path, self._on_config_changed)
118 |
119 | # 创建观察者并开始监听
120 | self.observer = Observer()
121 | watch_dir = self.config_path.parent
122 | self.observer.schedule(self.handler, str(watch_dir), recursive=False)
123 | self.observer.start()
124 |
125 | logger.info(f"开始监听配置文件: {self.config_path}")
126 |
127 | def stop_watching(self) -> None:
128 | """停止监听配置文件变化"""
129 | if self.observer is None:
130 | return
131 |
132 | logger.info("停止配置文件监听")
133 | self.observer.stop()
134 | self.observer.join()
135 | self.observer = None
136 | self.handler = None
137 |
138 | # 关闭线程池
139 | if self._executor is not None:
140 | self._executor.shutdown(wait=True)
141 | self._executor = None
142 |
143 | def _on_config_changed(self) -> None:
144 | """配置文件变化时的处理逻辑"""
145 | logger.info("检测到配置文件变化,开始重新加载...")
146 |
147 | # 在线程池中执行异步验证和回调
148 | if self._executor is not None:
149 | self._executor.submit(self._handle_config_change)
150 | else:
151 | logger.error("线程池执行器未初始化,跳过配置重载")
152 |
153 | async def _process_config_change(self) -> None:
154 | """处理配置变化的异步逻辑"""
155 | # 验证配置文件格式
156 | if not await self._validate_config_file():
157 | logger.error("配置文件格式无效,跳过重载")
158 | return
159 |
160 | # 执行同步回调
161 | for callback in self._reload_callbacks:
162 | try:
163 | callback()
164 | logger.debug(f"同步配置重载回调执行成功: {callback.__name__}")
165 | except Exception as e:
166 | logger.error(f"同步配置重载回调执行失败 {callback.__name__}: {e}")
167 |
168 | # 执行异步回调
169 | await self._execute_async_callbacks()
170 |
171 | logger.info("配置重载完成")
172 |
173 | def _handle_config_change(self) -> None:
174 | """在线程池中处理配置变化"""
175 | try:
176 | # 在新线程中创建事件循环
177 | loop = asyncio.new_event_loop()
178 | asyncio.set_event_loop(loop)
179 | try:
180 | # 运行异步验证和回调
181 | loop.run_until_complete(self._process_config_change())
182 | finally:
183 | loop.close()
184 | except Exception as e:
185 | logger.error(f"配置变化处理失败: {e}")
186 |
187 | async def _execute_async_callbacks(self) -> None:
188 | """执行异步回调函数"""
189 | for callback in self._async_reload_callbacks:
190 | try:
191 | if asyncio.iscoroutinefunction(callback):
192 | await callback()
193 | else:
194 | callback()
195 | except Exception as e:
196 | logger.error(f"异步配置重载回调执行失败 {callback.__name__}: {e}")
197 |
198 | async def _validate_config_file(self) -> bool:
199 | """验证配置文件格式是否正确"""
200 | try:
201 | import aiofiles
202 |
203 | async with aiofiles.open(self.config_path, encoding="utf-8") as f:
204 | content = await f.read()
205 | json.loads(content)
206 | return True
207 | except (json.JSONDecodeError, OSError) as e:
208 | logger.error(f"配置文件验证失败: {e}")
209 | return False
210 |
211 | def __enter__(self):
212 | """上下文管理器入口"""
213 | self.start_watching()
214 | return self
215 |
216 | def __exit__(self, exc_type, exc_val, exc_tb):
217 | """上下文管理器退出"""
218 | # 忽略异常信息,直接停止监听
219 | self.stop_watching()
220 |
--------------------------------------------------------------------------------
/src/core/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 核心功能模块
3 |
4 | 提供代理服务的核心功能,包括:
5 | - OpenAI 客户端封装
6 | - 请求/响应格式转换器
7 | - 数据模型定义
8 |
9 | 子模块:
10 | - clients: OpenAI API 客户端实现
11 | - converters: Anthropic ↔ OpenAI 格式转换器
12 | """
13 |
14 | # 导入核心客户端
15 | from .clients import OpenAIServiceClient
16 |
17 | # 导入转换器
18 | from .converters import (
19 | AnthropicToOpenAIConverter,
20 | OpenAIToAnthropicConverter,
21 | )
22 |
23 | __all__ = [
24 | # 客户端
25 | "OpenAIServiceClient",
26 | # 转换器
27 | "AnthropicToOpenAIConverter",
28 | "OpenAIToAnthropicConverter",
29 | ]
30 |
--------------------------------------------------------------------------------
/src/core/clients/__init__.py:
--------------------------------------------------------------------------------
1 | from .openai_client import OpenAIServiceClient
2 |
3 | __all__ = ["OpenAIServiceClient"]
4 |
--------------------------------------------------------------------------------
/src/core/clients/openai_client.py:
--------------------------------------------------------------------------------
1 | """OpenAI API client for making asynchronous requests to OpenAI endpoints."""
2 |
3 | import json
4 | from collections.abc import AsyncGenerator
5 | from typing import Any
6 |
7 | import httpx
8 | from loguru import logger
9 |
10 | from src.models.errors import StandardErrorResponse, get_error_response
11 | from src.models.openai import OpenAIRequest, OpenAIStreamResponse
12 |
13 |
14 | class OpenAIClientError(Exception):
15 | """Base exception for OpenAI client errors."""
16 |
17 | def __init__(self, error_response: StandardErrorResponse):
18 | self.error_response = error_response
19 | super().__init__(str(error_response))
20 |
21 |
22 | class OpenAIServiceClient:
23 | """Async OpenAI API client with connection pooling and retry logic."""
24 |
25 | def __init__(
26 | self,
27 | api_key: str,
28 | base_url: str = "https://api.openai.com/v1",
29 | timeout: float = 60.0,
30 | ):
31 | """Initialize OpenAI client with connection pool.
32 |
33 | Args:
34 | api_key: OpenAI API密钥
35 | base_url: OpenAI API基础URL
36 | timeout: 请求超时时间(秒)
37 | """
38 | self.api_key = api_key
39 | self.base_url = base_url.rstrip("/")
40 | self.timeout = timeout
41 |
42 | self.client = httpx.AsyncClient(
43 | headers={
44 | "Authorization": f"Bearer {api_key}",
45 | "Content-Type": "application/json",
46 | "Connection": "keep-alive",
47 | },
48 | # 确保自动解压缩响应
49 | follow_redirects=True,
50 | timeout=timeout,
51 | )
52 |
53 | async def __aenter__(self):
54 | """Async context manager entry."""
55 | return self
56 |
57 | async def __aexit__(self, exc_type, exc_val, exc_tb):
58 | """Async context manager exit."""
59 | await self.aclose()
60 |
61 | async def aclose(self):
62 | """Close the HTTP client."""
63 | await self.client.aclose()
64 |
65 | async def send_request(
66 | self,
67 | request: OpenAIRequest,
68 | endpoint: str = "/chat/completions",
69 | request_id: str = None,
70 | ) -> dict[str, Any]:
71 | """Send synchronous request to OpenAI API.
72 |
73 | Args:
74 | request: OpenAI request object
75 | endpoint: API endpoint path
76 | request_id: 请求ID用于日志追踪
77 |
78 | Returns:
79 | OpenAI API响应
80 |
81 | Raises:
82 | OpenAIClientError: 当API返回错误时
83 | """
84 | # 获取绑定了请求ID的logger
85 | from src.common.logging import get_logger_with_request_id
86 |
87 | bound_logger = get_logger_with_request_id(request_id)
88 |
89 | url = f"{self.base_url}{endpoint}"
90 | request_data = request.model_dump(exclude_none=True)
91 |
92 | # 记录请求详情
93 | bound_logger.info(
94 | f"发送OpenAI请求 - URL: {url}, Model: {request_data.get('model', 'unknown')}, Messages: {len(request_data.get('messages', []))}"
95 | )
96 |
97 | try:
98 | response = await self.client.post(
99 | url,
100 | json=request_data,
101 | )
102 | response.raise_for_status()
103 | # 记录响应状态
104 | content_type = response.headers.get("content-type", "unknown")
105 | bound_logger.info(
106 | f"收到OpenAI响应 - Status: {response.status_code}, Content-Type: {content_type}, Size: {len(response.content)} bytes"
107 | )
108 |
109 | # 使用 response.text 让 httpx 自动处理编码和解压缩
110 | try:
111 | text = response.text
112 | result = json.loads(text)
113 |
114 | # 记录响应内容(如果启用详细日志)
115 | bound_logger.debug(
116 | f"OpenAI响应内容 - ID: {result.get('id', 'unknown')}, Model: {result.get('model', 'unknown')}, Usage: {result.get('usage', {})}"
117 | )
118 |
119 | except json.JSONDecodeError as e:
120 | # 记录详细的JSON解析错误信息
121 | response_preview = (
122 | response.text[:500] if response.text else "Empty response"
123 | )
124 | content_type = response.headers.get("content-type", "unknown")
125 | bound_logger.exception(
126 | f"OpenAI JSON解析失败 - Status: {response.status_code}, Content-Type: {content_type}, "
127 | f"Error: {str(e)}, Response Preview: {response_preview}"
128 | )
129 | # 抛出包含更多上下文信息的异常
130 | raise json.JSONDecodeError(
131 | f"Failed to parse OpenAI response (Status: {response.status_code}): {str(e)}",
132 | response.text,
133 | e.pos,
134 | )
135 | return result
136 | except httpx.HTTPStatusError as e:
137 | # 安全读取响应内容(非流式模式)
138 | response_body = ""
139 | try:
140 | response_body = e.response.text
141 | except httpx.ResponseNotRead:
142 | # 如果响应未被读取,直接获取错误信息
143 | response_body = str(e)
144 |
145 | bound_logger.error(
146 | f"OpenAI API返回错误 - Endpoint: {endpoint}, Status: {e.response.status_code}, Response: {response_body[:200]}"
147 | )
148 |
149 | raise OpenAIClientError(
150 | get_error_response(
151 | status_code=e.response.status_code,
152 | message=response_body,
153 | details={"type": "http_error"},
154 | )
155 | )
156 |
157 | except httpx.TimeoutException as e:
158 | bound_logger.error(
159 | f"OpenAI API request timeout - Endpoint: {endpoint}, Timeout: {self.timeout}s"
160 | )
161 | raise OpenAIClientError(
162 | get_error_response(
163 | status_code=504,
164 | message=str(e),
165 | details={"type": "timeout_error", "original_error": str(e)},
166 | )
167 | )
168 |
169 | except httpx.ConnectError as e:
170 | bound_logger.error(
171 | f"OpenAI API connection error - Endpoint: {endpoint}, Error: {str(e)}"
172 | )
173 | raise OpenAIClientError(
174 | get_error_response(
175 | status_code=502,
176 | message=str(e),
177 | details={"type": "connection_error", "original_error": str(e)},
178 | )
179 | )
180 |
181 | async def send_streaming_request(
182 | self,
183 | request: OpenAIRequest,
184 | endpoint: str = "/chat/completions",
185 | request_id: str = None,
186 | ) -> AsyncGenerator[str, None]:
187 | """Send streaming request to OpenAI API.
188 |
189 | Args:
190 | request: OpenAI request object
191 | endpoint: API endpoint path
192 | request_id: 请求ID用于日志追踪
193 |
194 | Yields:
195 | 原始的Server-Sent Events数据行
196 |
197 | Raises:
198 | OpenAIClientError: 当API返回错误时
199 | """
200 | # 获取绑定了请求ID的logger
201 | from src.common.logging import get_logger_with_request_id
202 |
203 | bound_logger = get_logger_with_request_id(request_id)
204 |
205 | url = f"{self.base_url}{endpoint}"
206 |
207 | # Ensure streaming is enabled
208 | request_dict = request.model_dump(exclude_none=True)
209 | request_dict["stream"] = True
210 |
211 | # 记录流式请求详情
212 | bound_logger.info(
213 | f"发送OpenAI流式请求 - URL: {url}, Model: {request_dict.get('model', 'unknown')}, Messages: {len(request_dict.get('messages', []))}, Stream: True"
214 | )
215 |
216 | try:
217 | async with self.client.stream(
218 | "POST",
219 | url,
220 | json=request_dict,
221 | ) as response:
222 | response.raise_for_status()
223 |
224 | # 记录流式响应开始
225 | content_type = response.headers.get("content-type", "unknown")
226 | bound_logger.info(
227 | f"开始接收OpenAI流式响应 - Status: {response.status_code}, Content-Type: {content_type}"
228 | )
229 |
230 | buffer = ""
231 |
232 | async for chunk_bytes in response.aiter_bytes(chunk_size=1024):
233 | chunk_text = chunk_bytes.decode("utf-8", errors="ignore")
234 | buffer += chunk_text
235 |
236 | # 处理完整的行
237 | while "\n" in buffer:
238 | line, buffer = buffer.split("\n", 1)
239 | line = line.strip()
240 |
241 | # 直接转发非空行
242 | if line:
243 | # logger.debug(f"Forwarding line: {line}")
244 | yield line
245 | # 检查是否结束
246 | if line == "data: [DONE]":
247 | return
248 |
249 | # 处理最后可能剩余的数据
250 | if buffer.strip():
251 | yield buffer.strip()
252 |
253 | except httpx.HTTPStatusError as e:
254 | error_body = ""
255 | try:
256 | # 尝试读取完整的错误响应体
257 | error_body = await e.response.aread()
258 | error_body = error_body.decode("utf-8", errors="ignore")
259 | except Exception as read_error:
260 | error_body = f"无法读取错误响应: {str(read_error)}"
261 |
262 | # 记录完整错误信息,但在日志中截断过长内容
263 | error_summary = (
264 | error_body[:500] + "..." if len(error_body) > 500 else error_body
265 | )
266 | bound_logger.error(
267 | f"OpenAI API 错误 - Status: {e.response.status_code}, URL: {url}"
268 | )
269 | bound_logger.error(f"Error Response: {error_summary}")
270 | raise OpenAIClientError(
271 | get_error_response(
272 | status_code=e.response.status_code,
273 | message=f"HTTP {e.response.status_code} error",
274 | details={"type": "http_error"},
275 | )
276 | )
277 |
278 | except httpx.TimeoutException as e:
279 | bound_logger.error(f"OpenAI API 超时 - Error: {str(e)}")
280 | raise OpenAIClientError(
281 | get_error_response(
282 | status_code=504,
283 | message="Request timeout",
284 | details={"type": "timeout_error"},
285 | )
286 | )
287 |
288 | except httpx.ConnectError as e:
289 | bound_logger.error(f"OpenAI API 连接错误 - Error: {str(e)}")
290 | raise OpenAIClientError(
291 | get_error_response(
292 | status_code=502,
293 | message="Connection error",
294 | details={"type": "connection_error"},
295 | )
296 | )
297 |
298 | async def _parse_streaming_chunk(
299 | self, chunk_data: str, tool_calls_state: dict
300 | ) -> OpenAIStreamResponse | None:
301 | """解析流式响应chunk,优雅处理不完整的JSON数据。
302 |
303 | Args:
304 | chunk_data: JSON字符串的响应块
305 | tool_calls_state: 预留参数(未使用)
306 |
307 | Returns:
308 | 解析后的响应对象,如果数据不完整则返回None
309 | """
310 | import json
311 |
312 | try:
313 | # 尝试解析JSON数据
314 | raw_data = json.loads(chunk_data)
315 |
316 | result = OpenAIStreamResponse.model_validate(raw_data)
317 | return result
318 |
319 | except json.JSONDecodeError as e:
320 | # JSON解析失败,通常是因为数据被分割,静默跳过
321 | logger.debug(
322 | f"Skipping incomplete JSON chunk - Error: {str(e)}, Data: {chunk_data[:100]}"
323 | )
324 | return None
325 | except Exception as e:
326 | # Pydantic验证失败,可能是tool_calls的增量数据不完整
327 | logger.debug(
328 | f"Skipping chunk due to validation error - Error: {str(e)}, Data: {chunk_data[:100]}"
329 | )
330 | return None
331 |
332 | async def health_check(self) -> dict[str, bool]:
333 | """Check OpenAI API availability.
334 |
335 | Returns:
336 | 健康检查结果
337 | """
338 | try:
339 | url = f"{self.base_url}/models"
340 | response = await self.client.get(url)
341 |
342 | return {
343 | "openai_service": response.status_code == 200,
344 | "api_accessible": True,
345 | "last_check": True,
346 | }
347 |
348 | except Exception as e:
349 | logger.exception(f"OpenAI health check failed - Error: {str(e)}")
350 | return {
351 | "openai_service": False,
352 | "api_accessible": False,
353 | "last_check": True,
354 | }
355 |
--------------------------------------------------------------------------------
/src/core/converters/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 转换器模块
3 |
4 | 提供Anthropic和OpenAI API格式之间的数据转换功能。
5 | """
6 |
7 | from .request_converter import AnthropicToOpenAIConverter
8 | from .response_converter import OpenAIToAnthropicConverter
9 |
10 | __all__ = ["AnthropicToOpenAIConverter", "OpenAIToAnthropicConverter"]
11 |
--------------------------------------------------------------------------------
/src/core/converters/response_converter.py:
--------------------------------------------------------------------------------
1 | """
2 | OpenAI-to-Anthropic 响应转换器
3 |
4 | 实现将OpenAI格式的响应转换为Anthropic格式的功能
5 | """
6 |
7 | import json
8 | import time
9 | import traceback
10 | from collections.abc import AsyncIterator
11 | from typing import Any
12 |
13 | from src.common.token_cache import get_cached_tokens
14 |
15 | from .stream_converters import (
16 | StreamState,
17 | _log_stream_completion_details,
18 | format_event,
19 | process_finish_event,
20 | process_regular_content,
21 | process_thinking_content,
22 | process_tool_calls,
23 | safe_json_parse,
24 | )
25 |
26 |
27 | from src.models.anthropic import (
28 | AnthropicContentBlock,
29 | AnthropicContentTypes,
30 | AnthropicMessageResponse,
31 | AnthropicMessageTypes,
32 | AnthropicRoles,
33 | AnthropicStreamEventTypes,
34 | AnthropicStreamMessage,
35 | AnthropicStreamMessageStartMessage,
36 | AnthropicUsage,
37 | )
38 | from src.models.openai import (
39 | OpenAIChoice,
40 | OpenAIMessage,
41 | )
42 |
43 |
44 | class OpenAIToAnthropicConverter:
45 | """OpenAI响应到Anthropic格式的转换器"""
46 |
47 | @staticmethod
48 | async def convert_response(
49 | openai_response: dict[str, Any],
50 | original_model: str = None,
51 | request_id: str = None,
52 | ) -> AnthropicMessageResponse:
53 | """
54 | 将OpenAI非流式响应转换为Anthropic格式
55 |
56 | Args:
57 | openai_response: OpenAI响应字典
58 | original_model: 原始请求的Anthropic模型
59 | request_id: 请求ID,用于获取缓存的token数量
60 |
61 | Returns:
62 | AnthropicMessageResponse: 转换后的Anthropic格式响应
63 | """
64 | choices = openai_response.get("choices")
65 | if not choices:
66 | raise ValueError("OpenAI响应没有有效的choices")
67 |
68 | # 使用第一个choice作为主要响应
69 | first_choice_data = choices[0]
70 | message_data = first_choice_data.get("message", {})
71 | choice = OpenAIChoice(
72 | message=(
73 | OpenAIMessage(
74 | role=message_data.get("role"),
75 | content=message_data.get("content", ""),
76 | tool_calls=message_data.get("tool_calls"),
77 | )
78 | if message_data
79 | else None
80 | ),
81 | finish_reason=first_choice_data.get("finish_reason"),
82 | index=first_choice_data.get("index", 0),
83 | )
84 |
85 | # 提取内容块
86 | content_blocks = (
87 | OpenAIToAnthropicConverter._extract_content_blocks_with_reasoning(
88 | choice, first_choice_data
89 | )
90 | )
91 |
92 | # 转换使用统计
93 | usage_data = openai_response.get("usage", {})
94 | usage = OpenAIToAnthropicConverter._convert_usage(
95 | usage_data, request_id, content_blocks
96 | )
97 |
98 | # 确定模型ID
99 | model = original_model if original_model else openai_response.get("model")
100 |
101 | mapping = {
102 | "stop": "end_turn",
103 | "length": "max_tokens",
104 | "content_filter": "content_filter",
105 | "tool_calls": "tool_use",
106 | "function_call": "tool_use",
107 | }
108 |
109 | # 映射完成原因
110 | stop_reason = mapping.get(choice.finish_reason, "end_turn")
111 |
112 | return AnthropicMessageResponse(
113 | id=openai_response.get("id", ""),
114 | type=AnthropicMessageTypes.MESSAGE,
115 | role=AnthropicRoles.ASSISTANT,
116 | content=content_blocks,
117 | model=model,
118 | stop_reason=stop_reason,
119 | stop_sequence=None,
120 | usage=usage,
121 | )
122 |
123 | @staticmethod
124 | def _extract_content_blocks_with_reasoning(
125 | choice, first_choice_data
126 | ) -> list[AnthropicContentBlock]:
127 | """
128 | 从OpenAI choice中提取内容块,包括推理内容
129 |
130 | Args:
131 | choice: OpenAI选择对象
132 | first_choice_data: OpenAI choice的原始数据
133 |
134 | Returns:
135 | List[AnthropicContentBlock]: 内容块列表
136 | """
137 | if not choice.message:
138 | return []
139 |
140 | content_blocks = []
141 | message_data = first_choice_data.get("message", {})
142 |
143 | # 处理推理内容 - 作为独立的thinking类型内容块
144 | reasoning_content = message_data.get("reasoning_content")
145 | if (
146 | reasoning_content
147 | and isinstance(reasoning_content, str)
148 | and reasoning_content.strip()
149 | ):
150 | content_blocks.append(
151 | AnthropicContentBlock(
152 | type=AnthropicContentTypes.THINKING,
153 | thinking=reasoning_content.strip(),
154 | signature=f"{int(time.time()*1000)}",
155 | )
156 | )
157 |
158 | # 处理普通内容 - 作为独立的text类型内容块
159 | content_str = message_data.get("content", "")
160 | if content_str and isinstance(content_str, str) and content_str.strip():
161 | # 检查content中是否包含标签
162 | if "" in content_str and "" in content_str:
163 | # 分离思考内容和普通内容
164 | import re
165 |
166 | think_pattern = r"(.*?)"
167 | think_matches = re.findall(think_pattern, content_str, re.DOTALL)
168 |
169 | # 如果还没有添加thinking块且找到了思考内容,添加thinking块
170 | if think_matches and not any(
171 | block.type == AnthropicContentTypes.THINKING
172 | for block in content_blocks
173 | ):
174 | thinking_content = think_matches[0].strip()
175 | if thinking_content:
176 | content_blocks.append(
177 | AnthropicContentBlock(
178 | type=AnthropicContentTypes.THINKING,
179 | thinking=thinking_content,
180 | signature=f"{int(time.time()*1000)}",
181 | )
182 | )
183 |
184 | # 移除标签,保留普通内容
185 | clean_content = re.sub(
186 | think_pattern, "", content_str, flags=re.DOTALL
187 | ).strip()
188 | if clean_content:
189 | content_blocks.append(
190 | AnthropicContentBlock(
191 | type=AnthropicContentTypes.TEXT, text=clean_content
192 | )
193 | )
194 | else:
195 | # 没有思考标签,直接作为普通内容
196 | content_blocks.append(
197 | AnthropicContentBlock(
198 | type=AnthropicContentTypes.TEXT, text=content_str.strip()
199 | )
200 | )
201 |
202 | # 处理工具调用
203 | if choice.message.tool_calls:
204 | from src.models.openai import OpenAIToolCall
205 |
206 | for tool_call_data in choice.message.tool_calls:
207 | tool_call = OpenAIToolCall.model_validate(tool_call_data)
208 | if hasattr(tool_call, "function") and tool_call.function:
209 | content_blocks.append(
210 | AnthropicContentBlock(
211 | type=AnthropicContentTypes.TOOL_USE,
212 | id=tool_call.id,
213 | name=tool_call.function.name,
214 | input=(
215 | safe_json_parse(tool_call.function.arguments)
216 | if tool_call.function.arguments
217 | else {}
218 | ),
219 | )
220 | )
221 |
222 | # 如果没有任何内容,返回空的text块
223 | if not content_blocks:
224 | content_blocks = [
225 | AnthropicContentBlock(type=AnthropicContentTypes.TEXT, text="")
226 | ]
227 |
228 | return content_blocks
229 |
230 | @staticmethod
231 | def _convert_usage(
232 | usage_data: dict[str, Any], request_id: str = None, content_blocks: list = None
233 | ) -> AnthropicUsage:
234 | """
235 | 将OpenAI使用统计转换为Anthropic格式,支持缓存fallback
236 |
237 | Args:
238 | usage_data: OpenAI使用统计数据
239 | request_id: 请求ID,用于获取缓存的token数量
240 | content_blocks: 内容块列表,用于计算输出token数量
241 |
242 | Returns:
243 | AnthropicUsage: Anthropic格式的使用统计
244 | """
245 | prompt_tokens = usage_data.get("prompt_tokens", 0) if usage_data else 0
246 | completion_tokens = usage_data.get("completion_tokens", 0) if usage_data else 0
247 |
248 | # 如果OpenAI没有返回prompt_tokens,使用缓存的值
249 | if not prompt_tokens and request_id:
250 | from src.common.token_cache import get_cached_tokens
251 |
252 | cached_tokens = get_cached_tokens(request_id)
253 | if cached_tokens:
254 | prompt_tokens = cached_tokens
255 |
256 | # 如果OpenAI没有返回completion_tokens,使用我们的计算方法
257 | if not completion_tokens and content_blocks:
258 | from src.common.token_counter import token_counter
259 |
260 | # 使用同步版本,保持简单性(KISS原则)
261 | completion_tokens = token_counter.count_response_tokens(content_blocks)
262 |
263 | return AnthropicUsage(
264 | input_tokens=prompt_tokens,
265 | output_tokens=completion_tokens,
266 | )
267 |
268 | @staticmethod
269 | async def convert_openai_stream_to_anthropic_stream(
270 | openai_stream: AsyncIterator[str],
271 | model: str = "unknown",
272 | request_id: str = None,
273 | ) -> AsyncIterator[str]:
274 | """将 OpenAI 流式响应转换为 Anthropic 流式响应格式
275 |
276 | Args:
277 | openai_stream: OpenAI 流式数据源
278 | model: 模型名称
279 | request_id: 请求ID用于日志追踪
280 |
281 | Yields:
282 | str: Anthropic 格式的流式事件字符串
283 | """
284 | # 获取绑定了请求ID的logger
285 | from src.common.logging import get_logger_with_request_id
286 |
287 | bound_logger = get_logger_with_request_id(request_id)
288 |
289 | state = StreamState()
290 |
291 | try:
292 | async for chunk in openai_stream:
293 | if state.has_finished:
294 | break
295 |
296 | state.buffer += chunk
297 | lines = state.buffer.split("\n")
298 | state.buffer = lines.pop() if lines else ""
299 |
300 | for line in lines:
301 | if state.has_finished:
302 | break
303 |
304 | if not line.startswith("data: "):
305 | continue
306 |
307 | data = line[6:]
308 | if data == "[DONE]":
309 | continue
310 | try:
311 | chunk_data = json.loads(data)
312 | state.total_chunks += 1
313 | # 处理错误
314 | if "error" in chunk_data:
315 | error_event = {
316 | "type": "error",
317 | "message": {
318 | "type": "api_error",
319 | "message": json.dumps(chunk_data["error"]),
320 | },
321 | }
322 | yield format_event("error", error_event)
323 | continue
324 |
325 | # 发送 message_start 事件
326 | if not state.has_started and not state.has_finished:
327 | # 获取input token缓存
328 | input_tokens = 0
329 | cached_tokens = get_cached_tokens(request_id)
330 | if cached_tokens:
331 | input_tokens = cached_tokens
332 | state.has_started = True
333 | message_start = AnthropicStreamMessage(
334 | message=AnthropicStreamMessageStartMessage(
335 | id=state.message_id,
336 | model=model,
337 | usage=AnthropicUsage(input_tokens=input_tokens),
338 | ),
339 | )
340 | yield format_event(
341 | AnthropicStreamEventTypes.MESSAGE_START,
342 | message_start.model_dump(exclude=["delta", "usage"]),
343 | )
344 |
345 | choices = chunk_data.get("choices", [])
346 | if not choices:
347 | continue
348 |
349 | choice = choices[0]
350 | delta = choice.get("delta", None)
351 | if delta is None:
352 | continue
353 |
354 | content = delta.get("content", None)
355 | reasoning_content = delta.get("reasoning_content", None)
356 | tool_calls = delta.get("tool_calls", None)
357 |
358 | # 检查是否有任何内容需要处理
359 | has_content = content is not None and content != ""
360 | has_reasoning = (
361 | reasoning_content is not None and reasoning_content != ""
362 | )
363 | has_tool_calls = tool_calls is not None and len(tool_calls) > 0
364 |
365 | if not has_content and not has_reasoning and not has_tool_calls:
366 | if choice.get("finish_reason") is None:
367 | continue
368 |
369 | # 处理思考内容
370 | events = process_thinking_content(delta, state)
371 | if events:
372 | for event in events:
373 | yield event
374 |
375 | # 处理普通文本内容
376 | events = process_regular_content(delta, state)
377 | if events:
378 | for event in events:
379 | yield event
380 | continue
381 |
382 | # 处理工具调用
383 | if has_tool_calls:
384 | events = process_tool_calls(delta, state)
385 | if events:
386 | for event in events:
387 | yield event
388 | continue
389 |
390 | # 处理完成事件
391 | finish_reason = choice.get("finish_reason")
392 | if finish_reason:
393 | finish_events = process_finish_event(
394 | chunk_data, state, request_id
395 | )
396 | for event in finish_events:
397 | yield event
398 |
399 | # 在所有事件生成完成后记录详细日志(遵循KISS原则)
400 | _log_stream_completion_details(
401 | state,
402 | request_id,
403 | model,
404 | )
405 |
406 | except json.JSONDecodeError as parse_error:
407 | bound_logger.error(
408 | f"Parse error - Error: {str(parse_error.args[0])}, Data: {data[:100]}",
409 | exc_info=True,
410 | )
411 | except Exception as e:
412 | bound_logger.error(
413 | f"Unexpected error processing chunk - Error: {str(e)}",
414 | exc_info=True,
415 | )
416 | traceback.print_exc()
417 |
418 | except Exception as error:
419 | bound_logger.error(
420 | f"Stream conversion error - Error: {str(error)}", exc_info=True
421 | )
422 | error_event = {
423 | "type": "error",
424 | "message": {"type": "api_error", "message": str(error)},
425 | }
426 | yield format_event("error", error_event)
427 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | from contextlib import asynccontextmanager
2 |
3 | from fastapi import FastAPI
4 | from fastapi.middleware.cors import CORSMiddleware
5 | from loguru import logger
6 |
7 | from src.api.handlers import router as messages_router
8 | from src.api.middleware.auth import APIKeyMiddleware
9 | from src.api.middleware.timing import setup_middlewares
10 | from src.api.routes import router as health_router
11 |
12 | # 启动时同步加载配置(模块级别,应用启动时执行)
13 | from .common.logging import get_request_id_from_request
14 | from src.config.settings import Config
15 |
16 | config = Config.from_file_sync()
17 |
18 |
19 | @asynccontextmanager
20 | async def lifespan(app: FastAPI):
21 | """应用生命周期管理"""
22 | from src.api.handlers import MessagesHandler
23 | from src.common.logging import configure_logging
24 | from src.config.settings import get_config_file_path, reload_config
25 | from src.config.watcher import ConfigWatcher
26 |
27 | host, port = config.get_server_config()
28 |
29 | # 配置Loguru日志
30 | configure_logging(config.logging)
31 |
32 | # 创建配置化的消息处理器并缓存到应用中
33 | app.state.messages_handler = await MessagesHandler.create(config)
34 |
35 | # 配置重载回调函数
36 | async def on_config_reload():
37 | """配置重载时的回调函数"""
38 | try:
39 | # 重新加载配置
40 | new_config = await reload_config()
41 |
42 | # 重新配置日志
43 | configure_logging(new_config.logging)
44 |
45 | # 重新创建消息处理器
46 | app.state.messages_handler = await MessagesHandler.create(new_config)
47 |
48 | logger.info("配置热重载完成,服务已更新")
49 | except Exception as e:
50 | logger.error(f"配置热重载失败: {e}")
51 |
52 | # 启动配置文件监听器
53 | config_watcher = ConfigWatcher(get_config_file_path())
54 | config_watcher.add_reload_callback(on_config_reload)
55 | await config_watcher.start_watching()
56 |
57 | # 缓存监听器到应用状态
58 | app.state.config_watcher = config_watcher
59 |
60 | logger.info(
61 | f"启动 OpenAI To Claude 服务器 - Host: {host}, Port: {port}, LogLevel: {config.logging.level}"
62 | )
63 | logger.info(f"配置文件监听已启用: {get_config_file_path()}")
64 |
65 | yield
66 |
67 | # 关闭时的清理工作
68 | logger.info("正在停止配置文件监听...")
69 | if hasattr(app.state, "config_watcher"):
70 | app.state.config_watcher.stop_watching()
71 | logger.info("服务器已停止")
72 |
73 |
74 | app = FastAPI(
75 | title="OpenAI To Claude Server",
76 | version="0.1.0",
77 | description="A server to convert OpenAI API calls to Claude format.",
78 | lifespan=lifespan,
79 | )
80 |
81 | # 设置CORS中间件
82 | app.add_middleware(
83 | CORSMiddleware,
84 | allow_origins=["*"], # 允许所有来源,生产环境建议指定具体域名
85 | allow_credentials=True,
86 | allow_methods=["*"], # 允许所有HTTP方法
87 | allow_headers=["*"], # 允许所有请求头
88 | )
89 |
90 | # 设置其他中间件
91 | setup_middlewares(app)
92 | app.add_middleware(APIKeyMiddleware, api_key=config.api_key)
93 |
94 | app.include_router(health_router)
95 | app.include_router(messages_router)
96 |
97 |
98 | @app.get("/")
99 | async def root():
100 | return {"message": "Welcome to the OpenAI To Claude Server"}
101 |
102 |
103 | from fastapi.responses import JSONResponse
104 | from src.models.errors import get_error_response
105 |
106 |
107 | # 全局异常处理程序
108 | @app.exception_handler(Exception)
109 | async def global_exception_handler(request, exc):
110 | """全局异常处理,防止Internal Server Error直接返回给客户端"""
111 | from src.common.logging import (
112 | get_logger_with_request_id,
113 | get_request_id_from_request,
114 | )
115 |
116 | request_id = get_request_id_from_request(request)
117 | bound_logger = get_logger_with_request_id(request_id)
118 |
119 | bound_logger.exception(
120 | "捕获未处理的服务器异常",
121 | error_type=type(exc).__name__,
122 | error_message=str(exc),
123 | request_path=str(request.url),
124 | request_method=request.method,
125 | )
126 |
127 | # 使用标准错误响应格式返回客户端
128 | error_response = get_error_response(500, message="服务器内部错误,请稍后重试")
129 | return JSONResponse(status_code=500, content=error_response.model_dump())
130 |
131 |
132 | # Pydantic验证错误处理
133 | @app.exception_handler(422)
134 | async def validation_exception_handler(request, exc):
135 | """处理Pydantic验证错误"""
136 | from src.common.logging import get_logger_with_request_id
137 |
138 | request_id = get_request_id_from_request(request)
139 | bound_logger = get_logger_with_request_id(request_id)
140 |
141 | bound_logger.warning("请求验证失败")
142 | return JSONResponse(status_code=422, content=exc.detail)
143 |
144 |
145 | # 404错误处理
146 | @app.exception_handler(404)
147 | async def not_found_handler(request, exc):
148 | """处理404错误"""
149 | error_response = get_error_response(404, message="请求的资源不存在")
150 |
151 | return JSONResponse(status_code=404, content=error_response.model_dump())
152 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | # 数据模型模块 - 定义所有API相关的数据模型
2 |
3 | # Anthropic模型
4 | from .anthropic import (
5 | AnthropicContentBlock,
6 | AnthropicContentTypes,
7 | AnthropicErrorDetail,
8 | AnthropicErrorResponse,
9 | AnthropicMessage,
10 | # 数据模型
11 | AnthropicMessageContent,
12 | AnthropicMessageResponse,
13 | AnthropicMessageTypes,
14 | AnthropicPing,
15 | AnthropicRequest,
16 | AnthropicRoles,
17 | # 流式响应相关
18 | AnthropicStreamContentBlock,
19 | # 常量类
20 | AnthropicStreamEventTypes,
21 | AnthropicSystemMessage,
22 | AnthropicTextContent,
23 | AnthropicToolDefinition,
24 | AnthropicToolUse,
25 | AnthropicUsage,
26 | ContentBlock,
27 | Delta,
28 | )
29 |
30 | # 错误模型
31 | from .errors import (
32 | ERROR_CODE_MAPPING,
33 | BadRequestError,
34 | ErrorDetail,
35 | ExternalServiceError,
36 | NotFoundError,
37 | RateLimitError,
38 | ServerError,
39 | ServiceUnavailableError,
40 | StandardErrorResponse,
41 | TimeoutError,
42 | UnauthorizedError,
43 | ValidationError,
44 | ValidationErrorItem,
45 | ValidationErrorResponse,
46 | get_error_response,
47 | )
48 |
49 | # OpenAI模型
50 | from .openai import (
51 | OpenAIChoice,
52 | OpenAIChoiceDelta,
53 | OpenAICompletionUsage,
54 | OpenAIErrorDetail,
55 | OpenAIErrorResponse,
56 | OpenAIImageUrl,
57 | OpenAIMessage,
58 | OpenAIMessageContent,
59 | OpenAIRequest,
60 | OpenAIResponse,
61 | OpenAIStreamResponse,
62 | OpenAITool,
63 | OpenAIToolCall,
64 | OpenAIToolCallFunction,
65 | OpenAIToolFunction,
66 | OpenAIUsage,
67 | )
68 |
69 | __all__ = [
70 | # Anthropic 常量类
71 | "AnthropicStreamEventTypes",
72 | "AnthropicContentTypes",
73 | "AnthropicMessageTypes",
74 | "AnthropicRoles",
75 | # Anthropic 数据模型
76 | "AnthropicMessageContent",
77 | "AnthropicMessage",
78 | "AnthropicSystemMessage",
79 | "AnthropicToolDefinition",
80 | "AnthropicRequest",
81 | "AnthropicToolUse",
82 | "AnthropicTextContent",
83 | "AnthropicContentBlock",
84 | "AnthropicUsage",
85 | "AnthropicMessageResponse",
86 | "AnthropicErrorDetail",
87 | "AnthropicErrorResponse",
88 | "AnthropicPing",
89 | "AnthropicStreamContentBlock",
90 | "Delta",
91 | "ContentBlock",
92 | # OpenAI
93 | "OpenAIMessageContent",
94 | "OpenAIMessage",
95 | "OpenAIImageUrl",
96 | "OpenAIToolCallFunction",
97 | "OpenAIToolCall",
98 | "OpenAIToolFunction",
99 | "OpenAITool",
100 | "OpenAIRequest",
101 | "OpenAIChoiceDelta",
102 | "OpenAIChoice",
103 | "OpenAIUsage",
104 | "OpenAICompletionUsage",
105 | "OpenAIResponse",
106 | "OpenAIStreamResponse",
107 | "OpenAIErrorDetail",
108 | "OpenAIErrorResponse",
109 | # 错误处理
110 | "ErrorDetail",
111 | "StandardErrorResponse",
112 | "ValidationErrorItem",
113 | "ValidationError",
114 | "ValidationErrorResponse",
115 | "UnauthorizedError",
116 | "RateLimitError",
117 | "ServerError",
118 | "TimeoutError",
119 | "NotFoundError",
120 | "BadRequestError",
121 | "ServiceUnavailableError",
122 | "ExternalServiceError",
123 | "ERROR_CODE_MAPPING",
124 | "get_error_response",
125 | ]
126 |
--------------------------------------------------------------------------------
/src/models/anthropic.py:
--------------------------------------------------------------------------------
1 | """Anthropic API 数据模型定义"""
2 |
3 | from typing import Any, Literal, Optional, Union
4 |
5 | from pydantic import BaseModel, Field
6 |
7 |
8 | class AnthropicStreamEventTypes:
9 | """Anthropic流式响应事件类型常量"""
10 |
11 | # 消息相关事件
12 | MESSAGE_START = "message_start"
13 | MESSAGE_DELTA = "message_delta"
14 | MESSAGE_STOP = "message_stop"
15 |
16 | # 内容块相关事件
17 | CONTENT_BLOCK_START = "content_block_start"
18 | CONTENT_BLOCK_DELTA = "content_block_delta"
19 | CONTENT_BLOCK_STOP = "content_block_stop"
20 |
21 | # 其他事件
22 | PING = "ping"
23 |
24 |
25 | class AnthropicContentTypes:
26 | """Anthropic内容类型常量"""
27 |
28 | # 基础内容类型
29 | TEXT = "text"
30 | IMAGE = "image"
31 | TOOL_USE = "tool_use"
32 | TOOL_RESULT = "tool_result"
33 | THINKING = "thinking"
34 |
35 | # 增量类型
36 | TEXT_DELTA = "text_delta"
37 | INPUT_JSON_DELTA = "input_json_delta"
38 | THINKING_DELTA = "thinking_delta"
39 | SIGNATURE_DELTA = "signature_delta"
40 |
41 |
42 | class AnthropicMessageTypes:
43 | """Anthropic消息类型常量"""
44 |
45 | MESSAGE = "message"
46 | ERROR = "error"
47 |
48 |
49 | class AnthropicRoles:
50 | """Anthropic角色常量"""
51 |
52 | USER = "user"
53 | ASSISTANT = "assistant"
54 |
55 |
56 | class AnthropicMessageContent(BaseModel):
57 | """Anthropic消息内容项"""
58 |
59 | type: Literal["text", "thinking", "image", "tool_use", "tool_result"] = Field(
60 | description="内容类型"
61 | )
62 | text: str | None = Field(None, description="文本内容(当type为text时)")
63 | source: dict[str, Any] | None = Field(None, description="当type为image时的源信息")
64 | id: str | None = Field(None, description="工具调用ID(当type为tool_use时)")
65 | name: str | None = Field(None, description="工具名称(当type为tool_use时)")
66 | input: dict[str, Any] | None = Field(
67 | None, description="工具输入参数(当type为tool_use时)"
68 | )
69 | tool_use_id: str | None = Field(
70 | None, description="工具使用ID(当type为tool_result时)"
71 | )
72 | content: str | list[dict[str, Any]] | None = Field(
73 | None, description="工具结果内容(当type为tool_result时)"
74 | )
75 | is_error: bool | None = Field(
76 | None, description="工具调用是否为错误结果(当type为tool_result时)"
77 | )
78 |
79 |
80 | class AnthropicMessage(BaseModel):
81 | """Anthropic消息格式"""
82 |
83 | role: Literal["user", "assistant"] = Field(description="消息角色")
84 | content: str | list[AnthropicMessageContent] = Field(description="消息内容")
85 |
86 |
87 | class AnthropicSystemMessage(BaseModel):
88 | """Anthropic系统消息"""
89 |
90 | type: Literal["text"] = Field(
91 | default=AnthropicContentTypes.TEXT, description="系统消息类型,固定为text"
92 | )
93 | text: str = Field(description="系统消息文本内容")
94 |
95 |
96 | class AnthropicToolDefinition(BaseModel):
97 | """Anthropic工具定义"""
98 |
99 | name: str = Field(description="工具名称")
100 | description: Optional[str] = Field(None, description="工具描述")
101 | input_schema: Optional[dict[str, Any]] = Field(
102 | None, description="JSON Schema格式的输入参数定义"
103 | )
104 | type: Optional[str] = Field(None, description="工具名称")
105 | max_uses: Optional[int] = Field(None, description="工具名称")
106 |
107 |
108 | class AnthropicRequest(BaseModel):
109 | """Anthropic API请求模型"""
110 |
111 | model: str = Field(description="使用的模型ID,如claude-3-5-sonnet-20241022")
112 | messages: list[AnthropicMessage] = Field(description="对话消息列表")
113 | max_tokens: int = Field(description="最大输出token数量")
114 | system: str | list[AnthropicSystemMessage] | None = Field(
115 | None, description="系统提示信息"
116 | )
117 | tools: list[AnthropicToolDefinition] | None = Field(
118 | None, description="可用工具定义"
119 | )
120 | tool_choice: str | dict[str, Any] | None = Field(None, description="工具选择配置")
121 | metadata: dict[str, Any] | None = Field(None, description="可选元数据")
122 | stop_sequences: list[str] | None = Field(None, description="停止序列")
123 | stream: bool | None = Field(False, description="是否使用流式响应")
124 | temperature: float | None = Field(None, ge=0.0, le=1.0, description="采样温度")
125 | top_p: float | None = Field(None, ge=0.0, le=1.0, description="top-p采样参数")
126 | top_k: int | None = Field(None, ge=1, le=1000, description="top-k采样参数")
127 | thinking: bool | dict[str, Any] | None = Field(
128 | None, description="是否启用推理模型模式或配置对象"
129 | )
130 |
131 |
132 | class AnthropicToolUse(BaseModel):
133 | """工具使用响应"""
134 |
135 | id: str = Field(description="工具调用ID")
136 | type: str = Field(default=AnthropicContentTypes.TOOL_USE, description="响应类型")
137 | name: str = Field(description="工具名称")
138 | input: dict[str, Any] = Field(description="工具输入参数")
139 |
140 |
141 | class AnthropicTextContent(BaseModel):
142 | """文本内容响应"""
143 |
144 | type: str = Field(default=AnthropicContentTypes.TEXT, description="内容类型")
145 | text: str = Field(description="文本内容")
146 |
147 |
148 | class AnthropicContentBlock(BaseModel):
149 | """Anthropic内容块"""
150 |
151 | type: Literal["text", "tool_use", "thinking"]
152 | text: str | None = Field(None, description="文本内容,当type为text时")
153 | id: str | None = Field(None, description="工具调用ID,当type为tool_use时")
154 | name: str | None = Field(None, description="工具名称,当type为tool_use时")
155 | input: dict[str, Any] | None = Field(
156 | None, description="工具输入,当type为tool_use时"
157 | )
158 | thinking: str | None = Field(None, description="思考内容,当type为thinking时")
159 | signature: str | None = Field(None, description="思考内容签名,当type为thinking时")
160 |
161 |
162 | class AnthropicUsage(BaseModel):
163 | """Anthropic使用统计"""
164 |
165 | input_tokens: int = Field(1, description="输入token数量")
166 | output_tokens: int | None = Field(1, description="输出token数量")
167 | cache_creation_input_tokens: int | None = Field(
168 | 0, description="缓存创建输入token数量"
169 | )
170 | cache_read_input_tokens: int | None = Field(0, description="缓存读取输入token数量")
171 | service_tier: str | None = Field("standard", description="服务层级")
172 |
173 |
174 | class MessageDelta(BaseModel):
175 | """消息增量"""
176 |
177 | stop_reason: str | None = Field(None, description="停止原因")
178 | stop_sequence: str | None = Field(None, description="停止序列")
179 |
180 |
181 | class AnthropicStreamMessageStartMessage(BaseModel):
182 | """Anthropic流式消息开始事件中的消息详情"""
183 |
184 | id: str = Field(description="消息ID")
185 | type: str = Field(default=AnthropicMessageTypes.MESSAGE, description="消息类型")
186 | role: Literal["assistant"] = Field(
187 | default=AnthropicRoles.ASSISTANT, description="消息角色"
188 | )
189 | model: str = Field(description="使用的模型ID")
190 | content: list[Any] = Field([], description="内容块,通常为空")
191 | stop_reason: str | None = Field(None, description="停止原因")
192 | stop_sequence: str | None = Field(None, description="停止序列")
193 | usage: AnthropicUsage = Field(description="使用统计")
194 |
195 |
196 | class AnthropicStreamMessage(BaseModel):
197 | """流式消息开始事件"""
198 |
199 | type: str = Field(
200 | default=AnthropicStreamEventTypes.MESSAGE_START, description="事件类型"
201 | )
202 | message: AnthropicStreamMessageStartMessage = Field(None, description="消息详情")
203 | delta: MessageDelta = Field(None, description="消息增量")
204 | usage: AnthropicUsage = Field(None, description="使用统计")
205 |
206 |
207 | class Delta(BaseModel):
208 | """文本增量"""
209 |
210 | type: str = Field(default=AnthropicContentTypes.TEXT_DELTA, description="增量类型")
211 | text: str | None = Field(None, description="文本增量内容")
212 | thinking: str | None = Field(None, description="思考内容")
213 | signature: str | None = Field(None, description="签名内容")
214 | partial_json: str | None = Field(None, description="部分JSON字符串")
215 |
216 |
217 | class InputJsonDelta(BaseModel):
218 | """输入JSON增量"""
219 |
220 | type: str = Field(
221 | default=AnthropicContentTypes.INPUT_JSON_DELTA, description="增量类型"
222 | )
223 |
224 |
225 | class AnthropicUsageDelta(BaseModel):
226 | """使用统计增量"""
227 |
228 | output_tokens: int = Field(description="输出token数量增量")
229 |
230 |
231 | class AnthropicMessageResponse(BaseModel):
232 | """Anthropic消息响应"""
233 |
234 | id: str = Field(description="响应唯一ID")
235 | type: str = Field(default=AnthropicMessageTypes.MESSAGE, description="响应类型")
236 | role: Literal["assistant"] = Field(
237 | default=AnthropicRoles.ASSISTANT, description="消息角色"
238 | )
239 | content: list[AnthropicContentBlock] = Field(description="消息内容块")
240 | model: str = Field(description="使用的模型ID")
241 | stop_reason: str | None = Field(None, description="停止原因")
242 | stop_sequence: str | None = Field(None, description="停止序列")
243 | usage: AnthropicUsage = Field(description="使用统计")
244 |
245 |
246 | class AnthropicErrorDetail(BaseModel):
247 | """Anthropic错误详情"""
248 |
249 | type: str = Field(description="错误类型")
250 | message: str = Field(description="错误消息")
251 |
252 |
253 | class AnthropicErrorResponse(BaseModel):
254 | """Anthropic错误响应模型"""
255 |
256 | type: str = Field(default=AnthropicMessageTypes.ERROR, description="响应类型")
257 | error: AnthropicErrorDetail = Field(description="错误详情")
258 |
259 |
260 | class ContentBlock(BaseModel):
261 | """内容块"""
262 |
263 | type: str = Field(default=AnthropicContentTypes.TEXT, description="内容块类型")
264 | text: str | None = Field(None, description="文本内容")
265 | thinking: str | None = Field(None, description="思考内容")
266 | signature: str | None = Field(None, description="签名内容")
267 | # tool_use相关字段
268 | id: str | None = Field(None, description="工具调用ID,当type为tool_use时")
269 | name: str | None = Field(None, description="工具名称,当type为tool_use时")
270 | input: dict[str, Any] | None = Field(
271 | None, description="工具输入,当type为tool_use时"
272 | )
273 |
274 |
275 | class AnthropicStreamContentBlockStart(BaseModel):
276 | """流式内容块开始"""
277 |
278 | type: str = Field(
279 | default=AnthropicStreamEventTypes.CONTENT_BLOCK_START, description="事件类型"
280 | )
281 | index: int = Field(default=0, description="内容块索引")
282 | content_block: ContentBlock = Field(
283 | default_factory=lambda: ContentBlock(type=AnthropicContentTypes.TEXT, text="")
284 | )
285 |
286 |
287 | class AnthropicStreamContentBlock(BaseModel):
288 | """流式内容块增量"""
289 |
290 | type: str = Field(
291 | default=AnthropicStreamEventTypes.CONTENT_BLOCK_DELTA, description="事件类型"
292 | )
293 | index: int = Field(0, description="内容块索引")
294 | delta: Delta = Field(None, description="增量内容")
295 | content_block: ContentBlock = Field(None, description="内容块")
296 | usage: Delta | None = Field(None, description="使用统计")
297 |
298 |
299 | class AnthropicStreamContentBlockStop(BaseModel):
300 | """流式内容块结束"""
301 |
302 | type: str = Field(
303 | default=AnthropicStreamEventTypes.CONTENT_BLOCK_STOP, description="事件类型"
304 | )
305 | index: int = Field(0, description="内容块索引")
306 |
307 |
308 | class AnthropicPing(BaseModel):
309 | """流式ping消息"""
310 |
311 | type: str = Field(default=AnthropicStreamEventTypes.PING, description="事件类型")
312 |
313 |
314 | AnthropicStreamResponse = Union[
315 | AnthropicStreamMessage,
316 | AnthropicStreamContentBlockStart,
317 | AnthropicStreamContentBlock,
318 | AnthropicStreamContentBlockStop,
319 | AnthropicPing,
320 | ]
321 |
--------------------------------------------------------------------------------
/src/models/errors.py:
--------------------------------------------------------------------------------
1 | """标准化错误响应模型"""
2 |
3 | from typing import Any
4 |
5 | from pydantic import BaseModel, Field
6 |
7 |
8 | class ErrorDetail(BaseModel):
9 | """错误详细信息"""
10 |
11 | code: str = Field(description="错误代码")
12 | message: str = Field(description="错误消息")
13 | param: str | None = Field(None, description="相关参数名称")
14 | type: str | None = Field(None, description="错误类型")
15 | details: dict[str, Any] | None = Field(None, description="额外错误详情")
16 | request_id: str | None = Field(None, description="请求ID用于追踪")
17 |
18 |
19 | class StandardErrorResponse(BaseModel):
20 | """标准化错误响应模型"""
21 |
22 | type: str = Field("error", description="响应类型")
23 | error: ErrorDetail = Field(description="错误详情")
24 |
25 |
26 | class ValidationErrorItem(BaseModel):
27 | """验证错误项"""
28 |
29 | loc: list[str] = Field(description="错误位置字段路径")
30 | msg: str = Field(description="错误消息")
31 | type: str = Field(description="错误类型")
32 |
33 |
34 | class ValidationError(BaseModel):
35 | """验证错误详情"""
36 |
37 | code: str = Field("validation_error", description="错误代码")
38 | message: str = Field("请求参数验证失败", description="错误消息")
39 | details: list[ValidationErrorItem] = Field(..., description="验证错误列表")
40 |
41 |
42 | class ValidationErrorResponse(BaseModel):
43 | """验证错误响应模型"""
44 |
45 | type: str = Field("validation_error", description="响应类型")
46 | error: ValidationError = Field(description="验证错误详情")
47 |
48 |
49 | class UnauthorizedError(BaseModel):
50 | """未授权错误"""
51 |
52 | code: str = Field("unauthorized", description="错误代码")
53 | message: str = Field("无效的API密钥或未经授权的访问", description="错误消息")
54 | type: str = Field("authentication_error", description="错误类型")
55 |
56 |
57 | class RateLimitError(BaseModel):
58 | """限流错误"""
59 |
60 | code: str = Field("rate_limit_exceeded", description="错误代码")
61 | message: str = Field("请求频率超出限制,请稍后重试", description="错误消息")
62 | type: str = Field("rate_limit_error", description="错误类型")
63 | retry_after: int | None = Field(None, description="重试等待时间(秒)")
64 |
65 |
66 | class ServerError(BaseModel):
67 | """服务器内部错误"""
68 |
69 | code: str = Field("internal_server_error", description="错误代码")
70 | message: str = Field("服务器内部错误,请稍后重试", description="错误消息")
71 | type: str = Field("server_error", description="错误类型")
72 | request_id: str | None = Field(None, description="请求ID用于排查问题")
73 |
74 |
75 | class TimeoutError(BaseModel):
76 | """超时错误"""
77 |
78 | code: str = Field("timeout", description="错误代码")
79 | message: str = Field("请求超时,请稍后重试", description="错误消息")
80 | type: str = Field("timeout_error", description="错误类型")
81 | timeout: int | None = Field(None, description="超时时间(秒)")
82 |
83 |
84 | class NotFoundError(BaseModel):
85 | """资源未找到错误"""
86 |
87 | code: str = Field("not_found", description="错误代码")
88 | message: str = Field("请求的资源不存在", description="错误消息")
89 | type: str = Field("not_found_error", description="错误类型")
90 |
91 |
92 | class BadRequestError(BaseModel):
93 | """错误请求"""
94 |
95 | code: str = Field("bad_request", description="错误代码")
96 | message: str = Field("请求格式错误或参数无效", description="错误消息")
97 | type: str = Field("invalid_request_error", description="错误类型")
98 |
99 |
100 | class ServiceUnavailableError(BaseModel):
101 | """服务不可用"""
102 |
103 | code: str = Field("service_unavailable", description="错误代码")
104 | message: str = Field("服务暂时不可用,请稍后重试", description="错误消息")
105 | type: str = Field("server_error", description="错误类型")
106 | retry_after: int | None = Field(None, description="建议重试时间(秒)")
107 |
108 |
109 | class ExternalServiceError(BaseModel):
110 | """外部服务错误"""
111 |
112 | code: str = Field("external_service_error", description="错误代码")
113 | message: str = Field("外部服务错误,请稍后重试", description="错误消息")
114 | type: str = Field("api_error", description="错误类型")
115 | service: str | None = Field(None, description="出错的外部服务名称")
116 | original_error: dict[str, Any] | None = Field(
117 | None, description="原始错误信息(生产环境可能会省略)"
118 | )
119 |
120 |
121 | # 错误代码映射表
122 | ERROR_CODE_MAPPING = {
123 | 400: BadRequestError,
124 | 401: UnauthorizedError,
125 | 404: NotFoundError,
126 | 422: ValidationError,
127 | 429: RateLimitError,
128 | 500: ServerError,
129 | 502: ExternalServiceError,
130 | 503: ServiceUnavailableError,
131 | 504: TimeoutError,
132 | }
133 |
134 |
135 | def format_compact_traceback(error: Exception, max_lines: int = 10) -> str:
136 | """格式化紧凑的错误堆栈信息,只保留项目相关部分"""
137 | import traceback
138 |
139 | error_traceback = "".join(
140 | traceback.format_exception(type(error), error, error.__traceback__)
141 | )
142 | lines = error_traceback.split("\n")
143 |
144 | # 过滤出项目相关的行
145 | filtered_lines = []
146 | for line in lines:
147 | if (
148 | "openai-claude-code-proxy/src" in line
149 | or line.strip().startswith('File "/Users')
150 | or any(keyword in line for keyword in ["Error:", "Exception:", " "])
151 | or line.strip() == ""
152 | ):
153 | filtered_lines.append(line)
154 |
155 | # 只保留最后max_lines行
156 | return "\n".join(filtered_lines[-max_lines:]) if filtered_lines else str(error)
157 |
158 |
159 | def get_error_response(
160 | status_code: int,
161 | message: str | None = None,
162 | details: dict[str, Any] | None = None,
163 | ) -> StandardErrorResponse:
164 | """根据HTTP状态码获取对应的错误响应模型,返回Pydantic模型实例"""
165 | error_class = ERROR_CODE_MAPPING.get(status_code, ServerError)
166 |
167 | # 创建错误详情
168 | if error_class == ValidationError:
169 | # 验证错误需要特殊处理
170 | if details and "validation_errors" in details:
171 | validation_items = [
172 | {
173 | "loc": error.get("loc", []),
174 | "msg": error.get("msg", ""),
175 | "type": error.get("type", "value_error"),
176 | }
177 | for error in details["validation_errors"]
178 | ]
179 | error_detail_data = {
180 | "code": "validation_error",
181 | "message": message or "请求参数验证失败",
182 | "details": validation_items,
183 | }
184 | else:
185 | error_detail_data = {
186 | "code": "validation_error",
187 | "message": message or "请求参数验证失败",
188 | "details": [],
189 | }
190 | else:
191 | # 其他错误类型
192 | error_detail_data = {
193 | "code": error_class.model_fields["code"].default,
194 | "message": message or error_class.model_fields["message"].default,
195 | }
196 |
197 | if details:
198 | # Add optional fields if present in details
199 | for key in [
200 | "param",
201 | "type",
202 | "retry_after",
203 | "request_id",
204 | "service",
205 | "original_error",
206 | ]:
207 | if key in details:
208 | error_detail_data[key] = details[key]
209 | # 将其他details作为details字段
210 | other_details = {
211 | k: v
212 | for k, v in details.items()
213 | if k
214 | not in [
215 | "param",
216 | "type",
217 | "retry_after",
218 | "request_id",
219 | "service",
220 | "original_error",
221 | ]
222 | }
223 | if other_details:
224 | error_detail_data["details"] = other_details
225 |
226 | # 创建并返回StandardErrorResponse实例
227 | return StandardErrorResponse(error=ErrorDetail(**error_detail_data))
228 |
--------------------------------------------------------------------------------
/src/models/openai.py:
--------------------------------------------------------------------------------
1 | """OpenAI API 数据模型定义"""
2 |
3 | from typing import Any, Literal
4 |
5 | from pydantic import BaseModel, Field
6 |
7 |
8 | class OpenAIMessageContent(BaseModel):
9 | """OpenAI消息内容项"""
10 |
11 | type: Literal["text", "image_url"] = Field(description="内容类型")
12 | text: str | None = Field(None, description="文本内容")
13 | image_url: dict[str, str] | None = Field(None, description="图像URL配置")
14 |
15 |
16 | class OpenAIMessage(BaseModel):
17 | """OpenAI消息格式"""
18 |
19 | role: Literal["system", "user", "assistant", "tool"] = Field(description="消息角色")
20 | content: str | list[OpenAIMessageContent] | None = Field(description="消息内容")
21 | name: str | None = Field(None, description="消息作者名称")
22 | tool_calls: list[dict[str, Any]] | None = Field(
23 | None, description="工具调用信息(当role为assistant时)"
24 | )
25 | tool_call_id: str | None = Field(
26 | None, description="工具调用ID(当role为tool时)"
27 | )
28 | refusal: str | None = Field(None, description="拒绝服务的详细信息")
29 | reasoning_content: str | None = Field(None, description="推理内容")
30 | annotations: list[dict[str, Any]] | None = Field(
31 | None, description="模型生成的标注"
32 | )
33 |
34 |
35 | class OpenAIImageUrl(BaseModel):
36 | """OpenAI图像URL"""
37 |
38 | url: str = Field(description="图像URL")
39 | detail: Literal["auto", "low", "high"] | None = Field(
40 | "auto", description="图像细节级别"
41 | )
42 |
43 |
44 | class OpenAIToolCallFunction(BaseModel):
45 | """工具调用函数"""
46 |
47 | name: str | None = Field(None, description="函数名称") # 改为可选,因为流式响应中可能缺失
48 | arguments: str | None = Field(None, description="JSON格式的函数参数") # 改为可选,支持增量传输
49 |
50 |
51 | class OpenAIToolCall(BaseModel):
52 | """工具调用"""
53 |
54 | id: str = Field(description="工具调用ID")
55 | type: Literal["function"] = Field("function", description="调用类型")
56 | function: OpenAIToolCallFunction = Field(description="函数详情")
57 |
58 |
59 | class OpenAIDeltaToolCall(BaseModel):
60 | """工具调用增量"""
61 |
62 | index: int | None = Field(None, description="工具调用索引") # 改为可选,因为流式响应中可能缺失
63 | id: str | None = Field(None, description="工具调用ID")
64 | type: Literal["function"] | None = Field(None, description="调用类型")
65 | function: OpenAIToolCallFunction | None = Field(None, description="函数详情增量")
66 |
67 |
68 | class OpenAIToolFunction(BaseModel):
69 | """OpenAI工具函数定义"""
70 |
71 | name: str = Field(description="函数名称")
72 | description: str | None = Field(None, description="函数描述")
73 | parameters: dict[str, Any] | None = Field(
74 | None, description="JSON Schema格式的函数参数"
75 | )
76 |
77 |
78 | class OpenAITool(BaseModel):
79 | """OpenAI工具定义"""
80 |
81 | type: Literal["function"] = Field("function", description="工具类型")
82 | function: OpenAIToolFunction = Field(description="函数定义")
83 |
84 |
85 | class OpenAIRequest(BaseModel):
86 | """OpenAI API请求模型"""
87 |
88 | model: str = Field(description="使用的模型ID,如gpt-4o或claude-3-5-sonnet")
89 | messages: list[OpenAIMessage] = Field(description="对话消息列表")
90 | max_tokens: int | None = Field(None, description="最大输出token数量")
91 | max_completion_tokens: int | None = Field(
92 | None, description="最大完成token数量(新格式)"
93 | )
94 | temperature: float | None = Field(None, ge=0.0, le=2.0, description="采样温度")
95 | top_p: float | None = Field(None, ge=0.0, le=1.0, description="top-p采样参数")
96 | top_k: int | None = Field(None, ge=0, description="top-k采样参数")
97 | stream: bool | None = Field(False, description="是否使用流式响应")
98 | stream_options: dict[str, Any] | None = Field(None, description="流式响应选项")
99 | stop: str | list[str] | None = Field(None, description="停止序列")
100 | frequency_penalty: float | None = Field(None, description="频率惩罚")
101 | presence_penalty: float | None = Field(None, description="存在惩罚")
102 | logprobs: bool | None = Field(False, description="是否返回log概率")
103 | top_logprobs: int | None = Field(
104 | None, description="返回的top log probability数量"
105 | )
106 | logit_bias: dict[str, int] | None = Field(None, description="logit偏差")
107 | n: int | None = Field(None, ge=1, le=128, description="生成的消息数量")
108 | seed: int | None = Field(None, description="随机种子")
109 | response_format: dict[str, Any] | None = Field(None, description="响应格式配置")
110 | tools: list[OpenAITool] | None = Field(None, description="可用工具定义")
111 | tool_choice: str | dict[str, Any] | None = Field(
112 | None, description="工具选择配置"
113 | )
114 | parallel_tool_calls: bool | None = Field(
115 | None, description="是否允许并行工具调用"
116 | )
117 | user: str | None = Field(None, description="用户信息")
118 | think: bool | None = Field(None, description="是否启用推理模型模式")
119 |
120 |
121 | class OpenAIChoiceDelta(BaseModel):
122 | """流式响应增量内容"""
123 |
124 | role: str | None = Field(None, description="消息角色")
125 | content: str | None = Field(None, description="内容增量")
126 | reasoning_content: str | None = Field(None, description="推理内容增量")
127 | tool_calls: list[OpenAIDeltaToolCall] | None = Field(
128 | None, description="工具调用增量"
129 | )
130 |
131 |
132 | class OpenAIChoice(BaseModel):
133 | """OpenAI响应选项"""
134 |
135 | index: int = Field(description="选项索引")
136 | message: OpenAIMessage | None = Field(None, description="完整消息响应")
137 | delta: OpenAIChoiceDelta | None = Field(None, description="流式增量内容")
138 | finish_reason: str | None = Field(
139 | None,
140 | description="完成原因: stop, length, content_filter, tool_calls, function_call",
141 | )
142 | logprobs: dict[str, Any] | None = Field(None, description="log概率信息")
143 |
144 |
145 | class OpenAIUsage(BaseModel):
146 | """OpenAI使用统计"""
147 |
148 | prompt_tokens: int = Field(description="提示token数量")
149 | completion_tokens: int = Field(description="完成token数量")
150 | total_tokens: int = Field(description="总token数量")
151 | completion_tokens_details: dict[str, Any] | None = Field(
152 | None, description="完成token的详细信息"
153 | )
154 | prompt_tokens_details: dict[str, Any] | None = Field(
155 | None, description="提示token的详细信息"
156 | )
157 |
158 |
159 | class OpenAICompletionUsage(BaseModel):
160 | """OpenAI完成使用统计"""
161 |
162 | completion_tokens: int = Field(description="完成token数量")
163 | prompt_tokens: int = Field(description="提示token数量")
164 | total_tokens: int = Field(description="总token数量")
165 |
166 |
167 | class OpenAIResponse(BaseModel):
168 | """OpenAI API响应模型"""
169 |
170 | id: str = Field(description="响应唯一ID")
171 | object: Literal["chat.completion"] = Field(description="对象类型")
172 | created: int = Field(description="创建时间戳")
173 | model: str = Field(description="使用的模型ID")
174 | choices: list[OpenAIChoice] = Field(description="响应选项列表")
175 | usage: OpenAIUsage = Field(description="使用统计")
176 | system_fingerprint: str | None = Field(None, description="系统指纹")
177 |
178 |
179 | class OpenAIStreamResponse(BaseModel):
180 | """OpenAI流式响应模型"""
181 |
182 | id: str = Field(description="响应唯一ID")
183 | object: Literal["chat.completion.chunk"] = Field(description="对象类型")
184 | created: int = Field(description="创建时间戳")
185 | model: str = Field(description="使用的模型ID")
186 | choices: list[OpenAIChoice] = Field(description="响应选项列表")
187 | usage: OpenAIUsage | None = Field(
188 | None, description="使用统计(仅在流式响应最后一块出现)"
189 | )
190 | system_fingerprint: str | None = Field(None, description="系统指纹")
191 |
192 |
193 | class OpenAIErrorDetail(BaseModel):
194 | """OpenAI错误详情"""
195 |
196 | message: str = Field(description="错误消息")
197 | type: str = Field(description="错误类型")
198 | param: str | None = Field(None, description="相关参数")
199 | code: str | None = Field(None, description="错误代码")
200 |
201 |
202 | class OpenAIErrorResponse(BaseModel):
203 | """OpenAI错误响应模型"""
204 |
205 | error: OpenAIErrorDetail = Field(description="错误详情")
206 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 测试模块
3 |
4 | 包含项目的单元测试和集成测试。
5 |
6 | 测试结构:
7 | - unit/: 单元测试
8 | - integration/: 集成测试
9 | - fixtures/: 测试夹具和工具
10 |
11 | 测试覆盖:
12 | - API端点测试
13 | - 数据转换测试
14 | - 错误处理测试
15 | - 流式响应测试
16 | - 配置管理测试
17 | """
18 |
19 | # 导入测试夹具
20 | from .fixtures import *
21 |
22 | __all__ = [
23 | # 测试夹具将通过 fixtures 模块的 __all__ 自动导出
24 | ]
--------------------------------------------------------------------------------
/tests/fixtures.py:
--------------------------------------------------------------------------------
1 | """Mock OpenAI server for end-to-end testing."""
2 |
3 | from typing import Dict, Any, List, Optional, AsyncGenerator
4 | from fastapi import FastAPI, HTTPException, Header
5 | from pydantic import BaseModel, Field
6 | import time
7 | import asyncio
8 | from contextlib import asynccontextmanager
9 |
10 |
11 | class MockMessage(BaseModel):
12 | role: str
13 | content: str
14 |
15 |
16 | class MockChoice(BaseModel):
17 | index: int = 0
18 | message: MockMessage
19 | finish_reason: str = "stop"
20 |
21 |
22 | class MockDelta(BaseModel):
23 | content: Optional[str] = None
24 | role: Optional[str] = None
25 |
26 |
27 | class MockStreamChoice(BaseModel):
28 | index: int = 0
29 | delta: MockDelta
30 | finish_reason: Optional[str] = None
31 |
32 |
33 | class MockUsage(BaseModel):
34 | prompt_tokens: int
35 | completion_tokens: int
36 | total_tokens: int
37 |
38 |
39 | class MockCompletionResponse(BaseModel):
40 | id: str = Field(default="mock-completion-id")
41 | object: str = "chat.completion"
42 | created: int = Field(default_factory=lambda: int(time.time()))
43 | model: str
44 | choices: List[MockChoice]
45 | usage: MockUsage
46 |
47 |
48 | class MockStreamResponse(BaseModel):
49 | id: str = Field(default="mock-completion-id")
50 | object: str = "chat.completion.chunk"
51 | created: int = Field(default_factory=lambda: int(time.time()))
52 | model: str
53 | choices: List[MockStreamChoice]
54 |
55 |
56 | class MockChatCompletionRequest(BaseModel):
57 | model: str
58 | messages: List[MockMessage]
59 | max_tokens: Optional[int] = None
60 | temperature: Optional[float] = None
61 | stream: bool = False
62 |
63 |
64 | mock_responses: Dict[str, Dict[str, Any]] = {}
65 | error_trigger: str = ""
66 | delay_ms: int = 0
67 |
68 |
69 | @asynccontextmanager
70 | async def lifespan(app: FastAPI):
71 | """Manage mock server lifecycle."""
72 | print("Mock OpenAI server starting...")
73 | yield
74 | print("Mock OpenAI server shutting down...")
75 |
76 |
77 | app = FastAPI(lifespan=lifespan)
78 |
79 |
80 | @app.middleware("http")
81 | async def add_delay(request, call_next):
82 | """Add configurable delay to all responses."""
83 | if delay_ms > 0:
84 | await asyncio.sleep(delay_ms / 1000)
85 | response = await call_next(request)
86 | return response
87 |
88 |
89 | @app.post("/v1/chat/completions")
90 | async def mock_chat_completions(
91 | request: MockChatCompletionRequest,
92 | authorization: str = Header(...),
93 | ):
94 | """Mock the OpenAI chat completions endpoint."""
95 |
96 | if not authorization.startswith("Bearer mock-key"):
97 | raise HTTPException(status_code=401, detail="Invalid API key")
98 |
99 | if error_trigger == "invalid_model":
100 | raise HTTPException(status_code=400, detail="Invalid model")
101 |
102 | if error_trigger == "rate_limit":
103 | raise HTTPException(status_code=429, detail="Rate limit exceeded")
104 |
105 | if error_trigger == "server_error":
106 | raise HTTPException(status_code=500, detail="Internal server error")
107 |
108 | model = request.model
109 | if model == "invalid-model":
110 | raise HTTPException(status_code=404, detail="Model not found")
111 |
112 | last_message = request.messages[-1]
113 |
114 | if request.stream:
115 | return mock_stream_response(model, last_message.content or "")
116 |
117 | return mock_non_stream_response(model, last_message.content or "")
118 |
119 |
120 | def mock_non_stream_response(model: str, content: str) -> MockCompletionResponse:
121 | """Generate a non-streaming chat completion response."""
122 | prompt_tokens = len(content.split())
123 | completion_tokens = 10 # Fixed for testing
124 |
125 | return MockCompletionResponse(
126 | model=model,
127 | choices=[
128 | MockChoice(
129 | message=MockMessage(
130 | role="assistant",
131 | content=f"Mock response for: {content}"
132 | ),
133 | finish_reason="stop"
134 | )
135 | ],
136 | usage=MockUsage(
137 | prompt_tokens=prompt_tokens,
138 | completion_tokens=completion_tokens,
139 | total_tokens=prompt_tokens + completion_tokens
140 | )
141 | )
142 |
143 |
144 | async def mock_stream_response(model: str, content: str):
145 | """Generate a streaming chat completion response."""
146 | words = ["This", "is", "a", "mock", "response", "for:", content[:10], "..."]
147 |
148 | # First chunk with role
149 | yield MockStreamResponse(
150 | model=model,
151 | choices=[
152 | MockStreamChoice(
153 | delta=MockDelta(role="assistant", content=""),
154 | finish_reason=None
155 | )
156 | ]
157 | ).model_dump_json() + "\n"
158 |
159 | # Stream content chunks
160 | for word in words:
161 | chunk = MockStreamResponse(
162 | model=model,
163 | choices=[
164 | MockStreamChoice(
165 | delta=MockDelta(content=f"{word}"),
166 | finish_reason=None
167 | )
168 | ]
169 | )
170 | yield chunk.model_dump_json() + "\n"
171 | await asyncio.sleep(0.05) # Simulate streaming delay
172 |
173 | # Final chunk
174 | chunk = MockStreamResponse(
175 | model=model,
176 | choices=[
177 | MockStreamChoice(
178 | delta=MockDelta(content=""),
179 | finish_reason="stop"
180 | )
181 | ]
182 | )
183 | yield chunk.model_dump_json() + "\n"
184 |
185 |
186 | @app.get("/health")
187 | async def health_check():
188 | """Health check endpoint."""
189 | return {"status": "healthy"}
190 |
191 |
192 | @app.get("/mock/config")
193 | async def get_config():
194 | """Get current mock configuration."""
195 | return {
196 | "error_trigger": error_trigger,
197 | "delay_ms": delay_ms,
198 | "responses": list(mock_responses.keys())
199 | }
200 |
201 |
202 | @app.post("/mock/configure")
203 | async def configure_mock(config: Dict[str, Any]):
204 | """Configure mock server behavior."""
205 | global error_trigger, delay_ms, mock_responses
206 |
207 | if "error_trigger" in config:
208 | error_trigger = config["error_trigger"]
209 | if "delay_ms" in config:
210 | delay_ms = config["delay_ms"]
211 | if "responses" in config:
212 | mock_responses.update(config["responses"])
213 |
214 | return {"message": "Mock server configured"}
215 |
216 |
217 | @app.delete("/mock/reset")
218 | async def reset_mock():
219 | """Reset mock server to default state."""
220 | global error_trigger, delay_ms, mock_responses
221 | error_trigger = ""
222 | delay_ms = 0
223 | mock_responses.clear()
224 | return {"message": "Mock server reset"}
--------------------------------------------------------------------------------
/tests/integration/test_end_to_end.py:
--------------------------------------------------------------------------------
1 | """End-to-end integration tests for the OpenAI-to-Claude proxy."""
2 |
3 | import asyncio
4 | import json
5 | import os
6 | import requests
7 | import pytest
8 | from typing import Dict, Any, List
9 | import httpx
10 | from fastapi.testclient import TestClient
11 | from unittest.mock import patch, AsyncMock
12 |
13 | from src.main import app
14 | from tests.fixtures.mock_openai_server import app as mock_app
15 |
16 |
17 | class TestEndToEndIntegration:
18 | """End-to-end integration tests covering the full proxy flow."""
19 |
20 | @pytest.fixture(autouse=True)
21 | def setup_client(self):
22 | """Set up test environment and create a test client with auth headers."""
23 | api_key = "mock-key"
24 | os.environ["API_KEY"] = api_key
25 | os.environ["OPENAI_API_KEY"] = "mock-openai-key"
26 | os.environ["OPENAI_BASE_URL"] = "http://localhost:8001"
27 |
28 | # Reset mock server configuration
29 | try:
30 | requests.delete("http://localhost:8001/mock/reset")
31 | except requests.ConnectionError:
32 | # Mock server might not be running, which is fine for some tests
33 | pass
34 |
35 | self.client = TestClient(app)
36 | self.client.headers = {"Authorization": f"Bearer {api_key}"}
37 |
38 | yield
39 |
40 | # Clean up environment
41 | if "API_KEY" in os.environ:
42 | del os.environ["API_KEY"]
43 | if "OPENAI_API_KEY" in os.environ:
44 | del os.environ["OPENAI_API_KEY"]
45 | if "OPENAI_BASE_URL" in os.environ:
46 | del os.environ["OPENAI_BASE_URL"]
47 |
48 | def test_basic_chat_completion(self):
49 | """Test complete flow from Anthropic request to OpenAI proxy."""
50 | client = self.client
51 |
52 | payload = {
53 | "model": "claude-3-sonnet-20240229",
54 | "max_tokens": 1024,
55 | "messages": [
56 | {"role": "user", "content": "Hello, world!"}
57 | ]
58 | }
59 |
60 | response = client.post("/v1/messages", json=payload)
61 |
62 | assert response.status_code == 200
63 | data = response.json()
64 |
65 | # Verify Anthropic response structure
66 | assert "id" in data
67 | assert "type" in data
68 | assert "role" in data
69 | assert "content" in data
70 | assert "model" in data
71 | assert "stop_reason" in data
72 | assert "usage" in data
73 |
74 | # Verify content appeared in mock response
75 | assert "Hello, world!" in data["content"][0]["text"]
76 |
77 | def test_chat_completion_with_system_message(self):
78 | """Test chat completion with system message included."""
79 | client = self.client
80 |
81 | payload = {
82 | "model": "claude-3-sonnet-20240229",
83 | "max_tokens": 1024,
84 | "messages": [
85 | {"role": "system", "content": "You are a helpful assistant."},
86 | {"role": "user", "content": "Tell me a joke"}
87 | ]
88 | }
89 |
90 | response = client.post("/v1/messages", json=payload)
91 |
92 | assert response.status_code == 200
93 | data = response.json()
94 |
95 | assert data["role"] == "assistant"
96 | assert "Tell me a joke" in data["content"][0]["text"]
97 |
98 | def test_chat_completion_with_temperature(self):
99 | """Test parameter conversion including temperature."""
100 | client = self.client
101 |
102 | payload = {
103 | "model": "claude-3-sonnet-20240229",
104 | "max_tokens": 500,
105 | "temperature": 0.7,
106 | "messages": [
107 | {"role": "user", "content": "Give me a creative response"}
108 | ]
109 | }
110 |
111 | response = client.post("/v1/messages", json=payload)
112 |
113 | assert response.status_code == 200
114 | data = response.json()
115 |
116 | assert response.status_code == 200
117 | assert "creative response" in data["content"][0]["text"]
118 |
119 | def test_invalid_model_error(self):
120 | """Test error handling for invalid model."""
121 | client = self.client
122 |
123 | # Configure mock to return error
124 | requests.post(
125 | "http://localhost:8001/mock/configure",
126 | json={"error_trigger": "invalid_model"}
127 | )
128 |
129 | payload = {
130 | "model": "invalid-model-name",
131 | "max_tokens": 1024,
132 | "messages": [{"role": "user", "content": "test"}]
133 | }
134 |
135 | response = client.post("/v1/messages", json=payload)
136 |
137 | assert response.status_code == 400
138 | data = response.json()
139 |
140 | assert "type" in data
141 | assert "error" in data
142 | assert data["error"]["type"] == "invalid_request_error"
143 |
144 | def test_rate_limit_error(self):
145 | """Test rate limit error response."""
146 | client = self.client
147 |
148 | # Configure mock to return rate limit error
149 | requests.post(
150 | "http://localhost:8001/mock/configure",
151 | json={"error_trigger": "rate_limit"}
152 | )
153 |
154 | payload = {
155 | "model": "claude-3-sonnet-20240229",
156 | "max_tokens": 1024,
157 | "messages": [{"role": "user", "content": "test"}]
158 | }
159 |
160 | response = client.post("/v1/messages", json=payload)
161 |
162 | assert response.status_code == 429
163 | data = response.json()
164 |
165 | assert data["error"]["type"] == "rate_limit_error"
166 |
167 | def test_server_error(self):
168 | """Test server error response."""
169 | client = self.client
170 |
171 | # Configure mock to return server error
172 | requests.post(
173 | "http://localhost:8001/mock/configure",
174 | json={"error_trigger": "server_error"}
175 | )
176 |
177 | payload = {
178 | "model": "claude-3-sonnet-20240229",
179 | "max_tokens": 1024,
180 | "messages": [{"role": "user", "content": "test"}]
181 | }
182 |
183 | response = client.post("/v1/messages", json=payload)
184 |
185 | assert response.status_code == 500
186 | data = response.json()
187 |
188 | assert data["error"]["type"].startswith("api")
189 |
190 | def test_invalid_api_key_error(self):
191 | """Test unauthorized error with invalid API key."""
192 | os.environ["OPENAI_API_KEY"] = "invalid-key"
193 |
194 | client = self.client
195 |
196 | payload = {
197 | "model": "claude-3-sonnet-20240229",
198 | "max_tokens": 1024,
199 | "messages": [{"role": "user", "content": "test"}]
200 | }
201 |
202 | response = client.post("/v1/messages", json=payload)
203 |
204 | assert response.status_code == 401
205 | data = response.json()
206 |
207 | assert data["error"]["type"] == "authentication_error"
208 |
209 | def test_request_validation_error(self):
210 | """Test request validation error handling."""
211 | client = self.client
212 |
213 | # Invalid payload structure
214 | payload = {
215 | "model": "claude-3-sonnet-20240229",
216 | # Missing required 'messages'
217 | "max_tokens": 1024
218 | }
219 |
220 | response = client.post("/v1/messages", json=payload)
221 |
222 | assert response.status_code == 422
223 | data = response.json()
224 |
225 | assert "messages" in data["error"]["type"] or "messages" in str(data)
226 |
227 | def test_empty_message_content(self):
228 | """Test handling of empty message content."""
229 | client = self.client
230 |
231 | payload = {
232 | "model": "claude-3-sonnet-20240229",
233 | "max_tokens": 1024,
234 | "messages": [{"role": "user", "content": ""}]
235 | }
236 |
237 | response = client.post("/v1/messages", json=payload)
238 |
239 | assert response.status_code == 200
240 | data = response.json()
241 |
242 | assert "content" in data
243 | assert len(data["content"]) > 0
244 |
245 | def test_concurrent_requests(self):
246 | """Test handling multiple concurrent requests."""
247 | client = self.client
248 |
249 | def make_request(content: str):
250 | payload = {
251 | "model": "claude-3-sonnet-20240229",
252 | "max_tokens": 100,
253 | "messages": [{"role": "user", "content": content}]
254 | }
255 | return client.post("/v1/messages", json=payload)
256 |
257 | # Make 5 concurrent requests
258 | import concurrent.futures
259 |
260 | with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
261 | futures = [
262 | executor.submit(make_request, f"Request {i}")
263 | for i in range(5)
264 | ]
265 |
266 | results = [future.result() for future in concurrent.futures.as_completed(futures)]
267 |
268 | # All should succeed
269 | for response in results:
270 | assert response.status_code == 200
271 | data = response.json()
272 | assert "content" in data
273 |
274 | def test_response_timing(self):
275 | """Test response timing is under expected threshold."""
276 | import time
277 |
278 | client = self.client
279 |
280 | # Configure mock for 100ms delay
281 | requests.post(
282 | "http://localhost:8001/mock/configure",
283 | json={"delay_ms": 100}
284 | )
285 |
286 | payload = {
287 | "model": "claude-3-sonnet-20240229",
288 | "max_tokens": 1024,
289 | "messages": [{"role": "user", "content": "Why is the sky blue?"}]
290 | }
291 |
292 | start_time = time.time()
293 | response = client.post("/v1/messages", json=payload)
294 | end_time = time.time()
295 |
296 | assert response.status_code == 200
297 | duration = (end_time - start_time) * 1000 # Convert to milliseconds
298 |
299 | # Should be under 50ms plus mock delay (adjusted for testing environment)
300 | assert duration < 200 # Conservative threshold including mock overhead
301 |
302 | @pytest.mark.skip
303 | def test_health_check_endpoint(self):
304 | """Test health check endpoint connectivity."""
305 | client = self.client
306 |
307 | response = client.get("/health")
308 |
309 | assert response.status_code == 200
310 | data = response.json()
311 | assert data["status"] == "running"
312 |
313 |
314 | if __name__ == "__main__":
315 | import uvicorn
316 | uvicorn.run(mock_app, host="0.0.0.0", port=8001)
--------------------------------------------------------------------------------
/tests/integration/test_error_handling.py:
--------------------------------------------------------------------------------
1 | """Comprehensive error handling tests for the proxy."""
2 |
3 | import pytest
4 | import asyncio
5 | from typing import Dict, Any
6 | import requests
7 | from fastapi.testclient import TestClient
8 | from unittest.mock import patch, MagicMock, AsyncMock
9 | import os
10 |
11 | from src.main import app
12 | from src.core.clients.openai_client import OpenAIServiceClient
13 |
14 |
15 | class TestErrorHandlingIntegration:
16 | """Integration tests for various error paths and edge cases."""
17 |
18 | @pytest.fixture(autouse=True)
19 | def setup(self):
20 | """Set up test environment."""
21 | os.environ["OPENAI_API_KEY"] = "mock-key"
22 | os.environ["OPENAI_BASE_URL"] = "http://localhost:8001"
23 | os.environ["REQUEST_TIMEOUT"] = "2" # Short timeout for testing
24 |
25 | # Reset mock server
26 | requests.delete("http://localhost:8001/mock/reset")
27 |
28 | yield
29 |
30 | # Clean up environment
31 | if "OPENAI_API_KEY" in os.environ:
32 | del os.environ["OPENAI_API_KEY"]
33 | if "OPENAI_BASE_URL" in os.environ:
34 | del os.environ["OPENAI_BASE_URL"]
35 | if "REQUEST_TIMEOUT" in os.environ:
36 | del os.environ["REQUEST_TIMEOUT"]
37 |
38 | def test_network_timeout(self):
39 | """Test network timeout handling."""
40 | # Configure mock server with long delay
41 | requests.post(
42 | "http://localhost:8001/mock/configure",
43 | json={"delay_ms": 3000} # 3s delay, longer than 2s timeout
44 | )
45 |
46 | client = TestClient(app)
47 |
48 | payload = {
49 | "model": "claude-3-sonnet-20240229",
50 | "max_tokens": 100,
51 | "messages": [{"role": "user", "content": "Timeout test"}]
52 | }
53 |
54 | response = client.post("/v1/messages", json=payload)
55 |
56 | assert response.status_code == 504 # Gateway timeout
57 | data = response.json()
58 |
59 | assert data["error"]["type"] == "timeout_error"
60 | assert "timeout" in data["error"]["message"].lower()
61 |
62 | def test_connection_refused(self):
63 | """Test connection refused error handling."""
64 | # Configure invalid base URL
65 | os.environ["OPENAI_BASE_URL"] = "http://localhost:9999" # Invalid port
66 |
67 | client = TestClient(app)
68 |
69 | payload = {
70 | "model": "claude-3-sonnet-20240229",
71 | "max_tokens": 100,
72 | "messages": [{"role": "user", "content": "Connection test"}]
73 | }
74 |
75 | response = client.post("/v1/messages", json=payload)
76 |
77 | assert response.status_code == 503 # Service unavailable
78 | data = response.json()
79 |
80 | assert "connection" in data["error"]["type"] or "service_unavailable" in data["error"]["type"]
81 |
82 | def test_malformed_openai_response(self):
83 | """Test handling of malformed JSON from OpenAI."""
84 | from unittest.mock import patch
85 | import httpx
86 |
87 | client = TestClient(app)
88 |
89 | # Mock httpx to return invalid JSON
90 | with patch('httpx.AsyncClient.post') as mock_post:
91 | mock_response = MagicMock()
92 | mock_response.status_code = 200
93 | mock_response.json.side_effect = ValueError("Invalid JSON")
94 | mock_response.text = "invalid json response"
95 | mock_post.return_value.__aenter__.return_value = mock_response
96 |
97 | payload = {
98 | "model": "claude-3-sonnet-20240229",
99 | "max_tokens": 100,
100 | "messages": [{"role": "user", "content": "Malformed response test"}]
101 | }
102 |
103 | response = client.post("/v1/messages", json=payload)
104 |
105 | assert response.status_code == 502 # Bad gateway
106 | data = response.json()
107 |
108 | assert data["error"]["type"] == "api_error"
109 |
110 | def test_openai_500_error(self):
111 | """Test handling of 500 errors from OpenAI."""
112 | client = TestClient(app)
113 |
114 | requests.post(
115 | "http://localhost:8001/mock/configure",
116 | json={"error_trigger": "server_error"}
117 | )
118 |
119 | payload = {
120 | "model": "claude-3-sonnet-20240229",
121 | "max_tokens": 100,
122 | "messages": [{"role": "user", "content": "Server error test"}]
123 | }
124 |
125 | response = client.post("/v1/messages", json=payload)
126 |
127 | assert response.status_code == 500
128 | data = response.json()
129 |
130 | assert data["error"]["type"] == "api_error"
131 | assert "server" in data["error"]["message"].lower()
132 |
133 | def test_openai_429_rate_limit(self):
134 | """Test handling of rate limit errors from OpenAI."""
135 | client = TestClient(app)
136 |
137 | requests.post(
138 | "http://localhost:8001/mock/configure",
139 | json={"error_trigger": "rate_limit"}
140 | )
141 |
142 | payload = {
143 | "model": "claude-3-sonnet-20240229",
144 | "max_tokens": 100,
145 | "messages": [{"role": "user", "content": "Rate limit test"}]
146 | }
147 |
148 | response = client.post("/v1/messages", json=payload)
149 |
150 | assert response.status_code == 429
151 | data = response.json()
152 |
153 | assert data["error"]["type"] == "rate_limit_error"
154 |
155 | def test_openai_401_unauthorized(self):
156 | """Test handling of unauthorized errors from OpenAI."""
157 | os.environ["OPENAI_API_KEY"] = "invalid-mock-key"
158 |
159 | client = TestClient(app)
160 |
161 | payload = {
162 | "model": "claude-3-sonnet-20240229",
163 | "max_tokens": 100,
164 | "messages": [{"role": "user", "content": "Unauthorized test"}]
165 | }
166 |
167 | response = client.post("/v1/messages", json=payload)
168 |
169 | assert response.status_code == 401
170 | data = response.json()
171 |
172 | assert data["error"]["type"] == "authentication_error"
173 |
174 | def test_invalid_json_payload(self):
175 | """Test handling of invalid JSON payload."""
176 | client = TestClient(app)
177 |
178 | # Send raw text instead of JSON
179 | response = client.post(
180 | "/v1/messages",
181 | data="invalid json payload",
182 | headers={"content-type": "application/json"}
183 | )
184 |
185 | assert response.status_code == 422
186 | data = response.json()
187 |
188 | assert "json" in str(data).lower() or "invalid" in str(data).lower()
189 |
190 | def test_missing_required_fields(self):
191 | """Test validation errors for missing required fields."""
192 | client = TestClient(app)
193 |
194 | test_cases = [
195 | # Missing messages
196 | {"model": "claude-3-sonnet-20240229", "max_tokens": 100},
197 | # Empty messages
198 | {"model": "claude-3-sonnet-20240229", "max_tokens": 100, "messages": []},
199 | # Missing model
200 | {"max_tokens": 100, "messages": [{"role": "user", "content": "test"}]},
201 | # Invalid message structure
202 | {
203 | "model": "claude-3-sonnet-20240229",
204 | "max_tokens": 100,
205 | "messages": [{"invalid_role": "test", "content": "test"}]
206 | }
207 | ]
208 |
209 | for payload in test_cases:
210 | response = client.post("/v1/messages", json=payload)
211 | assert response.status_code == 422, f"Expected validation error for {payload}"
212 |
213 | def test_invalid_temperature_range(self):
214 | """Test validation of temperature parameter."""
215 | client = TestClient(app)
216 |
217 | test_cases = [
218 | -0.1, # Below 0.0
219 | 2.1, # Above 2.0
220 | "abc", # Non-numeric
221 | None # Null (optional should be fine)
222 | ]
223 |
224 | for temp in test_cases:
225 | payload = {
226 | "model": "claude-3-sonnet-20240229",
227 | "max_tokens": 100,
228 | "temperature": temp,
229 | "messages": [{"role": "user", "content": "test"}]
230 | }
231 |
232 | if temp is None or temp == "abc":
233 | # These might pass validation, skip
234 | continue
235 |
236 | response = client.post("/v1/messages", json=payload)
237 |
238 | if temp < 0.0 or temp > 2.0:
239 | assert response.status_code == 422
240 |
241 | def test_invalid_max_tokens(self):
242 | """Test validation of max_tokens parameter."""
243 | client = TestClient(app)
244 |
245 | test_cases = [
246 | -1, # Negative
247 | 0, # Zero
248 | 1_000_000, # Too large
249 | "abc", # Non-integer
250 | None # Should be required
251 | ]
252 |
253 | for max_tokens in test_cases:
254 | payload = {
255 | "model": "claude-3-sonnet-20240229",
256 | "max_tokens": max_tokens,
257 | "messages": [{"role": "user", "content": "test"}]
258 | }
259 |
260 | if max_tokens is None or max_tokens == "abc" or max_tokens <= 0:
261 | response = client.post("/v1/messages", json=payload)
262 | if max_tokens is None or max_tokens <= 0:
263 | assert response.status_code == 422
264 |
265 | def test_large_payload_handling(self):
266 | """Test handling of very large payloads."""
267 | client = TestClient(app)
268 |
269 | # Create large message content
270 | large_content = "x" * 100_000 # 100KB of content
271 |
272 | payload = {
273 | "model": "claude-3-sonnet-20240229",
274 | "max_tokens": 100,
275 | "messages": [
276 | {"role": "user", "content": large_content},
277 | {"role": "assistant", "content": "That's a large payload"},
278 | {"role": "user", "content": "Yes it is"}
279 | ]
280 | }
281 |
282 | response = client.post("/v1/messages", json=payload)
283 |
284 | # Should handle large content gracefully (might succeed or validation error)
285 | assert response.status_code in [200, 422]
286 |
287 | def test_special_characters_in_content(self):
288 | """Test handling of special characters in message content."""
289 | client = TestClient(app)
290 |
291 | special_contents = [
292 | "Unicode: 你好世界 🌍",
293 | "Emojis: 🚀✈️🚁🛸",
294 | "Newlines: Hello\nWorld\nTest",
295 | "Quotes: \"Hello\" and 'World'",
296 | "HTML: test
",
297 | "JSON: {'key': 'value'}",
298 | "Code: `print('hello')`",
299 | "Long text: " + "word " * 100
300 | ]
301 |
302 | for content in special_contents:
303 | payload = {
304 | "model": "claude-3-sonnet-20240229",
305 | "max_tokens": 100,
306 | "messages": [{"role": "user", "content": content}]
307 | }
308 |
309 | response = client.post("/v1/messages", json=payload)
310 | assert response.status_code == 200, f"Failed for content: {content[:50]}..."
311 |
312 | def test_streaming_timeout(self):
313 | """Test streaming timeout handling."""
314 | # Configure mock server with long delay for streaming
315 | requests.post(
316 | "http://localhost:8001/mock/configure",
317 | json={"delay_ms": 3000}
318 | )
319 |
320 | client = TestClient(app)
321 |
322 | payload = {
323 | "model": "claude-3-sonnet-20240229",
324 | "max_tokens": 100,
325 | "stream": True,
326 | "messages": [{"role": "user", "content": "Streaming timeout test"}]
327 | }
328 |
329 | with client.stream("POST", "/v1/messages", json=payload) as response:
330 | # Streaming timeout might not be caught until we read
331 | assert response.status_code == 200 # Headers sent immediately
332 |
333 | # Collect chunks
334 | chunks = []
335 | for line in response.iter_lines():
336 | if line and line.startswith("data: "):
337 | data_str = line[6:]
338 | if data_str.strip() == "[DONE]":
339 | break
340 |
341 | try:
342 | data = json.loads(data_str)
343 | chunks.append(data)
344 | except Exception:
345 | # Might get timeout error in streaming chunks
346 | pass
347 |
348 | def test_concurrent_error_handling(self):
349 | """Test error handling with concurrent requests."""
350 | import concurrent.futures
351 | import time
352 |
353 | client = TestClient(app)
354 |
355 | def make_request(request_num):
356 | if request_num % 2 == 0:
357 | # Valid request
358 | payload = {
359 | "model": "claude-3-sonnet-20240229",
360 | "max_tokens": 100,
361 | "messages": [{"role": "user", "content": f"Valid request {request_num}"}]
362 | }
363 | else:
364 | # Invalid request (missing messages)
365 | payload = {
366 | "model": "claude-3-sonnet-20240229",
367 | "max_tokens": 100
368 | }
369 |
370 | return client.post("/v1/messages", json=payload)
371 |
372 | # Make concurrent requests with mix of valid and invalid
373 | with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
374 | futures = [executor.submit(make_request, i) for i in range(10)]
375 | results = [future.result() for future in concurrent.futures.as_completed(futures)]
376 |
377 | # Count success vs error responses
378 | success_count = sum(1 for r in results if r.status_code == 200)
379 | error_count = sum(1 for r in results if r.status_code != 200)
380 |
381 | # Should have both success and error responses
382 | assert success_count == 5 # Even-numbered requests
383 | assert error_count == 5 # Odd-numbered requests
--------------------------------------------------------------------------------
/tests/integration/test_health_endpoint.py:
--------------------------------------------------------------------------------
1 | """
2 | Health check endpoint integration tests.
3 | """
4 | import pytest
5 | from httpx import AsyncClient
6 | from fastapi import FastAPI
7 | from unittest.mock import AsyncMock, patch
8 |
9 | from src.main import app
10 |
11 |
12 | class TestHealthEndpoint:
13 | """Test cases for the health check endpoint."""
14 |
15 | @pytest.fixture
16 | def test_app(self) -> FastAPI:
17 | """Create test FastAPI app."""
18 | return app
19 |
20 | @pytest.mark.asyncio
21 | async def test_health_check_success(self, test_app):
22 | """Test successful health check with OpenAI service available."""
23 |
24 | mock_health_result = {
25 | "openai_service": True,
26 | "api_accessible": True,
27 | "last_check": True
28 | }
29 |
30 | with patch(
31 | "src.clients.openai_client.OpenAIServiceClient.health_check",
32 | new_callable=AsyncMock
33 | ) as mock_health_check:
34 | mock_health_check.return_value = mock_health_result
35 |
36 | async with AsyncClient(app=test_app, base_url="http://test") as ac:
37 | response = await ac.get("/health")
38 |
39 | assert response.status_code == 200
40 | data = response.json()
41 |
42 | assert data["status"] == "healthy"
43 | assert data["service"] == "openai-to-claude"
44 | assert "timestamp" in data
45 | assert data["checks"]["openai"] == mock_health_result
46 |
47 | @pytest.mark.asyncio
48 | async def test_health_check_degraded(self, test_app):
49 | """Test health check when OpenAI service is degraded."""
50 |
51 | mock_health_result = {
52 | "openai_service": False,
53 | "api_accessible": True,
54 | "last_check": True
55 | }
56 |
57 | with patch(
58 | "src.clients.openai_client.OpenAIServiceClient.health_check",
59 | new_callable=AsyncMock
60 | ) as mock_health_check:
61 | mock_health_check.return_value = mock_health_result
62 |
63 | async with AsyncClient(app=test_app, base_url="http://test") as ac:
64 | response = await ac.get("/health")
65 |
66 | assert response.status_code == 200
67 | data = response.json()
68 |
69 | assert data["status"] == "degraded"
70 | assert data["service"] == "openai-to-claude"
71 | assert data["checks"]["openai"] == mock_health_result
72 |
73 | @pytest.mark.asyncio
74 | async def test_health_check_unhealthy(self, test_app):
75 | """Test health check when OpenAI service is unavailable."""
76 |
77 | with patch(
78 | "src.clients.openai_client.OpenAIServiceClient.health_check",
79 | new_callable=AsyncMock
80 | ) as mock_health_check:
81 | mock_health_check.side_effect = Exception("Connection failed")
82 |
83 | async with AsyncClient(app=test_app, base_url="http://test") as ac:
84 | response = await ac.get("/health")
85 |
86 | assert response.status_code == 200
87 | data = response.json()
88 |
89 | assert data["status"] == "unhealthy"
90 | assert data["service"] == "openai-to-claude"
91 | assert data["checks"]["openai"]["openai_service"] is False
92 | assert data["checks"]["openai"]["api_accessible"] is False
93 | assert "error" in data["checks"]["openai"]
94 |
95 | @pytest.mark.asyncio
96 | async def test_health_check_response_structure(self, test_app):
97 | """Test health check response structure and required fields."""
98 |
99 | mock_health_result = {
100 | "openai_service": True,
101 | "api_accessible": True,
102 | "last_check": True
103 | }
104 |
105 | with patch(
106 | "src.clients.openai_client.OpenAIServiceClient.health_check",
107 | new_callable=AsyncMock
108 | ) as mock_health_check:
109 | mock_health_check.return_value = mock_health_result
110 |
111 | async with AsyncClient(app=test_app, base_url="http://test") as ac:
112 | response = await ac.get("/health")
113 |
114 | assert response.status_code == 200
115 | assert response.headers["content-type"] == "application/json"
116 |
117 | data = response.json()
118 |
119 | # Check required top-level fields
120 | assert "status" in data
121 | assert "service" in data
122 | assert "timestamp" in data
123 | assert "checks" in data
124 |
125 | # Check status values
126 | assert data["status"] in ["healthy", "degraded", "unhealthy"]
127 | assert data["service"] == "openai-to-claude"
128 |
129 | # Check timestamp format
130 | from datetime import datetime
131 | datetime.fromisoformat(data["timestamp"].replace('Z', '+00:00'))
132 |
133 | # Check checks structure
134 | assert isinstance(data["checks"], dict)
135 | assert "openai" in data["checks"]
--------------------------------------------------------------------------------
/tests/integration/test_messages_endpoint.py:
--------------------------------------------------------------------------------
1 | """
2 | _integration tests for /v1/messages endpoint
3 |
4 | ٛKՌ�t*�B-͔A�pnlb�OpenAI API�
5 | """
6 |
7 | import pytest
8 | from fastapi.testclient import TestClient
9 | from unittest.mock import AsyncMock, patch, MagicMock
10 | import json
11 |
12 | from src.main import app
13 | from src.models.anthropic import AnthropicRequest
14 | from src.models.openai import OpenAIRequest, OpenAIResponse
15 |
16 |
17 | class TestMessagesEndpoint:
18 | """K� /v1/messages ﹟�"""
19 |
20 | @pytest.fixture
21 | def client(self):
22 | """�Kբ7�"""
23 | return TestClient(app)
24 |
25 | @pytest.fixture
26 | def valid_anthropic_request(self):
27 | """ H�Anthropic�B:�"""
28 | return {
29 | "model": "claude-3-5-sonnet-20241022",
30 | "max_tokens": 100,
31 | "messages": [
32 | {
33 | "role": "user",
34 | "content": "Hello, how are you?"
35 | }
36 | ]
37 | }
38 |
39 | @pytest.fixture
40 | def mock_openai_response(self):
41 | """Mock�OpenAI͔"""
42 | return {
43 | "id": "chatcmpl-123",
44 | "object": "chat.completion",
45 | "created": 1700000000,
46 | "model": "gpt-3.5-turbo",
47 | "choices": [
48 | {
49 | "index": 0,
50 | "message": {
51 | "role": "assistant",
52 | "content": "I'm doing well! How can I help you today?"
53 | },
54 | "finish_reason": "stop"
55 | }
56 | ],
57 | "usage": {
58 | "prompt_tokens": 10,
59 | "completion_tokens": 15,
60 | "total_tokens": 25
61 | }
62 | }
63 |
64 | def test_endpoint_exists(self, client):
65 | """���X("""
66 | response = client.post("/v1/messages", json={"model": "test"})
67 | assert response.status_code != 404
68 |
69 | def test_invalid_request_format(self, client):
70 | """K��H�B<"""
71 | response = client.post("/v1/messages", json={"invalid": "format"})
72 | assert response.status_code == 422
73 |
74 | def test_missing_required_fields(self, client):
75 | """K�:�kW�"""
76 | response = client.post("/v1/messages", json={"model": "test"})
77 | assert response.status_code == 422
78 |
79 | def test_empty_messages(self, client):
80 | """K�z�oh"""
81 | request = {
82 | "model": "claude-3-5-sonnet-20241022",
83 | "messages": [],
84 | "max_tokens": 100
85 | }
86 | response = client.post("/v1/messages", json=request)
87 | assert response.status_code == 422
88 |
89 | @patch('src.handlers.messages.MessagesHandler.process_message')
90 | @pytest.mark.asyncio
91 | async def test_successful_non_streaming_request(
92 | self, mock_process, client, valid_anthropic_request, mock_openai_response
93 | ):
94 | """K���^A�B"""
95 | from src.models.anthropic import AnthropicMessageResponse
96 |
97 | # ���Anthropic͔
98 | expected_response = AnthropicMessageResponse(
99 | id="msg_123",
100 | type="message",
101 | role="assistant",
102 | content=[{"type": "text", "text": "I'm doing well! How can I help you today?"}],
103 | model="claude-3-5-sonnet-20241022",
104 | stop_reason="end_turn",
105 | usage={"input_tokens": 10, "output_tokens": 15}
106 | )
107 |
108 | mock_process.return_value = expected_response
109 |
110 | response = client.post("/v1/messages", json=valid_anthropic_request)
111 |
112 | assert response.status_code == 200
113 | response_data = response.json()
114 |
115 | assert response_data["type"] == "message"
116 | assert response_data["role"] == "assistant"
117 | assert len(response_data["content"]) > 0
118 | assert response_data["content"][0]["type"] == "text"
119 | assert response_data["model"] == "claude-3-5-sonnet-20241022"
120 | assert "usage" in response_data
121 |
122 | @patch('src.handlers.messages.MessagesHandler.process_stream_message')
123 | @pytest.mark.asyncio
124 | async def test_successful_streaming_request(
125 | self, mock_stream, client, valid_anthropic_request, mock_openai_response
126 | ):
127 | """K���A�B"""
128 | valid_anthropic_request["stream"] = True
129 |
130 | # !�A͔
131 | async def mock_generator():
132 | yield 'event: content_block_start\ndata: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}}\n\n'
133 | yield 'event: content_block_delta\ndata: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}}\n\n'
134 | yield 'event: message_stop\ndata: {"type": "message_stop"}\n\n'
135 |
136 | mock_stream.return_value = mock_generator()
137 |
138 | response = client.post("/v1/messages", json=valid_anthropic_request)
139 |
140 | assert response.status_code == 200
141 | assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
142 |
143 | lines = response.content.decode('utf-8').split('\\n')
144 | assert any('content_block_start' in line for line in lines)
145 |
146 | def test_system_message_support(self, client):
147 | """K��߈o/"""
148 | request = {
149 | "model": "claude-3-5-sonnet-20241022",
150 | "messages": [{"role": "user", "content": "Hello"}],
151 | "system": [{
152 | "type": "text",
153 | "text": "You are a helpful assistant"
154 | }],
155 | "max_tokens": 100
156 | }
157 |
158 | response = client.post("/v1/messages", json=request)
159 | assert response.status_code != 400
160 |
161 | def test_tool_definition_support(self, client):
162 | """K��w�I/"""
163 | request = {
164 | "model": "claude-3-5-sonnet-20241022",
165 | "messages": [{"role": "user", "content": "Hello"}],
166 | "tools": [
167 | {
168 | "name": "get_weather",
169 | "description": "Get weather information",
170 | "input_schema": {
171 | "type": "object",
172 | "properties": {
173 | "location": {"type": "string"}
174 | },
175 | "required": ["location"]
176 | }
177 | }
178 | ],
179 | "max_tokens": 100
180 | }
181 |
182 | response = client.post("/v1/messages", json=request)
183 | assert response.status_code != 400
184 |
185 | def test_temperature_parameter(self, client):
186 | """K�temperature�p"""
187 | request = {
188 | "model": "claude-3-5-sonnet-20241022",
189 | "messages": [{"role": "user", "content": "Hello"}],
190 | "temperature": 0.5,
191 | "max_tokens": 100
192 | }
193 |
194 | response = client.post("/v1/messages", json=request)
195 | # ��nj��E��1converter
196 | assert response.status_code != 422
197 |
198 | def test_max_tokens_parameter(self, client):
199 | """K�max_tokens�p"""
200 | request = {
201 | "model": "claude-3-5-sonnet-20241022",
202 | "messages": [{"role": "user", "content": "Hello"}],
203 | "max_tokens": 1000
204 | }
205 |
206 | response = client.post("/v1/messages", json=request)
207 | assert response.status_code != 422
208 |
209 | def test_top_p_parameter(self, client):
210 | """K�top_p�p"""
211 | request = {
212 | "model": "claude-3-5-sonnet-20241022",
213 | "messages": [{"role": "user", "content": "Hello"}],
214 | "top_p": 0.9,
215 | "max_tokens": 100
216 | }
217 |
218 | response = client.post("/v1/messages", json=request)
219 | assert response.status_code != 422
220 |
221 | def test_stop_sequences_parameter(self, client):
222 | """K�stop_sequences�p"""
223 | request = {
224 | "model": "claude-3-5-sonnet-20241022",
225 | "messages": [{"role": "user", "content": "Hello"}],
226 | "stop_sequences": ["\n\n"],
227 | "max_tokens": 100
228 | }
229 |
230 | response = client.post("/v1/messages", json=request)
231 | assert response.status_code != 422
232 |
233 | def test_metadata_parameter(self, client):
234 | """K�metadata�p"""
235 | request = {
236 | "model": "claude-3-5-sonnet-20241022",
237 | "messages": [{"role": "user", "content": "Hello"}],
238 | "metadata": {"trace_id": "123456"},
239 | "max_tokens": 100
240 | }
241 |
242 | response = client.post("/v1/messages", json=request)
243 | assert response.status_code != 422
244 |
245 | def test_image_content_support(self, client):
246 | """K��υ�/"""
247 | request = {
248 | "model": "claude-3-5-sonnet-20241022",
249 | "messages": [
250 | {
251 | "role": "user",
252 | "content": [
253 | {"type": "text", "text": "What's in this image?"},
254 | {
255 | "type": "image",
256 | "source": {
257 | "type": "base64",
258 | "media_type": "image/jpeg",
259 | "data": "base64_data_here"
260 | }
261 | }
262 | ]
263 | }
264 | ],
265 | "max_tokens": 100
266 | }
267 |
268 | response = client.post("/v1/messages", json=request)
269 | assert response.status_code != 422
270 |
271 | @patch.dict('src.models.config.Config', {'openai': {'api_key': 'test-key'}})
272 | def test_environment_variables(self, client):
273 | """Kկ���Mn"""
274 | # ��Mn���
275 | response = client.post("/v1/messages", json={
276 | "model": "claude-3-5-sonnet-20241022",
277 | "messages": [{"role": "user", "content": "Hello"}],
278 | "max_tokens": 100
279 | })
280 |
281 | # Config import
282 | assert response.status_code != 500
283 |
284 | def test_error_response_format(self, client):
285 | """K��͔<&Anthropic�"""
286 | response = client.post("/v1/messages", json={"invalid": "format"})
287 |
288 | assert response.status_code == 422
289 | error_data = response.json()
290 |
291 | # ���<+�W�
292 | assert "type" in error_data
293 | assert "error" in error_data
294 | assert "type" in error_data["error"]
295 | assert "message" in error_data["error"]
296 |
297 | def test_all_required_fields_present(self, client):
298 | """K��kW��t'"""
299 | required_fields = [
300 | "model",
301 | "messages",
302 | "max_tokens"
303 | ]
304 |
305 | for field in required_fields:
306 | base_request = {
307 | "model": "claude-3-5-sonnet-20241022",
308 | "messages": [{"role": "user", "content": "Hello"}],
309 | "max_tokens": 100
310 | }
311 |
312 | # �d�kW�
313 | del base_request[field]
314 |
315 | response = client.post("/v1/messages", json=base_request)
316 | assert response.status_code == 422
--------------------------------------------------------------------------------
/tests/integration/test_model_mapping_integration.py:
--------------------------------------------------------------------------------
1 | """
2 | 集成测试:验证模型映射的完整功能
3 | """
4 | import json
5 | from src.models.anthropic import AnthropicRequest, AnthropicMessage
6 | from src.models.openai import OpenAIRequest
7 | from src.core.converters.request_converter import AnthropicToOpenAIConverter
8 |
9 |
10 | class TestModelMappingIntegration:
11 | """集成测试:验证从Anthropic到OpenAI的完整转换"""
12 |
13 | def test_convert_with_thinking_model(self):
14 | """使用thinking模型的完整转换测试"""
15 | anthropic_request = AnthropicRequest(
16 | model="claude-3-5-sonnet-20241022",
17 | messages=[AnthropicMessage(role="user", content="hello")],
18 | max_tokens=100,
19 | thinking=True
20 | )
21 |
22 | openai_request = AnthropicToOpenAIConverter.convert_anthropic_to_openai(anthropic_request)
23 |
24 | assert isinstance(openai_request, OpenAIRequest)
25 | assert openai_request.model == "claude-3-7-sonnet-thinking"
26 | assert len(openai_request.messages) == 1
27 | assert openai_request.messages[0].content == "hello"
28 | assert openai_request.max_tokens == 100
29 |
30 | def test_convert_with_sonnet_model(self):
31 | """使用sonnet模型的完整转换测试"""
32 | anthropic_request = AnthropicRequest(
33 | model="claude-3-5-sonnet-20241022",
34 | messages=[AnthropicMessage(role="user", content="hello")],
35 | max_tokens=100,
36 | thinking=None
37 | )
38 |
39 | openai_request = AnthropicToOpenAIConverter.convert_anthropic_to_openai(anthropic_request)
40 |
41 | assert isinstance(openai_request, OpenAIRequest)
42 | assert openai_request.model == "claude-3-5-sonnet"
43 |
44 | def test_convert_with_haiku_model(self):
45 | """使用haiku模型的完整转换测试"""
46 | anthropic_request = AnthropicRequest(
47 | model="claude-3-5-haiku",
48 | messages=[AnthropicMessage(role="user", content="hello")],
49 | max_tokens=100,
50 | thinking=None
51 | )
52 |
53 | openai_request = AnthropicToOpenAIConverter.convert_anthropic_to_openai(anthropic_request)
54 |
55 | assert isinstance(openai_request, OpenAIRequest)
56 | assert openai_request.model == "claude-3-5-haiku"
57 |
58 | def test_convert_with_system_support(self):
59 | """带系统提示的完整转换测试"""
60 | anthropic_request = AnthropicRequest(
61 | model="claude-3-5-sonnet",
62 | messages=[AnthropicMessage(role="user", content="hello")],
63 | max_tokens=100,
64 | system="你是一个有用的助手",
65 | thinking=True
66 | )
67 |
68 | openai_request = AnthropicToOpenAIConverter.convert_anthropic_to_openai(anthropic_request)
69 |
70 | assert openai_request.model == "claude-3-7-sonnet-thinking"
71 | assert openai_request.messages[0].role == "system"
72 | assert openai_request.messages[0].content == "你是一个有用的助手"
73 |
74 | def test_convert_with_tools(self):
75 | """带工具定义的完整转换测试"""
76 | tools = [
77 | {
78 | "name": "get_weather",
79 | "description": "获取天气信息",
80 | "input_schema": {
81 | "type": "object",
82 | "properties": {
83 | "location": {"type": "string"},
84 | "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}
85 | },
86 | "required": ["location"]
87 | }
88 | }
89 | ]
90 |
91 | anthropic_request = AnthropicRequest(
92 | model="claude-3-5-haiku",
93 | messages=[AnthropicMessage(role="user", content="天气如何")],
94 | max_tokens=100,
95 | tools=tools,
96 | thinking=None
97 | )
98 |
99 | openai_request = AnthropicToOpenAIConverter.convert_anthropic_to_openai(anthropic_request)
100 |
101 | assert openai_request.model == "claude-3-5-haiku"
102 | assert openai_request.tools is not None
103 | assert len(openai_request.tools) == 1
104 | assert openai_request.tools[0].function.name == "get_weather"
105 |
106 | def test_model_passthrough(self):
107 | """测试带有逗号的模型名不被转换"""
108 | anthropic_request = AnthropicRequest(
109 | model="claude-custom,variant",
110 | messages=[AnthropicMessage(role="user", content="hello")],
111 | max_tokens=100
112 | )
113 |
114 | # thinking字段存在,但因为模型名有逗号,应该保留原样
115 | openai_request = AnthropicToOpenAIConverter.convert_anthropic_to_openai(anthropic_request)
116 |
117 | assert openai_request.model == "claude-custom,variant"
118 |
119 | def test_serialization_roundtrip(self):
120 | """测试序列化和反序列化"""
121 | original_request = {
122 | "model": "claude-3-5-sonnet",
123 | "messages": [{"role": "user", "content": "hello"}],
124 | "max_tokens": 100,
125 | "thinking": True
126 | }
127 |
128 | anthropic_request = AnthropicRequest(**original_request)
129 | openai_request = AnthropicToOpenAIConverter.convert_anthropic_to_openai(anthropic_request)
130 |
131 | # 确保可以序列化为JSON
132 | openai_dict = openai_request.dict()
133 | assert openai_dict["model"] == "claude-3-7-sonnet-thinking"
134 | assert openai_dict["messages"][0]["role"] == "user"
135 | assert openai_dict["messages"][0]["content"] == "hello"
--------------------------------------------------------------------------------
/tests/integration/test_streaming.py:
--------------------------------------------------------------------------------
1 | """Streaming response tests for end-to-end testing."""
2 |
3 | import pytest
4 | import requests
5 | from fastapi.testclient import TestClient
6 | import os
7 |
8 | from src.main import app
9 |
10 |
11 | class TestStreamingIntegration:
12 | """Tests for streaming response handling in the proxy."""
13 |
14 | @pytest.fixture(autouse=True)
15 | def setup(self):
16 | """Set up test environment."""
17 | os.environ["OPENAI_API_KEY"] = "mock-key"
18 | os.environ["OPENAI_BASE_URL"] = "http://localhost:8001"
19 |
20 | # Reset mock server
21 | requests.delete("http://localhost:8001/mock/reset")
22 |
23 | yield
24 |
25 | # Clean up environment
26 | if "OPENAI_API_KEY" in os.environ:
27 | del os.environ["OPENAI_API_KEY"]
28 | if "OPENAI_BASE_URL" in os.environ:
29 | del os.environ["OPENAI_BASE_URL"]
30 |
31 | def test_streaming_chat_completion(self):
32 | """Test complete streaming flow from Anthropic to OpenAI proxy."""
33 | client = TestClient(app)
34 |
35 | payload = {
36 | "model": "claude-3-sonnet-20240229",
37 | "max_tokens": 1024,
38 | "stream": True,
39 | "messages": [
40 | {"role": "user", "content": "Write a short poem about Python"}
41 | ]
42 | }
43 |
44 | with client.stream("POST", "/v1/messages", json=payload) as response:
45 | assert response.status_code == 200
46 | assert response.headers["content-type"] == "text/plain; charset=utf-8"
47 |
48 | chunks = []
49 | for line in response.iter_lines():
50 | if line and line.startswith("data: "):
51 | data_str = line[6:] # Remove "data: " prefix
52 | if data_str.strip() == "[DONE]":
53 | break
54 |
55 | try:
56 | data = json.loads(data_str)
57 | chunks.append(data)
58 | except json.JSONDecodeError:
59 | continue
60 |
61 | # Verify we received some chunks
62 | assert len(chunks) > 0
63 |
64 | # Check structure of first chunk
65 | first_chunk = chunks[0]
66 | assert "type" in first_chunk
67 | assert first_chunk["type"] == "content_block_start"
68 | assert "content_block" in first_chunk
69 |
70 | # Check structure of content chunks
71 | content_chunks = [c for c in chunks if c.get("type") == "content_block_delta"]
72 | assert len(content_chunks) > 0
73 |
74 | # Check final chunk
75 | final_chunks = [c for c in chunks if c.get("type") == "message_stop"]
76 | assert len(final_chunks) == 1
77 |
78 | def test_streaming_with_system_message(self):
79 | """Test streaming with system message included."""
80 | client = TestClient(app)
81 |
82 | payload = {
83 | "model": "claude-3-sonnet-20240229",
84 | "max_tokens": 512,
85 | "stream": True,
86 | "temperature": 0.8,
87 | "messages": [
88 | {"role": "system", "content": "You are a Python programming assistant."},
89 | {"role": "user", "content": "Explain decorators"}
90 | ]
91 | }
92 |
93 | with client.stream("POST", "/v1/messages", json=payload) as response:
94 | assert response.status_code == 200
95 |
96 | chunks = []
97 | for line in response.iter_lines():
98 | if line and line.startswith("data: "):
99 | data_str = line[6:]
100 | if data_str.strip() == "[DONE]":
101 | break
102 |
103 | try:
104 | data = json.loads(data_str)
105 | chunks.append(data)
106 | except json.JSONDecodeError:
107 | continue
108 |
109 | # Should still receive chunks even with system message
110 | assert len(chunks) > 2
111 |
112 | # Verify we got the message start
113 | start_chunks = [c for c in chunks if c.get("type") == "message_start"]
114 | assert len(start_chunks) == 1
115 |
116 | # Verify we got content blocks
117 | content_start = [c for c in chunks if c.get("type") == "content_block_start"]
118 | assert len(content_start) == 1
119 |
120 | def test_streaming_empty_content_handling(self):
121 | """Test streaming with empty or minimal content."""
122 | client = TestClient(app)
123 |
124 | payload = {
125 | "model": "claude-3-sonnet-20240229",
126 | "max_tokens": 10,
127 | "stream": True,
128 | "messages": [
129 | {"role": "user", "content": ""}
130 | ]
131 | }
132 |
133 | with client.stream("POST", "/v1/messages", json=payload) as response:
134 | assert response.status_code == 200
135 |
136 | chunks = []
137 | for line in response.iter_lines():
138 | if line and line.startswith("data: "):
139 | data_str = line[6:]
140 | if data_str.strip() == "[DONE]":
141 | break
142 |
143 | try:
144 | data = json.loads(data_str)
145 | chunks.append(data)
146 | except json.JSONDecodeError:
147 | continue
148 |
149 | # Should still receive proper streaming structure
150 | assert len(chunks) >= 3 # message_start, content_block_start, content_block_delta, message_stop
151 |
152 | def test_streaming_usage_information(self):
153 | """Test that usage information is included in streaming response."""
154 | client = TestClient(app)
155 |
156 | payload = {
157 | "model": "claude-3-sonnet-20240229",
158 | "max_tokens": 100,
159 | "stream": True,
160 | "messages": [
161 | {"role": "user", "content": "Count to 5"}
162 | ]
163 | }
164 |
165 | with client.stream("POST", "/v1/messages", json=payload) as response:
166 | assert response.status_code == 200
167 |
168 | # Collect all chunks to ensure complete flow
169 | chunks = []
170 | for line in response.iter_lines():
171 | if line and line.startswith("data: "):
172 | data_str = line[6:]
173 | if data_str.strip() == "[DONE]":
174 | break
175 |
176 | try:
177 | data = json.loads(data_str)
178 | chunks.append(data)
179 | except json.JSONDecodeError:
180 | continue
181 |
182 | # Look for message_delta with usage info
183 | usage_chunks = [c for c in chunks if c.get("type") == "message_delta"]
184 | assert len(usage_chunks) >= 1
185 |
186 | final_chunk = usage_chunks[0]
187 | assert "usage" in final_chunk
188 | usage = final_chunk["usage"]
189 | assert "input_tokens" in usage
190 | assert "output_tokens" in usage
191 |
192 | def test_streaming_response_timing(self):
193 | """Test streaming response timing performance."""
194 | import time
195 |
196 | client = TestClient(app)
197 |
198 | payload = {
199 | "model": "claude-3-sonnet-20240229",
200 | "max_tokens": 50,
201 | "stream": True,
202 | "messages": [
203 | {"role": "user", "content": "Say 'hello' three times"}
204 | ]
205 | }
206 |
207 | start_time = time.time()
208 |
209 | with client.stream("POST", "/v1/messages", json=payload) as response:
210 | assert response.status_code == 200
211 |
212 | chunks = []
213 | for line in response.iter_lines():
214 | if line and line.startswith("data: "):
215 | data_str = line[6:]
216 | if data_str.strip() == "[DONE]":
217 | break
218 |
219 | try:
220 | data = json.loads(data_str)
221 | chunks.append(data)
222 | except json.JSONDecodeError:
223 | continue
224 |
225 | end_time = time.time()
226 | duration = (end_time - start_time) * 1000
227 |
228 | # Ensure reasonable performance for streaming
229 | assert duration < 1000 # Should complete within 1 second for test
230 | assert len(chunks) > 3 # Should have multiple streaming chunks
231 |
232 | def test_streaming_with_stream_false_comparison(self):
233 | """Test that streaming works differently from non-streaming."""
234 | client = TestClient(app)
235 |
236 | request_data = {
237 | "model": "claude-3-sonnet-20240229",
238 | "max_tokens": 50,
239 | "messages": [
240 | {"role": "user", "content": "Reply with a single word: Python"}
241 | ]
242 | }
243 |
244 | # Non-streaming version
245 | non_stream_payload = request_data.copy()
246 | non_stream_payload["stream"] = False
247 | sync_response = client.post("/v1/messages", json=non_stream_payload)
248 | assert sync_response.status_code == 200
249 |
250 | sync_data = sync_response.json()
251 | assert "content" in sync_data
252 |
253 | # Streaming version
254 | stream_payload = request_data.copy()
255 | stream_payload["stream"] = True
256 |
257 | with client.stream("POST", "/v1/messages", json=stream_payload) as response:
258 | assert response.status_code == 200
259 |
260 | # Should be server-sent events format
261 | assert response.headers["content-type"] == "text/plain; charset=utf-8"
262 |
263 | # Collect streaming content
264 | collected_content = ""
265 | chunks = []
266 |
267 | for line in response.iter_lines():
268 | if line and line.startswith("data: "):
269 | data_str = line[6:]
270 | if data_str.strip() == "[DONE]":
271 | break
272 |
273 | try:
274 | data = json.loads(data_str)
275 | chunks.append(data)
276 |
277 | # Extract text content from delta chunks
278 | if data.get("type") == "content_block_delta":
279 | delta = data.get("delta", {})
280 | if "text" in delta:
281 | collected_content += delta["text"]
282 | except json.JSONDecodeError:
283 | continue
284 |
285 | # Both should return similar content, streaming breaks it into chunks
286 | assert len(collected_content) > 0
287 | assert "Python" in collected_content or "python" in collected_content.lower()
--------------------------------------------------------------------------------