├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── config.yml │ └── feature_request.md └── workflows │ ├── docker-image.yml │ └── lint.yml ├── .gitignore ├── .golangci.yaml ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE ├── Midjourney.md ├── README.md ├── VERSION ├── bin ├── migration_v0.2-v0.3.sql ├── migration_v0.3-v0.4.sql └── time_test.sh ├── common ├── constants.go ├── crypto.go ├── custom-event.go ├── database.go ├── email.go ├── embed-file-system.go ├── gin.go ├── go-channel.go ├── group-ratio.go ├── image.go ├── init.go ├── logger.go ├── model-ratio.go ├── pprof.go ├── rate-limit.go ├── redis.go ├── topup-ratio.go ├── utils.go ├── validate.go └── verification.go ├── controller ├── billing.go ├── channel-billing.go ├── channel-test.go ├── channel.go ├── github.go ├── group.go ├── log.go ├── midjourney.go ├── misc.go ├── model.go ├── option.go ├── redemption.go ├── relay.go ├── telegram.go ├── token.go ├── topup.go ├── usedata.go ├── user.go └── wechat.go ├── docker-compose.yml ├── dto ├── audio.go ├── dalle.go ├── error.go ├── midjourney.go ├── text_request.go └── text_response.go ├── go.mod ├── go.sum ├── i18n ├── en.json └── translate.py ├── main.go ├── makefile ├── middleware ├── auth.go ├── cache.go ├── cors.go ├── distributor.go ├── logger.go ├── rate-limit.go ├── recover.go ├── request-id.go ├── turnstile-check.go └── utils.go ├── model ├── ability.go ├── cache.go ├── channel.go ├── log.go ├── main.go ├── midjourney.go ├── option.go ├── redemption.go ├── token.go ├── topup.go ├── usedata.go ├── user.go ├── user_checkin.go └── utils.go ├── one-api.service ├── relay ├── channel │ ├── adapter.go │ ├── ai360 │ │ └── constants.go │ ├── ali │ │ ├── adaptor.go │ │ ├── constants.go │ │ ├── dto.go │ │ └── relay-ali.go │ ├── api_request.go │ ├── baidu │ │ ├── adaptor.go │ │ ├── constants.go │ │ ├── dto.go │ │ └── relay-baidu.go │ ├── claude │ │ ├── adaptor.go │ │ ├── constants.go │ │ ├── dto.go │ │ └── relay-claude.go │ ├── gemini │ │ ├── adaptor.go │ │ ├── constant.go │ │ ├── dto.go │ │ └── relay-gemini.go │ ├── moonshot │ │ └── constants.go │ ├── openai │ │ ├── adaptor.go │ │ ├── constant.go │ │ └── relay-openai.go │ ├── palm │ │ ├── adaptor.go │ │ ├── constants.go │ │ ├── dto.go │ │ └── relay-palm.go │ ├── tencent │ │ ├── adaptor.go │ │ ├── constants.go │ │ ├── dto.go │ │ └── relay-tencent.go │ ├── xunfei │ │ ├── adaptor.go │ │ ├── constants.go │ │ ├── dto.go │ │ └── relay-xunfei.go │ ├── zhipu │ │ ├── adaptor.go │ │ ├── constants.go │ │ ├── dto.go │ │ └── relay-zhipu.go │ └── zhipu_4v │ │ ├── adaptor.go │ │ ├── constants.go │ │ ├── dto.go │ │ └── relay-zhipu_v4.go ├── common │ ├── relay_info.go │ └── relay_utils.go ├── constant │ ├── api_type.go │ └── relay_mode.go ├── relay-audio.go ├── relay-image.go ├── relay-mj.go ├── relay-text.go └── relay_adaptor.go ├── router ├── api-router.go ├── dashboard.go ├── main.go ├── relay-router.go └── web-router.go ├── service ├── channel.go ├── epay.go ├── error.go ├── http_client.go ├── sse.go ├── token_counter.go ├── usage_helpr.go └── user_notify.go └── web ├── .gitignore ├── README.md ├── package.json ├── public ├── favicon.ico ├── index.html ├── logo.png └── robots.txt ├── src ├── App.js ├── components │ ├── ChannelsTable.js │ ├── Footer.js │ ├── GitHubOAuth.js │ ├── HeaderBar.js │ ├── Loading.js │ ├── LoginForm.js │ ├── LogsTable.js │ ├── MjLogsTable.js │ ├── OperationSetting.js │ ├── OtherSetting.js │ ├── PasswordResetConfirm.js │ ├── PasswordResetForm.js │ ├── PersonalSetting.js │ ├── PrivateRoute.js │ ├── RedemptionsTable.js │ ├── RegisterForm.js │ ├── SiderBar.js │ ├── SystemSetting.js │ ├── TokensTable.js │ ├── UsersTable.js │ ├── WeChatIcon.js │ └── utils.js ├── constants │ ├── channel.constants.js │ ├── common.constant.js │ ├── index.js │ ├── toast.constants.js │ └── user.constants.js ├── context │ ├── Status │ │ ├── index.js │ │ └── reducer.js │ └── User │ │ ├── index.js │ │ └── reducer.js ├── helpers │ ├── api.js │ ├── auth-header.js │ ├── history.js │ ├── index.js │ ├── render.js │ └── utils.js ├── index.css ├── index.js └── pages │ ├── About │ └── index.js │ ├── Channel │ ├── EditChannel.js │ └── index.js │ ├── Chat │ └── index.js │ ├── Detail │ └── index.js │ ├── Home │ └── index.js │ ├── Log │ └── index.js │ ├── Midjourney │ └── index.js │ ├── NotFound │ └── index.js │ ├── Redemption │ ├── EditRedemption.js │ └── index.js │ ├── Setting │ └── index.js │ ├── Token │ ├── EditToken.js │ └── index.js │ ├── TopUp │ └── index.js │ └── User │ ├── AddUser.js │ ├── EditUser.js │ └── index.js └── vercel.json /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 报告问题 3 | about: 使用简练详细的语言描述你遇到的问题 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **例行检查** 11 | 12 | [//]: # (方框内删除已有的空格,填 x 号) 13 | + [ ] 我已确认目前没有类似 issue 14 | + [ ] 我已确认我已升级到最新版本 15 | + [ ] 我已完整查看过项目 README,尤其是常见问题部分 16 | + [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈 17 | + [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭** 18 | 19 | **问题描述** 20 | 21 | **复现步骤** 22 | 23 | **预期结果** 24 | 25 | **相关截图** 26 | 如果没有的话,请删除此节。 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: 项目群聊 4 | url: https://private-user-images.githubusercontent.com/61247483/283011625-de536a8a-0161-47a7-a0a2-66ef6de81266.jpeg?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTEiLCJleHAiOjE3MDIyMjQzOTAsIm5iZiI6MTcwMjIyNDA5MCwicGF0aCI6Ii82MTI0NzQ4My8yODMwMTE2MjUtZGU1MzZhOGEtMDE2MS00N2E3LWEwYTItNjZlZjZkZTgxMjY2LmpwZWc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBSVdOSllBWDRDU1ZFSDUzQSUyRjIwMjMxMjEwJTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDIzMTIxMFQxNjAxMzBaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT02MGIxYmM3ZDQyYzBkOTA2ZTYyYmVmMzQ1NjY4NjM1YjY0NTUzNTM5NjE1NDZkYTIzODdhYTk4ZjZjODJmYzY2JlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.TJ8CTfOSwR0-CHS1KLfomqgL0e4YH1luy8lSLrkv5Zg 5 | about: QQ 群:629454374 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 功能请求 3 | about: 使用简练详细的语言描述希望加入的新功能 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **例行检查** 11 | 12 | [//]: # (方框内删除已有的空格,填 x 号) 13 | + [ ] 我已确认目前没有类似 issue 14 | + [ ] 我已确认我已升级到最新版本 15 | + [ ] 我已完整查看过项目 README,已确定现有版本无法满足需求 16 | + [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈 17 | + [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭** 18 | 19 | **功能描述** 20 | 21 | **应用场景** 22 | -------------------------------------------------------------------------------- /.github/workflows/docker-image.yml: -------------------------------------------------------------------------------- 1 | name: Publish Docker image (amd64) 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | - '!*-alpha*' 8 | workflow_dispatch: 9 | inputs: 10 | name: 11 | description: 'reason' 12 | required: false 13 | jobs: 14 | push_to_registries: 15 | name: Push Docker image to registries 16 | runs-on: ubuntu-latest 17 | permissions: 18 | packages: write 19 | contents: read 20 | steps: 21 | - name: Check out the repo 22 | uses: actions/checkout@v3 23 | 24 | - name: Save version info 25 | run: | 26 | git rev-parse --short HEAD > VERSION 27 | 28 | - name: Log in to Docker Hub 29 | uses: docker/login-action@v3 30 | with: 31 | registry: sjc.vultrcr.com 32 | ecr: false 33 | username: ${{secrets.DOCKER_USERNAME}} 34 | password: ${{secrets.DOCKER_PASSWORD}} 35 | 36 | - name: Extract metadata (tags, labels) for Docker 37 | id: meta 38 | uses: docker/metadata-action@v4 39 | with: 40 | images: | 41 | sjc.vultrcr.com/ehcotest/new-api 42 | 43 | - name: Build and push Docker images 44 | uses: docker/build-push-action@v3 45 | with: 46 | context: . 47 | platforms: linux/amd64 48 | push: true 49 | tags: ${{ steps.meta.outputs.tags }} 50 | labels: ${{ steps.meta.outputs.labels }} 51 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | jobs: 9 | lint: 10 | name: Lint 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - uses: actions/setup-go@v5 15 | 16 | - name: Install tools 17 | run: make install-dev 18 | 19 | - run: | 20 | mkdir -p web/build 21 | touch web/build/index.html 22 | 23 | - uses: pre-commit/action@v3.0.1 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | upload 4 | *.exe 5 | *.db 6 | build 7 | *.db-journal 8 | logs 9 | -------------------------------------------------------------------------------- /.golangci.yaml: -------------------------------------------------------------------------------- 1 | linters: 2 | disable: 3 | - unused 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.5.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-added-large-files 11 | - repo: https://github.com/dnephin/pre-commit-golang 12 | rev: v0.5.1 13 | hooks: 14 | - id: go-fmt 15 | - id: go-vet 16 | - id: go-imports 17 | # - id: go-cyclo 18 | # args: [-over=15] 19 | - id: validate-toml 20 | - id: no-go-testing 21 | - id: golangci-lint 22 | # - id: go-critic 23 | - id: go-unit-tests 24 | - id: go-build 25 | - id: go-mod-tidy 26 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM node:16 as builder 2 | 3 | WORKDIR /build 4 | COPY web/package.json . 5 | RUN npm install 6 | COPY ./web . 7 | COPY ./VERSION . 8 | RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build 9 | 10 | FROM golang AS builder2 11 | 12 | ENV GO111MODULE=on \ 13 | CGO_ENABLED=1 \ 14 | GOOS=linux 15 | 16 | WORKDIR /build 17 | ADD go.mod go.sum ./ 18 | RUN go mod download 19 | COPY . . 20 | COPY --from=builder /build/build ./web/build 21 | RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api 22 | 23 | FROM alpine 24 | 25 | RUN apk update \ 26 | && apk upgrade \ 27 | && apk add --no-cache ca-certificates tzdata \ 28 | && update-ca-certificates 2>/dev/null || true 29 | 30 | COPY --from=builder2 /build/one-api / 31 | EXPOSE 3000 32 | WORKDIR /data 33 | ENTRYPOINT ["/one-api"] 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Calcium-Ion 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # New API 3 | 4 | > [!NOTE] 5 | > 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。 6 | > 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 7 | 8 | 9 | > [!WARNING] 10 | > 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。 11 | > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 12 | 13 | > [!NOTE] 14 | > 最新版Docker镜像 calciumion/new-api:latest 15 | > 更新指令 docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR 16 | 17 | ## 主要变更 18 | 此分叉版本的主要变更如下: 19 | 20 | 1. 全新的UI界面(部分界面还待更新) 21 | 2. 添加[Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)接口的支持 22 | + [x] /mj/submit/imagine 23 | + [x] /mj/submit/change 24 | + [x] /mj/submit/blend 25 | + [x] /mj/submit/describe 26 | + [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**) 27 | + [x] /mj/task/{id}/fetch (此接口返回的图片地址为经过One API转发的地址) 28 | + [x] /task/list-by-condition 29 | 3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口: 30 | + [x] 易支付 31 | 4. 支持用key查询使用额度: 32 | + 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用 33 | 5. 渠道显示已使用额度,支持指定组织访问 34 | 6. 分页支持选择每页显示数量 35 | 7. 兼容原版One API的数据库,可直接使用原版数据库(one-api.db) 36 | 8. 支持模型按次数收费,可在 系统设置-运营设置 中设置 37 | 9. 支持渠道**加权随机** 38 | 10. 数据看板 39 | 11. 可设置令牌能调用的模型 40 | 12. 支持Telegram授权登录 41 | 42 | ## 模型支持 43 | 此版本额外支持以下模型: 44 | 1. 第三方模型 **gps** (gpt-4-gizmo-*) 45 | 2. 智谱glm-4v,glm-4v识图 46 | 47 | 您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。 48 | 49 | ## 部署 50 | ### 基于 Docker 进行部署 51 | ```shell 52 | # 使用 SQLite 的部署命令: 53 | docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest 54 | # 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数。 55 | # 例如: 56 | docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest 57 | ``` 58 | ### 使用宝塔面板Docker功能部署 59 | ```shell 60 | # 使用 SQLite 的部署命令: 61 | docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest 62 | # 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数。 63 | # 例如: 64 | # 注意:数据库要开启远程访问,并且只允许服务器IP访问 65 | docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(宝塔的服务器地址:宝塔数据库端口)/宝塔数据库名称" -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest 66 | # 注意:数据库要开启远程访问,并且只允许服务器IP访问 67 | ``` 68 | ## Midjourney接口设置文档 69 | [对接文档](Midjourney.md) 70 | 71 | ## 交流群 72 | 73 | 74 | ## 界面截图 75 |  76 | 77 |  78 | 79 |  80 |  81 |  82 |  83 | 夜间模式 84 |  85 | 86 |  87 |  88 | 89 | ## Star History 90 | 91 | [](https://star-history.com/#Calcium-Ion/new-api&Date) 92 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ehco1996/new-api/2f4e6e4e1d9129c9d5fa0a25a9207254d62c1f06/VERSION -------------------------------------------------------------------------------- /bin/migration_v0.2-v0.3.sql: -------------------------------------------------------------------------------- 1 | UPDATE users 2 | SET quota = quota + ( 3 | SELECT SUM(remain_quota) 4 | FROM tokens 5 | WHERE tokens.user_id = users.id 6 | ) 7 | -------------------------------------------------------------------------------- /bin/migration_v0.3-v0.4.sql: -------------------------------------------------------------------------------- 1 | INSERT INTO abilities (`group`, model, channel_id, enabled) 2 | SELECT c.`group`, m.model, c.id, 1 3 | FROM channels c 4 | CROSS JOIN ( 5 | SELECT 'gpt-3.5-turbo' AS model UNION ALL 6 | SELECT 'gpt-3.5-turbo-0301' AS model UNION ALL 7 | SELECT 'gpt-4' AS model UNION ALL 8 | SELECT 'gpt-4-0314' AS model 9 | ) AS m 10 | WHERE c.status = 1 11 | AND NOT EXISTS ( 12 | SELECT 1 13 | FROM abilities a 14 | WHERE a.`group` = c.`group` 15 | AND a.model = m.model 16 | AND a.channel_id = c.id 17 | ); 18 | -------------------------------------------------------------------------------- /bin/time_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ $# -lt 3 ]; then 4 | echo "Usage: time_test.sh []" 5 | exit 1 6 | fi 7 | 8 | domain=$1 9 | key=$2 10 | count=$3 11 | model=${4:-"gpt-3.5-turbo"} # 设置默认模型为 gpt-3.5-turbo 12 | 13 | total_time=0 14 | times=() 15 | 16 | for ((i=1; i<=count; i++)); do 17 | result=$(curl -o /dev/null -s -w "%{http_code} %{time_total}\\n" \ 18 | https://"$domain"/v1/chat/completions \ 19 | -H "Content-Type: application/json" \ 20 | -H "Authorization: Bearer $key" \ 21 | -d '{"messages": [{"content": "echo hi", "role": "user"}], "model": "'"$model"'", "stream": false, "max_tokens": 1}') 22 | http_code=$(echo "$result" | awk '{print $1}') 23 | time=$(echo "$result" | awk '{print $2}') 24 | echo "HTTP status code: $http_code, Time taken: $time" 25 | total_time=$(bc <<< "$total_time + $time") 26 | times+=("$time") 27 | done 28 | 29 | average_time=$(echo "scale=4; $total_time / $count" | bc) 30 | 31 | sum_of_squares=0 32 | for time in "${times[@]}"; do 33 | difference=$(echo "scale=4; $time - $average_time" | bc) 34 | square=$(echo "scale=4; $difference * $difference" | bc) 35 | sum_of_squares=$(echo "scale=4; $sum_of_squares + $square" | bc) 36 | done 37 | 38 | standard_deviation=$(echo "scale=4; sqrt($sum_of_squares / $count)" | bc) 39 | 40 | echo "Average time: $average_time±$standard_deviation" 41 | -------------------------------------------------------------------------------- /common/crypto.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import "golang.org/x/crypto/bcrypt" 4 | 5 | func Password2Hash(password string) (string, error) { 6 | passwordBytes := []byte(password) 7 | hashedPassword, err := bcrypt.GenerateFromPassword(passwordBytes, bcrypt.DefaultCost) 8 | return string(hashedPassword), err 9 | } 10 | 11 | func ValidatePasswordAndHash(password string, hash string) bool { 12 | err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) 13 | return err == nil 14 | } 15 | -------------------------------------------------------------------------------- /common/custom-event.go: -------------------------------------------------------------------------------- 1 | // Copyright 2014 Manu Martinez-Almeida. All rights reserved. 2 | // Use of this source code is governed by a MIT style 3 | // license that can be found in the LICENSE file. 4 | 5 | package common 6 | 7 | import ( 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "strings" 12 | ) 13 | 14 | type stringWriter interface { 15 | io.Writer 16 | writeString(string) (int, error) 17 | } 18 | 19 | type stringWrapper struct { 20 | io.Writer 21 | } 22 | 23 | func (w stringWrapper) writeString(str string) (int, error) { 24 | return w.Writer.Write([]byte(str)) 25 | } 26 | 27 | func checkWriter(writer io.Writer) stringWriter { 28 | if w, ok := writer.(stringWriter); ok { 29 | return w 30 | } else { 31 | return stringWrapper{writer} 32 | } 33 | } 34 | 35 | // Server-Sent Events 36 | // W3C Working Draft 29 October 2009 37 | // http://www.w3.org/TR/2009/WD-eventsource-20091029/ 38 | 39 | var contentType = []string{"text/event-stream"} 40 | var noCache = []string{"no-cache"} 41 | 42 | var fieldReplacer = strings.NewReplacer( 43 | "\n", "\\n", 44 | "\r", "\\r") 45 | 46 | var dataReplacer = strings.NewReplacer( 47 | "\n", "\ndata:", 48 | "\r", "\\r") 49 | 50 | type CustomEvent struct { 51 | Event string 52 | Id string 53 | Retry uint 54 | Data interface{} 55 | } 56 | 57 | func encode(writer io.Writer, event CustomEvent) error { 58 | w := checkWriter(writer) 59 | return writeData(w, event.Data) 60 | } 61 | 62 | func writeData(w stringWriter, data interface{}) error { 63 | _, _ = dataReplacer.WriteString(w, fmt.Sprint(data)) 64 | if strings.HasPrefix(data.(string), "data") { 65 | _, _ = w.writeString("\n\n") 66 | } 67 | return nil 68 | } 69 | 70 | func (r CustomEvent) Render(w http.ResponseWriter) error { 71 | r.WriteContentType(w) 72 | return encode(w, r) 73 | } 74 | 75 | func (r CustomEvent) WriteContentType(w http.ResponseWriter) { 76 | header := w.Header() 77 | header["Content-Type"] = contentType 78 | 79 | if _, exist := header["Cache-Control"]; !exist { 80 | header["Cache-Control"] = noCache 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /common/database.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | var UsingSQLite = false 4 | var UsingPostgreSQL = false 5 | 6 | var SQLitePath = "one-api.db?_busy_timeout=5000" 7 | -------------------------------------------------------------------------------- /common/email.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "crypto/tls" 5 | "encoding/base64" 6 | "fmt" 7 | "net/smtp" 8 | "strings" 9 | "time" 10 | ) 11 | 12 | func SendEmail(subject string, receiver string, content string) error { 13 | if SMTPFrom == "" { // for compatibility 14 | SMTPFrom = SMTPAccount 15 | } 16 | encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) 17 | mail := []byte(fmt.Sprintf("To: %s\r\n"+ 18 | "From: %s<%s>\r\n"+ 19 | "Subject: %s\r\n"+ 20 | "Date: %s\r\n"+ 21 | "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", 22 | receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), content)) 23 | auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) 24 | addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) 25 | to := strings.Split(receiver, ";") 26 | var err error 27 | if SMTPPort == 465 { 28 | tlsConfig := &tls.Config{ 29 | InsecureSkipVerify: true, 30 | ServerName: SMTPServer, 31 | } 32 | conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig) 33 | if err != nil { 34 | return err 35 | } 36 | client, err := smtp.NewClient(conn, SMTPServer) 37 | if err != nil { 38 | return err 39 | } 40 | defer client.Close() 41 | if err = client.Auth(auth); err != nil { 42 | return err 43 | } 44 | if err = client.Mail(SMTPFrom); err != nil { 45 | return err 46 | } 47 | receiverEmails := strings.Split(receiver, ";") 48 | for _, receiver := range receiverEmails { 49 | if err = client.Rcpt(receiver); err != nil { 50 | return err 51 | } 52 | } 53 | w, err := client.Data() 54 | if err != nil { 55 | return err 56 | } 57 | _, err = w.Write(mail) 58 | if err != nil { 59 | return err 60 | } 61 | err = w.Close() 62 | if err != nil { 63 | return err 64 | } 65 | } else { 66 | err = smtp.SendMail(addr, auth, SMTPAccount, to, mail) 67 | } 68 | return err 69 | } 70 | -------------------------------------------------------------------------------- /common/embed-file-system.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "embed" 5 | "io/fs" 6 | "net/http" 7 | 8 | "github.com/gin-contrib/static" 9 | ) 10 | 11 | // Credit: https://github.com/gin-contrib/static/issues/19 12 | 13 | type embedFileSystem struct { 14 | http.FileSystem 15 | } 16 | 17 | func (e embedFileSystem) Exists(prefix string, path string) bool { 18 | _, err := e.Open(path) 19 | return err == nil 20 | } 21 | 22 | func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem { 23 | efs, err := fs.Sub(fsEmbed, targetPath) 24 | if err != nil { 25 | panic(err) 26 | } 27 | return embedFileSystem{ 28 | FileSystem: http.FS(efs), 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /common/gin.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "io" 7 | 8 | "github.com/gin-gonic/gin" 9 | ) 10 | 11 | func UnmarshalBodyReusable(c *gin.Context, v any) error { 12 | requestBody, err := io.ReadAll(c.Request.Body) 13 | if err != nil { 14 | return err 15 | } 16 | err = c.Request.Body.Close() 17 | if err != nil { 18 | return err 19 | } 20 | err = json.Unmarshal(requestBody, &v) 21 | if err != nil { 22 | return err 23 | } 24 | // Reset request body 25 | c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) 26 | return nil 27 | } 28 | -------------------------------------------------------------------------------- /common/go-channel.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "fmt" 5 | "runtime/debug" 6 | ) 7 | 8 | func SafeGoroutine(f func()) { 9 | go func() { 10 | defer func() { 11 | if r := recover(); r != nil { 12 | SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack()))) 13 | } 14 | }() 15 | f() 16 | }() 17 | } 18 | 19 | func SafeSend(ch chan bool, value bool) (closed bool) { 20 | defer func() { 21 | // Recover from panic if one occured. A panic would mean the channel was closed. 22 | if recover() != nil { 23 | closed = true 24 | } 25 | }() 26 | 27 | // This will panic if the channel is closed. 28 | ch <- value 29 | 30 | // If the code reaches here, then the channel was not closed. 31 | return false 32 | } 33 | -------------------------------------------------------------------------------- /common/group-ratio.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import "encoding/json" 4 | 5 | var GroupRatio = map[string]float64{ 6 | "default": 1, 7 | "vip": 1, 8 | "svip": 1, 9 | } 10 | 11 | func GroupRatio2JSONString() string { 12 | jsonBytes, err := json.Marshal(GroupRatio) 13 | if err != nil { 14 | SysError("error marshalling model ratio: " + err.Error()) 15 | } 16 | return string(jsonBytes) 17 | } 18 | 19 | func UpdateGroupRatioByJSONString(jsonStr string) error { 20 | GroupRatio = make(map[string]float64) 21 | return json.Unmarshal([]byte(jsonStr), &GroupRatio) 22 | } 23 | 24 | func GetGroupRatio(name string) float64 { 25 | ratio, ok := GroupRatio[name] 26 | if !ok { 27 | SysError("group ratio not found: " + name) 28 | return 1 29 | } 30 | return ratio 31 | } 32 | -------------------------------------------------------------------------------- /common/image.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "fmt" 7 | "image" 8 | "io" 9 | "net/http" 10 | "strings" 11 | 12 | "github.com/chai2010/webp" 13 | ) 14 | 15 | func DecodeBase64ImageData(base64String string) (image.Config, string, error) { 16 | // 去除base64数据的URL前缀(如果有) 17 | if idx := strings.Index(base64String, ","); idx != -1 { 18 | base64String = base64String[idx+1:] 19 | } 20 | 21 | // 将base64字符串解码为字节切片 22 | decodedData, err := base64.StdEncoding.DecodeString(base64String) 23 | if err != nil { 24 | fmt.Println("Error: Failed to decode base64 string") 25 | return image.Config{}, "", err 26 | } 27 | 28 | // 创建一个bytes.Buffer用于存储解码后的数据 29 | reader := bytes.NewReader(decodedData) 30 | config, format, err := getImageConfig(reader) 31 | return config, format, err 32 | } 33 | 34 | func IsImageUrl(url string) (bool, error) { 35 | resp, err := http.Head(url) 36 | if err != nil { 37 | return false, err 38 | } 39 | if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") { 40 | return false, nil 41 | } 42 | return true, nil 43 | } 44 | 45 | func GetImageFromUrl(url string) (mimeType string, data string, err error) { 46 | isImage, err := IsImageUrl(url) 47 | if !isImage { 48 | return 49 | } 50 | resp, err := http.Get(url) 51 | if err != nil { 52 | return 53 | } 54 | defer resp.Body.Close() 55 | buffer := bytes.NewBuffer(nil) 56 | _, err = buffer.ReadFrom(resp.Body) 57 | if err != nil { 58 | return 59 | } 60 | mimeType = resp.Header.Get("Content-Type") 61 | data = base64.StdEncoding.EncodeToString(buffer.Bytes()) 62 | return 63 | } 64 | 65 | func DecodeUrlImageData(imageUrl string) (image.Config, string, error) { 66 | response, err := http.Get(imageUrl) 67 | if err != nil { 68 | SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error())) 69 | return image.Config{}, "", err 70 | } 71 | defer response.Body.Close() 72 | 73 | var readData []byte 74 | for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} { 75 | SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit)) 76 | 77 | // 从response.Body读取更多的数据直到达到当前的限制 78 | additionalData := make([]byte, limit-int64(len(readData))) 79 | n, _ := io.ReadFull(response.Body, additionalData) 80 | readData = append(readData, additionalData[:n]...) 81 | 82 | // 使用io.MultiReader组合已经读取的数据和response.Body 83 | limitReader := io.MultiReader(bytes.NewReader(readData), response.Body) 84 | 85 | var config image.Config 86 | var format string 87 | config, format, err = getImageConfig(limitReader) 88 | if err == nil { 89 | return config, format, nil 90 | } 91 | } 92 | 93 | return image.Config{}, "", err // 返回最后一个错误 94 | } 95 | 96 | func getImageConfig(reader io.Reader) (image.Config, string, error) { 97 | // 读取图片的头部信息来获取图片尺寸 98 | config, format, err := image.DecodeConfig(reader) 99 | if err != nil { 100 | err = fmt.Errorf("fail to decode image config(gif, jpg, png): %s", err.Error()) 101 | SysLog(err.Error()) 102 | config, err = webp.DecodeConfig(reader) 103 | if err != nil { 104 | err = fmt.Errorf("fail to decode image config(webp): %s", err.Error()) 105 | SysLog(err.Error()) 106 | } 107 | format = "webp" 108 | } 109 | if err != nil { 110 | return image.Config{}, "", err 111 | } 112 | return config, format, nil 113 | } 114 | -------------------------------------------------------------------------------- /common/init.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "log" 7 | "os" 8 | "path/filepath" 9 | ) 10 | 11 | var ( 12 | Port = flag.Int("port", 3000, "the listening port") 13 | PrintVersion = flag.Bool("version", false, "print version and exit") 14 | PrintHelp = flag.Bool("help", false, "print help and exit") 15 | LogDir = flag.String("log-dir", "./logs", "specify the log directory") 16 | ) 17 | 18 | func printHelp() { 19 | fmt.Println("New API " + Version + " - All in one API service for OpenAI API.") 20 | fmt.Println("Copyright (C) 2023 JustSong. All rights reserved.") 21 | fmt.Println("GitHub: https://github.com/songquanpeng/one-api") 22 | fmt.Println("Usage: one-api [--port ] [--log-dir ] [--version] [--help]") 23 | } 24 | 25 | func init() { 26 | flag.Parse() 27 | 28 | if *PrintVersion { 29 | fmt.Println(Version) 30 | os.Exit(0) 31 | } 32 | 33 | if *PrintHelp { 34 | printHelp() 35 | os.Exit(0) 36 | } 37 | 38 | if os.Getenv("SESSION_SECRET") != "" { 39 | ss := os.Getenv("SESSION_SECRET") 40 | if ss == "random_string" { 41 | log.Println("WARNING: SESSION_SECRET is set to the default value 'random_string', please change it to a random string.") 42 | log.Println("警告:SESSION_SECRET被设置为默认值'random_string',请修改为随机字符串。") 43 | log.Fatal("Please set SESSION_SECRET to a random string.") 44 | } else { 45 | SessionSecret = ss 46 | } 47 | } 48 | if os.Getenv("SQLITE_PATH") != "" { 49 | SQLitePath = os.Getenv("SQLITE_PATH") 50 | } 51 | if *LogDir != "" { 52 | var err error 53 | *LogDir, err = filepath.Abs(*LogDir) 54 | if err != nil { 55 | log.Fatal(err) 56 | } 57 | if _, err := os.Stat(*LogDir); os.IsNotExist(err) { 58 | err = os.Mkdir(*LogDir, 0777) 59 | if err != nil { 60 | log.Fatal(err) 61 | } 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /common/logger.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "log" 8 | "os" 9 | "path/filepath" 10 | "sync" 11 | "time" 12 | 13 | "github.com/gin-gonic/gin" 14 | ) 15 | 16 | const ( 17 | loggerINFO = "INFO" 18 | loggerWarn = "WARN" 19 | loggerError = "ERR" 20 | ) 21 | 22 | const maxLogCount = 1000000 23 | 24 | var logCount int 25 | var setupLogLock sync.Mutex 26 | var setupLogWorking bool 27 | 28 | func SetupLogger() { 29 | if *LogDir != "" { 30 | ok := setupLogLock.TryLock() 31 | if !ok { 32 | log.Println("setup log is already working") 33 | return 34 | } 35 | defer func() { 36 | setupLogLock.Unlock() 37 | setupLogWorking = false 38 | }() 39 | logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) 40 | fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) 41 | if err != nil { 42 | log.Fatal("failed to open log file") 43 | } 44 | gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) 45 | gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) 46 | } 47 | } 48 | 49 | func SysLog(s string) { 50 | t := time.Now() 51 | _, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) 52 | } 53 | 54 | func SysError(s string) { 55 | t := time.Now() 56 | _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) 57 | } 58 | 59 | func LogInfo(ctx context.Context, msg string) { 60 | logHelper(ctx, loggerINFO, msg) 61 | } 62 | 63 | func LogWarn(ctx context.Context, msg string) { 64 | logHelper(ctx, loggerWarn, msg) 65 | } 66 | 67 | func LogError(ctx context.Context, msg string) { 68 | logHelper(ctx, loggerError, msg) 69 | } 70 | 71 | func logHelper(ctx context.Context, level string, msg string) { 72 | writer := gin.DefaultErrorWriter 73 | if level == loggerINFO { 74 | writer = gin.DefaultWriter 75 | } 76 | id := ctx.Value(RequestIdKey) 77 | now := time.Now() 78 | _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) 79 | logCount++ // we don't need accurate count, so no lock here 80 | if logCount > maxLogCount && !setupLogWorking { 81 | logCount = 0 82 | setupLogWorking = true 83 | go func() { 84 | SetupLogger() 85 | }() 86 | } 87 | } 88 | 89 | func FatalLog(v ...any) { 90 | t := time.Now() 91 | _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) 92 | os.Exit(1) 93 | } 94 | 95 | func LogQuota(quota int) string { 96 | if DisplayInCurrencyEnabled { 97 | return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit) 98 | } else { 99 | return fmt.Sprintf("%d 点额度", quota) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /common/pprof.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "runtime/pprof" 7 | "time" 8 | 9 | "github.com/shirou/gopsutil/cpu" 10 | ) 11 | 12 | // Monitor 定时监控cpu使用率,超过阈值输出pprof文件 13 | func Monitor() { 14 | for { 15 | percent, err := cpu.Percent(time.Second, false) 16 | if err != nil { 17 | panic(err) 18 | } 19 | if percent[0] > 80 { 20 | fmt.Println("cpu usage too high") 21 | // write pprof file 22 | if _, err := os.Stat("./pprof"); os.IsNotExist(err) { 23 | err := os.Mkdir("./pprof", os.ModePerm) 24 | if err != nil { 25 | SysLog("创建pprof文件夹失败 " + err.Error()) 26 | continue 27 | } 28 | } 29 | f, err := os.Create("./pprof/" + fmt.Sprintf("cpu-%s.pprof", time.Now().Format("20060102150405"))) 30 | if err != nil { 31 | SysLog("创建pprof文件失败 " + err.Error()) 32 | continue 33 | } 34 | err = pprof.StartCPUProfile(f) 35 | if err != nil { 36 | SysLog("启动pprof失败 " + err.Error()) 37 | continue 38 | } 39 | time.Sleep(10 * time.Second) // profile for 30 seconds 40 | pprof.StopCPUProfile() 41 | f.Close() 42 | } 43 | time.Sleep(30 * time.Second) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /common/rate-limit.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | ) 7 | 8 | type InMemoryRateLimiter struct { 9 | store map[string]*[]int64 10 | mutex sync.Mutex 11 | expirationDuration time.Duration 12 | } 13 | 14 | func (l *InMemoryRateLimiter) Init(expirationDuration time.Duration) { 15 | if l.store == nil { 16 | l.mutex.Lock() 17 | if l.store == nil { 18 | l.store = make(map[string]*[]int64) 19 | l.expirationDuration = expirationDuration 20 | if expirationDuration > 0 { 21 | go l.clearExpiredItems() 22 | } 23 | } 24 | l.mutex.Unlock() 25 | } 26 | } 27 | 28 | func (l *InMemoryRateLimiter) clearExpiredItems() { 29 | for { 30 | time.Sleep(l.expirationDuration) 31 | l.mutex.Lock() 32 | now := time.Now().Unix() 33 | for key := range l.store { 34 | queue := l.store[key] 35 | size := len(*queue) 36 | if size == 0 || now-(*queue)[size-1] > int64(l.expirationDuration.Seconds()) { 37 | delete(l.store, key) 38 | } 39 | } 40 | l.mutex.Unlock() 41 | } 42 | } 43 | 44 | // Request parameter duration's unit is seconds 45 | func (l *InMemoryRateLimiter) Request(key string, maxRequestNum int, duration int64) bool { 46 | l.mutex.Lock() 47 | defer l.mutex.Unlock() 48 | // [old <-- new] 49 | queue, ok := l.store[key] 50 | now := time.Now().Unix() 51 | if ok { 52 | if len(*queue) < maxRequestNum { 53 | *queue = append(*queue, now) 54 | return true 55 | } else { 56 | if now-(*queue)[0] >= duration { 57 | *queue = (*queue)[1:] 58 | *queue = append(*queue, now) 59 | return true 60 | } else { 61 | return false 62 | } 63 | } 64 | } else { 65 | s := make([]int64, 0, maxRequestNum) 66 | l.store[key] = &s 67 | *(l.store[key]) = append(*(l.store[key]), now) 68 | } 69 | return true 70 | } 71 | -------------------------------------------------------------------------------- /common/redis.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "time" 7 | 8 | "github.com/go-redis/redis/v8" 9 | ) 10 | 11 | var RDB *redis.Client 12 | var RedisEnabled = true 13 | 14 | // InitRedisClient This function is called after init() 15 | func InitRedisClient() (err error) { 16 | if os.Getenv("REDIS_CONN_STRING") == "" { 17 | RedisEnabled = false 18 | SysLog("REDIS_CONN_STRING not set, Redis is not enabled") 19 | return nil 20 | } 21 | if os.Getenv("SYNC_FREQUENCY") == "" { 22 | RedisEnabled = false 23 | SysLog("SYNC_FREQUENCY not set, Redis is disabled") 24 | return nil 25 | } 26 | SysLog("Redis is enabled") 27 | opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) 28 | if err != nil { 29 | FatalLog("failed to parse Redis connection string: " + err.Error()) 30 | } 31 | RDB = redis.NewClient(opt) 32 | 33 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 34 | defer cancel() 35 | 36 | _, err = RDB.Ping(ctx).Result() 37 | if err != nil { 38 | FatalLog("Redis ping test failed: " + err.Error()) 39 | } 40 | return err 41 | } 42 | 43 | func ParseRedisOption() *redis.Options { 44 | opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) 45 | if err != nil { 46 | FatalLog("failed to parse Redis connection string: " + err.Error()) 47 | } 48 | return opt 49 | } 50 | 51 | func RedisSet(key string, value string, expiration time.Duration) error { 52 | ctx := context.Background() 53 | return RDB.Set(ctx, key, value, expiration).Err() 54 | } 55 | 56 | func RedisGet(key string) (string, error) { 57 | ctx := context.Background() 58 | return RDB.Get(ctx, key).Result() 59 | } 60 | 61 | func RedisExpire(key string, expiration time.Duration) error { 62 | ctx := context.Background() 63 | return RDB.Expire(ctx, key, expiration).Err() 64 | } 65 | 66 | func RedisGetEx(key string, expiration time.Duration) (string, error) { 67 | ctx := context.Background() 68 | return RDB.GetSet(ctx, key, expiration).Result() 69 | } 70 | 71 | func RedisDel(key string) error { 72 | ctx := context.Background() 73 | return RDB.Del(ctx, key).Err() 74 | } 75 | 76 | func RedisDecrease(key string, value int64) error { 77 | 78 | // 检查键的剩余生存时间 79 | ttlCmd := RDB.TTL(context.Background(), key) 80 | ttl, err := ttlCmd.Result() 81 | if err != nil { 82 | // 失败则尝试直接减少 83 | return RDB.DecrBy(context.Background(), key, value).Err() 84 | } 85 | 86 | // 如果剩余生存时间大于0,则进行减少操作 87 | if ttl > 0 { 88 | ctx := context.Background() 89 | // 开始一个Redis事务 90 | txn := RDB.TxPipeline() 91 | 92 | // 减少余额 93 | decrCmd := txn.DecrBy(ctx, key, value) 94 | if err := decrCmd.Err(); err != nil { 95 | return err // 如果减少失败,则直接返回错误 96 | } 97 | 98 | // 重新设置过期时间,使用原来的过期时间 99 | txn.Expire(ctx, key, ttl) 100 | 101 | // 执行事务 102 | _, err = txn.Exec(ctx) 103 | return err 104 | } else { 105 | _ = RedisDel(key) 106 | } 107 | return nil 108 | } 109 | -------------------------------------------------------------------------------- /common/topup-ratio.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import "encoding/json" 4 | 5 | var TopupGroupRatio = map[string]float64{ 6 | "default": 1, 7 | "vip": 1, 8 | "svip": 1, 9 | } 10 | 11 | func TopupGroupRatio2JSONString() string { 12 | jsonBytes, err := json.Marshal(TopupGroupRatio) 13 | if err != nil { 14 | SysError("error marshalling model ratio: " + err.Error()) 15 | } 16 | return string(jsonBytes) 17 | } 18 | 19 | func UpdateTopupGroupRatioByJSONString(jsonStr string) error { 20 | TopupGroupRatio = make(map[string]float64) 21 | return json.Unmarshal([]byte(jsonStr), &TopupGroupRatio) 22 | } 23 | 24 | func GetTopupGroupRatio(name string) float64 { 25 | ratio, ok := TopupGroupRatio[name] 26 | if !ok { 27 | SysError("topup group ratio not found: " + name) 28 | return 1 29 | } 30 | return ratio 31 | } 32 | -------------------------------------------------------------------------------- /common/validate.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import "github.com/go-playground/validator/v10" 4 | 5 | var Validate *validator.Validate 6 | 7 | func init() { 8 | Validate = validator.New() 9 | } 10 | -------------------------------------------------------------------------------- /common/verification.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "strings" 5 | "sync" 6 | "time" 7 | 8 | "github.com/google/uuid" 9 | ) 10 | 11 | type verificationValue struct { 12 | code string 13 | time time.Time 14 | } 15 | 16 | const ( 17 | EmailVerificationPurpose = "v" 18 | PasswordResetPurpose = "r" 19 | ) 20 | 21 | var verificationMutex sync.Mutex 22 | var verificationMap map[string]verificationValue 23 | var verificationMapMaxSize = 10 24 | var VerificationValidMinutes = 10 25 | 26 | func GenerateVerificationCode(length int) string { 27 | code := uuid.New().String() 28 | code = strings.Replace(code, "-", "", -1) 29 | if length == 0 { 30 | return code 31 | } 32 | return code[:length] 33 | } 34 | 35 | func RegisterVerificationCodeWithKey(key string, code string, purpose string) { 36 | verificationMutex.Lock() 37 | defer verificationMutex.Unlock() 38 | verificationMap[purpose+key] = verificationValue{ 39 | code: code, 40 | time: time.Now(), 41 | } 42 | if len(verificationMap) > verificationMapMaxSize { 43 | removeExpiredPairs() 44 | } 45 | } 46 | 47 | func VerifyCodeWithKey(key string, code string, purpose string) bool { 48 | verificationMutex.Lock() 49 | defer verificationMutex.Unlock() 50 | value, okay := verificationMap[purpose+key] 51 | now := time.Now() 52 | if !okay || int(now.Sub(value.time).Seconds()) >= VerificationValidMinutes*60 { 53 | return false 54 | } 55 | return code == value.code 56 | } 57 | 58 | func DeleteKey(key string, purpose string) { 59 | verificationMutex.Lock() 60 | defer verificationMutex.Unlock() 61 | delete(verificationMap, purpose+key) 62 | } 63 | 64 | // no lock inside, so the caller must lock the verificationMap before calling! 65 | func removeExpiredPairs() { 66 | now := time.Now() 67 | for key := range verificationMap { 68 | if int(now.Sub(verificationMap[key].time).Seconds()) >= VerificationValidMinutes*60 { 69 | delete(verificationMap, key) 70 | } 71 | } 72 | } 73 | 74 | func init() { 75 | verificationMutex.Lock() 76 | defer verificationMutex.Unlock() 77 | verificationMap = make(map[string]verificationValue) 78 | } 79 | -------------------------------------------------------------------------------- /controller/billing.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "one-api/common" 5 | "one-api/dto" 6 | "one-api/model" 7 | 8 | "github.com/gin-gonic/gin" 9 | ) 10 | 11 | func GetSubscription(c *gin.Context) { 12 | var remainQuota int 13 | var usedQuota int 14 | var err error 15 | var token *model.Token 16 | var expiredTime int64 17 | if common.DisplayTokenStatEnabled { 18 | tokenId := c.GetInt("token_id") 19 | token, err = model.GetTokenById(tokenId) 20 | expiredTime = token.ExpiredTime 21 | remainQuota = token.RemainQuota 22 | usedQuota = token.UsedQuota 23 | } else { 24 | userId := c.GetInt("id") 25 | remainQuota, _ = model.GetUserQuota(userId) 26 | usedQuota, err = model.GetUserUsedQuota(userId) 27 | } 28 | if expiredTime <= 0 { 29 | expiredTime = 0 30 | } 31 | if err != nil { 32 | openAIError := dto.OpenAIError{ 33 | Message: err.Error(), 34 | Type: "upstream_error", 35 | } 36 | c.JSON(200, gin.H{ 37 | "error": openAIError, 38 | }) 39 | return 40 | } 41 | quota := remainQuota + usedQuota 42 | amount := float64(quota) 43 | if common.DisplayInCurrencyEnabled { 44 | amount /= common.QuotaPerUnit 45 | } 46 | if token != nil && token.UnlimitedQuota { 47 | amount = 100000000 48 | } 49 | subscription := OpenAISubscriptionResponse{ 50 | Object: "billing_subscription", 51 | HasPaymentMethod: true, 52 | SoftLimitUSD: amount, 53 | HardLimitUSD: amount, 54 | SystemHardLimitUSD: amount, 55 | AccessUntil: expiredTime, 56 | } 57 | c.JSON(200, subscription) 58 | } 59 | 60 | func GetUsage(c *gin.Context) { 61 | var quota int 62 | var err error 63 | var token *model.Token 64 | if common.DisplayTokenStatEnabled { 65 | tokenId := c.GetInt("token_id") 66 | token, err = model.GetTokenById(tokenId) 67 | quota = token.UsedQuota 68 | } else { 69 | userId := c.GetInt("id") 70 | quota, err = model.GetUserUsedQuota(userId) 71 | } 72 | if err != nil { 73 | openAIError := dto.OpenAIError{ 74 | Message: err.Error(), 75 | Type: "new_api_error", 76 | } 77 | c.JSON(200, gin.H{ 78 | "error": openAIError, 79 | }) 80 | return 81 | } 82 | amount := float64(quota) 83 | if common.DisplayInCurrencyEnabled { 84 | amount /= common.QuotaPerUnit 85 | } 86 | usage := OpenAIUsageResponse{ 87 | Object: "list", 88 | TotalUsage: amount * 100, 89 | } 90 | c.JSON(200, usage) 91 | } 92 | -------------------------------------------------------------------------------- /controller/group.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "net/http" 5 | "one-api/common" 6 | 7 | "github.com/gin-gonic/gin" 8 | ) 9 | 10 | func GetGroups(c *gin.Context) { 11 | groupNames := make([]string, 0) 12 | for groupName := range common.GroupRatio { 13 | groupNames = append(groupNames, groupName) 14 | } 15 | c.JSON(http.StatusOK, gin.H{ 16 | "success": true, 17 | "message": "", 18 | "data": groupNames, 19 | }) 20 | } 21 | -------------------------------------------------------------------------------- /controller/option.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "one-api/common" 7 | "one-api/model" 8 | "strings" 9 | 10 | "github.com/gin-gonic/gin" 11 | ) 12 | 13 | func GetOptions(c *gin.Context) { 14 | var options []*model.Option 15 | common.OptionMapRWMutex.Lock() 16 | for k, v := range common.OptionMap { 17 | if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") { 18 | continue 19 | } 20 | options = append(options, &model.Option{ 21 | Key: k, 22 | Value: common.Interface2String(v), 23 | }) 24 | } 25 | common.OptionMapRWMutex.Unlock() 26 | c.JSON(http.StatusOK, gin.H{ 27 | "success": true, 28 | "message": "", 29 | "data": options, 30 | }) 31 | } 32 | 33 | func UpdateOption(c *gin.Context) { 34 | var option model.Option 35 | err := json.NewDecoder(c.Request.Body).Decode(&option) 36 | if err != nil { 37 | c.JSON(http.StatusBadRequest, gin.H{ 38 | "success": false, 39 | "message": "无效的参数", 40 | }) 41 | return 42 | } 43 | switch option.Key { 44 | case "GitHubOAuthEnabled": 45 | if option.Value == "true" && common.GitHubClientId == "" { 46 | c.JSON(http.StatusOK, gin.H{ 47 | "success": false, 48 | "message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", 49 | }) 50 | return 51 | } 52 | case "EmailDomainRestrictionEnabled": 53 | if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { 54 | c.JSON(http.StatusOK, gin.H{ 55 | "success": false, 56 | "message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", 57 | }) 58 | return 59 | } 60 | case "WeChatAuthEnabled": 61 | if option.Value == "true" && common.WeChatServerAddress == "" { 62 | c.JSON(http.StatusOK, gin.H{ 63 | "success": false, 64 | "message": "无法启用微信登录,请先填入微信登录相关配置信息!", 65 | }) 66 | return 67 | } 68 | case "TurnstileCheckEnabled": 69 | if option.Value == "true" && common.TurnstileSiteKey == "" { 70 | c.JSON(http.StatusOK, gin.H{ 71 | "success": false, 72 | "message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", 73 | }) 74 | return 75 | } 76 | } 77 | err = model.UpdateOption(option.Key, option.Value) 78 | if err != nil { 79 | c.JSON(http.StatusOK, gin.H{ 80 | "success": false, 81 | "message": err.Error(), 82 | }) 83 | return 84 | } 85 | c.JSON(http.StatusOK, gin.H{ 86 | "success": true, 87 | "message": "", 88 | }) 89 | } 90 | -------------------------------------------------------------------------------- /controller/redemption.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "net/http" 5 | "one-api/common" 6 | "one-api/model" 7 | "strconv" 8 | 9 | "github.com/gin-gonic/gin" 10 | ) 11 | 12 | func GetAllRedemptions(c *gin.Context) { 13 | p, _ := strconv.Atoi(c.Query("p")) 14 | if p < 0 { 15 | p = 0 16 | } 17 | redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage) 18 | if err != nil { 19 | c.JSON(http.StatusOK, gin.H{ 20 | "success": false, 21 | "message": err.Error(), 22 | }) 23 | return 24 | } 25 | c.JSON(http.StatusOK, gin.H{ 26 | "success": true, 27 | "message": "", 28 | "data": redemptions, 29 | }) 30 | } 31 | 32 | func SearchRedemptions(c *gin.Context) { 33 | keyword := c.Query("keyword") 34 | redemptions, err := model.SearchRedemptions(keyword) 35 | if err != nil { 36 | c.JSON(http.StatusOK, gin.H{ 37 | "success": false, 38 | "message": err.Error(), 39 | }) 40 | return 41 | } 42 | c.JSON(http.StatusOK, gin.H{ 43 | "success": true, 44 | "message": "", 45 | "data": redemptions, 46 | }) 47 | } 48 | 49 | func GetRedemption(c *gin.Context) { 50 | id, err := strconv.Atoi(c.Param("id")) 51 | if err != nil { 52 | c.JSON(http.StatusOK, gin.H{ 53 | "success": false, 54 | "message": err.Error(), 55 | }) 56 | return 57 | } 58 | redemption, err := model.GetRedemptionById(id) 59 | if err != nil { 60 | c.JSON(http.StatusOK, gin.H{ 61 | "success": false, 62 | "message": err.Error(), 63 | }) 64 | return 65 | } 66 | c.JSON(http.StatusOK, gin.H{ 67 | "success": true, 68 | "message": "", 69 | "data": redemption, 70 | }) 71 | } 72 | 73 | func AddRedemption(c *gin.Context) { 74 | redemption := model.Redemption{} 75 | err := c.ShouldBindJSON(&redemption) 76 | if err != nil { 77 | c.JSON(http.StatusOK, gin.H{ 78 | "success": false, 79 | "message": err.Error(), 80 | }) 81 | return 82 | } 83 | if len(redemption.Name) == 0 || len(redemption.Name) > 20 { 84 | c.JSON(http.StatusOK, gin.H{ 85 | "success": false, 86 | "message": "兑换码名称长度必须在1-20之间", 87 | }) 88 | return 89 | } 90 | if redemption.Count <= 0 { 91 | c.JSON(http.StatusOK, gin.H{ 92 | "success": false, 93 | "message": "兑换码个数必须大于0", 94 | }) 95 | return 96 | } 97 | if redemption.Count > 100 { 98 | c.JSON(http.StatusOK, gin.H{ 99 | "success": false, 100 | "message": "一次兑换码批量生成的个数不能大于 100", 101 | }) 102 | return 103 | } 104 | var keys []string 105 | for i := 0; i < redemption.Count; i++ { 106 | key := common.GetUUID() 107 | cleanRedemption := model.Redemption{ 108 | UserId: c.GetInt("id"), 109 | Name: redemption.Name, 110 | Key: key, 111 | CreatedTime: common.GetTimestamp(), 112 | Quota: redemption.Quota, 113 | } 114 | err = cleanRedemption.Insert() 115 | if err != nil { 116 | c.JSON(http.StatusOK, gin.H{ 117 | "success": false, 118 | "message": err.Error(), 119 | "data": keys, 120 | }) 121 | return 122 | } 123 | keys = append(keys, key) 124 | } 125 | c.JSON(http.StatusOK, gin.H{ 126 | "success": true, 127 | "message": "", 128 | "data": keys, 129 | }) 130 | } 131 | 132 | func DeleteRedemption(c *gin.Context) { 133 | id, _ := strconv.Atoi(c.Param("id")) 134 | err := model.DeleteRedemptionById(id) 135 | if err != nil { 136 | c.JSON(http.StatusOK, gin.H{ 137 | "success": false, 138 | "message": err.Error(), 139 | }) 140 | return 141 | } 142 | c.JSON(http.StatusOK, gin.H{ 143 | "success": true, 144 | "message": "", 145 | }) 146 | } 147 | 148 | func UpdateRedemption(c *gin.Context) { 149 | statusOnly := c.Query("status_only") 150 | redemption := model.Redemption{} 151 | err := c.ShouldBindJSON(&redemption) 152 | if err != nil { 153 | c.JSON(http.StatusOK, gin.H{ 154 | "success": false, 155 | "message": err.Error(), 156 | }) 157 | return 158 | } 159 | cleanRedemption, err := model.GetRedemptionById(redemption.Id) 160 | if err != nil { 161 | c.JSON(http.StatusOK, gin.H{ 162 | "success": false, 163 | "message": err.Error(), 164 | }) 165 | return 166 | } 167 | if statusOnly != "" { 168 | cleanRedemption.Status = redemption.Status 169 | } else { 170 | // If you add more fields, please also update redemption.Update() 171 | cleanRedemption.Name = redemption.Name 172 | cleanRedemption.Quota = redemption.Quota 173 | } 174 | err = cleanRedemption.Update() 175 | if err != nil { 176 | c.JSON(http.StatusOK, gin.H{ 177 | "success": false, 178 | "message": err.Error(), 179 | }) 180 | return 181 | } 182 | c.JSON(http.StatusOK, gin.H{ 183 | "success": true, 184 | "message": "", 185 | "data": cleanRedemption, 186 | }) 187 | } 188 | -------------------------------------------------------------------------------- /controller/telegram.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "crypto/hmac" 5 | "crypto/sha256" 6 | "encoding/hex" 7 | "io" 8 | "one-api/common" 9 | "one-api/model" 10 | "sort" 11 | 12 | "github.com/gin-contrib/sessions" 13 | "github.com/gin-gonic/gin" 14 | ) 15 | 16 | func TelegramBind(c *gin.Context) { 17 | if !common.TelegramOAuthEnabled { 18 | c.JSON(200, gin.H{ 19 | "message": "管理员未开启通过 Telegram 登录以及注册", 20 | "success": false, 21 | }) 22 | return 23 | } 24 | params := c.Request.URL.Query() 25 | if !checkTelegramAuthorization(params, common.TelegramBotToken) { 26 | c.JSON(200, gin.H{ 27 | "message": "无效的请求", 28 | "success": false, 29 | }) 30 | return 31 | } 32 | telegramId := params["id"][0] 33 | if model.IsTelegramIdAlreadyTaken(telegramId) { 34 | c.JSON(200, gin.H{ 35 | "message": "该 Telegram 账户已被绑定", 36 | "success": false, 37 | }) 38 | return 39 | } 40 | 41 | session := sessions.Default(c) 42 | id := session.Get("id") 43 | user := model.User{Id: id.(int)} 44 | if err := user.FillUserById(); err != nil { 45 | c.JSON(200, gin.H{ 46 | "message": err.Error(), 47 | "success": false, 48 | }) 49 | return 50 | } 51 | user.TelegramId = telegramId 52 | if err := user.Update(false); err != nil { 53 | c.JSON(200, gin.H{ 54 | "message": err.Error(), 55 | "success": false, 56 | }) 57 | return 58 | } 59 | 60 | c.Redirect(302, "/setting") 61 | } 62 | 63 | func TelegramLogin(c *gin.Context) { 64 | if !common.TelegramOAuthEnabled { 65 | c.JSON(200, gin.H{ 66 | "message": "管理员未开启通过 Telegram 登录以及注册", 67 | "success": false, 68 | }) 69 | return 70 | } 71 | params := c.Request.URL.Query() 72 | if !checkTelegramAuthorization(params, common.TelegramBotToken) { 73 | c.JSON(200, gin.H{ 74 | "message": "无效的请求", 75 | "success": false, 76 | }) 77 | return 78 | } 79 | 80 | telegramId := params["id"][0] 81 | user := model.User{TelegramId: telegramId} 82 | if err := user.FillUserByTelegramId(); err != nil { 83 | c.JSON(200, gin.H{ 84 | "message": err.Error(), 85 | "success": false, 86 | }) 87 | return 88 | } 89 | setupLogin(&user, c) 90 | } 91 | 92 | func checkTelegramAuthorization(params map[string][]string, token string) bool { 93 | strs := []string{} 94 | var hash = "" 95 | for k, v := range params { 96 | if k == "hash" { 97 | hash = v[0] 98 | continue 99 | } 100 | strs = append(strs, k+"="+v[0]) 101 | } 102 | sort.Strings(strs) 103 | var imploded = "" 104 | for _, s := range strs { 105 | if imploded != "" { 106 | imploded += "\n" 107 | } 108 | imploded += s 109 | } 110 | sha256hash := sha256.New() 111 | _, _ = io.WriteString(sha256hash, token) 112 | hmachash := hmac.New(sha256.New, sha256hash.Sum(nil)) 113 | _, _ = io.WriteString(hmachash, imploded) 114 | ss := hex.EncodeToString(hmachash.Sum(nil)) 115 | return hash == ss 116 | } 117 | -------------------------------------------------------------------------------- /controller/usedata.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "net/http" 5 | "one-api/model" 6 | "strconv" 7 | 8 | "github.com/gin-gonic/gin" 9 | ) 10 | 11 | func GetAllQuotaDates(c *gin.Context) { 12 | startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) 13 | endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) 14 | username := c.Query("username") 15 | dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username) 16 | if err != nil { 17 | c.JSON(http.StatusOK, gin.H{ 18 | "success": false, 19 | "message": err.Error(), 20 | }) 21 | return 22 | } 23 | c.JSON(http.StatusOK, gin.H{ 24 | "success": true, 25 | "message": "", 26 | "data": dates, 27 | }) 28 | } 29 | 30 | func GetUserQuotaDates(c *gin.Context) { 31 | userId := c.GetInt("id") 32 | startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) 33 | endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) 34 | // 判断时间跨度是否超过 1 个月 35 | if endTimestamp-startTimestamp > 2592000 { 36 | c.JSON(http.StatusOK, gin.H{ 37 | "success": false, 38 | "message": "时间跨度不能超过 1 个月", 39 | }) 40 | return 41 | } 42 | dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp) 43 | if err != nil { 44 | c.JSON(http.StatusOK, gin.H{ 45 | "success": false, 46 | "message": err.Error(), 47 | }) 48 | return 49 | } 50 | c.JSON(http.StatusOK, gin.H{ 51 | "success": true, 52 | "message": "", 53 | "data": dates, 54 | }) 55 | } 56 | -------------------------------------------------------------------------------- /controller/wechat.go: -------------------------------------------------------------------------------- 1 | package controller 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | "one-api/common" 9 | "one-api/model" 10 | "strconv" 11 | "time" 12 | 13 | "github.com/gin-gonic/gin" 14 | ) 15 | 16 | type wechatLoginResponse struct { 17 | Success bool `json:"success"` 18 | Message string `json:"message"` 19 | Data string `json:"data"` 20 | } 21 | 22 | func getWeChatIdByCode(code string) (string, error) { 23 | if code == "" { 24 | return "", errors.New("无效的参数") 25 | } 26 | req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil) 27 | if err != nil { 28 | return "", err 29 | } 30 | req.Header.Set("Authorization", common.WeChatServerToken) 31 | client := http.Client{ 32 | Timeout: 5 * time.Second, 33 | } 34 | httpResponse, err := client.Do(req) 35 | if err != nil { 36 | return "", err 37 | } 38 | defer httpResponse.Body.Close() 39 | var res wechatLoginResponse 40 | err = json.NewDecoder(httpResponse.Body).Decode(&res) 41 | if err != nil { 42 | return "", err 43 | } 44 | if !res.Success { 45 | return "", errors.New(res.Message) 46 | } 47 | if res.Data == "" { 48 | return "", errors.New("验证码错误或已过期") 49 | } 50 | return res.Data, nil 51 | } 52 | 53 | func WeChatAuth(c *gin.Context) { 54 | if !common.WeChatAuthEnabled { 55 | c.JSON(http.StatusOK, gin.H{ 56 | "message": "管理员未开启通过微信登录以及注册", 57 | "success": false, 58 | }) 59 | return 60 | } 61 | code := c.Query("code") 62 | wechatId, err := getWeChatIdByCode(code) 63 | if err != nil { 64 | c.JSON(http.StatusOK, gin.H{ 65 | "message": err.Error(), 66 | "success": false, 67 | }) 68 | return 69 | } 70 | user := model.User{ 71 | WeChatId: wechatId, 72 | } 73 | if model.IsWeChatIdAlreadyTaken(wechatId) { 74 | err := user.FillUserByWeChatId() 75 | if err != nil { 76 | c.JSON(http.StatusOK, gin.H{ 77 | "success": false, 78 | "message": err.Error(), 79 | }) 80 | return 81 | } 82 | } else { 83 | if common.RegisterEnabled { 84 | user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) 85 | user.DisplayName = "WeChat User" 86 | user.Role = common.RoleCommonUser 87 | user.Status = common.UserStatusEnabled 88 | 89 | if err := user.Insert(0); err != nil { 90 | c.JSON(http.StatusOK, gin.H{ 91 | "success": false, 92 | "message": err.Error(), 93 | }) 94 | return 95 | } 96 | } else { 97 | c.JSON(http.StatusOK, gin.H{ 98 | "success": false, 99 | "message": "管理员关闭了新用户注册", 100 | }) 101 | return 102 | } 103 | } 104 | 105 | if user.Status != common.UserStatusEnabled { 106 | c.JSON(http.StatusOK, gin.H{ 107 | "message": "用户已被封禁", 108 | "success": false, 109 | }) 110 | return 111 | } 112 | setupLogin(&user, c) 113 | } 114 | 115 | func WeChatBind(c *gin.Context) { 116 | if !common.WeChatAuthEnabled { 117 | c.JSON(http.StatusOK, gin.H{ 118 | "message": "管理员未开启通过微信登录以及注册", 119 | "success": false, 120 | }) 121 | return 122 | } 123 | code := c.Query("code") 124 | wechatId, err := getWeChatIdByCode(code) 125 | if err != nil { 126 | c.JSON(http.StatusOK, gin.H{ 127 | "message": err.Error(), 128 | "success": false, 129 | }) 130 | return 131 | } 132 | if model.IsWeChatIdAlreadyTaken(wechatId) { 133 | c.JSON(http.StatusOK, gin.H{ 134 | "success": false, 135 | "message": "该微信账号已被绑定", 136 | }) 137 | return 138 | } 139 | id := c.GetInt("id") 140 | user := model.User{ 141 | Id: id, 142 | } 143 | err = user.FillUserById() 144 | if err != nil { 145 | c.JSON(http.StatusOK, gin.H{ 146 | "success": false, 147 | "message": err.Error(), 148 | }) 149 | return 150 | } 151 | user.WeChatId = wechatId 152 | err = user.Update(false) 153 | if err != nil { 154 | c.JSON(http.StatusOK, gin.H{ 155 | "success": false, 156 | "message": err.Error(), 157 | }) 158 | return 159 | } 160 | c.JSON(http.StatusOK, gin.H{ 161 | "success": true, 162 | "message": "", 163 | }) 164 | } 165 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.4' 2 | 3 | services: 4 | new-api: 5 | image: calciumion/new-api:latest 6 | # build: . 7 | container_name: new-api 8 | restart: always 9 | command: --log-dir /app/logs 10 | ports: 11 | - "3000:3000" 12 | volumes: 13 | - ./data:/data 14 | - ./logs:/app/logs 15 | environment: 16 | - SQL_DSN=root:123456@tcp(host.docker.internal:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库 17 | - REDIS_CONN_STRING=redis://redis 18 | - SESSION_SECRET=random_string # 修改为随机字符串 19 | - TZ=Asia/Shanghai 20 | # - NODE_TYPE=slave # 多机部署时从节点取消注释该行 21 | # - SYNC_FREQUENCY=60 # 需要定期从数据库加载数据时取消注释该行 22 | # - FRONTEND_BASE_URL=https://openai.justsong.cn # 多机部署时从节点取消注释该行 23 | 24 | depends_on: 25 | - redis 26 | healthcheck: 27 | test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ] 28 | interval: 30s 29 | timeout: 10s 30 | retries: 3 31 | 32 | redis: 33 | image: redis:latest 34 | container_name: redis 35 | restart: always 36 | -------------------------------------------------------------------------------- /dto/audio.go: -------------------------------------------------------------------------------- 1 | package dto 2 | 3 | type TextToSpeechRequest struct { 4 | Model string `json:"model" binding:"required"` 5 | Input string `json:"input" binding:"required"` 6 | Voice string `json:"voice" binding:"required"` 7 | Speed float64 `json:"speed"` 8 | ResponseFormat string `json:"response_format"` 9 | } 10 | 11 | type AudioResponse struct { 12 | Text string `json:"text"` 13 | } 14 | -------------------------------------------------------------------------------- /dto/dalle.go: -------------------------------------------------------------------------------- 1 | package dto 2 | 3 | type ImageRequest struct { 4 | Model string `json:"model"` 5 | Prompt string `json:"prompt" binding:"required"` 6 | N int `json:"n,omitempty"` 7 | Size string `json:"size,omitempty"` 8 | Quality string `json:"quality,omitempty"` 9 | ResponseFormat string `json:"response_format,omitempty"` 10 | Style string `json:"style,omitempty"` 11 | User string `json:"user,omitempty"` 12 | } 13 | 14 | type ImageResponse struct { 15 | Created int `json:"created"` 16 | Data []struct { 17 | Url string `json:"url"` 18 | B64Json string `json:"b64_json"` 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /dto/error.go: -------------------------------------------------------------------------------- 1 | package dto 2 | 3 | type OpenAIError struct { 4 | Message string `json:"message"` 5 | Type string `json:"type"` 6 | Param string `json:"param"` 7 | Code any `json:"code"` 8 | } 9 | 10 | type OpenAIErrorWithStatusCode struct { 11 | Error OpenAIError `json:"error"` 12 | StatusCode int `json:"status_code"` 13 | } 14 | 15 | type GeneralErrorResponse struct { 16 | Error OpenAIError `json:"error"` 17 | Message string `json:"message"` 18 | Msg string `json:"msg"` 19 | Err string `json:"err"` 20 | ErrorMsg string `json:"error_msg"` 21 | Header struct { 22 | Message string `json:"message"` 23 | } `json:"header"` 24 | Response struct { 25 | Error struct { 26 | Message string `json:"message"` 27 | } `json:"error"` 28 | } `json:"response"` 29 | } 30 | 31 | func (e GeneralErrorResponse) ToMessage() string { 32 | if e.Error.Message != "" { 33 | return e.Error.Message 34 | } 35 | if e.Message != "" { 36 | return e.Message 37 | } 38 | if e.Msg != "" { 39 | return e.Msg 40 | } 41 | if e.Err != "" { 42 | return e.Err 43 | } 44 | if e.ErrorMsg != "" { 45 | return e.ErrorMsg 46 | } 47 | if e.Header.Message != "" { 48 | return e.Header.Message 49 | } 50 | if e.Response.Error.Message != "" { 51 | return e.Response.Error.Message 52 | } 53 | return "" 54 | } 55 | -------------------------------------------------------------------------------- /dto/midjourney.go: -------------------------------------------------------------------------------- 1 | package dto 2 | 3 | type MidjourneyRequest struct { 4 | Prompt string `json:"prompt"` 5 | NotifyHook string `json:"notifyHook"` 6 | Action string `json:"action"` 7 | Index int `json:"index"` 8 | State string `json:"state"` 9 | TaskId string `json:"taskId"` 10 | Base64Array []string `json:"base64Array"` 11 | Content string `json:"content"` 12 | } 13 | 14 | type MidjourneyResponse struct { 15 | Code int `json:"code"` 16 | Description string `json:"description"` 17 | Properties interface{} `json:"properties"` 18 | Result string `json:"result"` 19 | } 20 | -------------------------------------------------------------------------------- /dto/text_response.go: -------------------------------------------------------------------------------- 1 | package dto 2 | 3 | type TextResponse struct { 4 | Choices []OpenAITextResponseChoice `json:"choices"` 5 | Usage `json:"usage"` 6 | Error OpenAIError `json:"error"` 7 | } 8 | 9 | type OpenAITextResponseChoice struct { 10 | Index int `json:"index"` 11 | Message `json:"message"` 12 | FinishReason string `json:"finish_reason"` 13 | } 14 | 15 | type OpenAITextResponse struct { 16 | Id string `json:"id"` 17 | Object string `json:"object"` 18 | Created int64 `json:"created"` 19 | Choices []OpenAITextResponseChoice `json:"choices"` 20 | Usage `json:"usage"` 21 | } 22 | 23 | type OpenAIEmbeddingResponseItem struct { 24 | Object string `json:"object"` 25 | Index int `json:"index"` 26 | Embedding []float64 `json:"embedding"` 27 | } 28 | 29 | type OpenAIEmbeddingResponse struct { 30 | Object string `json:"object"` 31 | Data []OpenAIEmbeddingResponseItem `json:"data"` 32 | Model string `json:"model"` 33 | Usage `json:"usage"` 34 | } 35 | 36 | type ChatCompletionsStreamResponseChoice struct { 37 | Delta struct { 38 | Content string `json:"content"` 39 | Role string `json:"role,omitempty"` 40 | ToolCalls any `json:"tool_calls,omitempty"` 41 | } `json:"delta"` 42 | FinishReason *string `json:"finish_reason,omitempty"` 43 | Index int `json:"index,omitempty"` 44 | } 45 | 46 | type ChatCompletionsStreamResponse struct { 47 | Id string `json:"id"` 48 | Object string `json:"object"` 49 | Created int64 `json:"created"` 50 | Model string `json:"model"` 51 | Choices []ChatCompletionsStreamResponseChoice `json:"choices"` 52 | } 53 | 54 | type ChatCompletionsStreamResponseSimple struct { 55 | Choices []ChatCompletionsStreamResponseChoice `json:"choices"` 56 | } 57 | 58 | type CompletionsStreamResponse struct { 59 | Choices []struct { 60 | Text string `json:"text"` 61 | FinishReason string `json:"finish_reason"` 62 | } `json:"choices"` 63 | } 64 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module one-api 2 | 3 | // +heroku goVersion go1.18 4 | go 1.18 5 | 6 | require ( 7 | github.com/chai2010/webp v1.1.1 8 | github.com/gin-contrib/cors v1.4.0 9 | github.com/gin-contrib/gzip v0.0.6 10 | github.com/gin-contrib/sessions v0.0.5 11 | github.com/gin-contrib/static v0.0.1 12 | github.com/gin-gonic/gin v1.9.1 13 | github.com/go-playground/validator/v10 v10.16.0 14 | github.com/go-redis/redis/v8 v8.11.5 15 | github.com/golang-jwt/jwt v3.2.2+incompatible 16 | github.com/google/uuid v1.3.0 17 | github.com/gorilla/websocket v1.5.0 18 | github.com/pkoukk/tiktoken-go v0.1.6 19 | github.com/samber/lo v1.38.1 20 | github.com/shirou/gopsutil v3.21.11+incompatible 21 | github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 22 | golang.org/x/crypto v0.17.0 23 | gorm.io/driver/mysql v1.4.3 24 | gorm.io/driver/postgres v1.5.2 25 | gorm.io/driver/sqlite v1.4.3 26 | gorm.io/gorm v1.25.0 27 | ) 28 | 29 | require ( 30 | github.com/bytedance/sonic v1.9.1 // indirect 31 | github.com/cespare/xxhash/v2 v2.1.2 // indirect 32 | github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect 33 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 34 | github.com/dlclark/regexp2 v1.10.0 // indirect 35 | github.com/gabriel-vasile/mimetype v1.4.2 // indirect 36 | github.com/gin-contrib/sse v0.1.0 // indirect 37 | github.com/go-ole/go-ole v1.2.6 // indirect 38 | github.com/go-playground/locales v0.14.1 // indirect 39 | github.com/go-playground/universal-translator v0.18.1 // indirect 40 | github.com/go-sql-driver/mysql v1.6.0 // indirect 41 | github.com/goccy/go-json v0.10.2 // indirect 42 | github.com/gorilla/context v1.1.1 // indirect 43 | github.com/gorilla/securecookie v1.1.1 // indirect 44 | github.com/gorilla/sessions v1.2.1 // indirect 45 | github.com/jackc/pgpassfile v1.0.0 // indirect 46 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect 47 | github.com/jackc/pgx/v5 v5.5.1 // indirect 48 | github.com/jackc/puddle/v2 v2.2.1 // indirect 49 | github.com/jinzhu/inflection v1.0.0 // indirect 50 | github.com/jinzhu/now v1.1.5 // indirect 51 | github.com/json-iterator/go v1.1.12 // indirect 52 | github.com/klauspost/cpuid/v2 v2.2.4 // indirect 53 | github.com/leodido/go-urn v1.2.4 // indirect 54 | github.com/mattn/go-isatty v0.0.20 // indirect 55 | github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect 56 | github.com/mitchellh/mapstructure v1.5.0 // indirect 57 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 58 | github.com/modern-go/reflect2 v1.0.2 // indirect 59 | github.com/pelletier/go-toml/v2 v2.0.8 // indirect 60 | github.com/tklauser/go-sysconf v0.3.12 // indirect 61 | github.com/tklauser/numcpus v0.6.1 // indirect 62 | github.com/twitchyliquid64/golang-asm v0.15.1 // indirect 63 | github.com/ugorji/go/codec v1.2.11 // indirect 64 | github.com/yusufpapurcu/wmi v1.2.3 // indirect 65 | golang.org/x/arch v0.3.0 // indirect 66 | golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect 67 | golang.org/x/net v0.17.0 // indirect 68 | golang.org/x/sync v0.1.0 // indirect 69 | golang.org/x/sys v0.15.0 // indirect 70 | golang.org/x/text v0.14.0 // indirect 71 | google.golang.org/protobuf v1.30.0 // indirect 72 | gopkg.in/yaml.v3 v3.0.1 // indirect 73 | ) 74 | -------------------------------------------------------------------------------- /i18n/translate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | def list_file_paths(path): 6 | file_paths = [] 7 | for root, dirs, files in os.walk(path): 8 | if "node_modules" in dirs: 9 | dirs.remove("node_modules") 10 | if "build" in dirs: 11 | dirs.remove("build") 12 | if "i18n" in dirs: 13 | dirs.remove("i18n") 14 | for file in files: 15 | file_path = os.path.join(root, file) 16 | if file_path.endswith("png") or file_path.endswith("ico") or file_path.endswith("db") or file_path.endswith("exe"): 17 | continue 18 | file_paths.append(file_path) 19 | 20 | for dir in dirs: 21 | dir_path = os.path.join(root, dir) 22 | file_paths += list_file_paths(dir_path) 23 | 24 | return file_paths 25 | 26 | 27 | def replace_keys_in_repository(repo_path, json_file_path): 28 | with open(json_file_path, 'r', encoding="utf-8") as json_file: 29 | key_value_pairs = json.load(json_file) 30 | 31 | pairs = [] 32 | for key, value in key_value_pairs.items(): 33 | pairs.append((key, value)) 34 | pairs.sort(key=lambda x: len(x[0]), reverse=True) 35 | 36 | files = list_file_paths(repo_path) 37 | print('Total files: {}'.format(len(files))) 38 | for file_path in files: 39 | replace_keys_in_file(file_path, pairs) 40 | 41 | 42 | def replace_keys_in_file(file_path, pairs): 43 | try: 44 | with open(file_path, 'r', encoding="utf-8") as file: 45 | content = file.read() 46 | 47 | for key, value in pairs: 48 | content = content.replace(key, value) 49 | 50 | with open(file_path, 'w', encoding="utf-8") as file: 51 | file.write(content) 52 | except UnicodeDecodeError: 53 | print('UnicodeDecodeError: {}'.format(file_path)) 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser(description='Replace keys in repository.') 58 | parser.add_argument('--repository_path', help='Path to repository') 59 | parser.add_argument('--json_file_path', help='Path to JSON file') 60 | args = parser.parse_args() 61 | replace_keys_in_repository(args.repository_path, args.json_file_path) 62 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "embed" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "net/http" 9 | "one-api/common" 10 | "one-api/controller" 11 | "one-api/middleware" 12 | "one-api/model" 13 | "one-api/router" 14 | "one-api/service" 15 | "os" 16 | "strconv" 17 | 18 | "github.com/gin-contrib/sessions" 19 | "github.com/gin-contrib/sessions/cookie" 20 | "github.com/gin-gonic/gin" 21 | 22 | _ "net/http/pprof" 23 | ) 24 | 25 | //go:embed web/build 26 | var buildFS embed.FS 27 | 28 | //go:embed web/build/index.html 29 | var indexPage []byte 30 | 31 | func main() { 32 | common.SetupLogger() 33 | common.SysLog("New API " + common.Version + " started") 34 | if os.Getenv("GIN_MODE") != "debug" { 35 | gin.SetMode(gin.ReleaseMode) 36 | } 37 | if common.DebugEnabled { 38 | common.SysLog("running in debug mode") 39 | } 40 | // Initialize SQL Database 41 | err := model.InitDB() 42 | if err != nil { 43 | common.FatalLog("failed to initialize database: " + err.Error()) 44 | } 45 | defer func() { 46 | err := model.CloseDB() 47 | if err != nil { 48 | common.FatalLog("failed to close database: " + err.Error()) 49 | } 50 | }() 51 | // 必须在数据库初始化之后 52 | flag.Parse() 53 | if len(os.Args) > 1 { 54 | subCommand := os.Args[1] 55 | switch subCommand { 56 | case "migrate": 57 | model.MustMigrate() 58 | default: 59 | fmt.Printf("未知的子命令: %s\n", subCommand) 60 | os.Exit(1) 61 | } 62 | } 63 | 64 | // Initialize Redis 65 | err = common.InitRedisClient() 66 | if err != nil { 67 | common.FatalLog("failed to initialize Redis: " + err.Error()) 68 | } 69 | 70 | // Initialize options 71 | model.InitOptionMap() 72 | if common.RedisEnabled { 73 | // for compatibility with old versions 74 | common.MemoryCacheEnabled = true 75 | } 76 | if common.MemoryCacheEnabled { 77 | common.SysLog("memory cache enabled") 78 | common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) 79 | model.InitChannelCache() 80 | } 81 | if common.RedisEnabled { 82 | go model.SyncTokenCache(common.SyncFrequency) 83 | } 84 | if common.MemoryCacheEnabled { 85 | go model.SyncOptions(common.SyncFrequency) 86 | go model.SyncChannelCache(common.SyncFrequency) 87 | } 88 | 89 | // 数据看板 90 | go model.UpdateQuotaData() 91 | 92 | if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { 93 | frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) 94 | if err != nil { 95 | common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) 96 | } 97 | go controller.AutomaticallyUpdateChannels(frequency) 98 | } 99 | if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { 100 | frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) 101 | if err != nil { 102 | common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) 103 | } 104 | go controller.AutomaticallyTestChannels(frequency) 105 | } 106 | common.SafeGoroutine(func() { 107 | controller.UpdateMidjourneyTaskBulk() 108 | }) 109 | if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { 110 | common.BatchUpdateEnabled = true 111 | common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") 112 | model.InitBatchUpdater() 113 | } 114 | 115 | if os.Getenv("ENABLE_PPROF") == "true" { 116 | go func() { 117 | log.Println(http.ListenAndServe("0.0.0.0:8005", nil)) 118 | }() 119 | go common.Monitor() 120 | common.SysLog("pprof enabled") 121 | } 122 | 123 | service.InitTokenEncoders() 124 | 125 | // Initialize HTTP server 126 | server := gin.New() 127 | server.Use(gin.CustomRecovery(func(c *gin.Context, err any) { 128 | common.SysError(fmt.Sprintf("panic detected: %v", err)) 129 | c.JSON(http.StatusInternalServerError, gin.H{ 130 | "error": gin.H{ 131 | "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err), 132 | "type": "new_api_panic", 133 | }, 134 | }) 135 | })) 136 | // This will cause SSE not to work!!! 137 | //server.Use(gzip.Gzip(gzip.DefaultCompression)) 138 | server.Use(middleware.RequestId()) 139 | middleware.SetUpLogger(server) 140 | // Initialize session store 141 | store := cookie.NewStore([]byte(common.SessionSecret)) 142 | server.Use(sessions.Sessions("session", store)) 143 | 144 | router.SetRouter(server, buildFS, indexPage) 145 | var port = os.Getenv("PORT") 146 | if port == "" { 147 | port = strconv.Itoa(*common.Port) 148 | } 149 | err = server.Run(":" + port) 150 | if err != nil { 151 | common.FatalLog("failed to start HTTP server: " + err.Error()) 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | FRONTEND_DIR = ./web 2 | BACKEND_DIR = . 3 | 4 | .PHONY: all build-frontend start-backend 5 | 6 | all: build-frontend start-backend 7 | 8 | build-frontend: 9 | @echo "Building frontend..." 10 | @cd $(FRONTEND_DIR) && npm install && DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build npm run build 11 | 12 | start-backend: 13 | @echo "Starting backend dev server..." 14 | @cd $(BACKEND_DIR) && go run main.go & 15 | 16 | install-dev: 17 | go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest 18 | go install golang.org/x/tools/cmd/goimports@latest 19 | go install github.com/fzipp/gocyclo/cmd/gocyclo@latest 20 | go install github.com/BurntSushi/toml/cmd/tomlv@master 21 | go install github.com/go-critic/go-critic/cmd/gocritic@latest 22 | 23 | fmt: 24 | @echo "Running gofmt..." 25 | pre-commit run --all-files 26 | -------------------------------------------------------------------------------- /middleware/auth.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "one-api/common" 6 | "one-api/model" 7 | "strings" 8 | 9 | "github.com/gin-contrib/sessions" 10 | "github.com/gin-gonic/gin" 11 | ) 12 | 13 | func authHelper(c *gin.Context, minRole int) { 14 | session := sessions.Default(c) 15 | username := session.Get("username") 16 | role := session.Get("role") 17 | id := session.Get("id") 18 | status := session.Get("status") 19 | if username == nil { 20 | // Check access token 21 | accessToken := c.Request.Header.Get("Authorization") 22 | if accessToken == "" { 23 | c.JSON(http.StatusUnauthorized, gin.H{ 24 | "success": false, 25 | "message": "无权进行此操作,未登录且未提供 access token", 26 | }) 27 | c.Abort() 28 | return 29 | } 30 | user := model.ValidateAccessToken(accessToken) 31 | if user != nil && user.Username != "" { 32 | // Token is valid 33 | username = user.Username 34 | role = user.Role 35 | id = user.Id 36 | status = user.Status 37 | } else { 38 | c.JSON(http.StatusOK, gin.H{ 39 | "success": false, 40 | "message": "无权进行此操作,access token 无效", 41 | }) 42 | c.Abort() 43 | return 44 | } 45 | } 46 | if status.(int) == common.UserStatusDisabled { 47 | c.JSON(http.StatusOK, gin.H{ 48 | "success": false, 49 | "message": "用户已被封禁", 50 | }) 51 | c.Abort() 52 | return 53 | } 54 | if role.(int) < minRole { 55 | c.JSON(http.StatusOK, gin.H{ 56 | "success": false, 57 | "message": "无权进行此操作,权限不足", 58 | }) 59 | c.Abort() 60 | return 61 | } 62 | c.Set("username", username) 63 | c.Set("role", role) 64 | c.Set("id", id) 65 | c.Next() 66 | } 67 | 68 | func UserAuth() func(c *gin.Context) { 69 | return func(c *gin.Context) { 70 | authHelper(c, common.RoleCommonUser) 71 | } 72 | } 73 | 74 | func AdminAuth() func(c *gin.Context) { 75 | return func(c *gin.Context) { 76 | authHelper(c, common.RoleAdminUser) 77 | } 78 | } 79 | 80 | func RootAuth() func(c *gin.Context) { 81 | return func(c *gin.Context) { 82 | authHelper(c, common.RoleRootUser) 83 | } 84 | } 85 | 86 | func TokenAuth() func(c *gin.Context) { 87 | return func(c *gin.Context) { 88 | key := c.Request.Header.Get("Authorization") 89 | var parts []string 90 | key = strings.TrimPrefix(key, "Bearer ") 91 | if key == "" || key == "midjourney-proxy" { 92 | key = c.Request.Header.Get("mj-api-secret") 93 | key = strings.TrimPrefix(key, "Bearer ") 94 | key = strings.TrimPrefix(key, "sk-") 95 | parts = strings.Split(key, "-") 96 | key = parts[0] 97 | } else { 98 | key = strings.TrimPrefix(key, "sk-") 99 | parts = strings.Split(key, "-") 100 | key = parts[0] 101 | } 102 | token, err := model.ValidateUserToken(key) 103 | if err != nil { 104 | abortWithMessage(c, http.StatusUnauthorized, err.Error()) 105 | return 106 | } 107 | userEnabled, err := model.CacheIsUserEnabled(token.UserId) 108 | if err != nil { 109 | abortWithMessage(c, http.StatusInternalServerError, err.Error()) 110 | return 111 | } 112 | if !userEnabled { 113 | abortWithMessage(c, http.StatusForbidden, "用户已被封禁") 114 | return 115 | } 116 | c.Set("id", token.UserId) 117 | c.Set("token_id", token.Id) 118 | c.Set("token_name", token.Name) 119 | c.Set("token_unlimited_quota", token.UnlimitedQuota) 120 | if !token.UnlimitedQuota { 121 | c.Set("token_quota", token.RemainQuota) 122 | } 123 | if token.ModelLimitsEnabled { 124 | c.Set("token_model_limit_enabled", true) 125 | c.Set("token_model_limit", token.GetModelLimitsMap()) 126 | } else { 127 | c.Set("token_model_limit_enabled", false) 128 | } 129 | requestURL := c.Request.URL.String() 130 | consumeQuota := true 131 | if strings.HasPrefix(requestURL, "/v1/models") { 132 | consumeQuota = false 133 | } 134 | c.Set("consume_quota", consumeQuota) 135 | if len(parts) > 1 { 136 | if model.IsAdmin(token.UserId) { 137 | c.Set("channelId", parts[1]) 138 | } else { 139 | abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") 140 | return 141 | } 142 | } 143 | c.Next() 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /middleware/cache.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "github.com/gin-gonic/gin" 5 | ) 6 | 7 | func Cache() func(c *gin.Context) { 8 | return func(c *gin.Context) { 9 | if c.Request.RequestURI == "/" { 10 | c.Header("Cache-Control", "no-cache") 11 | } else { 12 | c.Header("Cache-Control", "max-age=604800") // one week 13 | } 14 | c.Next() 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /middleware/cors.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "github.com/gin-contrib/cors" 5 | "github.com/gin-gonic/gin" 6 | ) 7 | 8 | func CORS() gin.HandlerFunc { 9 | config := cors.DefaultConfig() 10 | config.AllowAllOrigins = true 11 | config.AllowCredentials = true 12 | config.AllowMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"} 13 | config.AllowHeaders = []string{"*"} 14 | return cors.New(config) 15 | } 16 | -------------------------------------------------------------------------------- /middleware/logger.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "one-api/common" 6 | 7 | "github.com/gin-gonic/gin" 8 | ) 9 | 10 | func SetUpLogger(server *gin.Engine) { 11 | server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { 12 | var requestID string 13 | if param.Keys != nil { 14 | requestID = param.Keys[common.RequestIdKey].(string) 15 | } 16 | return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", 17 | param.TimeStamp.Format("2006/01/02 - 15:04:05"), 18 | requestID, 19 | param.StatusCode, 20 | param.Latency, 21 | param.ClientIP, 22 | param.Method, 23 | param.Path, 24 | ) 25 | })) 26 | } 27 | -------------------------------------------------------------------------------- /middleware/rate-limit.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "one-api/common" 8 | "time" 9 | 10 | "github.com/gin-gonic/gin" 11 | ) 12 | 13 | var timeFormat = "2006-01-02T15:04:05.000Z" 14 | 15 | var inMemoryRateLimiter common.InMemoryRateLimiter 16 | 17 | func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { 18 | ctx := context.Background() 19 | rdb := common.RDB 20 | key := "rateLimit:" + mark + c.ClientIP() 21 | listLength, err := rdb.LLen(ctx, key).Result() 22 | if err != nil { 23 | fmt.Println(err.Error()) 24 | c.Status(http.StatusInternalServerError) 25 | c.Abort() 26 | return 27 | } 28 | if listLength < int64(maxRequestNum) { 29 | rdb.LPush(ctx, key, time.Now().Format(timeFormat)) 30 | rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) 31 | } else { 32 | oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() 33 | oldTime, err := time.Parse(timeFormat, oldTimeStr) 34 | if err != nil { 35 | fmt.Println(err) 36 | c.Status(http.StatusInternalServerError) 37 | c.Abort() 38 | return 39 | } 40 | nowTimeStr := time.Now().Format(timeFormat) 41 | nowTime, err := time.Parse(timeFormat, nowTimeStr) 42 | if err != nil { 43 | fmt.Println(err) 44 | c.Status(http.StatusInternalServerError) 45 | c.Abort() 46 | return 47 | } 48 | // time.Since will return negative number! 49 | // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows 50 | if int64(nowTime.Sub(oldTime).Seconds()) < duration { 51 | rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) 52 | c.Status(http.StatusTooManyRequests) 53 | c.Abort() 54 | return 55 | } else { 56 | rdb.LPush(ctx, key, time.Now().Format(timeFormat)) 57 | rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) 58 | rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) 59 | } 60 | } 61 | } 62 | 63 | func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { 64 | key := mark + c.ClientIP() 65 | if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) { 66 | c.Status(http.StatusTooManyRequests) 67 | c.Abort() 68 | return 69 | } 70 | } 71 | 72 | func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) { 73 | if common.RedisEnabled { 74 | return func(c *gin.Context) { 75 | redisRateLimiter(c, maxRequestNum, duration, mark) 76 | } 77 | } else { 78 | // It's safe to call multi times. 79 | inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) 80 | return func(c *gin.Context) { 81 | memoryRateLimiter(c, maxRequestNum, duration, mark) 82 | } 83 | } 84 | } 85 | 86 | func GlobalWebRateLimit() func(c *gin.Context) { 87 | return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW") 88 | } 89 | 90 | func GlobalAPIRateLimit() func(c *gin.Context) { 91 | return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA") 92 | } 93 | 94 | func CriticalRateLimit() func(c *gin.Context) { 95 | return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT") 96 | } 97 | 98 | func DownloadRateLimit() func(c *gin.Context) { 99 | return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW") 100 | } 101 | 102 | func UploadRateLimit() func(c *gin.Context) { 103 | return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP") 104 | } 105 | -------------------------------------------------------------------------------- /middleware/recover.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "one-api/common" 7 | "runtime/debug" 8 | 9 | "github.com/gin-gonic/gin" 10 | ) 11 | 12 | func RelayPanicRecover() gin.HandlerFunc { 13 | return func(c *gin.Context) { 14 | defer func() { 15 | if err := recover(); err != nil { 16 | common.SysError(fmt.Sprintf("panic detected: %v", err)) 17 | common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) 18 | c.JSON(http.StatusInternalServerError, gin.H{ 19 | "error": gin.H{ 20 | "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err), 21 | "type": "new_api_panic", 22 | }, 23 | }) 24 | c.Abort() 25 | } 26 | }() 27 | c.Next() 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /middleware/request-id.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "context" 5 | "one-api/common" 6 | 7 | "github.com/gin-gonic/gin" 8 | ) 9 | 10 | func RequestId() func(c *gin.Context) { 11 | return func(c *gin.Context) { 12 | id := common.GetTimeString() + common.GetRandomString(8) 13 | c.Set(common.RequestIdKey, id) 14 | // nolint:staticcheck 15 | ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) 16 | c.Request = c.Request.WithContext(ctx) 17 | c.Header(common.RequestIdKey, id) 18 | c.Next() 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /middleware/turnstile-check.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "net/url" 7 | "one-api/common" 8 | 9 | "github.com/gin-contrib/sessions" 10 | "github.com/gin-gonic/gin" 11 | ) 12 | 13 | type turnstileCheckResponse struct { 14 | Success bool `json:"success"` 15 | } 16 | 17 | func TurnstileCheck() gin.HandlerFunc { 18 | return func(c *gin.Context) { 19 | if common.TurnstileCheckEnabled { 20 | session := sessions.Default(c) 21 | turnstileChecked := session.Get("turnstile") 22 | if turnstileChecked != nil { 23 | c.Next() 24 | return 25 | } 26 | response := c.Query("turnstile") 27 | if response == "" { 28 | c.JSON(http.StatusOK, gin.H{ 29 | "success": false, 30 | "message": "Turnstile token 为空", 31 | }) 32 | c.Abort() 33 | return 34 | } 35 | rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{ 36 | "secret": {common.TurnstileSecretKey}, 37 | "response": {response}, 38 | "remoteip": {c.ClientIP()}, 39 | }) 40 | if err != nil { 41 | common.SysError(err.Error()) 42 | c.JSON(http.StatusOK, gin.H{ 43 | "success": false, 44 | "message": err.Error(), 45 | }) 46 | c.Abort() 47 | return 48 | } 49 | defer rawRes.Body.Close() 50 | var res turnstileCheckResponse 51 | err = json.NewDecoder(rawRes.Body).Decode(&res) 52 | if err != nil { 53 | common.SysError(err.Error()) 54 | c.JSON(http.StatusOK, gin.H{ 55 | "success": false, 56 | "message": err.Error(), 57 | }) 58 | c.Abort() 59 | return 60 | } 61 | if !res.Success { 62 | c.JSON(http.StatusOK, gin.H{ 63 | "success": false, 64 | "message": "Turnstile 校验失败,请刷新重试!", 65 | }) 66 | c.Abort() 67 | return 68 | } 69 | session.Set("turnstile", true) 70 | err = session.Save() 71 | if err != nil { 72 | c.JSON(http.StatusOK, gin.H{ 73 | "message": "无法保存会话信息,请重试", 74 | "success": false, 75 | }) 76 | return 77 | } 78 | } 79 | c.Next() 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /middleware/utils.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "one-api/common" 5 | 6 | "github.com/gin-gonic/gin" 7 | ) 8 | 9 | func abortWithMessage(c *gin.Context, statusCode int, message string) { 10 | c.JSON(statusCode, gin.H{ 11 | "error": gin.H{ 12 | "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), 13 | "type": "new_api_error", 14 | }, 15 | }) 16 | c.Abort() 17 | common.LogError(c.Request.Context(), message) 18 | } 19 | -------------------------------------------------------------------------------- /model/main.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "log" 5 | "one-api/common" 6 | "os" 7 | "strings" 8 | "sync" 9 | "time" 10 | 11 | "gorm.io/driver/mysql" 12 | "gorm.io/driver/postgres" 13 | "gorm.io/driver/sqlite" 14 | "gorm.io/gorm" 15 | ) 16 | 17 | var DB *gorm.DB 18 | 19 | func createRootAccountIfNeed() error { 20 | var user User 21 | //if user.Status != common.UserStatusEnabled { 22 | if err := DB.First(&user).Error; err != nil { 23 | common.SysLog("no user exists, create a root user for you: username is root, password is 123456") 24 | hashedPassword, err := common.Password2Hash("123456") 25 | if err != nil { 26 | return err 27 | } 28 | rootUser := User{ 29 | Username: "root", 30 | Password: hashedPassword, 31 | Role: common.RoleRootUser, 32 | Status: common.UserStatusEnabled, 33 | DisplayName: "Root User", 34 | AccessToken: common.GetUUID(), 35 | Quota: 100000000, 36 | } 37 | DB.Create(&rootUser) 38 | } 39 | return nil 40 | } 41 | 42 | func chooseDB() (*gorm.DB, error) { 43 | if os.Getenv("SQL_DSN") != "" { 44 | dsn := os.Getenv("SQL_DSN") 45 | if strings.HasPrefix(dsn, "postgres://") { 46 | // Use PostgreSQL 47 | common.SysLog("using PostgreSQL as database") 48 | common.UsingPostgreSQL = true 49 | return gorm.Open(postgres.New(postgres.Config{ 50 | DSN: dsn, 51 | PreferSimpleProtocol: true, // disables implicit prepared statement usage 52 | }), &gorm.Config{ 53 | PrepareStmt: true, // precompile SQL 54 | }) 55 | } 56 | // Use MySQL 57 | common.SysLog("using MySQL as database") 58 | // check parseTime 59 | if !strings.Contains(dsn, "parseTime") { 60 | if strings.Contains(dsn, "?") { 61 | dsn += "&parseTime=true" 62 | } else { 63 | dsn += "?parseTime=true" 64 | } 65 | } 66 | return gorm.Open(mysql.Open(dsn), &gorm.Config{ 67 | PrepareStmt: true, // precompile SQL 68 | }) 69 | } 70 | // Use SQLite 71 | common.SysLog("SQL_DSN not set, using SQLite as database") 72 | common.UsingSQLite = true 73 | return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ 74 | PrepareStmt: true, // precompile SQL 75 | }) 76 | } 77 | 78 | func InitDB() (err error) { 79 | db, err := chooseDB() 80 | if err == nil { 81 | if common.DebugEnabled { 82 | db = db.Debug() 83 | } 84 | DB = db 85 | sqlDB, err := DB.DB() 86 | if err != nil { 87 | return err 88 | } 89 | sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100)) 90 | sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) 91 | sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60))) 92 | 93 | if !common.IsMasterNode { 94 | return nil 95 | } 96 | err = createRootAccountIfNeed() 97 | return err 98 | } else { 99 | common.FatalLog(err) 100 | } 101 | return err 102 | } 103 | 104 | func CloseDB() error { 105 | sqlDB, err := DB.DB() 106 | if err != nil { 107 | return err 108 | } 109 | err = sqlDB.Close() 110 | return err 111 | } 112 | 113 | func MustMigrate() { 114 | if DB == nil { 115 | common.FatalLog("DB is nil") 116 | } 117 | common.SysLog("database migration started") 118 | TableList := []interface{}{ 119 | &Channel{}, 120 | &Log{}, 121 | &Midjourney{}, 122 | &Option{}, 123 | &Redemption{}, 124 | &Token{}, 125 | &TopUp{}, 126 | &QuotaData{}, 127 | &User{}, 128 | &UserCheckInLog{}, 129 | } 130 | if err := DB.AutoMigrate(TableList...); err != nil { 131 | common.FatalLog("migrate DB meet error", err.Error()) 132 | os.Exit(2) 133 | } 134 | common.SysLog("database migrated") 135 | } 136 | 137 | var ( 138 | lastPingTime time.Time 139 | pingMutex sync.Mutex 140 | ) 141 | 142 | func PingDB() error { 143 | pingMutex.Lock() 144 | defer pingMutex.Unlock() 145 | 146 | if time.Since(lastPingTime) < time.Second*10 { 147 | return nil 148 | } 149 | 150 | sqlDB, err := DB.DB() 151 | if err != nil { 152 | log.Printf("Error getting sql.DB from GORM: %v", err) 153 | return err 154 | } 155 | 156 | err = sqlDB.Ping() 157 | if err != nil { 158 | log.Printf("Error pinging DB: %v", err) 159 | return err 160 | } 161 | 162 | lastPingTime = time.Now() 163 | common.SysLog("Database pinged successfully") 164 | return nil 165 | } 166 | -------------------------------------------------------------------------------- /model/midjourney.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | type Midjourney struct { 4 | Id int `json:"id"` 5 | Code int `json:"code"` 6 | UserId int `json:"user_id" gorm:"index"` 7 | Action string `json:"action"` 8 | MjId string `json:"mj_id" gorm:"index"` 9 | Prompt string `json:"prompt"` 10 | PromptEn string `json:"prompt_en"` 11 | Description string `json:"description"` 12 | State string `json:"state"` 13 | SubmitTime int64 `json:"submit_time"` 14 | StartTime int64 `json:"start_time"` 15 | FinishTime int64 `json:"finish_time"` 16 | ImageUrl string `json:"image_url"` 17 | Status string `json:"status"` 18 | Progress string `json:"progress"` 19 | FailReason string `json:"fail_reason"` 20 | ChannelId int `json:"channel_id"` 21 | Quota int `json:"quota"` 22 | } 23 | 24 | // TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 25 | type TaskQueryParams struct { 26 | ChannelID string 27 | MjID string 28 | StartTimestamp string 29 | EndTimestamp string 30 | } 31 | 32 | func GetAllUserTask(userId int, startIdx int, num int, queryParams TaskQueryParams) []*Midjourney { 33 | var tasks []*Midjourney 34 | var err error 35 | 36 | // 初始化查询构建器 37 | query := DB.Where("user_id = ?", userId) 38 | 39 | if queryParams.MjID != "" { 40 | query = query.Where("mj_id = ?", queryParams.MjID) 41 | } 42 | if queryParams.StartTimestamp != "" { 43 | // 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析 44 | query = query.Where("submit_time >= ?", queryParams.StartTimestamp) 45 | } 46 | if queryParams.EndTimestamp != "" { 47 | query = query.Where("submit_time <= ?", queryParams.EndTimestamp) 48 | } 49 | 50 | // 获取数据 51 | err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error 52 | if err != nil { 53 | return nil 54 | } 55 | 56 | return tasks 57 | } 58 | 59 | func GetAllTasks(startIdx int, num int, queryParams TaskQueryParams) []*Midjourney { 60 | var tasks []*Midjourney 61 | var err error 62 | 63 | // 初始化查询构建器 64 | query := DB 65 | 66 | // 添加过滤条件 67 | if queryParams.ChannelID != "" { 68 | query = query.Where("channel_id = ?", queryParams.ChannelID) 69 | } 70 | if queryParams.MjID != "" { 71 | query = query.Where("mj_id = ?", queryParams.MjID) 72 | } 73 | if queryParams.StartTimestamp != "" { 74 | query = query.Where("submit_time >= ?", queryParams.StartTimestamp) 75 | } 76 | if queryParams.EndTimestamp != "" { 77 | query = query.Where("submit_time <= ?", queryParams.EndTimestamp) 78 | } 79 | 80 | // 获取数据 81 | err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error 82 | if err != nil { 83 | return nil 84 | } 85 | 86 | return tasks 87 | } 88 | 89 | func GetAllUnFinishTasks() []*Midjourney { 90 | var tasks []*Midjourney 91 | // get all tasks progress is not 100% 92 | err := DB.Where("progress != ?", "100%").Find(&tasks).Error 93 | if err != nil { 94 | return nil 95 | } 96 | return tasks 97 | } 98 | 99 | func GetByOnlyMJId(mjId string) *Midjourney { 100 | var mj *Midjourney 101 | err := DB.Where("mj_id = ?", mjId).First(&mj).Error 102 | if err != nil { 103 | return nil 104 | } 105 | return mj 106 | } 107 | 108 | func GetByMJId(userId int, mjId string) *Midjourney { 109 | var mj *Midjourney 110 | err := DB.Where("user_id = ? and mj_id = ?", userId, mjId).First(&mj).Error 111 | if err != nil { 112 | return nil 113 | } 114 | return mj 115 | } 116 | 117 | func GetByMJIds(userId int, mjIds []string) []*Midjourney { 118 | var mj []*Midjourney 119 | err := DB.Where("user_id = ? and mj_id in (?)", userId, mjIds).Find(&mj).Error 120 | if err != nil { 121 | return nil 122 | } 123 | return mj 124 | } 125 | 126 | func GetMjByuId(id int) *Midjourney { 127 | var mj *Midjourney 128 | err := DB.Where("id = ?", id).First(&mj).Error 129 | if err != nil { 130 | return nil 131 | } 132 | return mj 133 | } 134 | 135 | func UpdateProgress(id int, progress string) error { 136 | return DB.Model(&Midjourney{}).Where("id = ?", id).Update("progress", progress).Error 137 | } 138 | 139 | func (midjourney *Midjourney) Insert() error { 140 | err := DB.Create(midjourney).Error 141 | return err 142 | } 143 | 144 | func (midjourney *Midjourney) Update() error { 145 | err := DB.Save(midjourney).Error 146 | return err 147 | } 148 | 149 | func MjBulkUpdate(mjIds []string, params map[string]any) error { 150 | return DB.Model(&Midjourney{}). 151 | Where("mj_id in (?)", mjIds). 152 | Updates(params).Error 153 | } 154 | 155 | func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error { 156 | return DB.Model(&Midjourney{}). 157 | Where("id in (?)", taskIDs). 158 | Updates(params).Error 159 | } 160 | -------------------------------------------------------------------------------- /model/redemption.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "one-api/common" 7 | 8 | "gorm.io/gorm" 9 | ) 10 | 11 | type Redemption struct { 12 | Id int `json:"id"` 13 | UserId int `json:"user_id"` 14 | Key string `json:"key" gorm:"type:char(32);uniqueIndex"` 15 | Status int `json:"status" gorm:"default:1"` 16 | Name string `json:"name" gorm:"index"` 17 | Quota int `json:"quota" gorm:"default:100"` 18 | CreatedTime int64 `json:"created_time" gorm:"bigint"` 19 | RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"` 20 | Count int `json:"count" gorm:"-:all"` // only for api request 21 | UsedUserId int `json:"used_user_id"` 22 | } 23 | 24 | func GetAllRedemptions(startIdx int, num int) ([]*Redemption, error) { 25 | var redemptions []*Redemption 26 | err := DB.Order("id desc").Limit(num).Offset(startIdx).Find(&redemptions).Error 27 | return redemptions, err 28 | } 29 | 30 | func SearchRedemptions(keyword string) (redemptions []*Redemption, err error) { 31 | err = DB.Where("id = ? or name LIKE ?", keyword, keyword+"%").Find(&redemptions).Error 32 | return redemptions, err 33 | } 34 | 35 | func GetRedemptionById(id int) (*Redemption, error) { 36 | if id == 0 { 37 | return nil, errors.New("id 为空!") 38 | } 39 | redemption := Redemption{Id: id} 40 | err := DB.First(&redemption, "id = ?", id).Error 41 | return &redemption, err 42 | } 43 | 44 | func Redeem(key string, userId int) (quota int, err error) { 45 | if key == "" { 46 | return 0, errors.New("未提供兑换码") 47 | } 48 | if userId == 0 { 49 | return 0, errors.New("无效的 user id") 50 | } 51 | redemption := &Redemption{} 52 | 53 | keyCol := "`key`" 54 | if common.UsingPostgreSQL { 55 | keyCol = `"key"` 56 | } 57 | 58 | err = DB.Transaction(func(tx *gorm.DB) error { 59 | err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error 60 | if err != nil { 61 | return errors.New("无效的兑换码") 62 | } 63 | if redemption.Status != common.RedemptionCodeStatusEnabled { 64 | return errors.New("该兑换码已被使用") 65 | } 66 | err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error 67 | if err != nil { 68 | return err 69 | } 70 | redemption.RedeemedTime = common.GetTimestamp() 71 | redemption.Status = common.RedemptionCodeStatusUsed 72 | redemption.UsedUserId = userId 73 | err = tx.Save(redemption).Error 74 | return err 75 | }) 76 | if err != nil { 77 | return 0, errors.New("兑换失败," + err.Error()) 78 | } 79 | RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) 80 | return redemption.Quota, nil 81 | } 82 | 83 | func (redemption *Redemption) Insert() error { 84 | err := DB.Create(redemption).Error 85 | return err 86 | } 87 | 88 | func (redemption *Redemption) SelectUpdate() error { 89 | // This can update zero values 90 | return DB.Model(redemption).Select("redeemed_time", "status").Updates(redemption).Error 91 | } 92 | 93 | // Update Make sure your token's fields is completed, because this will update non-zero values 94 | func (redemption *Redemption) Update() error { 95 | err := DB.Model(redemption).Select("name", "status", "quota", "redeemed_time").Updates(redemption).Error 96 | return err 97 | } 98 | 99 | func (redemption *Redemption) Delete() error { 100 | err := DB.Delete(redemption).Error 101 | return err 102 | } 103 | 104 | func DeleteRedemptionById(id int) (err error) { 105 | if id == 0 { 106 | return errors.New("id 为空!") 107 | } 108 | redemption := Redemption{Id: id} 109 | err = DB.Where(redemption).First(&redemption).Error 110 | if err != nil { 111 | return err 112 | } 113 | return redemption.Delete() 114 | } 115 | -------------------------------------------------------------------------------- /model/topup.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | type TopUp struct { 4 | Id int `json:"id"` 5 | UserId int `json:"user_id" gorm:"index"` 6 | Amount int `json:"amount"` 7 | Money float64 `json:"money"` 8 | TradeNo string `json:"trade_no"` 9 | CreateTime int64 `json:"create_time"` 10 | Status string `json:"status"` 11 | } 12 | 13 | func (topUp *TopUp) Insert() error { 14 | err := DB.Create(topUp).Error 15 | return err 16 | } 17 | 18 | func (topUp *TopUp) Update() error { 19 | err := DB.Save(topUp).Error 20 | return err 21 | } 22 | 23 | func GetTopUpById(id int) *TopUp { 24 | var topUp *TopUp 25 | err := DB.Where("id = ?", id).First(&topUp).Error 26 | if err != nil { 27 | return nil 28 | } 29 | return topUp 30 | } 31 | 32 | func GetTopUpByTradeNo(tradeNo string) *TopUp { 33 | var topUp *TopUp 34 | err := DB.Where("trade_no = ?", tradeNo).First(&topUp).Error 35 | if err != nil { 36 | return nil 37 | } 38 | return topUp 39 | } 40 | -------------------------------------------------------------------------------- /model/user_checkin.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "math/rand" 7 | "one-api/common" 8 | "time" 9 | _ "time/tzdata" 10 | 11 | "gorm.io/gorm" 12 | ) 13 | 14 | const ( 15 | DollarToToken = 500000 16 | ) 17 | 18 | type UserCheckInLog struct { 19 | ID uint `gorm:"primaryKey" json:"id"` 20 | UserID int `gorm:"index;not null;column:user_id" json:"user_id"` // 关联到User模型,明确指定列名为user_id 21 | Date time.Time `gorm:"uniqueIndex:user_date_idx;column:date" json:"date"` // 确保同一个用户每天只能有一条记录,明确指定列名为date 22 | GiftQuota int `gorm:"not null;column:gift_quota" json:"gift_quota"` // 签到赠送的随机quota值,明确指定列名为gift_quota 23 | CreatedAt time.Time `gorm:"autoCreateTime;column:created_at" json:"created_at"` // 记录创建时间,明确指定列名为created_at 24 | } 25 | 26 | func (UserCheckInLog) TableName() string { 27 | return "user_check_in_logs" 28 | } 29 | 30 | func getRandomQuota() (int, error) { 31 | minDollar, err := common.GetIntEnv("MIN_CHECKIN_DOLLAR") 32 | if err != nil { 33 | return 0, err 34 | 35 | } 36 | maxDollar, err := common.GetIntEnv("MAX_CHECKIN_DOLLAR") 37 | if err != nil { 38 | return 0, err 39 | } 40 | if minDollar >= maxDollar { 41 | return 0, fmt.Errorf("MIN_CHECKIN_DOLLAR 必须小于 MAX_CHECKIN_DOLLAR") 42 | } 43 | randomQuota := rand.Intn(maxDollar-minDollar) + minDollar 44 | randomQuota *= DollarToToken 45 | return randomQuota, nil 46 | } 47 | 48 | // CheckIn 用户签到功能 49 | func CheckIn(userID int) (*UserCheckInLog, error) { 50 | var checkInLog *UserCheckInLog 51 | // 确保在一个事务中执行 52 | err := DB.Transaction(func(tx *gorm.DB) error { 53 | location, _ := time.LoadLocation("Asia/Shanghai") 54 | now := time.Now().In(location) 55 | today := now.Format("2006-01-02") // 今天的日期 56 | var count int64 57 | // 检查用户今天是否已经签到,使用 +8:00 时区 58 | if err := tx.Model(&UserCheckInLog{}).Where("user_id = ? AND DATE(CONVERT_TZ(date,'+00:00','+08:00')) = ?", userID, today).Count(&count).Error; err != nil { 59 | return err 60 | } 61 | if count > 0 { 62 | // 如果已经签到过,返回错误 63 | return errors.New("今日已签到,请明天再来") 64 | } 65 | 66 | // 生成随机GiftQuota值,比如1到10之间 67 | giftQuota, err := getRandomQuota() 68 | if err != nil { 69 | return err 70 | } 71 | // 增加用户Quota 72 | if err := IncreaseUserQuotaWithTX(tx, userID, giftQuota); err != nil { 73 | return err 74 | } 75 | 76 | // 创建签到记录 77 | checkInLog = &UserCheckInLog{ 78 | UserID: userID, 79 | Date: now, 80 | GiftQuota: giftQuota, 81 | } 82 | if err := tx.Create(checkInLog).Error; err != nil { 83 | return err 84 | } 85 | 86 | RecordLog(checkInLog.UserID, LogTypeTopup, 87 | fmt.Sprintf("签到赠送金额: %v", common.LogQuota(checkInLog.GiftQuota))) 88 | return nil 89 | }) 90 | 91 | // 如果事务执行成功,返回签到记录和nil错误;否则返回nil和相应的错误信息 92 | if err != nil { 93 | return nil, err 94 | } 95 | return checkInLog, nil 96 | } 97 | -------------------------------------------------------------------------------- /model/utils.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "one-api/common" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | const ( 10 | BatchUpdateTypeUserQuota = iota 11 | BatchUpdateTypeTokenQuota 12 | BatchUpdateTypeUsedQuota 13 | BatchUpdateTypeChannelUsedQuota 14 | BatchUpdateTypeRequestCount 15 | BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock 16 | ) 17 | 18 | var batchUpdateStores []map[int]int 19 | var batchUpdateLocks []sync.Mutex 20 | 21 | func init() { 22 | for i := 0; i < BatchUpdateTypeCount; i++ { 23 | batchUpdateStores = append(batchUpdateStores, make(map[int]int)) 24 | batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) 25 | } 26 | } 27 | 28 | func InitBatchUpdater() { 29 | go func() { 30 | for { 31 | time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) 32 | batchUpdate() 33 | } 34 | }() 35 | } 36 | 37 | func addNewRecord(type_ int, id int, value int) { 38 | batchUpdateLocks[type_].Lock() 39 | defer batchUpdateLocks[type_].Unlock() 40 | if _, ok := batchUpdateStores[type_][id]; !ok { 41 | batchUpdateStores[type_][id] = value 42 | } else { 43 | batchUpdateStores[type_][id] += value 44 | } 45 | } 46 | 47 | func batchUpdate() { 48 | common.SysLog("batch update started") 49 | for i := 0; i < BatchUpdateTypeCount; i++ { 50 | batchUpdateLocks[i].Lock() 51 | store := batchUpdateStores[i] 52 | batchUpdateStores[i] = make(map[int]int) 53 | batchUpdateLocks[i].Unlock() 54 | // TODO: maybe we can combine updates with same key? 55 | for key, value := range store { 56 | switch i { 57 | case BatchUpdateTypeUserQuota: 58 | err := increaseUserQuota(key, value) 59 | if err != nil { 60 | common.SysError("failed to batch update user quota: " + err.Error()) 61 | } 62 | case BatchUpdateTypeTokenQuota: 63 | err := increaseTokenQuota(key, value) 64 | if err != nil { 65 | common.SysError("failed to batch update token quota: " + err.Error()) 66 | } 67 | case BatchUpdateTypeUsedQuota: 68 | updateUserUsedQuota(key, value) 69 | case BatchUpdateTypeRequestCount: 70 | updateUserRequestCount(key, value) 71 | case BatchUpdateTypeChannelUsedQuota: 72 | updateChannelUsedQuota(key, value) 73 | } 74 | } 75 | } 76 | common.SysLog("batch update finished") 77 | } 78 | -------------------------------------------------------------------------------- /one-api.service: -------------------------------------------------------------------------------- 1 | # File path: /etc/systemd/system/one-api.service 2 | # sudo systemctl daemon-reload 3 | # sudo systemctl start one-api 4 | # sudo systemctl enable one-api 5 | # sudo systemctl status one-api 6 | [Unit] 7 | Description=One API Service 8 | After=network.target 9 | 10 | [Service] 11 | User=ubuntu # 注意修改用户名 12 | WorkingDirectory=/path/to/one-api # 注意修改路径 13 | ExecStart=/path/to/one-api/one-api --port 3000 --log-dir /path/to/one-api/logs # 注意修改路径和端口号 14 | Restart=always 15 | RestartSec=5 16 | 17 | [Install] 18 | WantedBy=multi-user.target 19 | -------------------------------------------------------------------------------- /relay/channel/adapter.go: -------------------------------------------------------------------------------- 1 | package channel 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | "one-api/dto" 7 | relaycommon "one-api/relay/common" 8 | 9 | "github.com/gin-gonic/gin" 10 | ) 11 | 12 | type Adaptor interface { 13 | // Init IsStream bool 14 | Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) 15 | GetRequestURL(info *relaycommon.RelayInfo) (string, error) 16 | SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error 17 | ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) 18 | DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) 19 | DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) 20 | GetModelList() []string 21 | GetChannelName() string 22 | } 23 | -------------------------------------------------------------------------------- /relay/channel/ai360/constants.go: -------------------------------------------------------------------------------- 1 | package ai360 2 | 3 | var ModelList = []string{ 4 | "360GPT_S2_V9", 5 | "embedding-bert-512-v1", 6 | "embedding_s1_v1", 7 | "semantic_similarity_s1_v1", 8 | } 9 | -------------------------------------------------------------------------------- /relay/channel/ali/adaptor.go: -------------------------------------------------------------------------------- 1 | package ali 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "one-api/dto" 9 | "one-api/relay/channel" 10 | relaycommon "one-api/relay/common" 11 | "one-api/relay/constant" 12 | 13 | "github.com/gin-gonic/gin" 14 | ) 15 | 16 | type Adaptor struct { 17 | } 18 | 19 | func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { 20 | 21 | } 22 | 23 | func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { 24 | fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", info.BaseUrl) 25 | if info.RelayMode == constant.RelayModeEmbeddings { 26 | fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl) 27 | } 28 | return fullRequestURL, nil 29 | } 30 | 31 | func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { 32 | channel.SetupApiRequestHeader(info, c, req) 33 | req.Header.Set("Authorization", "Bearer "+info.ApiKey) 34 | if info.IsStream { 35 | req.Header.Set("X-DashScope-SSE", "enable") 36 | } 37 | if c.GetString("plugin") != "" { 38 | req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) 39 | } 40 | return nil 41 | } 42 | 43 | func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { 44 | if request == nil { 45 | return nil, errors.New("request is nil") 46 | } 47 | switch relayMode { 48 | case constant.RelayModeEmbeddings: 49 | baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request) 50 | return baiduEmbeddingRequest, nil 51 | default: 52 | baiduRequest := requestOpenAI2Ali(*request) 53 | return baiduRequest, nil 54 | } 55 | } 56 | 57 | func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { 58 | return channel.DoApiRequest(a, c, info, requestBody) 59 | } 60 | 61 | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { 62 | if info.IsStream { 63 | err, usage = aliStreamHandler(c, resp) 64 | } else { 65 | switch info.RelayMode { 66 | case constant.RelayModeEmbeddings: 67 | err, usage = aliEmbeddingHandler(c, resp) 68 | default: 69 | err, usage = aliHandler(c, resp) 70 | } 71 | } 72 | return 73 | } 74 | 75 | func (a *Adaptor) GetModelList() []string { 76 | return ModelList 77 | } 78 | 79 | func (a *Adaptor) GetChannelName() string { 80 | return ChannelName 81 | } 82 | -------------------------------------------------------------------------------- /relay/channel/ali/constants.go: -------------------------------------------------------------------------------- 1 | package ali 2 | 3 | var ModelList = []string{ 4 | "qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", 5 | "text-embedding-v1", 6 | } 7 | 8 | var ChannelName = "ali" 9 | -------------------------------------------------------------------------------- /relay/channel/ali/dto.go: -------------------------------------------------------------------------------- 1 | package ali 2 | 3 | type AliMessage struct { 4 | Content string `json:"content"` 5 | Role string `json:"role"` 6 | } 7 | 8 | type AliInput struct { 9 | Prompt string `json:"prompt"` 10 | History []AliMessage `json:"history"` 11 | } 12 | 13 | type AliParameters struct { 14 | TopP float64 `json:"top_p,omitempty"` 15 | TopK int `json:"top_k,omitempty"` 16 | Seed uint64 `json:"seed,omitempty"` 17 | EnableSearch bool `json:"enable_search,omitempty"` 18 | IncrementalOutput bool `json:"incremental_output,omitempty"` 19 | } 20 | 21 | type AliChatRequest struct { 22 | Model string `json:"model"` 23 | Input AliInput `json:"input"` 24 | Parameters AliParameters `json:"parameters,omitempty"` 25 | } 26 | 27 | type AliEmbeddingRequest struct { 28 | Model string `json:"model"` 29 | Input struct { 30 | Texts []string `json:"texts"` 31 | } `json:"input"` 32 | Parameters *struct { 33 | TextType string `json:"text_type,omitempty"` 34 | } `json:"parameters,omitempty"` 35 | } 36 | 37 | type AliEmbedding struct { 38 | Embedding []float64 `json:"embedding"` 39 | TextIndex int `json:"text_index"` 40 | } 41 | 42 | type AliEmbeddingResponse struct { 43 | Output struct { 44 | Embeddings []AliEmbedding `json:"embeddings"` 45 | } `json:"output"` 46 | Usage AliUsage `json:"usage"` 47 | AliError 48 | } 49 | 50 | type AliError struct { 51 | Code string `json:"code"` 52 | Message string `json:"message"` 53 | RequestId string `json:"request_id"` 54 | } 55 | 56 | type AliUsage struct { 57 | InputTokens int `json:"input_tokens"` 58 | OutputTokens int `json:"output_tokens"` 59 | TotalTokens int `json:"total_tokens"` 60 | } 61 | 62 | type AliOutput struct { 63 | Text string `json:"text"` 64 | FinishReason string `json:"finish_reason"` 65 | } 66 | 67 | type AliChatResponse struct { 68 | Output AliOutput `json:"output"` 69 | Usage AliUsage `json:"usage"` 70 | AliError 71 | } 72 | -------------------------------------------------------------------------------- /relay/channel/api_request.go: -------------------------------------------------------------------------------- 1 | package channel 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "one-api/relay/common" 9 | "one-api/service" 10 | 11 | "github.com/gin-gonic/gin" 12 | ) 13 | 14 | func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) { 15 | req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) 16 | req.Header.Set("Accept", c.Request.Header.Get("Accept")) 17 | if info.IsStream && c.Request.Header.Get("Accept") == "" { 18 | req.Header.Set("Accept", "text/event-stream") 19 | } 20 | } 21 | 22 | func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { 23 | fullRequestURL, err := a.GetRequestURL(info) 24 | if err != nil { 25 | return nil, fmt.Errorf("get request url failed: %w", err) 26 | } 27 | req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) 28 | if err != nil { 29 | return nil, fmt.Errorf("new request failed: %w", err) 30 | } 31 | err = a.SetupRequestHeader(c, req, info) 32 | if err != nil { 33 | return nil, fmt.Errorf("setup request header failed: %w", err) 34 | } 35 | resp, err := doRequest(c, req) 36 | if err != nil { 37 | return nil, fmt.Errorf("do request failed: %w", err) 38 | } 39 | return resp, nil 40 | } 41 | 42 | func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) { 43 | resp, err := service.GetHttpClient().Do(req) 44 | if err != nil { 45 | return nil, err 46 | } 47 | if resp == nil { 48 | return nil, errors.New("resp is nil") 49 | } 50 | _ = req.Body.Close() 51 | _ = c.Request.Body.Close() 52 | return resp, nil 53 | } 54 | -------------------------------------------------------------------------------- /relay/channel/baidu/adaptor.go: -------------------------------------------------------------------------------- 1 | package baidu 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "net/http" 7 | "one-api/dto" 8 | "one-api/relay/channel" 9 | relaycommon "one-api/relay/common" 10 | "one-api/relay/constant" 11 | 12 | "github.com/gin-gonic/gin" 13 | ) 14 | 15 | type Adaptor struct { 16 | } 17 | 18 | func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { 19 | 20 | } 21 | 22 | func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { 23 | var fullRequestURL string 24 | switch info.UpstreamModelName { 25 | case "ERNIE-Bot-4": 26 | fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" 27 | case "ERNIE-Bot-8K": 28 | fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k" 29 | case "ERNIE-Bot": 30 | fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" 31 | case "ERNIE-Speed": 32 | fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed" 33 | case "ERNIE-Bot-turbo": 34 | fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" 35 | case "BLOOMZ-7B": 36 | fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" 37 | case "Embedding-V1": 38 | fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" 39 | } 40 | var accessToken string 41 | var err error 42 | if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil { 43 | return "", err 44 | } 45 | fullRequestURL += "?access_token=" + accessToken 46 | return fullRequestURL, nil 47 | } 48 | 49 | func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { 50 | channel.SetupApiRequestHeader(info, c, req) 51 | req.Header.Set("Authorization", "Bearer "+info.ApiKey) 52 | return nil 53 | } 54 | 55 | func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { 56 | if request == nil { 57 | return nil, errors.New("request is nil") 58 | } 59 | switch relayMode { 60 | case constant.RelayModeEmbeddings: 61 | baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request) 62 | return baiduEmbeddingRequest, nil 63 | default: 64 | baiduRequest := requestOpenAI2Baidu(*request) 65 | return baiduRequest, nil 66 | } 67 | } 68 | 69 | func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { 70 | return channel.DoApiRequest(a, c, info, requestBody) 71 | } 72 | 73 | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { 74 | if info.IsStream { 75 | err, usage = baiduStreamHandler(c, resp) 76 | } else { 77 | switch info.RelayMode { 78 | case constant.RelayModeEmbeddings: 79 | err, usage = baiduEmbeddingHandler(c, resp) 80 | default: 81 | err, usage = baiduHandler(c, resp) 82 | } 83 | } 84 | return 85 | } 86 | 87 | func (a *Adaptor) GetModelList() []string { 88 | return ModelList 89 | } 90 | 91 | func (a *Adaptor) GetChannelName() string { 92 | return ChannelName 93 | } 94 | -------------------------------------------------------------------------------- /relay/channel/baidu/constants.go: -------------------------------------------------------------------------------- 1 | package baidu 2 | 3 | var ModelList = []string{ 4 | "ERNIE-Bot-4", 5 | "ERNIE-Bot-8K", 6 | "ERNIE-Bot", 7 | "ERNIE-Speed", 8 | "ERNIE-Bot-turbo", 9 | "Embedding-V1", 10 | } 11 | 12 | var ChannelName = "baidu" 13 | -------------------------------------------------------------------------------- /relay/channel/baidu/dto.go: -------------------------------------------------------------------------------- 1 | package baidu 2 | 3 | import ( 4 | "one-api/dto" 5 | "time" 6 | ) 7 | 8 | type BaiduMessage struct { 9 | Role string `json:"role"` 10 | Content string `json:"content"` 11 | } 12 | 13 | type BaiduChatRequest struct { 14 | Messages []BaiduMessage `json:"messages"` 15 | Stream bool `json:"stream"` 16 | UserId string `json:"user_id,omitempty"` 17 | } 18 | 19 | type Error struct { 20 | ErrorCode int `json:"error_code"` 21 | ErrorMsg string `json:"error_msg"` 22 | } 23 | 24 | type BaiduChatResponse struct { 25 | Id string `json:"id"` 26 | Object string `json:"object"` 27 | Created int64 `json:"created"` 28 | Result string `json:"result"` 29 | IsTruncated bool `json:"is_truncated"` 30 | NeedClearHistory bool `json:"need_clear_history"` 31 | Usage dto.Usage `json:"usage"` 32 | Error 33 | } 34 | 35 | type BaiduChatStreamResponse struct { 36 | BaiduChatResponse 37 | SentenceId int `json:"sentence_id"` 38 | IsEnd bool `json:"is_end"` 39 | } 40 | 41 | type BaiduEmbeddingRequest struct { 42 | Input []string `json:"input"` 43 | } 44 | 45 | type BaiduEmbeddingData struct { 46 | Object string `json:"object"` 47 | Embedding []float64 `json:"embedding"` 48 | Index int `json:"index"` 49 | } 50 | 51 | type BaiduEmbeddingResponse struct { 52 | Id string `json:"id"` 53 | Object string `json:"object"` 54 | Created int64 `json:"created"` 55 | Data []BaiduEmbeddingData `json:"data"` 56 | Usage dto.Usage `json:"usage"` 57 | Error 58 | } 59 | 60 | type BaiduAccessToken struct { 61 | AccessToken string `json:"access_token"` 62 | Error string `json:"error,omitempty"` 63 | ErrorDescription string `json:"error_description,omitempty"` 64 | ExpiresIn int64 `json:"expires_in,omitempty"` 65 | ExpiresAt time.Time `json:"-"` 66 | } 67 | 68 | type BaiduTokenResponse struct { 69 | ExpiresIn int `json:"expires_in"` 70 | AccessToken string `json:"access_token"` 71 | } 72 | -------------------------------------------------------------------------------- /relay/channel/claude/adaptor.go: -------------------------------------------------------------------------------- 1 | package claude 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "one-api/dto" 9 | "one-api/relay/channel" 10 | relaycommon "one-api/relay/common" 11 | "one-api/service" 12 | 13 | "strings" 14 | 15 | "github.com/gin-gonic/gin" 16 | ) 17 | 18 | const ( 19 | RequestModeCompletion = 1 20 | RequestModeMessage = 2 21 | ) 22 | 23 | type Adaptor struct { 24 | RequestMode int 25 | } 26 | 27 | func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { 28 | if strings.HasPrefix(info.UpstreamModelName, "claude-3") { 29 | a.RequestMode = RequestModeMessage 30 | } else { 31 | a.RequestMode = RequestModeCompletion 32 | } 33 | } 34 | 35 | func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { 36 | if a.RequestMode == RequestModeMessage { 37 | return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil 38 | } else { 39 | return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil 40 | } 41 | } 42 | 43 | func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { 44 | channel.SetupApiRequestHeader(info, c, req) 45 | req.Header.Set("x-api-key", info.ApiKey) 46 | anthropicVersion := c.Request.Header.Get("anthropic-version") 47 | if anthropicVersion == "" { 48 | anthropicVersion = "2023-06-01" 49 | } 50 | req.Header.Set("anthropic-version", anthropicVersion) 51 | return nil 52 | } 53 | 54 | func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { 55 | if request == nil { 56 | return nil, errors.New("request is nil") 57 | } 58 | //if a.RequestMode == RequestModeCompletion { 59 | // return requestOpenAI2ClaudeComplete(*request), nil 60 | //} else { 61 | // return requestOpenAI2ClaudeMessage(*request), nil 62 | //} 63 | return request, nil 64 | } 65 | 66 | func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { 67 | return channel.DoApiRequest(a, c, info, requestBody) 68 | } 69 | 70 | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { 71 | if info.IsStream { 72 | var responseText string 73 | err, responseText = claudeStreamHandler(c, resp) 74 | usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) 75 | } else { 76 | err, usage = claudeHandler(c, resp, info.PromptTokens, info.UpstreamModelName) 77 | } 78 | return 79 | } 80 | 81 | func (a *Adaptor) GetModelList() []string { 82 | return ModelList 83 | } 84 | 85 | func (a *Adaptor) GetChannelName() string { 86 | return ChannelName 87 | } 88 | -------------------------------------------------------------------------------- /relay/channel/claude/constants.go: -------------------------------------------------------------------------------- 1 | package claude 2 | 3 | var ModelList = []string{ 4 | "claude-instant-1", "claude-2", "claude-2.0", "claude-2.1", 5 | } 6 | 7 | var ChannelName = "claude" 8 | -------------------------------------------------------------------------------- /relay/channel/claude/dto.go: -------------------------------------------------------------------------------- 1 | package claude 2 | 3 | type ClaudeMetadata struct { 4 | UserId string `json:"user_id"` 5 | } 6 | 7 | type ClaudeRequest struct { 8 | Model string `json:"model"` 9 | Prompt string `json:"prompt"` 10 | MaxTokensToSample uint `json:"max_tokens_to_sample"` 11 | StopSequences []string `json:"stop_sequences,omitempty"` 12 | Temperature float64 `json:"temperature,omitempty"` 13 | TopP float64 `json:"top_p,omitempty"` 14 | TopK int `json:"top_k,omitempty"` 15 | //ClaudeMetadata `json:"metadata,omitempty"` 16 | Stream bool `json:"stream,omitempty"` 17 | } 18 | 19 | type ClaudeError struct { 20 | Type string `json:"type"` 21 | Message string `json:"message"` 22 | } 23 | 24 | type ClaudeResponse struct { 25 | Completion string `json:"completion"` 26 | StopReason string `json:"stop_reason"` 27 | Model string `json:"model"` 28 | Error ClaudeError `json:"error"` 29 | } 30 | -------------------------------------------------------------------------------- /relay/channel/gemini/adaptor.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "one-api/dto" 9 | "one-api/relay/channel" 10 | relaycommon "one-api/relay/common" 11 | "one-api/service" 12 | 13 | "github.com/gin-gonic/gin" 14 | ) 15 | 16 | type Adaptor struct { 17 | } 18 | 19 | func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { 20 | } 21 | 22 | func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { 23 | version := "v1" 24 | action := "generateContent" 25 | if info.IsStream { 26 | action = "streamGenerateContent" 27 | } 28 | return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil 29 | } 30 | 31 | func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { 32 | channel.SetupApiRequestHeader(info, c, req) 33 | req.Header.Set("x-goog-api-key", info.ApiKey) 34 | return nil 35 | } 36 | 37 | func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { 38 | if request == nil { 39 | return nil, errors.New("request is nil") 40 | } 41 | return CovertGemini2OpenAI(*request), nil 42 | } 43 | 44 | func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { 45 | return channel.DoApiRequest(a, c, info, requestBody) 46 | } 47 | 48 | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { 49 | if info.IsStream { 50 | var responseText string 51 | err, responseText = geminiChatStreamHandler(c, resp) 52 | usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) 53 | } else { 54 | err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName) 55 | } 56 | return 57 | } 58 | 59 | func (a *Adaptor) GetModelList() []string { 60 | return ModelList 61 | } 62 | 63 | func (a *Adaptor) GetChannelName() string { 64 | return ChannelName 65 | } 66 | -------------------------------------------------------------------------------- /relay/channel/gemini/constant.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | const ( 4 | GeminiVisionMaxImageNum = 16 5 | ) 6 | 7 | var ModelList = []string{ 8 | "gemini-pro", 9 | "gemini-pro-vision", 10 | } 11 | 12 | var ChannelName = "google gemini" 13 | -------------------------------------------------------------------------------- /relay/channel/gemini/dto.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | type GeminiChatRequest struct { 4 | Contents []GeminiChatContent `json:"contents"` 5 | SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` 6 | GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` 7 | Tools []GeminiChatTools `json:"tools,omitempty"` 8 | } 9 | 10 | type GeminiInlineData struct { 11 | MimeType string `json:"mimeType"` 12 | Data string `json:"data"` 13 | } 14 | 15 | type GeminiPart struct { 16 | Text string `json:"text,omitempty"` 17 | InlineData *GeminiInlineData `json:"inlineData,omitempty"` 18 | } 19 | 20 | type GeminiChatContent struct { 21 | Role string `json:"role,omitempty"` 22 | Parts []GeminiPart `json:"parts"` 23 | } 24 | 25 | type GeminiChatSafetySettings struct { 26 | Category string `json:"category"` 27 | Threshold string `json:"threshold"` 28 | } 29 | 30 | type GeminiChatTools struct { 31 | FunctionDeclarations any `json:"functionDeclarations,omitempty"` 32 | } 33 | 34 | type GeminiChatGenerationConfig struct { 35 | Temperature float64 `json:"temperature,omitempty"` 36 | TopP float64 `json:"topP,omitempty"` 37 | TopK float64 `json:"topK,omitempty"` 38 | MaxOutputTokens uint `json:"maxOutputTokens,omitempty"` 39 | CandidateCount int `json:"candidateCount,omitempty"` 40 | StopSequences []string `json:"stopSequences,omitempty"` 41 | } 42 | 43 | type GeminiChatCandidate struct { 44 | Content GeminiChatContent `json:"content"` 45 | FinishReason string `json:"finishReason"` 46 | Index int64 `json:"index"` 47 | SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` 48 | } 49 | 50 | type GeminiChatSafetyRating struct { 51 | Category string `json:"category"` 52 | Probability string `json:"probability"` 53 | } 54 | 55 | type GeminiChatPromptFeedback struct { 56 | SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` 57 | } 58 | 59 | type GeminiChatResponse struct { 60 | Candidates []GeminiChatCandidate `json:"candidates"` 61 | PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` 62 | } 63 | -------------------------------------------------------------------------------- /relay/channel/moonshot/constants.go: -------------------------------------------------------------------------------- 1 | package moonshot 2 | 3 | var ModelList = []string{ 4 | "moonshot-v1-8k", 5 | "moonshot-v1-32k", 6 | "moonshot-v1-128k", 7 | } 8 | -------------------------------------------------------------------------------- /relay/channel/openai/adaptor.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "one-api/common" 9 | "one-api/dto" 10 | "one-api/relay/channel" 11 | "one-api/relay/channel/ai360" 12 | "one-api/relay/channel/moonshot" 13 | relaycommon "one-api/relay/common" 14 | "one-api/service" 15 | "strings" 16 | 17 | "github.com/gin-gonic/gin" 18 | ) 19 | 20 | type Adaptor struct { 21 | ChannelType int 22 | } 23 | 24 | func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { 25 | a.ChannelType = info.ChannelType 26 | } 27 | 28 | func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { 29 | if info.ChannelType == common.ChannelTypeAzure { 30 | // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api 31 | requestURL := strings.Split(info.RequestURLPath, "?")[0] 32 | requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, info.ApiVersion) 33 | task := strings.TrimPrefix(requestURL, "/v1/") 34 | model_ := info.UpstreamModelName 35 | model_ = strings.Replace(model_, ".", "", -1) 36 | // https://github.com/songquanpeng/one-api/issues/67 37 | model_ = strings.TrimSuffix(model_, "-0301") 38 | model_ = strings.TrimSuffix(model_, "-0314") 39 | model_ = strings.TrimSuffix(model_, "-0613") 40 | 41 | requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) 42 | return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil 43 | } 44 | return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil 45 | } 46 | 47 | func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { 48 | channel.SetupApiRequestHeader(info, c, req) 49 | if info.ChannelType == common.ChannelTypeAzure { 50 | req.Header.Set("api-key", info.ApiKey) 51 | return nil 52 | } 53 | req.Header.Set("Authorization", "Bearer "+info.ApiKey) 54 | //if info.ChannelType == common.ChannelTypeOpenRouter { 55 | // req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") 56 | // req.Header.Set("X-Title", "One API") 57 | //} 58 | return nil 59 | } 60 | 61 | func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { 62 | if request == nil { 63 | return nil, errors.New("request is nil") 64 | } 65 | return request, nil 66 | } 67 | 68 | func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { 69 | return channel.DoApiRequest(a, c, info, requestBody) 70 | } 71 | 72 | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { 73 | if info.IsStream { 74 | var responseText string 75 | err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode) 76 | usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) 77 | } else { 78 | err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) 79 | } 80 | return 81 | } 82 | 83 | func (a *Adaptor) GetModelList() []string { 84 | switch a.ChannelType { 85 | case common.ChannelType360: 86 | return ai360.ModelList 87 | case common.ChannelTypeMoonshot: 88 | return moonshot.ModelList 89 | default: 90 | return ModelList 91 | } 92 | } 93 | 94 | func (a *Adaptor) GetChannelName() string { 95 | return ChannelName 96 | } 97 | -------------------------------------------------------------------------------- /relay/channel/openai/constant.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | var ModelList = []string{ 4 | "gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125", 5 | "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", 6 | "gpt-3.5-turbo-instruct", 7 | "gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview", 8 | "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", 9 | "gpt-4-turbo-preview", 10 | "gpt-4-vision-preview", 11 | "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", 12 | "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003", 13 | "text-moderation-latest", "text-moderation-stable", 14 | "text-davinci-edit-001", 15 | "davinci-002", "babbage-002", 16 | "dall-e-2", "dall-e-3", 17 | "whisper-1", 18 | "tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", 19 | } 20 | 21 | var ChannelName = "openai" 22 | -------------------------------------------------------------------------------- /relay/channel/palm/adaptor.go: -------------------------------------------------------------------------------- 1 | package palm 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "one-api/dto" 9 | "one-api/relay/channel" 10 | relaycommon "one-api/relay/common" 11 | "one-api/service" 12 | 13 | "github.com/gin-gonic/gin" 14 | ) 15 | 16 | type Adaptor struct { 17 | } 18 | 19 | func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { 20 | } 21 | 22 | func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { 23 | return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil 24 | } 25 | 26 | func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { 27 | channel.SetupApiRequestHeader(info, c, req) 28 | req.Header.Set("x-goog-api-key", info.ApiKey) 29 | return nil 30 | } 31 | 32 | func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { 33 | if request == nil { 34 | return nil, errors.New("request is nil") 35 | } 36 | return request, nil 37 | } 38 | 39 | func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { 40 | return channel.DoApiRequest(a, c, info, requestBody) 41 | } 42 | 43 | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { 44 | if info.IsStream { 45 | var responseText string 46 | err, responseText = palmStreamHandler(c, resp) 47 | usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) 48 | } else { 49 | err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName) 50 | } 51 | return 52 | } 53 | 54 | func (a *Adaptor) GetModelList() []string { 55 | return ModelList 56 | } 57 | 58 | func (a *Adaptor) GetChannelName() string { 59 | return ChannelName 60 | } 61 | -------------------------------------------------------------------------------- /relay/channel/palm/constants.go: -------------------------------------------------------------------------------- 1 | package palm 2 | 3 | var ModelList = []string{ 4 | "PaLM-2", 5 | } 6 | 7 | var ChannelName = "google palm" 8 | -------------------------------------------------------------------------------- /relay/channel/palm/dto.go: -------------------------------------------------------------------------------- 1 | package palm 2 | 3 | import "one-api/dto" 4 | 5 | type PaLMChatMessage struct { 6 | Author string `json:"author"` 7 | Content string `json:"content"` 8 | } 9 | 10 | type PaLMFilter struct { 11 | Reason string `json:"reason"` 12 | Message string `json:"message"` 13 | } 14 | 15 | type PaLMPrompt struct { 16 | Messages []PaLMChatMessage `json:"messages"` 17 | } 18 | 19 | type PaLMChatRequest struct { 20 | Prompt PaLMPrompt `json:"prompt"` 21 | Temperature float64 `json:"temperature,omitempty"` 22 | CandidateCount int `json:"candidateCount,omitempty"` 23 | TopP float64 `json:"topP,omitempty"` 24 | TopK uint `json:"topK,omitempty"` 25 | } 26 | 27 | type PaLMError struct { 28 | Code int `json:"code"` 29 | Message string `json:"message"` 30 | Status string `json:"status"` 31 | } 32 | 33 | type PaLMChatResponse struct { 34 | Candidates []PaLMChatMessage `json:"candidates"` 35 | Messages []dto.Message `json:"messages"` 36 | Filters []PaLMFilter `json:"filters"` 37 | Error PaLMError `json:"error"` 38 | } 39 | -------------------------------------------------------------------------------- /relay/channel/tencent/adaptor.go: -------------------------------------------------------------------------------- 1 | package tencent 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "one-api/dto" 9 | "one-api/relay/channel" 10 | relaycommon "one-api/relay/common" 11 | "one-api/service" 12 | "strings" 13 | 14 | "github.com/gin-gonic/gin" 15 | ) 16 | 17 | type Adaptor struct { 18 | Sign string 19 | } 20 | 21 | func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { 22 | } 23 | 24 | func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { 25 | return fmt.Sprintf("%s/hyllm/v1/chat/completions", info.BaseUrl), nil 26 | } 27 | 28 | func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { 29 | channel.SetupApiRequestHeader(info, c, req) 30 | req.Header.Set("Authorization", a.Sign) 31 | req.Header.Set("X-TC-Action", info.UpstreamModelName) 32 | return nil 33 | } 34 | 35 | func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { 36 | if request == nil { 37 | return nil, errors.New("request is nil") 38 | } 39 | apiKey := c.Request.Header.Get("Authorization") 40 | apiKey = strings.TrimPrefix(apiKey, "Bearer ") 41 | appId, secretId, secretKey, err := parseTencentConfig(apiKey) 42 | if err != nil { 43 | return nil, err 44 | } 45 | tencentRequest := requestOpenAI2Tencent(*request) 46 | tencentRequest.AppId = appId 47 | tencentRequest.SecretId = secretId 48 | // we have to calculate the sign here 49 | a.Sign = getTencentSign(*tencentRequest, secretKey) 50 | return tencentRequest, nil 51 | } 52 | 53 | func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { 54 | return channel.DoApiRequest(a, c, info, requestBody) 55 | } 56 | 57 | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { 58 | if info.IsStream { 59 | var responseText string 60 | err, responseText = tencentStreamHandler(c, resp) 61 | usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) 62 | } else { 63 | err, usage = tencentHandler(c, resp) 64 | } 65 | return 66 | } 67 | 68 | func (a *Adaptor) GetModelList() []string { 69 | return ModelList 70 | } 71 | 72 | func (a *Adaptor) GetChannelName() string { 73 | return ChannelName 74 | } 75 | -------------------------------------------------------------------------------- /relay/channel/tencent/constants.go: -------------------------------------------------------------------------------- 1 | package tencent 2 | 3 | var ModelList = []string{ 4 | "ChatPro", 5 | "ChatStd", 6 | "hunyuan", 7 | } 8 | 9 | var ChannelName = "tencent" 10 | -------------------------------------------------------------------------------- /relay/channel/tencent/dto.go: -------------------------------------------------------------------------------- 1 | package tencent 2 | 3 | import "one-api/dto" 4 | 5 | type TencentMessage struct { 6 | Role string `json:"role"` 7 | Content string `json:"content"` 8 | } 9 | 10 | type TencentChatRequest struct { 11 | AppId int64 `json:"app_id"` // 腾讯云账号的 APPID 12 | SecretId string `json:"secret_id"` // 官网 SecretId 13 | // Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 14 | // 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 15 | Timestamp int64 `json:"timestamp"` 16 | // Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, 17 | // 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 18 | Expired int64 `json:"expired"` 19 | QueryID string `json:"query_id"` //请求 Id,用于问题排查 20 | // Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 21 | // 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 22 | // 建议该参数和 top_p 只设置1个,不要同时更改 top_p 23 | Temperature float64 `json:"temperature"` 24 | // TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 25 | // 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 26 | // 建议该参数和 temperature 只设置1个,不要同时更改 27 | TopP float64 `json:"top_p"` 28 | // Stream 0:同步,1:流式 (默认,协议:SSE) 29 | // 同步请求超时:60s,如果内容较长建议使用流式 30 | Stream int `json:"stream"` 31 | // Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 32 | // 输入 content 总数最大支持 3000 token。 33 | Messages []TencentMessage `json:"messages"` 34 | } 35 | 36 | type TencentError struct { 37 | Code int `json:"code"` 38 | Message string `json:"message"` 39 | } 40 | 41 | type TencentUsage struct { 42 | InputTokens int `json:"input_tokens"` 43 | OutputTokens int `json:"output_tokens"` 44 | TotalTokens int `json:"total_tokens"` 45 | } 46 | 47 | type TencentResponseChoices struct { 48 | FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 49 | Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 50 | Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 51 | } 52 | 53 | type TencentChatResponse struct { 54 | Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 55 | Created string `json:"created,omitempty"` // unix 时间戳的字符串 56 | Id string `json:"id,omitempty"` // 会话 id 57 | Usage dto.Usage `json:"usage,omitempty"` // token 数量 58 | Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 59 | Note string `json:"note,omitempty"` // 注释 60 | ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 61 | } 62 | -------------------------------------------------------------------------------- /relay/channel/xunfei/adaptor.go: -------------------------------------------------------------------------------- 1 | package xunfei 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "net/http" 7 | "one-api/dto" 8 | "one-api/relay/channel" 9 | relaycommon "one-api/relay/common" 10 | "one-api/service" 11 | "strings" 12 | 13 | "github.com/gin-gonic/gin" 14 | ) 15 | 16 | type Adaptor struct { 17 | request *dto.GeneralOpenAIRequest 18 | } 19 | 20 | func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { 21 | } 22 | 23 | func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { 24 | return "", nil 25 | } 26 | 27 | func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { 28 | channel.SetupApiRequestHeader(info, c, req) 29 | return nil 30 | } 31 | 32 | func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { 33 | if request == nil { 34 | return nil, errors.New("request is nil") 35 | } 36 | a.request = request 37 | return request, nil 38 | } 39 | 40 | func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { 41 | // xunfei's request is not http request, so we don't need to do anything here 42 | dummyResp := &http.Response{} 43 | dummyResp.StatusCode = http.StatusOK 44 | return dummyResp, nil 45 | } 46 | 47 | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { 48 | splits := strings.Split(info.ApiKey, "|") 49 | if len(splits) != 3 { 50 | return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) 51 | } 52 | if a.request == nil { 53 | return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest) 54 | } 55 | if info.IsStream { 56 | err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2]) 57 | } else { 58 | err, usage = xunfeiHandler(c, *a.request, splits[0], splits[1], splits[2]) 59 | } 60 | return 61 | } 62 | 63 | func (a *Adaptor) GetModelList() []string { 64 | return ModelList 65 | } 66 | 67 | func (a *Adaptor) GetChannelName() string { 68 | return ChannelName 69 | } 70 | -------------------------------------------------------------------------------- /relay/channel/xunfei/constants.go: -------------------------------------------------------------------------------- 1 | package xunfei 2 | 3 | var ModelList = []string{ 4 | "SparkDesk", 5 | "SparkDesk-v1.1", 6 | "SparkDesk-v2.1", 7 | "SparkDesk-v3.1", 8 | "SparkDesk-v3.5", 9 | } 10 | 11 | var ChannelName = "xunfei" 12 | -------------------------------------------------------------------------------- /relay/channel/xunfei/dto.go: -------------------------------------------------------------------------------- 1 | package xunfei 2 | 3 | import "one-api/dto" 4 | 5 | type XunfeiMessage struct { 6 | Role string `json:"role"` 7 | Content string `json:"content"` 8 | } 9 | 10 | type XunfeiChatRequest struct { 11 | Header struct { 12 | AppId string `json:"app_id"` 13 | } `json:"header"` 14 | Parameter struct { 15 | Chat struct { 16 | Domain string `json:"domain,omitempty"` 17 | Temperature float64 `json:"temperature,omitempty"` 18 | TopK int `json:"top_k,omitempty"` 19 | MaxTokens uint `json:"max_tokens,omitempty"` 20 | Auditing bool `json:"auditing,omitempty"` 21 | } `json:"chat"` 22 | } `json:"parameter"` 23 | Payload struct { 24 | Message struct { 25 | Text []XunfeiMessage `json:"text"` 26 | } `json:"message"` 27 | } `json:"payload"` 28 | } 29 | 30 | type XunfeiChatResponseTextItem struct { 31 | Content string `json:"content"` 32 | Role string `json:"role"` 33 | Index int `json:"index"` 34 | } 35 | 36 | type XunfeiChatResponse struct { 37 | Header struct { 38 | Code int `json:"code"` 39 | Message string `json:"message"` 40 | Sid string `json:"sid"` 41 | Status int `json:"status"` 42 | } `json:"header"` 43 | Payload struct { 44 | Choices struct { 45 | Status int `json:"status"` 46 | Seq int `json:"seq"` 47 | Text []XunfeiChatResponseTextItem `json:"text"` 48 | } `json:"choices"` 49 | Usage struct { 50 | //Text struct { 51 | // QuestionTokens string `json:"question_tokens"` 52 | // PromptTokens string `json:"prompt_tokens"` 53 | // CompletionTokens string `json:"completion_tokens"` 54 | // TotalTokens string `json:"total_tokens"` 55 | //} `json:"text"` 56 | Text dto.Usage `json:"text"` 57 | } `json:"usage"` 58 | } `json:"payload"` 59 | } 60 | -------------------------------------------------------------------------------- /relay/channel/zhipu/adaptor.go: -------------------------------------------------------------------------------- 1 | package zhipu 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "one-api/dto" 9 | "one-api/relay/channel" 10 | relaycommon "one-api/relay/common" 11 | 12 | "github.com/gin-gonic/gin" 13 | ) 14 | 15 | type Adaptor struct { 16 | } 17 | 18 | func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { 19 | } 20 | 21 | func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { 22 | method := "invoke" 23 | if info.IsStream { 24 | method = "sse-invoke" 25 | } 26 | return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil 27 | } 28 | 29 | func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { 30 | channel.SetupApiRequestHeader(info, c, req) 31 | token := getZhipuToken(info.ApiKey) 32 | req.Header.Set("Authorization", token) 33 | return nil 34 | } 35 | 36 | func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { 37 | if request == nil { 38 | return nil, errors.New("request is nil") 39 | } 40 | return requestOpenAI2Zhipu(*request), nil 41 | } 42 | 43 | func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { 44 | return channel.DoApiRequest(a, c, info, requestBody) 45 | } 46 | 47 | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { 48 | if info.IsStream { 49 | err, usage = zhipuStreamHandler(c, resp) 50 | } else { 51 | err, usage = zhipuHandler(c, resp) 52 | } 53 | return 54 | } 55 | 56 | func (a *Adaptor) GetModelList() []string { 57 | return ModelList 58 | } 59 | 60 | func (a *Adaptor) GetChannelName() string { 61 | return ChannelName 62 | } 63 | -------------------------------------------------------------------------------- /relay/channel/zhipu/constants.go: -------------------------------------------------------------------------------- 1 | package zhipu 2 | 3 | var ModelList = []string{ 4 | "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", 5 | } 6 | 7 | var ChannelName = "zhipu" 8 | -------------------------------------------------------------------------------- /relay/channel/zhipu/dto.go: -------------------------------------------------------------------------------- 1 | package zhipu 2 | 3 | import ( 4 | "one-api/dto" 5 | "time" 6 | ) 7 | 8 | type ZhipuMessage struct { 9 | Role string `json:"role"` 10 | Content string `json:"content"` 11 | } 12 | 13 | type ZhipuRequest struct { 14 | Prompt []ZhipuMessage `json:"prompt"` 15 | Temperature float64 `json:"temperature,omitempty"` 16 | TopP float64 `json:"top_p,omitempty"` 17 | RequestId string `json:"request_id,omitempty"` 18 | Incremental bool `json:"incremental,omitempty"` 19 | } 20 | 21 | type ZhipuResponseData struct { 22 | TaskId string `json:"task_id"` 23 | RequestId string `json:"request_id"` 24 | TaskStatus string `json:"task_status"` 25 | Choices []ZhipuMessage `json:"choices"` 26 | dto.Usage `json:"usage"` 27 | } 28 | 29 | type ZhipuResponse struct { 30 | Code int `json:"code"` 31 | Msg string `json:"msg"` 32 | Success bool `json:"success"` 33 | Data ZhipuResponseData `json:"data"` 34 | } 35 | 36 | type ZhipuStreamMetaResponse struct { 37 | RequestId string `json:"request_id"` 38 | TaskId string `json:"task_id"` 39 | TaskStatus string `json:"task_status"` 40 | dto.Usage `json:"usage"` 41 | } 42 | 43 | type zhipuTokenData struct { 44 | Token string 45 | ExpiryTime time.Time 46 | } 47 | -------------------------------------------------------------------------------- /relay/channel/zhipu_4v/adaptor.go: -------------------------------------------------------------------------------- 1 | package zhipu_4v 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "one-api/dto" 9 | "one-api/relay/channel" 10 | "one-api/relay/channel/openai" 11 | relaycommon "one-api/relay/common" 12 | 13 | "one-api/service" 14 | 15 | "github.com/gin-gonic/gin" 16 | ) 17 | 18 | type Adaptor struct { 19 | } 20 | 21 | func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { 22 | } 23 | 24 | func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { 25 | return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil 26 | } 27 | 28 | func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { 29 | channel.SetupApiRequestHeader(info, c, req) 30 | token := getZhipuToken(info.ApiKey) 31 | req.Header.Set("Authorization", token) 32 | return nil 33 | } 34 | 35 | func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { 36 | if request == nil { 37 | return nil, errors.New("request is nil") 38 | } 39 | return requestOpenAI2Zhipu(*request), nil 40 | } 41 | 42 | func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { 43 | return channel.DoApiRequest(a, c, info, requestBody) 44 | } 45 | 46 | func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { 47 | if info.IsStream { 48 | var responseText string 49 | err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) 50 | usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) 51 | } else { 52 | err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) 53 | } 54 | return 55 | } 56 | 57 | func (a *Adaptor) GetModelList() []string { 58 | return ModelList 59 | } 60 | 61 | func (a *Adaptor) GetChannelName() string { 62 | return ChannelName 63 | } 64 | -------------------------------------------------------------------------------- /relay/channel/zhipu_4v/constants.go: -------------------------------------------------------------------------------- 1 | package zhipu_4v 2 | 3 | var ModelList = []string{ 4 | "glm-4", "glm-4v", "glm-3-turbo", 5 | } 6 | 7 | var ChannelName = "zhipu_4v" 8 | -------------------------------------------------------------------------------- /relay/channel/zhipu_4v/dto.go: -------------------------------------------------------------------------------- 1 | package zhipu_4v 2 | 3 | import ( 4 | "one-api/dto" 5 | "time" 6 | ) 7 | 8 | // type ZhipuMessage struct { 9 | // Role string `json:"role,omitempty"` 10 | // Content string `json:"content,omitempty"` 11 | // ToolCalls any `json:"tool_calls,omitempty"` 12 | // ToolCallId any `json:"tool_call_id,omitempty"` 13 | // } 14 | // 15 | // type ZhipuRequest struct { 16 | // Model string `json:"model"` 17 | // Stream bool `json:"stream,omitempty"` 18 | // Messages []ZhipuMessage `json:"messages"` 19 | // Temperature float64 `json:"temperature,omitempty"` 20 | // TopP float64 `json:"top_p,omitempty"` 21 | // MaxTokens int `json:"max_tokens,omitempty"` 22 | // Stop []string `json:"stop,omitempty"` 23 | // RequestId string `json:"request_id,omitempty"` 24 | // Tools any `json:"tools,omitempty"` 25 | // ToolChoice any `json:"tool_choice,omitempty"` 26 | // } 27 | // 28 | // type ZhipuV4TextResponseChoice struct { 29 | // Index int `json:"index"` 30 | // ZhipuMessage `json:"message"` 31 | // FinishReason string `json:"finish_reason"` 32 | // } 33 | type ZhipuV4Response struct { 34 | Id string `json:"id"` 35 | Created int64 `json:"created"` 36 | Model string `json:"model"` 37 | TextResponseChoices []dto.OpenAITextResponseChoice `json:"choices"` 38 | Usage dto.Usage `json:"usage"` 39 | Error dto.OpenAIError `json:"error"` 40 | } 41 | 42 | // 43 | //type ZhipuV4StreamResponseChoice struct { 44 | // Index int `json:"index,omitempty"` 45 | // Delta ZhipuMessage `json:"delta"` 46 | // FinishReason *string `json:"finish_reason,omitempty"` 47 | //} 48 | 49 | type ZhipuV4StreamResponse struct { 50 | Id string `json:"id"` 51 | Created int64 `json:"created"` 52 | Choices []dto.ChatCompletionsStreamResponseChoice `json:"choices"` 53 | Usage dto.Usage `json:"usage"` 54 | } 55 | 56 | type tokenData struct { 57 | Token string 58 | ExpiryTime time.Time 59 | } 60 | -------------------------------------------------------------------------------- /relay/common/relay_info.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "one-api/common" 5 | "one-api/relay/constant" 6 | "strings" 7 | "time" 8 | 9 | "github.com/gin-gonic/gin" 10 | ) 11 | 12 | type RelayInfo struct { 13 | ChannelType int 14 | ChannelId int 15 | TokenId int 16 | UserId int 17 | Group string 18 | TokenUnlimited bool 19 | StartTime time.Time 20 | ApiType int 21 | IsStream bool 22 | RelayMode int 23 | UpstreamModelName string 24 | RequestURLPath string 25 | ApiVersion string 26 | PromptTokens int 27 | ApiKey string 28 | BaseUrl string 29 | } 30 | 31 | func GenRelayInfo(c *gin.Context) *RelayInfo { 32 | channelType := c.GetInt("channel") 33 | channelId := c.GetInt("channel_id") 34 | tokenId := c.GetInt("token_id") 35 | userId := c.GetInt("id") 36 | group := c.GetString("group") 37 | tokenUnlimited := c.GetBool("token_unlimited_quota") 38 | startTime := time.Now() 39 | 40 | apiType := constant.ChannelType2APIType(channelType) 41 | 42 | info := &RelayInfo{ 43 | RelayMode: constant.Path2RelayMode(c.Request.URL.Path), 44 | BaseUrl: c.GetString("base_url"), 45 | RequestURLPath: c.Request.URL.String(), 46 | ChannelType: channelType, 47 | ChannelId: channelId, 48 | TokenId: tokenId, 49 | UserId: userId, 50 | Group: group, 51 | TokenUnlimited: tokenUnlimited, 52 | StartTime: startTime, 53 | ApiType: apiType, 54 | ApiVersion: c.GetString("api_version"), 55 | ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), 56 | } 57 | if info.BaseUrl == "" { 58 | info.BaseUrl = common.ChannelBaseURLs[channelType] 59 | } 60 | if info.ChannelType == common.ChannelTypeAzure { 61 | info.ApiVersion = GetAzureAPIVersion(c) 62 | } 63 | return info 64 | } 65 | 66 | func (info *RelayInfo) SetPromptTokens(promptTokens int) { 67 | info.PromptTokens = promptTokens 68 | } 69 | 70 | func (info *RelayInfo) SetIsStream(isStream bool) { 71 | info.IsStream = isStream 72 | } 73 | -------------------------------------------------------------------------------- /relay/common/relay_utils.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | _ "image/gif" 7 | _ "image/jpeg" 8 | _ "image/png" 9 | "io" 10 | "net/http" 11 | "one-api/common" 12 | "one-api/dto" 13 | "strconv" 14 | "strings" 15 | 16 | "github.com/gin-gonic/gin" 17 | ) 18 | 19 | var StopFinishReason = "stop" 20 | 21 | func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { 22 | OpenAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{ 23 | StatusCode: resp.StatusCode, 24 | Error: dto.OpenAIError{ 25 | Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), 26 | Type: "upstream_error", 27 | Code: "bad_response_status_code", 28 | Param: strconv.Itoa(resp.StatusCode), 29 | }, 30 | } 31 | responseBody, err := io.ReadAll(resp.Body) 32 | if err != nil { 33 | return 34 | } 35 | err = resp.Body.Close() 36 | if err != nil { 37 | return 38 | } 39 | var textResponse dto.TextResponse 40 | err = json.Unmarshal(responseBody, &textResponse) 41 | if err != nil { 42 | return 43 | } 44 | OpenAIErrorWithStatusCode.Error = textResponse.Error 45 | return 46 | } 47 | 48 | func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { 49 | fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) 50 | 51 | if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { 52 | switch channelType { 53 | case common.ChannelTypeOpenAI: 54 | fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) 55 | case common.ChannelTypeAzure: 56 | fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) 57 | } 58 | } 59 | return fullRequestURL 60 | } 61 | 62 | func GetAPIVersion(c *gin.Context) string { 63 | query := c.Request.URL.Query() 64 | apiVersion := query.Get("api-version") 65 | if apiVersion == "" { 66 | apiVersion = c.GetString("api_version") 67 | } 68 | return apiVersion 69 | } 70 | 71 | func GetAzureAPIVersion(c *gin.Context) string { 72 | query := c.Request.URL.Query() 73 | apiVersion := query.Get("api-version") 74 | if apiVersion == "" { 75 | apiVersion = c.GetString("api_version") 76 | } 77 | return apiVersion 78 | } 79 | -------------------------------------------------------------------------------- /relay/constant/api_type.go: -------------------------------------------------------------------------------- 1 | package constant 2 | 3 | import ( 4 | "one-api/common" 5 | ) 6 | 7 | const ( 8 | APITypeOpenAI = iota 9 | APITypeAnthropic 10 | APITypePaLM 11 | APITypeBaidu 12 | APITypeZhipu 13 | APITypeAli 14 | APITypeXunfei 15 | APITypeAIProxyLibrary 16 | APITypeTencent 17 | APITypeGemini 18 | APITypeZhipu_v4 19 | 20 | APITypeDummy // this one is only for count, do not add any channel after this 21 | ) 22 | 23 | func ChannelType2APIType(channelType int) int { 24 | apiType := APITypeOpenAI 25 | switch channelType { 26 | case common.ChannelTypeAnthropic: 27 | apiType = APITypeAnthropic 28 | case common.ChannelTypeBaidu: 29 | apiType = APITypeBaidu 30 | case common.ChannelTypePaLM: 31 | apiType = APITypePaLM 32 | case common.ChannelTypeZhipu: 33 | apiType = APITypeZhipu 34 | case common.ChannelTypeAli: 35 | apiType = APITypeAli 36 | case common.ChannelTypeXunfei: 37 | apiType = APITypeXunfei 38 | case common.ChannelTypeAIProxyLibrary: 39 | apiType = APITypeAIProxyLibrary 40 | case common.ChannelTypeTencent: 41 | apiType = APITypeTencent 42 | case common.ChannelTypeGemini: 43 | apiType = APITypeGemini 44 | case common.ChannelTypeZhipu_v4: 45 | apiType = APITypeZhipu_v4 46 | } 47 | return apiType 48 | } 49 | -------------------------------------------------------------------------------- /relay/constant/relay_mode.go: -------------------------------------------------------------------------------- 1 | package constant 2 | 3 | import "strings" 4 | 5 | const ( 6 | RelayModeUnknown = iota 7 | RelayModeChatCompletions 8 | RelayModeCompletions 9 | RelayModeEmbeddings 10 | RelayModeModerations 11 | RelayModeImagesGenerations 12 | RelayModeEdits 13 | RelayModeMidjourneyImagine 14 | RelayModeMidjourneyDescribe 15 | RelayModeMidjourneyBlend 16 | RelayModeMidjourneyChange 17 | RelayModeMidjourneySimpleChange 18 | RelayModeMidjourneyNotify 19 | RelayModeMidjourneyTaskFetch 20 | RelayModeMidjourneyTaskFetchByCondition 21 | RelayModeAudioSpeech 22 | RelayModeAudioTranscription 23 | RelayModeAudioTranslation 24 | ) 25 | 26 | func Path2RelayMode(path string) int { 27 | relayMode := RelayModeUnknown 28 | if strings.HasPrefix(path, "/v1/chat/completions") { 29 | relayMode = RelayModeChatCompletions 30 | } else if strings.HasPrefix(path, "/v1/completions") { 31 | relayMode = RelayModeCompletions 32 | } else if strings.HasPrefix(path, "/v1/embeddings") { 33 | relayMode = RelayModeEmbeddings 34 | } else if strings.HasSuffix(path, "embeddings") { 35 | relayMode = RelayModeEmbeddings 36 | } else if strings.HasPrefix(path, "/v1/moderations") { 37 | relayMode = RelayModeModerations 38 | } else if strings.HasPrefix(path, "/v1/images/generations") { 39 | relayMode = RelayModeImagesGenerations 40 | } else if strings.HasPrefix(path, "/v1/edits") { 41 | relayMode = RelayModeEdits 42 | } else if strings.HasPrefix(path, "/v1/audio/speech") { 43 | relayMode = RelayModeAudioSpeech 44 | } else if strings.HasPrefix(path, "/v1/audio/transcriptions") { 45 | relayMode = RelayModeAudioTranscription 46 | } else if strings.HasPrefix(path, "/v1/audio/translations") { 47 | relayMode = RelayModeAudioTranslation 48 | } 49 | return relayMode 50 | } 51 | -------------------------------------------------------------------------------- /relay/relay_adaptor.go: -------------------------------------------------------------------------------- 1 | package relay 2 | 3 | import ( 4 | "one-api/relay/channel" 5 | "one-api/relay/channel/ali" 6 | "one-api/relay/channel/baidu" 7 | "one-api/relay/channel/claude" 8 | "one-api/relay/channel/gemini" 9 | "one-api/relay/channel/openai" 10 | "one-api/relay/channel/palm" 11 | "one-api/relay/channel/tencent" 12 | "one-api/relay/channel/xunfei" 13 | "one-api/relay/channel/zhipu" 14 | "one-api/relay/channel/zhipu_4v" 15 | "one-api/relay/constant" 16 | ) 17 | 18 | func GetAdaptor(apiType int) channel.Adaptor { 19 | switch apiType { 20 | //case constant.APITypeAIProxyLibrary: 21 | // return &aiproxy.Adaptor{} 22 | case constant.APITypeAli: 23 | return &ali.Adaptor{} 24 | case constant.APITypeAnthropic: 25 | return &claude.Adaptor{} 26 | case constant.APITypeBaidu: 27 | return &baidu.Adaptor{} 28 | case constant.APITypeGemini: 29 | return &gemini.Adaptor{} 30 | case constant.APITypeOpenAI: 31 | return &openai.Adaptor{} 32 | case constant.APITypePaLM: 33 | return &palm.Adaptor{} 34 | case constant.APITypeTencent: 35 | return &tencent.Adaptor{} 36 | case constant.APITypeXunfei: 37 | return &xunfei.Adaptor{} 38 | case constant.APITypeZhipu: 39 | return &zhipu.Adaptor{} 40 | case constant.APITypeZhipu_v4: 41 | return &zhipu_4v.Adaptor{} 42 | } 43 | return nil 44 | } 45 | -------------------------------------------------------------------------------- /router/dashboard.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "one-api/controller" 5 | "one-api/middleware" 6 | 7 | "github.com/gin-contrib/gzip" 8 | "github.com/gin-gonic/gin" 9 | ) 10 | 11 | func SetDashboardRouter(router *gin.Engine) { 12 | apiRouter := router.Group("/") 13 | apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) 14 | apiRouter.Use(middleware.GlobalAPIRateLimit()) 15 | apiRouter.Use(middleware.CORS()) 16 | apiRouter.Use(middleware.TokenAuth()) 17 | { 18 | apiRouter.GET("/dashboard/billing/subscription", controller.GetSubscription) 19 | apiRouter.GET("/v1/dashboard/billing/subscription", controller.GetSubscription) 20 | apiRouter.GET("/dashboard/billing/usage", controller.GetUsage) 21 | apiRouter.GET("/v1/dashboard/billing/usage", controller.GetUsage) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /router/main.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "embed" 5 | "fmt" 6 | "net/http" 7 | "one-api/common" 8 | "os" 9 | "strings" 10 | 11 | "github.com/gin-gonic/gin" 12 | ) 13 | 14 | func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { 15 | SetApiRouter(router) 16 | SetDashboardRouter(router) 17 | SetRelayRouter(router) 18 | frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL") 19 | if common.IsMasterNode && frontendBaseUrl != "" { 20 | frontendBaseUrl = "" 21 | common.SysLog("FRONTEND_BASE_URL is ignored on master node") 22 | } 23 | if frontendBaseUrl == "" { 24 | SetWebRouter(router, buildFS, indexPage) 25 | } else { 26 | frontendBaseUrl = strings.TrimSuffix(frontendBaseUrl, "/") 27 | router.NoRoute(func(c *gin.Context) { 28 | c.Redirect(http.StatusMovedPermanently, fmt.Sprintf("%s%s", frontendBaseUrl, c.Request.RequestURI)) 29 | }) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /router/relay-router.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "one-api/controller" 5 | "one-api/middleware" 6 | "one-api/relay" 7 | 8 | "github.com/gin-gonic/gin" 9 | ) 10 | 11 | func SetRelayRouter(router *gin.Engine) { 12 | router.Use(middleware.CORS()) 13 | // https://platform.openai.com/docs/api-reference/introduction 14 | modelsRouter := router.Group("/v1/models") 15 | modelsRouter.Use(middleware.TokenAuth()) 16 | { 17 | modelsRouter.GET("", controller.ListModels) 18 | modelsRouter.GET("/:model", controller.RetrieveModel) 19 | } 20 | relayV1Router := router.Group("/v1") 21 | relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) 22 | { 23 | relayV1Router.POST("/completions", controller.Relay) 24 | relayV1Router.POST("/chat/completions", controller.Relay) 25 | relayV1Router.POST("/edits", controller.Relay) 26 | relayV1Router.POST("/images/generations", controller.Relay) 27 | relayV1Router.POST("/images/edits", controller.RelayNotImplemented) 28 | relayV1Router.POST("/images/variations", controller.RelayNotImplemented) 29 | relayV1Router.POST("/embeddings", controller.Relay) 30 | relayV1Router.POST("/engines/:model/embeddings", controller.Relay) 31 | relayV1Router.POST("/audio/transcriptions", controller.Relay) 32 | relayV1Router.POST("/audio/translations", controller.Relay) 33 | relayV1Router.POST("/audio/speech", controller.Relay) 34 | relayV1Router.GET("/files", controller.RelayNotImplemented) 35 | relayV1Router.POST("/files", controller.RelayNotImplemented) 36 | relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) 37 | relayV1Router.GET("/files/:id", controller.RelayNotImplemented) 38 | relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented) 39 | relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented) 40 | relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented) 41 | relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented) 42 | relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) 43 | relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) 44 | relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) 45 | relayV1Router.POST("/moderations", controller.Relay) 46 | } 47 | relayMjRouter := router.Group("/mj") 48 | relayMjRouter.GET("/image/:id", relay.RelayMidjourneyImage) 49 | relayMjRouter.Use(middleware.TokenAuth(), middleware.Distribute()) 50 | { 51 | relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney) 52 | relayMjRouter.POST("/submit/change", controller.RelayMidjourney) 53 | relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney) 54 | relayMjRouter.POST("/submit/describe", controller.RelayMidjourney) 55 | relayMjRouter.POST("/submit/blend", controller.RelayMidjourney) 56 | relayMjRouter.POST("/notify", controller.RelayMidjourney) 57 | relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney) 58 | relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney) 59 | } 60 | //relayMjRouter.Use() 61 | } 62 | -------------------------------------------------------------------------------- /router/web-router.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "embed" 5 | "net/http" 6 | "one-api/common" 7 | "one-api/controller" 8 | "one-api/middleware" 9 | "strings" 10 | 11 | "github.com/gin-contrib/gzip" 12 | "github.com/gin-contrib/static" 13 | "github.com/gin-gonic/gin" 14 | ) 15 | 16 | func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) { 17 | router.Use(gzip.Gzip(gzip.DefaultCompression)) 18 | router.Use(middleware.GlobalWebRateLimit()) 19 | router.Use(middleware.Cache()) 20 | router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/build"))) 21 | router.NoRoute(func(c *gin.Context) { 22 | if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") { 23 | controller.RelayNotFound(c) 24 | return 25 | } 26 | c.Header("Cache-Control", "no-cache") 27 | c.Data(http.StatusOK, "text/html; charset=utf-8", indexPage) 28 | }) 29 | } 30 | -------------------------------------------------------------------------------- /service/channel.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "one-api/common" 7 | relaymodel "one-api/dto" 8 | "one-api/model" 9 | ) 10 | 11 | // disable & notify 12 | func DisableChannel(channelId int, channelName string, reason string) { 13 | model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) 14 | subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) 15 | content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) 16 | notifyRootUser(subject, content) 17 | } 18 | 19 | func EnableChannel(channelId int, channelName string) { 20 | model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) 21 | subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) 22 | content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) 23 | notifyRootUser(subject, content) 24 | } 25 | 26 | func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool { 27 | if !common.AutomaticDisableChannelEnabled { 28 | return false 29 | } 30 | if err == nil { 31 | return false 32 | } 33 | if statusCode == http.StatusUnauthorized { 34 | return true 35 | } 36 | if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" || err.Code == "billing_not_active" { 37 | return true 38 | } 39 | return false 40 | } 41 | 42 | func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError) bool { 43 | if !common.AutomaticEnableChannelEnabled { 44 | return false 45 | } 46 | if err != nil { 47 | return false 48 | } 49 | if openAIErr != nil { 50 | return false 51 | } 52 | return true 53 | } 54 | -------------------------------------------------------------------------------- /service/epay.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import "one-api/common" 4 | 5 | func GetCallbackAddress() string { 6 | if common.CustomCallbackAddress == "" { 7 | return common.ServerAddress 8 | } 9 | return common.CustomCallbackAddress 10 | } 11 | -------------------------------------------------------------------------------- /service/error.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | "one-api/common" 9 | "one-api/dto" 10 | "strconv" 11 | "strings" 12 | ) 13 | 14 | // OpenAIErrorWrapper wraps an error into an OpenAIErrorWithStatusCode 15 | func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode { 16 | text := err.Error() 17 | // 定义一个正则表达式匹配URL 18 | if strings.Contains(text, "Post") { 19 | common.SysLog(fmt.Sprintf("error: %s", text)) 20 | text = "请求上游地址失败" 21 | } 22 | //避免暴露内部错误 23 | 24 | openAIError := dto.OpenAIError{ 25 | Message: text, 26 | Type: "new_api_error", 27 | Code: code, 28 | } 29 | return &dto.OpenAIErrorWithStatusCode{ 30 | Error: openAIError, 31 | StatusCode: statusCode, 32 | } 33 | } 34 | 35 | func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) { 36 | errWithStatusCode = &dto.OpenAIErrorWithStatusCode{ 37 | StatusCode: resp.StatusCode, 38 | Error: dto.OpenAIError{ 39 | Message: "", 40 | Type: "upstream_error", 41 | Code: "bad_response_status_code", 42 | Param: strconv.Itoa(resp.StatusCode), 43 | }, 44 | } 45 | responseBody, err := io.ReadAll(resp.Body) 46 | if err != nil { 47 | return 48 | } 49 | err = resp.Body.Close() 50 | if err != nil { 51 | return 52 | } 53 | var errResponse dto.GeneralErrorResponse 54 | err = json.Unmarshal(responseBody, &errResponse) 55 | if err != nil { 56 | return 57 | } 58 | if errResponse.Error.Message != "" { 59 | // OpenAI format error, so we override the default one 60 | errWithStatusCode.Error = errResponse.Error 61 | } else { 62 | errWithStatusCode.Error.Message = errResponse.ToMessage() 63 | } 64 | if errWithStatusCode.Error.Message == "" { 65 | errWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) 66 | } 67 | return 68 | } 69 | -------------------------------------------------------------------------------- /service/http_client.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "net/http" 5 | "one-api/common" 6 | "time" 7 | ) 8 | 9 | var httpClient *http.Client 10 | var impatientHTTPClient *http.Client 11 | 12 | func init() { 13 | if common.RelayTimeout == 0 { 14 | httpClient = &http.Client{} 15 | } else { 16 | httpClient = &http.Client{ 17 | Timeout: time.Duration(common.RelayTimeout) * time.Second, 18 | } 19 | } 20 | 21 | impatientHTTPClient = &http.Client{ 22 | Timeout: 5 * time.Second, 23 | } 24 | } 25 | 26 | func GetHttpClient() *http.Client { 27 | return httpClient 28 | } 29 | 30 | func GetImpatientHttpClient() *http.Client { 31 | return impatientHTTPClient 32 | } 33 | -------------------------------------------------------------------------------- /service/sse.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import "github.com/gin-gonic/gin" 4 | 5 | func SetEventStreamHeaders(c *gin.Context) { 6 | c.Writer.Header().Set("Content-Type", "text/event-stream") 7 | c.Writer.Header().Set("Cache-Control", "no-cache") 8 | c.Writer.Header().Set("Connection", "keep-alive") 9 | c.Writer.Header().Set("Transfer-Encoding", "chunked") 10 | c.Writer.Header().Set("X-Accel-Buffering", "no") 11 | } 12 | -------------------------------------------------------------------------------- /service/usage_helpr.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "errors" 5 | "one-api/dto" 6 | "one-api/relay/constant" 7 | ) 8 | 9 | func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) { 10 | switch relayMode { 11 | case constant.RelayModeChatCompletions: 12 | return CountTokenMessages(textRequest.Messages, textRequest.Model) 13 | case constant.RelayModeCompletions: 14 | return CountTokenInput(textRequest.Prompt, textRequest.Model), nil 15 | case constant.RelayModeModerations: 16 | return CountTokenInput(textRequest.Input, textRequest.Model), nil 17 | } 18 | return 0, errors.New("unknown relay mode") 19 | } 20 | 21 | func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage { 22 | usage := &dto.Usage{} 23 | usage.PromptTokens = promptTokens 24 | usage.CompletionTokens = CountTokenText(responseText, modeName) 25 | usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens 26 | return usage 27 | } 28 | -------------------------------------------------------------------------------- /service/user_notify.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "fmt" 5 | "one-api/common" 6 | "one-api/model" 7 | ) 8 | 9 | func notifyRootUser(subject string, content string) { 10 | if common.RootUserEmail == "" { 11 | common.RootUserEmail = model.GetRootUserEmail() 12 | } 13 | err := common.SendEmail(subject, common.RootUserEmail, content) 14 | if err != nil { 15 | common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /web/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | 8 | # testing 9 | /coverage 10 | 11 | # production 12 | /build 13 | 14 | # misc 15 | .DS_Store 16 | .env.local 17 | .env.development.local 18 | .env.test.local 19 | .env.production.local 20 | 21 | npm-debug.log* 22 | yarn-debug.log* 23 | yarn-error.log* 24 | .idea 25 | package-lock.json 26 | yarn.lock 27 | -------------------------------------------------------------------------------- /web/README.md: -------------------------------------------------------------------------------- 1 | # React Template 2 | 3 | ## Basic Usages 4 | 5 | ```shell 6 | # Runs the app in the development mode 7 | npm start 8 | 9 | # Builds the app for production to the `build` folder 10 | npm run build 11 | ``` 12 | 13 | If you want to change the default server, please set `REACT_APP_SERVER` environment variables before build, 14 | for example: `REACT_APP_SERVER=http://your.domain.com`. 15 | 16 | Before you start editing, make sure your `Actions on Save` options have `Optimize imports` & `Run Prettier` enabled. 17 | 18 | ## Reference 19 | 20 | 1. https://github.com/OIerDb-ng/OIerDb 21 | 2. https://github.com/cornflourblue/react-hooks-redux-registration-login-example 22 | -------------------------------------------------------------------------------- /web/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "react-template", 3 | "version": "0.1.0", 4 | "private": true, 5 | "dependencies": { 6 | "@douyinfe/semi-icons": "^2.46.1", 7 | "@douyinfe/semi-ui": "^2.46.1", 8 | "@visactor/react-vchart": "~1.8.8", 9 | "@visactor/vchart": "~1.8.8", 10 | "@visactor/vchart-semi-theme": "~1.8.8", 11 | "axios": "^0.27.2", 12 | "history": "^5.3.0", 13 | "marked": "^4.1.1", 14 | "react": "^18.2.0", 15 | "react-dom": "^18.2.0", 16 | "react-dropzone": "^14.2.3", 17 | "react-fireworks": "^1.0.4", 18 | "react-router-dom": "^6.3.0", 19 | "react-scripts": "5.0.1", 20 | "react-telegram-login": "^1.1.2", 21 | "react-toastify": "^9.0.8", 22 | "react-turnstile": "^1.0.5", 23 | "semantic-ui-css": "^2.5.0", 24 | "semantic-ui-react": "^2.1.3", 25 | "usehooks-ts": "^2.9.1" 26 | }, 27 | "scripts": { 28 | "start": "react-scripts start", 29 | "build": "react-scripts build", 30 | "test": "react-scripts test", 31 | "eject": "react-scripts eject" 32 | }, 33 | "eslintConfig": { 34 | "extends": [ 35 | "react-app", 36 | "react-app/jest" 37 | ] 38 | }, 39 | "browserslist": { 40 | "production": [ 41 | ">0.2%", 42 | "not dead", 43 | "not op_mini all" 44 | ], 45 | "development": [ 46 | "last 1 chrome version", 47 | "last 1 firefox version", 48 | "last 1 safari version" 49 | ] 50 | }, 51 | "devDependencies": { 52 | "prettier": "^2.7.1", 53 | "typescript": "4.4.2" 54 | }, 55 | "prettier": { 56 | "singleQuote": true, 57 | "jsxSingleQuote": true 58 | }, 59 | "proxy": "http://localhost:3000" 60 | } 61 | -------------------------------------------------------------------------------- /web/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ehco1996/new-api/2f4e6e4e1d9129c9d5fa0a25a9207254d62c1f06/web/public/favicon.ico -------------------------------------------------------------------------------- /web/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 12 | New API 13 | 14 | 15 | You need to enable JavaScript to run this app. 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /web/public/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ehco1996/new-api/2f4e6e4e1d9129c9d5fa0a25a9207254d62c1f06/web/public/logo.png -------------------------------------------------------------------------------- /web/public/robots.txt: -------------------------------------------------------------------------------- 1 | # https://www.robotstxt.org/robotstxt.html 2 | User-agent: * 3 | Disallow: 4 | -------------------------------------------------------------------------------- /web/src/components/Footer.js: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState } from 'react'; 2 | 3 | import { getFooterHTML, getSystemName } from '../helpers'; 4 | import {Layout} from "@douyinfe/semi-ui"; 5 | 6 | const Footer = () => { 7 | const systemName = getSystemName(); 8 | const [footer, setFooter] = useState(getFooterHTML()); 9 | let remainCheckTimes = 5; 10 | 11 | const loadFooter = () => { 12 | let footer_html = localStorage.getItem('footer_html'); 13 | if (footer_html) { 14 | setFooter(footer_html); 15 | } 16 | }; 17 | 18 | useEffect(() => { 19 | const timer = setInterval(() => { 20 | if (remainCheckTimes <= 0) { 21 | clearInterval(timer); 22 | return; 23 | } 24 | remainCheckTimes--; 25 | loadFooter(); 26 | }, 200); 27 | return () => clearTimeout(timer); 28 | }, []); 29 | 30 | return ( 31 | 32 | 33 | {footer ? ( 34 | 38 | ) : ( 39 | 40 | 44 | New API {process.env.REACT_APP_VERSION}{' '} 45 | 46 | 由{' '} 47 | 48 | Calcium-Ion 49 | {' '} 50 | 开发,基于{' '} 51 | 52 | One API v0.5.4 53 | {' '} 54 | ,本项目根据{' '} 55 | 56 | MIT 许可证 57 | {' '} 58 | 授权 59 | 60 | )} 61 | 62 | 63 | ); 64 | }; 65 | 66 | export default Footer; 67 | -------------------------------------------------------------------------------- /web/src/components/GitHubOAuth.js: -------------------------------------------------------------------------------- 1 | import React, { useContext, useEffect, useState } from 'react'; 2 | import { Dimmer, Loader, Segment } from 'semantic-ui-react'; 3 | import { useNavigate, useSearchParams } from 'react-router-dom'; 4 | import { API, showError, showSuccess } from '../helpers'; 5 | import { UserContext } from '../context/User'; 6 | 7 | const GitHubOAuth = () => { 8 | const [searchParams, setSearchParams] = useSearchParams(); 9 | 10 | const [userState, userDispatch] = useContext(UserContext); 11 | const [prompt, setPrompt] = useState('处理中...'); 12 | const [processing, setProcessing] = useState(true); 13 | 14 | let navigate = useNavigate(); 15 | 16 | const sendCode = async (code, state, count) => { 17 | const res = await API.get(`/api/oauth/github?code=${code}&state=${state}`); 18 | const { success, message, data } = res.data; 19 | if (success) { 20 | if (message === 'bind') { 21 | showSuccess('绑定成功!'); 22 | navigate('/setting'); 23 | } else { 24 | userDispatch({ type: 'login', payload: data }); 25 | localStorage.setItem('user', JSON.stringify(data)); 26 | showSuccess('登录成功!'); 27 | navigate('/'); 28 | } 29 | } else { 30 | showError(message); 31 | if (count === 0) { 32 | setPrompt(`操作失败,重定向至登录界面中...`); 33 | navigate('/setting'); // in case this is failed to bind GitHub 34 | return; 35 | } 36 | count++; 37 | setPrompt(`出现错误,第 ${count} 次重试中...`); 38 | await new Promise((resolve) => setTimeout(resolve, count * 2000)); 39 | await sendCode(code, state, count); 40 | } 41 | }; 42 | 43 | useEffect(() => { 44 | let code = searchParams.get('code'); 45 | let state = searchParams.get('state'); 46 | sendCode(code, state, 0).then(); 47 | }, []); 48 | 49 | return ( 50 | 51 | 52 | {prompt} 53 | 54 | 55 | ); 56 | }; 57 | 58 | export default GitHubOAuth; 59 | -------------------------------------------------------------------------------- /web/src/components/Loading.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { Segment, Dimmer, Loader } from 'semantic-ui-react'; 3 | 4 | const Loading = ({ prompt: name = 'page' }) => { 5 | return ( 6 | 7 | 8 | 加载{name}中... 9 | 10 | 11 | ); 12 | }; 13 | 14 | export default Loading; 15 | -------------------------------------------------------------------------------- /web/src/components/PasswordResetConfirm.js: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState } from 'react'; 2 | import { Button, Form, Grid, Header, Image, Segment } from 'semantic-ui-react'; 3 | import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers'; 4 | import { useSearchParams } from 'react-router-dom'; 5 | 6 | const PasswordResetConfirm = () => { 7 | const [inputs, setInputs] = useState({ 8 | email: '', 9 | token: '', 10 | }); 11 | const { email, token } = inputs; 12 | 13 | const [loading, setLoading] = useState(false); 14 | 15 | const [disableButton, setDisableButton] = useState(false); 16 | const [countdown, setCountdown] = useState(30); 17 | 18 | const [newPassword, setNewPassword] = useState(''); 19 | 20 | const [searchParams, setSearchParams] = useSearchParams(); 21 | useEffect(() => { 22 | let token = searchParams.get('token'); 23 | let email = searchParams.get('email'); 24 | setInputs({ 25 | token, 26 | email, 27 | }); 28 | }, []); 29 | 30 | useEffect(() => { 31 | let countdownInterval = null; 32 | if (disableButton && countdown > 0) { 33 | countdownInterval = setInterval(() => { 34 | setCountdown(countdown - 1); 35 | }, 1000); 36 | } else if (countdown === 0) { 37 | setDisableButton(false); 38 | setCountdown(30); 39 | } 40 | return () => clearInterval(countdownInterval); 41 | }, [disableButton, countdown]); 42 | 43 | async function handleSubmit(e) { 44 | setDisableButton(true); 45 | if (!email) return; 46 | setLoading(true); 47 | const res = await API.post(`/api/user/reset`, { 48 | email, 49 | token, 50 | }); 51 | const { success, message } = res.data; 52 | if (success) { 53 | let password = res.data.data; 54 | setNewPassword(password); 55 | await copy(password); 56 | showNotice(`新密码已复制到剪贴板:${password}`); 57 | } else { 58 | showError(message); 59 | } 60 | setLoading(false); 61 | } 62 | 63 | return ( 64 | 65 | 66 | 67 | 密码重置确认 68 | 69 | 70 | 71 | 80 | {newPassword && ( 81 | { 90 | e.target.select(); 91 | navigator.clipboard.writeText(newPassword); 92 | showNotice(`密码已复制到剪贴板:${newPassword}`); 93 | }} 94 | /> 95 | )} 96 | 104 | {disableButton ? `密码重置完成` : '提交'} 105 | 106 | 107 | 108 | 109 | 110 | ); 111 | }; 112 | 113 | export default PasswordResetConfirm; 114 | -------------------------------------------------------------------------------- /web/src/components/PasswordResetForm.js: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState } from 'react'; 2 | import { Button, Form, Grid, Header, Image, Segment } from 'semantic-ui-react'; 3 | import { API, showError, showInfo, showSuccess } from '../helpers'; 4 | import Turnstile from 'react-turnstile'; 5 | 6 | const PasswordResetForm = () => { 7 | const [inputs, setInputs] = useState({ 8 | email: '' 9 | }); 10 | const { email } = inputs; 11 | 12 | const [loading, setLoading] = useState(false); 13 | const [turnstileEnabled, setTurnstileEnabled] = useState(false); 14 | const [turnstileSiteKey, setTurnstileSiteKey] = useState(''); 15 | const [turnstileToken, setTurnstileToken] = useState(''); 16 | const [disableButton, setDisableButton] = useState(false); 17 | const [countdown, setCountdown] = useState(30); 18 | 19 | useEffect(() => { 20 | let countdownInterval = null; 21 | if (disableButton && countdown > 0) { 22 | countdownInterval = setInterval(() => { 23 | setCountdown(countdown - 1); 24 | }, 1000); 25 | } else if (countdown === 0) { 26 | setDisableButton(false); 27 | setCountdown(30); 28 | } 29 | return () => clearInterval(countdownInterval); 30 | }, [disableButton, countdown]); 31 | 32 | function handleChange(e) { 33 | const { name, value } = e.target; 34 | setInputs(inputs => ({ ...inputs, [name]: value })); 35 | } 36 | 37 | async function handleSubmit(e) { 38 | setDisableButton(true); 39 | if (!email) return; 40 | if (turnstileEnabled && turnstileToken === '') { 41 | showInfo('请稍后几秒重试,Turnstile 正在检查用户环境!'); 42 | return; 43 | } 44 | setLoading(true); 45 | const res = await API.get( 46 | `/api/reset_password?email=${email}&turnstile=${turnstileToken}` 47 | ); 48 | const { success, message } = res.data; 49 | if (success) { 50 | showSuccess('重置邮件发送成功,请检查邮箱!'); 51 | setInputs({ ...inputs, email: '' }); 52 | } else { 53 | showError(message); 54 | } 55 | setLoading(false); 56 | } 57 | 58 | return ( 59 | 60 | 61 | 62 | 密码重置 63 | 64 | 65 | 66 | 75 | {turnstileEnabled ? ( 76 | { 79 | setTurnstileToken(token); 80 | }} 81 | /> 82 | ) : ( 83 | <>> 84 | )} 85 | 93 | {disableButton ? `重试 (${countdown})` : '提交'} 94 | 95 | 96 | 97 | 98 | 99 | ); 100 | }; 101 | 102 | export default PasswordResetForm; 103 | -------------------------------------------------------------------------------- /web/src/components/PrivateRoute.js: -------------------------------------------------------------------------------- 1 | import { Navigate } from 'react-router-dom'; 2 | 3 | import { history } from '../helpers'; 4 | 5 | 6 | function PrivateRoute({ children }) { 7 | if (!localStorage.getItem('user')) { 8 | return ; 9 | } 10 | return children; 11 | } 12 | 13 | export { PrivateRoute }; 14 | -------------------------------------------------------------------------------- /web/src/components/WeChatIcon.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { Icon } from '@douyinfe/semi-ui'; 3 | 4 | const WeChatIcon = () => { 5 | function CustomIcon() { 6 | return 8 | 11 | 14 | ; 15 | } 16 | 17 | return ( 18 | 19 | } /> 20 | 21 | ); 22 | }; 23 | 24 | export default WeChatIcon; 25 | -------------------------------------------------------------------------------- /web/src/components/utils.js: -------------------------------------------------------------------------------- 1 | import { API, showError } from '../helpers'; 2 | 3 | export async function getOAuthState() { 4 | const res = await API.get('/api/oauth/state'); 5 | const { success, message, data } = res.data; 6 | if (success) { 7 | return data; 8 | } else { 9 | showError(message); 10 | return ''; 11 | } 12 | } 13 | 14 | export async function onGitHubOAuthClicked(github_client_id) { 15 | const state = await getOAuthState(); 16 | if (!state) return; 17 | window.open( 18 | `https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email` 19 | ); 20 | } 21 | -------------------------------------------------------------------------------- /web/src/constants/channel.constants.js: -------------------------------------------------------------------------------- 1 | export const CHANNEL_OPTIONS = [ 2 | {key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI'}, 3 | {key: 2, text: 'Midjourney Proxy', value: 2, color: 'light-blue', label: 'Midjourney Proxy'}, 4 | {key: 14, text: 'Anthropic Claude', value: 14, color: 'black', label: 'Anthropic Claude'}, 5 | {key: 3, text: 'Azure OpenAI', value: 3, color: 'olive', label: 'Azure OpenAI'}, 6 | {key: 11, text: 'Google PaLM2', value: 11, color: 'orange', label: 'Google PaLM2'}, 7 | {key: 24, text: 'Google Gemini', value: 24, color: 'orange', label: 'Google Gemini'}, 8 | {key: 15, text: '百度文心千帆', value: 15, color: 'blue', label: '百度文心千帆'}, 9 | {key: 17, text: '阿里通义千问', value: 17, color: 'orange', label: '阿里通义千问'}, 10 | {key: 18, text: '讯飞星火认知', value: 18, color: 'blue', label: '讯飞星火认知'}, 11 | {key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet', label: '智谱 ChatGLM'}, 12 | {key: 16, text: '智谱 GLM-4V', value: 26, color: 'green', label: '智谱 GLM-4V'}, 13 | {key: 16, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot'}, 14 | {key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑'}, 15 | {key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元'}, 16 | {key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道'}, 17 | {key: 22, text: '知识库:FastGPT', value: 22, color: 'blue', label: '知识库:FastGPT'}, 18 | {key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple', label: '知识库:AI Proxy'}, 19 | ]; 20 | -------------------------------------------------------------------------------- /web/src/constants/common.constant.js: -------------------------------------------------------------------------------- 1 | export const ITEMS_PER_PAGE = 10; // this value must keep same as the one defined in backend! 2 | -------------------------------------------------------------------------------- /web/src/constants/index.js: -------------------------------------------------------------------------------- 1 | export * from './toast.constants'; 2 | export * from './user.constants'; 3 | export * from './common.constant'; 4 | export * from './channel.constants'; 5 | -------------------------------------------------------------------------------- /web/src/constants/toast.constants.js: -------------------------------------------------------------------------------- 1 | export const toastConstants = { 2 | SUCCESS_TIMEOUT: 1500, 3 | INFO_TIMEOUT: 3000, 4 | ERROR_TIMEOUT: 5000, 5 | WARNING_TIMEOUT: 10000, 6 | NOTICE_TIMEOUT: 20000 7 | }; 8 | -------------------------------------------------------------------------------- /web/src/constants/user.constants.js: -------------------------------------------------------------------------------- 1 | export const userConstants = { 2 | REGISTER_REQUEST: 'USERS_REGISTER_REQUEST', 3 | REGISTER_SUCCESS: 'USERS_REGISTER_SUCCESS', 4 | REGISTER_FAILURE: 'USERS_REGISTER_FAILURE', 5 | 6 | LOGIN_REQUEST: 'USERS_LOGIN_REQUEST', 7 | LOGIN_SUCCESS: 'USERS_LOGIN_SUCCESS', 8 | LOGIN_FAILURE: 'USERS_LOGIN_FAILURE', 9 | 10 | LOGOUT: 'USERS_LOGOUT', 11 | 12 | GETALL_REQUEST: 'USERS_GETALL_REQUEST', 13 | GETALL_SUCCESS: 'USERS_GETALL_SUCCESS', 14 | GETALL_FAILURE: 'USERS_GETALL_FAILURE', 15 | 16 | DELETE_REQUEST: 'USERS_DELETE_REQUEST', 17 | DELETE_SUCCESS: 'USERS_DELETE_SUCCESS', 18 | DELETE_FAILURE: 'USERS_DELETE_FAILURE' 19 | }; 20 | -------------------------------------------------------------------------------- /web/src/context/Status/index.js: -------------------------------------------------------------------------------- 1 | // contexts/User/index.jsx 2 | 3 | import React from 'react'; 4 | import { initialState, reducer } from './reducer'; 5 | 6 | export const StatusContext = React.createContext({ 7 | state: initialState, 8 | dispatch: () => null, 9 | }); 10 | 11 | export const StatusProvider = ({ children }) => { 12 | const [state, dispatch] = React.useReducer(reducer, initialState); 13 | 14 | return ( 15 | 16 | {children} 17 | 18 | ); 19 | }; 20 | -------------------------------------------------------------------------------- /web/src/context/Status/reducer.js: -------------------------------------------------------------------------------- 1 | export const reducer = (state, action) => { 2 | switch (action.type) { 3 | case 'set': 4 | return { 5 | ...state, 6 | status: action.payload, 7 | }; 8 | case 'unset': 9 | return { 10 | ...state, 11 | status: undefined, 12 | }; 13 | default: 14 | return state; 15 | } 16 | }; 17 | 18 | export const initialState = { 19 | status: undefined, 20 | }; 21 | -------------------------------------------------------------------------------- /web/src/context/User/index.js: -------------------------------------------------------------------------------- 1 | // contexts/User/index.jsx 2 | 3 | import React from "react" 4 | import { reducer, initialState } from "./reducer" 5 | 6 | export const UserContext = React.createContext({ 7 | state: initialState, 8 | dispatch: () => null 9 | }) 10 | 11 | export const UserProvider = ({ children }) => { 12 | const [state, dispatch] = React.useReducer(reducer, initialState) 13 | 14 | return ( 15 | 16 | { children } 17 | 18 | ) 19 | } 20 | -------------------------------------------------------------------------------- /web/src/context/User/reducer.js: -------------------------------------------------------------------------------- 1 | export const reducer = (state, action) => { 2 | switch (action.type) { 3 | case 'login': 4 | return { 5 | ...state, 6 | user: action.payload 7 | }; 8 | case 'logout': 9 | return { 10 | ...state, 11 | user: undefined 12 | }; 13 | 14 | default: 15 | return state; 16 | } 17 | }; 18 | 19 | export const initialState = { 20 | user: undefined 21 | }; 22 | -------------------------------------------------------------------------------- /web/src/helpers/api.js: -------------------------------------------------------------------------------- 1 | import { showError } from './utils'; 2 | import axios from 'axios'; 3 | 4 | export const API = axios.create({ 5 | baseURL: process.env.REACT_APP_SERVER ? process.env.REACT_APP_SERVER : '', 6 | }); 7 | 8 | API.interceptors.response.use( 9 | (response) => response, 10 | (error) => { 11 | showError(error); 12 | } 13 | ); 14 | -------------------------------------------------------------------------------- /web/src/helpers/auth-header.js: -------------------------------------------------------------------------------- 1 | export function authHeader() { 2 | // return authorization header with jwt token 3 | let user = JSON.parse(localStorage.getItem('user')); 4 | 5 | if (user && user.token) { 6 | return { 'Authorization': 'Bearer ' + user.token }; 7 | } else { 8 | return {}; 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /web/src/helpers/history.js: -------------------------------------------------------------------------------- 1 | import { createBrowserHistory } from 'history'; 2 | 3 | export const history = createBrowserHistory(); 4 | -------------------------------------------------------------------------------- /web/src/helpers/index.js: -------------------------------------------------------------------------------- 1 | export * from './history'; 2 | export * from './auth-header'; 3 | export * from './utils'; 4 | export * from './api'; 5 | -------------------------------------------------------------------------------- /web/src/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | margin: 0; 3 | padding-top: 55px; 4 | overflow-y: scroll; 5 | font-family: Lato, 'Helvetica Neue', Arial, Helvetica, "Microsoft YaHei", sans-serif; 6 | -webkit-font-smoothing: antialiased; 7 | -moz-osx-font-smoothing: grayscale; 8 | scrollbar-width: none; 9 | color: var(--semi-color-text-0) !important; 10 | background-color: var( --semi-color-bg-0) !important; 11 | height: 100%; 12 | } 13 | 14 | #root { 15 | height: 100%; 16 | } 17 | 18 | @media only screen and (max-width: 767px) { 19 | .semi-table-tbody, .semi-table-row, .semi-table-row-cell { 20 | display: block!important; 21 | width: auto!important; 22 | padding: 2px!important; 23 | } 24 | .semi-table-row-cell { 25 | border-bottom: 0!important; 26 | } 27 | .semi-table-tbody>.semi-table-row { 28 | border-bottom: 1px solid rgba(0,0,0,.1); 29 | } 30 | .semi-space { 31 | display: block!important; 32 | } 33 | } 34 | 35 | .semi-layout { 36 | height: 100%; 37 | } 38 | 39 | .tableShow { 40 | display: revert; 41 | } 42 | 43 | .tableHiddle { 44 | display: none !important; 45 | } 46 | 47 | body::-webkit-scrollbar { 48 | display: none; 49 | } 50 | 51 | code { 52 | font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', monospace; 53 | } 54 | 55 | .semi-navigation-vertical { 56 | /*display: flex;*/ 57 | /*flex-direction: column;*/ 58 | } 59 | 60 | .semi-navigation-item { 61 | margin-bottom: 0; 62 | } 63 | 64 | .semi-navigation-vertical { 65 | /*flex: 0 0 auto;*/ 66 | /*display: flex;*/ 67 | /*flex-direction: column;*/ 68 | /*width: 100%;*/ 69 | height: 100%; 70 | overflow: hidden; 71 | } 72 | 73 | .main-content { 74 | padding: 4px; 75 | height: 100%; 76 | } 77 | 78 | .small-icon .icon { 79 | font-size: 1em !important; 80 | } 81 | 82 | .custom-footer { 83 | font-size: 1.1em; 84 | } 85 | 86 | @media only screen and (max-width: 600px) { 87 | .hide-on-mobile { 88 | display: none !important; 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /web/src/index.js: -------------------------------------------------------------------------------- 1 | import { initVChartSemiTheme } from '@visactor/vchart-semi-theme'; 2 | import React from 'react'; 3 | import ReactDOM from 'react-dom/client'; 4 | import {BrowserRouter} from 'react-router-dom'; 5 | import App from './App'; 6 | import HeaderBar from './components/HeaderBar'; 7 | import Footer from './components/Footer'; 8 | import 'semantic-ui-css/semantic.min.css'; 9 | import './index.css'; 10 | import {UserProvider} from './context/User'; 11 | import {ToastContainer} from 'react-toastify'; 12 | import 'react-toastify/dist/ReactToastify.css'; 13 | import {StatusProvider} from './context/Status'; 14 | import {Layout} from "@douyinfe/semi-ui"; 15 | import SiderBar from "./components/SiderBar"; 16 | 17 | // initialization 18 | initVChartSemiTheme({ 19 | isWatchingThemeSwitch: true, 20 | }); 21 | 22 | const root = ReactDOM.createRoot(document.getElementById('root')); 23 | const {Sider, Content, Header} = Layout; 24 | root.render( 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | ); 55 | -------------------------------------------------------------------------------- /web/src/pages/About/index.js: -------------------------------------------------------------------------------- 1 | import React, { useEffect, useState } from 'react'; 2 | import { API, showError } from '../../helpers'; 3 | import { marked } from 'marked'; 4 | import {Layout} from "@douyinfe/semi-ui"; 5 | 6 | const About = () => { 7 | const [about, setAbout] = useState(''); 8 | const [aboutLoaded, setAboutLoaded] = useState(false); 9 | 10 | const displayAbout = async () => { 11 | setAbout(localStorage.getItem('about') || ''); 12 | const res = await API.get('/api/about'); 13 | const { success, message, data } = res.data; 14 | if (success) { 15 | let aboutContent = data; 16 | if (!data.startsWith('https://')) { 17 | aboutContent = marked.parse(data); 18 | } 19 | setAbout(aboutContent); 20 | localStorage.setItem('about', aboutContent); 21 | } else { 22 | showError(message); 23 | setAbout('加载关于内容失败...'); 24 | } 25 | setAboutLoaded(true); 26 | }; 27 | 28 | useEffect(() => { 29 | displayAbout().then(); 30 | }, []); 31 | 32 | return ( 33 | <> 34 | { 35 | aboutLoaded && about === '' ? <> 36 | 37 | 38 | 关于 39 | 40 | 41 | 42 | 可在设置页面设置关于内容,支持 HTML & Markdown 43 | 44 | new-api项目仓库地址: 45 | 46 | https://github.com/Calcium-Ion/new-api 47 | 48 | 49 | NewAPI © 2023 CalciumIon | 基于 One API v0.5.4 © 2023 JustSong。本项目根据MIT许可证授权。 50 | 51 | 52 | 53 | > : <> 54 | { 55 | about.startsWith('https://') ? : 59 | } 60 | > 61 | } 62 | > 63 | ); 64 | }; 65 | 66 | 67 | export default About; 68 | -------------------------------------------------------------------------------- /web/src/pages/Channel/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ChannelsTable from '../../components/ChannelsTable'; 3 | import {Layout} from "@douyinfe/semi-ui"; 4 | 5 | const File = () => ( 6 | <> 7 | 8 | 9 | 管理渠道 10 | 11 | 12 | 13 | 14 | 15 | > 16 | ); 17 | 18 | export default File; 19 | -------------------------------------------------------------------------------- /web/src/pages/Chat/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | 3 | const Chat = () => { 4 | const chatLink = localStorage.getItem('chat_link'); 5 | 6 | return ( 7 | 11 | ); 12 | }; 13 | 14 | 15 | export default Chat; 16 | -------------------------------------------------------------------------------- /web/src/pages/Log/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import LogsTable from '../../components/LogsTable'; 3 | 4 | const Token = () => ( 5 | <> 6 | 7 | > 8 | ); 9 | 10 | export default Token; 11 | -------------------------------------------------------------------------------- /web/src/pages/Midjourney/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import MjLogsTable from '../../components/MjLogsTable'; 3 | 4 | const Midjourney = () => ( 5 | <> 6 | 7 | > 8 | ); 9 | 10 | export default Midjourney; 11 | -------------------------------------------------------------------------------- /web/src/pages/NotFound/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { Message } from 'semantic-ui-react'; 3 | 4 | const NotFound = () => ( 5 | <> 6 | 7 | 页面不存在 8 | 请检查你的浏览器地址是否正确 9 | 10 | > 11 | ); 12 | 13 | export default NotFound; 14 | -------------------------------------------------------------------------------- /web/src/pages/Redemption/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import RedemptionsTable from '../../components/RedemptionsTable'; 3 | import {Layout} from "@douyinfe/semi-ui"; 4 | 5 | const Redemption = () => ( 6 | <> 7 | 8 | 9 | 管理兑换码 10 | 11 | 12 | 13 | 14 | 15 | > 16 | ); 17 | 18 | export default Redemption; 19 | -------------------------------------------------------------------------------- /web/src/pages/Setting/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import SystemSetting from '../../components/SystemSetting'; 3 | import {isRoot} from '../../helpers'; 4 | import OtherSetting from '../../components/OtherSetting'; 5 | import PersonalSetting from '../../components/PersonalSetting'; 6 | import OperationSetting from '../../components/OperationSetting'; 7 | import {Layout, TabPane, Tabs} from "@douyinfe/semi-ui"; 8 | 9 | const Setting = () => { 10 | let panes = [ 11 | { 12 | tab: '个人设置', 13 | content: , 14 | itemKey: '1' 15 | } 16 | ]; 17 | 18 | if (isRoot()) { 19 | panes.push({ 20 | tab: '运营设置', 21 | content: , 22 | itemKey: '2' 23 | }); 24 | panes.push({ 25 | tab: '系统设置', 26 | content: , 27 | itemKey: '3' 28 | }); 29 | panes.push({ 30 | tab: '其他设置', 31 | content: , 32 | itemKey: '4' 33 | }); 34 | } 35 | 36 | return ( 37 | 38 | 39 | 40 | 41 | {panes.map(pane => ( 42 | 43 | {pane.content} 44 | 45 | ))} 46 | 47 | 48 | 49 | 50 | ); 51 | }; 52 | 53 | export default Setting; 54 | -------------------------------------------------------------------------------- /web/src/pages/Token/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import TokensTable from '../../components/TokensTable'; 3 | import {Layout} from "@douyinfe/semi-ui"; 4 | const Token = () => ( 5 | <> 6 | 7 | 8 | 我的令牌 9 | 10 | 11 | 12 | 13 | 14 | > 15 | ); 16 | 17 | export default Token; 18 | -------------------------------------------------------------------------------- /web/src/pages/User/AddUser.js: -------------------------------------------------------------------------------- 1 | import React, {useState} from 'react'; 2 | import {API, isMobile, showError, showSuccess} from '../../helpers'; 3 | import Title from "@douyinfe/semi-ui/lib/es/typography/title"; 4 | import {Button, SideSheet, Space, Input, Spin} from "@douyinfe/semi-ui"; 5 | 6 | const AddUser = (props) => { 7 | const originInputs = { 8 | username: '', 9 | display_name: '', 10 | password: '', 11 | }; 12 | const [inputs, setInputs] = useState(originInputs); 13 | const [loading, setLoading] = useState(false); 14 | const {username, display_name, password} = inputs; 15 | 16 | const handleInputChange = (name, value) => { 17 | setInputs((inputs) => ({...inputs, [name]: value})); 18 | }; 19 | 20 | const submit = async () => { 21 | setLoading(true); 22 | if (inputs.username === '' || inputs.password === '') return; 23 | const res = await API.post(`/api/user/`, inputs); 24 | const {success, message} = res.data; 25 | if (success) { 26 | showSuccess('用户账户创建成功!'); 27 | setInputs(originInputs); 28 | props.refresh(); 29 | props.handleClose(); 30 | } else { 31 | showError(message); 32 | } 33 | setLoading(false); 34 | }; 35 | 36 | const handleCancel = () => { 37 | props.handleClose(); 38 | } 39 | 40 | return ( 41 | <> 42 | {'添加用户'}} 45 | headerStyle={{borderBottom: '1px solid var(--semi-color-border)'}} 46 | bodyStyle={{borderBottom: '1px solid var(--semi-color-border)'}} 47 | visible={props.visible} 48 | footer={ 49 | 50 | 51 | 提交 52 | 取消 53 | 54 | 55 | } 56 | closeIcon={null} 57 | onCancel={() => handleCancel()} 58 | width={isMobile() ? '100%' : 600} 59 | > 60 | 61 | handleInputChange('username', value)} 68 | value={username} 69 | autoComplete="off" 70 | /> 71 | handleInputChange('display_name', value)} 79 | value={display_name} 80 | /> 81 | handleInputChange('password', value)} 89 | value={password} 90 | autoComplete="off" 91 | /> 92 | 93 | 94 | > 95 | ); 96 | }; 97 | 98 | export default AddUser; 99 | -------------------------------------------------------------------------------- /web/src/pages/User/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import UsersTable from '../../components/UsersTable'; 3 | import {Layout} from "@douyinfe/semi-ui"; 4 | 5 | const User = () => ( 6 | <> 7 | 8 | 9 | 管理用户 10 | 11 | 12 | 13 | 14 | 15 | > 16 | ); 17 | 18 | export default User; 19 | -------------------------------------------------------------------------------- /web/vercel.json: -------------------------------------------------------------------------------- 1 | { 2 | "github": { 3 | "silent": true 4 | } 5 | } 6 | --------------------------------------------------------------------------------
42 | 可在设置页面设置关于内容,支持 HTML & Markdown 43 |
49 | NewAPI © 2023 CalciumIon | 基于 One API v0.5.4 © 2023 JustSong。本项目根据MIT许可证授权。 50 |
请检查你的浏览器地址是否正确