├── .github ├── assets │ ├── aoa.jpg │ ├── dockerhub.png │ └── invoke-easy.jpg └── workflows │ ├── ci.yml │ └── release.yml ├── .gitignore ├── .goreleaser.yaml ├── Dockerfile ├── LICENSE ├── README.md ├── cmd ├── cmd.go └── flags.go ├── example ├── docker-compose.azure-gpt4v.yml ├── docker-compose.azure.yml ├── docker-compose.gemini.yml ├── docker-compose.yi.yml ├── gpt-vision.md ├── openai-chat-completion.py └── openai-chat-stream.py ├── go.mod ├── go.sum ├── internal ├── define │ └── define.go ├── fn │ ├── cmd.go │ ├── cmd_test.go │ ├── gunzip.go │ ├── models.go │ └── models_test.go ├── model │ └── flags.go ├── network │ ├── http_proxy.go │ ├── http_proxy_test.go │ ├── response_err.go │ └── response_err_test.go ├── router │ ├── misc.go │ └── router.go └── version │ └── version.go ├── main.go ├── models ├── azure │ ├── azure.go │ ├── azure_test.go │ ├── define.go │ ├── model.go │ ├── model_test.go │ ├── proxy.go │ └── proxy_test.go ├── gemini │ ├── define.go │ ├── gemini.go │ ├── model.go │ └── proxy.go └── yi │ ├── define.go │ ├── model.go │ ├── model_test.go │ ├── proxy.go │ ├── proxy_test.go │ ├── yi.go │ └── yi_test.go └── pkg └── logger ├── gin-logrus.go └── logger.go /.github/assets/aoa.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soulteary/amazing-openai-api/a3f0c4848a8c8ce21df6cb3c258f8c28a7a612e5/.github/assets/aoa.jpg -------------------------------------------------------------------------------- /.github/assets/dockerhub.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soulteary/amazing-openai-api/a3f0c4848a8c8ce21df6cb3c258f8c28a7a612e5/.github/assets/dockerhub.png -------------------------------------------------------------------------------- /.github/assets/invoke-easy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soulteary/amazing-openai-api/a3f0c4848a8c8ce21df6cb3c258f8c28a7a612e5/.github/assets/invoke-easy.jpg -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | paths-ignore: 7 | - 'docs/**' 8 | - 'assets/**' 9 | - '**/*.gitignore' 10 | - '**/*.md' 11 | pull_request: 12 | branches: [ "main" ] 13 | paths-ignore: 14 | - 'docs/**' 15 | - 'assets/**' 16 | - '**/*.gitignore' 17 | - '**/*.md' 18 | 19 | jobs: 20 | build: 21 | runs-on: ubuntu-latest 22 | steps: 23 | - name: Checkout 24 | uses: actions/checkout@v3 25 | with: 26 | fetch-depth: '0' 27 | - name: Set up Go 28 | uses: actions/setup-go@v3 29 | with: 30 | cache: false 31 | go-version-file: go.mod 32 | 33 | - name: Verify gofmt 34 | run: | 35 | go fmt ./... && git add cmd internal models pkg && 36 | git diff --cached --exit-code || (echo 'Please run "make fmt" to verify gofmt' && exit 1); 37 | - name: Verify govet 38 | run: | 39 | go vet ./... && git add cmd internal models pkg && 40 | git diff --cached --exit-code || (echo 'Please run "make vet" to verify govet' && exit 1); 41 | 42 | - name: Build 43 | run: CGO_ENABLED=0 go build -trimpath -ldflags "-s -w" -o aoa . -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | tag: 7 | default: 'latest' 8 | required: true 9 | description: 'Docker image tag' 10 | push: 11 | tags: 12 | - 'v*' 13 | 14 | permissions: 15 | contents: read 16 | packages: write 17 | 18 | jobs: 19 | build-image: 20 | runs-on: ubuntu-latest 21 | steps: 22 | - name: Set up QEMU 23 | uses: docker/setup-qemu-action@v2 24 | - name: Set up Docker Buildx 25 | uses: docker/setup-buildx-action@v2 26 | 27 | - name: Login to Docker Hub 28 | uses: docker/login-action@v2 29 | with: 30 | username: ${{ secrets.DOCKERHUB_USERNAME }} 31 | password: ${{ secrets.DOCKERHUB_TOKEN }} 32 | 33 | - name: Login to the GPR 34 | uses: docker/login-action@v2 35 | with: 36 | registry: ghcr.io 37 | username: ${{ github.repository_owner }} 38 | password: ${{ secrets.GITHUB_TOKEN }} 39 | 40 | - name: Parse Tag Name 41 | run: | 42 | if [ x${{ github.event.inputs.tag }} == x"" ]; then 43 | echo "TAG_NAME=${{ github.ref_name }}" >> $GITHUB_ENV 44 | else 45 | echo "TAG_NAME=${{ github.event.inputs.tag }}" >> $GITHUB_ENV 46 | fi 47 | 48 | - name: Build and push 49 | uses: docker/build-push-action@v4 50 | env: 51 | BUILDX_NO_DEFAULT_ATTESTATIONS: 1 # https://github.com/orgs/community/discussions/45969 52 | with: 53 | platforms: linux/amd64,linux/arm64 54 | push: true 55 | pull: true 56 | labels: | 57 | org.opencontainers.image.source=${{ github.server_url }}/${{ github.repository }} 58 | org.opencontainers.image.licenses=Apache-2.0 59 | tags: | 60 | ${{ github.repository }}:${{ env.TAG_NAME }} 61 | ghcr.io/${{ github.repository }}:${{ env.TAG_NAME }} 62 | cache-from: type=gha # https://docs.docker.com/build/cache/backends/gha/ 63 | cache-to: type=gha,mode=max 64 | 65 | goreleaser: 66 | permissions: write-all 67 | runs-on: ubuntu-latest 68 | if: ${{ github.event.inputs.tag == '' }} 69 | steps: 70 | - name: Checkout 71 | uses: actions/checkout@v3 72 | with: 73 | fetch-depth: 0 74 | - name: Fetch all tags 75 | run: git fetch --force --tags 76 | - name: Set up Go 77 | uses: actions/setup-go@v4 78 | with: 79 | cache: false 80 | go-version-file: go.mod 81 | - name: Run GoReleaser 82 | uses: goreleaser/goreleaser-action@v4 83 | with: 84 | distribution: goreleaser 85 | version: latest 86 | args: release --clean 87 | env: 88 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | 17 | .idea/ 18 | bin/ 19 | release/ 20 | docker-compose.yml 21 | dist/ 22 | config/config.yaml 23 | test*.sh 24 | 25 | -------------------------------------------------------------------------------- /.goreleaser.yaml: -------------------------------------------------------------------------------- 1 | # Make sure to check the documentation at https://goreleaser.com 2 | env: 3 | - GIT_URL=https://github.com/soulteary/amazing-openai-api 4 | before: 5 | hooks: 6 | - go mod tidy 7 | builds: 8 | - id: amazing-openai-api 9 | env: 10 | - CGO_ENABLED=0 11 | goos: 12 | - linux 13 | - windows 14 | - darwin 15 | goarch: 16 | - amd64 17 | main: ./ 18 | binary: aoa 19 | flags: 20 | - -trimpath 21 | ldflags: 22 | - -s -w 23 | - -X github.com/soulteary/amazing-openai-api/internal/version.Version={{ .Version }} 24 | - -X github.com/soulteary/amazing-openai-api/internal/version.BuildDate={{ .Date }} 25 | - -X github.com/soulteary/amazing-openai-api/internal/version.GitCommit={{ .Commit }} 26 | 27 | archives: 28 | - format: tar.gz 29 | # this name template makes the OS and Arch compatible with the results of uname. 30 | name_template: >- 31 | {{ .ProjectName }}_ 32 | {{- .Version }}_ 33 | {{- .Os }}_ 34 | {{- if eq .Arch "amd64" }}x86_64 35 | {{- else if eq .Arch "386" }}i386 36 | {{- else }}{{ .Arch }}{{ end }} 37 | {{- if .Arm }}v{{ .Arm }}{{ end }} 38 | # use zip for windows archives 39 | format_overrides: 40 | - goos: windows 41 | format: zip 42 | checksum: 43 | name_template: 'checksums.txt' 44 | snapshot: 45 | name_template: "{{ incpatch .Version }}-next" 46 | 47 | # https://goreleaser.com/customization/changelog/ 48 | changelog: 49 | sort: asc 50 | use: github 51 | filters: 52 | exclude: 53 | - '^build:' 54 | - '^ci:' 55 | # - '^docs:' 56 | - '^test:' 57 | - '^chore:' 58 | - '^feat(deps):' 59 | - 'merge conflict' 60 | - Merge pull request 61 | - Merge remote-tracking branch 62 | - Merge branch 63 | - go mod tidy 64 | - '^Update' 65 | groups: 66 | - title: Dependency updates 67 | regexp: '^.*?(feat|fix)\(deps\)!?:.+$' 68 | order: 300 69 | - title: 'New Features' 70 | regexp: '^.*?feat(\([[:word:]]+\))??!?:.+$' 71 | order: 100 72 | - title: 'Security updates' 73 | regexp: '^.*?sec(\([[:word:]]+\))??!?:.+$' 74 | order: 150 75 | - title: 'Bug fixes' 76 | regexp: '^.*?fix(\([[:word:]]+\))??!?:.+$' 77 | order: 200 78 | - title: 'Documentation updates' 79 | regexp: '^.*?doc(\([[:word:]]+\))??!?:.+$' 80 | order: 400 81 | # - title: 'Build process updates' 82 | # regexp: '^.*?build(\([[:word:]]+\))??!?:.+$' 83 | # order: 400 84 | - title: Other work 85 | order: 9999 86 | release: 87 | footer: | 88 | **Full Changelog**: https://github.com/soulteary/amazing-openai-api/compare/{{ .PreviousTag }}...{{ .Tag }} 89 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.22.0-alpine AS builder 2 | RUN apk update && apk upgrade \ 3 | && apk add --no-cache ca-certificates tzdata \ 4 | && update-ca-certificates 2>/dev/null || true 5 | RUN apk add --no-cache make git gcc g++ libc-dev 6 | ENV GO111MODULE=on 7 | ENV CGO_ENABLED=0 8 | ENV GOOS=linux 9 | WORKDIR /build 10 | ADD go.mod go.sum ./ 11 | RUN go mod download 12 | COPY . . 13 | RUN go build -trimpath -ldflags "-s -w" -o aoa . 14 | 15 | FROM alpine:3.18.0 16 | RUN apk update && apk upgrade \ 17 | && apk add --no-cache ca-certificates tzdata \ 18 | && update-ca-certificates 2>/dev/null || true 19 | WORKDIR /app 20 | EXPOSE 8080 21 | COPY --from=builder /build/aoa . 22 | ENTRYPOINT ["/app/aoa"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Amazing OpenAI API 2 | 3 | ![](.github/assets/aoa.jpg) 4 | 5 | Convert different model APIs into the OpenAI API format out of the box. 6 | 7 | 10MB+的小工具,能够将各种不同的模型 API 转换为开箱即用的 OpenAI API 格式。 8 | 9 | 当前支持模型: 10 | 11 | - Azure OpenAI API (GPT 3.5/4), GPT4 Vision (GPT4v) 12 | - YI 34B API 13 | - Google Gemini Pro 14 | 15 | 16 | ## 下载 📦 17 | 18 | 访问 [GitHub Release 页面](https://github.com/soulteary/amazing-openai-api/releases),下载适合你的操作系统的执行文件。 19 | 20 | ![](.github/assets/dockerhub.png) 21 | 22 | 或者使用 Docker Pull,下载指定版本的镜像文件: 23 | 24 | ```bash 25 | docker pull soulteary/amazing-openai-api:v0.7.0 26 | ``` 27 | 28 | ## 快速上手 29 | 30 | `AOA` 不需要编写任何配置文件,通过指定环境变量就能够完成应用行为的调整,包括“选择工作模型”、“设置模型运行需要的参数”、“设置模型兼容别名”。 31 | 32 | 默认执行 `./aoa` ,程序会将工作模型设置为 `azure`,此时我们设置环境变量 `AZURE_ENDPOINT=https://你的部署名称.openai.azure.com/` 然后就可以正常使用服务啦。 33 | 34 | ```bash 35 | AZURE_ENDPOINT=https://你的部署名称.openai.azure.com/ ./aoa 36 | ``` 37 | 38 | 如果你更喜欢 Docker,可以用下面的命令: 39 | 40 | ```bash 41 | docker run --rm -it -e AZURE_ENDPOINT=https://你的部署名称.openai.azure.com/ -p 8080:8080 soulteary/amazing-openai-api:v0.7.0 42 | ``` 43 | 44 | 当服务启动之后,我们就可以通过访问 `http://localhost:8080/v1` 来访问和 OpenAI 一样的 API 服务啦。 45 | 46 | 你可以使用 `curl` 来进行一个快速测试: 47 | 48 | ```bash 49 | curl -v http://127.0.0.1:8080/v1/chat/completions \ 50 | -H "Content-Type: application/json" \ 51 | -H "Authorization: Bearer 123" \ 52 | -d '{ 53 | "model": "gpt-4", 54 | "messages": [ 55 | { 56 | "role": "system", 57 | "content": "You are a poetic assistant, skilled in explaining complex programming concepts with creative flair." 58 | }, 59 | { 60 | "role": "user", 61 | "content": "Compose a poem that explains the concept of recursion in programming." 62 | } 63 | ] 64 | }' 65 | ``` 66 | 67 | 也可以使用 OpenAI 官方 SDK 进行调用,或者任意兼容 OpenAI 的开源软件进行使用(更多例子,参考 [example](./example/)): 68 | 69 | ```python 70 | from openai import OpenAI 71 | 72 | client = OpenAI( 73 | api_key="your-key-or-input-something-as-you-like", 74 | base_url="http://127.0.0.1:8080/v1" 75 | ) 76 | 77 | chat_completion = client.chat.completions.create( 78 | messages=[ 79 | { 80 | "role": "user", 81 | "content": "Say this is a test", 82 | } 83 | ], 84 | model="gpt-3.5-turbo", 85 | ) 86 | 87 | print(chat_completion) 88 | ``` 89 | 90 | 你如果你希望不要将 API Key 暴露给应用,或者不放心各种复杂的开源软件是否有 API Key 泄漏风险,我们可以多配置一个 `AZURE_API_KEY=你的 API Key` 环境变量,然后各种开源软件在请求的时候就无需再填写 API key 了(或者随便填写也行)。 91 | 92 | 当然,因为 Azure 的一些限制,以及一些开源软件中的模型调用名称不好调整,我们可以通过下面的方式,来将原始请求中的模型,映射为我们真实的模型名称。比如,将 GPT 3.5/4 都替换为 `yi-34b-chat`: 93 | 94 | ```bash 95 | gpt-3.5-turbo:yi-34b-chat,gpt-4:yi-34b-chat 96 | ``` 97 | 98 | 如果你希望使用 `yi-34b-chat`,或者 `gemini-pro`,我们需要设置 `AOA_TYPE=yi` 或者 `AOA_TYPE=gemini`,除此之外,没有任何差别。 99 | 100 | ## 容器快速上手 101 | 102 | 项目中包含当前支持的三种模型接口的 `docker compose` 示例文件,我们将 `example` 目录中的不同的文件,按需选择使用,将必填的信息填写完毕后,将文件修改为 `docker-compose.yml`。 103 | 104 | 然后使用 `docker compose up` 启动服务,就能够快速使用啦。 105 | 106 | - [docker-compose.azure.yml](./example/docker-compose.azure.yml) 107 | - [docker-compose.azure-gpt4v.yml](./example/docker-compose.azure-gpt4v.yml) 108 | - [docker-compose.yi.yml](./example/docker-compose.yi.yml) 109 | - [docker-compose.gemini.yml](./example/docker-compose.gemini.yml) 110 | 111 | ## 详细配置使用 112 | 113 | 调整工作模型 `AOA_TYPE`,可选参数,默认为 `azure`: 114 | 115 | ```bash 116 | # 选择一个服务, "azure", "yi", "gemini" 117 | AOA_TYPE: "azure" 118 | ``` 119 | 120 | 程序服务地址,可选参数,默认为 `8080` 和 `0.0.0.0`: 121 | 122 | ```bash 123 | # 服务端口,默认 `8080` 124 | AOA_PORT: 8080 125 | # 服务地址,默认 `0.0.0.0` 126 | AOA_HOST: "0.0.0.0" 127 | ``` 128 | 129 | ## Azure 使用 130 | 131 | 如果我们想将 Azure 上部署的 OpenAI 服务转换为标准的 OpenAI 调用,可以用下面的命令: 132 | 133 | ```bash 134 | AZURE_ENDPOINT=https://<你的 Endpoint 地址>.openai.azure.com/ AZURE_API_KEY=<你的 API KEY> AZURE_MODEL_ALIAS=gpt-3.5-turbo:gpt-35 ./amazing-openai-api 135 | ``` 136 | 137 | 在上面的命令中 `AZURE_ENDPOINT` 和 `AZURE_API_KEY` 包含了 Azure OpenAI 服务中的核心要素,因为 Azure 部署 GPT 3.5 / GPT 4 的部署名称不允许包含 `.`,所以我们使用 `AZURE_MODEL_ALIAS` 将我们**请求内容中的模型名称**替换为真实的 Azure 部署名称。甚至可以使用这个技巧将各种开源、闭源软件使用的模型自动映射为我们希望的模型: 138 | 139 | ```bash 140 | # 比如不论是 3.5 还是 4 都映射为 `gpt-35` 141 | AZURE_MODEL_ALIAS=gpt-3.5-turbo:gpt-35,gpt-4:gpt-35 142 | ``` 143 | 144 | 因为我们已经配置了 `AZURE_API_KEY`,所以开源软件也好,使用 `curl` 调用也罢,都不需要添加 `Authorization: Bearer <你的 API Key>` (也可以随便写),这样就起到了严格的 API Key 隔离,提升了 API Key 的安全性。 145 | 146 | ![](.github/assets/invoke-easy.jpg) 147 | 148 | 如果你还是习惯在请求头参数中添加认证内容,可以使用下面的不包含 `AZURE_API_KEY` 的命令,程序将透传验证到 Azure 服务: 149 | 150 | ```bash 151 | AZURE_ENDPOINT=https://<你的 Endpoint 地址>.openai.azure.com/ AZURE_MODEL_ALIAS=gpt-3.5-turbo:gpt-35 ./amazing-openai-api 152 | ``` 153 | 154 | 如果你希望自己指定特别的 API Version,可以指定 `AZURE_IGNORE_API_VERSION_CHECK=true` 来强制忽略程序本身的 API Version 有效性验证。 155 | 156 | ### GPT4 Vision 157 | 158 | 如果你已经拥有了 Azure GPT Vision,除了使用 SDK 调用之外,你也可以参考这篇文档,使用 `curl` 进行调用:[GPT Vision](./example/gpt-vision.md)。 159 | 160 | ### 模型参数设置 161 | 162 | ```bash 163 | # (必选) Azure Deployment Endpoint URL 164 | AZURE_ENDPOINT 165 | # (必选) Azure API Key 166 | AZURE_API_KEY 167 | # (可选) 模型名称,默认 GPT-4 168 | AZURE_MODEL 169 | # (可选) API Version 170 | AZURE_API_VER 171 | # (可选) 是否是 Vision 实例 172 | ENV_AZURE_VISION 173 | # (可选) 模型映射别名 174 | AZURE_MODEL_ALIAS 175 | # (可选) Azure 网络代理 176 | AZURE_HTTP_PROXY 177 | AZURE_SOCKS_PROXY 178 | # (可选) 忽略 Azure API Version 检查,默认 false,始终检查 179 | AZURE_IGNORE_API_VERSION_CHECK 180 | ``` 181 | 182 | ## YI (零一万物) 183 | 184 | 如果我们想将 YI 官方的 API 转换为标准的 OpenAI 调用,可以用下面的命令: 185 | 186 | ```bash 187 | AOA_TYPE=yi YI_API_KEY=<你的 API KEY> ./amazing-openai-api 188 | ``` 189 | 190 | 和使用 Azure 服务类似,我们可以使用一个技巧将各种开源、闭源软件使用的模型自动映射为我们希望的模型: 191 | 192 | ```bash 193 | # 比如不论是 3.5 还是 4 都映射为 `gpt-35` 194 | YI_MODEL_ALIAS=gpt-3.5-turbo:yi-34b-chat,gpt-4:yi-34b-chat 195 | ``` 196 | 197 | 如果我们在启动服务的时候配置了 `YI_API_KEY` 的话,不论是开源软件也好,使用 `curl` 调用也罢,我们都不需要添加 `Authorization: Bearer <你的 API Key>` (也可以随便写),这样就起到了严格的 API Key 隔离,提升了 API Key 的安全性。 198 | 199 | 如果你还是习惯在请求头参数中添加认证内容,可以使用下面的不包含 `YI_API_KEY` 的命令,程序将透传验证到 Yi API 服务: 200 | 201 | ```bash 202 | ./amazing-openai-api 203 | ``` 204 | 205 | ### 模型参数设置 206 | 207 | ```bash 208 | # (必选) YI API Key 209 | YI_API_KEY 210 | # (可选) 模型名称,默认 yi-34b-chat 211 | YI_MODEL 212 | # (可选) YI Deployment Endpoint URL 213 | YI_ENDPOINT 214 | # (可选) API Version,默认 v1beta,可选 v1 215 | YI_API_VER 216 | # (可选) 模型映射别名 217 | YI_MODEL_ALIAS 218 | # (可选) Azure 网络代理 219 | YI_HTTP_PROXY 220 | YI_SOCKS_PROXY 221 | ``` 222 | 223 | ## Gemini PRO 224 | 225 | 如果我们想将 Google 官方的 Gemini API 转换为标准的 OpenAI 调用,可以用下面的命令: 226 | 227 | ```bash 228 | AOA_TYPE=gemini GEMINI_API_KEY=<你的 API KEY> ./amazing-openai-api 229 | ``` 230 | 231 | 和使用 Azure 服务类似,我们可以使用一个技巧将各种开源、闭源软件使用的模型自动映射为我们希望的模型: 232 | 233 | ```bash 234 | # 比如不论是 3.5 还是 4 都映射为 `gpt-35` 235 | GEMINI_MODEL_ALIAS=gpt-3.5-turbo:gemini-pro,gpt-4:gemini-pro 236 | ``` 237 | 238 | 如果我们在启动服务的时候配置了 `GEMINI_API_KEY` 的话,不论是开源软件也好,使用 `curl` 调用也罢,我们都不需要添加 `Authorization: Bearer <你的 API Key>` (也可以随便写),这样就起到了严格的 API Key 隔离,提升了 API Key 的安全性。 239 | 240 | 如果你还是习惯在请求头参数中添加认证内容,可以使用下面的不包含 `GEMINI_API_KEY` 的命令,程序将透传验证到 Google AI 服务: 241 | 242 | ```bash 243 | ./amazing-openai-api 244 | ``` 245 | 246 | ### 模型参数设置 247 | 248 | ```bash 249 | # (必选) Gemini API Key 250 | GEMINI_API_KEY 251 | 252 | # (可选) Gemini 安全设置,可选 `BLOCK_NONE` / `BLOCK_ONLY_HIGH` / `BLOCK_MEDIUM_AND_ABOVE` / `BLOCK_LOW_AND_ABOVE` / `HARM_BLOCK_THRESHOLD_UNSPECIFIED` 253 | GEMINI_SAFETY 254 | # (可选) Gemini 模型 版本,默认 `gemini-pro` 255 | GEMINI_MODEL 256 | # (可选) Gemini API 版本,默认 `v1beta` 257 | GEMINI_API_VER 258 | # (可选) Gemini API 接口地址 259 | GEMINI_ENDPOINT 260 | # (可选) 模型映射别名 261 | GEMINI_MODEL_ALIAS 262 | # (可选) Gemini 网络代理 263 | GEMINI_HTTP_PROXY 264 | GEMINI_SOCKS_PROXY 265 | ``` 266 | -------------------------------------------------------------------------------- /cmd/cmd.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "os/signal" 8 | "strconv" 9 | "syscall" 10 | "time" 11 | 12 | "github.com/gin-gonic/gin" 13 | AoaModel "github.com/soulteary/amazing-openai-api/internal/model" 14 | AoaRouter "github.com/soulteary/amazing-openai-api/internal/router" 15 | "github.com/soulteary/amazing-openai-api/internal/version" 16 | "github.com/soulteary/amazing-openai-api/models/azure" 17 | "github.com/soulteary/amazing-openai-api/models/gemini" 18 | "github.com/soulteary/amazing-openai-api/models/yi" 19 | "github.com/soulteary/amazing-openai-api/pkg/logger" 20 | ) 21 | 22 | const ( 23 | _DEFAULT_PORT = 8080 24 | _DEFAULT_HOST = "0.0.0.0" 25 | _DEFAULT_TYPE = "azure" 26 | _DEFAULT_VISION = false 27 | 28 | _ENV_KEY_NAME_PORT = "AOA_PORT" 29 | _ENV_KEY_USE_VISION = "AOA_VISION" 30 | _ENV_KEY_NAME_HOST = "AOA_HOST" 31 | _ENV_KEY_SERVICE_TYPE = "AOA_TYPE" 32 | ) 33 | 34 | // refs: https://github.com/soulteary/flare/blob/main/cmd/cmd.go 35 | func Parse() { 36 | // 1. First try to get the environment variables 37 | flags := parseEnvVars() 38 | // 2. Then try to get the command line flags, overwrite the environment variables 39 | // flags := parseCLI(envs) 40 | 41 | log := logger.GetLogger() 42 | log.Println("程序启动中 🚀") 43 | log.Println("程序版本", version.Version) 44 | log.Println("程序构建日期", version.BuildDate) 45 | log.Println("程序 Git Commit", version.GitCommit) 46 | log.Println("程序服务地址", fmt.Sprintf("%s:%d", flags.Host, flags.Port)) 47 | 48 | startDaemon(&flags) 49 | } 50 | 51 | // refs: https://github.com/soulteary/flare/blob/main/cmd/daemon.go 52 | func startDaemon(flags *AoaModel.Flags) { 53 | ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) 54 | defer stop() 55 | 56 | gin.SetMode(gin.ReleaseMode) 57 | 58 | router := gin.Default() 59 | log := logger.GetLogger() 60 | 61 | router.Use(logger.Logger(log), gin.Recovery()) 62 | 63 | AoaRouter.RegisterMiscRoute(router) 64 | 65 | switch flags.Type { 66 | case "azure": 67 | err := azure.Init() 68 | if err != nil { 69 | log.Fatalf("初始化 Azure OpenAI API 出错: %s\n", err) 70 | } 71 | case "yi": 72 | err := yi.Init() 73 | if err != nil { 74 | log.Fatalf("初始化 Yi API 出错: %s\n", err) 75 | } 76 | case "gemini": 77 | err := gemini.Init() 78 | if err != nil { 79 | log.Fatalf("初始化 Gemini API 出错: %s\n", err) 80 | } 81 | } 82 | AoaRouter.RegisterModelRoute(router, flags.Type) 83 | 84 | srv := &http.Server{ 85 | Addr: ":" + strconv.Itoa(flags.Port), 86 | Handler: router, 87 | ReadHeaderTimeout: 5 * time.Second, 88 | ReadTimeout: 5 * time.Second, 89 | } 90 | 91 | go func() { 92 | if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { 93 | log.Fatalf("程序启动出错: %s\n", err) 94 | } 95 | }() 96 | log.Println("程序已启动完毕 🚀") 97 | 98 | <-ctx.Done() 99 | 100 | stop() 101 | log.Println("程序正在关闭中,如需立即结束请按 CTRL+C") 102 | 103 | ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 104 | defer cancel() 105 | if err := srv.Shutdown(ctx); err != nil { 106 | log.Fatal("程序强制关闭: ", err) 107 | } 108 | 109 | log.Println("期待与你的再次相遇 ❤️") 110 | } 111 | -------------------------------------------------------------------------------- /cmd/flags.go: -------------------------------------------------------------------------------- 1 | package cmd 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/soulteary/amazing-openai-api/internal/fn" 7 | AoaModel "github.com/soulteary/amazing-openai-api/internal/model" 8 | ) 9 | 10 | // refs: https://github.com/soulteary/flare/blob/main/cmd/flags.go 11 | func parseEnvVars() AoaModel.Flags { 12 | // use default values 13 | flags := AoaModel.Flags{ 14 | DebugMode: false, 15 | ShowVersion: false, 16 | ShowHelp: false, 17 | 18 | Type: _DEFAULT_TYPE, 19 | Vision: _DEFAULT_VISION, 20 | Port: _DEFAULT_PORT, 21 | Host: _DEFAULT_HOST, 22 | } 23 | 24 | // check and set port 25 | flags.Port = fn.GetIntOrDefaultFromEnv(_ENV_KEY_NAME_PORT, _DEFAULT_PORT) 26 | if flags.Port <= 0 || flags.Port > 65535 { 27 | flags.Port = _DEFAULT_PORT 28 | } 29 | 30 | // check and set host 31 | flags.Host = fn.GetStringOrDefaultFromEnv(_ENV_KEY_NAME_HOST, _DEFAULT_HOST) 32 | if !fn.IsValidIPAddress(flags.Host) { 33 | flags.Host = _DEFAULT_HOST 34 | } 35 | 36 | // check and set vision 37 | flags.Vision = fn.GetBoolOrDefaultFromEnv(_ENV_KEY_USE_VISION, _DEFAULT_VISION) 38 | 39 | // check and set type 40 | flags.Type = strings.ToLower(fn.GetStringOrDefaultFromEnv(_ENV_KEY_SERVICE_TYPE, _DEFAULT_TYPE)) 41 | // TODO support all types 42 | if flags.Type != "azure" && 43 | flags.Type != "yi" && 44 | flags.Type != "gemini" { 45 | flags.Type = _DEFAULT_TYPE 46 | } 47 | return flags 48 | } 49 | 50 | // func parseCLI() { 51 | // TODO: parse command line flags 52 | // } 53 | -------------------------------------------------------------------------------- /example/docker-compose.azure-gpt4v.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | services: 4 | amazing-openai-api: 5 | image: soulteary/amazing-openai-api:v0.8.0 6 | restart: always 7 | ports: 8 | - 8080:8080 9 | environment: 10 | - AZURE_ENDPOINT=https://<修改为你的部署名称>.openai.azure.com/ 11 | - AZURE_API_KEY=<修改为你的API KEY> 12 | - AZURE_VISION=true 13 | - AZURE_MODEL=gpt-4v 14 | # 模型名称映射,比如将请求中的 GPT 3.5 Turbo 映射为 GPT 4v 15 | - AZURE_MODEL_ALIAS=gpt-3.5-turbo:gpt-4v,gpt-4:gpt4v 16 | logging: 17 | options: 18 | max-size: 1m 19 | -------------------------------------------------------------------------------- /example/docker-compose.azure.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | services: 4 | amazing-openai-api: 5 | image: soulteary/amazing-openai-api:v0.8.0 6 | restart: always 7 | ports: 8 | - 8080:8080 9 | environment: 10 | - AZURE_ENDPOINT=https://<修改为你的部署名称>.openai.azure.com/ 11 | - AZURE_API_KEY=<修改为你的API KEY> 12 | - AZURE_MODEL=gpt-4 13 | # 模型名称映射,比如将请求中的 GPT 3.5 Turbo 映射为 GPT 4 14 | - AZURE_MODEL_ALIAS=gpt-3.5-turbo:gpt-4 15 | logging: 16 | options: 17 | max-size: 1m 18 | -------------------------------------------------------------------------------- /example/docker-compose.gemini.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | services: 4 | amazing-openai-api: 5 | image: soulteary/amazing-openai-api:v0.8.0 6 | restart: always 7 | ports: 8 | - 8080:8080 9 | environment: 10 | # 设置工作模型为 Gemini 11 | - AOA_TYPE=gemini 12 | # 设置 Gemini API Key 13 | - GEMINI_API_KEY=<修改为你的API KEY> 14 | # 模型名称映射,比如将请求中的 GPT 3.5 Turbo,GPT-4 都映射为 gemini-pro 15 | - GEMINI_MODEL_ALIAS=gpt-3.5-turbo:gemini-pro,gpt-4:gemini-pro 16 | # 限制国内请求,需要使用服务器进行代理中转,或者跑在国外服务器上 17 | - https_proxy=http://10.11.12.90:7890 18 | logging: 19 | options: 20 | max-size: 1m 21 | -------------------------------------------------------------------------------- /example/docker-compose.yi.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | services: 4 | amazing-openai-api: 5 | image: soulteary/amazing-openai-api:v0.8.0 6 | restart: always 7 | ports: 8 | - 8080:8080 9 | environment: 10 | # 设置工作模型为 YI 11 | - AOA_TYPE=yi 12 | # 设置 YI API 服务器地址 13 | - YI_ENDPOINT=<修改为你申请或搭建的服务地址> 14 | # 设置 YI API Key 15 | - YI_API_KEY=<修改为你的API KEY> 16 | # 模型名称映射,比如将请求中的 GPT 3.5 Turbo,GPT-4 都映射为 yi-34b-chat 17 | - YI_MODEL_ALIAS=gpt-3.5-turbo:yi-34b-chat,gpt-4:yi-34b-chat 18 | logging: 19 | options: 20 | max-size: 1m 21 | -------------------------------------------------------------------------------- /example/gpt-vision.md: -------------------------------------------------------------------------------- 1 | # GPT Vision 2 | 3 | 如果你已经拥有了 Azure GPT4 Vision,并且想要使用 OpenAI API 的接口格式来进行调用,我们可以在使用 `azure` 服务类型时,设置 `AZURE_VISION` 的数值为 `true|1|on|yes` 任意值,激活 Vision API。 4 | 5 | ```bash 6 | AZURE_VISION=true 7 | ``` 8 | 9 | 调用方法很简单,除了使用 SDK 之外,同样可以使用 `curl`: 10 | 11 | ```bash 12 | curl -v http://127.0.0.1:8080/v1/chat/completions \ 13 | -H "Content-Type: application/json" \ 14 | -H "Authorization: Bearer 123" \ 15 | -d '{ 16 | "model": "gpt-4v", 17 | "messages":[ 18 | {"role":"system","content":"You are a helpful assistant."}, 19 | {"role":"user","content":[ 20 | {"type":"text","text":"Describe this picture:"}, 21 | { "type": "image_url", "image_url": { "url": "https://learn.microsoft.com/azure/ai-services/computer-vision/media/quickstarts/presentation.png", "detail": "high" }} 22 | ]} 23 | ] 24 | }' 25 | ``` 26 | 27 | 当然,你也可以将本地的图片 Base64 处理后,再调用中进行传递: 28 | 29 | ```bash 30 | curl -v http://127.0.0.1:8080/v1/chat/completions \ 31 | -H "Content-Type: application/json" \ 32 | -H "Authorization: Bearer 123" \ 33 | -d '{ 34 | "model": "gpt-4v", 35 | "messages":[ 36 | {"role":"system","content":"You are a helpful assistant."}, 37 | {"role":"user","content":[ 38 | {"type":"text","text":"Describe this picture:"}, 39 | { "type": "image_url", "image_url": { "url": "", "detail": "high" }} 40 | ]} 41 | ] 42 | }' 43 | ``` -------------------------------------------------------------------------------- /example/openai-chat-completion.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | client = OpenAI( 4 | api_key="your-key-or-input-something-as-you-like", 5 | base_url="http://127.0.0.1:8080/v1" 6 | ) 7 | 8 | chat_completion = client.chat.completions.create( 9 | messages=[ 10 | { 11 | "role": "user", 12 | "content": "Say this is a test", 13 | } 14 | ], 15 | model="gpt-3.5-turbo", 16 | ) 17 | 18 | print(chat_completion) -------------------------------------------------------------------------------- /example/openai-chat-stream.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | client = OpenAI( 4 | api_key="your-key-or-input-something-as-you-like", 5 | base_url="http://127.0.0.1:8080/v1" 6 | ) 7 | 8 | stream = client.chat.completions.create( 9 | model="gpt-4", 10 | messages=[{"role": "user", "content": "Write a romantic poem and talk about League of Legends"}], 11 | stream=True, 12 | ) 13 | for chunk in stream: 14 | print(chunk.choices[0].delta.content or "", end="") -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/soulteary/amazing-openai-api 2 | 3 | go 1.22 4 | 5 | require ( 6 | github.com/gin-gonic/gin v1.9.1 7 | github.com/pkg/errors v0.9.1 8 | github.com/sirupsen/logrus v1.9.3 9 | github.com/stretchr/testify v1.8.4 10 | golang.org/x/net v0.21.0 11 | ) 12 | 13 | require ( 14 | github.com/bytedance/sonic v1.11.0 // indirect 15 | github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d // indirect 16 | github.com/chenzhuoyu/iasm v0.9.1 // indirect 17 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 18 | github.com/gabriel-vasile/mimetype v1.4.3 // indirect 19 | github.com/gin-contrib/sse v0.1.0 // indirect 20 | github.com/go-playground/locales v0.14.1 // indirect 21 | github.com/go-playground/universal-translator v0.18.1 // indirect 22 | github.com/go-playground/validator/v10 v10.18.0 // indirect 23 | github.com/goccy/go-json v0.10.2 // indirect 24 | github.com/google/go-cmp v0.5.9 // indirect 25 | github.com/json-iterator/go v1.1.12 // indirect 26 | github.com/klauspost/cpuid/v2 v2.2.6 // indirect 27 | github.com/kr/pretty v0.3.1 // indirect 28 | github.com/leodido/go-urn v1.4.0 // indirect 29 | github.com/mattn/go-isatty v0.0.20 // indirect 30 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 31 | github.com/modern-go/reflect2 v1.0.2 // indirect 32 | github.com/pelletier/go-toml/v2 v2.1.1 // indirect 33 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 34 | github.com/stretchr/objx v0.5.0 // indirect 35 | github.com/twitchyliquid64/golang-asm v0.15.1 // indirect 36 | github.com/ugorji/go/codec v1.2.12 // indirect 37 | golang.org/x/arch v0.7.0 // indirect 38 | golang.org/x/crypto v0.19.0 // indirect 39 | golang.org/x/sys v0.17.0 // indirect 40 | golang.org/x/text v0.14.0 // indirect 41 | google.golang.org/protobuf v1.32.0 // indirect 42 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect 43 | gopkg.in/yaml.v3 v3.0.1 // indirect 44 | ) 45 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= 2 | github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM= 3 | github.com/bytedance/sonic v1.11.0 h1:FwNNv6Vu4z2Onf1++LNzxB/QhitD8wuTdpZzMTGITWo= 4 | github.com/bytedance/sonic v1.11.0/go.mod h1:iZcSUejdk5aukTND/Eu/ivjQuEL0Cu9/rf50Hi0u/g4= 5 | github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= 6 | github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= 7 | github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d h1:77cEq6EriyTZ0g/qfRdp61a3Uu/AWrgIq2s0ClJV1g0= 8 | github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d/go.mod h1:8EPpVsBuRksnlj1mLy4AWzRNQYxauNi62uWcE3to6eA= 9 | github.com/chenzhuoyu/iasm v0.9.0/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog= 10 | github.com/chenzhuoyu/iasm v0.9.1 h1:tUHQJXo3NhBqw6s33wkGn9SP3bvrWLdlVIJ3hQBL7P0= 11 | github.com/chenzhuoyu/iasm v0.9.1/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog= 12 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 13 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 14 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 15 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= 16 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 17 | github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= 18 | github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= 19 | github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= 20 | github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= 21 | github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= 22 | github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= 23 | github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= 24 | github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= 25 | github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= 26 | github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= 27 | github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= 28 | github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= 29 | github.com/go-playground/validator/v10 v10.18.0 h1:BvolUXjp4zuvkZ5YN5t7ebzbhlUtPsPm2S9NAZ5nl9U= 30 | github.com/go-playground/validator/v10 v10.18.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= 31 | github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= 32 | github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= 33 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 34 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 35 | github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= 36 | github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= 37 | github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= 38 | github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= 39 | github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= 40 | github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= 41 | github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= 42 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 43 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 44 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 45 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 46 | github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= 47 | github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= 48 | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= 49 | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= 50 | github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 51 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= 52 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 53 | github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= 54 | github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= 55 | github.com/pelletier/go-toml/v2 v2.1.1 h1:LWAJwfNvjQZCFIDKWYQaM62NcYeYViCmWIwmOStowAI= 56 | github.com/pelletier/go-toml/v2 v2.1.1/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= 57 | github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= 58 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 59 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 60 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 61 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= 62 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 63 | github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= 64 | github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= 65 | github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= 66 | github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= 67 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 68 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 69 | github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= 70 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 71 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 72 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 73 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 74 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 75 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 76 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 77 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 78 | github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= 79 | github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= 80 | github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= 81 | github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= 82 | golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= 83 | golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc= 84 | golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= 85 | golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= 86 | golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= 87 | golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= 88 | golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= 89 | golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 90 | golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 91 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 92 | golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= 93 | golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 94 | golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= 95 | golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 96 | google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= 97 | google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= 98 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 99 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= 100 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 101 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 102 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 103 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 104 | nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= 105 | rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= 106 | -------------------------------------------------------------------------------- /internal/define/define.go: -------------------------------------------------------------------------------- 1 | package define 2 | 3 | import "net/url" 4 | 5 | type ModelConfig struct { 6 | Name string `yaml:"name" json:"name"` 7 | Endpoint string `yaml:"endpoint" json:"endpoint"` 8 | Model string `yaml:"model" json:"model"` 9 | Version string `yaml:"version" json:"version"` 10 | Key string `yaml:"key" json:"key"` 11 | URL *url.URL 12 | Alias string 13 | Vision bool 14 | } 15 | 16 | type ModelAlias [][]string 17 | 18 | // openai api payload 19 | 20 | type OpenAI_Payload_Model struct { 21 | Model string `json:"model"` 22 | } 23 | 24 | type OpenAI_Payload struct { 25 | MaxTokens int `json:"max_tokens,omitempty"` 26 | Model string `json:"model"` 27 | Temperature float64 `json:"temperature,omitempty"` 28 | TopP float64 `json:"top_p,omitempty"` 29 | PresencePenalty float64 `json:"presence_penalty,omitempty"` 30 | Messages []Message `json:"messages"` 31 | Stream bool `json:"stream,omitempty"` 32 | } 33 | 34 | type OpenAI_Vision_Payload struct { 35 | MaxTokens int `json:"max_tokens,omitempty"` 36 | Model string `json:"model"` 37 | Temperature float64 `json:"temperature,omitempty"` 38 | TopP float64 `json:"top_p,omitempty"` 39 | PresencePenalty float64 `json:"presence_penalty,omitempty"` 40 | Stream bool `json:"stream,omitempty"` 41 | Messages []any `json:"messages"` 42 | } 43 | 44 | type VisionMessage struct { 45 | Role string `json:"role"` 46 | Content []VisionMessageContent `json:"content"` 47 | } 48 | 49 | type VisionMessageContent struct { 50 | Type string `json:"type"` 51 | Text string `json:"text,omitempty"` 52 | ImageURL VisionMessageContentImageURL `json:"image_url"` 53 | } 54 | 55 | type VisionMessageContentImageURL struct { 56 | URL string `json:"url"` 57 | Detail string `json:"detail"` 58 | } 59 | 60 | type Message struct { 61 | Role string `json:"role"` 62 | Content string `json:"content"` 63 | } 64 | 65 | type OpenAI_Usage struct { 66 | CompletionTokens int `json:"completion_tokens"` 67 | PromptTokens int `json:"prompt_tokens"` 68 | TotalTokens int `json:"total_tokens"` 69 | } 70 | 71 | type OpenAI_Choices struct { 72 | Index int `json:"index"` 73 | Message Message `json:"message"` 74 | FinishReason string `json:"finish_reason"` 75 | } 76 | 77 | type OpeAI_Response struct { 78 | ID string `json:"id"` 79 | Object string `json:"object"` 80 | Created int `json:"created"` 81 | Model string `json:"model"` 82 | Usage OpenAI_Usage `json:"usage"` 83 | Choices []OpenAI_Choices `json:"choices"` 84 | // openai extra fields 85 | SystemFingerprint string `json:"system_fingerprint"` 86 | } 87 | -------------------------------------------------------------------------------- /internal/fn/cmd.go: -------------------------------------------------------------------------------- 1 | package fn 2 | 3 | import ( 4 | "net" 5 | "os" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | func GetIntOrDefaultFromEnv(key string, defaultValue int) int { 11 | value := strings.TrimSpace(os.Getenv(key)) 12 | num, err := strconv.ParseInt(value, 10, 0) 13 | if err != nil { 14 | return defaultValue 15 | } 16 | return int(num) 17 | } 18 | 19 | func GetStringOrDefaultFromEnv(key string, defaultValue string) string { 20 | value := strings.TrimSpace(os.Getenv(key)) 21 | if value == "" { 22 | return defaultValue 23 | } 24 | return value 25 | } 26 | 27 | func GetBoolOrDefaultFromEnv(key string, defaultValue bool) bool { 28 | value := strings.TrimSpace(os.Getenv(key)) 29 | if value == "" { 30 | return defaultValue 31 | } 32 | 33 | s := strings.ToLower(value) 34 | if s == "true" || s == "on" || s == "yes" || s == "1" { 35 | return true 36 | } 37 | return false 38 | } 39 | 40 | func IsValidIPAddress(ip string) bool { 41 | return net.ParseIP(ip) != nil 42 | } 43 | -------------------------------------------------------------------------------- /internal/fn/cmd_test.go: -------------------------------------------------------------------------------- 1 | package fn_test 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/soulteary/amazing-openai-api/internal/fn" 8 | ) 9 | 10 | func TestGetIntOrDefaultFromEnv(t *testing.T) { 11 | const defaultVal = 10 12 | const envKey = "TEST_INT_ENV_VAR" 13 | 14 | t.Run("ReturnsDefaultValueForUnset", func(t *testing.T) { 15 | os.Unsetenv(envKey) 16 | if got := fn.GetIntOrDefaultFromEnv(envKey, defaultVal); got != defaultVal { 17 | t.Errorf("Expected default value %d, got %d", defaultVal, got) 18 | } 19 | }) 20 | 21 | t.Run("ReturnsParsedValue", func(t *testing.T) { 22 | expected := 42 23 | os.Setenv(envKey, "42") 24 | defer os.Unsetenv(envKey) 25 | if got := fn.GetIntOrDefaultFromEnv(envKey, defaultVal); got != expected { 26 | t.Errorf("Expected parsed value %d, got %d", expected, got) 27 | } 28 | }) 29 | 30 | t.Run("IgnoresInvalidValue", func(t *testing.T) { 31 | os.Setenv(envKey, "invalid") 32 | defer os.Unsetenv(envKey) 33 | if got := fn.GetIntOrDefaultFromEnv(envKey, defaultVal); got != defaultVal { 34 | t.Errorf("Expected default value %d when variable is invalid, got %d", defaultVal, got) 35 | } 36 | }) 37 | } 38 | 39 | func TestGetStringOrDefaultFromEnv(t *testing.T) { 40 | const defaultVal = "default" 41 | const envKey = "TEST_STRING_ENV_VAR" 42 | 43 | t.Run("ReturnsDefaultValueForUnset", func(t *testing.T) { 44 | os.Unsetenv(envKey) 45 | if got := fn.GetStringOrDefaultFromEnv(envKey, defaultVal); got != defaultVal { 46 | t.Errorf("Expected default value %s, got %s", defaultVal, got) 47 | } 48 | }) 49 | 50 | t.Run("ReturnsNonEmptyValue", func(t *testing.T) { 51 | expected := "test value" 52 | os.Setenv(envKey, expected) 53 | defer os.Unsetenv(envKey) 54 | if got := fn.GetStringOrDefaultFromEnv(envKey, defaultVal); got != expected { 55 | t.Errorf("Expected non-empty value %s, got %s", expected, got) 56 | } 57 | }) 58 | 59 | t.Run("TrimsWhitespace", func(t *testing.T) { 60 | expected := "test value" 61 | os.Setenv(envKey, " "+expected+" ") 62 | defer os.Unsetenv(envKey) 63 | if got := fn.GetStringOrDefaultFromEnv(envKey, defaultVal); got != expected { 64 | t.Errorf("Expected trimmed value %s, got %s", expected, got) 65 | } 66 | }) 67 | } 68 | 69 | func TestIsValidIPAddress(t *testing.T) { 70 | testCases := []struct { 71 | ip string 72 | valid bool 73 | }{ 74 | {"192.168.1.1", true}, 75 | {"255.255.255.255", true}, 76 | {"0.0.0.0", true}, 77 | {"256.1.1.1", false}, 78 | {"192.168.1", false}, 79 | {"not an ip", false}, 80 | {"::1", true}, // IPv6 81 | } 82 | 83 | for _, tc := range testCases { 84 | t.Run(tc.ip, func(t *testing.T) { 85 | if got := fn.IsValidIPAddress(tc.ip); got != tc.valid { 86 | t.Errorf("IsValidIPAddress(%q) = %v; want %v", tc.ip, got, tc.valid) 87 | } 88 | }) 89 | } 90 | } 91 | 92 | func TestGetBoolOrDefaultFromEnv(t *testing.T) { 93 | const envKey = "TEST_BOOL_ENV_VAR" 94 | 95 | t.Run("ReturnsDefaultValueForUnset", func(t *testing.T) { 96 | os.Unsetenv(envKey) 97 | if got := fn.GetBoolOrDefaultFromEnv(envKey, false); got != false { 98 | t.Errorf("Expected default value %v, got %v", false, got) 99 | } 100 | 101 | os.Unsetenv(envKey) 102 | if got := fn.GetBoolOrDefaultFromEnv(envKey, true); got != true { 103 | t.Errorf("Expected default value %v, got %v", true, got) 104 | } 105 | }) 106 | 107 | t.Run("test on", func(t *testing.T) { 108 | expected := "on" 109 | os.Setenv(envKey, expected) 110 | defer os.Unsetenv(envKey) 111 | if got := fn.GetBoolOrDefaultFromEnv(envKey, false); got != true { 112 | t.Errorf("Expected non-empty value %v, got %v", expected, got) 113 | } 114 | 115 | os.Setenv(envKey, expected) 116 | defer os.Unsetenv(envKey) 117 | if got := fn.GetBoolOrDefaultFromEnv(envKey, true); got != true { 118 | t.Errorf("Expected non-empty value %v, got %v", expected, got) 119 | } 120 | }) 121 | 122 | t.Run("test true", func(t *testing.T) { 123 | expected := "true" 124 | os.Setenv(envKey, expected) 125 | defer os.Unsetenv(envKey) 126 | if got := fn.GetBoolOrDefaultFromEnv(envKey, false); got != true { 127 | t.Errorf("Expected non-empty value %v, got %v", expected, got) 128 | } 129 | 130 | os.Setenv(envKey, expected) 131 | defer os.Unsetenv(envKey) 132 | if got := fn.GetBoolOrDefaultFromEnv(envKey, true); got != true { 133 | t.Errorf("Expected non-empty value %v, got %v", expected, got) 134 | } 135 | }) 136 | 137 | t.Run("test 1", func(t *testing.T) { 138 | expected := "1" 139 | os.Setenv(envKey, expected) 140 | defer os.Unsetenv(envKey) 141 | if got := fn.GetBoolOrDefaultFromEnv(envKey, false); got != true { 142 | t.Errorf("Expected non-empty value %v, got %v", expected, got) 143 | } 144 | 145 | os.Setenv(envKey, expected) 146 | defer os.Unsetenv(envKey) 147 | if got := fn.GetBoolOrDefaultFromEnv(envKey, true); got != true { 148 | t.Errorf("Expected non-empty value %v, got %v", expected, got) 149 | } 150 | }) 151 | 152 | t.Run("test yes", func(t *testing.T) { 153 | expected := "yes" 154 | os.Setenv(envKey, expected) 155 | defer os.Unsetenv(envKey) 156 | if got := fn.GetBoolOrDefaultFromEnv(envKey, false); got != true { 157 | t.Errorf("Expected non-empty value %v, got %v", expected, got) 158 | } 159 | 160 | os.Setenv(envKey, expected) 161 | defer os.Unsetenv(envKey) 162 | if got := fn.GetBoolOrDefaultFromEnv(envKey, true); got != true { 163 | t.Errorf("Expected non-empty value %v, got %v", expected, got) 164 | } 165 | }) 166 | 167 | t.Run("test 0", func(t *testing.T) { 168 | expected := "0" 169 | os.Setenv(envKey, expected) 170 | defer os.Unsetenv(envKey) 171 | if got := fn.GetBoolOrDefaultFromEnv(envKey, false); got != false { 172 | t.Errorf("Expected non-empty value %v, got %v", expected, got) 173 | } 174 | 175 | os.Setenv(envKey, expected) 176 | defer os.Unsetenv(envKey) 177 | if got := fn.GetBoolOrDefaultFromEnv(envKey, true); got != false { 178 | t.Errorf("Expected non-empty value %v, got %v", expected, got) 179 | } 180 | }) 181 | 182 | } 183 | -------------------------------------------------------------------------------- /internal/fn/gunzip.go: -------------------------------------------------------------------------------- 1 | package fn 2 | 3 | import ( 4 | "compress/gzip" 5 | "io" 6 | ) 7 | 8 | func Gunzip(r io.Reader) (reader io.ReadCloser, err error) { 9 | reader, err = gzip.NewReader(r) 10 | if err != nil { 11 | return nil, err 12 | } 13 | defer reader.Close() 14 | return reader, nil 15 | } 16 | -------------------------------------------------------------------------------- /internal/fn/models.go: -------------------------------------------------------------------------------- 1 | package fn 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/soulteary/amazing-openai-api/internal/define" 7 | ) 8 | 9 | func ExtractModelAlias(alias string) define.ModelAlias { 10 | var result define.ModelAlias 11 | if alias == "" { 12 | return result 13 | } 14 | pairs := strings.Split(alias, ",") 15 | for _, pair := range pairs { 16 | alias := strings.Split(pair, ":") 17 | if len(alias) != 2 { 18 | continue 19 | } 20 | result = append(result, alias) 21 | } 22 | return result 23 | } 24 | -------------------------------------------------------------------------------- /internal/fn/models_test.go: -------------------------------------------------------------------------------- 1 | package fn 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/soulteary/amazing-openai-api/internal/define" 8 | ) 9 | 10 | func TestExtractModelAlias(t *testing.T) { 11 | var result define.ModelAlias 12 | 13 | tests := []struct { 14 | name string 15 | alias string 16 | expected define.ModelAlias 17 | }{ 18 | { 19 | name: "empty string", 20 | alias: "", 21 | expected: result, 22 | }, 23 | { 24 | name: "single valid alias pair", 25 | alias: "key1:value1", 26 | expected: define.ModelAlias{{"key1", "value1"}}, 27 | }, 28 | { 29 | name: "multiple valid alias pairs", 30 | alias: "key1:value1,key2:value2", 31 | expected: define.ModelAlias{{"key1", "value1"}, {"key2", "value2"}}, 32 | }, 33 | { 34 | name: "invalid alias pair", 35 | alias: "singleword", 36 | expected: result, 37 | }, 38 | { 39 | name: "mixed valid and invalid alias pairs", 40 | alias: "key1:value1,singleword,key2:value2", 41 | expected: define.ModelAlias{{"key1", "value1"}, {"key2", "value2"}}, 42 | }, 43 | { 44 | name: "valid alias with extra colon", 45 | alias: "key1:value1:extra", 46 | expected: result, 47 | }, 48 | } 49 | 50 | for _, tt := range tests { 51 | t.Run(tt.name, func(t *testing.T) { 52 | result := ExtractModelAlias(tt.alias) 53 | if !reflect.DeepEqual(result, tt.expected) { 54 | t.Errorf("ExtractModelAlias(%q) = %v, want %v", tt.alias, result, tt.expected) 55 | } 56 | }) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /internal/model/flags.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | type Flags struct { 4 | DebugMode bool 5 | ShowVersion bool 6 | ShowHelp bool 7 | 8 | Type string 9 | Vision bool 10 | Port int 11 | Host string 12 | } 13 | -------------------------------------------------------------------------------- /internal/network/http_proxy.go: -------------------------------------------------------------------------------- 1 | package network 2 | 3 | import ( 4 | "context" 5 | "encoding/base64" 6 | "fmt" 7 | "net" 8 | "net/http" 9 | "net/url" 10 | 11 | "golang.org/x/net/proxy" 12 | ) 13 | 14 | func NewProxyFromEnv(socksProxy string, httpProxy string) (*http.Transport, error) { 15 | if socksProxy != "" { 16 | return NewSocksProxy(socksProxy) 17 | } 18 | 19 | if httpProxy != "" { 20 | return NewHttpProxy(httpProxy) 21 | } 22 | return nil, nil 23 | } 24 | 25 | func NewHttpProxy(proxyAddress string) (*http.Transport, error) { 26 | proxyURL, err := url.Parse(proxyAddress) 27 | if err != nil { 28 | return nil, fmt.Errorf("error parsing proxy URL: %v", err) 29 | } 30 | 31 | transport := &http.Transport{ 32 | Proxy: http.ProxyURL(proxyURL), 33 | } 34 | 35 | if proxyURL.User != nil { 36 | proxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(proxyURL.User.String())) 37 | 38 | transport.ProxyConnectHeader = http.Header{ 39 | "Proxy-Authorization": []string{proxyAuth}, 40 | } 41 | } 42 | return transport, nil 43 | } 44 | 45 | func NewSocksProxy(proxyAddress string) (*http.Transport, error) { 46 | // proxyAddress: socks5://user:password@127.0.0.1:1080 47 | proxyURL, err := url.Parse(proxyAddress) 48 | if err != nil { 49 | return nil, fmt.Errorf("error parsing proxy URL: %v", err) 50 | } 51 | 52 | dialer, err := proxy.FromURL(proxyURL, proxy.Direct) 53 | if err != nil { 54 | return nil, fmt.Errorf("error creating proxy dialer: %v", err) 55 | } 56 | 57 | transport := &http.Transport{ 58 | DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { 59 | return dialer.Dial(network, address) 60 | }, 61 | } 62 | return transport, nil 63 | } 64 | -------------------------------------------------------------------------------- /internal/network/http_proxy_test.go: -------------------------------------------------------------------------------- 1 | package network 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestHttpProxy(t *testing.T) { 11 | proxyAddress := "http://127.0.0.1:1087" 12 | transport, err := NewHttpProxy(proxyAddress) 13 | 14 | assert.NoError(t, err) 15 | assert.NotNil(t, transport) 16 | 17 | client := &http.Client{ 18 | Transport: transport, 19 | } 20 | 21 | resp, err := client.Get("https://www.google.com") 22 | assert.NoError(t, err) 23 | assert.NotNil(t, resp) 24 | assert.Equal(t, 200, resp.StatusCode) 25 | } 26 | 27 | func TestSocksProxy(t *testing.T) { 28 | proxyAddress := "socks5://127.0.0.1:1080" 29 | transport, err := NewSocksProxy(proxyAddress) 30 | 31 | assert.NoError(t, err) 32 | assert.NotNil(t, transport) 33 | 34 | client := &http.Client{ 35 | Transport: transport, 36 | } 37 | 38 | resp, err := client.Get("https://www.google.com") 39 | assert.NoError(t, err) 40 | assert.NotNil(t, resp) 41 | assert.Equal(t, 200, resp.StatusCode) 42 | } 43 | -------------------------------------------------------------------------------- /internal/network/response_err.go: -------------------------------------------------------------------------------- 1 | package network 2 | 3 | import ( 4 | "github.com/gin-gonic/gin" 5 | ) 6 | 7 | type ApiResponse struct { 8 | Error ErrorDescription `json:"error"` 9 | } 10 | 11 | type ErrorDescription struct { 12 | Code string `json:"code"` 13 | Message string `json:"message"` 14 | } 15 | 16 | func SendError(c *gin.Context, err error) { 17 | c.JSON(500, ApiResponse{ 18 | Error: ErrorDescription{ 19 | Code: "500", 20 | Message: err.Error(), 21 | }, 22 | }) 23 | } 24 | -------------------------------------------------------------------------------- /internal/network/response_err_test.go: -------------------------------------------------------------------------------- 1 | package network_test 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/gin-gonic/gin" 10 | "github.com/soulteary/amazing-openai-api/internal/network" 11 | ) 12 | 13 | func TestSendError(t *testing.T) { 14 | gin.SetMode(gin.TestMode) 15 | w := httptest.NewRecorder() 16 | c, _ := gin.CreateTestContext(w) 17 | 18 | testError := errors.New("internal server error") 19 | network.SendError(c, testError) 20 | 21 | if w.Code != 500 { 22 | t.Errorf("Expected status code 500, got %d", w.Code) 23 | } 24 | 25 | var apiResponse network.ApiResponse 26 | err := json.Unmarshal(w.Body.Bytes(), &apiResponse) 27 | if err != nil { 28 | t.Fatalf("Error unmarshalling response: %v", err) 29 | } 30 | 31 | if apiResponse.Error.Code != "500" { 32 | t.Errorf("Expected error code '500', got '%s'", apiResponse.Error.Code) 33 | } 34 | 35 | expectedErrorMessage := testError.Error() 36 | if apiResponse.Error.Message != expectedErrorMessage { 37 | t.Errorf("Expected error message '%s', got '%s'", expectedErrorMessage, apiResponse.Error.Message) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /internal/router/misc.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import "github.com/gin-gonic/gin" 4 | 5 | func Hi(c *gin.Context) { 6 | c.Status(200) 7 | } 8 | 9 | func RegisterMiscRoute(r *gin.Engine) { 10 | r.GET("/", Hi) 11 | r.GET("/health", Hi) 12 | r.GET("/ping", Hi) 13 | } 14 | -------------------------------------------------------------------------------- /internal/router/router.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "github.com/gin-gonic/gin" 5 | "github.com/soulteary/amazing-openai-api/models/azure" 6 | "github.com/soulteary/amazing-openai-api/models/gemini" 7 | "github.com/soulteary/amazing-openai-api/models/yi" 8 | ) 9 | 10 | func RegisterModelRoute(r *gin.Engine, serviceType string) { 11 | // https://platform.openai.com/docs/api-reference 12 | apiBase := "/v1" 13 | 14 | switch serviceType { 15 | case "azure": 16 | stripPrefixConverter := azure.NewStripPrefixConverter(apiBase) 17 | r.GET(stripPrefixConverter.Prefix+"/models", azure.ModelProxy) 18 | apiBasedRouter := r.Group(apiBase) 19 | { 20 | apiBasedRouter.Any("/completions", azure.ProxyWithConverter(stripPrefixConverter)) 21 | apiBasedRouter.Any("/chat/completions", azure.ProxyWithConverter(stripPrefixConverter)) 22 | } 23 | case "yi": 24 | stripPrefixConverter := yi.NewStripPrefixConverter(apiBase) 25 | apiBasedRouter := r.Group(apiBase) 26 | { 27 | apiBasedRouter.Any("/completions", yi.ProxyWithConverter(stripPrefixConverter)) 28 | apiBasedRouter.Any("/chat/completions", yi.ProxyWithConverter(stripPrefixConverter)) 29 | } 30 | case "gemini": 31 | stripPrefixConverter := gemini.NewStripPrefixConverter(apiBase) 32 | apiBasedRouter := r.Group(apiBase) 33 | { 34 | apiBasedRouter.Any("/completions", gemini.ProxyWithConverter(stripPrefixConverter)) 35 | apiBasedRouter.Any("/chat/completions", gemini.ProxyWithConverter(stripPrefixConverter)) 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /internal/version/version.go: -------------------------------------------------------------------------------- 1 | package version 2 | 3 | var ( 4 | Version = "" 5 | BuildDate = "" 6 | GitCommit = "" 7 | ) 8 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/soulteary/amazing-openai-api/cmd" 5 | ) 6 | 7 | // refs: https://github.com/soulteary/flare/blob/main/main.go 8 | func main() { 9 | cmd.Parse() 10 | } 11 | -------------------------------------------------------------------------------- /models/azure/azure.go: -------------------------------------------------------------------------------- 1 | package azure 2 | 3 | import ( 4 | "fmt" 5 | "net/url" 6 | "strings" 7 | 8 | "github.com/soulteary/amazing-openai-api/internal/define" 9 | "github.com/soulteary/amazing-openai-api/internal/fn" 10 | ) 11 | 12 | var ( 13 | ModelConfig = map[string]define.ModelConfig{} 14 | ) 15 | 16 | func Init() (err error) { 17 | var modelConfig define.ModelConfig 18 | 19 | // azure openai api endpoint 20 | endpoint := fn.GetStringOrDefaultFromEnv(ENV_AZURE_ENDPOINT, "") 21 | if endpoint == "" { 22 | return fmt.Errorf("missing environment variable %s", ENV_AZURE_ENDPOINT) 23 | } 24 | // Use a URL starting with `https://` and ending with `.openai.azure.com/` 25 | if !(strings.HasPrefix(endpoint, "https://") && strings.HasSuffix(endpoint, ".openai.azure.com/")) { 26 | return fmt.Errorf("invalid environment variable %s", ENV_AZURE_ENDPOINT) 27 | } 28 | u, err := url.Parse(endpoint) 29 | if err != nil { 30 | return fmt.Errorf("parse endpoint error: %w", err) 31 | } 32 | modelConfig.URL = u 33 | modelConfig.Endpoint = endpoint 34 | 35 | // azure openai api version 36 | apiVersion := fn.GetStringOrDefaultFromEnv(ENV_AZURE_API_VER, DEFAULT_AZURE_API_VER) 37 | 38 | ignoreAPIVersionCheck := fn.GetBoolOrDefaultFromEnv(ENV_IGNORE_API_VERSION_CHECK, false) 39 | if !ignoreAPIVersionCheck { 40 | // azure openai api versions supported 41 | // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference 42 | if apiVersion != "2022-12-01" && 43 | apiVersion != "2023-03-15-preview" && 44 | apiVersion != "2023-05-15" && 45 | apiVersion != "2023-06-01-preview" && 46 | apiVersion != "2023-07-01-preview" && 47 | apiVersion != "2023-08-01-preview" && 48 | apiVersion != "2023-09-01-preview" && 49 | apiVersion != "2023-12-01-preview" && 50 | apiVersion != "2024-02-15-preview" { 51 | apiVersion = DEFAULT_AZURE_API_VER 52 | } 53 | } 54 | modelConfig.Version = apiVersion 55 | 56 | // azure openai api key, allow override by request header 57 | apikey := fn.GetStringOrDefaultFromEnv(ENV_AZURE_API_KEY, "") 58 | modelConfig.Key = apikey 59 | 60 | // azure openai api model 61 | model := fn.GetStringOrDefaultFromEnv(ENV_AZURE_MODEL, DEFAULT_AZURE_MODEL) 62 | if model == "" { 63 | model = DEFAULT_AZURE_MODEL 64 | } 65 | modelConfig.Model = model 66 | 67 | modelConfig.Vision = fn.GetBoolOrDefaultFromEnv(ENV_AZURE_VISION, false) 68 | 69 | ModelConfig[model] = modelConfig 70 | 71 | // azure openai api model alias 72 | alias := fn.ExtractModelAlias(fn.GetStringOrDefaultFromEnv(ENV_AZURE_MODEL_ALIAS, "")) 73 | for _, pair := range alias { 74 | modelConfig.Alias = pair[1] 75 | ModelConfig[pair[0]] = modelConfig 76 | } 77 | return nil 78 | } 79 | -------------------------------------------------------------------------------- /models/azure/azure_test.go: -------------------------------------------------------------------------------- 1 | package azure_test 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "testing" 7 | 8 | "github.com/soulteary/amazing-openai-api/models/azure" 9 | ) 10 | 11 | // Helper function to set environment variables for testing. 12 | func setEnv(envMap map[string]string) error { 13 | for key, value := range envMap { 14 | err := os.Setenv(key, value) 15 | if err != nil { 16 | return err 17 | } 18 | } 19 | return nil 20 | } 21 | 22 | // Helper function to unset environment variables after testing. 23 | func unsetEnv(keys []string) { 24 | for _, key := range keys { 25 | os.Unsetenv(key) 26 | } 27 | } 28 | 29 | // TestInitMissingEndpoint tests if Init returns an error when ENV_AZURE_ENDPOINT is missing. 30 | func TestInitMissingEndpoint(t *testing.T) { 31 | unsetEnv([]string{azure.ENV_AZURE_ENDPOINT}) // Ensure the environment variable is not set. 32 | 33 | err := azure.Init() 34 | if err == nil || err.Error() != fmt.Sprintf("missing environment variable %s", azure.ENV_AZURE_ENDPOINT) { 35 | t.Errorf("Expected missing endpoint error, got %v", err) 36 | } 37 | } 38 | 39 | // TestInitInvalidEndpoint tests if Init returns an error when ENV_AZURE_ENDPOINT is invalid. 40 | func TestInitInvalidEndpoint(t *testing.T) { 41 | envMap := map[string]string{ 42 | azure.ENV_AZURE_ENDPOINT: "http://invalid-endpoint", // Invalid schema or format 43 | } 44 | setEnv(envMap) 45 | defer unsetEnv([]string{azure.ENV_AZURE_ENDPOINT}) 46 | 47 | err := azure.Init() 48 | if err == nil || err.Error() != fmt.Sprintf("invalid environment variable %s", azure.ENV_AZURE_ENDPOINT) { 49 | t.Errorf("Expected invalid endpoint error, got %v", err) 50 | } 51 | } 52 | 53 | // TestInitUnsupportedVersion tests if Init sets the default version when an unsupported version is passed. 54 | func TestInitUnsupportedVersion(t *testing.T) { 55 | envMap := map[string]string{ 56 | azure.ENV_AZURE_ENDPOINT: "https://valid-endpoint.openai.azure.com/", 57 | azure.ENV_AZURE_API_VER: "unsupported-version", 58 | } 59 | setEnv(envMap) 60 | defer unsetEnv([]string{azure.ENV_AZURE_ENDPOINT, azure.ENV_AZURE_API_VER}) 61 | 62 | err := azure.Init() 63 | if err != nil { 64 | t.Fatalf("Init failed with error: %v", err) 65 | } 66 | if azure.ModelConfig[azure.DEFAULT_AZURE_MODEL].Version != azure.DEFAULT_AZURE_API_VER { 67 | t.Errorf("Expected version to be set to default, got %v", azure.ModelConfig[azure.DEFAULT_AZURE_MODEL].Version) 68 | } 69 | } 70 | 71 | // TestInitSuccess tests if Init successfully initializes ModelConfig with the right values. 72 | func TestInitSuccess(t *testing.T) { 73 | envMap := map[string]string{ 74 | azure.ENV_AZURE_ENDPOINT: "https://valid-endpoint.openai.azure.com/", 75 | azure.ENV_AZURE_API_VER: "2023-03-15-preview", 76 | azure.ENV_AZURE_API_KEY: "test-api-key", 77 | azure.ENV_AZURE_MODEL: "test-model", 78 | azure.ENV_AZURE_MODEL_ALIAS: "alias1:test-model-alias", 79 | } 80 | setEnv(envMap) 81 | defer unsetEnv([]string{azure.ENV_AZURE_ENDPOINT, azure.ENV_AZURE_API_VER, azure.ENV_AZURE_API_KEY, azure.ENV_AZURE_MODEL, azure.ENV_AZURE_MODEL_ALIAS}) 82 | 83 | err := azure.Init() 84 | if err != nil { 85 | t.Fatalf("Init failed with error: %v", err) 86 | } 87 | 88 | modelConfig, ok := azure.ModelConfig["test-model"] 89 | if !ok { 90 | t.Fatalf("Model 'test-model' not found in ModelConfig") 91 | } 92 | 93 | if modelConfig.Endpoint != "https://valid-endpoint.openai.azure.com/" { 94 | t.Errorf("Expected endpoint to match, got %v", modelConfig.Endpoint) 95 | } 96 | if modelConfig.Version != "2023-03-15-preview" { 97 | t.Errorf("Expected API version to match, got %v", modelConfig.Version) 98 | } 99 | 100 | if modelConfig.Model != "test-model" { // The alias should override the original model name. 101 | t.Errorf("Expected model to use alias, got %v", modelConfig.Model) 102 | } 103 | 104 | config, ok := azure.ModelConfig["alias1"] 105 | if !ok { 106 | t.Fatalf("Model 'alias1' not found in ModelConfig") 107 | } 108 | if config.Alias != "test-model-alias" { 109 | t.Errorf("Expected model to match, got %v", config.Alias) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /models/azure/define.go: -------------------------------------------------------------------------------- 1 | package azure 2 | 3 | const ( 4 | ENV_AZURE_ENDPOINT = "AZURE_ENDPOINT" 5 | ENV_AZURE_API_VER = "AZURE_API_VER" 6 | ENV_AZURE_MODEL_ALIAS = "AZURE_MODEL_ALIAS" 7 | ENV_AZURE_API_KEY = "AZURE_API_KEY" 8 | ENV_AZURE_MODEL = "AZURE_MODEL" 9 | ENV_AZURE_VISION = "AZURE_VISION" 10 | 11 | ENV_IGNORE_API_VERSION_CHECK = "AZURE_IGNORE_API_VERSION_CHECK" 12 | 13 | ENV_AZURE_HTTP_PROXY = "AZURE_HTTP_PROXY" 14 | ENV_AZURE_SOCKS_PROXY = "AZURE_SOCKS_PROXY" 15 | ) 16 | 17 | const ( 18 | DEFAULT_AZURE_API_VER = "2023-05-15" 19 | DEFAULT_AZURE_MODEL = "gpt-3.5-turbo" 20 | ) 21 | -------------------------------------------------------------------------------- /models/azure/model.go: -------------------------------------------------------------------------------- 1 | package azure 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "path" 7 | "strings" 8 | 9 | "github.com/soulteary/amazing-openai-api/internal/define" 10 | ) 11 | 12 | type RequestConverter interface { 13 | Name() string 14 | Convert(req *http.Request, config *define.ModelConfig) (*http.Request, error) 15 | } 16 | 17 | type StripPrefixConverter struct { 18 | Prefix string 19 | } 20 | 21 | func (c *StripPrefixConverter) Name() string { 22 | return "StripPrefix" 23 | } 24 | 25 | func (c *StripPrefixConverter) Convert(req *http.Request, config *define.ModelConfig) (*http.Request, error) { 26 | req.Host = config.URL.Host 27 | req.URL.Scheme = config.URL.Scheme 28 | req.URL.Host = config.URL.Host 29 | req.URL.Path = path.Join(fmt.Sprintf("/openai/deployments/%s", config.Model), strings.Replace(req.URL.Path, c.Prefix+"/", "/", 1)) 30 | req.URL.RawPath = req.URL.EscapedPath() 31 | 32 | query := req.URL.Query() 33 | query.Add(HeaderAPIVer, config.Version) 34 | req.URL.RawQuery = query.Encode() 35 | return req, nil 36 | } 37 | 38 | func NewStripPrefixConverter(prefix string) *StripPrefixConverter { 39 | return &StripPrefixConverter{ 40 | Prefix: prefix, 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /models/azure/model_test.go: -------------------------------------------------------------------------------- 1 | package azure_test 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | 7 | "testing" 8 | 9 | "github.com/soulteary/amazing-openai-api/internal/define" 10 | "github.com/soulteary/amazing-openai-api/models/azure" 11 | ) 12 | 13 | func TestStripPrefixConverter_Convert(t *testing.T) { 14 | prefix := "/api/v1" 15 | converter := azure.NewStripPrefixConverter(prefix) 16 | 17 | u, _ := url.Parse("https://example.com") 18 | 19 | modelConfig := &define.ModelConfig{ 20 | Model: "test-model", 21 | Version: "2023-04-01", 22 | URL: u, 23 | } 24 | 25 | reqURL, _ := url.Parse("http://localhost:8080/api/v1/model/predict") 26 | req := &http.Request{ 27 | URL: reqURL, 28 | Header: http.Header{}, 29 | } 30 | 31 | convertedReq, err := converter.Convert(req, modelConfig) 32 | if err != nil { 33 | t.Fatalf("Convert failed with error: %v", err) 34 | } 35 | 36 | expectedPath := "/openai/deployments/test-model/model/predict" 37 | if convertedReq.URL.Path != expectedPath { 38 | t.Errorf("Expected path '%s', but got '%s'", expectedPath, convertedReq.URL.Path) 39 | } 40 | 41 | if convertedReq.URL.Host != modelConfig.URL.Host { 42 | t.Errorf("Expected host '%s', but got '%s'", modelConfig.URL.Host, convertedReq.URL.Host) 43 | } 44 | 45 | if convertedReq.URL.Scheme != modelConfig.URL.Scheme { 46 | t.Errorf("Expected scheme '%s', but got '%s'", modelConfig.URL.Scheme, convertedReq.URL.Scheme) 47 | } 48 | 49 | expectedVersion := modelConfig.Version 50 | queryValues := convertedReq.URL.Query() 51 | if queryValues.Get(azure.HeaderAPIVer) != expectedVersion { 52 | t.Errorf("Expected API version query parameter '%s', but got '%s'", expectedVersion, queryValues.Get(azure.HeaderAPIVer)) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /models/azure/proxy.go: -------------------------------------------------------------------------------- 1 | package azure 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "log" 9 | "net/http" 10 | "net/http/httputil" 11 | "strings" 12 | 13 | "github.com/soulteary/amazing-openai-api/internal/define" 14 | "github.com/soulteary/amazing-openai-api/internal/fn" 15 | "github.com/soulteary/amazing-openai-api/internal/network" 16 | 17 | "github.com/gin-gonic/gin" 18 | "github.com/pkg/errors" 19 | ) 20 | 21 | const ( 22 | HeaderAuthKey = "api-key" 23 | HeaderAPIVer = "api-version" 24 | ) 25 | 26 | func ProxyWithConverter(requestConverter RequestConverter) gin.HandlerFunc { 27 | return func(c *gin.Context) { 28 | if c.Request.Method == http.MethodOptions { 29 | c.Header("Access-Control-Allow-Origin", "*") 30 | c.Header("Access-Control-Allow-Methods", "GET, OPTIONS, POST") 31 | c.Header("Access-Control-Allow-Headers", "Authorization, Content-Type, x-requested-with") 32 | c.Status(200) 33 | return 34 | } 35 | Proxy(c, requestConverter) 36 | } 37 | } 38 | 39 | type DeploymentInfo struct { 40 | Data []map[string]interface{} `json:"data"` 41 | Object string `json:"object"` 42 | } 43 | 44 | func ModelProxy(c *gin.Context) { 45 | // Create a channel to receive the results of each request 46 | results := make(chan []map[string]interface{}, len(ModelConfig)) 47 | 48 | // Send a request for each deployment in the map 49 | for _, deployment := range ModelConfig { 50 | go func(deployment define.ModelConfig) { 51 | // Create the request 52 | req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/openai/deployments?api-version=%s", deployment.Endpoint, deployment.Version), nil) 53 | if err != nil { 54 | log.Printf("error parsing response body for deployment %s: %v", deployment.Name, err) 55 | results <- nil 56 | return 57 | } 58 | 59 | // Set the auth header 60 | req.Header.Set(HeaderAuthKey, deployment.Key) 61 | 62 | // Send the request 63 | client := &http.Client{} 64 | resp, err := client.Do(req) 65 | if err != nil { 66 | log.Printf("error sending request for deployment %s: %v", deployment.Name, err) 67 | results <- nil 68 | return 69 | } 70 | defer resp.Body.Close() 71 | if resp.StatusCode != http.StatusOK { 72 | log.Printf("unexpected status code %d for deployment %s", resp.StatusCode, deployment.Name) 73 | results <- nil 74 | return 75 | } 76 | 77 | // Read the response body 78 | body, err := io.ReadAll(resp.Body) 79 | if err != nil { 80 | log.Printf("error reading response body for deployment %s: %v", deployment.Name, err) 81 | results <- nil 82 | return 83 | } 84 | 85 | // Parse the response body as JSON 86 | var deplotmentInfo DeploymentInfo 87 | err = json.Unmarshal(body, &deplotmentInfo) 88 | if err != nil { 89 | log.Printf("error parsing response body for deployment %s: %v", deployment.Name, err) 90 | results <- nil 91 | return 92 | } 93 | results <- deplotmentInfo.Data 94 | }(deployment) 95 | } 96 | 97 | // Wait for all requests to finish and collect the results 98 | var allResults []map[string]interface{} 99 | for i := 0; i < len(ModelConfig); i++ { 100 | result := <-results 101 | if result != nil { 102 | allResults = append(allResults, result...) 103 | } 104 | } 105 | var info = DeploymentInfo{Data: allResults, Object: "list"} 106 | combinedResults, err := json.Marshal(info) 107 | if err != nil { 108 | log.Printf("error marshalling results: %v", err) 109 | network.SendError(c, err) 110 | return 111 | } 112 | 113 | // Set the response headers and body 114 | c.Header("Content-Type", "application/json") 115 | c.String(http.StatusOK, string(combinedResults)) 116 | } 117 | 118 | // Proxy Azure OpenAI 119 | func Proxy(c *gin.Context, requestConverter RequestConverter) { 120 | // preserve request body for error logging 121 | var buf bytes.Buffer 122 | tee := io.TeeReader(c.Request.Body, &buf) 123 | bodyBytes, err := io.ReadAll(tee) 124 | if err != nil { 125 | log.Printf("Error reading request body: %v", err) 126 | return 127 | } 128 | c.Request.Body = io.NopCloser(&buf) 129 | 130 | director := func(req *http.Request) { 131 | if req.Body == nil { 132 | network.SendError(c, errors.New("request body is empty")) 133 | return 134 | } 135 | 136 | // extract model from request url 137 | model := c.Param("model") 138 | if model == "" { 139 | // extract model from request body 140 | body, err := io.ReadAll(req.Body) 141 | defer req.Body.Close() 142 | if err != nil { 143 | network.SendError(c, errors.Wrap(err, "read request body error")) 144 | return 145 | } 146 | 147 | var modelPayload define.OpenAI_Payload_Model 148 | err = json.Unmarshal(body, &modelPayload) 149 | if err != nil { 150 | network.SendError(c, errors.Wrap(err, "parse model payload error")) 151 | return 152 | } 153 | 154 | model = modelPayload.Model 155 | model := strings.TrimSpace(modelPayload.Model) 156 | if model == "" { 157 | model = DEFAULT_AZURE_MODEL 158 | } 159 | 160 | config, ok := ModelConfig[model] 161 | if ok { 162 | fmt.Println("rewrite model ", model, "to", config.Model) 163 | if !config.Vision { 164 | var payload define.OpenAI_Payload 165 | err = json.Unmarshal(body, &payload) 166 | if err != nil { 167 | network.SendError(c, errors.Wrap(err, "parse payload error")) 168 | return 169 | } 170 | 171 | payload.Model = config.Model 172 | 173 | repack, err := json.Marshal(payload) 174 | if err != nil { 175 | network.SendError(c, errors.Wrap(err, "repack payload error")) 176 | return 177 | } 178 | body = repack 179 | } else { 180 | var visionPayload define.OpenAI_Vision_Payload 181 | err = json.Unmarshal(body, &visionPayload) 182 | if err != nil { 183 | network.SendError(c, errors.Wrap(err, "parse vision payload error")) 184 | return 185 | } 186 | visionPayload.Model = config.Model 187 | 188 | repack, err := json.Marshal(visionPayload) 189 | if err != nil { 190 | network.SendError(c, errors.Wrap(err, "repack vision payload error")) 191 | return 192 | } 193 | body = repack 194 | } 195 | } 196 | 197 | req.Body = io.NopCloser(bytes.NewBuffer(body)) 198 | req.ContentLength = int64(len(body)) 199 | } 200 | 201 | // get deployment from request 202 | deployment, err := GetDeploymentByModel(model) 203 | if err != nil { 204 | network.SendError(c, err) 205 | return 206 | } 207 | 208 | // get auth token from header or deployemnt config 209 | token := deployment.Key 210 | if token == "" { 211 | rawToken := req.Header.Get("Authorization") 212 | token = strings.TrimPrefix(rawToken, "Bearer ") 213 | } 214 | if token == "" { 215 | network.SendError(c, errors.New("token is empty")) 216 | return 217 | } 218 | req.Header.Set(HeaderAuthKey, token) 219 | req.Header.Del("Authorization") 220 | 221 | originURL := req.URL.String() 222 | req, err = requestConverter.Convert(req, deployment) 223 | if err != nil { 224 | network.SendError(c, errors.Wrap(err, "convert request error")) 225 | return 226 | } 227 | log.Printf("proxying request [%s] %s -> %s", model, originURL, req.URL.String()) 228 | } 229 | 230 | proxy := &httputil.ReverseProxy{Director: director} 231 | transport, err := network.NewProxyFromEnv( 232 | fn.GetStringOrDefaultFromEnv("ENV_AZURE_SOCKS_PROXY", ""), 233 | fn.GetStringOrDefaultFromEnv("ENV_AZURE_HTTP_PROXY", ""), 234 | ) 235 | if err != nil { 236 | network.SendError(c, errors.Wrap(err, "get proxy error")) 237 | return 238 | } 239 | if transport != nil { 240 | proxy.Transport = transport 241 | } 242 | 243 | proxy.ServeHTTP(c.Writer, c.Request) 244 | 245 | // issue: https://github.com/Chanzhaoyu/chatgpt-web/issues/831 246 | if c.Writer.Header().Get("Content-Type") == "text/event-stream" { 247 | if _, err := c.Writer.Write([]byte{'\n'}); err != nil { 248 | log.Printf("rewrite response error: %v", err) 249 | } 250 | } 251 | 252 | if c.Writer.Status() != 200 { 253 | log.Printf("encountering error with body: %s", string(bodyBytes)) 254 | } 255 | } 256 | 257 | func GetDeploymentByModel(model string) (*define.ModelConfig, error) { 258 | deploymentConfig, exist := ModelConfig[model] 259 | if !exist { 260 | return nil, errors.New(fmt.Sprintf("deployment config for %s not found", model)) 261 | } 262 | return &deploymentConfig, nil 263 | } 264 | -------------------------------------------------------------------------------- /models/azure/proxy_test.go: -------------------------------------------------------------------------------- 1 | package azure_test 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/gin-gonic/gin" 9 | "github.com/soulteary/amazing-openai-api/internal/define" 10 | "github.com/soulteary/amazing-openai-api/models/azure" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/mock" 13 | ) 14 | 15 | type MockedRequestConverter struct { 16 | mock.Mock 17 | } 18 | 19 | func (m *MockedRequestConverter) Convert(req *http.Request, deployment *define.ModelConfig) (*http.Request, error) { 20 | args := m.Called(req, deployment) 21 | return args.Get(0).(*http.Request), args.Error(1) 22 | } 23 | 24 | func (m *MockedRequestConverter) Name() string { 25 | args := m.Called() 26 | return args.String(0) 27 | } 28 | 29 | func TestProxyMiddlewareWithOptionsMethod(t *testing.T) { 30 | gin.SetMode(gin.TestMode) 31 | r := gin.New() 32 | mockReqConverter := new(MockedRequestConverter) 33 | r.Use(azure.ProxyWithConverter(mockReqConverter)) 34 | 35 | req, _ := http.NewRequest(http.MethodOptions, "/", nil) 36 | 37 | w := httptest.NewRecorder() 38 | r.ServeHTTP(w, req) 39 | 40 | assert.Equal(t, 200, w.Code) 41 | // Check for CORS headers here... 42 | } 43 | 44 | func TestModelProxySuccess(t *testing.T) { 45 | // This test would require setting up the expected behavior of sending requests 46 | // and collecting results, you would mock the external dependencies. 47 | } 48 | 49 | func TestModelProxyFailures(t *testing.T) { 50 | // Similarly, this would test failure scenarios (bad responses, errors in request sending, etc.) 51 | // by adjusting the mocked behavior accordingly. 52 | } 53 | 54 | func TestProxyFunctionality(t *testing.T) { 55 | // Here you would validate the proxy functionality with a setup similar to 56 | // 'TestModelProxySuccess' and 'TestModelProxyFailures' tests but focusing on the Proxy function. 57 | } 58 | 59 | func TestGetDeploymentByModel(t *testing.T) { 60 | expectedModel := "test-model" 61 | expectedConfig := define.ModelConfig{ 62 | Name: expectedModel, 63 | Endpoint: "https://example.com", 64 | Key: "secret-key", 65 | } 66 | 67 | // Assuming ModelConfig is a global variable storing configurations, it should be mocked or set appropriately. 68 | azure.ModelConfig = map[string]define.ModelConfig{ 69 | expectedModel: expectedConfig, 70 | } 71 | 72 | config, err := azure.GetDeploymentByModel(expectedModel) 73 | 74 | assert.Nil(t, err) 75 | assert.Equal(t, &expectedConfig, config) 76 | } 77 | 78 | func TestGetDeploymentByModelNotFound(t *testing.T) { 79 | unexpectedModel := "non-existent-model" 80 | 81 | _, err := azure.GetDeploymentByModel(unexpectedModel) 82 | 83 | assert.NotNil(t, err) 84 | assert.Equal(t, "deployment config for non-existent-model not found", err.Error()) 85 | } 86 | -------------------------------------------------------------------------------- /models/gemini/define.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | const ( 4 | ENV_GEMINI_ENDPOINT = "GEMINI_ENDPOINT" 5 | ENV_GEMINI_API_VER = "GEMINI_API_VER" 6 | ENV_GEMINI_MODEL_ALIAS = "GEMINI_MODEL_ALIAS" 7 | ENV_GEMINI_API_KEY = "GEMINI_API_KEY" 8 | ENV_GEMINI_MODEL = "GEMINI_MODEL" 9 | ENV_GEMINI_SAFETY = "GEMINI_SAFETY" 10 | 11 | ENV_GEMINI_HTTP_PROXY = "GEMINI_HTTP_PROXY" 12 | ENV_GEMINI_SOCKS_PROXY = "GEMINI_SOCKS_PROXY" 13 | ) 14 | 15 | const ( 16 | DEFAULT_REST_API_VERSION_SHIM = "/v1" 17 | DEFAULT_REST_API_VERSION = "/v1beta" 18 | DEFAULT_REST_API_ENTRYPOINT = "https://generativelanguage.googleapis.com" 19 | ) 20 | 21 | const ( 22 | DEFAULT_SAFETY_THRESHOLD_NONE = "BLOCK_NONE" 23 | DEFAULT_SAFETY_THRESHOLD_LESS = "BLOCK_ONLY_HIGH" 24 | DEFAULT_SAFETY_THRESHOLD_MEDIUM = "BLOCK_MEDIUM_AND_ABOVE" 25 | DEFAULT_SAFETY_THRESHOLD_HIGH = "BLOCK_LOW_AND_ABOVE" 26 | DEFAULT_SAFETY_THRESHOLD_UNSET = "HARM_BLOCK_THRESHOLD_UNSPECIFIED" 27 | ) 28 | 29 | const ( 30 | DEFAULT_GEMINI_API_VER = DEFAULT_REST_API_VERSION 31 | DEFAULT_GEMINI_MODEL = "gemini-pro" 32 | ) 33 | 34 | type OpenAIPayloadMessages struct { 35 | Role string `json:"role"` 36 | Content string `json:"content"` 37 | } 38 | 39 | type OpenAIPayload struct { 40 | MaxTokens int `json:"max_tokens"` 41 | Model string `json:"model"` 42 | Temperature float64 `json:"temperature"` 43 | TopP float64 `json:"top_p"` 44 | PresencePenalty float64 `json:"presence_penalty"` 45 | Messages []OpenAIPayloadMessages `json:"messages"` 46 | Stream bool `json:"stream"` 47 | } 48 | 49 | type GoogleGeminiPayload struct { 50 | Contents []GeminiPayloadContents `json:"contents"` 51 | SafetySettings []GeminiSafetySettings `json:"safetySettings"` 52 | GenerationConfig GeminiGenerationConfig `json:"generationConfig"` 53 | } 54 | 55 | type GeminiSafetySettings struct { 56 | Category string `json:"category"` 57 | Threshold string `json:"threshold"` 58 | } 59 | 60 | type GeminiGenerationConfig struct { 61 | StopSequences []string `json:"stopSequences"` 62 | Temperature float64 `json:"temperature,omitempty"` 63 | MaxOutputTokens int `json:"maxOutputTokens,omitempty"` 64 | TopP float64 `json:"topP,omitempty"` 65 | TopK int `json:"topK,omitempty"` 66 | } 67 | 68 | // gemini response 69 | type GeminiSafetyRatings struct { 70 | Category string `json:"category"` 71 | Probability string `json:"probability"` 72 | } 73 | 74 | type PromptFeedback struct { 75 | SafetyRatings []GeminiSafetyRatings `json:"safetyRatings"` 76 | } 77 | 78 | type GeminiPayloadParts struct { 79 | Text string `json:"text"` 80 | } 81 | 82 | type GeminiPayloadContents struct { 83 | Parts []GeminiPayloadParts `json:"parts"` 84 | Role string `json:"role"` 85 | } 86 | 87 | type GeminiCandidates struct { 88 | Content GeminiPayloadContents `json:"content"` 89 | FinishReason string `json:"finishReason"` 90 | Index int `json:"index"` 91 | SafetyRatings []GeminiSafetyRatings `json:"safetyRatings"` 92 | } 93 | 94 | type GeminiResponse struct { 95 | Candidates []GeminiCandidates `json:"candidates"` 96 | PromptFeedback PromptFeedback `json:"promptFeedback"` 97 | } 98 | -------------------------------------------------------------------------------- /models/gemini/gemini.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "fmt" 5 | "net/url" 6 | 7 | "github.com/soulteary/amazing-openai-api/internal/define" 8 | "github.com/soulteary/amazing-openai-api/internal/fn" 9 | ) 10 | 11 | // refs: https://ai.google.dev/models/gemini?hl=zh-cn 12 | var ( 13 | ModelConfig = map[string]define.ModelConfig{} 14 | ) 15 | 16 | func Init() (err error) { 17 | var modelConfig define.ModelConfig 18 | 19 | // gemini openai api endpoint 20 | endpoint := fn.GetStringOrDefaultFromEnv(ENV_GEMINI_ENDPOINT, DEFAULT_REST_API_ENTRYPOINT) 21 | u, err := url.Parse(endpoint) 22 | if err != nil { 23 | return fmt.Errorf("parse endpoint error: %w", err) 24 | } 25 | modelConfig.URL = u 26 | modelConfig.Endpoint = endpoint 27 | 28 | // gemini openai api version 29 | apiVersion := fn.GetStringOrDefaultFromEnv(ENV_GEMINI_API_VER, DEFAULT_GEMINI_API_VER) 30 | // google api versions supported 31 | // https://ai.google.dev/docs/api_versions?hl=zh-cn 32 | if apiVersion != "v1" && apiVersion != "v1beta" { 33 | apiVersion = DEFAULT_GEMINI_API_VER 34 | } else { 35 | apiVersion = "/" + apiVersion 36 | } 37 | modelConfig.Version = apiVersion 38 | 39 | // gemini openai api key, allow override by request header 40 | apikey := fn.GetStringOrDefaultFromEnv(ENV_GEMINI_API_KEY, "") 41 | modelConfig.Key = apikey 42 | 43 | // gemini openai api model 44 | model := fn.GetStringOrDefaultFromEnv(ENV_GEMINI_MODEL, DEFAULT_GEMINI_MODEL) 45 | if model == "" { 46 | model = DEFAULT_GEMINI_MODEL 47 | } 48 | modelConfig.Model = model 49 | 50 | ModelConfig[model] = modelConfig 51 | 52 | // gemini openai api model alias 53 | alias := fn.ExtractModelAlias(fn.GetStringOrDefaultFromEnv(ENV_GEMINI_MODEL_ALIAS, "")) 54 | for _, pair := range alias { 55 | if model == pair[0] { 56 | modelConfig.Model = pair[1] 57 | } 58 | ModelConfig[pair[0]] = modelConfig 59 | } 60 | return nil 61 | } 62 | -------------------------------------------------------------------------------- /models/gemini/model.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | 9 | "github.com/soulteary/amazing-openai-api/internal/define" 10 | ) 11 | 12 | type RequestConverter interface { 13 | Name() string 14 | Convert(req *http.Request, config *define.ModelConfig, payload []byte, openaiPayload define.OpenAI_Payload, apikey string) (*http.Request, error) 15 | } 16 | 17 | type StripPrefixConverter struct { 18 | Prefix string 19 | } 20 | 21 | func (c *StripPrefixConverter) Name() string { 22 | return "StripPrefix" 23 | } 24 | 25 | func (c *StripPrefixConverter) Convert(req *http.Request, config *define.ModelConfig, payload []byte, openaiPayload define.OpenAI_Payload, apikey string) (*http.Request, error) { 26 | req.Host = config.URL.Host 27 | req.URL.Scheme = config.URL.Scheme 28 | req.URL.Host = config.URL.Host 29 | 30 | // if openaiPayload.Stream { 31 | // req.URL.Path = fmt.Sprintf("%s/models/%s:streamGenerateContent", config.Version, config.Model) 32 | // } else { 33 | req.URL.Path = fmt.Sprintf("%s/models/%s:generateContent", config.Version, config.Model) 34 | // } 35 | 36 | req.URL.RawPath = req.URL.EscapedPath() 37 | 38 | query := req.URL.Query() 39 | if config.Key == "" { 40 | if apikey == "" { 41 | return nil, fmt.Errorf("missing api key") 42 | } else { 43 | query.Add("key", apikey) 44 | } 45 | } else { 46 | query.Add("key", config.Key) 47 | } 48 | req.URL.RawQuery = query.Encode() 49 | req.Body = io.NopCloser(bytes.NewBuffer(payload)) 50 | req.ContentLength = int64(len(payload)) 51 | return req, nil 52 | } 53 | 54 | func NewStripPrefixConverter(prefix string) *StripPrefixConverter { 55 | return &StripPrefixConverter{ 56 | Prefix: prefix, 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /models/gemini/proxy.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "log" 9 | "net/http" 10 | "net/http/httputil" 11 | "regexp" 12 | "strconv" 13 | "strings" 14 | "time" 15 | 16 | "github.com/gin-gonic/gin" 17 | "github.com/pkg/errors" 18 | "github.com/soulteary/amazing-openai-api/internal/define" 19 | "github.com/soulteary/amazing-openai-api/internal/fn" 20 | "github.com/soulteary/amazing-openai-api/internal/network" 21 | ) 22 | 23 | const ( 24 | HeaderAuthKey = "api-key" 25 | HeaderAPIVer = "api-version" 26 | ) 27 | 28 | func ProxyWithConverter(requestConverter RequestConverter) gin.HandlerFunc { 29 | return func(c *gin.Context) { 30 | if c.Request.Method == http.MethodOptions { 31 | c.Header("Access-Control-Allow-Origin", "*") 32 | c.Header("Access-Control-Allow-Methods", "GET, OPTIONS, POST") 33 | c.Header("Access-Control-Allow-Headers", "Authorization, Content-Type, x-requested-with") 34 | c.Status(200) 35 | return 36 | } 37 | Proxy(c, requestConverter) 38 | } 39 | } 40 | 41 | var maskURL = regexp.MustCompile(`key=.+`) 42 | 43 | func parseRequestBody(reqBody io.ReadCloser) (openaiPayload define.OpenAI_Payload, err error) { 44 | if reqBody == nil { 45 | err = errors.New("request body is empty") 46 | return openaiPayload, err 47 | } 48 | body, _ := io.ReadAll(reqBody) 49 | err = json.Unmarshal(body, &openaiPayload) 50 | return openaiPayload, err 51 | } 52 | 53 | func parseResponseBody(responseBody io.ReadCloser) (GeminiResponse, error) { 54 | var payload GeminiResponse 55 | body, err := io.ReadAll(responseBody) 56 | if err != nil { 57 | return payload, err 58 | } 59 | 60 | err = json.Unmarshal(body, &payload) 61 | if err != nil { 62 | return payload, err 63 | } 64 | return payload, nil 65 | } 66 | 67 | func GetModelNameAndConfig(openaiPayload define.OpenAI_Payload) (string, define.ModelConfig, bool) { 68 | model := strings.TrimSpace(openaiPayload.Model) 69 | if model == "" { 70 | model = DEFAULT_GEMINI_MODEL 71 | } 72 | config, ok := ModelConfig[model] 73 | return model, config, ok 74 | } 75 | 76 | func getDirector(req *http.Request, body []byte, c *gin.Context, requestConverter RequestConverter, openaiPayload define.OpenAI_Payload, model string) func(req *http.Request) { 77 | return func(req *http.Request) { 78 | // req.Body = io.NopCloser(bytes.NewBuffer(body)) 79 | 80 | var payload GoogleGeminiPayload 81 | for _, data := range openaiPayload.Messages { 82 | var message GeminiPayloadContents 83 | if strings.ToLower(data.Role) == "user" { 84 | message.Role = "user" 85 | } else { 86 | message.Role = "model" 87 | } 88 | message.Parts = append(message.Parts, GeminiPayloadParts{ 89 | Text: strings.TrimSpace(data.Content), 90 | }) 91 | payload.Contents = append(payload.Contents, message) 92 | } 93 | 94 | // set default safety settings 95 | var safetySettings []GeminiSafetySettings 96 | safetyThreshold := fn.GetStringOrDefaultFromEnv(ENV_GEMINI_SAFETY, DEFAULT_SAFETY_THRESHOLD_UNSET) 97 | if safetyThreshold != DEFAULT_SAFETY_THRESHOLD_NONE && safetyThreshold != DEFAULT_SAFETY_THRESHOLD_UNSET && safetyThreshold != DEFAULT_SAFETY_THRESHOLD_LESS && safetyThreshold != DEFAULT_SAFETY_THRESHOLD_MEDIUM && safetyThreshold != DEFAULT_SAFETY_THRESHOLD_HIGH { 98 | safetyThreshold = DEFAULT_SAFETY_THRESHOLD_UNSET 99 | } 100 | safetySettings = append(safetySettings, GeminiSafetySettings{ 101 | Category: "HARM_CATEGORY_DANGEROUS_CONTENT", 102 | Threshold: safetyThreshold, 103 | }) 104 | payload.SafetySettings = safetySettings 105 | 106 | // set default generation config 107 | payload.GenerationConfig.StopSequences = []string{"Title"} 108 | payload.GenerationConfig.Temperature = openaiPayload.Temperature 109 | payload.GenerationConfig.MaxOutputTokens = openaiPayload.MaxTokens 110 | payload.GenerationConfig.TopP = openaiPayload.TopP 111 | // payload.GenerationConfig.TopK = openaiPayload.TopK 112 | 113 | // get deployment from request 114 | deployment, err := GetDeploymentByModel(model) 115 | if err != nil { 116 | network.SendError(c, err) 117 | return 118 | } 119 | // get auth token from header or deployemnt config 120 | token := deployment.Key 121 | if token == "" { 122 | rawToken := req.Header.Get("Authorization") 123 | token = strings.TrimPrefix(rawToken, "Bearer ") 124 | } 125 | if token == "" { 126 | network.SendError(c, errors.New("token is empty")) 127 | return 128 | } 129 | req.Header.Del("Authorization") 130 | 131 | repack, err := json.Marshal(payload) 132 | if err != nil { 133 | network.SendError(c, errors.Wrap(err, "repack payload error")) 134 | return 135 | } 136 | 137 | originURL := req.URL.String() 138 | req, err = requestConverter.Convert(req, deployment, repack, openaiPayload, token) 139 | if err != nil { 140 | network.SendError(c, errors.Wrap(err, "convert request error")) 141 | return 142 | } 143 | 144 | log.Printf("proxying request [%s] %s -> %s", model, originURL, maskURL.ReplaceAllString(req.URL.String(), "key=******")) 145 | } 146 | } 147 | 148 | // Proxy Gemini 149 | func Proxy(c *gin.Context, requestConverter RequestConverter) { 150 | var body []byte 151 | 152 | openaiPayload, err := parseRequestBody(c.Request.Body) 153 | if err != nil { 154 | network.SendError(c, err) 155 | return 156 | } 157 | 158 | model, config, ok := GetModelNameAndConfig(openaiPayload) 159 | if ok { 160 | fmt.Println("rewrite model ", model, "to", config.Model) 161 | openaiPayload.Model = config.Model 162 | } 163 | 164 | proxy := &httputil.ReverseProxy{Director: getDirector(c.Request, body, c, requestConverter, openaiPayload, model)} 165 | transport, err := network.NewProxyFromEnv( 166 | fn.GetStringOrDefaultFromEnv("ENV_GEMINI_SOCKS_PROXY", ""), 167 | fn.GetStringOrDefaultFromEnv("ENV_GEMINI_HTTP_PROXY", ""), 168 | ) 169 | if err != nil { 170 | network.SendError(c, errors.Wrap(err, "get proxy error")) 171 | return 172 | } 173 | if transport != nil { 174 | proxy.Transport = transport 175 | } 176 | 177 | proxy.ModifyResponse = func(response *http.Response) error { 178 | if response.StatusCode == http.StatusOK { 179 | 180 | var reader io.ReadCloser 181 | if strings.ToLower(response.Header.Get("Content-Encoding")) == "gzip" { 182 | reader, err = fn.Gunzip(response.Body) 183 | if err != nil { 184 | return err 185 | } 186 | } else { 187 | reader = response.Body 188 | } 189 | 190 | responsePayload, err := parseResponseBody(reader) 191 | defer reader.Close() 192 | if err != nil { 193 | return err 194 | } 195 | 196 | var openaiResponse define.OpeAI_Response 197 | openaiResponse.ID = "gemini" 198 | // if openaiPayload.Stream { 199 | // openaiResponse.Object = "chat.completion.chunk" 200 | // } else { 201 | openaiResponse.Object = "chat.completion" 202 | // } 203 | openaiResponse.Created = int(time.Now().Unix()) 204 | openaiResponse.Model = model 205 | 206 | var openaiMessage define.Message 207 | var openaiChoice define.OpenAI_Choices 208 | 209 | promptTokens := 0 210 | for _, data := range openaiPayload.Messages { 211 | promptTokens += len(data.Content) 212 | } 213 | 214 | completionTokens := 0 215 | for _, candidates := range responsePayload.Candidates { 216 | for _, part := range candidates.Content.Parts { 217 | openaiMessage.Role = candidates.Content.Role 218 | openaiMessage.Content = part.Text 219 | completionTokens += len(part.Text) 220 | } 221 | if candidates.FinishReason != "" { 222 | openaiChoice.FinishReason = candidates.FinishReason 223 | } 224 | openaiChoice.Index = candidates.Index 225 | } 226 | 227 | openaiChoice.Message = openaiMessage 228 | openaiResponse.Choices = append(openaiResponse.Choices, openaiChoice) 229 | 230 | // stats 231 | openaiResponse.Usage.CompletionTokens = completionTokens 232 | openaiResponse.Usage.PromptTokens = promptTokens 233 | openaiResponse.Usage.TotalTokens = completionTokens + promptTokens 234 | 235 | repack, err := json.Marshal(openaiResponse) 236 | if err != nil { 237 | return err 238 | } 239 | 240 | response.Body = io.NopCloser(bytes.NewBuffer(repack)) 241 | response.ContentLength = int64(len(repack)) 242 | response.Header.Set("Content-Length", strconv.Itoa(len(repack))) 243 | } 244 | return nil 245 | } 246 | 247 | proxy.ServeHTTP(c.Writer, c.Request) 248 | 249 | // issue: https://github.com/Chanzhaoyu/chatgpt-web/issues/831 250 | if c.Writer.Header().Get("Content-Type") == "text/event-stream" { 251 | if _, err := c.Writer.Write([]byte{'\n'}); err != nil { 252 | log.Printf("rewrite response error: %v", err) 253 | } 254 | } 255 | 256 | if c.Writer.Status() != 200 { 257 | log.Printf("encountering error with body: %s", string(body)) 258 | } 259 | } 260 | 261 | func GetDeploymentByModel(model string) (*define.ModelConfig, error) { 262 | deploymentConfig, exist := ModelConfig[model] 263 | if !exist { 264 | return nil, errors.New(fmt.Sprintf("deployment config for %s not found", model)) 265 | } 266 | return &deploymentConfig, nil 267 | } 268 | -------------------------------------------------------------------------------- /models/yi/define.go: -------------------------------------------------------------------------------- 1 | package yi 2 | 3 | const ( 4 | ENV_YI_ENDPOINT = "YI_ENDPOINT" 5 | ENV_YI_API_VER = "YI_API_VER" 6 | ENV_YI_MODEL_ALIAS = "YI_MODEL_ALIAS" 7 | ENV_YI_API_KEY = "YI_API_KEY" 8 | ENV_YI_MODEL = "YI_MODEL" 9 | 10 | ENV_YI_HTTP_PROXY = "YI_HTTP_PROXY" 11 | ENV_YI_SOCKS_PROXY = "YI_SOCKS_PROXY" 12 | ) 13 | 14 | const ( 15 | DEFAULT_YI_API_VER = "2023-12-15-preview" 16 | DEFAULT_YI_MODEL = "yi-34b-chat" 17 | ) 18 | -------------------------------------------------------------------------------- /models/yi/model.go: -------------------------------------------------------------------------------- 1 | package yi 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/soulteary/amazing-openai-api/internal/define" 7 | ) 8 | 9 | type RequestConverter interface { 10 | Name() string 11 | Convert(req *http.Request, config *define.ModelConfig) (*http.Request, error) 12 | } 13 | 14 | type StripPrefixConverter struct { 15 | Prefix string 16 | } 17 | 18 | func (c *StripPrefixConverter) Name() string { 19 | return "StripPrefix" 20 | } 21 | 22 | func (c *StripPrefixConverter) Convert(req *http.Request, config *define.ModelConfig) (*http.Request, error) { 23 | req.Host = config.URL.Host 24 | req.URL.Scheme = config.URL.Scheme 25 | req.URL.Host = config.URL.Host 26 | req.URL.RawPath = req.URL.EscapedPath() 27 | 28 | query := req.URL.Query() 29 | query.Add(HeaderAPIVer, config.Version) 30 | req.URL.RawQuery = query.Encode() 31 | return req, nil 32 | } 33 | 34 | func NewStripPrefixConverter(prefix string) *StripPrefixConverter { 35 | return &StripPrefixConverter{ 36 | Prefix: prefix, 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /models/yi/model_test.go: -------------------------------------------------------------------------------- 1 | package yi_test 2 | 3 | import ( 4 | "net/http/httptest" 5 | "net/url" 6 | "testing" 7 | 8 | "github.com/soulteary/amazing-openai-api/internal/define" 9 | "github.com/soulteary/amazing-openai-api/models/yi" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | const HeaderAPIVer string = "X-API-Version" 14 | 15 | func TestStripPrefixConverter_Name(t *testing.T) { 16 | converter := yi.NewStripPrefixConverter("/api") 17 | assert.Equal(t, "StripPrefix", converter.Name()) 18 | } 19 | 20 | func TestStripPrefixConverter_Convert(t *testing.T) { 21 | prefix := "/api" 22 | converter := yi.NewStripPrefixConverter(prefix) 23 | modelConfig := &define.ModelConfig{ 24 | URL: &url.URL{ 25 | Scheme: "https", 26 | Host: "example.com", 27 | }, 28 | Version: "v1", 29 | } 30 | 31 | req := httptest.NewRequest("GET", "http://localhost"+prefix+"/endpoint?param=value", nil) 32 | convertedReq, err := converter.Convert(req, modelConfig) 33 | 34 | assert.NoError(t, err) 35 | assert.NotNil(t, convertedReq) 36 | assert.Equal(t, "example.com", convertedReq.Host) 37 | assert.Equal(t, "https", convertedReq.URL.Scheme) 38 | assert.Equal(t, "example.com", convertedReq.URL.Host) 39 | 40 | // Ensure original path is maintained without the prefix 41 | assert.Contains(t, convertedReq.URL.Path, prefix) 42 | } 43 | -------------------------------------------------------------------------------- /models/yi/proxy.go: -------------------------------------------------------------------------------- 1 | package yi 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "log" 9 | "net/http" 10 | "net/http/httputil" 11 | "regexp" 12 | "strings" 13 | 14 | "github.com/gin-gonic/gin" 15 | "github.com/pkg/errors" 16 | "github.com/soulteary/amazing-openai-api/internal/define" 17 | "github.com/soulteary/amazing-openai-api/internal/fn" 18 | "github.com/soulteary/amazing-openai-api/internal/network" 19 | ) 20 | 21 | const ( 22 | HeaderAPIVer = "api-version" 23 | ) 24 | 25 | var maskURL = regexp.MustCompile(`https?:\/\/.+\/v1\/`) 26 | 27 | func ProxyWithConverter(requestConverter RequestConverter) gin.HandlerFunc { 28 | return func(c *gin.Context) { 29 | if c.Request.Method == http.MethodOptions { 30 | c.Header("Access-Control-Allow-Origin", "*") 31 | c.Header("Access-Control-Allow-Methods", "GET, OPTIONS, POST") 32 | c.Header("Access-Control-Allow-Headers", "Authorization, Content-Type, x-requested-with") 33 | c.Status(200) 34 | return 35 | } 36 | Proxy(c, requestConverter) 37 | } 38 | } 39 | 40 | // Proxy YI 41 | func Proxy(c *gin.Context, requestConverter RequestConverter) { 42 | // preserve request body for error logging 43 | var buf bytes.Buffer 44 | tee := io.TeeReader(c.Request.Body, &buf) 45 | bodyBytes, err := io.ReadAll(tee) 46 | if err != nil { 47 | log.Printf("Error reading request body: %v", err) 48 | return 49 | } 50 | c.Request.Body = io.NopCloser(&buf) 51 | 52 | director := func(req *http.Request) { 53 | if req.Body == nil { 54 | network.SendError(c, errors.New("request body is empty")) 55 | return 56 | } 57 | 58 | // extract model from request url 59 | model := c.Param("model") 60 | if model == "" { 61 | // extract model from request body 62 | body, err := io.ReadAll(req.Body) 63 | defer req.Body.Close() 64 | if err != nil { 65 | network.SendError(c, errors.Wrap(err, "read request body error")) 66 | return 67 | } 68 | 69 | var payload define.OpenAI_Payload 70 | err = json.Unmarshal(body, &payload) 71 | if err != nil { 72 | network.SendError(c, errors.Wrap(err, "parse payload error")) 73 | return 74 | } 75 | 76 | model = payload.Model 77 | model := strings.TrimSpace(payload.Model) 78 | if model == "" { 79 | model = DEFAULT_YI_MODEL 80 | } 81 | 82 | config, ok := ModelConfig[model] 83 | if ok { 84 | fmt.Println("rewrite model ", model, "to", config.Model) 85 | payload.Model = config.Model 86 | repack, err := json.Marshal(payload) 87 | if err != nil { 88 | network.SendError(c, errors.Wrap(err, "repack payload error")) 89 | return 90 | } 91 | body = repack 92 | } 93 | 94 | req.Body = io.NopCloser(bytes.NewBuffer(body)) 95 | req.ContentLength = int64(len(body)) 96 | } 97 | 98 | // get deployment from request 99 | deployment, err := GetDeploymentByModel(model) 100 | if err != nil { 101 | network.SendError(c, err) 102 | return 103 | } 104 | 105 | // get auth token from header or deployemnt config 106 | token := deployment.Key 107 | if token == "" { 108 | rawToken := req.Header.Get("Authorization") 109 | token = strings.TrimPrefix(rawToken, "Bearer ") 110 | } 111 | if token == "" { 112 | network.SendError(c, errors.New("token is empty")) 113 | return 114 | } 115 | req.Header.Set("Authorization", token) 116 | 117 | originURL := req.URL.String() 118 | req, err = requestConverter.Convert(req, deployment) 119 | if err != nil { 120 | network.SendError(c, errors.Wrap(err, "convert request error")) 121 | return 122 | } 123 | 124 | log.Printf("proxying request [%s] %s -> %s", model, originURL, maskURL.ReplaceAllString(req.URL.String(), "${YI-API-SERVER}/v1/")) 125 | } 126 | 127 | proxy := &httputil.ReverseProxy{Director: director} 128 | transport, err := network.NewProxyFromEnv( 129 | fn.GetStringOrDefaultFromEnv("ENV_YI_SOCKS_PROXY", ""), 130 | fn.GetStringOrDefaultFromEnv("ENV_YI_HTTP_PROXY", ""), 131 | ) 132 | if err != nil { 133 | network.SendError(c, errors.Wrap(err, "get proxy error")) 134 | return 135 | } 136 | if transport != nil { 137 | proxy.Transport = transport 138 | } 139 | 140 | proxy.ServeHTTP(c.Writer, c.Request) 141 | 142 | // issue: https://github.com/Chanzhaoyu/chatgpt-web/issues/831 143 | if c.Writer.Header().Get("Content-Type") == "text/event-stream" { 144 | if _, err := c.Writer.Write([]byte{'\n'}); err != nil { 145 | log.Printf("rewrite response error: %v", err) 146 | } 147 | } 148 | 149 | if c.Writer.Status() != 200 { 150 | log.Printf("encountering error with body: %s", string(bodyBytes)) 151 | } 152 | } 153 | 154 | func GetDeploymentByModel(model string) (*define.ModelConfig, error) { 155 | deploymentConfig, exist := ModelConfig[model] 156 | if !exist { 157 | return nil, errors.New(fmt.Sprintf("deployment config for %s not found", model)) 158 | } 159 | return &deploymentConfig, nil 160 | } 161 | -------------------------------------------------------------------------------- /models/yi/proxy_test.go: -------------------------------------------------------------------------------- 1 | package yi_test 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/gin-gonic/gin" 9 | "github.com/soulteary/amazing-openai-api/internal/define" 10 | "github.com/soulteary/amazing-openai-api/models/yi" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/mock" 13 | ) 14 | 15 | // Mocks 16 | type MockedRequestConverter struct { 17 | mock.Mock 18 | } 19 | 20 | func (m *MockedRequestConverter) Convert(req *http.Request, deployment *define.ModelConfig) (*http.Request, error) { 21 | args := m.Called(req, deployment) 22 | return args.Get(0).(*http.Request), args.Error(1) 23 | } 24 | 25 | func (m *MockedRequestConverter) Name() string { 26 | args := m.Called() 27 | return args.String(0) 28 | } 29 | 30 | // Tests 31 | func TestProxyWithConverter(t *testing.T) { 32 | gin.SetMode(gin.TestMode) 33 | r := gin.New() 34 | mockReqConverter := new(MockedRequestConverter) 35 | r.Use(yi.ProxyWithConverter(mockReqConverter)) 36 | 37 | req, _ := http.NewRequest(http.MethodOptions, "/", nil) 38 | 39 | w := httptest.NewRecorder() 40 | r.ServeHTTP(w, req) 41 | 42 | assert.Equal(t, 200, w.Code) 43 | } 44 | 45 | func TestGetDeploymentByModel(t *testing.T) { 46 | // Assuming ModelConfig has been defined with at least one key "test-model" 47 | modelName := "test-model" 48 | expectedConfig := &define.ModelConfig{ 49 | Key: "some-key", 50 | // ... other fields 51 | } 52 | 53 | // Set up the global variable for testing 54 | yi.ModelConfig = map[string]define.ModelConfig{ 55 | modelName: *expectedConfig, 56 | } 57 | 58 | config, err := yi.GetDeploymentByModel(modelName) 59 | assert.NoError(t, err) 60 | assert.Equal(t, expectedConfig, config) 61 | 62 | // Test with a non-existing model 63 | _, err = yi.GetDeploymentByModel("non-existing-model") 64 | assert.Error(t, err) 65 | } 66 | -------------------------------------------------------------------------------- /models/yi/yi.go: -------------------------------------------------------------------------------- 1 | package yi 2 | 3 | import ( 4 | "fmt" 5 | "net/url" 6 | 7 | "github.com/soulteary/amazing-openai-api/internal/define" 8 | "github.com/soulteary/amazing-openai-api/internal/fn" 9 | ) 10 | 11 | var ( 12 | ModelConfig = map[string]define.ModelConfig{} 13 | ) 14 | 15 | func Init() (err error) { 16 | var modelConfig define.ModelConfig 17 | 18 | // yi api endpoint 19 | endpoint := fn.GetStringOrDefaultFromEnv(ENV_YI_ENDPOINT, "") 20 | if endpoint == "" { 21 | return fmt.Errorf("missing environment variable %s", ENV_YI_ENDPOINT) 22 | } 23 | u, err := url.Parse(endpoint) 24 | if err != nil { 25 | return fmt.Errorf("parse endpoint error: %w", err) 26 | } 27 | modelConfig.URL = u 28 | modelConfig.Endpoint = endpoint 29 | 30 | // yi api version 31 | apiVersion := fn.GetStringOrDefaultFromEnv(ENV_YI_API_VER, DEFAULT_YI_API_VER) 32 | if apiVersion == "" { 33 | apiVersion = DEFAULT_YI_API_VER 34 | } 35 | modelConfig.Version = apiVersion 36 | 37 | // yi api key, allow override by request header 38 | apikey := fn.GetStringOrDefaultFromEnv(ENV_YI_API_KEY, "") 39 | modelConfig.Key = apikey 40 | 41 | // yi api model 42 | model := fn.GetStringOrDefaultFromEnv(ENV_YI_MODEL, DEFAULT_YI_MODEL) 43 | if model == "" { 44 | model = DEFAULT_YI_MODEL 45 | } 46 | modelConfig.Model = model 47 | 48 | ModelConfig[model] = modelConfig 49 | 50 | // yi api model alias 51 | alias := fn.ExtractModelAlias(fn.GetStringOrDefaultFromEnv(ENV_YI_MODEL_ALIAS, "")) 52 | for _, pair := range alias { 53 | if model == pair[0] { 54 | modelConfig.Model = pair[1] 55 | } 56 | ModelConfig[pair[0]] = modelConfig 57 | } 58 | return nil 59 | } 60 | -------------------------------------------------------------------------------- /models/yi/yi_test.go: -------------------------------------------------------------------------------- 1 | package yi_test 2 | 3 | import ( 4 | "net/url" 5 | "os" 6 | "reflect" 7 | "testing" 8 | 9 | "github.com/soulteary/amazing-openai-api/models/yi" 10 | ) 11 | 12 | func TestInit(t *testing.T) { 13 | // Load environment variables from a .env file for testing if needed 14 | // godotenv.Load("../path/to/your/.env.testing") 15 | 16 | t.Run("it should handle missing endpoint error", func(t *testing.T) { 17 | // Clear environment variable for endpoint 18 | os.Unsetenv(yi.ENV_YI_ENDPOINT) 19 | 20 | err := yi.Init() 21 | if err == nil || err.Error() != "missing environment variable "+yi.ENV_YI_ENDPOINT { 22 | t.Errorf("Expected missing endpoint environment variable error, got %v", err) 23 | } 24 | }) 25 | 26 | t.Run("it should parse and assign endpoint successfully", func(t *testing.T) { 27 | expectedURL := "https://example.com/api" 28 | os.Setenv(yi.ENV_YI_ENDPOINT, expectedURL) 29 | 30 | err := yi.Init() 31 | if err != nil { 32 | t.Fatalf("Unexpected error: %v", err) 33 | } 34 | 35 | modelCfg, exists := yi.ModelConfig[yi.DEFAULT_YI_MODEL] 36 | if !exists { 37 | t.Fatal("Model config does not exist after Init") 38 | } 39 | 40 | parsedURL, _ := url.Parse(expectedURL) 41 | if !reflect.DeepEqual(modelCfg.URL, parsedURL) { 42 | t.Errorf("Expected URL to be parsed correctly, got %+v", modelCfg.URL) 43 | } 44 | }) 45 | 46 | // ... Additional tests for version, api key, model, aliasing, etc. 47 | 48 | // Reset environment variables after testing 49 | t.Cleanup(func() { 50 | os.Unsetenv(yi.ENV_YI_ENDPOINT) 51 | // ... Unset other environment variables used during the tests 52 | }) 53 | } 54 | -------------------------------------------------------------------------------- /pkg/logger/gin-logrus.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | /* 4 | MIT License 5 | 6 | Copyright (c) 2016 Stéphane Depierrepont 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | 26 | @file: https://github.com/toorop/gin-logrus/blob/master/logger.go 27 | */ 28 | 29 | import ( 30 | "fmt" 31 | "math" 32 | "net/http" 33 | "os" 34 | "strings" 35 | "time" 36 | 37 | "github.com/gin-gonic/gin" 38 | "github.com/sirupsen/logrus" 39 | ) 40 | 41 | // 2016-09-27 09:38:21.541541811 +0200 CEST 42 | // 127.0.0.1 - frank [10/Oct/2000:13:55:36 -0700] 43 | // "GET /apache_pb.gif HTTP/1.0" 200 2326 44 | // "http://www.example.com/start.html" 45 | // "Mozilla/4.08 [en] (Win98; I ;Nav)" 46 | 47 | var timeFormat = "02/Jan/2006:15:04:05 -0700" 48 | 49 | // Logger is the logrus logger handler 50 | func Logger(logger logrus.FieldLogger, notLogged ...string) gin.HandlerFunc { 51 | hostname, err := os.Hostname() 52 | if err != nil { 53 | hostname = "unknow" 54 | } 55 | 56 | var skip map[string]struct{} 57 | 58 | if length := len(notLogged); length > 0 { 59 | skip = make(map[string]struct{}, length) 60 | 61 | for _, p := range notLogged { 62 | skip[p] = struct{}{} 63 | } 64 | } 65 | 66 | return func(c *gin.Context) { 67 | // other handler can change c.Path so: 68 | path := inputSanitized(c.Request.URL.Path) 69 | start := time.Now() 70 | c.Next() 71 | stop := time.Since(start) 72 | latency := int(math.Ceil(float64(stop.Nanoseconds()) / 1000000.0)) 73 | statusCode := c.Writer.Status() 74 | clientIP := inputSanitized(c.ClientIP()) 75 | clientUserAgent := inputSanitized(c.Request.UserAgent()) 76 | referer := inputSanitized(c.Request.Referer()) 77 | dataLength := c.Writer.Size() 78 | if dataLength < 0 { 79 | dataLength = 0 80 | } 81 | 82 | if _, ok := skip[path]; ok { 83 | return 84 | } 85 | 86 | entry := logger.WithFields(logrus.Fields{ 87 | "hostname": hostname, 88 | "statusCode": statusCode, 89 | "latency": latency, // time to process 90 | "clientIP": clientIP, 91 | "method": c.Request.Method, 92 | "path": path, 93 | "referer": referer, 94 | "dataLength": dataLength, 95 | "userAgent": clientUserAgent, 96 | }) 97 | 98 | if len(c.Errors) > 0 { 99 | entry.Error(c.Errors.ByType(gin.ErrorTypePrivate).String()) 100 | } else { 101 | msg := fmt.Sprintf("%s - %s [%s] \"%s %s\" %d %d \"%s\" \"%s\" (%dms)", clientIP, hostname, time.Now().Format(timeFormat), c.Request.Method, path, statusCode, dataLength, referer, clientUserAgent, latency) 102 | if statusCode >= http.StatusInternalServerError { 103 | entry.Error(msg) 104 | } else if statusCode >= http.StatusBadRequest { 105 | entry.Warn(msg) 106 | } else { 107 | entry.Info(msg) 108 | } 109 | } 110 | } 111 | } 112 | 113 | func inputSanitized(input string) string { 114 | return strings.Replace(strings.Replace(input, "\n", "", -1), "\r", "", -1) 115 | } 116 | -------------------------------------------------------------------------------- /pkg/logger/logger.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/sirupsen/logrus" 7 | ) 8 | 9 | var instance = logrus.New() 10 | 11 | func init() { 12 | instance.Formatter = new(logrus.TextFormatter) 13 | instance.Formatter.(*logrus.TextFormatter).DisableColors = false 14 | instance.Formatter.(*logrus.TextFormatter).DisableTimestamp = false 15 | instance.Formatter.(*logrus.TextFormatter).FullTimestamp = true 16 | 17 | // TODO Automatically adjust log output level based on environment startup configuration 18 | instance.Level = logrus.TraceLevel 19 | instance.Out = os.Stdout 20 | } 21 | 22 | func GetLogger() *logrus.Logger { 23 | return instance 24 | } 25 | --------------------------------------------------------------------------------