├── web ├── favicon.ico ├── favicon-192.png ├── favicon-512.png ├── apple-touch-icon.png ├── assets │ ├── fonts │ │ ├── inter-400.ttf │ │ ├── inter-500.ttf │ │ ├── inter-600.ttf │ │ └── inter-700.ttf │ ├── css │ │ ├── inter.css │ │ ├── logs.css │ │ ├── tokens.css │ │ └── channels.css │ └── js │ │ ├── date-range-selector.js │ │ ├── template-engine.js │ │ ├── channels-import-export.js │ │ ├── channels-data.js │ │ ├── login.js │ │ ├── channels-init.js │ │ ├── channels-state.js │ │ └── settings.js ├── manifest.json ├── favicon.svg └── settings.html ├── internal ├── app │ ├── socket_unix.go │ ├── socket_windows.go │ ├── admin_settings_response_test.go │ ├── admin_cooldown.go │ ├── proxy_gemini.go │ ├── request_context.go │ ├── admin_response_contract_test.go │ ├── token_stats_shutdown_test.go │ ├── proxy_forward_soft_error_test.go │ ├── proxy_stream.go │ ├── health_cache.go │ ├── admin_testing_test.go │ ├── admin_auth_tokens_test.go │ ├── proxy_util_test.go │ ├── config_service.go │ ├── static.go │ ├── key_selector.go │ └── proxy_handler_test.go ├── util │ ├── serialize.go │ ├── apikeys.go │ ├── apikeys_test.go │ ├── serialize_test.go │ ├── time_additional_test.go │ ├── time_bench_test.go │ ├── channel_types_bench_test.go │ ├── time.go │ ├── channel_types.go │ ├── classifier_1308_test.go │ └── rate_limiter.go ├── storage │ ├── sqlite │ │ └── test_store_helpers_test.go │ ├── schema │ │ ├── builder_test.go │ │ └── builder.go │ ├── sql │ │ ├── metrics_finalize.go │ │ ├── admin_sessions.go │ │ ├── system_settings.go │ │ ├── metrics_aggregate_rows.go │ │ ├── store_impl.go │ │ ├── transaction_deadline_test.go │ │ └── auth_token_stats.go │ ├── health_success_rate_test.go │ └── cache_metrics_test.go ├── version │ ├── version.go │ └── banner.go ├── model │ ├── health.go │ ├── system_setting.go │ ├── config.go │ ├── log.go │ ├── auth_token.go │ └── stats.go ├── testutil │ └── types.go ├── config │ ├── defaults.go │ └── defaults_test.go └── validator │ └── validator.go ├── .dockerignore ├── .gitignore ├── docker-compose.yml ├── com.ccload.service.plist.template ├── test └── integration │ ├── setup_test.go │ └── csv_import_export_test.go ├── docker-compose.build.yml ├── .env.docker.example ├── .env.example ├── go.mod ├── CLAUDE.md ├── .github └── workflows │ └── docker.yml └── Dockerfile /web/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caidaoli/ccLoad/HEAD/web/favicon.ico -------------------------------------------------------------------------------- /web/favicon-192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caidaoli/ccLoad/HEAD/web/favicon-192.png -------------------------------------------------------------------------------- /web/favicon-512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caidaoli/ccLoad/HEAD/web/favicon-512.png -------------------------------------------------------------------------------- /web/apple-touch-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caidaoli/ccLoad/HEAD/web/apple-touch-icon.png -------------------------------------------------------------------------------- /web/assets/fonts/inter-400.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caidaoli/ccLoad/HEAD/web/assets/fonts/inter-400.ttf -------------------------------------------------------------------------------- /web/assets/fonts/inter-500.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caidaoli/ccLoad/HEAD/web/assets/fonts/inter-500.ttf -------------------------------------------------------------------------------- /web/assets/fonts/inter-600.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caidaoli/ccLoad/HEAD/web/assets/fonts/inter-600.ttf -------------------------------------------------------------------------------- /web/assets/fonts/inter-700.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caidaoli/ccLoad/HEAD/web/assets/fonts/inter-700.ttf -------------------------------------------------------------------------------- /internal/app/socket_unix.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | package app 4 | 5 | import "syscall" 6 | 7 | // setTCPNoDelay 在 Unix 系统上设置 TCP_NODELAY 8 | func setTCPNoDelay(fd uintptr) error { 9 | return syscall.SetsockoptInt(int(fd), syscall.IPPROTO_TCP, syscall.TCP_NODELAY, 1) 10 | } 11 | -------------------------------------------------------------------------------- /internal/app/socket_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package app 4 | 5 | import "syscall" 6 | 7 | // setTCPNoDelay 在 Windows 上设置 TCP_NODELAY 8 | func setTCPNoDelay(fd uintptr) error { 9 | return syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_TCP, syscall.TCP_NODELAY, 1) 10 | } 11 | -------------------------------------------------------------------------------- /internal/util/serialize.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import "github.com/bytedance/sonic" 4 | 5 | // SerializeJSON 序列化任意类型为JSON字符串,失败时返回默认值 6 | // 自动处理空值:nil返回默认值,空切片/map正常序列化为[]或{} 7 | func SerializeJSON(v any, defaultValue string) (string, error) { 8 | // 检查空值 9 | if v == nil { 10 | return defaultValue, nil 11 | } 12 | 13 | bytes, err := sonic.Marshal(v) 14 | if err != nil { 15 | return defaultValue, err 16 | } 17 | return string(bytes), nil 18 | } 19 | -------------------------------------------------------------------------------- /internal/util/apikeys.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import "strings" 4 | 5 | // ParseAPIKeys 解析 API Key 字符串(支持逗号分隔的多个 Key) 6 | // 设计原则(DRY):统一的Key解析逻辑,供多个模块复用 7 | func ParseAPIKeys(apiKey string) []string { 8 | if apiKey == "" { 9 | return []string{} 10 | } 11 | parts := strings.Split(apiKey, ",") 12 | keys := make([]string, 0, len(parts)) 13 | for _, k := range parts { 14 | k = strings.TrimSpace(k) 15 | if k != "" { 16 | keys = append(keys, k) 17 | } 18 | } 19 | return keys 20 | } 21 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # Git 相关 2 | .git 3 | .gitignore 4 | 5 | # 开发工具 6 | .vscode 7 | .idea 8 | *.swp 9 | *.swo 10 | 11 | # 日志文件 12 | logs/ 13 | *.log 14 | 15 | # 数据文件 16 | data/ 17 | *.db 18 | *.db-journal 19 | 20 | # 构建产物 21 | ccload 22 | /tmp/ 23 | 24 | # macOS LaunchAgent 相关 25 | *.plist 26 | *.plist.template 27 | com.ccload.service.plist 28 | 29 | # 测试文件 30 | *_test.go 31 | test_* 32 | 33 | # 文档 34 | README.md 35 | CLAUDE.md 36 | 37 | # 依赖和模块缓存 38 | vendor/ 39 | 40 | # 环境配置 41 | .env 42 | .env.local 43 | .env.*.local 44 | 45 | # Makefile(容器内不需要) 46 | Makefile -------------------------------------------------------------------------------- /internal/storage/sqlite/test_store_helpers_test.go: -------------------------------------------------------------------------------- 1 | package sqlite_test 2 | 3 | import ( 4 | "ccLoad/internal/storage" 5 | "testing" 6 | ) 7 | 8 | func setupSQLiteTestStore(t *testing.T, dbFile string) (storage.Store, func()) { 9 | t.Helper() 10 | 11 | tmpDB := t.TempDir() + "/" + dbFile 12 | store, err := storage.CreateSQLiteStore(tmpDB, nil) 13 | if err != nil { 14 | t.Fatalf("创建测试数据库失败: %v", err) 15 | } 16 | 17 | cleanup := func() { 18 | if err := store.Close(); err != nil { 19 | t.Logf("关闭测试数据库失败: %v", err) 20 | } 21 | } 22 | 23 | return store, cleanup 24 | } 25 | -------------------------------------------------------------------------------- /internal/version/version.go: -------------------------------------------------------------------------------- 1 | // Package version 提供应用版本信息 2 | // 版本号通过 go build -ldflags 注入,用于静态资源缓存控制 3 | package version 4 | 5 | // 构建信息变量,通过 ldflags 注入 6 | // 构建命令示例: 7 | // go build -ldflags "-X ccLoad/internal/version.Version=$(git describe --tags --always) \ 8 | // -X ccLoad/internal/version.Commit=$(git rev-parse --short HEAD) \ 9 | // -X 'ccLoad/internal/version.BuildTime=$(date +%Y-%m-%d\ %H:%M:%S\ %z)' \ 10 | // -X ccLoad/internal/version.BuiltBy=$(whoami)" 11 | var ( 12 | Version = "dev" 13 | Commit = "unknown" 14 | BuildTime = "unknown" 15 | BuiltBy = "unknown" 16 | ) 17 | -------------------------------------------------------------------------------- /web/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Claude Code & Codex Proxy", 3 | "short_name": "ccLoad", 4 | "description": "Claude API代理管理服务", 5 | "start_url": "/web/index.html", 6 | "display": "standalone", 7 | "background_color": "#ffffff", 8 | "theme_color": "#3b82f6", 9 | "icons": [ 10 | { 11 | "src": "/web/favicon-192.png", 12 | "sizes": "192x192", 13 | "type": "image/png", 14 | "purpose": "any maskable" 15 | }, 16 | { 17 | "src": "/web/favicon-512.png", 18 | "sizes": "512x512", 19 | "type": "image/png", 20 | "purpose": "any maskable" 21 | } 22 | ] 23 | } 24 | -------------------------------------------------------------------------------- /internal/model/health.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | // HealthScoreConfig 健康度排序配置 4 | type HealthScoreConfig struct { 5 | Enabled bool // 是否启用健康度排序 6 | SuccessRatePenaltyWeight float64 // 成功率惩罚权重(乘以失败率) 7 | WindowMinutes int // 成功率统计时间窗口(分钟) 8 | UpdateIntervalSeconds int // 成功率缓存更新间隔(秒) 9 | } 10 | 11 | // DefaultHealthScoreConfig 返回默认健康度配置 12 | func DefaultHealthScoreConfig() HealthScoreConfig { 13 | return HealthScoreConfig{ 14 | Enabled: false, 15 | SuccessRatePenaltyWeight: 100, 16 | WindowMinutes: 5, 17 | UpdateIntervalSeconds: 30, 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /internal/model/system_setting.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "errors" 4 | 5 | // ErrSettingNotFound 系统设置未找到错误 6 | var ErrSettingNotFound = errors.New("setting not found") 7 | 8 | // SystemSetting 系统配置项 9 | type SystemSetting struct { 10 | Key string `json:"key"` // 配置键(如log_retention_days) 11 | Value string `json:"value"` // 配置值(字符串存储,运行时解析) 12 | ValueType string `json:"value_type"` // 值类型(int/bool/string/duration) 13 | Description string `json:"description"` // 配置说明(用于前端显示) 14 | DefaultValue string `json:"default_value"` // 默认值(用于重置功能) 15 | UpdatedAt int64 `json:"updated_at"` // 更新时间(Unix秒) 16 | } 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE and editor files 2 | .gocache 3 | .idea 4 | .vscode 5 | *.swp 6 | *.swo 7 | *~ 8 | 9 | # Data and database files 10 | data/ 11 | *.db 12 | *.sqlite 13 | *.sqlite3 14 | 15 | # Build artifacts 16 | ccLoad 17 | ccload 18 | /tmp/ccload* 19 | *.exe 20 | *.dll 21 | *.so 22 | *.dylib 23 | 24 | # Test files 25 | *.test 26 | *.out 27 | test*.sh 28 | 29 | # Environment files 30 | .env 31 | .env.local 32 | .env.*.local 33 | 34 | # OS files 35 | .DS_Store 36 | Thumbs.db 37 | 38 | # Temporary files 39 | *.tmp 40 | *.temp 41 | /tmp/ 42 | 43 | # Log files 44 | *.log 45 | 46 | # Playwright MCP 47 | .playwright-mcp 48 | .claude 49 | com.ccload.service.plist 50 | .dataX 51 | .serena 52 | .gocache 53 | .gomodcache 54 | AGENTS.md 55 | dist 56 | *.bak 57 | .docs -------------------------------------------------------------------------------- /web/assets/css/inter.css: -------------------------------------------------------------------------------- 1 | @font-face { 2 | font-family: 'Inter'; 3 | font-style: normal; 4 | font-weight: 400; 5 | font-display: swap; 6 | src: url(../fonts/inter-400.ttf) format('truetype'); 7 | } 8 | @font-face { 9 | font-family: 'Inter'; 10 | font-style: normal; 11 | font-weight: 500; 12 | font-display: swap; 13 | src: url(../fonts/inter-500.ttf) format('truetype'); 14 | } 15 | @font-face { 16 | font-family: 'Inter'; 17 | font-style: normal; 18 | font-weight: 600; 19 | font-display: swap; 20 | src: url(../fonts/inter-600.ttf) format('truetype'); 21 | } 22 | @font-face { 23 | font-family: 'Inter'; 24 | font-style: normal; 25 | font-weight: 700; 26 | font-display: swap; 27 | src: url(../fonts/inter-700.ttf) format('truetype'); 28 | } 29 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | ccload: 5 | image: ghcr.io/caidaoli/ccload:latest 6 | container_name: ccload 7 | user: root 8 | restart: unless-stopped 9 | ports: 10 | - "8080:8080" 11 | environment: 12 | - PORT=8080 13 | - SQLITE_PATH=/app/data/ccload.db 14 | - GIN_MODE=release 15 | # 必填:未设置将无法启动 16 | - CCLOAD_PASS=your_admin_password 17 | - TZ=Asia/Shanghai 18 | # API访问令牌通过Web界面管理: http://localhost:8080/web/tokens.html 19 | volumes: 20 | - ./data:/app/data 21 | healthcheck: 22 | test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:8080/health"] 23 | interval: 30s 24 | timeout: 10s 25 | retries: 3 26 | start_period: 40s 27 | 28 | 29 | -------------------------------------------------------------------------------- /com.ccload.service.plist.template: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Label 6 | com.ccload.service 7 | ProgramArguments 8 | 9 | {{PROJECT_DIR}}/ccload 10 | 11 | WorkingDirectory 12 | {{PROJECT_DIR}} 13 | RunAtLoad 14 | 15 | KeepAlive 16 | 17 | StandardOutPath 18 | /dev/null 19 | StandardErrorPath 20 | /dev/null 21 | EnvironmentVariables 22 | 23 | PATH 24 | /usr/local/bin:/usr/bin:/bin 25 | 26 | 27 | -------------------------------------------------------------------------------- /internal/testutil/types.go: -------------------------------------------------------------------------------- 1 | package testutil 2 | 3 | import "fmt" 4 | 5 | // TestChannelRequest 渠道测试请求结构 6 | type TestChannelRequest struct { 7 | Model string `json:"model" binding:"required"` 8 | MaxTokens int `json:"max_tokens,omitempty"` // 可选,默认512 9 | Stream bool `json:"stream,omitempty"` // 可选,流式响应 10 | Content string `json:"content,omitempty"` // 可选,测试内容,默认"test" 11 | Headers map[string]string `json:"headers,omitempty"` // 可选,自定义请求头 12 | ChannelType string `json:"channel_type,omitempty"` // 可选,渠道类型:anthropic(默认)、codex、gemini 13 | KeyIndex int `json:"key_index,omitempty"` // 可选,指定测试的Key索引,默认0(第一个) 14 | } 15 | 16 | // Validate 实现RequestValidator接口 17 | func (tr *TestChannelRequest) Validate() error { 18 | if tr.Model == "" { 19 | return fmt.Errorf("model cannot be empty") 20 | } 21 | return nil 22 | } 23 | -------------------------------------------------------------------------------- /test/integration/setup_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "ccLoad/internal/storage" 5 | "context" 6 | "path/filepath" 7 | "testing" 8 | ) 9 | 10 | // setupTestStore 创建测试用的 SQLite Store 11 | func setupTestStore(t *testing.T) (storage.Store, func()) { 12 | t.Helper() 13 | 14 | // 创建临时目录和数据库文件 15 | tmpDir := t.TempDir() 16 | dbPath := filepath.Join(tmpDir, "test.db") 17 | 18 | store, err := storage.CreateSQLiteStore(dbPath, nil) 19 | if err != nil { 20 | t.Fatalf("创建测试数据库失败: %v", err) 21 | } 22 | 23 | // 返回清理函数 24 | cleanup := func() { 25 | if err := store.Close(); err != nil { 26 | t.Logf("⚠️ 关闭数据库失败: %v", err) 27 | } 28 | } 29 | 30 | return store, cleanup 31 | } 32 | 33 | // setupTestStoreWithContext 创建测试用的 Store 和 Context 34 | func setupTestStoreWithContext(t *testing.T) (storage.Store, context.Context, func()) { 35 | t.Helper() 36 | 37 | store, cleanup := setupTestStore(t) 38 | ctx := context.Background() 39 | 40 | return store, ctx, cleanup 41 | } 42 | 43 | -------------------------------------------------------------------------------- /docker-compose.build.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | ccload: 5 | build: 6 | context: . 7 | dockerfile: Dockerfile 8 | args: 9 | # 版本号:用于静态资源缓存控制 10 | # - dev(默认):开发环境,静态资源不缓存 11 | # - v1.x.x:生产环境,静态资源长缓存 12 | # 生产构建: VERSION=$(git describe --tags --always) docker-compose -f docker-compose.build.yml build 13 | VERSION: ${VERSION:-dev} 14 | image: ccload:local 15 | container_name: ccload 16 | restart: unless-stopped 17 | ports: 18 | - "8080:8080" 19 | environment: 20 | - PORT=8080 21 | - SQLITE_PATH=/app/data/ccload.db 22 | - GIN_MODE=release 23 | # 必填:未设置将无法启动 24 | - CCLOAD_PASS=your_admin_password 25 | # API访问令牌通过Web界面管理: http://localhost:8080/web/tokens.html 26 | volumes: 27 | - ccload_data:/app/data 28 | healthcheck: 29 | test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:8080/health"] 30 | interval: 30s 31 | timeout: 10s 32 | retries: 3 33 | start_period: 40s 34 | 35 | volumes: 36 | ccload_data: 37 | driver: local 38 | -------------------------------------------------------------------------------- /web/favicon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /internal/app/admin_settings_response_test.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/gin-gonic/gin" 10 | ) 11 | 12 | func TestAdminAPI_ListSettings_ResponseShape(t *testing.T) { 13 | server, store, cleanup := setupAdminTestServer(t) 14 | defer cleanup() 15 | 16 | server.configService = NewConfigService(store) 17 | 18 | w := httptest.NewRecorder() 19 | c, _ := gin.CreateTestContext(w) 20 | c.Request = httptest.NewRequest(http.MethodGet, "/admin/settings", nil) 21 | 22 | server.AdminListSettings(c) 23 | 24 | if w.Code != http.StatusOK { 25 | t.Fatalf("Expected 200, got %d", w.Code) 26 | } 27 | 28 | var resp map[string]any 29 | if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { 30 | t.Fatalf("Parse error: %v", err) 31 | } 32 | 33 | if resp["success"] != true { 34 | t.Fatalf("Expected success=true, got %v", resp["success"]) 35 | } 36 | 37 | data := resp["data"] 38 | if data == nil { 39 | t.Fatalf("Expected data to be [], got null") 40 | } 41 | if _, ok := data.([]any); !ok { 42 | t.Fatalf("Expected data to be array, got %T", data) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /internal/util/apikeys_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | // TestParseAPIKeys 测试API Key解析 8 | func TestParseAPIKeys(t *testing.T) { 9 | tests := []struct { 10 | name string 11 | input string 12 | expected []string 13 | }{ 14 | { 15 | name: "单个Key", 16 | input: "sk-test-key", 17 | expected: []string{"sk-test-key"}, 18 | }, 19 | { 20 | name: "多个Key (逗号分隔)", 21 | input: "sk-key1,sk-key2,sk-key3", 22 | expected: []string{"sk-key1", "sk-key2", "sk-key3"}, 23 | }, 24 | { 25 | name: "带空格的Key", 26 | input: " sk-key1 , sk-key2 , sk-key3 ", 27 | expected: []string{"sk-key1", "sk-key2", "sk-key3"}, 28 | }, 29 | { 30 | name: "空字符串", 31 | input: "", 32 | expected: []string{}, 33 | }, 34 | { 35 | name: "仅空格", 36 | input: " ", 37 | expected: []string{}, 38 | }, 39 | { 40 | name: "包含空项", 41 | input: "sk-key1,,sk-key3", 42 | expected: []string{"sk-key1", "sk-key3"}, 43 | }, 44 | } 45 | 46 | for _, tt := range tests { 47 | t.Run(tt.name, func(t *testing.T) { 48 | result := ParseAPIKeys(tt.input) 49 | if len(result) != len(tt.expected) { 50 | t.Errorf("期望 %d 个key, 实际 %d 个", len(tt.expected), len(result)) 51 | return 52 | } 53 | for i, key := range result { 54 | if key != tt.expected[i] { 55 | t.Errorf("索引 %d: 期望 %q, 实际 %q", i, tt.expected[i], key) 56 | } 57 | } 58 | }) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /internal/util/serialize_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | // TestSerializeJSON 测试JSON序列化 8 | func TestSerializeJSON(t *testing.T) { 9 | tests := []struct { 10 | name string 11 | input any 12 | defaultValue string 13 | expected string 14 | expectError bool 15 | }{ 16 | { 17 | name: "空数组", 18 | input: []string{}, 19 | defaultValue: "[]", 20 | expected: "[]", 21 | expectError: false, 22 | }, 23 | { 24 | name: "单个元素", 25 | input: []string{"test"}, 26 | defaultValue: "[]", 27 | expected: `["test"]`, 28 | expectError: false, 29 | }, 30 | { 31 | name: "多个元素", 32 | input: []string{"a", "b", "c"}, 33 | defaultValue: "[]", 34 | expected: `["a","b","c"]`, 35 | expectError: false, 36 | }, 37 | { 38 | name: "nil值返回默认值", 39 | input: nil, 40 | defaultValue: "default", 41 | expected: "default", 42 | expectError: false, 43 | }, 44 | { 45 | name: "map对象", 46 | input: map[string]string{"key": "value"}, 47 | defaultValue: "{}", 48 | expected: `{"key":"value"}`, 49 | expectError: false, 50 | }, 51 | } 52 | 53 | for _, tt := range tests { 54 | t.Run(tt.name, func(t *testing.T) { 55 | result, err := SerializeJSON(tt.input, tt.defaultValue) 56 | if (err != nil) != tt.expectError { 57 | t.Errorf("期望错误=%v, 实际错误=%v", tt.expectError, err) 58 | } 59 | if result != tt.expected { 60 | t.Errorf("期望 %q, 实际 %q", tt.expected, result) 61 | } 62 | }) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /internal/storage/schema/builder_test.go: -------------------------------------------------------------------------------- 1 | package schema 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestChannelsTableGeneration(t *testing.T) { 8 | channels := DefineChannelsTable() 9 | 10 | t.Run("MySQL DDL", func(t *testing.T) { 11 | sql := channels.BuildMySQL() 12 | t.Logf("MySQL DDL:\n%s", sql) 13 | 14 | // 验证关键字 15 | if !contains(sql, "INT PRIMARY KEY AUTO_INCREMENT") { 16 | t.Error("Missing AUTO_INCREMENT") 17 | } 18 | if !contains(sql, "VARCHAR(191)") { 19 | t.Error("Missing VARCHAR") 20 | } 21 | }) 22 | 23 | t.Run("SQLite DDL", func(t *testing.T) { 24 | sql := channels.BuildSQLite() 25 | t.Logf("SQLite DDL:\n%s", sql) 26 | 27 | // 验证类型转换 28 | if !contains(sql, "INTEGER PRIMARY KEY AUTOINCREMENT") { 29 | t.Error("Missing AUTOINCREMENT") 30 | } 31 | if !contains(sql, "TEXT") { 32 | t.Error("Missing TEXT type") 33 | } 34 | if contains(sql, "VARCHAR") { 35 | t.Error("VARCHAR not converted to TEXT") 36 | } 37 | }) 38 | 39 | t.Run("Indexes", func(t *testing.T) { 40 | mysqlIndexes := channels.GetIndexesMySQL() 41 | sqliteIndexes := channels.GetIndexesSQLite() 42 | 43 | if len(mysqlIndexes) != 4 { 44 | t.Errorf("Expected 4 MySQL indexes, got %d", len(mysqlIndexes)) 45 | } 46 | 47 | // 验证SQLite索引包含IF NOT EXISTS 48 | for _, idx := range sqliteIndexes { 49 | if !contains(idx.SQL, "IF NOT EXISTS") { 50 | t.Errorf("SQLite index missing IF NOT EXISTS: %s", idx.SQL) 51 | } 52 | } 53 | 54 | t.Logf("MySQL indexes: %d", len(mysqlIndexes)) 55 | t.Logf("SQLite indexes: %d", len(sqliteIndexes)) 56 | }) 57 | } 58 | 59 | func contains(s, substr string) bool { 60 | return len(s) > 0 && len(substr) > 0 && stringContains(s, substr) 61 | } 62 | 63 | func stringContains(s, substr string) bool { 64 | for i := 0; i <= len(s)-len(substr); i++ { 65 | if s[i:i+len(substr)] == substr { 66 | return true 67 | } 68 | } 69 | return false 70 | } 71 | -------------------------------------------------------------------------------- /internal/util/time_additional_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | // TestCalculateCooldownDuration 测试冷却持续时间计算 9 | func TestCalculateCooldownDuration(t *testing.T) { 10 | now := time.Now() 11 | 12 | tests := []struct { 13 | name string 14 | until time.Time 15 | now time.Time 16 | expected int64 17 | }{ 18 | { 19 | name: "正常冷却(60秒)", 20 | until: now.Add(60 * time.Second), 21 | now: now, 22 | expected: 60000, // 60秒 = 60000毫秒 23 | }, 24 | { 25 | name: "已过期冷却", 26 | until: now.Add(-10 * time.Second), 27 | now: now, 28 | expected: 0, 29 | }, 30 | { 31 | name: "零时间", 32 | until: time.Time{}, 33 | now: now, 34 | expected: 0, 35 | }, 36 | { 37 | name: "1分钟冷却", 38 | until: now.Add(1 * time.Minute), 39 | now: now, 40 | expected: 60000, 41 | }, 42 | { 43 | name: "30分钟冷却", 44 | until: now.Add(30 * time.Minute), 45 | now: now, 46 | expected: 1800000, 47 | }, 48 | } 49 | 50 | for _, tt := range tests { 51 | t.Run(tt.name, func(t *testing.T) { 52 | result := CalculateCooldownDuration(tt.until, tt.now) 53 | 54 | // 允许小幅误差(±100毫秒) 55 | diff := result - tt.expected 56 | if diff < 0 { 57 | diff = -diff 58 | } 59 | if diff > 100 { 60 | t.Errorf("期望 %d 毫秒, 实际 %d 毫秒", tt.expected, result) 61 | } 62 | }) 63 | } 64 | } 65 | 66 | // TestCalculateCooldownDuration_Precision 测试精度 67 | func TestCalculateCooldownDuration_Precision(t *testing.T) { 68 | now := time.Now() 69 | 70 | // 测试毫秒级精度 71 | until := now.Add(1234 * time.Millisecond) 72 | result := CalculateCooldownDuration(until, now) 73 | 74 | expected := int64(1234) 75 | diff := result - expected 76 | if diff < 0 { 77 | diff = -diff 78 | } 79 | 80 | // 允许±1毫秒误差 81 | if diff > 1 { 82 | t.Errorf("精度测试失败: 期望 %d ms, 实际 %d ms", expected, result) 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /internal/version/banner.go: -------------------------------------------------------------------------------- 1 | package version 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "golang.org/x/term" 8 | ) 9 | 10 | const banner = ` 11 | ██████╗ ██████╗ ██╗ ██████╗ █████╗ ██████╗ 12 | ██╔════╝ ██╔════╝ ██║ ██╔═══██╗ ██╔══██╗ ██╔══██╗ 13 | ██║ ██║ ██║ ██║ ██║ ███████║ ██║ ██║ 14 | ██║ ██║ ██║ ██║ ██║ ██╔══██║ ██║ ██║ 15 | ╚██████╗ ╚██████╗ ███████╗ ╚██████╔╝ ██║ ██║ ██████╔╝ 16 | ╚═════╝ ╚═════╝ ╚══════╝ ╚═════╝ ╚═╝ ╚═╝ ╚═════╝ 17 | ` 18 | 19 | const repoURL = "https://github.com/caidaoli/ccLoad" 20 | 21 | // ANSI 颜色码 22 | const ( 23 | colorReset = "\033[0m" 24 | colorCyan = "\033[36m" 25 | colorGreen = "\033[32m" 26 | colorYellow = "\033[33m" 27 | colorBlue = "\033[34m" 28 | ) 29 | 30 | // PrintBanner 打印启动 Banner 和版本信息到 stderr 31 | func PrintBanner() { 32 | // 检测是否为终端,非终端不输出颜色 33 | isTTY := term.IsTerminal(int(os.Stderr.Fd())) 34 | 35 | if isTTY { 36 | fmt.Fprintf(os.Stderr, "%s%s%s", colorCyan, banner, colorReset) 37 | fmt.Fprintf(os.Stderr, " %sAPI Load Balancer & Proxy%s\n\n", colorYellow, colorReset) 38 | fmt.Fprintf(os.Stderr, "%-14s %s%s%s\n", "Version:", colorGreen, Version, colorReset) 39 | fmt.Fprintf(os.Stderr, "%-14s %s%s%s\n", "Commit:", colorGreen, Commit, colorReset) 40 | fmt.Fprintf(os.Stderr, "%-14s %s%s%s\n", "Build Time:", colorGreen, BuildTime, colorReset) 41 | fmt.Fprintf(os.Stderr, "%-14s %s%s%s\n", "Built By:", colorGreen, BuiltBy, colorReset) 42 | fmt.Fprintf(os.Stderr, "%-14s %s%s%s\n\n", "Repo:", colorBlue, repoURL, colorReset) 43 | } else { 44 | fmt.Fprint(os.Stderr, banner) 45 | fmt.Fprintf(os.Stderr, " API Load Balancer & Proxy\n\n") 46 | fmt.Fprintf(os.Stderr, "%-14s %s\n", "Version:", Version) 47 | fmt.Fprintf(os.Stderr, "%-14s %s\n", "Commit:", Commit) 48 | fmt.Fprintf(os.Stderr, "%-14s %s\n", "Build Time:", BuildTime) 49 | fmt.Fprintf(os.Stderr, "%-14s %s\n", "Built By:", BuiltBy) 50 | fmt.Fprintf(os.Stderr, "%-14s %s\n\n", "Repo:", repoURL) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /.env.docker.example: -------------------------------------------------------------------------------- 1 | # ccLoad Docker 环境配置示例 2 | # 复制此文件为 .env 并根据需要修改配置 3 | 4 | # ======================================== 5 | # 核心配置(必需) 6 | # ======================================== 7 | 8 | # 管理后台密码(必需,未设置将导致程序退出) 9 | CCLOAD_PASS=your_secure_admin_password 10 | 11 | # API 访问令牌通过 Web 管理界面动态配置 12 | # 访问 http://localhost:8080/web/tokens.html 进行令牌管理 13 | 14 | # ======================================== 15 | # 数据库配置 16 | # ======================================== 17 | 18 | # 数据库文件路径(容器内路径,通常不需要修改) 19 | SQLITE_PATH=/app/data/ccload.db 20 | 21 | # SQLite Journal 模式(可选,默认: WAL) 22 | # 可选值: WAL | DELETE | TRUNCATE | PERSIST | MEMORY | OFF 23 | # - WAL(默认):Write-Ahead Logging,高性能,适合本地文件系统 24 | # - TRUNCATE:传统回滚日志,适合 Docker/K8s 环境或网络存储(NFS等) 25 | # - DELETE:与 TRUNCATE 类似,但删除日志文件而非截断 26 | # ⚠️ 容器环境建议:SQLITE_JOURNAL_MODE=TRUNCATE(避免WAL文件损坏风险) 27 | # SQLITE_JOURNAL_MODE=TRUNCATE 28 | 29 | # ======================================== 30 | # 网络配置 31 | # ======================================== 32 | 33 | # HTTP 服务端口(容器内端口,通常不需要修改) 34 | PORT=8080 35 | 36 | # TLS 证书验证(可选,默认: false) 37 | # 仅开发环境使用,生产环境严禁禁用 TLS 验证 38 | # CCLOAD_SKIP_TLS_VERIFY=true 39 | 40 | # ======================================== 41 | # 性能优化配置 42 | # ======================================== 43 | 44 | # 单个渠道内最大 Key 重试次数(可选,默认: 3) 45 | # CCLOAD_MAX_KEY_RETRIES=3 46 | 47 | # 最大并发请求数(可选,默认: 1000) 48 | # 限制同时处理的代理请求数量,防止goroutine爆炸 49 | # CCLOAD_MAX_CONCURRENCY=1000 50 | 51 | # 请求体最大字节数(可选,默认: 2097152,即 2MB) 52 | # 限制单个API请求体的大小,防止大包打爆内存 53 | # CCLOAD_MAX_BODY_BYTES=2097152 54 | 55 | # 上游首字节超时(可选,单位: 秒,默认: 不设置) 56 | # 设置后,如果上游API在指定时间内未返回首字节,则请求超时 57 | # CCLOAD_UPSTREAM_FIRST_BYTE_TIMEOUT=30 58 | 59 | 60 | 61 | # ======================================== 62 | # Redis 同步配置(推荐) 63 | # ======================================== 64 | 65 | # Redis 连接 URL(推荐配置,用于渠道数据同步) 66 | # Docker Compose 环境下可使用服务名 67 | # REDIS_URL=redis://redis:6379 68 | 69 | # ======================================== 70 | # 运行模式配置 71 | # ======================================== 72 | 73 | # Gin 运行模式(release/debug) 74 | GIN_MODE=release 75 | -------------------------------------------------------------------------------- /internal/app/admin_cooldown.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "strconv" 7 | "time" 8 | 9 | "github.com/gin-gonic/gin" 10 | ) 11 | 12 | // ==================== 冷却管理 ==================== 13 | // 从admin.go拆分冷却管理,遵循SRP原则 14 | 15 | // handleSetChannelCooldown 设置渠道级别冷却 16 | func (s *Server) HandleSetChannelCooldown(c *gin.Context) { 17 | id, err := ParseInt64Param(c, "id") 18 | if err != nil { 19 | RespondErrorMsg(c, http.StatusBadRequest, "invalid channel ID") 20 | return 21 | } 22 | 23 | var req CooldownRequest 24 | if err := c.ShouldBindJSON(&req); err != nil { 25 | RespondError(c, http.StatusBadRequest, err) 26 | return 27 | } 28 | 29 | until := time.Now().Add(time.Duration(req.DurationMs) * time.Millisecond) 30 | err = s.store.SetChannelCooldown(c.Request.Context(), id, until) 31 | if err != nil { 32 | RespondError(c, http.StatusInternalServerError, err) 33 | return 34 | } 35 | 36 | // 精确计数(手动设置渠道冷却 37 | 38 | RespondJSON(c, http.StatusOK, gin.H{"message": fmt.Sprintf("渠道已冷却 %d 毫秒", req.DurationMs)}) 39 | } 40 | 41 | // handleSetKeyCooldown 设置Key级别冷却 42 | func (s *Server) HandleSetKeyCooldown(c *gin.Context) { 43 | id, err := ParseInt64Param(c, "id") 44 | if err != nil { 45 | RespondErrorMsg(c, http.StatusBadRequest, "invalid channel ID") 46 | return 47 | } 48 | 49 | keyIndexStr := c.Param("keyIndex") 50 | keyIndex, err := strconv.Atoi(keyIndexStr) 51 | if err != nil || keyIndex < 0 { 52 | RespondErrorMsg(c, http.StatusBadRequest, "invalid key index") 53 | return 54 | } 55 | 56 | var req CooldownRequest 57 | if err := c.ShouldBindJSON(&req); err != nil { 58 | RespondError(c, http.StatusBadRequest, err) 59 | return 60 | } 61 | 62 | until := time.Now().Add(time.Duration(req.DurationMs) * time.Millisecond) 63 | err = s.store.SetKeyCooldown(c.Request.Context(), id, keyIndex, until) 64 | if err != nil { 65 | RespondError(c, http.StatusInternalServerError, err) 66 | return 67 | } 68 | 69 | // [INFO] 修复:使API Keys缓存失效,确保前端能立即看到冷却状态 70 | s.InvalidateAPIKeysCache(id) 71 | 72 | RespondJSON(c, http.StatusOK, gin.H{"message": fmt.Sprintf("Key #%d 已冷却 %d 毫秒", keyIndex+1, req.DurationMs)}) 73 | } 74 | -------------------------------------------------------------------------------- /internal/util/time_bench_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | // BenchmarkCalculateBackoffDuration_AuthError 基准测试:401认证错误首次冷却 9 | func BenchmarkCalculateBackoffDuration_AuthError(b *testing.B) { 10 | statusCode := 401 11 | now := time.Now() 12 | 13 | b.ReportAllocs() 14 | for b.Loop() { 15 | _ = CalculateBackoffDuration(0, time.Time{}, now, &statusCode) 16 | } 17 | } 18 | 19 | // BenchmarkCalculateBackoffDuration_OtherError 基准测试:500服务器错误首次冷却 20 | func BenchmarkCalculateBackoffDuration_OtherError(b *testing.B) { 21 | statusCode := 500 22 | now := time.Now() 23 | 24 | b.ReportAllocs() 25 | for b.Loop() { 26 | _ = CalculateBackoffDuration(0, time.Time{}, now, &statusCode) 27 | } 28 | } 29 | 30 | // BenchmarkCalculateBackoffDuration_ExponentialBackoff 基准测试:指数退避计算 31 | func BenchmarkCalculateBackoffDuration_ExponentialBackoff(b *testing.B) { 32 | statusCode := 401 33 | now := time.Now() 34 | prevMs := int64(5 * time.Minute / time.Millisecond) 35 | 36 | b.ReportAllocs() 37 | for b.Loop() { 38 | _ = CalculateBackoffDuration(prevMs, time.Unix(0, 0), now, &statusCode) 39 | } 40 | } 41 | 42 | // BenchmarkCalculateBackoffDuration_NilStatusCode 基准测试:无状态码场景(网络错误) 43 | func BenchmarkCalculateBackoffDuration_NilStatusCode(b *testing.B) { 44 | now := time.Now() 45 | 46 | b.ResetTimer() 47 | b.ReportAllocs() 48 | for i := 0; i < b.N; i++ { 49 | _ = CalculateBackoffDuration(0, time.Time{}, now, nil) 50 | } 51 | } 52 | 53 | // BenchmarkCalculateBackoffDuration_MaxLimit 基准测试:达到上限30分钟场景 54 | func BenchmarkCalculateBackoffDuration_MaxLimit(b *testing.B) { 55 | statusCode := 401 56 | now := time.Now() 57 | prevMs := int64(20 * time.Minute / time.Millisecond) // 20分钟 * 2 = 40分钟(超过上限) 58 | 59 | b.ReportAllocs() 60 | for b.Loop() { 61 | _ = CalculateBackoffDuration(prevMs, time.Unix(0, 0), now, &statusCode) 62 | } 63 | } 64 | 65 | // BenchmarkCalculateCooldownDuration 基准测试:计算冷却持续时间 66 | func BenchmarkCalculateCooldownDuration(b *testing.B) { 67 | now := time.Now() 68 | until := now.Add(5 * time.Minute) 69 | 70 | b.ReportAllocs() 71 | for b.Loop() { 72 | _ = CalculateCooldownDuration(until, now) 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /internal/app/proxy_gemini.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/gin-gonic/gin" 7 | ) 8 | 9 | // ============================================================================ 10 | // Gemini API 特殊处理 11 | // ============================================================================ 12 | 13 | // handleListGeminiModels 处理 GET /v1beta/models 请求,返回本地 Gemini 模型列表 14 | // 从proxy.go提取,遵循SRP原则 15 | func (s *Server) handleListGeminiModels(c *gin.Context) { 16 | ctx := c.Request.Context() 17 | 18 | // 获取所有 gemini 渠道的去重模型列表 19 | models, err := s.getModelsByChannelType(ctx, "gemini") 20 | if err != nil { 21 | c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load models"}) 22 | return 23 | } 24 | 25 | // 构造 Gemini API 响应格式 26 | type ModelInfo struct { 27 | Name string `json:"name"` 28 | DisplayName string `json:"displayName"` 29 | } 30 | 31 | modelList := make([]ModelInfo, 0, len(models)) 32 | for _, model := range models { 33 | modelList = append(modelList, ModelInfo{ 34 | Name: "models/" + model, 35 | DisplayName: formatModelDisplayName(model), 36 | }) 37 | } 38 | 39 | c.JSON(http.StatusOK, gin.H{ 40 | "models": modelList, 41 | }) 42 | } 43 | 44 | // handleListOpenAIModels 处理 GET /v1/models 请求,返回本地 OpenAI 模型列表 45 | func (s *Server) handleListOpenAIModels(c *gin.Context) { 46 | ctx := c.Request.Context() 47 | 48 | // 获取所有 openai 渠道的去重模型列表 49 | models, err := s.getModelsByChannelType(ctx, "openai") 50 | if err != nil { 51 | c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load models"}) 52 | return 53 | } 54 | 55 | // 构造 OpenAI API 响应格式 56 | type ModelInfo struct { 57 | ID string `json:"id"` 58 | Object string `json:"object"` 59 | Created int64 `json:"created"` 60 | OwnedBy string `json:"owned_by"` 61 | } 62 | 63 | modelList := make([]ModelInfo, 0, len(models)) 64 | for _, model := range models { 65 | modelList = append(modelList, ModelInfo{ 66 | ID: model, 67 | Object: "model", 68 | Created: 0, 69 | OwnedBy: "system", 70 | }) 71 | } 72 | 73 | c.JSON(http.StatusOK, gin.H{ 74 | "object": "list", 75 | "data": modelList, 76 | }) 77 | } 78 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # ccLoad 环境配置示例文件 2 | # 复制此文件为 .env 并根据需要修改配置值 3 | 4 | # ======================================== 5 | # 核心配置(必需) 6 | # ======================================== 7 | 8 | # 管理后台密码(必需,未设置将导致程序退出) 9 | CCLOAD_PASS=your_strong_password_here 10 | 11 | # API 访问令牌通过 Web 管理界面动态配置 12 | # 访问 http://localhost:8080/web/tokens.html 进行令牌管理 13 | 14 | # ======================================== 15 | # 数据库配置 16 | # ======================================== 17 | 18 | # SQLite 数据库路径(可选,默认: data/ccload.db) 19 | SQLITE_PATH=./data/ccload.db 20 | 21 | # SQLite Journal 模式(可选,默认: WAL) 22 | # 可选值: WAL | DELETE | TRUNCATE | PERSIST | MEMORY | OFF 23 | # - WAL(默认):Write-Ahead Logging,高性能,适合本地文件系统 24 | # - TRUNCATE:传统回滚日志,适合 Docker/K8s 环境或网络存储(NFS等) 25 | # - DELETE:与 TRUNCATE 类似,但删除日志文件而非截断 26 | # ⚠️ 容器环境建议:SQLITE_JOURNAL_MODE=TRUNCATE(避免WAL文件损坏风险) 27 | # SQLITE_JOURNAL_MODE=WAL 28 | 29 | # ======================================== 30 | # 网络配置 31 | # ======================================== 32 | 33 | # HTTP 服务端口(可选,默认: 8080) 34 | PORT=8080 35 | 36 | # ======================================== 37 | # 性能优化配置 38 | # ======================================== 39 | 40 | # 最大并发请求数(可选,默认: 1000) 41 | # 限制同时处理的代理请求数量,防止goroutine爆炸 42 | # CCLOAD_MAX_CONCURRENCY=1000 43 | 44 | # 请求体最大字节数(可选,默认: 2097152,即 2MB) 45 | # 限制单个API请求体的大小,防止大包打爆内存 46 | # CCLOAD_MAX_BODY_BYTES=2097152 47 | 48 | # ======================================== 49 | # Redis 同步配置(可选) 50 | # ======================================== 51 | 52 | # Redis 连接 URL(可选,用于渠道数据同步备份) 53 | # 格式: redis://localhost:6379 或 redis://user:password@localhost:6379/0 54 | # 启用内存数据库模式时强烈推荐配置,用于故障恢复 55 | # REDIS_URL=redis://localhost:6379 56 | 57 | # ======================================== 58 | # 运行模式配置 59 | # ======================================== 60 | 61 | # Gin 运行模式(可选,默认: release) 62 | # 生产环境建议设置为 release 63 | # GIN_MODE=release 64 | 65 | # ======================================== 66 | # 系统配置(已迁移到 Web 管理界面) 67 | # ======================================== 68 | # 以下配置项已迁移到数据库,通过 Web 界面管理,支持热重载: 69 | # - 日志保留天数 (log_retention_days) 70 | # - 单渠道最大Key重试次数 (max_key_retries) 71 | # - 上游首字节超时 (upstream_first_byte_timeout) 72 | # - 88code免费套餐限制 (88code_free_only) 73 | # - TLS证书验证 (skip_tls_verify) 74 | # 75 | # 访问 http://localhost:8080/web/settings.html 进行配置管理 76 | -------------------------------------------------------------------------------- /internal/storage/sql/metrics_finalize.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "ccLoad/internal/model" 5 | "context" 6 | "fmt" 7 | "log" 8 | "time" 9 | ) 10 | 11 | type metricAggregationHelper struct { 12 | totalFirstByteTime float64 13 | firstByteCount int 14 | totalDuration float64 15 | durationCount int 16 | } 17 | 18 | func (s *SQLStore) finalizeMetricPoints(ctx context.Context, mapp map[int64]*model.MetricPoint, helperMap map[int64]*metricAggregationHelper, channelIDsToFetch map[int64]bool, since, until time.Time, bucket time.Duration) []model.MetricPoint { 19 | channelNames, err := s.fetchChannelNamesBatch(ctx, channelIDsToFetch) 20 | if err != nil { 21 | log.Printf("[WARN] 批量查询渠道名称失败: %v", err) 22 | channelNames = make(map[int64]string) 23 | } 24 | 25 | for bucketTs, mp := range mapp { 26 | newChannels := make(map[string]model.ChannelMetric) 27 | for key, metric := range mp.Channels { 28 | if key == "未知渠道" { 29 | newChannels[key] = metric 30 | continue 31 | } 32 | var channelID int64 33 | fmt.Sscanf(key, "ch_%d", &channelID) 34 | if name, ok := channelNames[channelID]; ok { 35 | newChannels[name] = metric 36 | } else { 37 | newChannels["未知渠道"] = metric 38 | } 39 | } 40 | mp.Channels = newChannels 41 | 42 | if helper, ok := helperMap[bucketTs]; ok { 43 | if helper.firstByteCount > 0 { 44 | avgFBT := helper.totalFirstByteTime / float64(helper.firstByteCount) 45 | mp.AvgFirstByteTimeSeconds = new(float64) 46 | *mp.AvgFirstByteTimeSeconds = avgFBT 47 | mp.FirstByteSampleCount = helper.firstByteCount 48 | } 49 | if helper.durationCount > 0 { 50 | avgDur := helper.totalDuration / float64(helper.durationCount) 51 | mp.AvgDurationSeconds = new(float64) 52 | *mp.AvgDurationSeconds = avgDur 53 | mp.DurationSampleCount = helper.durationCount 54 | } 55 | } 56 | } 57 | 58 | out := []model.MetricPoint{} 59 | endTime := until.Truncate(bucket).Add(bucket) 60 | startTime := since.Truncate(bucket) 61 | 62 | for t := startTime; t.Before(endTime); t = t.Add(bucket) { 63 | ts := t.Unix() 64 | if mp, ok := mapp[ts]; ok { 65 | out = append(out, *mp) 66 | } else { 67 | out = append(out, model.MetricPoint{ 68 | Ts: t, 69 | Channels: make(map[string]model.ChannelMetric), 70 | }) 71 | } 72 | } 73 | 74 | return out 75 | } 76 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module ccLoad 2 | 3 | go 1.25.0 4 | 5 | require ( 6 | github.com/bytedance/sonic v1.14.1 7 | github.com/gin-gonic/gin v1.10.1 8 | github.com/redis/go-redis/v9 v9.7.0 9 | modernc.org/sqlite v1.38.2 10 | ) 11 | 12 | require golang.org/x/term v0.38.0 13 | 14 | require ( 15 | filippo.io/edwards25519 v1.1.0 // indirect 16 | github.com/go-sql-driver/mysql v1.9.3 17 | ) 18 | 19 | require ( 20 | github.com/bytedance/gopkg v0.1.3 // indirect 21 | github.com/bytedance/sonic/loader v0.3.0 // indirect 22 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 23 | github.com/cloudwego/base64x v0.1.6 // indirect 24 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 25 | github.com/gabriel-vasile/mimetype v1.4.10 // indirect 26 | github.com/gin-contrib/sse v1.1.0 // indirect 27 | github.com/go-playground/locales v0.14.1 // indirect 28 | github.com/go-playground/universal-translator v0.18.1 // indirect 29 | github.com/go-playground/validator/v10 v10.27.0 // indirect 30 | github.com/goccy/go-json v0.10.5 // indirect 31 | github.com/google/uuid v1.6.0 // indirect 32 | github.com/json-iterator/go v1.1.12 // indirect 33 | github.com/klauspost/cpuid/v2 v2.3.0 // indirect 34 | github.com/leodido/go-urn v1.4.0 // indirect 35 | github.com/mattn/go-isatty v0.0.20 // indirect 36 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 37 | github.com/modern-go/reflect2 v1.0.2 // indirect 38 | github.com/ncruces/go-strftime v0.1.9 // indirect 39 | github.com/pelletier/go-toml/v2 v2.2.4 // indirect 40 | github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect 41 | github.com/twitchyliquid64/golang-asm v0.15.1 // indirect 42 | github.com/ugorji/go/codec v1.3.0 // indirect 43 | golang.org/x/arch v0.21.0 // indirect 44 | golang.org/x/crypto v0.46.0 45 | golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect 46 | golang.org/x/net v0.47.0 // indirect 47 | golang.org/x/text v0.32.0 // indirect 48 | google.golang.org/protobuf v1.36.8 // indirect 49 | gopkg.in/yaml.v3 v3.0.1 // indirect 50 | modernc.org/libc v1.66.3 // indirect 51 | modernc.org/mathutil v1.7.1 // indirect 52 | modernc.org/memory v1.11.0 // indirect 53 | ) 54 | 55 | require ( 56 | github.com/dustin/go-humanize v1.0.1 // indirect 57 | github.com/joho/godotenv v1.5.1 58 | golang.org/x/sys v0.39.0 // indirect 59 | ) 60 | -------------------------------------------------------------------------------- /CLAUDE.md: -------------------------------------------------------------------------------- 1 | # CLAUDE.md 2 | 3 | ## 构建与测试 4 | 5 | ```bash 6 | # 构建(必须 -tags go_json,注入版本号用于静态资源缓存) 7 | go build -tags go_json -ldflags "\ 8 | -X ccLoad/internal/version.Version=$(git describe --tags --always) \ 9 | -X ccLoad/internal/version.Commit=$(git rev-parse --short HEAD) \ 10 | -X 'ccLoad/internal/version.BuildTime=$(date '+%Y-%m-%d %H:%M:%S %z')' \ 11 | -X ccLoad/internal/version.BuiltBy=$(whoami)" -o ccload . 12 | 13 | # 测试(必须 -tags go_json) 14 | go test -tags go_json ./internal/... -v 15 | go test -tags go_json -race ./internal/... # 竞态检测 16 | 17 | # 开发运行(版本号为dev) 18 | export CCLOAD_PASS=test123 # 必填 19 | go run -tags go_json . 20 | ``` 21 | 22 | ## 核心架构 23 | 24 | ``` 25 | internal/ 26 | ├── app/ # HTTP层+业务逻辑 (proxy_*.go, admin_*.go, selector.go, key_selector.go) 27 | ├── cooldown/ # 冷却决策引擎 (manager.go) 28 | ├── storage/sql/ # 数据持久层 (SQLite/MySQL统一实现) 29 | ├── validator/ # 渠道验证器 30 | └── util/ # 工具库 (classifier.go错误分类, models_fetcher.go) 31 | ``` 32 | 33 | **故障切换策略**: 34 | - Key级错误(401/403/429) → 重试同渠道其他Key 35 | - 渠道级错误(5xx/520/524) → 切换到其他渠道 36 | - 客户端错误(404/405) → 不重试,直接返回 37 | - 指数退避: 2min → 4min → 8min → 30min(上限) 38 | 39 | **关键入口**: 40 | - `cooldown.Manager.HandleError()` - 冷却决策引擎 41 | - `util.ClassifyHTTPStatus()` - 错误分类 42 | - `app.KeySelector.SelectAvailableKey()` - Key负载均衡 43 | 44 | ## 开发指南 45 | 46 | ### Serena MCP 工具规范 47 | 48 | **代码浏览**: 49 | - 优先用符号化工具: `get_symbols_overview` → `find_symbol` 50 | - **禁止**直接读取整文件,先了解结构 51 | - 查找引用: `find_referencing_symbols` 52 | 53 | **代码编辑**: 54 | - 替换符号: `replace_symbol_body` 55 | - 插入代码: `insert_after_symbol` / `insert_before_symbol` 56 | - 小改动用 `Edit` 工具 57 | 58 | ### Playwright MCP 工具策略 59 | 60 | - 截图**必须** JPEG: `type: "jpeg"` 61 | - 优先 `browser_snapshot`(文本),视觉验证才截图 62 | - **避免** `fullPage: true` 63 | 64 | ### 添加 Admin API 65 | 1. `admin_types.go` - 定义类型 66 | 2. `admin_.go` - 实现Handler 67 | 3. `server.go:SetupRoutes()` - 注册路由 68 | 69 | ### 数据库操作 70 | - Schema更新: `storage/migrate.go` 启动自动执行 71 | - 事务: `(*SQLStore).WithTransaction(ctx, func(tx) error)` 72 | - 缓存失效: `InvalidateChannelListCache()` / `InvalidateAPIKeysCache()` 73 | 74 | ## 代码规范 75 | 76 | - **必须** `-tags go_json` 构建和测试 77 | - **必须** `any` 替代 `interface{}` 78 | - **禁止** 过度工程,YAGNI原则 79 | - **Fail-Fast**: 配置错误直接 `log.Fatal()` 退出 80 | - **Context**: `defer cancel()` 必须无条件调用,用 `context.AfterFunc` 监听取消 81 | -------------------------------------------------------------------------------- /.github/workflows/docker.yml: -------------------------------------------------------------------------------- 1 | name: Build and Push Docker Image 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | workflow_dispatch: 8 | inputs: 9 | tag: 10 | description: '手动指定镜像标签' 11 | required: false 12 | default: 'manual' 13 | 14 | env: 15 | REGISTRY: ghcr.io 16 | IMAGE_NAME: ${{ github.repository }} 17 | 18 | jobs: 19 | build-and-push: 20 | runs-on: ubuntu-latest 21 | permissions: 22 | contents: read 23 | packages: write 24 | attestations: write 25 | id-token: write 26 | 27 | steps: 28 | - name: Checkout repository 29 | uses: actions/checkout@v4 30 | 31 | - name: Set up Docker Buildx 32 | uses: docker/setup-buildx-action@v3 33 | 34 | - name: Log in to Container Registry 35 | uses: docker/login-action@v3 36 | with: 37 | registry: ${{ env.REGISTRY }} 38 | username: ${{ github.actor }} 39 | password: ${{ secrets.GITHUB_TOKEN }} 40 | 41 | - name: Extract metadata 42 | id: meta 43 | uses: docker/metadata-action@v5 44 | with: 45 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 46 | tags: | 47 | type=semver,pattern={{version}} 48 | type=semver,pattern={{major}}.{{minor}} 49 | type=semver,pattern={{major}} 50 | type=raw,value=latest,enable=${{ github.event_name == 'push' && contains(github.ref, 'refs/tags/') }} 51 | type=raw,value=${{ github.event.inputs.tag }},enable=${{ github.event_name == 'workflow_dispatch' && github.event.inputs.tag != '' }} 52 | 53 | - name: Build and push Docker image 54 | id: push 55 | uses: docker/build-push-action@v6 56 | with: 57 | context: . 58 | platforms: linux/amd64,linux/arm64 59 | push: true 60 | tags: ${{ steps.meta.outputs.tags }} 61 | labels: ${{ steps.meta.outputs.labels }} 62 | cache-from: type=gha 63 | cache-to: type=gha,mode=max 64 | build-args: | 65 | VERSION=${{ github.ref_name }} 66 | COMMIT=${{ github.sha }} 67 | 68 | - name: Generate artifact attestation 69 | uses: actions/attest-build-provenance@v1 70 | with: 71 | subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}} 72 | subject-digest: ${{ steps.push.outputs.digest }} 73 | push-to-registry: true -------------------------------------------------------------------------------- /internal/app/request_context.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "context" 5 | "sync/atomic" 6 | "time" 7 | ) 8 | 9 | // requestContext 封装单次请求的上下文和超时控制 10 | // 从 forwardOnceAsync 提取,遵循SRP原则 11 | // 补充首字节超时管控(可选) 12 | type requestContext struct { 13 | ctx context.Context 14 | cancel context.CancelFunc // [INFO] 总是非 nil(即使是 noop),调用方无需检查 15 | startTime time.Time 16 | isStreaming bool 17 | firstByteTimer *time.Timer 18 | firstByteTimedOut atomic.Bool 19 | } 20 | 21 | // newRequestContext 创建请求上下文(处理超时控制) 22 | // 设计原则: 23 | // - 流式请求:使用 firstByteTimeout(首字节超时),之后不限制 24 | // - 非流式请求:使用 nonStreamTimeout(整体超时),超时主动关闭上游连接 25 | // [INFO] Go 1.21+ 改进:总是返回非 nil 的 cancel,调用方无需检查(符合 Go 惯用法) 26 | func (s *Server) newRequestContext(parentCtx context.Context, requestPath string, body []byte) *requestContext { 27 | isStreaming := isStreamingRequest(requestPath, body) 28 | 29 | // [INFO] 关键改动:总是使用 WithCancel 包裹(即使无超时配置也能正常取消) 30 | ctx, cancel := context.WithCancel(parentCtx) 31 | 32 | // 非流式请求:在基础 cancel 之上叠加整体超时 33 | if !isStreaming && s.nonStreamTimeout > 0 { 34 | var timeoutCancel context.CancelFunc 35 | ctx, timeoutCancel = context.WithTimeout(ctx, s.nonStreamTimeout) 36 | // 链式 cancel:timeout 触发时也会取消父 context 37 | originalCancel := cancel 38 | cancel = func() { 39 | timeoutCancel() 40 | originalCancel() 41 | } 42 | } 43 | 44 | reqCtx := &requestContext{ 45 | ctx: ctx, 46 | cancel: cancel, // [INFO] 总是非 nil,无需检查 47 | startTime: time.Now(), 48 | isStreaming: isStreaming, 49 | } 50 | 51 | // 流式请求的首字节超时定时器 52 | if isStreaming && s.firstByteTimeout > 0 { 53 | reqCtx.firstByteTimer = time.AfterFunc(s.firstByteTimeout, func() { 54 | reqCtx.firstByteTimedOut.Store(true) 55 | cancel() // [INFO] 直接调用,无需检查 56 | }) 57 | } 58 | 59 | return reqCtx 60 | } 61 | 62 | func (rc *requestContext) stopFirstByteTimer() { 63 | if rc.firstByteTimer != nil { 64 | rc.firstByteTimer.Stop() 65 | } 66 | } 67 | 68 | func (rc *requestContext) firstByteTimeoutTriggered() bool { 69 | return rc.firstByteTimedOut.Load() 70 | } 71 | 72 | // Duration 返回从请求开始到现在的时间(秒) 73 | func (rc *requestContext) Duration() float64 { 74 | return time.Since(rc.startTime).Seconds() 75 | } 76 | 77 | // cleanup 统一清理请求上下文资源(定时器 + context) 78 | // [INFO] 符合 Go 惯用法:defer reqCtx.cleanup() 一行搞定 79 | func (rc *requestContext) cleanup() { 80 | rc.stopFirstByteTimer() // 停止首字节超时定时器 81 | rc.cancel() // 取消 context(总是非 nil,无需检查) 82 | } 83 | -------------------------------------------------------------------------------- /internal/util/channel_types_bench_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import "testing" 4 | 5 | // BenchmarkDetectChannelTypeFromPath 测试路径检测性能 6 | func BenchmarkDetectChannelTypeFromPath(b *testing.B) { 7 | testCases := []struct { 8 | name string 9 | path string 10 | }{ 11 | {"Anthropic", "/v1/messages"}, 12 | {"Codex", "/v1/responses"}, 13 | {"OpenAI_Chat", "/v1/chat/completions"}, 14 | {"OpenAI_Embeddings", "/v1/embeddings"}, 15 | {"Gemini", "/v1beta/models/gemini-pro:streamGenerateContent"}, 16 | {"Unknown", "/unknown/path"}, 17 | } 18 | 19 | for _, tc := range testCases { 20 | b.Run(tc.name, func(b *testing.B) { 21 | b.ReportAllocs() 22 | for i := 0; i < b.N; i++ { 23 | _ = DetectChannelTypeFromPath(tc.path) 24 | } 25 | }) 26 | } 27 | } 28 | 29 | // BenchmarkDetectChannelTypeFromPath_Parallel 并发性能测试 30 | func BenchmarkDetectChannelTypeFromPath_Parallel(b *testing.B) { 31 | path := "/v1/messages" 32 | b.RunParallel(func(pb *testing.PB) { 33 | for pb.Next() { 34 | _ = DetectChannelTypeFromPath(path) 35 | } 36 | }) 37 | } 38 | 39 | // BenchmarkNormalizeChannelType 测试渠道类型规范化性能 40 | func BenchmarkNormalizeChannelType(b *testing.B) { 41 | testCases := []struct { 42 | name string 43 | value string 44 | }{ 45 | {"Lowercase", "anthropic"}, 46 | {"Uppercase", "ANTHROPIC"}, 47 | {"MixedCase", "AnThRoPiC"}, 48 | {"WithSpaces", " anthropic "}, 49 | {"Empty", ""}, 50 | } 51 | 52 | for _, tc := range testCases { 53 | b.Run(tc.name, func(b *testing.B) { 54 | b.ReportAllocs() 55 | for i := 0; i < b.N; i++ { 56 | _ = NormalizeChannelType(tc.value) 57 | } 58 | }) 59 | } 60 | } 61 | 62 | // BenchmarkMatchPath 测试路径匹配性能 63 | func BenchmarkMatchPath(b *testing.B) { 64 | testCases := []struct { 65 | name string 66 | path string 67 | patterns []string 68 | matchType string 69 | }{ 70 | {"Prefix_Match", "/v1/messages", []string{"/v1/messages"}, MatchTypePrefix}, 71 | {"Prefix_NoMatch", "/v2/messages", []string{"/v1/messages"}, MatchTypePrefix}, 72 | {"Contains_Match", "/v1beta/models/gemini", []string{"/v1beta/"}, MatchTypeContains}, 73 | {"Contains_NoMatch", "/v1/models/gemini", []string{"/v1beta/"}, MatchTypeContains}, 74 | {"MultiPattern", "/v1/embeddings", []string{"/v1/chat", "/v1/completions", "/v1/embeddings"}, MatchTypePrefix}, 75 | } 76 | 77 | for _, tc := range testCases { 78 | b.Run(tc.name, func(b *testing.B) { 79 | b.ReportAllocs() 80 | for i := 0; i < b.N; i++ { 81 | _ = matchPath(tc.path, tc.patterns, tc.matchType) 82 | } 83 | }) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /internal/storage/sql/admin_sessions.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "time" 8 | 9 | "ccLoad/internal/model" 10 | ) 11 | 12 | // CreateAdminSession 创建管理员会话 13 | // [INFO] 安全修复:存储token的SHA256哈希而非明文(2025-12) 14 | func (s *SQLStore) CreateAdminSession(ctx context.Context, token string, expiresAt time.Time) error { 15 | tokenHash := model.HashToken(token) 16 | now := timeToUnix(time.Now()) 17 | _, err := s.db.ExecContext(ctx, ` 18 | REPLACE INTO admin_sessions (token, expires_at, created_at) 19 | VALUES (?, ?, ?) 20 | `, tokenHash, timeToUnix(expiresAt), now) 21 | return err 22 | } 23 | 24 | // GetAdminSession 获取管理员会话 25 | // [INFO] 安全修复:通过token哈希查询(2025-12) 26 | func (s *SQLStore) GetAdminSession(ctx context.Context, token string) (expiresAt time.Time, exists bool, err error) { 27 | tokenHash := model.HashToken(token) 28 | var expiresUnix int64 29 | err = s.db.QueryRowContext(ctx, ` 30 | SELECT expires_at FROM admin_sessions WHERE token = ? 31 | `, tokenHash).Scan(&expiresUnix) 32 | 33 | if err != nil { 34 | if errors.Is(err, sql.ErrNoRows) { 35 | return time.Time{}, false, nil 36 | } 37 | return time.Time{}, false, err 38 | } 39 | 40 | return unixToTime(expiresUnix), true, nil 41 | } 42 | 43 | // DeleteAdminSession 删除管理员会话 44 | // [INFO] 安全修复:通过token哈希删除(2025-12) 45 | func (s *SQLStore) DeleteAdminSession(ctx context.Context, token string) error { 46 | tokenHash := model.HashToken(token) 47 | _, err := s.db.ExecContext(ctx, `DELETE FROM admin_sessions WHERE token = ?`, tokenHash) 48 | return err 49 | } 50 | 51 | // CleanExpiredSessions 清理过期的会话 52 | func (s *SQLStore) CleanExpiredSessions(ctx context.Context) error { 53 | now := timeToUnix(time.Now()) 54 | _, err := s.db.ExecContext(ctx, `DELETE FROM admin_sessions WHERE expires_at < ?`, now) 55 | return err 56 | } 57 | 58 | // LoadAllSessions 加载所有未过期的会话(启动时调用) 59 | // [INFO] 安全修复:返回tokenHash→expiry映射(2025-12) 60 | func (s *SQLStore) LoadAllSessions(ctx context.Context) (map[string]time.Time, error) { 61 | now := timeToUnix(time.Now()) 62 | rows, err := s.db.QueryContext(ctx, ` 63 | SELECT token, expires_at FROM admin_sessions WHERE expires_at > ? 64 | `, now) 65 | if err != nil { 66 | return nil, err 67 | } 68 | defer rows.Close() 69 | 70 | sessions := make(map[string]time.Time) 71 | for rows.Next() { 72 | var tokenHash string 73 | var expiresUnix int64 74 | if err := rows.Scan(&tokenHash, &expiresUnix); err != nil { 75 | return nil, err 76 | } 77 | sessions[tokenHash] = unixToTime(expiresUnix) 78 | } 79 | 80 | return sessions, rows.Err() 81 | } 82 | -------------------------------------------------------------------------------- /internal/model/config.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | // Config 渠道配置 8 | type Config struct { 9 | ID int64 `json:"id"` 10 | Name string `json:"name"` 11 | ChannelType string `json:"channel_type"` // 渠道类型: "anthropic" | "codex" | "openai" | "gemini",默认anthropic 12 | URL string `json:"url"` 13 | Priority int `json:"priority"` 14 | Models []string `json:"models"` 15 | ModelRedirects map[string]string `json:"model_redirects,omitempty"` // 模型重定向映射:请求模型 -> 实际转发模型 16 | Enabled bool `json:"enabled"` 17 | 18 | // 渠道级冷却(从cooldowns表迁移) 19 | CooldownUntil int64 `json:"cooldown_until"` // Unix秒时间戳,0表示无冷却 20 | CooldownDurationMs int64 `json:"cooldown_duration_ms"` // 冷却持续时间(毫秒) 21 | 22 | CreatedAt JSONTime `json:"created_at"` // 使用JSONTime确保序列化格式一致(RFC3339) 23 | UpdatedAt JSONTime `json:"updated_at"` // 使用JSONTime确保序列化格式一致(RFC3339) 24 | 25 | // 缓存Key数量,避免冷却判断时的N+1查询 26 | KeyCount int `json:"key_count"` // API Key数量(查询时JOIN计算) 27 | } 28 | 29 | // GetChannelType 默认返回"anthropic"(Claude API) 30 | func (c *Config) GetChannelType() string { 31 | if c.ChannelType == "" { 32 | return "anthropic" 33 | } 34 | return c.ChannelType 35 | } 36 | 37 | func (c *Config) IsCoolingDown(now time.Time) bool { 38 | return c.CooldownUntil > now.Unix() 39 | } 40 | 41 | // KeyStrategy 常量定义 42 | const ( 43 | KeyStrategySequential = "sequential" // 顺序选择:按索引顺序尝试Key 44 | KeyStrategyRoundRobin = "round_robin" // 轮询选择:均匀分布请求到各个Key 45 | ) 46 | 47 | // IsValidKeyStrategy 验证KeyStrategy是否有效 48 | func IsValidKeyStrategy(s string) bool { 49 | return s == "" || s == KeyStrategySequential || s == KeyStrategyRoundRobin 50 | } 51 | 52 | type APIKey struct { 53 | ID int64 `json:"id"` 54 | ChannelID int64 `json:"channel_id"` 55 | KeyIndex int `json:"key_index"` 56 | APIKey string `json:"api_key"` 57 | 58 | KeyStrategy string `json:"key_strategy"` // "sequential" | "round_robin" 59 | 60 | // Key级冷却(从key_cooldowns表迁移) 61 | CooldownUntil int64 `json:"cooldown_until"` 62 | CooldownDurationMs int64 `json:"cooldown_duration_ms"` 63 | 64 | CreatedAt JSONTime `json:"created_at"` 65 | UpdatedAt JSONTime `json:"updated_at"` 66 | } 67 | 68 | func (k *APIKey) IsCoolingDown(now time.Time) bool { 69 | return k.CooldownUntil > now.Unix() 70 | } 71 | 72 | // ChannelWithKeys 用于Redis完整同步 73 | // 设计目标:解决Redis恢复后渠道缺少API Keys的问题 74 | type ChannelWithKeys struct { 75 | Config *Config `json:"config"` 76 | APIKeys []APIKey `json:"api_keys"` // 不使用指针避免额外分配 77 | } 78 | -------------------------------------------------------------------------------- /internal/model/log.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "strconv" 5 | "time" 6 | ) 7 | 8 | // JSONTime 自定义时间类型,使用Unix时间戳进行JSON序列化 9 | // 设计原则:与数据库格式统一,减少转换复杂度(KISS原则) 10 | type JSONTime struct { 11 | time.Time 12 | } 13 | 14 | // MarshalJSON 实现JSON序列化 15 | func (jt JSONTime) MarshalJSON() ([]byte, error) { 16 | if jt.Time.IsZero() { 17 | return []byte("0"), nil 18 | } 19 | return []byte(strconv.FormatInt(jt.Time.Unix(), 10)), nil 20 | } 21 | 22 | // UnmarshalJSON 实现JSON反序列化 23 | func (jt *JSONTime) UnmarshalJSON(data []byte) error { 24 | if string(data) == "null" || string(data) == "0" { 25 | jt.Time = time.Time{} 26 | return nil 27 | } 28 | ts, err := strconv.ParseInt(string(data), 10, 64) 29 | if err != nil { 30 | return err 31 | } 32 | jt.Time = time.Unix(ts, 0) 33 | return nil 34 | } 35 | 36 | // LogEntry 请求日志条目 37 | type LogEntry struct { 38 | ID int64 `json:"id"` 39 | Time JSONTime `json:"time"` 40 | Model string `json:"model"` 41 | ChannelID int64 `json:"channel_id"` 42 | ChannelName string `json:"channel_name,omitempty"` 43 | StatusCode int `json:"status_code"` 44 | Message string `json:"message"` 45 | Duration float64 `json:"duration"` // 总耗时(秒) 46 | IsStreaming bool `json:"is_streaming"` // 是否为流式请求 47 | FirstByteTime float64 `json:"first_byte_time"` // 首字节响应时间(秒) 48 | APIKeyUsed string `json:"api_key_used"` // 使用的API Key(查询时自动脱敏为 abcd...klmn 格式) 49 | AuthTokenID int64 `json:"auth_token_id"` // 客户端使用的API令牌ID(新增2025-12,0表示未使用token) 50 | ClientIP string `json:"client_ip"` // 客户端IP地址(新增2025-12) 51 | 52 | // Token统计(2025-11新增,支持Claude API usage字段) 53 | InputTokens int `json:"input_tokens"` 54 | OutputTokens int `json:"output_tokens"` 55 | CacheReadInputTokens int `json:"cache_read_input_tokens"` 56 | CacheCreationInputTokens int `json:"cache_creation_input_tokens"` // 5m+1h缓存总和(兼容字段) 57 | Cache5mInputTokens int `json:"cache_5m_input_tokens"` // 5分钟缓存写入Token数(新增2025-12) 58 | Cache1hInputTokens int `json:"cache_1h_input_tokens"` // 1小时缓存写入Token数(新增2025-12) 59 | Cost float64 `json:"cost"` // 请求成本(美元) 60 | } 61 | 62 | // LogFilter 日志查询过滤条件 63 | type LogFilter struct { 64 | ChannelID *int64 65 | ChannelName string 66 | ChannelNameLike string 67 | Model string 68 | ModelLike string 69 | StatusCode *int 70 | ChannelType string // 渠道类型过滤(anthropic/openai/gemini/codex) 71 | AuthTokenID *int64 // API令牌ID过滤 72 | } 73 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # ccLoad Docker镜像构建文件 2 | # 多平台构建:使用 tonistiigi/xx + Clang/LLVM 交叉编译 3 | # syntax=docker/dockerfile:1.4 4 | 5 | # 构建阶段 - 使用 BUILDPLATFORM 在原生架构执行 6 | FROM --platform=$BUILDPLATFORM golang:alpine AS builder 7 | 8 | # 版本号参数(优先使用 --build-arg,否则尝试从 git 获取) 9 | ARG VERSION 10 | ARG COMMIT 11 | 12 | # 安装交叉编译工具链 13 | # tonistiigi/xx 提供跨架构编译辅助工具 14 | COPY --from=tonistiigi/xx:1.6.1 / / 15 | RUN apk add --no-cache git ca-certificates tzdata clang lld 16 | 17 | # 设置工作目录 18 | WORKDIR /app 19 | 20 | # 配置目标平台的交叉编译工具链 21 | ARG TARGETPLATFORM 22 | RUN xx-apk add musl-dev gcc 23 | 24 | # 设置Go模块代理 25 | ENV GOPROXY=https://proxy.golang.org,direct 26 | 27 | # 复制go mod文件 28 | COPY go.mod go.sum ./ 29 | 30 | # 下载依赖(在原生平台执行,速度快) 31 | RUN --mount=type=cache,target=/root/.cache/go-mod \ 32 | go mod download 33 | 34 | # 复制源代码 35 | COPY . . 36 | 37 | # 交叉编译二进制文件(启用 CGO 以支持 bytedance/sonic) 38 | # xx-go 自动设置 GOOS/GOARCH/CC 等环境变量 39 | # VERSION 为空时从 git tag 获取,都没有则默认 "dev" 40 | ENV CGO_ENABLED=1 41 | RUN --mount=type=cache,target=/root/.cache/go-build \ 42 | --mount=type=cache,target=/root/.cache/go-mod \ 43 | BUILD_VERSION=${VERSION:-$(git describe --tags --always 2>/dev/null || echo "dev")} && \ 44 | BUILD_COMMIT=${COMMIT:-$(git rev-parse --short HEAD 2>/dev/null || echo "unknown")} && \ 45 | BUILD_COMMIT=$(echo $BUILD_COMMIT | cut -c1-7) && \ 46 | BUILD_TIME=$(date '+%Y-%m-%d %H:%M:%S %z') && \ 47 | xx-go build \ 48 | -tags go_json \ 49 | -ldflags="-s -w \ 50 | -X ccLoad/internal/version.Version=${BUILD_VERSION} \ 51 | -X ccLoad/internal/version.Commit=${BUILD_COMMIT} \ 52 | -X 'ccLoad/internal/version.BuildTime=${BUILD_TIME}' \ 53 | -X ccLoad/internal/version.BuiltBy=docker" \ 54 | -o ccload . && \ 55 | xx-verify ccload 56 | 57 | # 运行阶段 (使用固定版本避免 QEMU 模拟兼容性问题) 58 | FROM alpine:3.20 59 | 60 | # 安装运行时依赖 61 | RUN apk --no-cache add ca-certificates tzdata 62 | 63 | # 创建非root用户 64 | RUN addgroup -g 1001 -S ccload && \ 65 | adduser -u 1001 -S ccload -G ccload 66 | 67 | # 设置工作目录 68 | WORKDIR /app 69 | 70 | # 从构建阶段复制二进制文件 71 | COPY --from=builder /app/ccload . 72 | 73 | # 复制Web静态文件 74 | COPY --from=builder /app/web ./web 75 | 76 | # 创建数据目录并设置权限 77 | RUN mkdir -p /app/data && \ 78 | chown -R ccload:ccload /app 79 | 80 | # 切换到非root用户 81 | USER ccload 82 | 83 | # 暴露端口 84 | EXPOSE 8080 85 | 86 | # 设置环境变量 87 | ENV PORT=8080 88 | ENV SQLITE_PATH=/app/data/ccload.db 89 | ENV GIN_MODE=release 90 | 91 | # 健康检查(轻量级端点,<5ms响应) 92 | HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ 93 | CMD wget --no-verbose --tries=1 --spider http://localhost:8080/health || exit 1 94 | 95 | # 启动应用 96 | CMD ["./ccload"] 97 | -------------------------------------------------------------------------------- /internal/app/admin_response_contract_test.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/parser" 7 | "go/token" 8 | "os" 9 | "sort" 10 | "strings" 11 | "testing" 12 | ) 13 | 14 | func TestAdminHandlers_DoNotUseGinJSONDirectly(t *testing.T) { 15 | t.Helper() 16 | 17 | // 这些调用会绕过 APIResponse 统一格式(success/data/error/count)。 18 | banned := map[string]bool{ 19 | "JSON": true, 20 | "IndentedJSON": true, 21 | "SecureJSON": true, 22 | "AsciiJSON": true, 23 | "PureJSON": true, 24 | "JSONP": true, 25 | "AbortWithStatusJSON": true, 26 | } 27 | 28 | var files []string 29 | entries, err := os.ReadDir(".") 30 | if err != nil { 31 | t.Fatalf("ReadDir: %v", err) 32 | } 33 | for _, e := range entries { 34 | name := e.Name() 35 | if e.IsDir() { 36 | continue 37 | } 38 | if !strings.HasSuffix(name, ".go") || strings.HasSuffix(name, "_test.go") { 39 | continue 40 | } 41 | if strings.HasPrefix(name, "admin_") { 42 | files = append(files, name) 43 | } 44 | } 45 | 46 | // RequireTokenAuth 属于 Admin API 认证链路;RequireAPIAuth 属于代理API(不强制APIResponse格式)。 47 | files = append(files, "auth_service.go") 48 | sort.Strings(files) 49 | 50 | var offenders []string 51 | for _, filename := range files { 52 | allowInFunc := map[string]bool{} 53 | if filename == "auth_service.go" { 54 | allowInFunc["RequireAPIAuth"] = true 55 | } 56 | 57 | fset := token.NewFileSet() 58 | f, err := parser.ParseFile(fset, filename, nil, 0) 59 | if err != nil { 60 | t.Fatalf("ParseFile %s: %v", filename, err) 61 | } 62 | 63 | for _, decl := range f.Decls { 64 | fn, ok := decl.(*ast.FuncDecl) 65 | if !ok || fn.Body == nil { 66 | continue 67 | } 68 | 69 | funcName := "" 70 | if fn.Name != nil { 71 | funcName = fn.Name.Name 72 | } 73 | 74 | ast.Inspect(fn.Body, func(n ast.Node) bool { 75 | call, ok := n.(*ast.CallExpr) 76 | if !ok { 77 | return true 78 | } 79 | sel, ok := call.Fun.(*ast.SelectorExpr) 80 | if !ok || sel.Sel == nil { 81 | return true 82 | } 83 | if !banned[sel.Sel.Name] { 84 | return true 85 | } 86 | if allowInFunc[funcName] { 87 | return true 88 | } 89 | pos := fset.Position(sel.Sel.Pos()) 90 | offenders = append(offenders, fmt.Sprintf("%s:%d:%d %s.%s()", filename, pos.Line, pos.Column, funcName, sel.Sel.Name)) 91 | return true 92 | }) 93 | } 94 | } 95 | 96 | if len(offenders) > 0 { 97 | t.Fatalf("发现绕过APIResponse的直接JSON输出(请改用 RespondJSON/RespondError*):\n- %s", strings.Join(offenders, "\n- ")) 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /internal/app/token_stats_shutdown_test.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | "testing" 7 | "time" 8 | 9 | "ccLoad/internal/model" 10 | "ccLoad/internal/storage" 11 | ) 12 | 13 | func TestUpdateTokenStatsDuringShutdown(t *testing.T) { 14 | store, err := storage.CreateSQLiteStore(":memory:", nil) 15 | if err != nil { 16 | t.Fatalf("CreateSQLiteStore failed: %v", err) 17 | } 18 | 19 | srv := NewServer(store) 20 | 21 | ctx := context.Background() 22 | tokenHash := strings.Repeat("a", 64) 23 | if err := store.CreateAuthToken(ctx, &model.AuthToken{ 24 | Token: tokenHash, 25 | Description: "test", 26 | IsActive: true, 27 | }); err != nil { 28 | t.Fatalf("CreateAuthToken failed: %v", err) 29 | } 30 | 31 | // 阻塞wg.Wait,避免Shutdown过快走到store.Close,从而与“在途请求结束后写入统计”的场景失真 32 | blockCh := make(chan struct{}) 33 | srv.wg.Add(1) 34 | go func() { 35 | defer srv.wg.Done() 36 | <-blockCh 37 | }() 38 | 39 | shutdownErrCh := make(chan error, 1) 40 | shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 41 | defer cancel() 42 | go func() { 43 | shutdownErrCh <- srv.Shutdown(shutdownCtx) 44 | }() 45 | defer func() { 46 | close(blockCh) 47 | <-shutdownErrCh 48 | }() 49 | 50 | // 等待Shutdown进入“shutting down”状态 51 | deadline := time.Now().Add(1 * time.Second) 52 | for !srv.isShuttingDown.Load() { 53 | if time.Now().After(deadline) { 54 | t.Fatal("server did not enter shutting down state in time") 55 | } 56 | time.Sleep(1 * time.Millisecond) 57 | } 58 | 59 | // 模拟:shutdown开始后,一个在途请求完成并尝试写入计费/用量统计 60 | srv.updateTokenStatsAsync(tokenHash, true, 1.23, false, &fwResult{ 61 | FirstByteTime: 0.2, 62 | InputTokens: 10, 63 | OutputTokens: 20, 64 | CacheReadInputTokens: 5, 65 | CacheCreationInputTokens: 3, 66 | }, "gpt-5.1-codex") 67 | 68 | got, err := store.GetAuthTokenByValue(ctx, tokenHash) 69 | if err != nil { 70 | t.Fatalf("GetAuthTokenByValue failed: %v", err) 71 | } 72 | if got.SuccessCount != 1 { 73 | t.Fatalf("SuccessCount = %d, want %d", got.SuccessCount, 1) 74 | } 75 | if got.PromptTokensTotal != 10 { 76 | t.Fatalf("PromptTokensTotal = %d, want %d", got.PromptTokensTotal, 10) 77 | } 78 | if got.CompletionTokensTotal != 20 { 79 | t.Fatalf("CompletionTokensTotal = %d, want %d", got.CompletionTokensTotal, 20) 80 | } 81 | if got.CacheReadTokensTotal != 5 { 82 | t.Fatalf("CacheReadTokensTotal = %d, want %d", got.CacheReadTokensTotal, 5) 83 | } 84 | if got.CacheCreationTokensTotal != 3 { 85 | t.Fatalf("CacheCreationTokensTotal = %d, want %d", got.CacheCreationTokensTotal, 3) 86 | } 87 | if got.TotalCostUSD <= 0 { 88 | t.Fatalf("TotalCostUSD = %f, want > 0", got.TotalCostUSD) 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /web/assets/js/date-range-selector.js: -------------------------------------------------------------------------------- 1 | /** 2 | * 时间范围选择器 - 共享组件 3 | * 用于 logs/stats/trend 页面的统一时间范围选择 4 | * 5 | * 使用方式: 6 | * 1. 在HTML中引入: 7 | * 2. 调用 initDateRangeSelector(elementId, defaultRange, onChangeCallback) 8 | * 9 | * 后端API参数: range=today|yesterday|day_before_yesterday|this_week|last_week|this_month|last_month 10 | */ 11 | 12 | (function(window) { 13 | 'use strict'; 14 | 15 | // 时间范围预设 (key → 显示标签) 16 | // key与后端GetTimeRange()支持的range参数一致 17 | const DATE_RANGES = { 18 | 'today': { label: '本日' }, 19 | 'yesterday': { label: '昨日' }, 20 | 'day_before_yesterday': { label: '前日' }, 21 | 'this_week': { label: '本周' }, 22 | 'last_week': { label: '上周' }, 23 | 'this_month': { label: '本月' }, 24 | 'last_month': { label: '上月' } 25 | }; 26 | 27 | /** 28 | * 初始化时间范围选择器 29 | * @param {string} elementId - select元素的ID 30 | * @param {string} defaultRange - 默认选中的范围key (如'today') 31 | * @param {function} onChangeCallback - 值变化时的回调函数,接收range key参数 32 | */ 33 | window.initDateRangeSelector = function(elementId, defaultRange, onChangeCallback) { 34 | const selectEl = document.getElementById(elementId); 35 | if (!selectEl) { 36 | console.error(`时间范围选择器初始化失败: 未找到元素 #${elementId}`); 37 | return; 38 | } 39 | 40 | // 清空并重新生成选项 41 | selectEl.innerHTML = ''; 42 | Object.keys(DATE_RANGES).forEach(key => { 43 | const range = DATE_RANGES[key]; 44 | const option = document.createElement('option'); 45 | option.value = key; // 使用range key作为value 46 | option.textContent = range.label; 47 | selectEl.appendChild(option); 48 | }); 49 | 50 | // 设置默认值 51 | const validDefault = DATE_RANGES[defaultRange] ? defaultRange : 'today'; 52 | selectEl.value = validDefault; 53 | 54 | // 绑定change事件 55 | if (typeof onChangeCallback === 'function') { 56 | selectEl.addEventListener('change', function() { 57 | onChangeCallback(this.value); 58 | }); 59 | } 60 | }; 61 | 62 | /** 63 | * 获取范围的显示标签 64 | * @param {string} rangeKey - 范围key 65 | * @returns {string} 显示标签 66 | */ 67 | window.getRangeLabel = function(rangeKey) { 68 | return DATE_RANGES[rangeKey]?.label || '本日'; 69 | }; 70 | 71 | /** 72 | * 获取范围对应的大致小时数(用于metrics API的分桶计算) 73 | * @param {string} rangeKey - 范围key 74 | * @returns {number} 小时数 75 | */ 76 | window.getRangeHours = function(rangeKey) { 77 | const hoursMap = { 78 | 'today': 24, 79 | 'yesterday': 24, 80 | 'day_before_yesterday': 24, 81 | 'this_week': 168, 82 | 'last_week': 168, 83 | 'this_month': 720, 84 | 'last_month': 720 85 | }; 86 | return hoursMap[rangeKey] || 24; 87 | }; 88 | 89 | })(window); 90 | -------------------------------------------------------------------------------- /internal/util/time.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | // 冷却时间常量定义 8 | const ( 9 | // AuthErrorInitialCooldown 认证错误(401/403)的初始冷却时间 10 | // 设计目标:减少认证失败的无效重试,避免API配额浪费 11 | AuthErrorInitialCooldown = 5 * time.Minute 12 | 13 | // TimeoutErrorCooldown 超时错误的固定冷却时间 14 | // 设计目标:上游服务响应超时或完全无响应时,直接冷却避免资源浪费和级联故障 15 | // 适用场景:网络超时、上游服务无响应等(状态码598) 16 | TimeoutErrorCooldown = time.Minute 17 | 18 | // ServerErrorInitialCooldown 服务器错误(500/502/503/504)的初始冷却时间 19 | // 设计目标:指数退避策略,起始2分钟(2min → 4min → 8min → 16min → 30min上限) 20 | ServerErrorInitialCooldown = 2 * time.Minute 21 | 22 | // OtherErrorInitialCooldown 其他错误(429等)的初始冷却时间 23 | OtherErrorInitialCooldown = 10 * time.Second 24 | 25 | // MaxCooldownDuration 最大冷却时长(指数退避上限) 26 | MaxCooldownDuration = 30 * time.Minute 27 | 28 | // MinCooldownDuration 最小冷却时长(指数退避下限) 29 | MinCooldownDuration = 10 * time.Second 30 | ) 31 | 32 | // calculateBackoffDuration 计算指数退避冷却时间 33 | // 统一冷却策略: 34 | // - 认证错误(401/402/403): 起始5分钟,后续翻倍,上限30分钟 35 | // - 服务器错误(500/502/503/504): 起始2分钟,后续翻倍,上限30分钟 36 | // - 其他错误(429等): 起始10秒,后续翻倍,上限30分钟 37 | // 38 | // 参数: 39 | // - prevMs: 上次冷却持续时间(毫秒) 40 | // - until: 上次冷却截止时间 41 | // - now: 当前时间 42 | // - statusCode: HTTP状态码(可选,用于首次错误时确定初始冷却时间) 43 | // 44 | // 返回: 新的冷却持续时间 45 | // CalculateBackoffDuration 计算指数退避冷却时间 46 | func CalculateBackoffDuration(prevMs int64, until time.Time, now time.Time, statusCode *int) time.Duration { 47 | // 转换上次冷却持续时间 48 | prev := time.Duration(prevMs) * time.Millisecond 49 | 50 | // 如果没有历史记录,检查until字段 51 | if prev <= 0 { 52 | if !until.IsZero() && until.After(now) { 53 | prev = until.Sub(now) 54 | } else { 55 | // 首次错误:根据状态码确定初始冷却时间(直接返回,不翻倍) 56 | // 597/598:SSE错误/首字节超时,1分钟冷却,指数退避(1min → 2min → 4min → ...) 57 | if statusCode != nil && (*statusCode == 597 || *statusCode == 598) { 58 | return TimeoutErrorCooldown 59 | } 60 | // 服务器错误(500/502/503/504/520/521/524/599):2分钟冷却,指数退避(2min → 4min → 8min → ...) 61 | // 599:流式响应不完整,归类为服务器错误(上游服务问题) 62 | if statusCode != nil && (*statusCode == 500 || *statusCode == 502 || *statusCode == 503 || *statusCode == 504 || *statusCode == 520 || *statusCode == 521 || *statusCode == 524 || *statusCode == 599) { 63 | return ServerErrorInitialCooldown 64 | } 65 | // 认证错误(401/402/403):5分钟冷却,减少无效重试 66 | if statusCode != nil && (*statusCode == 401 || *statusCode == 402 || *statusCode == 403) { 67 | return AuthErrorInitialCooldown 68 | } 69 | // 其他错误(429等):10秒冷却,允许快速恢复 70 | return OtherErrorInitialCooldown 71 | } 72 | } 73 | 74 | // 后续错误:指数退避翻倍 // 边界限制(使用常量) 75 | 76 | next := min(max(prev*2, MinCooldownDuration), MaxCooldownDuration) 77 | 78 | return next 79 | } 80 | 81 | // CalculateCooldownDuration 计算冷却持续时间(毫秒) 82 | func CalculateCooldownDuration(until time.Time, now time.Time) int64 { 83 | if until.IsZero() || !until.After(now) { 84 | return 0 85 | } 86 | return int64(until.Sub(now) / time.Millisecond) 87 | } 88 | -------------------------------------------------------------------------------- /internal/app/proxy_forward_soft_error_test.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import "testing" 4 | 5 | func TestCheckSoftError(t *testing.T) { 6 | t.Parallel() 7 | 8 | tests := []struct { 9 | name string 10 | contentType string 11 | data []byte 12 | want bool 13 | }{ 14 | { 15 | name: "json_top_level_type_error", 16 | contentType: "application/json; charset=utf-8", 17 | data: []byte(`{"type":"error","message":"boom"}`), 18 | want: true, 19 | }, 20 | { 21 | name: "json_top_level_error_field", 22 | contentType: "application/json", 23 | data: []byte(`{"error":{"message":"boom"}}`), 24 | want: true, 25 | }, 26 | { 27 | name: "json_success_contains_keywords_should_not_match", 28 | contentType: "application/json", 29 | data: []byte(`{"type":"message","content":"api_error 当前模型负载过高"}`), 30 | want: false, 31 | }, 32 | { 33 | name: "json_truncated_object_should_not_guess", 34 | contentType: "application/json", 35 | data: []byte(`{"type":"error"`), 36 | want: false, 37 | }, 38 | { 39 | name: "json_content_type_but_plain_text_prefix_can_match", 40 | contentType: "application/json", 41 | data: []byte("当前模型负载过高,请稍后再试"), 42 | want: true, 43 | }, 44 | { 45 | name: "text_plain_prefix_short_match", 46 | contentType: "text/plain; charset=utf-8", 47 | data: []byte("当前模型负载过高,请稍后再试"), 48 | want: true, 49 | }, 50 | { 51 | name: "text_plain_contains_but_not_prefix_should_not_match", 52 | contentType: "text/plain", 53 | data: []byte("回答里提到 当前模型负载过高 但这不是错误"), 54 | want: false, 55 | }, 56 | { 57 | name: "text_plain_sse_should_not_match", 58 | contentType: "text/plain", 59 | data: []byte("data: {\"type\":\"message\"}\n\n"), 60 | want: false, 61 | }, 62 | } 63 | 64 | for _, tt := range tests { 65 | t.Run(tt.name, func(t *testing.T) { 66 | t.Parallel() 67 | if got := checkSoftError(tt.data, tt.contentType); got != tt.want { 68 | t.Fatalf("checkSoftError()=%v, want %v", got, tt.want) 69 | } 70 | }) 71 | } 72 | } 73 | 74 | func TestShouldCheckSoftErrorForChannelType(t *testing.T) { 75 | t.Parallel() 76 | 77 | tests := []struct { 78 | name string 79 | channelType string 80 | want bool 81 | }{ 82 | {name: "anthropic", channelType: "anthropic", want: true}, 83 | {name: "codex", channelType: "codex", want: true}, 84 | {name: "anthropic_default_empty", channelType: "", want: true}, 85 | {name: "openai", channelType: "openai", want: false}, 86 | {name: "gemini", channelType: "gemini", want: false}, 87 | {name: "unknown", channelType: "something", want: false}, 88 | } 89 | 90 | for _, tt := range tests { 91 | t.Run(tt.name, func(t *testing.T) { 92 | t.Parallel() 93 | if got := shouldCheckSoftErrorForChannelType(tt.channelType); got != tt.want { 94 | t.Fatalf("shouldCheckSoftErrorForChannelType()=%v, want %v", got, tt.want) 95 | } 96 | }) 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /internal/app/proxy_stream.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net/http" 7 | ) 8 | 9 | // ============================================================================ 10 | // 流式传输数据结构 11 | // ============================================================================ 12 | 13 | // streamReadStats 流式传输统计信息 14 | type streamReadStats struct { 15 | readCount int 16 | totalBytes int64 17 | } 18 | 19 | // firstByteDetector 检测首字节读取时间和传输统计的Reader包装器 20 | type firstByteDetector struct { 21 | io.ReadCloser 22 | stats *streamReadStats 23 | onFirstRead func() 24 | } 25 | 26 | // Read 实现io.Reader接口,记录读取统计 27 | func (r *firstByteDetector) Read(p []byte) (n int, err error) { 28 | n, err = r.ReadCloser.Read(p) 29 | if n > 0 { 30 | // 记录统计信息 31 | if r.stats != nil { 32 | r.stats.readCount++ 33 | r.stats.totalBytes += int64(n) 34 | } 35 | // 触发首次读取回调 36 | if r.onFirstRead != nil { 37 | r.onFirstRead() 38 | r.onFirstRead = nil // 只触发一次 39 | } 40 | } 41 | return 42 | } 43 | 44 | // ============================================================================ 45 | // 流式传输核心函数 46 | // ============================================================================ 47 | 48 | func streamCopyWithBufferSize(ctx context.Context, src io.Reader, dst http.ResponseWriter, onData func([]byte) error, bufSize int) error { 49 | buf := make([]byte, bufSize) 50 | for { 51 | select { 52 | case <-ctx.Done(): 53 | return ctx.Err() 54 | default: 55 | } 56 | 57 | n, err := src.Read(buf) 58 | if n > 0 { 59 | if _, writeErr := dst.Write(buf[:n]); writeErr != nil { 60 | return writeErr 61 | } 62 | if flusher, ok := dst.(http.Flusher); ok { 63 | flusher.Flush() 64 | } 65 | if onData != nil { 66 | if hookErr := onData(buf[:n]); hookErr != nil { 67 | // 钩子错误不中断流传输(容错设计) 68 | // 错误日志由钩子内部自行处理 69 | } 70 | } 71 | } 72 | if err != nil { 73 | if err == io.EOF { 74 | return nil 75 | } 76 | // [FIX] 检查 context 是否在 Read 期间被取消 77 | // 场景:客户端取消请求 → HTTP/2 流关闭 → Read 返回 "http2: response body closed" 78 | // 此时应返回 context.Canceled,让上层正确识别为客户端断开(499)而非上游错误(502) 79 | if ctx.Err() != nil { 80 | return ctx.Err() 81 | } 82 | return err 83 | } 84 | } 85 | } 86 | 87 | // streamCopy 流式复制(支持flusher与ctx取消) 88 | // 从proxy.go提取,遵循SRP原则 89 | // 简化实现:直接循环读取与写入,避免为每次读取创建goroutine导致泄漏 90 | // 首字节超时依赖于上游握手/响应头阶段的超时控制(Transport 配置),此处不再重复实现 91 | func streamCopy(ctx context.Context, src io.Reader, dst http.ResponseWriter, onData func([]byte) error) error { 92 | return streamCopyWithBufferSize(ctx, src, dst, onData, StreamBufferSize) 93 | } 94 | 95 | // streamCopySSE SSE专用流式复制(使用小缓冲区优化延迟) 96 | // [INFO] SSE优化(2025-10-17):4KB缓冲区降低首Token延迟60~80% 97 | // [INFO] 支持数据钩子(2025-11):允许SSE usage解析器增量处理数据流 98 | // 设计原则:SSE事件通常200B-2KB,小缓冲区避免事件积压 99 | func streamCopySSE(ctx context.Context, src io.Reader, dst http.ResponseWriter, onData func([]byte) error) error { 100 | return streamCopyWithBufferSize(ctx, src, dst, onData, SSEBufferSize) 101 | } 102 | -------------------------------------------------------------------------------- /internal/config/defaults.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import "time" 4 | 5 | // HTTP服务器配置常量 6 | const ( 7 | // DefaultMaxConcurrency 默认最大并发请求数 8 | DefaultMaxConcurrency = 1000 9 | 10 | // DefaultMaxKeyRetries 单个渠道内最大Key重试次数 11 | DefaultMaxKeyRetries = 3 12 | 13 | // DefaultMaxBodyBytes 默认最大请求体字节数(用于代理入口的解析) 14 | DefaultMaxBodyBytes = 2 * 1024 * 1024 // 2MB 15 | ) 16 | 17 | // HTTP客户端配置常量 18 | const ( 19 | // HTTPDialTimeout DNS解析+TCP连接建立超时 20 | HTTPDialTimeout = 30 * time.Second 21 | 22 | // HTTPKeepAliveInterval TCP keepalive间隔 23 | // 15秒:快速检测僵死连接(上游进程崩溃、网络中断) 24 | // 配合Linux默认重试(9次×3s),总检测时间42秒 25 | HTTPKeepAliveInterval = 15 * time.Second 26 | 27 | // HTTPTLSHandshakeTimeout TLS握手超时 28 | HTTPTLSHandshakeTimeout = 30 * time.Second 29 | 30 | // HTTPMaxIdleConns 全局空闲连接池大小 31 | HTTPMaxIdleConns = 100 32 | 33 | // HTTPMaxIdleConnsPerHost 单host空闲连接数 34 | HTTPMaxIdleConnsPerHost = 5 35 | 36 | // HTTPMaxConnsPerHost 单host最大连接数 37 | HTTPMaxConnsPerHost = 50 38 | 39 | // TLSSessionCacheSize TLS会话缓存大小 40 | TLSSessionCacheSize = 1024 41 | ) 42 | 43 | // 日志系统配置常量 44 | const ( 45 | // DefaultLogBufferSize 默认日志缓冲区大小(条数) 46 | DefaultLogBufferSize = 1000 47 | 48 | // DefaultLogWorkers 默认日志Worker协程数 49 | // 改为1以保证日志写入顺序(FIFO) 50 | // 多worker会导致竞争消费logChan,打乱日志顺序 51 | // 性能影响: 单worker仍支持批量写入,性能足够(1000条/秒+) 52 | DefaultLogWorkers = 1 53 | 54 | // LogBatchSize 批量写入日志的大小(条数) 55 | LogBatchSize = 100 56 | 57 | // LogBatchTimeout 批量写入超时时间 58 | LogBatchTimeout = 1 * time.Second 59 | 60 | // LogFlushTimeoutMs 单次日志刷盘的超时时间(毫秒) 61 | // 关停期间需要尽快完成,避免测试和生产关停卡顿 62 | LogFlushTimeoutMs = 300 63 | ) 64 | 65 | // Token认证配置常量 66 | const ( 67 | // TokenRandomBytes Token随机字节数(生成64字符十六进制) 68 | TokenRandomBytes = 32 69 | 70 | // TokenExpiry Token有效期 71 | TokenExpiry = 24 * time.Hour 72 | 73 | // TokenCleanupInterval Token清理间隔 74 | TokenCleanupInterval = 1 * time.Hour 75 | ) 76 | 77 | // Token统计配置常量 78 | const ( 79 | // DefaultTokenStatsBufferSize 默认Token统计更新队列大小(条数) 80 | // 设计原则:有界队列,避免每请求起goroutine导致资源失控 81 | DefaultTokenStatsBufferSize = 1000 82 | ) 83 | 84 | // SQLite连接池配置常量 85 | const ( 86 | // SQLiteMaxOpenConnsFile 文件模式最大连接数(WAL写并发瓶颈) 87 | // 保持5:1写 + 4读 = 充分利用WAL模式并发能力 88 | SQLiteMaxOpenConnsFile = 5 89 | 90 | // SQLiteMaxIdleConnsFile 文件模式最大空闲连接数 91 | // [INFO] 从2提升到5:避免高并发时频繁创建/销毁连接 92 | // 设计原则:空闲连接数 = 最大连接数,减少连接重建开销 93 | SQLiteMaxIdleConnsFile = 5 94 | 95 | // SQLiteConnMaxLifetime 连接最大生命周期 96 | // [INFO] 从1分钟提升到5分钟:降低连接过期频率 97 | // 权衡:更长的生命周期 vs 更低的连接重建开销 98 | SQLiteConnMaxLifetime = 5 * time.Minute 99 | ) 100 | 101 | // 性能优化配置常量 102 | const ( 103 | // LogCleanupInterval 日志清理间隔 104 | LogCleanupInterval = 1 * time.Hour 105 | ) 106 | 107 | // Redis同步配置常量 108 | const ( 109 | // RedisSyncShutdownTimeoutMs 优雅关闭等待时间(毫秒) 110 | RedisSyncShutdownTimeoutMs = 100 111 | ) 112 | 113 | // 启动超时配置(Fail-Fast:启动阶段网络问题应快速失败,避免卡死) 114 | const ( 115 | // StartupDBPingTimeout 数据库连接测试超时 116 | StartupDBPingTimeout = 10 * time.Second 117 | // StartupMigrationTimeout 数据库迁移超时 118 | StartupMigrationTimeout = 30 * time.Second 119 | // StartupRedisRestoreTimeout Redis数据恢复超时 120 | StartupRedisRestoreTimeout = 30 * time.Second 121 | ) 122 | -------------------------------------------------------------------------------- /internal/util/channel_types.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import "strings" 4 | 5 | // ChannelTypeConfig 渠道类型配置(元数据定义) 6 | type ChannelTypeConfig struct { 7 | Value string `json:"value"` // 内部值(数据库存储) 8 | DisplayName string `json:"display_name"` // 显示名称(前端展示) 9 | Description string `json:"description"` // 描述信息 10 | PathPatterns []string `json:"path_patterns"` // 路径匹配模式列表 11 | MatchType string `json:"match_type"` // 匹配类型: "prefix"(前缀) 或 "contains"(包含) 12 | } 13 | 14 | // ChannelTypes 全局渠道类型配置(单一数据源 - Single Source of Truth) 15 | var ChannelTypes = []ChannelTypeConfig{ 16 | { 17 | Value: ChannelTypeAnthropic, 18 | DisplayName: "Claude Code", 19 | Description: "Claude Code兼容API", 20 | PathPatterns: []string{"/v1/messages"}, 21 | MatchType: MatchTypePrefix, 22 | }, 23 | { 24 | Value: ChannelTypeCodex, 25 | DisplayName: "Codex", 26 | Description: "Codex兼容API", 27 | PathPatterns: []string{"/v1/responses"}, 28 | MatchType: MatchTypePrefix, 29 | }, 30 | { 31 | Value: ChannelTypeOpenAI, 32 | DisplayName: "OpenAI", 33 | Description: "OpenAI API (GPT系列)", 34 | PathPatterns: []string{"/v1/chat/completions", "/v1/completions", "/v1/embeddings"}, 35 | MatchType: MatchTypePrefix, 36 | }, 37 | { 38 | Value: ChannelTypeGemini, 39 | DisplayName: "Google Gemini", 40 | Description: "Google Gemini API", 41 | PathPatterns: []string{"/v1beta/"}, 42 | MatchType: MatchTypeContains, 43 | }, 44 | } 45 | 46 | // IsValidChannelType 验证渠道类型是否有效(替代models.go中的硬编码) 47 | func IsValidChannelType(value string) bool { 48 | for _, ct := range ChannelTypes { 49 | if ct.Value == value { 50 | return true 51 | } 52 | } 53 | return false 54 | } 55 | 56 | // NormalizeChannelType 规范化渠道类型(兼容性处理) 57 | // - 去除首尾空格 58 | // - 转小写 59 | // - 空值 → "anthropic" (默认值) 60 | func NormalizeChannelType(value string) string { 61 | // 去除首尾空格 62 | value = strings.TrimSpace(value) 63 | 64 | // 空值返回默认值 65 | if value == "" { 66 | return "anthropic" 67 | } 68 | 69 | // 转小写 70 | return strings.ToLower(value) 71 | } 72 | 73 | // 渠道类型常量(导出供其他包使用,遵循DRY原则) 74 | const ( 75 | ChannelTypeAnthropic = "anthropic" 76 | ChannelTypeCodex = "codex" 77 | ChannelTypeOpenAI = "openai" 78 | ChannelTypeGemini = "gemini" 79 | ) 80 | 81 | // 匹配类型常量(路径匹配方式) 82 | const ( 83 | MatchTypePrefix = "prefix" // 前缀匹配(strings.HasPrefix) 84 | MatchTypeContains = "contains" // 包含匹配(strings.Contains) 85 | ) 86 | 87 | // DetectChannelTypeFromPath 根据请求路径自动检测渠道类型 88 | // 使用 ChannelTypes 配置进行统一检测,遵循DRY原则 89 | func DetectChannelTypeFromPath(path string) string { 90 | for _, ct := range ChannelTypes { 91 | if matchPath(path, ct.PathPatterns, ct.MatchType) { 92 | return ct.Value 93 | } 94 | } 95 | return "" // 未匹配到任何类型 96 | } 97 | 98 | // matchPath 辅助函数:根据匹配类型检查路径是否匹配模式列表 99 | func matchPath(path string, patterns []string, matchType string) bool { 100 | for _, pattern := range patterns { 101 | switch matchType { 102 | case MatchTypePrefix: 103 | if strings.HasPrefix(path, pattern) { 104 | return true 105 | } 106 | case MatchTypeContains: 107 | if strings.Contains(path, pattern) { 108 | return true 109 | } 110 | } 111 | } 112 | return false 113 | } 114 | -------------------------------------------------------------------------------- /web/assets/js/template-engine.js: -------------------------------------------------------------------------------- 1 | /** 2 | * 轻量级模板引擎 3 | * 使用原生 HTML 元素实现 HTML/JS 分离 4 | * 5 | * 用法: 6 | * 1. 在 HTML 中定义 ... 7 | * 2. 模板内使用 {{key}} 或 {{obj.key}} 语法绑定数据 8 | * 3. JS 中调用 TemplateEngine.render('tpl-xxx', data) 9 | * 10 | * 特性: 11 | * - 自动 HTML 转义防止 XSS 12 | * - 支持嵌套属性访问 (obj.nested.value) 13 | * - 支持 {{{raw}}} 语法插入原始 HTML (慎用) 14 | * - 模板缓存提升性能 15 | */ 16 | const TemplateEngine = { 17 | // 模板缓存 18 | _cache: new Map(), 19 | 20 | /** 21 | * 获取模板内容 (带缓存) 22 | * @param {string} id - 模板ID (含或不含#前缀均可) 23 | * @returns {string} 模板HTML字符串 24 | */ 25 | _getTemplate(id) { 26 | const templateId = id.startsWith('#') ? id.slice(1) : id; 27 | 28 | if (!this._cache.has(templateId)) { 29 | const tpl = document.getElementById(templateId); 30 | if (!tpl) { 31 | console.error(`[TemplateEngine] Template not found: ${templateId}`); 32 | return ''; 33 | } 34 | // 缓存模板HTML字符串 35 | this._cache.set(templateId, tpl.innerHTML.trim()); 36 | } 37 | return this._cache.get(templateId); 38 | }, 39 | 40 | /** 41 | * HTML转义 (防XSS) 42 | * @param {string} str - 原始字符串 43 | * @returns {string} 转义后的字符串 44 | */ 45 | _escape(str) { 46 | if (str === null || str === undefined) return ''; 47 | return String(str).replace(/[&<>"']/g, c => ({ 48 | '&': '&', 49 | '<': '<', 50 | '>': '>', 51 | '"': '"', 52 | "'": ''' 53 | }[c])); 54 | }, 55 | 56 | /** 57 | * 从对象中获取嵌套属性值 58 | * @param {Object} obj - 数据对象 59 | * @param {string} path - 属性路径 (如 "user.name") 60 | * @returns {*} 属性值 61 | */ 62 | _getValue(obj, path) { 63 | return path.split('.').reduce((o, k) => o?.[k], obj); 64 | }, 65 | 66 | /** 67 | * 渲染单个模板 68 | * @param {string} id - 模板ID 69 | * @param {Object} data - 数据对象 70 | * @returns {HTMLElement|null} 渲染后的DOM元素 71 | */ 72 | render(id, data) { 73 | let html = this._getTemplate(id); 74 | if (!html) return null; 75 | 76 | // 处理 {{{raw}}} 语法 (原始HTML,不转义) 77 | html = html.replace(/\{\{\{(\w+(?:\.\w+)*)\}\}\}/g, (_, path) => { 78 | const value = this._getValue(data, path); 79 | return value !== undefined ? String(value) : ''; 80 | }); 81 | 82 | // 处理 {{key}} 语法 (自动转义) 83 | html = html.replace(/\{\{(\w+(?:\.\w+)*)\}\}/g, (_, path) => { 84 | const value = this._getValue(data, path); 85 | return value !== undefined ? this._escape(value) : ''; 86 | }); 87 | 88 | // 创建DOM元素 - 表格元素需要正确的父容器才能被浏览器正确解析 89 | const trimmed = html.trim().toLowerCase(); 90 | let temp; 91 | if (trimmed.startsWith(' 0 { 88 | chCost = new(float64) 89 | *chCost = totalCost 90 | } 91 | 92 | mp.Channels[channelKey] = model.ChannelMetric{ 93 | Success: success, 94 | Error: errorCount, 95 | AvgFirstByteTimeSeconds: avgFBT, 96 | AvgDurationSeconds: avgDur, 97 | TotalCost: chCost, 98 | InputTokens: inputTokens, 99 | OutputTokens: outputTokens, 100 | CacheReadTokens: cacheReadTokens, 101 | CacheCreationTokens: cacheCreationTokens, 102 | } 103 | } 104 | 105 | if err := rows.Err(); err != nil { 106 | return nil, nil, nil, err 107 | } 108 | 109 | return mapp, helperMap, channelIDsToFetch, nil 110 | } 111 | -------------------------------------------------------------------------------- /internal/storage/sql/store_impl.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "sync" 7 | "sync/atomic" 8 | "time" 9 | 10 | "ccLoad/internal/model" 11 | ) 12 | 13 | // syncType 定义同步类型(位标记,支持组合) 14 | // 包内私有:仅在 sql 包内使用,无需导出 (YAGNI) 15 | type syncType uint32 16 | 17 | const ( 18 | syncChannels syncType = 1 << iota // 同步渠道配置和 API Keys 19 | syncAuthTokens // 同步认证令牌 20 | 21 | syncAll = syncChannels | syncAuthTokens // 全量同步 22 | ) 23 | 24 | // RedisSync Redis同步接口 25 | // 支持渠道配置和Auth Tokens的双向同步 26 | type RedisSync interface { 27 | IsEnabled() bool 28 | LoadChannelsWithKeysFromRedis(ctx context.Context) ([]*model.ChannelWithKeys, error) 29 | SyncAllChannelsWithKeys(ctx context.Context, channels []*model.ChannelWithKeys) error 30 | // Auth Tokens同步 31 | SyncAllAuthTokens(ctx context.Context, tokens []*model.AuthToken) error 32 | LoadAuthTokensFromRedis(ctx context.Context) ([]*model.AuthToken, error) 33 | } 34 | 35 | // SQLStore 通用SQL存储实现 36 | // 支持 SQLite 和 MySQL(时间/布尔值存储格式完全一致,SQL语法按驱动分支) 37 | type SQLStore struct { 38 | db *sql.DB 39 | driverName string // "sqlite" 或 "mysql" 40 | 41 | // 异步Redis同步机制(性能优化: 避免同步等待) 42 | syncCh chan struct{} // 同步触发信号(缓冲1,去重合并多个请求) 43 | pendingSyncTypes atomic.Uint32 // 待同步类型(位标记,支持合并) 44 | done chan struct{} // 优雅关闭信号 45 | 46 | redisSync RedisSync // Redis同步接口(依赖注入,支持测试和扩展) 47 | 48 | // 优雅关闭:等待后台worker 49 | wg sync.WaitGroup 50 | 51 | // [FIX] 2025-12:保证 StartRedisSync 幂等性,防止多次调用启动多个 worker 52 | startOnce sync.Once 53 | // [FIX] 2025-12:保证 Close 幂等性,防止重复关闭 channel 导致 panic 54 | closeOnce sync.Once 55 | } 56 | 57 | // NewSQLStore 创建通用SQL存储实例 58 | // db: 数据库连接(由调用方初始化) 59 | // driverName: "sqlite" 或 "mysql" 60 | // redisSync: Redis同步器(可选,测试时可传nil) 61 | func NewSQLStore(db *sql.DB, driverName string, redisSync RedisSync) *SQLStore { 62 | return &SQLStore{ 63 | db: db, 64 | driverName: driverName, 65 | syncCh: make(chan struct{}, 1), 66 | done: make(chan struct{}), 67 | redisSync: redisSync, 68 | } 69 | } 70 | 71 | // StartRedisSync 显式启动 Redis 同步 worker 72 | // 必须在迁移完成且恢复逻辑执行后调用,避免空数据覆盖 Redis 备份 73 | // [FIX] 2025-12:使用 sync.Once 保证幂等性,防止多次调用启动多个 worker 74 | func (s *SQLStore) StartRedisSync() { 75 | if s.redisSync == nil || !s.redisSync.IsEnabled() { 76 | return 77 | } 78 | s.startOnce.Do(func() { 79 | s.wg.Add(1) 80 | go s.redisSyncWorker() 81 | // 启动时触发全量同步,确保所有存量数据备份到 Redis 82 | s.triggerAsyncSync(syncAll) 83 | }) 84 | } 85 | 86 | // IsRedisEnabled 检查Redis是否启用 87 | func (s *SQLStore) IsRedisEnabled() bool { 88 | return s.redisSync != nil && s.redisSync.IsEnabled() 89 | } 90 | 91 | // IsSQLite 检查是否为SQLite驱动 92 | func (s *SQLStore) IsSQLite() bool { 93 | return s.driverName == "sqlite" 94 | } 95 | 96 | // Close 关闭存储(优雅关闭) 97 | // Ping 检查数据库连接是否活跃(用于健康检查,<1ms) 98 | func (s *SQLStore) Ping(ctx context.Context) error { 99 | return s.db.PingContext(ctx) 100 | } 101 | 102 | func (s *SQLStore) Close() error { 103 | var err error 104 | s.closeOnce.Do(func() { 105 | // 1. 通知后台worker退出 106 | close(s.done) 107 | 108 | // 2. 等待worker完成 109 | s.wg.Wait() 110 | 111 | // 3. 关闭数据库连接 112 | if s.db != nil { 113 | err = s.db.Close() 114 | } 115 | }) 116 | return err 117 | } 118 | 119 | // CleanupLogsBefore 清理指定时间之前的日志 120 | func (s *SQLStore) CleanupLogsBefore(ctx context.Context, cutoff time.Time) error { 121 | query := "DELETE FROM logs WHERE timestamp < ?" 122 | _, err := s.db.ExecContext(ctx, query, timeToUnix(cutoff)) 123 | return err 124 | } 125 | -------------------------------------------------------------------------------- /internal/app/admin_auth_tokens_test.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | 11 | "ccLoad/internal/model" 12 | 13 | "github.com/gin-gonic/gin" 14 | ) 15 | 16 | func TestAuthToken_MaskToken(t *testing.T) { 17 | tests := []struct { 18 | name string 19 | token string 20 | expected string 21 | }{ 22 | { 23 | name: "Long token", 24 | token: "sk-ant-1234567890abcdefghijklmnop", 25 | expected: "sk-a****mnop", 26 | }, 27 | { 28 | name: "Short token", 29 | token: "short", 30 | expected: "****", 31 | }, 32 | } 33 | 34 | for _, tt := range tests { 35 | t.Run(tt.name, func(t *testing.T) { 36 | masked := model.MaskToken(tt.token) 37 | if masked != tt.expected { 38 | t.Errorf("Expected '%s', got '%s'", tt.expected, masked) 39 | } 40 | }) 41 | } 42 | } 43 | 44 | func TestAdminAPI_CreateAuthToken_Basic(t *testing.T) { 45 | server, cleanup := setupTestServer(t) 46 | defer cleanup() 47 | 48 | requestBody := map[string]any{ 49 | "description": "Test Token", 50 | } 51 | 52 | body, _ := json.Marshal(requestBody) 53 | w := httptest.NewRecorder() 54 | c, _ := gin.CreateTestContext(w) 55 | c.Request = httptest.NewRequest(http.MethodPost, "/admin/auth-tokens", bytes.NewBuffer(body)) 56 | c.Request.Header.Set("Content-Type", "application/json") 57 | 58 | server.HandleCreateAuthToken(c) 59 | 60 | if w.Code != http.StatusOK { 61 | t.Fatalf("Expected 200, got %d", w.Code) 62 | } 63 | 64 | var response struct { 65 | Success bool `json:"success"` 66 | Data struct { 67 | ID int64 `json:"id"` 68 | Token string `json:"token"` 69 | } `json:"data"` 70 | } 71 | if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { 72 | t.Fatalf("Parse error: %v", err) 73 | } 74 | 75 | if !response.Success || len(response.Data.Token) == 0 { 76 | t.Error("Token creation failed") 77 | } 78 | 79 | ctx := context.Background() 80 | stored, err := server.store.GetAuthToken(ctx, response.Data.ID) 81 | if err != nil { 82 | t.Fatalf("DB error: %v", err) 83 | } 84 | 85 | expectedHash := model.HashToken(response.Data.Token) 86 | if stored.Token != expectedHash { 87 | t.Error("Hash mismatch") 88 | } 89 | } 90 | 91 | func TestAdminAPI_ListAuthTokens_ResponseShape(t *testing.T) { 92 | server, cleanup := setupTestServer(t) 93 | defer cleanup() 94 | 95 | w := httptest.NewRecorder() 96 | c, _ := gin.CreateTestContext(w) 97 | c.Request = httptest.NewRequest(http.MethodGet, "/admin/auth-tokens", nil) 98 | 99 | server.HandleListAuthTokens(c) 100 | 101 | if w.Code != http.StatusOK { 102 | t.Fatalf("Expected 200, got %d", w.Code) 103 | } 104 | 105 | var resp map[string]any 106 | if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { 107 | t.Fatalf("Parse error: %v", err) 108 | } 109 | 110 | if resp["success"] != true { 111 | t.Fatalf("Expected success=true, got %v", resp["success"]) 112 | } 113 | 114 | data, ok := resp["data"].(map[string]any) 115 | if !ok { 116 | t.Fatalf("Expected data object, got %T", resp["data"]) 117 | } 118 | 119 | if _, ok := data["is_today"]; !ok { 120 | t.Fatalf("Expected data.is_today to exist") 121 | } 122 | 123 | tokens, ok := data["tokens"] 124 | if !ok { 125 | t.Fatalf("Expected data.tokens to exist") 126 | } 127 | if tokens == nil { 128 | t.Fatalf("Expected data.tokens to be [], got null") 129 | } 130 | if _, ok := tokens.([]any); !ok { 131 | t.Fatalf("Expected data.tokens to be array, got %T", tokens) 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /web/settings.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 系统设置 - Claude Code & Codex Proxy 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 系统设置 27 | 配置运行时参数 · 实时生效 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 配置项 38 | 当前值 39 | 操作 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 保存所有更改 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | {{description}} 61 | {{{inputHtml}}} 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /internal/model/auth_token.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "crypto/sha256" 5 | "encoding/hex" 6 | "time" 7 | ) 8 | 9 | // AuthToken 表示一个API访问令牌 10 | // 用于代理API (/v1/*) 的认证授权 11 | type AuthToken struct { 12 | ID int64 `json:"id"` 13 | Token string `json:"token"` // SHA256哈希值(存储时)或明文(创建时返回) 14 | Description string `json:"description"` // 令牌用途描述 15 | CreatedAt time.Time `json:"created_at"` // 创建时间 16 | ExpiresAt *int64 `json:"expires_at,omitempty"` // 过期时间(Unix毫秒时间戳),nil表示永不过期 17 | LastUsedAt *int64 `json:"last_used_at,omitempty"` // 最后使用时间(Unix毫秒时间戳) 18 | IsActive bool `json:"is_active"` // 是否启用 19 | 20 | // 统计字段(2025-11新增) 21 | SuccessCount int64 `json:"success_count"` // 成功调用次数 22 | FailureCount int64 `json:"failure_count"` // 失败调用次数 23 | StreamAvgTTFB float64 `json:"stream_avg_ttfb"` // 流式请求平均首字节时间(秒) 24 | NonStreamAvgRT float64 `json:"non_stream_avg_rt"` // 非流式请求平均响应时间(秒) 25 | StreamCount int64 `json:"stream_count"` // 流式请求计数(用于计算平均值) 26 | NonStreamCount int64 `json:"non_stream_count"` // 非流式请求计数(用于计算平均值) 27 | 28 | // Token成本统计(2025-12新增) 29 | PromptTokensTotal int64 `json:"prompt_tokens_total"` // 累计输入Token数 30 | CompletionTokensTotal int64 `json:"completion_tokens_total"` // 累计输出Token数 31 | CacheReadTokensTotal int64 `json:"cache_read_tokens_total"` // 累计缓存读Token数 32 | CacheCreationTokensTotal int64 `json:"cache_creation_tokens_total"` // 累计缓存写Token数 33 | TotalCostUSD float64 `json:"total_cost_usd"` // 累计成本(美元) 34 | 35 | // RPM统计(2025-12新增,用于tokens.html显示) 36 | PeakRPM float64 `json:"peak_rpm,omitempty"` // 峰值RPM 37 | AvgRPM float64 `json:"avg_rpm,omitempty"` // 平均RPM 38 | RecentRPM float64 `json:"recent_rpm,omitempty"` // 最近一分钟RPM 39 | } 40 | 41 | // AuthTokenRangeStats 某个时间范围内的token统计(从logs表聚合,2025-12新增) 42 | type AuthTokenRangeStats struct { 43 | SuccessCount int64 `json:"success_count"` // 成功次数 44 | FailureCount int64 `json:"failure_count"` // 失败次数 45 | PromptTokens int64 `json:"prompt_tokens"` // 输入Token总数 46 | CompletionTokens int64 `json:"completion_tokens"` // 输出Token总数 47 | CacheReadTokens int64 `json:"cache_read_tokens"` // 缓存读Token总数 48 | CacheCreationTokens int64 `json:"cache_creation_tokens"` // 缓存写Token总数 49 | TotalCost float64 `json:"total_cost"` // 总费用(美元) 50 | StreamAvgTTFB float64 `json:"stream_avg_ttfb"` // 流式请求平均首字节时间 51 | NonStreamAvgRT float64 `json:"non_stream_avg_rt"` // 非流式请求平均响应时间 52 | StreamCount int64 `json:"stream_count"` // 流式请求计数 53 | NonStreamCount int64 `json:"non_stream_count"` // 非流式请求计数 54 | // RPM统计(2025-12新增) 55 | PeakRPM float64 `json:"peak_rpm"` // 峰值RPM(每分钟最大请求数) 56 | AvgRPM float64 `json:"avg_rpm"` // 平均RPM 57 | RecentRPM float64 `json:"recent_rpm"` // 最近一分钟RPM(仅本日有效) 58 | } 59 | 60 | // HashToken 计算令牌的SHA256哈希值 61 | // 用于安全存储令牌到数据库 62 | func HashToken(token string) string { 63 | hash := sha256.Sum256([]byte(token)) 64 | return hex.EncodeToString(hash[:]) 65 | } 66 | 67 | // IsExpired 检查令牌是否已过期 68 | func (t *AuthToken) IsExpired() bool { 69 | if t.ExpiresAt == nil { 70 | return false 71 | } 72 | return time.Now().UnixMilli() > *t.ExpiresAt 73 | } 74 | 75 | // IsValid 检查令牌是否有效(启用且未过期) 76 | func (t *AuthToken) IsValid() bool { 77 | return t.IsActive && !t.IsExpired() 78 | } 79 | 80 | // MaskToken 脱敏显示令牌(仅显示前4后4字符) 81 | // 例如: "sk-ant-1234567890abcdef" -> "sk-a****cdef" 82 | func MaskToken(token string) string { 83 | if len(token) <= 8 { 84 | return "****" 85 | } 86 | return token[:4] + "****" + token[len(token)-4:] 87 | } 88 | 89 | // UpdateLastUsed 更新最后使用时间为当前时间 90 | func (t *AuthToken) UpdateLastUsed() { 91 | now := time.Now().UnixMilli() 92 | t.LastUsedAt = &now 93 | } 94 | -------------------------------------------------------------------------------- /test/integration/csv_import_export_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "ccLoad/internal/model" 5 | "ccLoad/internal/util" 6 | "testing" 7 | ) 8 | 9 | // ==================== CSV导出默认值测试 ==================== 10 | // 注意:新架构中APIKey和KeyStrategy已从Config移除,CSV导出从api_keys表查询 11 | // 此测试简化为仅验证channel_type的默认值处理 12 | 13 | // ==================== CSV导入默认值测试 ==================== 14 | 15 | func TestCSVImport_DefaultValues(t *testing.T) { 16 | // 测试渠道类型规范化 17 | tests := []struct { 18 | input string 19 | expected string 20 | }{ 21 | {"", "anthropic"}, // 空值 → 默认值 22 | {" ", "anthropic"}, // 空白 → 默认值 23 | {"anthropic", "anthropic"}, // 有效值保持 24 | {"gemini", "gemini"}, // 有效值保持 25 | {"codex", "codex"}, // 有效值保持 26 | } 27 | 28 | for _, tt := range tests { 29 | result := util.NormalizeChannelType(tt.input) 30 | if result != tt.expected { 31 | t.Errorf("util.NormalizeChannelType(%q) = %q, 期望 %q", tt.input, result, tt.expected) 32 | } 33 | } 34 | 35 | // 测试Key策略默认值处理 36 | keyStrategy := "" 37 | if keyStrategy == "" { 38 | keyStrategy = "sequential" 39 | } 40 | if keyStrategy != "sequential" { 41 | t.Errorf("空key_strategy应填充为sequential,实际为: %s", keyStrategy) 42 | } 43 | } 44 | 45 | // ==================== CSV导出导入循环测试 ==================== 46 | 47 | func TestCSVExportImportCycle(t *testing.T) { 48 | // 测试channel_type的导出导入循环 49 | // 场景:数据库中有空channel_type的Config 50 | original := &model.Config{ 51 | ID: 1, 52 | Name: "test-cycle", 53 | ChannelType: "", // 数据库中的空值 54 | URL: "https://api.example.com", 55 | Priority: 10, 56 | Models: []string{"test-model"}, 57 | Enabled: true, 58 | } 59 | 60 | // 步骤1:导出CSV(使用GetChannelType()) 61 | exportedChannelType := original.GetChannelType() 62 | if exportedChannelType != "anthropic" { 63 | t.Fatalf("导出channel_type应为anthropic,实际为: %s", exportedChannelType) 64 | } 65 | 66 | // 步骤2:导入CSV(规范化channel_type) 67 | importedChannelType := util.NormalizeChannelType(exportedChannelType) 68 | if importedChannelType != "anthropic" { 69 | t.Fatalf("导入channel_type应为anthropic,实际为: %s", importedChannelType) 70 | } 71 | 72 | t.Log("✅ CSV导出导入可以修复空channel_type为默认值") 73 | } 74 | 75 | // ==================== CSV时间字段缺失测试 ==================== 76 | 77 | func TestCSVExport_NoTimeFields(t *testing.T) { 78 | // 验证CSV导出不包含时间字段 79 | header := []string{"id", "name", "api_key", "url", "priority", "models", "model_redirects", "channel_type", "key_strategy", "enabled"} 80 | 81 | hasCreatedAt := false 82 | hasUpdatedAt := false 83 | 84 | for _, col := range header { 85 | if col == "created_at" { 86 | hasCreatedAt = true 87 | } 88 | if col == "updated_at" { 89 | hasUpdatedAt = true 90 | } 91 | } 92 | 93 | if hasCreatedAt { 94 | t.Error("CSV不应包含created_at字段(设计决定:导入时使用当前时间)") 95 | } 96 | 97 | if hasUpdatedAt { 98 | t.Error("CSV不应包含updated_at字段(设计决定:导入时使用当前时间)") 99 | } 100 | 101 | t.Log("✅ CSV正确省略了时间字段,导入时将使用当前时间") 102 | } 103 | 104 | // ==================== util.NormalizeChannelType 边界条件测试 ==================== 105 | 106 | func TestNormalizeChannelType(t *testing.T) { 107 | tests := []struct { 108 | input string 109 | expected string 110 | }{ 111 | {"", "anthropic"}, // 空值 → 默认值 112 | {" ", "anthropic"}, // 空白 → 默认值 113 | {"anthropic", "anthropic"}, // 有效值保持 114 | {"gemini", "gemini"}, // 有效值保持 115 | {"codex", "codex"}, // 有效值保持 116 | {"openai", "openai"}, // 有效值保持(openai是有效的渠道类型) 117 | {"ANTHROPIC", "anthropic"}, // 大写转小写 118 | {" gemini ", "gemini"}, // 去除空格并转小写 119 | } 120 | 121 | for _, tt := range tests { 122 | t.Run(tt.input, func(t *testing.T) { 123 | result := util.NormalizeChannelType(tt.input) 124 | if result != tt.expected { 125 | t.Errorf("util.NormalizeChannelType(%q) = %q, 期望 %q", tt.input, result, tt.expected) 126 | } 127 | }) 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /internal/storage/health_success_rate_test.go: -------------------------------------------------------------------------------- 1 | package storage_test 2 | 3 | import ( 4 | "context" 5 | "path/filepath" 6 | "testing" 7 | "time" 8 | 9 | "ccLoad/internal/model" 10 | "ccLoad/internal/storage" 11 | ) 12 | 13 | func TestGetChannelSuccessRates_IgnoresClientNoise(t *testing.T) { 14 | ctx := context.Background() 15 | tmpDir := t.TempDir() 16 | 17 | dbPath := filepath.Join(tmpDir, "success_rate.db") 18 | store, err := storage.CreateSQLiteStore(dbPath, nil) 19 | if err != nil { 20 | t.Fatalf("failed to create sqlite store: %v", err) 21 | } 22 | defer store.Close() 23 | 24 | cfg := &model.Config{ 25 | Name: "test-channel", 26 | URL: "https://example.com", 27 | Priority: 10, 28 | Models: []string{"model-a"}, 29 | Enabled: true, 30 | } 31 | created, err := store.CreateConfig(ctx, cfg) 32 | if err != nil { 33 | t.Fatalf("failed to create config: %v", err) 34 | } 35 | 36 | now := time.Now() 37 | logs := []*model.LogEntry{ 38 | {Time: model.JSONTime{Time: now.Add(-10 * time.Second)}, ChannelID: created.ID, StatusCode: 200, Message: "ok"}, 39 | {Time: model.JSONTime{Time: now.Add(-9 * time.Second)}, ChannelID: created.ID, StatusCode: 204, Message: "ok"}, 40 | {Time: model.JSONTime{Time: now.Add(-8 * time.Second)}, ChannelID: created.ID, StatusCode: 502, Message: "bad gateway"}, 41 | {Time: model.JSONTime{Time: now.Add(-7 * time.Second)}, ChannelID: created.ID, StatusCode: 597, Message: "sse error"}, 42 | {Time: model.JSONTime{Time: now.Add(-6 * time.Second)}, ChannelID: created.ID, StatusCode: 404, Message: "client not found"}, // 应被忽略 43 | {Time: model.JSONTime{Time: now.Add(-5 * time.Second)}, ChannelID: created.ID, StatusCode: 499, Message: "client canceled"}, // 应被忽略 44 | } 45 | for _, e := range logs { 46 | if err := store.AddLog(ctx, e); err != nil { 47 | t.Fatalf("failed to add log: %v", err) 48 | } 49 | } 50 | 51 | rates, err := store.GetChannelSuccessRates(ctx, now.Add(-time.Minute)) 52 | if err != nil { 53 | t.Fatalf("GetChannelSuccessRates error: %v", err) 54 | } 55 | 56 | // eligible: 200/204/502/597 -> 2 successes / 4 total = 0.5 57 | got, ok := rates[created.ID] 58 | if !ok { 59 | t.Fatalf("expected channel %d in rates", created.ID) 60 | } 61 | if got < 0.49 || got > 0.51 { 62 | t.Fatalf("expected success rate ~0.5, got %v", got) 63 | } 64 | } 65 | 66 | func TestGetChannelSuccessRates_NoEligibleResults(t *testing.T) { 67 | ctx := context.Background() 68 | tmpDir := t.TempDir() 69 | 70 | dbPath := filepath.Join(tmpDir, "success_rate_empty.db") 71 | store, err := storage.CreateSQLiteStore(dbPath, nil) 72 | if err != nil { 73 | t.Fatalf("failed to create sqlite store: %v", err) 74 | } 75 | defer store.Close() 76 | 77 | cfg := &model.Config{ 78 | Name: "test-channel", 79 | URL: "https://example.com", 80 | Priority: 10, 81 | Models: []string{"model-a"}, 82 | Enabled: true, 83 | } 84 | created, err := store.CreateConfig(ctx, cfg) 85 | if err != nil { 86 | t.Fatalf("failed to create config: %v", err) 87 | } 88 | 89 | now := time.Now() 90 | // 全部是应被忽略的客户端噪声 91 | logs := []*model.LogEntry{ 92 | {Time: model.JSONTime{Time: now.Add(-10 * time.Second)}, ChannelID: created.ID, StatusCode: 404, Message: "not found"}, 93 | {Time: model.JSONTime{Time: now.Add(-9 * time.Second)}, ChannelID: created.ID, StatusCode: 415, Message: "unsupported"}, 94 | {Time: model.JSONTime{Time: now.Add(-8 * time.Second)}, ChannelID: created.ID, StatusCode: 499, Message: "client canceled"}, 95 | } 96 | for _, e := range logs { 97 | if err := store.AddLog(ctx, e); err != nil { 98 | t.Fatalf("failed to add log: %v", err) 99 | } 100 | } 101 | 102 | rates, err := store.GetChannelSuccessRates(ctx, now.Add(-time.Minute)) 103 | if err != nil { 104 | t.Fatalf("GetChannelSuccessRates error: %v", err) 105 | } 106 | if _, ok := rates[created.ID]; ok { 107 | t.Fatalf("expected no rate for channel %d when only client noise exists", created.ID) 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /web/assets/js/channels-import-export.js: -------------------------------------------------------------------------------- 1 | function setupImportExport() { 2 | const exportBtn = document.getElementById('exportCsvBtn'); 3 | const importBtn = document.getElementById('importCsvBtn'); 4 | const importInput = document.getElementById('importCsvInput'); 5 | 6 | if (exportBtn) { 7 | exportBtn.addEventListener('click', () => exportChannelsCSV(exportBtn)); 8 | } 9 | 10 | if (importBtn && importInput) { 11 | importBtn.addEventListener('click', () => { 12 | if (window.pauseBackgroundAnimation) window.pauseBackgroundAnimation(); 13 | importInput.click(); 14 | }); 15 | 16 | importInput.addEventListener('change', (event) => { 17 | if (window.resumeBackgroundAnimation) window.resumeBackgroundAnimation(); 18 | handleImportCSV(event, importBtn); 19 | }); 20 | 21 | importInput.addEventListener('cancel', () => { 22 | if (window.resumeBackgroundAnimation) window.resumeBackgroundAnimation(); 23 | }); 24 | } 25 | } 26 | 27 | async function exportChannelsCSV(buttonEl) { 28 | try { 29 | if (buttonEl) buttonEl.disabled = true; 30 | const res = await fetchWithAuth('/admin/channels/export'); 31 | if (!res.ok) { 32 | const errorText = await res.text(); 33 | throw new Error(errorText || `导出失败 (HTTP ${res.status})`); 34 | } 35 | 36 | const blob = await res.blob(); 37 | const url = URL.createObjectURL(blob); 38 | const link = document.createElement('a'); 39 | link.href = url; 40 | link.download = `channels-${formatTimestampForFilename()}.csv`; 41 | document.body.appendChild(link); 42 | link.click(); 43 | document.body.removeChild(link); 44 | URL.revokeObjectURL(url); 45 | 46 | if (window.showSuccess) window.showSuccess('导出成功'); 47 | } catch (err) { 48 | console.error('导出CSV失败', err); 49 | if (window.showError) window.showError(err.message || '导出失败'); 50 | } finally { 51 | if (buttonEl) buttonEl.disabled = false; 52 | } 53 | } 54 | 55 | async function handleImportCSV(event, importBtn) { 56 | const input = event.target; 57 | if (!input.files || input.files.length === 0) { 58 | return; 59 | } 60 | 61 | const file = input.files[0]; 62 | const formData = new FormData(); 63 | formData.append('file', file); 64 | 65 | if (importBtn) importBtn.disabled = true; 66 | 67 | try { 68 | const resp = await fetchAPIWithAuth('/admin/channels/import', { 69 | method: 'POST', 70 | body: formData 71 | }); 72 | 73 | const summary = resp.data; 74 | if (!resp.success) { 75 | throw new Error(resp.error || '导入失败'); 76 | } 77 | if (summary) { 78 | let msg = `导入完成:新增 ${summary.created || 0},更新 ${summary.updated || 0},跳过 ${summary.skipped || 0}`; 79 | 80 | if (summary.redis_sync_enabled) { 81 | if (summary.redis_sync_success) { 82 | msg += `,已同步 ${summary.redis_synced_channels || 0} 个渠道到Redis`; 83 | } else { 84 | msg += ',Redis同步失败'; 85 | } 86 | } 87 | 88 | if (window.showSuccess) window.showSuccess(msg); 89 | 90 | if (summary.errors && summary.errors.length) { 91 | const preview = summary.errors.slice(0, 3).join(';'); 92 | const extra = summary.errors.length > 3 ? ` 等${summary.errors.length}条记录` : ''; 93 | if (window.showError) window.showError(`部分记录导入失败:${preview}${extra}`); 94 | } 95 | 96 | if (summary.redis_sync_enabled && !summary.redis_sync_success && summary.redis_sync_error) { 97 | if (window.showError) window.showError(`Redis同步失败:${summary.redis_sync_error}`); 98 | } 99 | } else if (window.showSuccess) { 100 | window.showSuccess('导入完成'); 101 | } 102 | 103 | clearChannelsCache(); 104 | await loadChannels(filters.channelType); 105 | } catch (err) { 106 | console.error('导入CSV失败', err); 107 | if (window.showError) window.showError(err.message || '导入失败'); 108 | } finally { 109 | if (importBtn) importBtn.disabled = false; 110 | input.value = ''; 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /web/assets/js/channels-data.js: -------------------------------------------------------------------------------- 1 | async function loadChannels(type = 'all') { 2 | try { 3 | if (channelsCache[type]) { 4 | channels = channelsCache[type]; 5 | updateModelOptions(); 6 | filterChannels(); 7 | return; 8 | } 9 | 10 | const url = type === 'all' ? '/admin/channels' : `/admin/channels?type=${encodeURIComponent(type)}`; 11 | const data = await fetchDataWithAuth(url); 12 | 13 | channelsCache[type] = data || []; 14 | channels = channelsCache[type]; 15 | 16 | updateModelOptions(); 17 | filterChannels(); 18 | } catch (e) { 19 | console.error('加载渠道失败', e); 20 | if (window.showError) window.showError('加载渠道失败'); 21 | } 22 | } 23 | 24 | async function loadChannelStatsRange() { 25 | try { 26 | const setting = await fetchDataWithAuth('/admin/settings/channel_stats_range'); 27 | if (setting && setting.value) { 28 | channelStatsRange = setting.value; 29 | } 30 | } catch (e) { 31 | console.error('加载统计范围设置失败', e); 32 | } 33 | } 34 | 35 | async function loadChannelStats(range = channelStatsRange) { 36 | try { 37 | const params = new URLSearchParams({ range, limit: '500', offset: '0' }); 38 | const data = await fetchDataWithAuth(`/admin/stats?${params.toString()}`); 39 | channelStatsById = aggregateChannelStats((data && data.stats) || []); 40 | filterChannels(); 41 | } catch (err) { 42 | console.error('加载渠道统计数据失败', err); 43 | } 44 | } 45 | 46 | function aggregateChannelStats(statsEntries = []) { 47 | const result = {}; 48 | 49 | for (const entry of statsEntries) { 50 | const channelId = Number(entry.channel_id || entry.channelID); 51 | if (!Number.isFinite(channelId) || channelId <= 0) continue; 52 | 53 | if (!result[channelId]) { 54 | result[channelId] = { 55 | success: 0, 56 | error: 0, 57 | total: 0, 58 | totalInputTokens: 0, 59 | totalOutputTokens: 0, 60 | totalCacheReadInputTokens: 0, 61 | totalCacheCreationInputTokens: 0, 62 | totalCost: 0, 63 | _firstByteWeightedSum: 0, 64 | _firstByteWeight: 0 65 | }; 66 | } 67 | 68 | const stats = result[channelId]; 69 | const success = toSafeNumber(entry.success); 70 | const error = toSafeNumber(entry.error); 71 | const total = toSafeNumber(entry.total); 72 | 73 | stats.success += success; 74 | stats.error += error; 75 | stats.total += total; 76 | 77 | const avgFirstByte = Number(entry.avg_first_byte_time_seconds); 78 | const weight = success || total || 0; 79 | if (Number.isFinite(avgFirstByte) && avgFirstByte > 0 && weight > 0) { 80 | stats._firstByteWeightedSum += avgFirstByte * weight; 81 | stats._firstByteWeight += weight; 82 | } 83 | 84 | stats.totalInputTokens += toSafeNumber(entry.total_input_tokens); 85 | stats.totalOutputTokens += toSafeNumber(entry.total_output_tokens); 86 | stats.totalCacheReadInputTokens += toSafeNumber(entry.total_cache_read_input_tokens); 87 | stats.totalCacheCreationInputTokens += toSafeNumber(entry.total_cache_creation_input_tokens); 88 | stats.totalCost += toSafeNumber(entry.total_cost); 89 | } 90 | 91 | for (const id of Object.keys(result)) { 92 | const stats = result[id]; 93 | if (stats._firstByteWeight > 0) { 94 | stats.avgFirstByteTimeSeconds = stats._firstByteWeightedSum / stats._firstByteWeight; 95 | } 96 | delete stats._firstByteWeightedSum; 97 | delete stats._firstByteWeight; 98 | } 99 | 100 | return result; 101 | } 102 | 103 | function toSafeNumber(value) { 104 | const num = Number(value); 105 | return Number.isFinite(num) ? num : 0; 106 | } 107 | 108 | // 加载默认测试内容(从系统设置) 109 | async function loadDefaultTestContent() { 110 | try { 111 | const setting = await fetchDataWithAuth('/admin/settings/channel_test_content'); 112 | if (setting && setting.value) { 113 | defaultTestContent = setting.value; 114 | } 115 | } catch (e) { 116 | console.warn('加载默认测试内容失败,使用内置默认值', e); 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /internal/model/stats.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import "time" 4 | 5 | // MetricPoint 指标数据点(用于趋势图) 6 | type MetricPoint struct { 7 | Ts time.Time `json:"ts"` 8 | Success int `json:"success"` 9 | Error int `json:"error"` 10 | AvgFirstByteTimeSeconds *float64 `json:"avg_first_byte_time_seconds,omitempty"` // 平均首字响应时间(秒) 11 | AvgDurationSeconds *float64 `json:"avg_duration_seconds,omitempty"` // 平均总耗时(秒) 12 | TotalCost *float64 `json:"total_cost,omitempty"` // 总费用(美元) 13 | FirstByteSampleCount int `json:"first_byte_count,omitempty"` // 首字响应样本数(流式成功且有首字时间) 14 | DurationSampleCount int `json:"duration_count,omitempty"` // 总耗时样本数(成功且有耗时) 15 | InputTokens int64 `json:"input_tokens,omitempty"` // 输入Token 16 | OutputTokens int64 `json:"output_tokens,omitempty"` // 输出Token 17 | CacheReadTokens int64 `json:"cache_read_tokens,omitempty"` // 缓存读取Token 18 | CacheCreationTokens int64 `json:"cache_creation_tokens,omitempty"` // 缓存创建Token 19 | Channels map[string]ChannelMetric `json:"channels,omitempty"` 20 | } 21 | 22 | // ChannelMetric 单个渠道的指标 23 | type ChannelMetric struct { 24 | Success int `json:"success"` 25 | Error int `json:"error"` 26 | AvgFirstByteTimeSeconds *float64 `json:"avg_first_byte_time_seconds,omitempty"` // 平均首字响应时间(秒) 27 | AvgDurationSeconds *float64 `json:"avg_duration_seconds,omitempty"` // 平均总耗时(秒) 28 | TotalCost *float64 `json:"total_cost,omitempty"` // 总费用(美元) 29 | InputTokens int64 `json:"input_tokens,omitempty"` // 输入Token 30 | OutputTokens int64 `json:"output_tokens,omitempty"` // 输出Token 31 | CacheReadTokens int64 `json:"cache_read_tokens,omitempty"` // 缓存读取Token 32 | CacheCreationTokens int64 `json:"cache_creation_tokens,omitempty"` // 缓存创建Token 33 | } 34 | 35 | // StatsEntry 统计数据条目 36 | type StatsEntry struct { 37 | ChannelID *int `json:"channel_id,omitempty"` 38 | ChannelName string `json:"channel_name"` 39 | ChannelPriority *int `json:"channel_priority,omitempty"` // 渠道优先级(用于前端排序) 40 | Model string `json:"model"` 41 | Success int `json:"success"` 42 | Error int `json:"error"` 43 | Total int `json:"total"` 44 | AvgFirstByteTimeSeconds *float64 `json:"avg_first_byte_time_seconds,omitempty"` // 流式请求平均首字响应时间(秒) 45 | AvgDurationSeconds *float64 `json:"avg_duration_seconds,omitempty"` // 平均总耗时(秒) 46 | 47 | // RPM/QPS统计(基于分钟级数据) 48 | PeakRPM *float64 `json:"peak_rpm,omitempty"` // 峰值RPM(该渠道+模型的最大每分钟请求数) 49 | AvgRPM *float64 `json:"avg_rpm,omitempty"` // 平均RPM 50 | RecentRPM *float64 `json:"recent_rpm,omitempty"` // 最近一分钟RPM(仅本日有效) 51 | 52 | // Token统计(2025-11新增) 53 | TotalInputTokens *int64 `json:"total_input_tokens,omitempty"` // 总输入Token 54 | TotalOutputTokens *int64 `json:"total_output_tokens,omitempty"` // 总输出Token 55 | TotalCacheReadInputTokens *int64 `json:"total_cache_read_input_tokens,omitempty"` // 总缓存读取Token 56 | TotalCacheCreationInputTokens *int64 `json:"total_cache_creation_input_tokens,omitempty"` // 总缓存创建Token 57 | TotalCost *float64 `json:"total_cost,omitempty"` // 总成本(美元) 58 | } 59 | 60 | // RPMStats 包含RPM/QPS相关的统计数据 61 | type RPMStats struct { 62 | PeakRPM float64 `json:"peak_rpm"` // 峰值RPM(每分钟最大请求数) 63 | PeakQPS float64 `json:"peak_qps"` // 峰值QPS(每秒最大请求数) 64 | AvgRPM float64 `json:"avg_rpm"` // 平均RPM 65 | AvgQPS float64 `json:"avg_qps"` // 平均QPS 66 | RecentRPM float64 `json:"recent_rpm"` // 最近一分钟RPM(仅本日有效) 67 | RecentQPS float64 `json:"recent_qps"` // 最近一分钟QPS(仅本日有效) 68 | } 69 | 70 | -------------------------------------------------------------------------------- /web/assets/js/login.js: -------------------------------------------------------------------------------- 1 | (function() { 2 | const form = document.getElementById('login-form'); 3 | const errorMessage = document.getElementById('error-message'); 4 | const errorText = document.getElementById('error-text'); 5 | const loginButton = document.getElementById('login-button'); 6 | const passwordInput = document.getElementById('password'); 7 | 8 | function showError(message) { 9 | if (window.showError) try { window.showError(message); } catch (_) {} 10 | errorText.textContent = message; 11 | errorMessage.style.display = 'flex'; 12 | 13 | // 添加摇晃动画 14 | errorMessage.style.animation = 'none'; 15 | errorMessage.offsetHeight; // 触发重绘 16 | errorMessage.style.animation = 'slideInUp 0.3s ease-out'; 17 | } 18 | 19 | function hideError() { 20 | errorMessage.style.display = 'none'; 21 | } 22 | 23 | function setLoading(loading) { 24 | if (loading) { 25 | loginButton.classList.add('loading'); 26 | loginButton.disabled = true; 27 | passwordInput.disabled = true; 28 | } else { 29 | loginButton.classList.remove('loading'); 30 | loginButton.disabled = false; 31 | passwordInput.disabled = false; 32 | } 33 | } 34 | 35 | // 表单提交处理 36 | form.addEventListener('submit', async (e) => { 37 | e.preventDefault(); 38 | hideError(); 39 | setLoading(true); 40 | 41 | const password = passwordInput.value; 42 | 43 | try { 44 | const resp = await fetchAPI('/login', { 45 | method: 'POST', 46 | headers: { 47 | 'Content-Type': 'application/json', 48 | }, 49 | body: JSON.stringify({ password }), 50 | }); 51 | 52 | if (resp.success) { 53 | const data = resp.data || {}; 54 | 55 | // 存储Token到localStorage 56 | localStorage.setItem('ccload_token', data.token); 57 | localStorage.setItem('ccload_token_expiry', Date.now() + data.expiresIn * 1000); 58 | 59 | // 登录成功,添加成功动画 60 | loginButton.style.background = 'linear-gradient(135deg, var(--success-500), var(--success-600))'; 61 | 62 | setTimeout(() => { 63 | const urlParams = new URLSearchParams(window.location.search); 64 | const redirect = urlParams.get('redirect') || '/web/index.html'; 65 | window.location.href = redirect; 66 | }, 500); 67 | } else { 68 | showError(resp.error || '密码错误,请重试'); 69 | 70 | // 添加输入框摇晃动画 71 | passwordInput.style.animation = 'none'; 72 | passwordInput.offsetHeight; 73 | passwordInput.style.animation = 'shake 0.5s ease-in-out'; 74 | 75 | setTimeout(() => { 76 | passwordInput.style.animation = ''; 77 | }, 500); 78 | } 79 | } catch (error) { 80 | console.error('Login error:', error); 81 | showError('网络连接错误,请检查网络后重试'); 82 | } finally { 83 | setLoading(false); 84 | } 85 | }); 86 | 87 | // 输入框焦点处理 88 | passwordInput.addEventListener('focus', hideError); 89 | 90 | // 键盘快捷键 91 | document.addEventListener('keydown', (e) => { 92 | if (e.key === 'Escape') { 93 | hideError(); 94 | } 95 | }); 96 | 97 | // 检查URL参数中的错误信息 98 | const urlParams = new URLSearchParams(window.location.search); 99 | const errorParam = urlParams.get('error'); 100 | if (errorParam) { 101 | showError(decodeURIComponent(errorParam)); 102 | } 103 | 104 | // 页面加载完成后的初始化 105 | document.addEventListener('DOMContentLoaded', function() { 106 | // 聚焦到密码输入框 107 | setTimeout(() => { 108 | passwordInput.focus(); 109 | }, 500); 110 | 111 | // 添加输入框摇晃动画关键帧 112 | const style = document.createElement('style'); 113 | style.textContent = ` 114 | @keyframes shake { 115 | 0%, 100% { transform: translateX(0); } 116 | 10%, 30%, 50%, 70%, 90% { transform: translateX(-8px); } 117 | 20%, 40%, 60%, 80% { transform: translateX(8px); } 118 | } 119 | `; 120 | document.head.appendChild(style); 121 | }); 122 | })(); 123 | -------------------------------------------------------------------------------- /internal/app/proxy_util_test.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | ) 8 | 9 | func TestBuildLogEntry_StreamDiagMsg(t *testing.T) { 10 | channelID := int64(1) 11 | 12 | t.Run("正常成功响应", func(t *testing.T) { 13 | res := &fwResult{ 14 | Status: 200, 15 | InputTokens: 10, 16 | OutputTokens: 20, 17 | } 18 | entry := buildLogEntry("claude-3", channelID, 200, 1.5, true, "sk-test", 0, "", res, "") 19 | if entry.Message != "ok" { 20 | t.Errorf("expected Message='ok', got %q", entry.Message) 21 | } 22 | }) 23 | 24 | t.Run("流传输中断诊断", func(t *testing.T) { 25 | res := &fwResult{ 26 | Status: 200, 27 | StreamDiagMsg: "[WARN] 流传输中断: 错误=unexpected EOF | 已读取=1024字节(分5次)", 28 | } 29 | entry := buildLogEntry("claude-3", channelID, 200, 1.5, true, "sk-test", 0, "", res, "") 30 | if entry.Message != res.StreamDiagMsg { 31 | t.Errorf("expected Message=%q, got %q", res.StreamDiagMsg, entry.Message) 32 | } 33 | }) 34 | 35 | t.Run("流响应不完整诊断", func(t *testing.T) { 36 | res := &fwResult{ 37 | Status: 200, 38 | StreamDiagMsg: "[WARN] 流响应不完整: 正常EOF但无usage | 已读取=512字节(分3次)", 39 | } 40 | entry := buildLogEntry("claude-3", channelID, 200, 1.5, true, "sk-test", 0, "", res, "") 41 | if entry.Message != res.StreamDiagMsg { 42 | t.Errorf("expected Message=%q, got %q", res.StreamDiagMsg, entry.Message) 43 | } 44 | }) 45 | 46 | t.Run("errMsg优先于StreamDiagMsg", func(t *testing.T) { 47 | res := &fwResult{ 48 | Status: 200, 49 | StreamDiagMsg: "[WARN] 流传输中断", 50 | } 51 | errMsg := "network error" 52 | entry := buildLogEntry("claude-3", channelID, 200, 1.5, true, "sk-test", 0, "", res, errMsg) 53 | if entry.Message != errMsg { 54 | t.Errorf("expected Message=%q, got %q", errMsg, entry.Message) 55 | } 56 | }) 57 | } 58 | 59 | func TestCopyRequestHeaders_StripsHopByHopAndAuth(t *testing.T) { 60 | req, err := http.NewRequest(http.MethodGet, "https://example.com", nil) 61 | if err != nil { 62 | t.Fatal(err) 63 | } 64 | 65 | src := http.Header{} 66 | src.Set("Connection", "Upgrade, X-Hop") 67 | src.Set("Upgrade", "websocket") 68 | src.Set("X-Hop", "1") 69 | src.Set("Keep-Alive", "timeout=5") 70 | src.Set("TE", "trailers") 71 | src.Set("Trailer", "X-Trailer") 72 | src.Set("Proxy-Authorization", "secret") 73 | src.Set("Authorization", "Bearer client-token") 74 | src.Set("X-API-Key", "client-token2") 75 | src.Set("x-goog-api-key", "client-goog") 76 | src.Set("Accept-Encoding", "br") 77 | src.Set("X-Pass", "ok") 78 | 79 | copyRequestHeaders(req, src) 80 | 81 | if got := req.Header.Get("X-Pass"); got != "ok" { 82 | t.Fatalf("expected X-Pass=ok, got %q", got) 83 | } 84 | if got := req.Header.Get("Accept"); got != "application/json" { 85 | t.Fatalf("expected default Accept=application/json, got %q", got) 86 | } 87 | 88 | for _, k := range []string{ 89 | "Connection", 90 | "Upgrade", 91 | "X-Hop", 92 | "Keep-Alive", 93 | "TE", 94 | "Trailer", 95 | "Proxy-Authorization", 96 | "Authorization", 97 | "X-API-Key", 98 | "x-goog-api-key", 99 | "Accept-Encoding", 100 | } { 101 | if v := req.Header.Get(k); v != "" { 102 | t.Fatalf("expected header %q stripped, got %q", k, v) 103 | } 104 | } 105 | } 106 | 107 | func TestFilterAndWriteResponseHeaders_StripsHopByHop(t *testing.T) { 108 | w := httptest.NewRecorder() 109 | 110 | hdr := http.Header{} 111 | hdr.Set("Connection", "Upgrade, X-Hop") 112 | hdr.Set("Upgrade", "websocket") 113 | hdr.Set("X-Hop", "1") 114 | hdr.Set("Transfer-Encoding", "chunked") 115 | hdr.Set("Trailer", "X-Trailer") 116 | hdr.Set("Content-Length", "123") 117 | hdr.Set("Content-Encoding", "br") 118 | hdr.Set("X-Pass", "ok") 119 | 120 | filterAndWriteResponseHeaders(w, hdr) 121 | 122 | if got := w.Header().Get("X-Pass"); got != "ok" { 123 | t.Fatalf("expected X-Pass=ok, got %q", got) 124 | } 125 | if got := w.Header().Get("Content-Encoding"); got != "br" { 126 | t.Fatalf("expected Content-Encoding=br, got %q", got) 127 | } 128 | 129 | for _, k := range []string{ 130 | "Connection", 131 | "Upgrade", 132 | "X-Hop", 133 | "Transfer-Encoding", 134 | "Trailer", 135 | "Content-Length", 136 | } { 137 | if v := w.Header().Get(k); v != "" { 138 | t.Fatalf("expected header %q stripped, got %q", k, v) 139 | } 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /internal/app/config_service.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "strconv" 8 | "time" 9 | 10 | "ccLoad/internal/model" 11 | "ccLoad/internal/storage" 12 | ) 13 | 14 | // ConfigService 配置管理服务 15 | // 职责: 启动时从数据库加载配置,提供只读访问 16 | // 配置修改后程序会自动重启,无需热重载 17 | type ConfigService struct { 18 | store storage.Store 19 | cache map[string]*model.SystemSetting // 启动时加载,运行期间只读 20 | loaded bool 21 | } 22 | 23 | // NewConfigService 创建配置服务 24 | func NewConfigService(store storage.Store) *ConfigService { 25 | return &ConfigService{ 26 | store: store, 27 | cache: make(map[string]*model.SystemSetting), 28 | } 29 | } 30 | 31 | // LoadDefaults 启动时从数据库加载配置到内存(只调用一次) 32 | func (cs *ConfigService) LoadDefaults(ctx context.Context) error { 33 | if cs.loaded { 34 | return nil 35 | } 36 | 37 | settings, err := cs.store.ListAllSettings(ctx) 38 | if err != nil { 39 | return fmt.Errorf("load settings from db: %w", err) 40 | } 41 | 42 | for _, s := range settings { 43 | cs.cache[s.Key] = s 44 | } 45 | cs.loaded = true 46 | 47 | log.Printf("[INFO] ConfigService loaded %d settings", len(settings)) 48 | return nil 49 | } 50 | 51 | // GetInt 获取整数配置 52 | func (cs *ConfigService) GetInt(key string, defaultValue int) int { 53 | if setting, ok := cs.cache[key]; ok { 54 | if intVal, err := strconv.Atoi(setting.Value); err == nil { 55 | return intVal 56 | } 57 | } 58 | return defaultValue 59 | } 60 | 61 | // GetBool 获取布尔配置 62 | func (cs *ConfigService) GetBool(key string, defaultValue bool) bool { 63 | if setting, ok := cs.cache[key]; ok { 64 | return setting.Value == "true" || setting.Value == "1" 65 | } 66 | return defaultValue 67 | } 68 | 69 | // GetString 获取字符串配置 70 | func (cs *ConfigService) GetString(key string, defaultValue string) string { 71 | if setting, ok := cs.cache[key]; ok { 72 | return setting.Value 73 | } 74 | return defaultValue 75 | } 76 | 77 | // GetFloat 获取浮点数配置 78 | func (cs *ConfigService) GetFloat(key string, defaultValue float64) float64 { 79 | if setting, ok := cs.cache[key]; ok { 80 | if floatVal, err := strconv.ParseFloat(setting.Value, 64); err == nil { 81 | return floatVal 82 | } 83 | } 84 | return defaultValue 85 | } 86 | 87 | // GetDuration 获取时长配置(秒转Duration) 88 | func (cs *ConfigService) GetDuration(key string, defaultValue time.Duration) time.Duration { 89 | seconds := cs.GetInt(key, int(defaultValue.Seconds())) 90 | return time.Duration(seconds) * time.Second 91 | } 92 | 93 | // GetIntMin 获取整数配置(带最小值约束) 94 | // 如果值小于 min,记录警告并返回 defaultValue 95 | func (cs *ConfigService) GetIntMin(key string, defaultValue, min int) int { 96 | val := cs.GetInt(key, defaultValue) 97 | if val < min { 98 | log.Printf("[WARN] 无效的 %s=%d(必须 >= %d),已使用默认值 %d", key, val, min, defaultValue) 99 | return defaultValue 100 | } 101 | return val 102 | } 103 | 104 | // GetDurationNonNegative 获取非负时长配置 105 | // 如果值为负,记录警告并返回 0(禁用) 106 | func (cs *ConfigService) GetDurationNonNegative(key string, defaultValue time.Duration) time.Duration { 107 | val := cs.GetDuration(key, defaultValue) 108 | if val < 0 { 109 | log.Printf("[WARN] 无效的 %s=%v(必须 >= 0),已设为 0(禁用)", key, val) 110 | return 0 111 | } 112 | return val 113 | } 114 | 115 | // GetDurationPositive 获取正时长配置 116 | // 如果值 <= 0,记录警告并返回 defaultValue 117 | func (cs *ConfigService) GetDurationPositive(key string, defaultValue time.Duration) time.Duration { 118 | val := cs.GetDuration(key, defaultValue) 119 | if val <= 0 { 120 | log.Printf("[WARN] 无效的 %s=%v(必须 > 0),已使用默认值 %v", key, val, defaultValue) 121 | return defaultValue 122 | } 123 | return val 124 | } 125 | 126 | // GetSetting 获取完整配置对象(用于验证等场景) 127 | func (cs *ConfigService) GetSetting(key string) *model.SystemSetting { 128 | return cs.cache[key] 129 | } 130 | 131 | // UpdateSetting 更新配置(仅写数据库,不更新缓存,因为会重启) 132 | func (cs *ConfigService) UpdateSetting(ctx context.Context, key, value string) error { 133 | return cs.store.UpdateSetting(ctx, key, value) 134 | } 135 | 136 | // ListAllSettings 获取所有配置(用于前端展示) 137 | func (cs *ConfigService) ListAllSettings(ctx context.Context) ([]*model.SystemSetting, error) { 138 | return cs.store.ListAllSettings(ctx) 139 | } 140 | 141 | // BatchUpdateSettings 批量更新配置(仅写数据库,不更新缓存,因为会重启) 142 | func (cs *ConfigService) BatchUpdateSettings(ctx context.Context, updates map[string]string) error { 143 | return cs.store.BatchUpdateSettings(ctx, updates) 144 | } 145 | -------------------------------------------------------------------------------- /internal/util/rate_limiter.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "log" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | // LoginRateLimiter 登录速率限制器(防暴力破解) 10 | // 设计原则: 11 | // - 基于IP地址限制:防止单个IP暴力破解 12 | // - 指数退避:失败次数越多,锁定时间越长 13 | // - 自动清理:1小时后重置计数器 14 | // 支持优雅关闭 15 | type LoginRateLimiter struct { 16 | attempts map[string]*attemptRecord // IP -> 尝试记录 17 | mu sync.RWMutex 18 | 19 | // 配置参数 20 | maxAttempts int // 最大尝试次数(默认5次) 21 | lockoutDuration time.Duration // 锁定时长(默认15分钟) 22 | resetInterval time.Duration // 计数重置间隔(默认1小时) 23 | 24 | // 优雅关闭机制 25 | stopCh chan struct{} // 关闭信号 26 | } 27 | 28 | // attemptRecord 尝试记录 29 | type attemptRecord struct { 30 | count int // 失败次数 31 | lastAttempt time.Time // 最后尝试时间 32 | lockUntil time.Time // 锁定截止时间 33 | } 34 | 35 | // NewLoginRateLimiter 创建登录速率限制器 36 | func NewLoginRateLimiter() *LoginRateLimiter { 37 | limiter := &LoginRateLimiter{ 38 | attempts: make(map[string]*attemptRecord), 39 | maxAttempts: 5, // 最大5次尝试 40 | lockoutDuration: 15 * time.Minute, // 锁定15分钟 41 | resetInterval: 1 * time.Hour, // 1小时后重置 42 | stopCh: make(chan struct{}), // 初始化关闭信号 43 | } 44 | 45 | // 启动后台清理协程(每小时清理过期记录) 46 | // 支持优雅关闭 47 | go limiter.cleanupLoop() 48 | 49 | return limiter 50 | } 51 | 52 | // AllowAttempt 检查是否允许尝试登录 53 | // 返回值:true=允许,false=拒绝(被锁定) 54 | func (rl *LoginRateLimiter) AllowAttempt(ip string) bool { 55 | rl.mu.Lock() 56 | defer rl.mu.Unlock() 57 | 58 | now := time.Now() 59 | record, exists := rl.attempts[ip] 60 | 61 | // 首次尝试 62 | if !exists { 63 | rl.attempts[ip] = &attemptRecord{ 64 | count: 1, 65 | lastAttempt: now, 66 | } 67 | return true 68 | } 69 | 70 | // 检查是否被锁定 71 | if now.Before(record.lockUntil) { 72 | return false 73 | } 74 | 75 | // 重置计数(超过1小时) 76 | if now.Sub(record.lastAttempt) > rl.resetInterval { 77 | record.count = 0 78 | } 79 | 80 | // 增加尝试次数 81 | record.count++ 82 | record.lastAttempt = now 83 | 84 | // 超过最大次数,锁定 85 | if record.count > rl.maxAttempts { 86 | record.lockUntil = now.Add(rl.lockoutDuration) 87 | return false 88 | } 89 | 90 | return true 91 | } 92 | 93 | // RecordSuccess 记录成功登录(重置计数) 94 | func (rl *LoginRateLimiter) RecordSuccess(ip string) { 95 | rl.mu.Lock() 96 | defer rl.mu.Unlock() 97 | 98 | // 成功登录后,清除该IP的尝试记录 99 | delete(rl.attempts, ip) 100 | } 101 | 102 | // GetLockoutTime 获取锁定剩余时间(秒) 103 | // 返回值:0=未锁定,>0=锁定剩余秒数 104 | func (rl *LoginRateLimiter) GetLockoutTime(ip string) int { 105 | rl.mu.RLock() 106 | defer rl.mu.RUnlock() 107 | 108 | record, exists := rl.attempts[ip] 109 | if !exists { 110 | return 0 111 | } 112 | 113 | now := time.Now() 114 | if now.Before(record.lockUntil) { 115 | return int(record.lockUntil.Sub(now).Seconds()) 116 | } 117 | 118 | return 0 119 | } 120 | 121 | // GetAttemptCount 获取当前尝试次数 122 | func (rl *LoginRateLimiter) GetAttemptCount(ip string) int { 123 | rl.mu.RLock() 124 | defer rl.mu.RUnlock() 125 | 126 | record, exists := rl.attempts[ip] 127 | if !exists { 128 | return 0 129 | } 130 | 131 | // 检查是否已过期 132 | if time.Since(record.lastAttempt) > rl.resetInterval { 133 | return 0 134 | } 135 | 136 | return record.count 137 | } 138 | 139 | // cleanupLoop 定期清理过期记录(后台协程) 140 | // 支持优雅关闭 141 | func (rl *LoginRateLimiter) cleanupLoop() { 142 | ticker := time.NewTicker(1 * time.Hour) 143 | defer ticker.Stop() 144 | 145 | for { 146 | select { 147 | case <-ticker.C: 148 | rl.cleanup() 149 | case <-rl.stopCh: 150 | // 收到关闭信号,执行最后一次清理后退出 151 | rl.cleanup() 152 | return 153 | } 154 | } 155 | } 156 | 157 | // cleanup 清理过期记录 158 | func (rl *LoginRateLimiter) cleanup() { 159 | rl.mu.Lock() 160 | defer rl.mu.Unlock() 161 | 162 | now := time.Now() 163 | toDelete := make([]string, 0) 164 | 165 | for ip, record := range rl.attempts { 166 | // 清理条件: 167 | // 1. 超过重置间隔且未被锁定 168 | // 2. 锁定已过期且超过重置间隔 169 | if now.Sub(record.lastAttempt) > rl.resetInterval && now.After(record.lockUntil) { 170 | toDelete = append(toDelete, ip) 171 | } 172 | } 173 | 174 | for _, ip := range toDelete { 175 | delete(rl.attempts, ip) 176 | } 177 | 178 | if len(toDelete) > 0 { 179 | log.Printf("🧹 登录速率限制器:清理 %d 条过期记录", len(toDelete)) 180 | } 181 | } 182 | 183 | // 优雅关闭LoginRateLimiter 184 | // Stop 停止cleanupLoop后台协程 185 | func (rl *LoginRateLimiter) Stop() { 186 | close(rl.stopCh) 187 | } 188 | -------------------------------------------------------------------------------- /web/assets/js/channels-init.js: -------------------------------------------------------------------------------- 1 | function highlightFromHash() { 2 | const m = (location.hash || '').match(/^#channel-(\d+)$/); 3 | if (!m) return; 4 | const el = document.getElementById(`channel-${m[1]}`); 5 | if (!el) return; 6 | el.scrollIntoView({ behavior: 'smooth', block: 'center' }); 7 | const prev = el.style.boxShadow; 8 | el.style.transition = 'box-shadow 0.3s ease, background 0.3s ease'; 9 | el.style.boxShadow = '0 0 0 3px rgba(59,130,246,0.35), 0 10px 25px rgba(59,130,246,0.20)'; 10 | el.style.background = 'rgba(59,130,246,0.06)'; 11 | setTimeout(() => { 12 | el.style.boxShadow = prev || ''; 13 | el.style.background = ''; 14 | }, 1600); 15 | } 16 | 17 | // 从URL参数获取目标渠道ID,查询其类型并返回 18 | async function getTargetChannelType() { 19 | const params = new URLSearchParams(location.search); 20 | const channelId = params.get('id'); 21 | if (!channelId) return null; 22 | 23 | try { 24 | const channel = await fetchDataWithAuth(`/admin/channels/${channelId}`); 25 | return channel.channel_type || 'anthropic'; 26 | } catch (e) { 27 | console.error('获取渠道类型失败:', e); 28 | return null; 29 | } 30 | } 31 | 32 | // localStorage key for channels page filters 33 | const CHANNELS_FILTER_KEY = 'channels.filters'; 34 | 35 | function saveChannelsFilters() { 36 | try { 37 | localStorage.setItem(CHANNELS_FILTER_KEY, JSON.stringify({ 38 | channelType: filters.channelType, 39 | status: filters.status, 40 | model: filters.model, 41 | search: filters.search, 42 | id: filters.id 43 | })); 44 | } catch (_) {} 45 | } 46 | 47 | function loadChannelsFilters() { 48 | try { 49 | const saved = localStorage.getItem(CHANNELS_FILTER_KEY); 50 | if (saved) return JSON.parse(saved); 51 | } catch (_) {} 52 | return null; 53 | } 54 | 55 | document.addEventListener('DOMContentLoaded', async () => { 56 | if (window.initTopbar) initTopbar('channels'); 57 | setupFilterListeners(); 58 | setupImportExport(); 59 | setupKeyImportPreview(); 60 | 61 | await window.ChannelTypeManager.renderChannelTypeRadios('channelTypeRadios'); 62 | 63 | // 优先从 localStorage 恢复,其次检查 URL 参数,最后默认 all 64 | const savedFilters = loadChannelsFilters(); 65 | const targetChannelType = await getTargetChannelType(); 66 | const initialType = targetChannelType || (savedFilters?.channelType) || 'all'; 67 | 68 | filters.channelType = initialType; 69 | if (savedFilters) { 70 | filters.status = savedFilters.status || 'all'; 71 | filters.model = savedFilters.model || 'all'; 72 | filters.search = savedFilters.search || ''; 73 | filters.id = savedFilters.id || ''; 74 | document.getElementById('statusFilter').value = filters.status; 75 | document.getElementById('modelFilter').value = filters.model; 76 | document.getElementById('searchInput').value = filters.search; 77 | document.getElementById('idFilter').value = filters.id; 78 | } 79 | 80 | // 初始化渠道类型筛选器(替换原Tab逻辑) 81 | await initChannelTypeFilter(initialType); 82 | 83 | await loadDefaultTestContent(); 84 | await loadChannelStatsRange(); 85 | 86 | await loadChannels(initialType); 87 | await loadChannelStats(); 88 | highlightFromHash(); 89 | window.addEventListener('hashchange', highlightFromHash); 90 | }); 91 | 92 | // 初始化渠道类型筛选器 93 | async function initChannelTypeFilter(initialType) { 94 | const select = document.getElementById('channelTypeFilter'); 95 | if (!select) return; 96 | 97 | const types = await window.ChannelTypeManager.getChannelTypes(); 98 | 99 | // 添加"全部"选项 100 | select.innerHTML = '全部'; 101 | types.forEach(type => { 102 | const option = document.createElement('option'); 103 | option.value = type.value; 104 | option.textContent = type.display_name; 105 | if (type.value === initialType) { 106 | option.selected = true; 107 | } 108 | select.appendChild(option); 109 | }); 110 | 111 | // 绑定change事件 112 | select.addEventListener('change', (e) => { 113 | const type = e.target.value; 114 | filters.channelType = type; 115 | filters.model = 'all'; 116 | document.getElementById('modelFilter').value = 'all'; 117 | saveChannelsFilters(); 118 | loadChannels(type); 119 | }); 120 | } 121 | 122 | document.addEventListener('keydown', (e) => { 123 | if (e.key === 'Escape') { 124 | closeModal(); 125 | closeDeleteModal(); 126 | closeTestModal(); 127 | closeKeyImportModal(); 128 | } 129 | }); 130 | -------------------------------------------------------------------------------- /internal/app/static.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | "os" 7 | "path/filepath" 8 | "strings" 9 | 10 | "ccLoad/internal/version" 11 | 12 | "github.com/gin-gonic/gin" 13 | ) 14 | 15 | // webRoot 是 web 目录的真实绝对路径,启动时初始化 16 | var webRoot string 17 | 18 | // setupStaticFiles 配置静态文件服务 19 | // - HTML 文件:不缓存,动态替换版本号占位符 20 | // - CSS/JS/字体:长缓存(1年),依赖版本号刷新 21 | // - dev 版本:不缓存,方便开发调试 22 | func setupStaticFiles(r *gin.Engine) { 23 | // 初始化 web 目录真实绝对路径(解析符号链接,用于安全检查) 24 | absPath, err := filepath.Abs("./web") 25 | if err != nil { 26 | log.Fatalf("[FATAL] 无法解析 web 目录路径: %v", err) 27 | } 28 | 29 | // 解析符号链接获取真实路径 30 | webRoot, err = filepath.EvalSymlinks(absPath) 31 | if err != nil { 32 | // web 目录不存在:生产环境 Fatal,测试环境警告 33 | if isTestMode() { 34 | log.Printf("[WARN] web 目录不存在: %v(测试环境忽略)", err) 35 | webRoot = absPath // 请求时会返回 404 36 | } else { 37 | log.Fatalf("[FATAL] web 目录不存在或无法访问: %v", err) 38 | } 39 | } 40 | 41 | r.GET("/web/*filepath", serveStaticFile) 42 | } 43 | 44 | // isTestMode 检测是否在 Go 测试环境中运行 45 | func isTestMode() bool { 46 | for _, arg := range os.Args { 47 | if strings.HasPrefix(arg, "-test.") { 48 | return true 49 | } 50 | } 51 | return false 52 | } 53 | 54 | // serveStaticFile 处理静态文件请求 55 | func serveStaticFile(c *gin.Context) { 56 | // Gin wildcard 参数带前导斜杠,如 "/index.html" 57 | reqPath := c.Param("filepath") 58 | 59 | // 去除前导斜杠,确保是相对路径 60 | reqPath = strings.TrimPrefix(reqPath, "/") 61 | 62 | // Clean 处理 .. 和多余的斜杠 63 | reqPath = filepath.Clean(reqPath) 64 | 65 | // 防止路径遍历:Clean 后仍以 .. 开头说明试图逃逸 66 | if reqPath == ".." || strings.HasPrefix(reqPath, ".."+string(filepath.Separator)) { 67 | c.Status(http.StatusForbidden) 68 | return 69 | } 70 | 71 | // 构建完整文件路径 72 | filePath := filepath.Join(webRoot, reqPath) 73 | 74 | // 检查文件是否存在 75 | info, err := os.Stat(filePath) 76 | if err != nil { 77 | c.Status(http.StatusNotFound) 78 | return 79 | } 80 | 81 | // 如果是目录,尝试返回 index.html 82 | if info.IsDir() { 83 | filePath = filepath.Join(filePath, "index.html") 84 | if _, err = os.Stat(filePath); err != nil { 85 | c.Status(http.StatusNotFound) 86 | return 87 | } 88 | } 89 | 90 | // 最终防线:解析符号链接后验证真实路径在 webRoot 下 91 | realPath, err := filepath.EvalSymlinks(filePath) 92 | if err != nil { 93 | c.Status(http.StatusForbidden) 94 | return 95 | } 96 | 97 | // 使用 filepath.Rel 检查是否逃逸(比 HasPrefix 更可靠,处理大小写不敏感文件系统) 98 | if !isPathUnder(realPath, webRoot) { 99 | c.Status(http.StatusForbidden) 100 | return 101 | } 102 | 103 | ext := strings.ToLower(filepath.Ext(realPath)) 104 | 105 | // 根据文件类型设置缓存策略 106 | if ext == ".html" { 107 | serveHTMLWithVersion(c, realPath) 108 | } else { 109 | serveStaticWithCache(c, realPath, ext) 110 | } 111 | } 112 | 113 | // isPathUnder 检查 path 是否在 base 目录下(含 base 自身) 114 | // 使用 filepath.Rel 而非 HasPrefix,正确处理大小写不敏感文件系统 115 | func isPathUnder(path, base string) bool { 116 | rel, err := filepath.Rel(base, path) 117 | if err != nil { 118 | return false 119 | } 120 | // 相对路径不能以 .. 开头(表示逃逸) 121 | return rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) 122 | } 123 | 124 | // serveHTMLWithVersion 处理 HTML 文件,替换版本号占位符 125 | func serveHTMLWithVersion(c *gin.Context, filePath string) { 126 | content, err := os.ReadFile(filePath) 127 | if err != nil { 128 | c.Status(http.StatusInternalServerError) 129 | return 130 | } 131 | 132 | // 替换版本号占位符 133 | html := strings.ReplaceAll(string(content), "__VERSION__", version.Version) 134 | 135 | // HTML 不缓存,确保用户总能获取最新版本号引用 136 | c.Header("Cache-Control", "no-cache, must-revalidate") 137 | c.Header("Content-Type", "text/html; charset=utf-8") 138 | c.String(http.StatusOK, html) 139 | } 140 | 141 | // serveStaticWithCache 处理静态资源,设置缓存策略 142 | func serveStaticWithCache(c *gin.Context, filePath, ext string) { 143 | // 缓存策略: 144 | // - dev 版本:不缓存,方便开发调试 145 | // - manifest.json/favicon:短缓存(无版本号控制) 146 | // - 其他静态资源:长缓存(通过 URL 版本号刷新) 147 | fileName := filepath.Base(filePath) 148 | 149 | if version.Version == "dev" { 150 | // 开发环境:不缓存,避免前端修改看不到 151 | c.Header("Cache-Control", "no-cache, must-revalidate") 152 | } else if fileName == "manifest.json" || ext == ".ico" { 153 | // 元数据文件:1小时缓存 + 必须验证 154 | c.Header("Cache-Control", "public, max-age=3600, must-revalidate") 155 | } else { 156 | // 静态资源:1年缓存,immutable 表示内容不会变化(通过版本号刷新) 157 | c.Header("Cache-Control", "public, max-age=31536000, immutable") 158 | } 159 | 160 | // 使用 c.File() 替代手写 Open+io.Copy 161 | // 自动处理:Content-Type、Content-Length、HEAD、Range、If-Modified-Since/304 162 | c.File(filePath) 163 | } 164 | -------------------------------------------------------------------------------- /internal/config/defaults_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | // TestDefaultConstants 测试默认常量值的合理性 9 | func TestDefaultConstants(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | value int 13 | min int 14 | max int 15 | }{ 16 | // HTTP配置 17 | {"DefaultMaxConcurrency", DefaultMaxConcurrency, 1, 10000}, 18 | {"DefaultMaxKeyRetries", DefaultMaxKeyRetries, 1, 10}, 19 | {"HTTPMaxIdleConns", HTTPMaxIdleConns, 1, 1000}, 20 | {"HTTPMaxIdleConnsPerHost", HTTPMaxIdleConnsPerHost, 1, 1000}, 21 | {"HTTPMaxConnsPerHost", HTTPMaxConnsPerHost, 0, 1000}, 22 | 23 | // 日志配置 24 | {"DefaultLogBufferSize", DefaultLogBufferSize, 100, 100000}, 25 | {"DefaultLogWorkers", DefaultLogWorkers, 1, 10}, 26 | {"LogBatchSize", LogBatchSize, 1, 1000}, 27 | 28 | // Token配置 29 | {"TokenRandomBytes", TokenRandomBytes, 16, 64}, 30 | {"DefaultTokenStatsBufferSize", DefaultTokenStatsBufferSize, 100, 100000}, 31 | 32 | // SQLite配置 33 | {"SQLiteMaxOpenConnsFile", SQLiteMaxOpenConnsFile, 1, 100}, 34 | {"SQLiteMaxIdleConnsFile", SQLiteMaxIdleConnsFile, 1, 100}, 35 | 36 | // 日志超时配置 37 | {"LogFlushTimeoutMs", LogFlushTimeoutMs, 100, 60000}, // 毫秒 38 | {"RedisSyncShutdownTimeoutMs", RedisSyncShutdownTimeoutMs, 100, 10000}, 39 | } 40 | 41 | for _, tt := range tests { 42 | t.Run(tt.name, func(t *testing.T) { 43 | if tt.value < tt.min || tt.value > tt.max { 44 | t.Errorf("%s=%d 超出合理范围 [%d, %d]", tt.name, tt.value, tt.min, tt.max) 45 | } 46 | }) 47 | } 48 | } 49 | 50 | // TestBufferSizeConstants 测试缓冲区大小常量 51 | func TestBufferSizeConstants(t *testing.T) { 52 | tests := []struct { 53 | name string 54 | value int 55 | min int 56 | max int 57 | }{ 58 | {"TLSSessionCacheSize", TLSSessionCacheSize, 0, 10000}, 59 | {"DefaultMaxBodyBytes", DefaultMaxBodyBytes, 1024, 100 * 1024 * 1024}, 60 | } 61 | 62 | for _, tt := range tests { 63 | t.Run(tt.name, func(t *testing.T) { 64 | if tt.value < tt.min || tt.value > tt.max { 65 | t.Errorf("%s=%d 超出合理范围 [%d, %d]", tt.name, tt.value, tt.min, tt.max) 66 | } 67 | }) 68 | } 69 | } 70 | 71 | // TestConfigRelationships 测试配置项之间的关系 72 | func TestConfigRelationships(t *testing.T) { 73 | // SQLite连接池配置: MaxOpenConns >= MaxIdleConns 74 | if SQLiteMaxOpenConnsFile < SQLiteMaxIdleConnsFile { 75 | t.Errorf("文件模式: MaxOpenConns(%d) < MaxIdleConns(%d)", 76 | SQLiteMaxOpenConnsFile, SQLiteMaxIdleConnsFile) 77 | } 78 | 79 | // HTTP连接池配置: MaxIdleConns >= MaxIdleConnsPerHost 80 | if HTTPMaxIdleConns < HTTPMaxIdleConnsPerHost { 81 | t.Errorf("HTTP: MaxIdleConns(%d) < MaxIdleConnsPerHost(%d)", 82 | HTTPMaxIdleConns, HTTPMaxIdleConnsPerHost) 83 | } 84 | 85 | // 日志配置: BufferSize >= BatchSize 86 | if DefaultLogBufferSize < LogBatchSize { 87 | t.Errorf("日志: BufferSize(%d) < BatchSize(%d)", 88 | DefaultLogBufferSize, LogBatchSize) 89 | } 90 | 91 | // 日志清理: CleanupInterval < 最小保留天数(1天) 92 | // log_retention_days 最小值为1天(24h), 清理间隔必须小于它 93 | cleanupHours := int(LogCleanupInterval.Hours()) 94 | minRetentionHours := 24 // 最小保留1天 95 | if cleanupHours >= minRetentionHours { 96 | t.Errorf("日志清理: CleanupInterval(%dh) >= MinRetention(%dh)", 97 | cleanupHours, minRetentionHours) 98 | } 99 | } 100 | 101 | // TestRedisSyncShutdownTimeout 测试Redis同步关闭超时 102 | func TestRedisSyncShutdownTimeout(t *testing.T) { 103 | // 关闭超时应该在合理范围内 (100ms - 10s) 104 | if RedisSyncShutdownTimeoutMs < 100 { 105 | t.Errorf("RedisSyncShutdownTimeout=%dms 太短", RedisSyncShutdownTimeoutMs) 106 | } 107 | if RedisSyncShutdownTimeoutMs > 10000 { 108 | t.Errorf("RedisSyncShutdownTimeout=%dms 太长", RedisSyncShutdownTimeoutMs) 109 | } 110 | } 111 | 112 | // TestHTTPTimeoutValues 测试HTTP超时值的合理性 113 | func TestHTTPTimeoutValues(t *testing.T) { 114 | // 所有HTTP超时应该大于0 115 | timeouts := map[string]time.Duration{ 116 | "HTTPDialTimeout": HTTPDialTimeout, 117 | "HTTPKeepAliveInterval": HTTPKeepAliveInterval, 118 | "HTTPTLSHandshakeTimeout": HTTPTLSHandshakeTimeout, 119 | } 120 | 121 | for name, value := range timeouts { 122 | if value <= 0 { 123 | t.Errorf("%s=%v 应该大于0", name, value) 124 | } 125 | } 126 | } 127 | 128 | // TestLogConfigValues 测试日志配置值的合理性 129 | func TestLogConfigValues(t *testing.T) { 130 | // 日志Worker数量应该合理 131 | if DefaultLogWorkers < 1 { 132 | t.Error("DefaultLogWorkers应该至少为1") 133 | } 134 | if DefaultLogWorkers > 10 { 135 | t.Logf("DefaultLogWorkers=%d 可能过多", DefaultLogWorkers) 136 | } 137 | 138 | // 日志批次大小应该小于缓冲区大小 139 | if LogBatchSize > DefaultLogBufferSize { 140 | t.Errorf("LogBatchSize(%d) > DefaultLogBufferSize(%d)", 141 | LogBatchSize, DefaultLogBufferSize) 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /web/assets/js/channels-state.js: -------------------------------------------------------------------------------- 1 | // 全局状态与通用工具函数 2 | let channels = []; 3 | let channelStatsById = {}; 4 | let editingChannelId = null; 5 | let deletingChannelId = null; 6 | let testingChannelId = null; 7 | let currentChannelKeyCooldowns = []; // 当前编辑渠道的Key冷却信息 8 | let redirectTableData = []; // 模型重定向表格数据: [{from: '', to: ''}] 9 | let defaultTestContent = 'sonnet 4.0的发布日期是什么'; // 默认测试内容(从设置加载) 10 | let channelStatsRange = 'today'; // 渠道统计时间范围(从设置加载) 11 | let channelsCache = {}; // 按类型缓存渠道数据: {type: channels[]} 12 | 13 | // Filter state 14 | let filters = { 15 | search: '', 16 | id: '', 17 | channelType: 'all', 18 | status: 'all', 19 | model: 'all' 20 | }; 21 | 22 | // 内联Key表格状态 23 | let inlineKeyTableData = []; 24 | let inlineKeyVisible = false; // 密码可见性状态 25 | let selectedKeyIndices = new Set(); // 选中的Key索引集合 26 | let currentKeyStatusFilter = 'all'; // 当前状态筛选:all/normal/cooldown 27 | let channelFormDirty = false; // 表单是否有未保存的更改 28 | 29 | // 虚拟滚动实现:优化大量Key时的渲染性能 30 | const VIRTUAL_SCROLL_CONFIG = { 31 | ROW_HEIGHT: 40, // 每行高度(像素) 32 | BUFFER_SIZE: 5, // 上下缓冲区行数(减少滚动时的闪烁) 33 | ENABLE_THRESHOLD: 50, // 启用虚拟滚动的阈值(Key数量) 34 | CONTAINER_HEIGHT: 250 // 容器固定高度(像素) 35 | }; 36 | 37 | let virtualScrollState = { 38 | enabled: false, 39 | scrollTop: 0, 40 | visibleStart: 0, 41 | visibleEnd: 0, 42 | rafId: null, 43 | filteredIndices: [] // 存储筛选后的索引列表(支持状态筛选) 44 | }; 45 | 46 | // 清除渠道缓存(在增删改操作后调用) 47 | function clearChannelsCache() { 48 | channelsCache = {}; 49 | } 50 | 51 | function humanizeMS(ms) { 52 | let s = Math.ceil(ms / 1000); 53 | const h = Math.floor(s / 3600); 54 | s = s % 3600; 55 | const m = Math.floor(s / 60); 56 | s = s % 60; 57 | 58 | if (h > 0) return `${h}小时${m}分`; 59 | if (m > 0) return `${m}分${s}秒`; 60 | return `${s}秒`; 61 | } 62 | 63 | function formatMetricNumber(value) { 64 | if (value === null || value === undefined) return '--'; 65 | const num = Number(value); 66 | if (!Number.isFinite(num)) return '--'; 67 | return formatCompactNumber(num); 68 | } 69 | 70 | function formatCompactNumber(num) { 71 | const abs = Math.abs(num); 72 | if (abs >= 1_000_000) return (num / 1_000_000).toFixed(1).replace(/\.0$/, '') + 'M'; 73 | if (abs >= 1_000) return (num / 1_000).toFixed(1).replace(/\.0$/, '') + 'K'; 74 | return num.toString(); 75 | } 76 | 77 | function formatSuccessRate(success, total) { 78 | if (success === null || success === undefined || total === null || total === undefined) return '--'; 79 | const succ = Number(success); 80 | const ttl = Number(total); 81 | if (!Number.isFinite(succ) || !Number.isFinite(ttl) || ttl <= 0) return '--'; 82 | return ((succ / ttl) * 100).toFixed(1) + '%'; 83 | } 84 | 85 | function formatAvgFirstByte(value) { 86 | if (value === null || value === undefined) return '--'; 87 | const num = Number(value); 88 | if (!Number.isFinite(num) || num <= 0) return '--'; 89 | return num.toFixed(2) + '秒'; 90 | } 91 | 92 | function formatCostValue(cost) { 93 | if (cost === null || cost === undefined) return '--'; 94 | const num = Number(cost); 95 | if (!Number.isFinite(num)) return '--'; 96 | if (num === 0) return '$0.00'; 97 | if (num < 0) return '--'; 98 | return formatCost(num); 99 | } 100 | 101 | function getStatsRangeLabel(range) { 102 | const labels = { 103 | 'today': '本日', 104 | 'this_week': '本周', 105 | 'this_month': '本月', 106 | 'all': '全部' 107 | }; 108 | return labels[range] || '本日'; 109 | } 110 | 111 | function formatTimestampForFilename() { 112 | const pad = (n) => String(n).padStart(2, '0'); 113 | const now = new Date(); 114 | return `${now.getFullYear()}${pad(now.getMonth() + 1)}${pad(now.getDate())}-${pad(now.getHours())}${pad(now.getMinutes())}${pad(now.getSeconds())}`; 115 | } 116 | 117 | // 遮罩Key显示(保留前后各4个字符) 118 | function maskKey(key) { 119 | if (key.length <= 8) return '***'; 120 | return key.slice(0, 4) + '***' + key.slice(-4); 121 | } 122 | 123 | // 标记表单有未保存的更改 124 | function markChannelFormDirty() { 125 | channelFormDirty = true; 126 | const saveBtn = document.getElementById('channelSaveBtn'); 127 | if (saveBtn && !saveBtn.classList.contains('btn-warning')) { 128 | saveBtn.classList.remove('btn-primary'); 129 | saveBtn.classList.add('btn-warning'); 130 | saveBtn.textContent = '保存 *'; 131 | } 132 | } 133 | 134 | // 重置表单dirty状态 135 | function resetChannelFormDirty() { 136 | channelFormDirty = false; 137 | const saveBtn = document.getElementById('channelSaveBtn'); 138 | if (saveBtn) { 139 | saveBtn.classList.remove('btn-warning'); 140 | saveBtn.classList.add('btn-primary'); 141 | saveBtn.textContent = '保存'; 142 | } 143 | } 144 | 145 | // 通知系统统一由 ui.js 提供(showNotification/showSuccess/showError) 146 | -------------------------------------------------------------------------------- /web/assets/css/tokens.css: -------------------------------------------------------------------------------- 1 | .token-table { 2 | width: 100%; 3 | background: white; 4 | border-radius: 12px; 5 | overflow: hidden; 6 | box-shadow: 0 1px 3px rgba(0,0,0,0.1); 7 | } 8 | .token-table table { 9 | width: 100%; 10 | border-collapse: collapse; 11 | } 12 | .token-table thead { 13 | background: var(--neutral-100); 14 | } 15 | .token-table th { 16 | padding: 12px 16px; 17 | text-align: left; 18 | font-size: 13px; 19 | font-weight: 600; 20 | color: var(--neutral-700); 21 | border-bottom: 1px solid var(--neutral-200); 22 | } 23 | .token-table td { 24 | padding: 12px 16px; 25 | border-bottom: 1px solid var(--neutral-100); 26 | font-size: 14px; 27 | } 28 | .token-table tbody tr:hover { 29 | background: var(--neutral-50); 30 | } 31 | .token-table tbody tr:last-child td { 32 | border-bottom: none; 33 | } 34 | .token-display { 35 | font-family: monospace; 36 | background: var(--neutral-100); 37 | padding: 4px 8px; 38 | border-radius: 4px; 39 | font-size: 13px; 40 | display: inline-block; 41 | } 42 | /* Token状态颜色 */ 43 | .token-display-active { 44 | background: var(--success-50); 45 | color: var(--success-800); 46 | border: 1px solid var(--success-300); 47 | } 48 | .token-display-inactive { 49 | background: var(--neutral-100); 50 | color: var(--neutral-600); 51 | border: 1px solid var(--neutral-300); 52 | } 53 | .token-display-expired { 54 | background: var(--error-50); 55 | color: var(--error-800); 56 | border: 1px solid var(--error-300); 57 | } 58 | .status-badge { 59 | display: inline-block; 60 | padding: 4px 12px; 61 | border-radius: 12px; 62 | font-size: 12px; 63 | font-weight: 600; 64 | } 65 | .status-active { 66 | background: var(--success-100); 67 | color: var(--success-700); 68 | } 69 | .status-inactive { 70 | background: var(--neutral-200); 71 | color: var(--neutral-700); 72 | } 73 | .status-expired { 74 | background: var(--error-100); 75 | color: var(--error-700); 76 | } 77 | .stats-badge { 78 | display: inline-block; 79 | padding: 2px 8px; 80 | border-radius: 4px; 81 | font-size: 12px; 82 | font-weight: 600; 83 | background: var(--neutral-100); 84 | color: var(--neutral-700); 85 | } 86 | .success-rate-high { 87 | background: var(--success-100); 88 | color: var(--success-700); 89 | } 90 | .success-rate-medium { 91 | background: var(--warning-100); 92 | color: var(--warning-700); 93 | } 94 | .success-rate-low { 95 | background: var(--error-100); 96 | color: var(--error-700); 97 | } 98 | .metric-value { 99 | font-weight: 600; 100 | color: var(--neutral-900); 101 | } 102 | .metric-label { 103 | font-size: 11px; 104 | color: var(--neutral-600); 105 | margin-top: 2px; 106 | } 107 | /* 响应时间颜色 */ 108 | .response-fast { 109 | color: var(--success-700); 110 | font-weight: 600; 111 | } 112 | .response-medium { 113 | color: var(--warning-700); 114 | font-weight: 600; 115 | } 116 | .response-slow { 117 | color: var(--error-700); 118 | font-weight: 600; 119 | } 120 | .modal { 121 | display: none; 122 | position: fixed; 123 | z-index: 1000; 124 | left: 0; 125 | top: 0; 126 | width: 100%; 127 | height: 100%; 128 | background-color: rgba(0,0,0,0.5); 129 | animation: fadeIn 0.2s; 130 | } 131 | .modal-content { 132 | background-color: white; 133 | margin: 5% auto; 134 | padding: 0; 135 | border-radius: 12px; 136 | max-width: 600px; 137 | box-shadow: 0 10px 40px rgba(0,0,0,0.2); 138 | animation: slideUp 0.3s; 139 | } 140 | .modal-header { 141 | padding: 20px 24px; 142 | border-bottom: 1px solid var(--neutral-200); 143 | display: flex; 144 | justify-content: space-between; 145 | align-items: center; 146 | } 147 | .modal-body { 148 | padding: 24px; 149 | } 150 | .modal-footer { 151 | padding: 16px 24px; 152 | border-top: 1px solid var(--neutral-200); 153 | display: flex; 154 | justify-content: flex-end; 155 | gap: 12px; 156 | } 157 | @keyframes fadeIn { 158 | from { opacity: 0; } 159 | to { opacity: 1; } 160 | } 161 | @keyframes slideUp { 162 | from { transform: translateY(50px); opacity: 0; } 163 | to { transform: translateY(0); opacity: 1; } 164 | } 165 | -------------------------------------------------------------------------------- /internal/storage/sql/transaction_deadline_test.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "testing" 8 | "time" 9 | 10 | _ "modernc.org/sqlite" 11 | ) 12 | 13 | // TestWithTransaction_ContextDeadline 验证 context.Deadline 限制总重试时间 14 | // [FIX] 后续优化: 防止事务重试超过 context 的 deadline 15 | func TestWithTransaction_ContextDeadline(t *testing.T) { 16 | t.Run("context 有 deadline 时应该提前退出", func(t *testing.T) { 17 | // 创建临时数据库 18 | db, err := sql.Open("sqlite", ":memory:") 19 | if err != nil { 20 | t.Fatalf("打开数据库失败: %v", err) 21 | } 22 | defer db.Close() 23 | 24 | // 创建一个 500ms deadline 的 context 25 | ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) 26 | defer cancel() 27 | 28 | attemptCount := 0 29 | start := time.Now() 30 | 31 | // 模拟一个总是返回 BUSY 错误的事务 32 | err = withTransaction(db, ctx, func(tx *sql.Tx) error { 33 | attemptCount++ 34 | // 模拟 SQLite BUSY 错误 35 | return errors.New("database is locked") 36 | }) 37 | 38 | elapsed := time.Since(start) 39 | 40 | // 验证:应该在 deadline 前退出(不是等到 12 次重试完) 41 | if err == nil { 42 | t.Fatal("期望失败,但成功了") 43 | } 44 | 45 | // 验证:耗时应该接近 500ms,而不是 51.2s(12 次重试的理论最大值) 46 | if elapsed > 1*time.Second { 47 | t.Errorf("重试耗时过长: %v(应该在 deadline 前退出)", elapsed) 48 | } 49 | 50 | // 验证:应该有多次重试(至少 2-3 次) 51 | if attemptCount < 2 { 52 | t.Errorf("重试次数过少: %d(应该至少有几次重试)", attemptCount) 53 | } 54 | 55 | // 验证:不应该达到最大重试次数 12 56 | if attemptCount >= 12 { 57 | t.Errorf("重试次数过多: %d(应该在 deadline 前退出)", attemptCount) 58 | } 59 | 60 | t.Logf("✅ context.Deadline 生效: 耗时 %v, 重试 %d 次后提前退出", elapsed, attemptCount) 61 | }) 62 | 63 | t.Run("没有 deadline 时应该正常重试到最大次数", func(t *testing.T) { 64 | // 创建临时数据库 65 | db, err := sql.Open("sqlite", ":memory:") 66 | if err != nil { 67 | t.Fatalf("打开数据库失败: %v", err) 68 | } 69 | defer db.Close() 70 | 71 | // 使用 background context(无 deadline) 72 | ctx := context.Background() 73 | 74 | attemptCount := 0 75 | 76 | // 模拟一个总是返回 BUSY 错误的事务 77 | err = withTransaction(db, ctx, func(tx *sql.Tx) error { 78 | attemptCount++ 79 | return errors.New("database is locked") 80 | }) 81 | 82 | // 验证:应该重试到最大次数 83 | if attemptCount != 12 { 84 | t.Errorf("重试次数不符合预期: got %d, want 12", attemptCount) 85 | } 86 | 87 | // 验证:错误信息应该包含"after 12 retries" 88 | if err == nil || err.Error() == "" { 89 | t.Fatal("期望失败,但成功了或错误为空") 90 | } 91 | 92 | t.Logf("✅ 无 deadline 时正常重试到最大次数: %d 次", attemptCount) 93 | }) 94 | 95 | t.Run("context 取消时应该立即退出", func(t *testing.T) { 96 | // 创建临时数据库 97 | db, err := sql.Open("sqlite", ":memory:") 98 | if err != nil { 99 | t.Fatalf("打开数据库失败: %v", err) 100 | } 101 | defer db.Close() 102 | 103 | // 创建可取消的 context 104 | ctx, cancel := context.WithCancel(context.Background()) 105 | 106 | attemptCount := 0 107 | start := time.Now() 108 | 109 | // 在第一次重试后取消 context 110 | go func() { 111 | time.Sleep(100 * time.Millisecond) 112 | cancel() 113 | }() 114 | 115 | // 模拟一个总是返回 BUSY 错误的事务 116 | err = withTransaction(db, ctx, func(tx *sql.Tx) error { 117 | attemptCount++ 118 | return errors.New("database is locked") 119 | }) 120 | 121 | elapsed := time.Since(start) 122 | 123 | // 验证:应该快速退出(不是等到 12 次重试完) 124 | if elapsed > 500*time.Millisecond { 125 | t.Errorf("取消后耗时过长: %v", elapsed) 126 | } 127 | 128 | // 验证:错误信息应该包含"cancelled" 129 | if err == nil { 130 | t.Fatal("期望失败,但成功了") 131 | } 132 | 133 | t.Logf("✅ context 取消时立即退出: 耗时 %v, 重试 %d 次", elapsed, attemptCount) 134 | }) 135 | } 136 | 137 | // TestWithTransaction_DeadlineRealWorld 模拟真实的 deadline 场景 138 | func TestWithTransaction_DeadlineRealWorld(t *testing.T) { 139 | t.Run("HTTP 请求超时应该传播到事务层", func(t *testing.T) { 140 | // 创建临时数据库 141 | db, err := sql.Open("sqlite", ":memory:") 142 | if err != nil { 143 | t.Fatalf("打开数据库失败: %v", err) 144 | } 145 | defer db.Close() 146 | 147 | // 模拟 HTTP 请求的 1 秒超时 148 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 149 | defer cancel() 150 | 151 | attemptCount := 0 152 | start := time.Now() 153 | 154 | // 模拟事务操作(总是失败) 155 | err = withTransaction(db, ctx, func(tx *sql.Tx) error { 156 | attemptCount++ 157 | return errors.New("database is deadlocked") 158 | }) 159 | if err == nil { 160 | t.Fatal("期望事务失败,但成功了") 161 | } 162 | 163 | elapsed := time.Since(start) 164 | 165 | // 验证:应该在 1 秒左右退出 166 | if elapsed > 1500*time.Millisecond { 167 | t.Errorf("超时控制失效: 耗时 %v(应该约 1s)", elapsed) 168 | } 169 | 170 | // 验证:不应该达到 12 次重试 171 | if attemptCount >= 12 { 172 | t.Errorf("重试次数过多: %d(应该被 deadline 提前终止)", attemptCount) 173 | } 174 | 175 | t.Logf("✅ HTTP 超时传播到事务层: 耗时 %v, 重试 %d 次后退出", elapsed, attemptCount) 176 | }) 177 | } 178 | -------------------------------------------------------------------------------- /internal/app/key_selector.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "ccLoad/internal/model" 5 | "fmt" 6 | "sync" 7 | "sync/atomic" 8 | "time" 9 | ) 10 | 11 | // KeySelector 负责从渠道的多个API Key中选择可用的Key 12 | // 移除store依赖,避免重复查询数据库 13 | type KeySelector struct { 14 | // 轮询计数器:channelID -> *rrCounter 15 | // 渠道删除时需要清理对应计数器,避免rrCounters无界增长。 16 | rrCounters map[int64]*rrCounter 17 | rrMutex sync.RWMutex 18 | } 19 | 20 | // rrCounter 轮询计数器(简化版) 21 | type rrCounter struct { 22 | counter atomic.Uint32 23 | } 24 | 25 | // NewKeySelector 创建Key选择器 26 | func NewKeySelector() *KeySelector { 27 | return &KeySelector{ 28 | rrCounters: make(map[int64]*rrCounter), 29 | } 30 | } 31 | 32 | // SelectAvailableKey 返回 (keyIndex, apiKey, error) 33 | // 策略: sequential顺序尝试 | round_robin轮询选择 34 | // excludeKeys: 避免同一请求内重复尝试 35 | // 移除store依赖,apiKeys由调用方传入,避免重复查询 36 | func (ks *KeySelector) SelectAvailableKey(channelID int64, apiKeys []*model.APIKey, excludeKeys map[int]bool) (int, string, error) { 37 | if len(apiKeys) == 0 { 38 | return -1, "", fmt.Errorf("no API keys configured for channel %d", channelID) 39 | } 40 | 41 | // 单Key场景:检查排除和冷却状态 42 | if len(apiKeys) == 1 { 43 | keyIndex := apiKeys[0].KeyIndex 44 | // [FIX] 使用真实 KeyIndex 检查排除集合,而非硬编码0 45 | if excludeKeys != nil && excludeKeys[keyIndex] { 46 | return -1, "", fmt.Errorf("single key (index=%d) already tried in this request", keyIndex) 47 | } 48 | // [INFO] 修复(2025-12-09): 检查冷却状态,防止单Key渠道冷却后仍被请求 49 | // 原逻辑"不使用Key级别冷却(YAGNI原则)"是错误的,会导致冷却Key持续触发上游错误 50 | if apiKeys[0].IsCoolingDown(time.Now()) { 51 | return -1, "", fmt.Errorf("single key (index=%d) is in cooldown until %s", 52 | keyIndex, 53 | time.Unix(apiKeys[0].CooldownUntil, 0).Format("2006-01-02 15:04:05")) 54 | } 55 | return keyIndex, apiKeys[0].APIKey, nil 56 | } 57 | 58 | // 多Key场景:根据策略选择 59 | strategy := apiKeys[0].KeyStrategy 60 | if strategy == "" { 61 | strategy = model.KeyStrategySequential 62 | } 63 | 64 | switch strategy { 65 | case model.KeyStrategyRoundRobin: 66 | return ks.selectRoundRobin(channelID, apiKeys, excludeKeys) 67 | case model.KeyStrategySequential: 68 | return ks.selectSequential(apiKeys, excludeKeys) 69 | default: 70 | return ks.selectSequential(apiKeys, excludeKeys) 71 | } 72 | } 73 | 74 | func (ks *KeySelector) selectSequential(apiKeys []*model.APIKey, excludeKeys map[int]bool) (int, string, error) { 75 | now := time.Now() 76 | 77 | for _, apiKey := range apiKeys { 78 | keyIndex := apiKey.KeyIndex 79 | 80 | if excludeKeys != nil && excludeKeys[keyIndex] { 81 | continue 82 | } 83 | 84 | if apiKey.IsCoolingDown(now) { 85 | continue 86 | } 87 | 88 | return keyIndex, apiKey.APIKey, nil 89 | } 90 | 91 | return -1, "", fmt.Errorf("all API keys are in cooldown or already tried") 92 | } 93 | 94 | // getOrCreateCounter 获取或创建渠道的轮询计数器(双重检查锁定) 95 | func (ks *KeySelector) getOrCreateCounter(channelID int64) *rrCounter { 96 | ks.rrMutex.RLock() 97 | counter, ok := ks.rrCounters[channelID] 98 | ks.rrMutex.RUnlock() 99 | 100 | if ok { 101 | return counter 102 | } 103 | 104 | ks.rrMutex.Lock() 105 | defer ks.rrMutex.Unlock() 106 | 107 | // 再次检查,避免多个goroutine同时创建 108 | if counter, ok = ks.rrCounters[channelID]; !ok { 109 | counter = &rrCounter{} 110 | ks.rrCounters[channelID] = counter 111 | } 112 | return counter 113 | } 114 | 115 | // RemoveChannelCounter 删除指定渠道的轮询计数器。 116 | // 在渠道被删除时调用,避免rrCounters长期积累。 117 | func (ks *KeySelector) RemoveChannelCounter(channelID int64) { 118 | ks.rrMutex.Lock() 119 | delete(ks.rrCounters, channelID) 120 | ks.rrMutex.Unlock() 121 | } 122 | 123 | // selectRoundRobin 轮询选择可用Key 124 | // [FIX] 按 slice 索引轮询,返回真实 KeyIndex,不再假设 KeyIndex 连续 125 | func (ks *KeySelector) selectRoundRobin(channelID int64, apiKeys []*model.APIKey, excludeKeys map[int]bool) (int, string, error) { 126 | keyCount := len(apiKeys) 127 | now := time.Now() 128 | 129 | counter := ks.getOrCreateCounter(channelID) 130 | startIdx := int(counter.counter.Add(1) % uint32(keyCount)) 131 | 132 | // 从startIdx开始轮询,最多尝试keyCount次 133 | for i := range keyCount { 134 | sliceIdx := (startIdx + i) % keyCount 135 | selectedKey := apiKeys[sliceIdx] 136 | if selectedKey == nil { 137 | continue 138 | } 139 | 140 | keyIndex := selectedKey.KeyIndex // 真实 KeyIndex,可能不连续 141 | 142 | // 检查排除集合(使用真实 KeyIndex) 143 | if excludeKeys != nil && excludeKeys[keyIndex] { 144 | continue 145 | } 146 | 147 | if selectedKey.IsCoolingDown(now) { 148 | continue 149 | } 150 | 151 | // 返回真实 KeyIndex,而非 slice 索引 152 | return keyIndex, selectedKey.APIKey, nil 153 | } 154 | 155 | return -1, "", fmt.Errorf("all API keys are in cooldown or already tried") 156 | } 157 | 158 | // KeySelector 专注于Key选择逻辑,冷却管理已移至 cooldownManager 159 | // 移除的方法: MarkKeyError, MarkKeySuccess, GetKeyCooldownInfo 160 | // 原因: 违反SRP原则,冷却管理应由专门的 cooldownManager 负责 161 | -------------------------------------------------------------------------------- /internal/storage/sql/auth_token_stats.go: -------------------------------------------------------------------------------- 1 | package sql 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "time" 7 | 8 | "ccLoad/internal/model" 9 | ) 10 | 11 | // GetAuthTokenStatsInRange 查询指定时间范围内每个token的统计数据(从logs表聚合) 12 | // 用于tokens.html页面按时间范围筛选显示(2025-12新增) 13 | func (s *SQLStore) GetAuthTokenStatsInRange(ctx context.Context, startTime, endTime time.Time) (map[int64]*model.AuthTokenRangeStats, error) { 14 | sinceMs := startTime.UnixMilli() 15 | untilMs := endTime.UnixMilli() 16 | 17 | query := ` 18 | SELECT 19 | auth_token_id, 20 | SUM(CASE WHEN status_code >= 200 AND status_code < 300 THEN 1 ELSE 0 END) AS success_count, 21 | SUM(CASE WHEN status_code < 200 OR status_code >= 300 THEN 1 ELSE 0 END) AS failure_count, 22 | SUM(input_tokens) AS prompt_tokens, 23 | SUM(output_tokens) AS completion_tokens, 24 | SUM(cache_read_input_tokens) AS cache_read_tokens, 25 | SUM(cache_creation_input_tokens) AS cache_creation_tokens, 26 | SUM(cost) AS total_cost, 27 | AVG(CASE WHEN is_streaming = 1 THEN first_byte_time ELSE NULL END) AS stream_avg_ttfb, 28 | AVG(CASE WHEN is_streaming = 0 THEN duration ELSE NULL END) AS non_stream_avg_rt, 29 | SUM(CASE WHEN is_streaming = 1 THEN 1 ELSE 0 END) AS stream_count, 30 | SUM(CASE WHEN is_streaming = 0 THEN 1 ELSE 0 END) AS non_stream_count 31 | FROM logs 32 | WHERE time >= ? AND time <= ? AND auth_token_id > 0 33 | GROUP BY auth_token_id 34 | ` 35 | 36 | rows, err := s.db.QueryContext(ctx, query, sinceMs, untilMs) 37 | if err != nil { 38 | return nil, err 39 | } 40 | defer rows.Close() 41 | 42 | stats := make(map[int64]*model.AuthTokenRangeStats) 43 | for rows.Next() { 44 | var tokenID int64 45 | var stat model.AuthTokenRangeStats 46 | var streamAvgTTFB, nonStreamAvgRT sql.NullFloat64 47 | 48 | if err := rows.Scan(&tokenID, &stat.SuccessCount, &stat.FailureCount, 49 | &stat.PromptTokens, &stat.CompletionTokens, 50 | &stat.CacheReadTokens, &stat.CacheCreationTokens, 51 | &stat.TotalCost, 52 | &streamAvgTTFB, &nonStreamAvgRT, 53 | &stat.StreamCount, &stat.NonStreamCount); err != nil { 54 | return nil, err 55 | } 56 | 57 | // 处理NULL值(当没有该类型请求时AVG返回NULL) 58 | if streamAvgTTFB.Valid { 59 | stat.StreamAvgTTFB = streamAvgTTFB.Float64 60 | } 61 | if nonStreamAvgRT.Valid { 62 | stat.NonStreamAvgRT = nonStreamAvgRT.Float64 63 | } 64 | 65 | stats[tokenID] = &stat 66 | } 67 | 68 | return stats, rows.Err() 69 | } 70 | 71 | // FillAuthTokenRPMStats 计算每个token的RPM统计(峰值、平均、最近) 72 | // 直接修改传入的stats map中的RPM字段 73 | func (s *SQLStore) FillAuthTokenRPMStats(ctx context.Context, stats map[int64]*model.AuthTokenRangeStats, startTime, endTime time.Time, isToday bool) error { 74 | if len(stats) == 0 { 75 | return nil 76 | } 77 | 78 | sinceMs := startTime.UnixMilli() 79 | untilMs := endTime.UnixMilli() 80 | 81 | // 计算时间跨度(秒) 82 | durationSeconds := endTime.Sub(startTime).Seconds() 83 | if durationSeconds < 1 { 84 | durationSeconds = 1 85 | } 86 | 87 | // 1. 计算平均RPM = 总请求数 × 60 / 时间范围秒数 88 | for _, stat := range stats { 89 | totalCount := stat.SuccessCount + stat.FailureCount 90 | stat.AvgRPM = float64(totalCount) * 60 / durationSeconds 91 | } 92 | 93 | // 2. 计算峰值RPM(每分钟请求数的最大值) 94 | peakQuery := ` 95 | SELECT auth_token_id, MAX(cnt) AS peak_rpm 96 | FROM ( 97 | SELECT auth_token_id, COUNT(*) AS cnt 98 | FROM logs 99 | WHERE time >= ? AND time <= ? AND auth_token_id > 0 100 | GROUP BY auth_token_id, FLOOR(time / 60000) 101 | ) t 102 | GROUP BY auth_token_id 103 | ` 104 | peakRows, err := s.db.QueryContext(ctx, peakQuery, sinceMs, untilMs) 105 | if err != nil { 106 | return err 107 | } 108 | defer peakRows.Close() 109 | 110 | for peakRows.Next() { 111 | var tokenID int64 112 | var peakRPM float64 113 | if err := peakRows.Scan(&tokenID, &peakRPM); err != nil { 114 | return err 115 | } 116 | if stat, ok := stats[tokenID]; ok { 117 | stat.PeakRPM = peakRPM 118 | } 119 | } 120 | 121 | // 3. 计算最近一分钟RPM(仅本日有效) 122 | if isToday { 123 | now := time.Now() 124 | recentStartMs := now.Add(-60 * time.Second).UnixMilli() 125 | recentEndMs := now.UnixMilli() 126 | 127 | recentQuery := ` 128 | SELECT auth_token_id, COUNT(*) AS cnt 129 | FROM logs 130 | WHERE time >= ? AND time <= ? AND auth_token_id > 0 131 | GROUP BY auth_token_id 132 | ` 133 | recentRows, err := s.db.QueryContext(ctx, recentQuery, recentStartMs, recentEndMs) 134 | if err != nil { 135 | return err 136 | } 137 | defer recentRows.Close() 138 | 139 | for recentRows.Next() { 140 | var tokenID int64 141 | var recentRPM float64 142 | if err := recentRows.Scan(&tokenID, &recentRPM); err != nil { 143 | return err 144 | } 145 | if stat, ok := stats[tokenID]; ok { 146 | stat.RecentRPM = recentRPM 147 | // 峰值必须 >= 最近值 148 | if stat.PeakRPM < recentRPM { 149 | stat.PeakRPM = recentRPM 150 | } 151 | } 152 | } 153 | } 154 | 155 | return nil 156 | } 157 | -------------------------------------------------------------------------------- /web/assets/js/settings.js: -------------------------------------------------------------------------------- 1 | // 系统设置页面 2 | initTopbar('settings'); 3 | 4 | let originalSettings = {}; // 保存原始值用于比较 5 | 6 | async function loadSettings() { 7 | try { 8 | const data = await fetchDataWithAuth('/admin/settings'); 9 | if (!Array.isArray(data)) throw new Error('响应不是数组'); 10 | renderSettings(data); 11 | } catch (err) { 12 | console.error('加载配置异常:', err); 13 | showError('加载配置异常: ' + err.message); 14 | } 15 | } 16 | 17 | function renderSettings(settings) { 18 | const tbody = document.getElementById('settings-tbody'); 19 | originalSettings = {}; 20 | tbody.innerHTML = ''; 21 | 22 | // 初始化事件委托(仅一次) 23 | initSettingsEventDelegation(); 24 | 25 | settings.forEach(s => { 26 | originalSettings[s.key] = s.value; 27 | const row = TemplateEngine.render('tpl-setting-row', { 28 | key: s.key, 29 | description: s.description, 30 | inputHtml: renderInput(s) 31 | }); 32 | if (row) tbody.appendChild(row); 33 | }); 34 | } 35 | 36 | // 初始化事件委托(替代 inline onclick) 37 | function initSettingsEventDelegation() { 38 | const tbody = document.getElementById('settings-tbody'); 39 | if (!tbody || tbody.dataset.delegated) return; 40 | tbody.dataset.delegated = 'true'; 41 | 42 | // 重置按钮点击 43 | tbody.addEventListener('click', (e) => { 44 | const resetBtn = e.target.closest('.setting-reset-btn'); 45 | if (resetBtn) { 46 | resetSetting(resetBtn.dataset.key); 47 | } 48 | }); 49 | 50 | // 输入变更 51 | tbody.addEventListener('change', (e) => { 52 | const input = e.target.closest('input'); 53 | if (input) markChanged(input); 54 | }); 55 | } 56 | 57 | function renderInput(setting) { 58 | const safeKey = escapeHtml(setting.key); 59 | const safeValue = escapeHtml(setting.value); 60 | const baseStyle = 'padding: 6px 10px; border: 1px solid var(--color-border); border-radius: 6px; background: var(--color-bg-secondary); color: var(--color-text); font-size: 13px;'; 61 | 62 | switch (setting.value_type) { 63 | case 'bool': 64 | const checked = setting.value === 'true' || setting.value === '1'; 65 | return ``; 66 | case 'int': 67 | case 'duration': 68 | return ``; 69 | default: 70 | return ``; 71 | } 72 | } 73 | 74 | function markChanged(input) { 75 | const key = input.id; 76 | const row = input.closest('tr'); 77 | 78 | const currentValue = input.type === 'checkbox' ? (input.checked ? 'true' : 'false') : input.value; 79 | if (currentValue !== originalSettings[key]) { 80 | row.style.background = 'rgba(59, 130, 246, 0.08)'; 81 | } else { 82 | row.style.background = ''; 83 | } 84 | } 85 | 86 | async function saveAllSettings() { 87 | // 收集所有变更 88 | const updates = {}; 89 | const needsRestartKeys = []; 90 | 91 | for (const key of Object.keys(originalSettings)) { 92 | const input = document.getElementById(key); 93 | if (!input) continue; 94 | 95 | const currentValue = input.type === 'checkbox' ? (input.checked ? 'true' : 'false') : input.value; 96 | if (currentValue !== originalSettings[key]) { 97 | updates[key] = currentValue; 98 | // 检查是否需要重启(从 DOM 中读取 description) 99 | const row = input.closest('tr'); 100 | if (row?.querySelector('td')?.textContent?.includes('[需重启]')) { 101 | needsRestartKeys.push(key); 102 | } 103 | } 104 | } 105 | 106 | if (Object.keys(updates).length === 0) { 107 | showInfo('没有需要保存的更改'); 108 | return; 109 | } 110 | 111 | // 使用批量更新接口(单次请求,事务保护) 112 | try { 113 | await fetchDataWithAuth('/admin/settings/batch', { 114 | method: 'POST', 115 | headers: { 'Content-Type': 'application/json' }, 116 | body: JSON.stringify(updates) 117 | }); 118 | let msg = `已保存 ${Object.keys(updates).length} 项配置`; 119 | if (needsRestartKeys.length > 0) { 120 | msg += `\n\n以下配置需要重启服务才能生效:\n${needsRestartKeys.join(', ')}`; 121 | } 122 | showSuccess(msg); 123 | } catch (err) { 124 | console.error('保存异常:', err); 125 | showError('保存异常: ' + err.message); 126 | } 127 | 128 | loadSettings(); 129 | } 130 | 131 | async function resetSetting(key) { 132 | if (!confirm(`确定要重置 "${key}" 为默认值吗?`)) return; 133 | 134 | try { 135 | await fetchDataWithAuth(`/admin/settings/${key}/reset`, { method: 'POST' }); 136 | showSuccess(`配置 ${key} 已重置为默认值`); 137 | loadSettings(); 138 | } catch (err) { 139 | console.error('重置异常:', err); 140 | showError('重置异常: ' + err.message); 141 | } 142 | } 143 | 144 | // showSuccess/showError 已在 ui.js 中定义(toast 通知),无需重复定义 145 | function showInfo(msg) { 146 | window.showNotification(msg, 'info'); 147 | } 148 | 149 | // 页面加载时执行 150 | loadSettings(); 151 | -------------------------------------------------------------------------------- /web/assets/css/channels.css: -------------------------------------------------------------------------------- 1 | /* 响应式布局样式 */ 2 | @media (max-width: 640px) { 3 | /* 移动端:垂直布局 */ 4 | .form-row-flex { 5 | flex-direction: column !important; 6 | } 7 | .form-row-flex > div { 8 | width: 100% !important; 9 | } 10 | } 11 | 12 | /* 优化表单行布局 */ 13 | .form-row-flex { 14 | display: flex; 15 | gap: 12px; 16 | align-items: flex-start; 17 | } 18 | 19 | /* 单选框样式美化 */ 20 | input[type="radio"] { 21 | width: 16px; 22 | height: 16px; 23 | margin: 0; 24 | cursor: pointer; 25 | accent-color: var(--primary-500); 26 | } 27 | 28 | input[type="radio"]:hover { 29 | transform: scale(1.1); 30 | } 31 | 32 | /* 单选框标签样式 */ 33 | #channelTypeRadios label, 34 | #keyStrategyRadios label { 35 | transition: all 0.2s ease; 36 | padding: 4px 8px; 37 | border-radius: 6px; 38 | } 39 | 40 | #channelTypeRadios label:hover, 41 | #keyStrategyRadios label:hover { 42 | background-color: var(--neutral-100); 43 | } 44 | 45 | #channelTypeRadios label:has(input:checked), 46 | #keyStrategyRadios label:has(input:checked) { 47 | background-color: var(--primary-100); 48 | color: var(--primary-700); 49 | font-weight: 500; 50 | } 51 | 52 | /* Toast动画 */ 53 | @keyframes slideIn { 54 | from { 55 | transform: translateX(400px); 56 | opacity: 0; 57 | } 58 | to { 59 | transform: translateX(0); 60 | opacity: 1; 61 | } 62 | } 63 | 64 | @keyframes slideOut { 65 | from { 66 | transform: translateX(0); 67 | opacity: 1; 68 | } 69 | to { 70 | transform: translateX(400px); 71 | opacity: 0; 72 | } 73 | } 74 | 75 | /* 渠道统计徽章 */ 76 | .channel-stat-badge { 77 | display: inline-flex; 78 | align-items: center; 79 | gap: 4px; 80 | padding: 4px 8px; 81 | border-radius: 8px; 82 | background: var(--neutral-100); 83 | border: 1px solid var(--neutral-200); 84 | font-size: 12px; 85 | color: var(--neutral-700); 86 | line-height: 1.3; 87 | } 88 | 89 | .channel-stat-badge strong { 90 | color: var(--neutral-900); 91 | } 92 | 93 | .channel-meta-line { 94 | display: flex; 95 | align-items: center; 96 | justify-content: space-between; 97 | gap: 12px; 98 | flex-wrap: wrap; 99 | } 100 | 101 | .channel-stats-inline { 102 | display: flex; 103 | align-items: center; 104 | gap: 6px; 105 | flex-wrap: wrap; 106 | margin-left: auto; 107 | justify-content: flex-end; 108 | margin-right: 32px; 109 | } 110 | 111 | /* 表格/图表视图切换按钮 */ 112 | .view-toggle-group { 113 | display: flex; 114 | gap: 4px; 115 | background: var(--neutral-100); 116 | padding: 4px; 117 | border-radius: 8px; 118 | } 119 | 120 | .view-toggle-btn { 121 | display: flex; 122 | align-items: center; 123 | gap: 6px; 124 | padding: 6px 12px; 125 | border: none; 126 | background: transparent; 127 | color: var(--neutral-600); 128 | font-size: 13px; 129 | font-weight: 500; 130 | border-radius: 6px; 131 | cursor: pointer; 132 | transition: all 0.2s ease; 133 | } 134 | 135 | .view-toggle-btn:hover { 136 | color: var(--neutral-800); 137 | background: var(--neutral-200); 138 | } 139 | 140 | .view-toggle-btn.active { 141 | background: var(--white); 142 | color: var(--primary-600); 143 | box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); 144 | } 145 | 146 | .view-toggle-btn svg { 147 | width: 16px; 148 | height: 16px; 149 | } 150 | 151 | /* 图表网格布局 */ 152 | .charts-grid { 153 | display: grid; 154 | grid-template-columns: repeat(2, 1fr); 155 | gap: 24px; 156 | padding: 8px 0; 157 | } 158 | 159 | @media (max-width: 768px) { 160 | .charts-grid { 161 | grid-template-columns: 1fr; 162 | } 163 | } 164 | 165 | /* 图表卡片 */ 166 | .chart-card { 167 | background: var(--white); 168 | border: 1px solid var(--neutral-200); 169 | border-radius: 12px; 170 | padding: 16px; 171 | } 172 | 173 | .chart-title { 174 | font-size: 14px; 175 | font-weight: 600; 176 | color: var(--neutral-700); 177 | margin: 0 0 12px 0; 178 | text-align: center; 179 | } 180 | 181 | /* 饼图容器 */ 182 | .pie-chart-container { 183 | width: 100%; 184 | height: 280px; 185 | } 186 | 187 | /* Drag and Drop for Keys */ 188 | .draggable-key-row { 189 | cursor: move; 190 | transition: background-color 0.2s; 191 | } 192 | .draggable-key-row.dragging { 193 | opacity: 0.5; 194 | background-color: var(--neutral-100); 195 | } 196 | .draggable-key-row.drag-over { 197 | border-top: 2px solid var(--primary-500) !important; 198 | } 199 | -------------------------------------------------------------------------------- /internal/storage/cache_metrics_test.go: -------------------------------------------------------------------------------- 1 | package storage_test 2 | 3 | import ( 4 | "context" 5 | "path/filepath" 6 | "testing" 7 | "time" 8 | 9 | "ccLoad/internal/model" 10 | "ccLoad/internal/storage" 11 | ) 12 | 13 | func TestChannelCacheMetrics(t *testing.T) { 14 | ctx := context.Background() 15 | tmpDir := t.TempDir() 16 | 17 | dbPath := filepath.Join(tmpDir, "metrics.db") 18 | store, err := storage.CreateSQLiteStore(dbPath, nil) 19 | if err != nil { 20 | t.Fatalf("failed to create sqlite store: %v", err) 21 | } 22 | defer store.Close() 23 | 24 | cache := storage.NewChannelCache(store, time.Minute) 25 | 26 | cfg := &model.Config{ 27 | Name: "test-channel", 28 | URL: "https://example.com", 29 | Priority: 10, 30 | Models: []string{"model-a"}, 31 | Enabled: true, 32 | } 33 | created, err := store.CreateConfig(ctx, cfg) 34 | if err != nil { 35 | t.Fatalf("failed to create config: %v", err) 36 | } 37 | 38 | now := time.Now() 39 | apiKey := &model.APIKey{ 40 | ChannelID: created.ID, 41 | KeyIndex: 0, 42 | APIKey: "sk-test", 43 | KeyStrategy: model.KeyStrategySequential, 44 | CreatedAt: model.JSONTime{Time: now}, 45 | UpdatedAt: model.JSONTime{Time: now}, 46 | } 47 | if err := store.CreateAPIKey(ctx, apiKey); err != nil { 48 | t.Fatalf("failed to create api key: %v", err) 49 | } 50 | 51 | if _, err := cache.GetAPIKeys(ctx, created.ID); err != nil { 52 | t.Fatalf("unexpected error getting api keys: %v", err) 53 | } 54 | 55 | if _, err := cache.GetAPIKeys(ctx, created.ID); err != nil { 56 | t.Fatalf("unexpected error getting api keys (cached): %v", err) 57 | } 58 | 59 | cache.InvalidateAPIKeysCache(created.ID) 60 | 61 | stats := cache.GetCacheStats() 62 | 63 | if hits, ok := stats["api_keys_hits"].(uint64); !ok || hits != 1 { 64 | t.Fatalf("expected 1 api key hit, got %v", stats["api_keys_hits"]) 65 | } 66 | if misses, ok := stats["api_keys_misses"].(uint64); !ok || misses != 1 { 67 | t.Fatalf("expected 1 api key miss, got %v", stats["api_keys_misses"]) 68 | } 69 | if invalidations, ok := stats["api_keys_invalidations"].(uint64); !ok || invalidations != 1 { 70 | t.Fatalf("expected 1 api key invalidation, got %v", stats["api_keys_invalidations"]) 71 | } 72 | } 73 | 74 | // TestChannelCacheDeepCopy 验证缓存返回深拷贝,防止并发污染 75 | // [REGRESSION] 防止回归到浅拷贝实现(只拷贝slice不拷贝对象) 76 | func TestChannelCacheDeepCopy(t *testing.T) { 77 | ctx := context.Background() 78 | tmpDir := t.TempDir() 79 | 80 | dbPath := filepath.Join(tmpDir, "deepcopy.db") 81 | store, err := storage.CreateSQLiteStore(dbPath, nil) 82 | if err != nil { 83 | t.Fatalf("failed to create sqlite store: %v", err) 84 | } 85 | defer store.Close() 86 | 87 | cache := storage.NewChannelCache(store, time.Minute) 88 | 89 | // 创建测试渠道和API Key 90 | cfg := &model.Config{ 91 | Name: "test-channel", 92 | URL: "https://example.com", 93 | Priority: 10, 94 | Models: []string{"model-a"}, 95 | Enabled: true, 96 | } 97 | created, err := store.CreateConfig(ctx, cfg) 98 | if err != nil { 99 | t.Fatalf("failed to create config: %v", err) 100 | } 101 | 102 | now := time.Now() 103 | apiKey := &model.APIKey{ 104 | ChannelID: created.ID, 105 | KeyIndex: 0, 106 | APIKey: "sk-original-key", 107 | KeyStrategy: model.KeyStrategySequential, 108 | CreatedAt: model.JSONTime{Time: now}, 109 | UpdatedAt: model.JSONTime{Time: now}, 110 | } 111 | if err := store.CreateAPIKey(ctx, apiKey); err != nil { 112 | t.Fatalf("failed to create api key: %v", err) 113 | } 114 | 115 | // 第一次获取(填充缓存) 116 | keys1, err := cache.GetAPIKeys(ctx, created.ID) 117 | if err != nil { 118 | t.Fatalf("unexpected error getting api keys: %v", err) 119 | } 120 | if len(keys1) != 1 { 121 | t.Fatalf("expected 1 api key, got %d", len(keys1)) 122 | } 123 | 124 | // 第二次获取(从缓存读取) 125 | keys2, err := cache.GetAPIKeys(ctx, created.ID) 126 | if err != nil { 127 | t.Fatalf("unexpected error getting cached api keys: %v", err) 128 | } 129 | if len(keys2) != 1 { 130 | t.Fatalf("expected 1 api key, got %d", len(keys2)) 131 | } 132 | 133 | // 关键验证:修改keys1不应影响keys2 134 | originalKey := keys2[0].APIKey 135 | keys1[0].APIKey = "sk-POLLUTED-KEY" 136 | keys1[0].KeyIndex = 999 137 | 138 | // 第三次获取,验证缓存未被污染 139 | keys3, err := cache.GetAPIKeys(ctx, created.ID) 140 | if err != nil { 141 | t.Fatalf("unexpected error getting api keys after modification: %v", err) 142 | } 143 | if len(keys3) != 1 { 144 | t.Fatalf("expected 1 api key, got %d", len(keys3)) 145 | } 146 | 147 | // 验证:keys3应该保留原始值,不受keys1修改的影响 148 | if keys3[0].APIKey != originalKey { 149 | t.Fatalf("cache pollution detected! expected key=%q, got key=%q", originalKey, keys3[0].APIKey) 150 | } 151 | if keys3[0].KeyIndex != 0 { 152 | t.Fatalf("cache pollution detected! expected KeyIndex=0, got KeyIndex=%d", keys3[0].KeyIndex) 153 | } 154 | 155 | // 额外验证:keys2也不应被修改(因为是深拷贝) 156 | if keys2[0].APIKey != originalKey { 157 | t.Fatalf("shallow copy detected! keys2 was modified by keys1 mutation") 158 | } 159 | if keys2[0].KeyIndex != 0 { 160 | t.Fatalf("shallow copy detected! keys2 KeyIndex was modified by keys1 mutation") 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /internal/app/proxy_handler_test.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/gin-gonic/gin" 10 | ) 11 | 12 | func TestHandleProxyRequest_UnknownPathReturns404(t *testing.T) { 13 | gin.SetMode(gin.TestMode) 14 | 15 | srv := &Server{ 16 | concurrencySem: make(chan struct{}, 1), 17 | } 18 | 19 | body := bytes.NewBufferString(`{"model":"gpt-4"}`) 20 | req := httptest.NewRequest(http.MethodPost, "/v1/unknown", body) 21 | req.Header.Set("Content-Type", "application/json") 22 | 23 | w := httptest.NewRecorder() 24 | c, _ := gin.CreateTestContext(w) 25 | c.Request = req 26 | 27 | srv.HandleProxyRequest(c) 28 | 29 | if w.Code != http.StatusNotFound { 30 | t.Fatalf("预期状态码404,实际%d", w.Code) 31 | } 32 | 33 | if body := w.Body.String(); !bytes.Contains([]byte(body), []byte("unsupported path")) { 34 | t.Fatalf("响应内容缺少错误信息,实际: %s", body) 35 | } 36 | } 37 | 38 | // ============================================================================ 39 | // 增加proxy_handler测试覆盖率 40 | // ============================================================================ 41 | 42 | // TestParseIncomingRequest_ValidJSON 测试有效JSON解析 43 | func TestParseIncomingRequest_ValidJSON(t *testing.T) { 44 | gin.SetMode(gin.TestMode) 45 | 46 | tests := []struct { 47 | name string 48 | body string 49 | path string 50 | expectModel string 51 | expectStream bool 52 | expectError bool 53 | }{ 54 | { 55 | name: "有效JSON-claude模型", 56 | body: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":"hello"}]}`, 57 | path: "/v1/messages", 58 | expectModel: "claude-3-5-sonnet-20241022", 59 | expectStream: false, 60 | expectError: false, 61 | }, 62 | { 63 | name: "流式请求-stream=true", 64 | body: `{"model":"gpt-4","stream":true,"messages":[]}`, 65 | path: "/v1/chat/completions", 66 | expectModel: "gpt-4", 67 | expectStream: true, 68 | expectError: false, 69 | }, 70 | { 71 | name: "空模型名-从路径提取", 72 | body: `{"messages":[{"role":"user","content":"test"}]}`, 73 | path: "/v1/models/gpt-4/completions", 74 | expectModel: "gpt-4", 75 | expectStream: false, 76 | expectError: false, 77 | }, 78 | { 79 | name: "GET请求-无模型使用通配符", 80 | body: "", 81 | path: "/v1/models", 82 | expectModel: "*", 83 | expectStream: false, 84 | expectError: false, 85 | }, 86 | } 87 | 88 | for _, tt := range tests { 89 | t.Run(tt.name, func(t *testing.T) { 90 | body := bytes.NewBufferString(tt.body) 91 | req := httptest.NewRequest(http.MethodPost, tt.path, body) 92 | if tt.body == "" { 93 | req.Method = http.MethodGet 94 | } 95 | req.Header.Set("Content-Type", "application/json") 96 | 97 | w := httptest.NewRecorder() 98 | c, _ := gin.CreateTestContext(w) 99 | c.Request = req 100 | 101 | model, _, isStreaming, err := parseIncomingRequest(c) 102 | 103 | if tt.expectError && err == nil { 104 | t.Errorf("期望错误但未发生") 105 | } 106 | if !tt.expectError && err != nil { 107 | t.Errorf("不期望错误但发生: %v", err) 108 | } 109 | if model != tt.expectModel { 110 | t.Errorf("模型名错误: 期望%s, 实际%s", tt.expectModel, model) 111 | } 112 | if isStreaming != tt.expectStream { 113 | t.Errorf("流式标志错误: 期望%v, 实际%v", tt.expectStream, isStreaming) 114 | } 115 | }) 116 | } 117 | } 118 | 119 | // TestParseIncomingRequest_BodyTooLarge 测试请求体过大 120 | func TestParseIncomingRequest_BodyTooLarge(t *testing.T) { 121 | gin.SetMode(gin.TestMode) 122 | 123 | // 创建超大请求体(>2MB) 124 | largeBody := make([]byte, 3*1024*1024) // 3MB 125 | for i := range largeBody { 126 | largeBody[i] = 'a' 127 | } 128 | 129 | req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(largeBody)) 130 | req.Header.Set("Content-Type", "application/json") 131 | 132 | w := httptest.NewRecorder() 133 | c, _ := gin.CreateTestContext(w) 134 | c.Request = req 135 | 136 | _, _, _, err := parseIncomingRequest(c) 137 | 138 | if err != errBodyTooLarge { 139 | t.Errorf("期望errBodyTooLarge错误, 实际: %v", err) 140 | } 141 | } 142 | 143 | // TestAcquireConcurrencySlot 测试并发槽位获取 144 | func TestAcquireConcurrencySlot(t *testing.T) { 145 | gin.SetMode(gin.TestMode) 146 | 147 | srv := &Server{ 148 | concurrencySem: make(chan struct{}, 2), // 最大并发数=2 149 | maxConcurrency: 2, 150 | } 151 | 152 | // 创建有效的gin.Context 153 | req := httptest.NewRequest(http.MethodPost, "/test", nil) 154 | w := httptest.NewRecorder() 155 | c, _ := gin.CreateTestContext(w) 156 | c.Request = req 157 | 158 | // 第一次获取应该成功 159 | release1, acquired1 := srv.acquireConcurrencySlot(c) 160 | if !acquired1 { 161 | t.Fatal("第一次获取应该成功") 162 | } 163 | 164 | // 第二次获取应该成功 165 | release2, acquired2 := srv.acquireConcurrencySlot(c) 166 | if !acquired2 { 167 | t.Fatal("第二次获取应该成功") 168 | } 169 | 170 | // 释放一个槽位 171 | release1() 172 | 173 | // 现在应该可以再次获取 174 | release3, acquired3 := srv.acquireConcurrencySlot(c) 175 | if !acquired3 { 176 | t.Fatal("释放后再次获取应该成功") 177 | } 178 | 179 | // 清理 180 | release2() 181 | release3() 182 | 183 | t.Log("[INFO] 并发控制测试通过:2个槽位正确管理") 184 | } 185 | --------------------------------------------------------------------------------
配置运行时参数 · 实时生效