├── .github └── workflows │ ├── dev-003.yaml │ ├── dev.yaml │ ├── main.yaml │ └── tools.yaml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── apis ├── account │ ├── account.go │ ├── routes.go │ ├── schemas.go │ ├── token.go │ └── user.go ├── chat │ ├── chat.go │ ├── image.go │ ├── routes.go │ └── schemas.go ├── config │ ├── config.go │ ├── routes.go │ └── schemas.go ├── default.go ├── record │ ├── api_ws.go │ ├── apis.go │ ├── infer.go │ ├── limiter.go │ ├── observe.go │ ├── openai.go │ ├── routes.go │ ├── schemas.go │ ├── utils.go │ └── yocsef.go └── routes.go ├── config ├── cache.go └── config.go ├── data ├── data.go ├── image.html ├── ip2region.xdb └── meta.json ├── docs ├── docs.go ├── swagger.json └── swagger.yaml ├── go.mod ├── go.sum ├── local.conf ├── main.go ├── middlewares └── init.go ├── models ├── active_status.go ├── base.go ├── chat.go ├── config.go ├── email_blacklist.go ├── init.go ├── invite_code.go ├── user.go └── user_offense.go ├── service └── yocsef.go └── utils ├── auth ├── identifier.go └── verification.go ├── errors.go ├── kong ├── jwt.go └── kong.go ├── logger.go ├── region.go ├── region_test.go ├── sender.go ├── sensitive ├── diting │ └── main.go ├── sensitive.go └── shumei │ └── main.go ├── tools ├── calculate.go ├── draw.go ├── main.go ├── schema.go ├── search.go └── solve.go ├── utils.go └── validate.go /.github/workflows/dev-003.yaml: -------------------------------------------------------------------------------- 1 | name: Dev Build 2 | on: 3 | push: 4 | branches: [ dev-0.0.3 ] 5 | 6 | env: 7 | APP_NAME: moss_backend 8 | 9 | jobs: 10 | docker: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@master 15 | - name: Set up QEMU 16 | uses: docker/setup-qemu-action@master 17 | - name: Set up Docker Buildx 18 | uses: docker/setup-buildx-action@master 19 | - name: Login to DockerHub 20 | uses: docker/login-action@master 21 | with: 22 | username: ${{ secrets.DOCKERHUB_USERNAME }} 23 | password: ${{ secrets.DOCKERHUB_TOKEN }} 24 | - name: Build and push 25 | id: docker_build 26 | uses: docker/build-push-action@master 27 | with: 28 | push: true 29 | tags: | 30 | ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.APP_NAME }}:dev-0.0.3 31 | 32 | -------------------------------------------------------------------------------- /.github/workflows/dev.yaml: -------------------------------------------------------------------------------- 1 | name: Dev Build 2 | on: 3 | push: 4 | branches: [ dev ] 5 | 6 | env: 7 | APP_NAME: moss_backend 8 | 9 | jobs: 10 | docker: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@master 15 | - name: Set up QEMU 16 | uses: docker/setup-qemu-action@master 17 | - name: Set up Docker Buildx 18 | uses: docker/setup-buildx-action@master 19 | - name: Login to DockerHub 20 | uses: docker/login-action@master 21 | with: 22 | username: ${{ secrets.DOCKERHUB_USERNAME }} 23 | password: ${{ secrets.DOCKERHUB_TOKEN }} 24 | - name: Build and push 25 | id: docker_build 26 | uses: docker/build-push-action@master 27 | with: 28 | push: true 29 | tags: | 30 | ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.APP_NAME }}:latest 31 | ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.APP_NAME }}:dev 32 | 33 | -------------------------------------------------------------------------------- /.github/workflows/main.yaml: -------------------------------------------------------------------------------- 1 | name: Master Build 2 | on: 3 | push: 4 | branches: [ main ] 5 | 6 | env: 7 | APP_NAME: moss_backend 8 | 9 | jobs: 10 | docker: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@master 15 | 16 | - name: Set up QEMU 17 | uses: docker/setup-qemu-action@master 18 | 19 | - name: Set up Docker Buildx 20 | uses: docker/setup-buildx-action@master 21 | 22 | - name: Login to Aliyun ACR 23 | uses: aliyun/acr-login@master 24 | with: 25 | login-server: https://registry.cn-shanghai.aliyuncs.com 26 | username: ${{ secrets.ACR_USERNAME }} 27 | password: ${{ secrets.ACR_PASSWORD }} 28 | 29 | - name: Build and push Aliyun ACR 30 | uses: docker/build-push-action@master 31 | with: 32 | push: true 33 | tags: | 34 | registry.cn-shanghai.aliyuncs.com/${{ secrets.ACR_NAMESPACE }}/${{ env.APP_NAME }}:latest 35 | registry.cn-shanghai.aliyuncs.com/${{ secrets.ACR_NAMESPACE }}/${{ env.APP_NAME }}:master 36 | 37 | - name: Login to DockerHub 38 | uses: docker/login-action@master 39 | with: 40 | username: ${{ secrets.DOCKERHUB_USERNAME }} 41 | password: ${{ secrets.DOCKERHUB_TOKEN }} 42 | 43 | - name: Build and push to DockerHub 44 | uses: docker/build-push-action@master 45 | with: 46 | push: true 47 | tags: | 48 | ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.APP_NAME }}:latest 49 | ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.APP_NAME }}:master 50 | 51 | -------------------------------------------------------------------------------- /.github/workflows/tools.yaml: -------------------------------------------------------------------------------- 1 | name: Dev Build 2 | on: 3 | push: 4 | branches: [ tools ] 5 | 6 | env: 7 | APP_NAME: moss_backend 8 | 9 | jobs: 10 | docker: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@master 15 | - name: Set up QEMU 16 | uses: docker/setup-qemu-action@master 17 | - name: Set up Docker Buildx 18 | uses: docker/setup-buildx-action@master 19 | - name: Login to DockerHub 20 | uses: docker/login-action@master 21 | with: 22 | username: ${{ secrets.DOCKERHUB_USERNAME }} 23 | password: ${{ secrets.DOCKERHUB_TOKEN }} 24 | - name: Build and push 25 | id: docker_build 26 | uses: docker/build-push-action@master 27 | with: 28 | push: true 29 | tags: | 30 | ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.APP_NAME }}:tools 31 | 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | sqlite.db 3 | utils/tools/tools_test.go -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.22-alpine as builder 2 | 3 | WORKDIR /app 4 | 5 | COPY go.mod go.sum ./ 6 | RUN apk add --no-cache --virtual .build-deps \ 7 | ca-certificates \ 8 | tzdata \ 9 | gcc \ 10 | g++ && \ 11 | go mod download 12 | 13 | COPY . . 14 | 15 | RUN go build -ldflags "-s -w" -o auth 16 | 17 | FROM alpine 18 | 19 | # Installs latest Chromium package. 20 | RUN apk add --no-cache \ 21 | chromium-swiftshader \ 22 | ttf-freefont \ 23 | font-noto-emoji \ 24 | && apk add --no-cache \ 25 | --repository=https://dl-cdn.alpinelinux.org/alpine/edge/testing \ 26 | font-wqy-zenhei 27 | 28 | COPY local.conf /etc/fonts/local.conf 29 | 30 | WORKDIR /app 31 | 32 | COPY --from=builder /app/auth /app/ 33 | COPY --from=builder /usr/share/zoneinfo /usr/share/zoneinfo 34 | COPY data data 35 | 36 | ENV TZ=Asia/Shanghai 37 | 38 | ENV MODE=production 39 | 40 | RUN mkdir -p ./screenshots 41 | RUN mkdir -p ./draw 42 | 43 | VOLUME ["/app/screenshots"] 44 | VOLUME ["/app/draw"] 45 | 46 | EXPOSE 8000 47 | 48 | ENTRYPOINT ["./auth"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MOSS backend 2 | 3 | backend for fastnlp MOSS project 4 | 5 | Only open source here. Don't host this project yourself. No inference model is included 6 | 7 | 仅开源,不建议自部署,不包含推断模型。 8 | 9 | ### Powered by 10 | 11 |  12 |  13 | 14 | 15 | ### Contributors 16 | 17 | This project exists thanks to all the people who contribute. 18 | 19 | 20 | 21 | 22 | 23 | ## Licence 24 | 25 | [](https://github.com/OpenTreeHole/auth_next/blob/master/LICENSE) 26 | © JingYiJun 27 | -------------------------------------------------------------------------------- /apis/account/account.go: -------------------------------------------------------------------------------- 1 | package account 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | "time" 7 | 8 | "MOSS_backend/config" 9 | . "MOSS_backend/models" 10 | . "MOSS_backend/utils" 11 | "MOSS_backend/utils/auth" 12 | "MOSS_backend/utils/kong" 13 | 14 | "github.com/gofiber/fiber/v2" 15 | "gorm.io/gorm" 16 | "gorm.io/gorm/clause" 17 | ) 18 | 19 | // Register godoc 20 | // 21 | // @Summary register 22 | // @Description register with email or phone, password and verification code 23 | // @Tags account 24 | // @Accept json 25 | // @Produce json 26 | // @Router /register [post] 27 | // @Param json body RegisterRequest true "json" 28 | // @Success 201 {object} TokenResponse 29 | // @Failure 400 {object} utils.MessageResponse "验证码错误、用户已注册" 30 | // @Failure 500 {object} utils.MessageResponse 31 | func Register(c *fiber.Ctx) error { 32 | scope := "register" 33 | var ( 34 | body RegisterRequest 35 | ok bool 36 | ) 37 | err := ValidateBody(c, &body) 38 | if err != nil { 39 | return err 40 | } 41 | 42 | var ( 43 | user User 44 | registered = false 45 | deleted = false 46 | inviteCode InviteCode 47 | ) 48 | 49 | errCollection, messageCollection := GetInfoByIP(GetRealIP(c)) 50 | 51 | // check invite code config 52 | var configObject Config 53 | err = LoadConfig(&configObject) 54 | if err != nil { 55 | return err 56 | } 57 | 58 | // check verification code first 59 | if body.PhoneModel != nil { 60 | ok = auth.CheckVerificationCode(body.Phone, scope, body.Verification) 61 | } else if body.EmailModel != nil { 62 | if IsEmailInBlacklist(body.Email) { 63 | return errCollection.ErrEmailInBlacklist 64 | } 65 | ok = auth.CheckVerificationCode(body.Email, scope, body.Verification) 66 | } 67 | if !ok { 68 | return errCollection.ErrVerificationCodeInvalid 69 | } 70 | 71 | // check Invite code 72 | var inviteRequired = configObject.InviteRequired 73 | if body.EmailModel != nil { 74 | // check email suffix in no need invite code 75 | for _, emailSuffix := range config.Config.NoNeedInviteCodeEmailSuffix { 76 | if strings.HasSuffix(body.Email, emailSuffix) { 77 | inviteRequired = false 78 | break 79 | } 80 | } 81 | } 82 | if inviteRequired { 83 | if body.InviteCode == nil { 84 | return errCollection.ErrNeedInviteCode 85 | } 86 | err = DB.Take(&inviteCode, "code = ?", body.InviteCode).Error 87 | if err != nil || !inviteCode.IsSend || inviteCode.IsActivated { 88 | return errCollection.ErrInviteCodeInvalid 89 | } 90 | } 91 | 92 | if body.PhoneModel != nil { 93 | err = DB.Unscoped().Take(&user, "phone = ?", body.Phone).Error 94 | if err != nil { 95 | if !errors.Is(err, gorm.ErrRecordNotFound) { 96 | return err 97 | } 98 | registered = false 99 | user.Phone = body.Phone 100 | } else { 101 | registered = true 102 | deleted = user.DeletedAt.Valid 103 | } 104 | } else if body.EmailModel != nil { 105 | err = DB.Unscoped().Take(&user, "email = ?", body.Email).Error 106 | if err != nil { 107 | if !errors.Is(err, gorm.ErrRecordNotFound) { 108 | return err 109 | } 110 | registered = false 111 | user.Email = body.Email 112 | } else { 113 | registered = true 114 | deleted = user.DeletedAt.Valid 115 | } 116 | } else { 117 | return BadRequest() 118 | } 119 | 120 | user.Password, err = auth.MakePassword(body.Password) 121 | if err != nil { 122 | return err 123 | } 124 | remoteIP := GetRealIP(c) 125 | user.ModelID = config.Config.DefaultModelID 126 | 127 | if registered { 128 | if deleted { 129 | err = DB.Unscoped().Model(&user).Update("DeletedAt", gorm.Expr("NULL")).Error 130 | if err != nil { 131 | return err 132 | } 133 | 134 | user.DeletedAt.Valid = false 135 | user.DeletedAt.Time = time.Unix(0, 0) 136 | 137 | user.JoinedTime = time.Now() 138 | user.RegisterIP = remoteIP 139 | user.LoginIP = []string{} 140 | user.UpdateIP(remoteIP) 141 | user.ShareConsent = true 142 | // set invite code 143 | if inviteRequired { 144 | user.InviteCode = inviteCode.Code 145 | } 146 | err = DB.Save(&user).Error 147 | if err != nil { 148 | return err 149 | } 150 | } else { 151 | return errCollection.ErrRegistered 152 | } 153 | } else { 154 | user.RegisterIP = remoteIP 155 | user.UpdateIP(remoteIP) 156 | user.ShareConsent = true 157 | 158 | // set invite code 159 | if inviteRequired { 160 | user.InviteCode = inviteCode.Code 161 | } 162 | 163 | err = DB.Create(&user).Error 164 | if err != nil { 165 | return err 166 | } 167 | 168 | err = kong.CreateUser(user.ID) 169 | if err != nil { 170 | return err 171 | } 172 | } 173 | 174 | // create kong token 175 | accessToken, refreshToken, err := kong.CreateToken(&user) 176 | if err != nil { 177 | return err 178 | } 179 | 180 | // delete verification 181 | if body.EmailModel != nil { 182 | _ = auth.DeleteVerificationCode(body.Email, scope) 183 | } else { 184 | _ = auth.DeleteVerificationCode(body.Phone, scope) 185 | } 186 | 187 | // update inviteCode 188 | if inviteRequired { 189 | inviteCode.IsActivated = true 190 | DB.Save(&inviteCode) 191 | } 192 | 193 | return c.JSON(TokenResponse{ 194 | Access: accessToken, 195 | Refresh: refreshToken, 196 | Message: messageCollection.MessageRegisterSuccess, 197 | }) 198 | } 199 | 200 | // ChangePassword godoc 201 | // 202 | // @Summary reset password 203 | // @Description reset password, reset jwt credential 204 | // @Tags account 205 | // @Accept json 206 | // @Produce json 207 | // @Router /register [put] 208 | // @Param json body RegisterRequest true "json" 209 | // @Success 200 {object} TokenResponse 210 | // @Failure 400 {object} utils.MessageResponse "验证码错误" 211 | // @Failure 500 {object} utils.MessageResponse 212 | func ChangePassword(c *fiber.Ctx) error { 213 | scope := "reset" 214 | var ( 215 | body RegisterRequest 216 | ok bool 217 | ) 218 | err := ValidateBody(c, &body) 219 | if err != nil { 220 | return err 221 | } 222 | 223 | errCollection, messageCollection := GetInfoByIP(GetRealIP(c)) 224 | 225 | if body.PhoneModel != nil { 226 | ok = auth.CheckVerificationCode(body.Phone, scope, body.Verification) 227 | } else if body.EmailModel != nil { 228 | if IsEmailInBlacklist(body.Email) { 229 | return errCollection.ErrEmailInBlacklist 230 | } 231 | ok = auth.CheckVerificationCode(body.Email, scope, body.Verification) 232 | } 233 | if !ok { 234 | return errCollection.ErrVerificationCodeInvalid 235 | } 236 | 237 | var user User 238 | err = DB.Transaction(func(tx *gorm.DB) error { 239 | querySet := tx.Clauses(clause.Locking{Strength: "UPDATE"}) 240 | if body.PhoneModel != nil { 241 | querySet = querySet.Where("phone = ?", body.Phone) 242 | } else if body.EmailModel != nil { 243 | querySet = querySet.Where("email = ?", body.Email) 244 | } else { 245 | return BadRequest() 246 | } 247 | err = querySet.Take(&user).Error 248 | if err != nil { 249 | return err 250 | } 251 | 252 | user.Password, err = auth.MakePassword(body.Password) 253 | if err != nil { 254 | return err 255 | } 256 | return tx.Save(&user).Error 257 | }) 258 | if err != nil { 259 | return err 260 | } 261 | 262 | err = kong.DeleteJwtCredential(user.ID) 263 | if err != nil { 264 | return err 265 | } 266 | 267 | accessToken, refreshToken, err := kong.CreateToken(&user) 268 | if err != nil { 269 | return err 270 | } 271 | 272 | if body.EmailModel != nil { 273 | err = auth.DeleteVerificationCode(body.Email, scope) 274 | } else { 275 | err = auth.DeleteVerificationCode(body.Phone, scope) 276 | } 277 | if err != nil { 278 | return err 279 | } 280 | 281 | return c.JSON(TokenResponse{ 282 | Access: accessToken, 283 | Refresh: refreshToken, 284 | Message: messageCollection.MessageResetPasswordSuccess, 285 | }) 286 | } 287 | 288 | // VerifyWithEmail godoc 289 | // 290 | // @Summary verify with email in query 291 | // @Description verify with email in query, Send verification email 292 | // @Tags account 293 | // @Produce json 294 | // @Router /verify/email [get] 295 | // @Param email query VerifyEmailRequest true "email" 296 | // @Success 200 {object} VerifyResponse 297 | // @Failure 400 {object} utils.MessageResponse "已注册“ 298 | func VerifyWithEmail(c *fiber.Ctx) error { 299 | var query VerifyEmailRequest 300 | err := ValidateQuery(c, &query) 301 | if err != nil { 302 | return err 303 | } 304 | 305 | errCollection, messageCollection := GetInfoByIP(GetRealIP(c)) 306 | if IsEmailInBlacklist(query.Email) { 307 | return errCollection.ErrEmailInBlacklist 308 | } 309 | 310 | var ( 311 | user User 312 | scope string 313 | login bool 314 | inviteCode InviteCode 315 | ) 316 | userID, _ := GetUserID(c) 317 | login = userID != 0 318 | // check invite code config 319 | var configObject Config 320 | err = LoadConfig(&configObject) 321 | if err != nil { 322 | return err 323 | } 324 | 325 | err = DB.Take(&user, "email = ?", query.Email).Error 326 | if err != nil { 327 | if !errors.Is(err, gorm.ErrRecordNotFound) { 328 | return err 329 | } 330 | if !login { 331 | scope = "register" 332 | } else { 333 | scope = "modify" 334 | DeleteUserCacheByID(userID) // 已登录的,清空缓存 335 | } 336 | } else { 337 | if !login { 338 | scope = "reset" 339 | } else { 340 | return errCollection.ErrEmailRegistered 341 | } 342 | } 343 | if query.Scope != "" { 344 | if scope != query.Scope { 345 | switch scope { 346 | case "register": 347 | return errCollection.ErrEmailNotRegistered 348 | case "reset": 349 | switch query.Scope { 350 | case "register": 351 | return errCollection.ErrEmailRegistered 352 | case "modify": 353 | return errCollection.ErrEmailCannotModify 354 | default: 355 | return BadRequest() 356 | } 357 | case "modify": 358 | switch query.Scope { 359 | case "register": 360 | return errCollection.ErrEmailRegistered 361 | case "reset": 362 | return errCollection.ErrEmailCannotReset 363 | default: 364 | return BadRequest() 365 | } 366 | default: 367 | return BadRequest() 368 | } 369 | } 370 | } 371 | 372 | if scope == "register" { 373 | // check Invite code 374 | inviteRequired := configObject.InviteRequired 375 | // check email suffix in no need invite code 376 | for _, emailSuffix := range config.Config.NoNeedInviteCodeEmailSuffix { 377 | if strings.HasSuffix(query.Email, emailSuffix) { 378 | inviteRequired = false 379 | break 380 | } 381 | } 382 | if inviteRequired { 383 | if query.InviteCode == nil { 384 | return errCollection.ErrNeedInviteCode 385 | } 386 | err = DB.Take(&inviteCode, "code = ?", query.InviteCode).Error 387 | if err != nil || !inviteCode.IsSend || inviteCode.IsActivated { 388 | return errCollection.ErrInviteCodeInvalid 389 | } 390 | } 391 | } 392 | 393 | code, err := auth.SetVerificationCode(query.Email, scope) 394 | if err != nil { 395 | return err 396 | } 397 | 398 | err = SendCodeEmail(code, query.Email) 399 | if err != nil { 400 | return err 401 | } 402 | 403 | return c.JSON(VerifyResponse{ 404 | Message: messageCollection.MessageVerificationEmailSend, 405 | Scope: scope, 406 | }) 407 | } 408 | 409 | // VerifyWithPhone godoc 410 | // 411 | // @Summary verify with phone in query 412 | // @Description verify with phone in query, Send verification message 413 | // @Tags account 414 | // @Produce json 415 | // @Router /verify/phone [get] 416 | // @Param phone query VerifyPhoneRequest true "phone" 417 | // @Success 200 {object} VerifyResponse 418 | // @Failure 400 {object} utils.MessageResponse "已注册“ 419 | func VerifyWithPhone(c *fiber.Ctx) error { 420 | var query VerifyPhoneRequest 421 | err := ValidateQuery(c, &query) 422 | if err != nil { 423 | return BadRequest("invalid phone number") 424 | } 425 | 426 | errCollection, messageCollection := GetInfoByIP(GetRealIP(c)) 427 | 428 | var ( 429 | user User 430 | scope string 431 | login bool 432 | inviteCode InviteCode 433 | ) 434 | 435 | // check invite code config 436 | var configObject Config 437 | err = LoadConfig(&configObject) 438 | if err != nil { 439 | return err 440 | } 441 | 442 | userID, _ := GetUserID(c) 443 | login = userID != 0 444 | err = DB.Take(&user, "phone = ?", query.Phone).Error 445 | if err != nil { 446 | if !errors.Is(err, gorm.ErrRecordNotFound) { 447 | return err 448 | } 449 | if !login { 450 | scope = "register" // 未注册、未登录 451 | } else { 452 | scope = "modify" // 未注册、已登录 453 | DeleteUserCacheByID(userID) 454 | } 455 | } else { 456 | if !login { 457 | scope = "reset" // 已注册、未登录 458 | } else { 459 | return errCollection.ErrPhoneRegistered // 已注册、已登录 460 | } 461 | } 462 | 463 | if query.Scope != "" { 464 | if scope != query.Scope { 465 | switch scope { 466 | case "register": 467 | return errCollection.ErrPhoneNotRegistered 468 | case "reset": 469 | switch query.Scope { 470 | case "register": 471 | return errCollection.ErrPhoneRegistered 472 | case "modify": 473 | return errCollection.ErrPhoneCannotModify 474 | default: 475 | return BadRequest() 476 | } 477 | case "modify": 478 | switch query.Scope { 479 | case "register": 480 | return errCollection.ErrPhoneRegistered 481 | case "reset": 482 | return errCollection.ErrPhoneCannotReset 483 | default: 484 | return BadRequest() 485 | } 486 | default: 487 | return BadRequest() 488 | } 489 | } 490 | } 491 | 492 | if scope == "register" { 493 | // check Invite code 494 | if configObject.InviteRequired { 495 | if query.InviteCode == nil { 496 | return errCollection.ErrNeedInviteCode 497 | } 498 | err = DB.Take(&inviteCode, "code = ?", query.InviteCode).Error 499 | if err != nil || !inviteCode.IsSend || inviteCode.IsActivated { 500 | return errCollection.ErrInviteCodeInvalid 501 | } 502 | } 503 | } 504 | code, err := auth.SetVerificationCode(query.Phone, scope) 505 | if err != nil { 506 | return err 507 | } 508 | 509 | err = SendCodeMessage(code, query.Phone) 510 | if err != nil { 511 | return err 512 | } 513 | 514 | return c.JSON(VerifyResponse{ 515 | Message: messageCollection.MessageVerificationPhoneSend, 516 | Scope: scope, 517 | }) 518 | } 519 | 520 | // DeleteUser godoc 521 | // 522 | // @Summary delete user 523 | // @Description delete user and related jwt credentials 524 | // @Tags account 525 | // @Router /users/me [delete] 526 | // @Param json body LoginRequest true "email, password" 527 | // @Success 204 528 | // @Failure 400 {object} utils.MessageResponse "密码错误“ 529 | // @Failure 404 {object} utils.MessageResponse "用户不存在“ 530 | // @Failure 500 {object} utils.MessageResponse 531 | func DeleteUser(c *fiber.Ctx) error { 532 | var body LoginRequest 533 | err := ValidateBody(c, &body) 534 | if err != nil { 535 | return err 536 | } 537 | 538 | errCollection, _ := GetInfoByIP(GetRealIP(c)) 539 | 540 | var user User 541 | err = DB.Transaction(func(tx *gorm.DB) error { 542 | querySet := tx.Clauses(clause.Locking{Strength: "UPDATE"}) 543 | if body.PhoneModel != nil { 544 | querySet = querySet.Where("phone = ?", body.Phone) 545 | } else if body.EmailModel != nil { 546 | querySet = querySet.Where("email = ?", body.Email) 547 | } else { 548 | return BadRequest() 549 | } 550 | err = querySet.Take(&user).Error 551 | if err != nil { 552 | return err 553 | } 554 | 555 | ok, err := auth.CheckPassword(body.Password, user.Password) 556 | if err != nil { 557 | return err 558 | } 559 | if !ok { 560 | return errCollection.ErrPasswordIncorrect 561 | } 562 | 563 | return tx.Delete(&user).Error 564 | }) 565 | 566 | DeleteUserCacheByID(user.ID) 567 | 568 | err = kong.DeleteJwtCredential(user.ID) 569 | if err != nil { 570 | return err 571 | } 572 | 573 | return c.SendStatus(204) 574 | } 575 | -------------------------------------------------------------------------------- /apis/account/routes.go: -------------------------------------------------------------------------------- 1 | package account 2 | 3 | import "github.com/gofiber/fiber/v2" 4 | 5 | func RegisterRoutes(routes fiber.Router) { 6 | // token 7 | routes.Post("/login", Login) 8 | routes.Get("/logout", Logout) 9 | routes.Post("/refresh", Refresh) 10 | 11 | // account management 12 | routes.Get("/verify/email", VerifyWithEmail) 13 | routes.Get("/verify/phone", VerifyWithPhone) 14 | routes.Post("/register", Register) 15 | routes.Put("/register", ChangePassword) 16 | routes.Delete("/users/me", DeleteUser) 17 | 18 | // user info 19 | routes.Get("/users/me", GetCurrentUser) 20 | routes.Put("/users/me", ModifyUser) 21 | } 22 | -------------------------------------------------------------------------------- /apis/account/schemas.go: -------------------------------------------------------------------------------- 1 | package account 2 | 3 | /* account */ 4 | 5 | type EmailModel struct { 6 | Email string `json:"email" query:"email" validate:"required,email"` 7 | } 8 | 9 | type ScopeModel struct { 10 | Scope string `json:"scope" query:"scope" validate:"omitempty,oneof=register reset modify"` 11 | } 12 | 13 | type PhoneModel struct { 14 | Phone string `json:"phone" query:"phone" validate:"required"` // phone number in e164 mode 15 | } 16 | 17 | type VerifyEmailRequest struct { 18 | EmailModel 19 | ScopeModel 20 | InviteCode *string `json:"invite_code" query:"invite_code" validate:"omitempty,min=1"` 21 | } 22 | 23 | type VerifyPhoneRequest struct { 24 | PhoneModel 25 | ScopeModel 26 | InviteCode *string `json:"invite_code" query:"invite_code" validate:"omitempty,min=1"` 27 | } 28 | 29 | type LoginRequest struct { 30 | *EmailModel `validate:"omitempty"` 31 | *PhoneModel `validate:"omitempty"` 32 | Password string `json:"password" minLength:"8"` 33 | } 34 | 35 | type TokenResponse struct { 36 | Access string `json:"access"` 37 | Refresh string `json:"refresh"` 38 | Message string `json:"message"` 39 | } 40 | 41 | type RegisterRequest struct { 42 | LoginRequest 43 | Verification string `json:"verification" minLength:"6" maxLength:"6" validate:"len=6"` 44 | InviteCode *string `json:"invite_code" validate:"omitempty,min=1"` 45 | } 46 | 47 | type VerifyResponse struct { 48 | Message string `json:"message"` 49 | Scope string `json:"scope" enums:"register,reset"` 50 | } 51 | 52 | type ModifyUserRequest struct { 53 | Nickname *string `json:"nickname" validate:"omitempty,min=1"` 54 | ShareConsent *bool `json:"share_consent"` 55 | *EmailModel `validate:"omitempty"` 56 | *PhoneModel `validate:"omitempty"` 57 | Verification string `json:"verification" minLength:"6" maxLength:"6" validate:"omitempty,len=6"` 58 | DisableSensitiveCheck *bool `json:"disable_sensitive_check"` 59 | ModelID *int `json:"model_id" validate:"omitempty,min=1"` 60 | PluginConfig map[string]bool `json:"plugin_config" validate:"omitempty"` 61 | } 62 | -------------------------------------------------------------------------------- /apis/account/token.go: -------------------------------------------------------------------------------- 1 | package account 2 | 3 | import ( 4 | . "MOSS_backend/models" 5 | . "MOSS_backend/utils" 6 | "MOSS_backend/utils/auth" 7 | "MOSS_backend/utils/kong" 8 | "errors" 9 | "github.com/gofiber/fiber/v2" 10 | "gorm.io/gorm" 11 | ) 12 | 13 | // Login godoc 14 | // 15 | // @Summary Login 16 | // @Description Login with email and password, return jwt token, not need jwt 17 | // @Tags token 18 | // @Accept json 19 | // @Produce json 20 | // @Router /login [post] 21 | // @Param json body LoginRequest true "json" 22 | // @Success 200 {object} TokenResponse 23 | // @Failure 400 {object} utils.MessageResponse 24 | // @Failure 404 {object} utils.MessageResponse "User Not Found" 25 | // @Failure 500 {object} utils.MessageResponse 26 | func Login(c *fiber.Ctx) error { 27 | var body LoginRequest 28 | err := ValidateBody(c, &body) 29 | if err != nil { 30 | return err 31 | } 32 | 33 | errCollection, messageCollection := GetInfoByIP(GetRealIP(c)) 34 | 35 | var user User 36 | if body.EmailModel != nil { 37 | err = DB.Where("email = ?", body.Email).Take(&user).Error 38 | } else if body.PhoneModel != nil { 39 | err = DB.Where("phone = ?", body.Phone).Take(&user).Error 40 | } else { 41 | return BadRequest() 42 | } 43 | if err != nil { 44 | if errors.Is(err, gorm.ErrRecordNotFound) { 45 | return NotFound("User Not Found") 46 | } else { 47 | return err 48 | } 49 | } 50 | 51 | ok, err := auth.CheckPassword(body.Password, user.Password) 52 | if err != nil { 53 | return err 54 | } 55 | if !ok { 56 | return errCollection.ErrPasswordIncorrect 57 | } 58 | 59 | // update login time and ip 60 | user.UpdateIP(GetRealIP(c)) 61 | err = DB.Save(&user).Error 62 | if err != nil { 63 | return err 64 | } 65 | 66 | access, refresh, err := kong.CreateToken(&user) 67 | if err != nil { 68 | return err 69 | } 70 | 71 | return c.JSON(TokenResponse{ 72 | Access: access, 73 | Refresh: refresh, 74 | Message: messageCollection.MessageLoginSuccess, 75 | }) 76 | } 77 | 78 | // Logout 79 | // 80 | // @Summary Logout 81 | // @Description Logout, clear jwt credential and return successful message, logout, login required 82 | // @Tags token 83 | // @Produce json 84 | // @Router /logout [get] 85 | // @Success 200 {object} utils.MessageResponse 86 | func Logout(c *fiber.Ctx) error { 87 | userID, err := GetUserID(c) 88 | if err != nil { 89 | return err 90 | } 91 | 92 | _, messageCollection := GetInfoByIP(GetRealIP(c)) 93 | 94 | var user User 95 | err = LoadUserByIDFromCache(userID, &user) 96 | if err != nil { 97 | return err 98 | } 99 | 100 | err = kong.DeleteJwtCredential(userID) 101 | if err != nil { 102 | return err 103 | } 104 | 105 | return c.JSON(MessageResponse{Message: messageCollection.MessageLogoutSuccess}) 106 | } 107 | 108 | // Refresh 109 | // 110 | // @Summary Refresh jwt token 111 | // @Description Refresh jwt token with refresh token in header, login required 112 | // @Tags token 113 | // @Produce json 114 | // @Router /refresh [post] 115 | // @Success 200 {object} TokenResponse 116 | func Refresh(c *fiber.Ctx) error { 117 | user, err := GetUserByRefreshToken(c) 118 | if err != nil { 119 | return err 120 | } 121 | 122 | // update login time and ip 123 | user.UpdateIP(GetRealIP(c)) 124 | err = DB.Model(&user).Select("LastLoginIP", "LoginIP").Save(&user).Error 125 | if err != nil { 126 | return err 127 | } 128 | 129 | access, refresh, err := kong.CreateToken(user) 130 | if err != nil { 131 | return err 132 | } 133 | return c.JSON(TokenResponse{ 134 | Access: access, 135 | Refresh: refresh, 136 | Message: "refresh successful", 137 | }) 138 | } 139 | -------------------------------------------------------------------------------- /apis/account/user.go: -------------------------------------------------------------------------------- 1 | package account 2 | 3 | import ( 4 | "MOSS_backend/config" 5 | . "MOSS_backend/models" 6 | . "MOSS_backend/utils" 7 | "MOSS_backend/utils/auth" 8 | 9 | "github.com/gofiber/fiber/v2" 10 | "gorm.io/gorm" 11 | ) 12 | 13 | // GetCurrentUser godoc 14 | // 15 | // @Summary get current user 16 | // @Tags user 17 | // @Produce json 18 | // @Router /users/me [get] 19 | // @Success 200 {object} User 20 | // @Failure 404 {object} utils.MessageResponse "User not found" 21 | // @Failure 500 {object} utils.MessageResponse 22 | func GetCurrentUser(c *fiber.Ctx) error { 23 | user, err := LoadUser(c) 24 | if err != nil { 25 | return err 26 | } 27 | return c.JSON(user) 28 | } 29 | 30 | // ModifyUser godoc 31 | // 32 | // @Summary modify user, need login 33 | // @Tags user 34 | // @Produce json 35 | // @Router /users/me [put] 36 | // @Param json body ModifyUserRequest true "json" 37 | // @Success 200 {object} User 38 | // @Failure 500 {object} utils.MessageResponse 39 | func ModifyUser(c *fiber.Ctx) error { 40 | scope := "modify" 41 | var body ModifyUserRequest 42 | err := ValidateBody(c, &body) 43 | if err != nil { 44 | return err 45 | } 46 | 47 | userID, err := GetUserID(c) 48 | if err != nil { 49 | return err 50 | } 51 | 52 | var user User 53 | err = DB.Transaction(func(tx *gorm.DB) error { 54 | err = tx.Clauses(LockingClause).Take(&user, userID).Error 55 | if err != nil { 56 | return err 57 | } 58 | 59 | if body.Nickname != nil { 60 | user.Nickname = *body.Nickname 61 | } 62 | 63 | if body.ShareConsent != nil { 64 | user.ShareConsent = *body.ShareConsent 65 | } 66 | 67 | if body.EmailModel != nil && body.Email != user.Email { 68 | ok := auth.CheckVerificationCode(body.Email, scope, body.Verification) 69 | if !ok { 70 | return BadRequest("verification code error") 71 | } 72 | 73 | user.Email = body.Email 74 | } 75 | 76 | if body.PhoneModel != nil && body.Phone != user.Phone { 77 | ok := auth.CheckVerificationCode(body.Phone, scope, body.Verification) 78 | if !ok { 79 | return BadRequest("verification code error") 80 | } 81 | 82 | user.Phone = body.Phone 83 | } 84 | 85 | if body.DisableSensitiveCheck != nil { 86 | if !user.IsAdmin { 87 | return Forbidden() 88 | } 89 | user.DisableSensitiveCheck = *body.DisableSensitiveCheck 90 | } 91 | if body.ModelID != nil { // model switch 92 | user.ModelID = *body.ModelID 93 | } 94 | var defaultPluginConfig map[string]bool 95 | 96 | // model switch or plugin config change => update plugin config 97 | if body.ModelID != nil || body.PluginConfig != nil { 98 | // init ModelID 99 | if user.ModelID == 0 { 100 | user.ModelID = config.Config.DefaultModelID 101 | } 102 | 103 | // model switch 104 | if body.ModelID != nil { 105 | user.ModelID = *body.ModelID 106 | } 107 | 108 | // init plugin config 109 | defaultPluginConfig, err = GetPluginConfig(user.ModelID) 110 | if err != nil { 111 | return InternalServerError("Failed to change plugin config, please try again later") 112 | } 113 | if user.PluginConfig == nil { 114 | user.PluginConfig = defaultPluginConfig 115 | } 116 | 117 | // plugin config change 118 | if body.PluginConfig != nil { 119 | for key, value := range body.PluginConfig { 120 | if _, ok := defaultPluginConfig[key]; ok { 121 | user.PluginConfig[key] = value 122 | } 123 | } 124 | } 125 | } 126 | return tx.Save(&user).Error 127 | }) 128 | 129 | if err != nil { 130 | return err 131 | } 132 | 133 | // redis update 134 | _ = config.SetCache(GetUserCacheKey(user.ID), user, UserCacheExpire) 135 | 136 | return c.JSON(user) 137 | } 138 | -------------------------------------------------------------------------------- /apis/chat/chat.go: -------------------------------------------------------------------------------- 1 | package chat 2 | 3 | import ( 4 | . "MOSS_backend/models" 5 | . "MOSS_backend/utils" 6 | "fmt" 7 | "github.com/gofiber/fiber/v2" 8 | "github.com/google/uuid" 9 | "gorm.io/gorm" 10 | "os" 11 | ) 12 | 13 | // ListChats 14 | // @Summary list user's chats 15 | // @Tags chat 16 | // @Router /chats [get] 17 | // @Success 200 {array} models.Chat 18 | func ListChats(c *fiber.Ctx) error { 19 | userID, err := GetUserID(c) 20 | if err != nil { 21 | return err 22 | } 23 | 24 | // delete empty chats 25 | err = DB.Where("user_id = ? and count = 0", userID).Delete(&Chat{}).Error 26 | if err != nil { 27 | return err 28 | } 29 | 30 | // get all chats 31 | var chats = Chats{} 32 | err = DB.Order("updated_at desc").Find(&chats, "user_id = ?", userID).Error 33 | if err != nil { 34 | return err 35 | } 36 | 37 | return c.JSON(chats) 38 | } 39 | 40 | // AddChat 41 | // @Summary add a chat 42 | // @Tags chat 43 | // @Router /chats [post] 44 | // @Success 201 {object} models.Chat 45 | func AddChat(c *fiber.Ctx) error { 46 | userID, err := GetUserID(c) 47 | if err != nil { 48 | return err 49 | } 50 | 51 | chat := Chat{UserID: userID} 52 | err = DB.Create(&chat).Error 53 | if err != nil { 54 | return err 55 | } 56 | 57 | return c.Status(201).JSON(chat) 58 | } 59 | 60 | // ModifyChat 61 | // @Summary modify a chat 62 | // @Tags chat 63 | // @Router /chats/{chat_id} [put] 64 | // @Param chat_id path int true "chat id" 65 | // @Param json body ModifyModel true "json" 66 | // @Success 200 {object} models.Chat 67 | func ModifyChat(c *fiber.Ctx) error { 68 | userID, err := GetUserID(c) 69 | if err != nil { 70 | return err 71 | } 72 | 73 | chatID, err := c.ParamsInt("id") 74 | if err != nil { 75 | return err 76 | } 77 | 78 | var body ModifyModel 79 | err = ValidateBody(c, &body) 80 | if err != nil { 81 | return err 82 | } 83 | 84 | var chat Chat 85 | err = DB.Transaction(func(tx *gorm.DB) error { 86 | err = tx.Clauses(LockingClause).Take(&chat, chatID).Error 87 | if err != nil { 88 | return err 89 | } 90 | 91 | if chat.UserID != userID { 92 | return Forbidden() 93 | } 94 | 95 | if body.Name != nil { 96 | chat.Name = *body.Name 97 | } 98 | 99 | return tx.Save(&chat).Error 100 | }) 101 | if err != nil { 102 | return err 103 | } 104 | 105 | return c.JSON(chat) 106 | } 107 | 108 | // DeleteChat 109 | // @Summary delete a chat 110 | // @Tags chat 111 | // @Router /chats/{chat_id} [delete] 112 | // @Param chat_id path int true "chat id" 113 | // @Success 204 114 | func DeleteChat(c *fiber.Ctx) error { 115 | userID, err := GetUserID(c) 116 | if err != nil { 117 | return err 118 | } 119 | 120 | chatID, err := c.ParamsInt("id") 121 | if err != nil { 122 | return err 123 | } 124 | 125 | var chat Chat 126 | err = DB.Transaction(func(tx *gorm.DB) error { 127 | err = tx.Clauses(LockingClause).Take(&chat, chatID).Error 128 | if err != nil { 129 | return err 130 | } 131 | 132 | if chat.UserID != userID { 133 | return Forbidden() 134 | } 135 | 136 | return tx.Delete(&chat).Error 137 | }) 138 | if err != nil { 139 | return err 140 | } 141 | 142 | return c.SendStatus(204) 143 | } 144 | 145 | // GenerateChatScreenshot 146 | // @Summary screenshot of a chat 147 | // @Tags record 148 | // @Produce png 149 | // @Router /chats/{chat_id}/screenshots [get] 150 | // @Param chat_id path int true "chat id" 151 | // @Success 200 152 | func GenerateChatScreenshot(c *fiber.Ctx) error { 153 | chatID, err := c.ParamsInt("id") 154 | if err != nil { 155 | return err 156 | } 157 | 158 | userID, err := GetUserID(c) 159 | if err != nil { 160 | return err 161 | } 162 | 163 | var chat Chat 164 | err = DB.Take(&chat, chatID).Error 165 | if err != nil { 166 | return err 167 | } 168 | 169 | if userID != chat.UserID { 170 | return Forbidden() 171 | } 172 | 173 | var records Records 174 | err = DB.Find(&records, "chat_id = ? and request_sensitive <> true and response_sensitive <> true", chatID).Error 175 | if err != nil { 176 | return err 177 | } 178 | 179 | buf, err := GenerateImage(records.ToRecordModel()) 180 | if err != nil { 181 | return err 182 | } 183 | 184 | filename := uuid.NewString() + ".png" 185 | err = os.WriteFile(fmt.Sprintf("./screenshots/%s", filename), buf, 0644) 186 | if err != nil { 187 | return err 188 | } 189 | 190 | url := fmt.Sprintf("https://%s/api/screenshots/%s", c.Get("Host"), filename) 191 | return c.JSON(Map{"url": url}) 192 | } 193 | -------------------------------------------------------------------------------- /apis/chat/image.go: -------------------------------------------------------------------------------- 1 | package chat 2 | 3 | import ( 4 | "MOSS_backend/data" 5 | "MOSS_backend/models" 6 | "context" 7 | "github.com/chromedp/chromedp" 8 | "net/http" 9 | "net/http/httptest" 10 | "strings" 11 | "text/template" 12 | ) 13 | 14 | var imageTemplate, _ = template.New("image").Funcs(map[string]any{"replace": ContentProcess}).Parse(string(data.ImageTemplate)) 15 | 16 | func GenerateImage(records []models.RecordModel) ([]byte, error) { 17 | // disable javascript in headless chrome 18 | opts := append(chromedp.DefaultExecAllocatorOptions[:], 19 | chromedp.Flag("blink-settings", "scriptEnabled=false"), 20 | ) 21 | ctx, cancel := chromedp.NewExecAllocator(context.Background(), opts...) 22 | defer cancel() 23 | 24 | ctx, cancel = chromedp.NewContext(context.Background()) 25 | defer cancel() 26 | 27 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 28 | w.Header().Set("Content-Type", "text/html") 29 | _ = imageTemplate.Execute(w, struct { 30 | Records []models.RecordModel 31 | }{ 32 | Records: records, 33 | }) 34 | })) 35 | defer server.Close() 36 | 37 | var buf []byte 38 | err := chromedp.Run(ctx, 39 | chromedp.EmulateViewport(800, 200), 40 | chromedp.Navigate(server.URL), 41 | chromedp.FullScreenshot(&buf, 100), 42 | ) 43 | return buf, err 44 | } 45 | 46 | func ContentProcess(content string) string { 47 | recordLines := strings.Split(content, "\n") 48 | var builder strings.Builder 49 | for i, recordLine := range recordLines { 50 | prefixSpaceCount := 0 51 | for _, character := range recordLine { 52 | if character == ' ' { 53 | prefixSpaceCount++ 54 | } else { 55 | break 56 | } 57 | } 58 | 59 | if prefixSpaceCount == 0 { 60 | builder.WriteString(recordLine) 61 | } else { 62 | builder.WriteString(strings.Replace(recordLine, " ", " ", prefixSpaceCount)) 63 | } 64 | if i != len(recordLines)-1 { 65 | builder.WriteString("") 66 | } 67 | } 68 | return builder.String() 69 | } 70 | -------------------------------------------------------------------------------- /apis/chat/routes.go: -------------------------------------------------------------------------------- 1 | package chat 2 | 3 | import ( 4 | "MOSS_backend/apis/record" 5 | "github.com/gofiber/fiber/v2" 6 | ) 7 | 8 | func RegisterRoutes(routes fiber.Router) { 9 | // chat 10 | routes.Get("/chats", ListChats) 11 | routes.Post("/chats", AddChat) 12 | routes.Put("/chats/:id/regenerate", record.RetryRecord) 13 | routes.Put("/chats/:id", ModifyChat) 14 | routes.Delete("/chats/:id", DeleteChat) 15 | routes.Get("/chats/:id/screenshots", GenerateChatScreenshot) 16 | 17 | routes.Static("/screenshots", "./screenshots") 18 | } 19 | -------------------------------------------------------------------------------- /apis/chat/schemas.go: -------------------------------------------------------------------------------- 1 | package chat 2 | 3 | type ModifyModel struct { 4 | Name *string `json:"name" validate:"omitempty,min=1"` 5 | } 6 | -------------------------------------------------------------------------------- /apis/config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/gofiber/fiber/v2" 5 | 6 | . "MOSS_backend/models" 7 | . "MOSS_backend/utils" 8 | ) 9 | 10 | // GetConfig 11 | // @Summary get global config 12 | // @Tags Config 13 | // @Produce json 14 | // @Router /config [get] 15 | // @Success 200 {object} Response 16 | func GetConfig(c *fiber.Ctx) error { 17 | var configObject Config 18 | err := LoadConfig(&configObject) 19 | if err != nil { 20 | return err 21 | } 22 | 23 | //var region string 24 | //ok, err := IsInChina(GetRealIP(c)) 25 | //if err != nil { 26 | // return err 27 | //} 28 | //if ok { 29 | // region = "cn" 30 | //} else { 31 | // region = "global" 32 | //} 33 | var region = "global" 34 | 35 | return c.JSON(Response{ 36 | Region: region, 37 | InviteRequired: configObject.InviteRequired, 38 | Notice: configObject.Notice, 39 | ModelConfig: FromModelConfig(configObject.ModelConfig), 40 | }) 41 | } 42 | 43 | // PatchConfig 44 | // @Summary update global config 45 | // @Tags Config 46 | // @Accept json 47 | // @Produce json 48 | // @Router /config [patch] 49 | // @Param json body ModifyModelConfigRequest true "body" 50 | // @Success 200 {object} Response 51 | // @Failure 400 {object} Response 52 | // @Failure 500 {object} Response 53 | func PatchConfig(c *fiber.Ctx) error { 54 | var configObject Config 55 | err := LoadConfig(&configObject) 56 | if err != nil { 57 | return InternalServerError("Failed to load config") 58 | } 59 | 60 | var body ModifyModelConfigRequest 61 | err = ValidateBody(c, &body) 62 | if err != nil { 63 | return BadRequest(err.Error()) 64 | } 65 | 66 | if body.InviteRequired != nil { 67 | configObject.InviteRequired = *body.InviteRequired 68 | } 69 | if body.OffenseCheck != nil { 70 | configObject.OffenseCheck = *body.OffenseCheck 71 | } 72 | if body.Notice != nil { 73 | configObject.Notice = *body.Notice 74 | } 75 | if body.ModelConfig != nil { 76 | newModelCfg := body.ModelConfig 77 | for _, newSingleCfg := range newModelCfg { 78 | modelID := *(newSingleCfg.ID) 79 | for i := range configObject.ModelConfig { 80 | if configObject.ModelConfig[i].ID == modelID { 81 | if newSingleCfg.Description != nil { 82 | configObject.ModelConfig[i].Description = *(newSingleCfg.Description) 83 | } 84 | if newSingleCfg.InnerThoughtsPostprocess != nil { 85 | configObject.ModelConfig[i].InnerThoughtsPostprocess = *(newSingleCfg.InnerThoughtsPostprocess) 86 | } 87 | if newSingleCfg.DefaultPluginConfig != nil { 88 | if configObject.ModelConfig[i].DefaultPluginConfig == nil { 89 | // this means the default plugin config is never set 90 | configObject.ModelConfig[i].DefaultPluginConfig = *(newSingleCfg.DefaultPluginConfig) 91 | } else { 92 | for k, v := range *(newSingleCfg.DefaultPluginConfig) { 93 | if _, ok := configObject.ModelConfig[i].DefaultPluginConfig[k]; ok { 94 | configObject.ModelConfig[i].DefaultPluginConfig[k] = v 95 | } 96 | } 97 | } 98 | 99 | } 100 | } 101 | } 102 | } 103 | } 104 | 105 | // 将更新后的 configObject 保存到数据库中 106 | err = UpdateConfig(&configObject) 107 | if err != nil { 108 | return InternalServerError("Failed to update config") 109 | } 110 | 111 | return c.Status(200).JSON(fiber.Map{ 112 | "success": "Config updated successfully", 113 | }) 114 | } 115 | -------------------------------------------------------------------------------- /apis/config/routes.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import "github.com/gofiber/fiber/v2" 4 | 5 | func RegisterRoutes(routes fiber.Router) { 6 | routes.Get("/config", GetConfig) 7 | // redis update & config update 8 | routes.Patch("/config", PatchConfig) 9 | } 10 | -------------------------------------------------------------------------------- /apis/config/schemas.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import "MOSS_backend/models" 4 | 5 | type Response struct { 6 | Region string `json:"region"` 7 | InviteRequired bool `json:"invite_required"` 8 | Notice string `json:"notice"` 9 | ModelConfig []ModelConfigResponse `json:"model_config"` 10 | } 11 | 12 | type ModelConfigResponse struct { 13 | ID int `json:"id"` 14 | Description string `json:"description"` 15 | DefaultPluginConfig map[string]bool `json:"default_plugin_config"` 16 | } 17 | 18 | func FromModelConfig(modelConfig []models.ModelConfig) []ModelConfigResponse { 19 | var response []ModelConfigResponse 20 | for _, config := range modelConfig { 21 | response = append(response, ModelConfigResponse{ 22 | ID: config.ID, 23 | Description: config.Description, 24 | DefaultPluginConfig: config.DefaultPluginConfig, 25 | }) 26 | } 27 | return response 28 | } 29 | 30 | type ModelConfigRequest struct { 31 | ID *int `json:"id" validate:"min=1"` 32 | InnerThoughtsPostprocess *bool `json:"inner_thoughts_postprocess" validate:"omitempty,oneof=true false"` 33 | Description *string `json:"description" validate:"omitempty"` 34 | DefaultPluginConfig *map[string]bool `json:"default_plugin_config" validate:"omitempty"` 35 | } 36 | 37 | type ModifyModelConfigRequest struct { 38 | InviteRequired *bool `json:"invite_required" validate:"omitempty,oneof=true false"` 39 | OffenseCheck *bool `json:"offense_check" validate:"omitempty,oneof=true false"` 40 | Notice *string `json:"notice" validate:"omitempty"` 41 | ModelConfig []*ModelConfigRequest `json:"model_config" validate:"omitempty"` 42 | } 43 | -------------------------------------------------------------------------------- /apis/default.go: -------------------------------------------------------------------------------- 1 | package apis 2 | 3 | import ( 4 | "MOSS_backend/data" 5 | "github.com/gofiber/fiber/v2" 6 | ) 7 | 8 | // Index 9 | // 10 | // @Produce application/json 11 | // @Router / [get] 12 | // @Success 200 {object} models.Map 13 | func Index(c *fiber.Ctx) error { 14 | return c.Send(data.MetaData) 15 | } 16 | -------------------------------------------------------------------------------- /apis/record/api_ws.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "log" 8 | "strconv" 9 | "sync" 10 | "sync/atomic" 11 | "time" 12 | 13 | "MOSS_backend/config" 14 | . "MOSS_backend/models" 15 | . "MOSS_backend/utils" 16 | "MOSS_backend/utils/sensitive" 17 | 18 | "github.com/gofiber/websocket/v2" 19 | "go.uber.org/zap" 20 | "gorm.io/gorm" 21 | ) 22 | 23 | var userLockMap sync.Map 24 | 25 | type UserLockValue struct { 26 | LockTime time.Time 27 | } 28 | 29 | func UserLockCheck() { 30 | ticker := time.NewTicker(time.Hour) 31 | for range ticker.C { 32 | userLockMap.Range(func(key, value interface{}) bool { 33 | userLockValue := value.(UserLockValue) 34 | // delete lock before 1 minute 35 | if userLockValue.LockTime.Before(time.Now().Add(-time.Minute)) { 36 | userLockMap.Delete(key) 37 | } 38 | return true 39 | }) 40 | } 41 | } 42 | 43 | // AddRecordAsync 44 | // @Summary add a record 45 | // @Tags Websocket 46 | // @Router /ws/chats/{chat_id}/records [get] 47 | // @Param chat_id path int true "chat id" 48 | // @Param json body CreateModel true "json" 49 | // @Success 201 {object} models.Record 50 | func AddRecordAsync(c *websocket.Conn) { 51 | var ( 52 | chatID int 53 | message []byte 54 | err error 55 | user *User 56 | banned bool 57 | chat Chat 58 | ) 59 | 60 | defer func() { 61 | if err != nil { 62 | Logger.Error( 63 | "client websocket return with error", 64 | zap.Error(err), 65 | ) 66 | response := InferResponseModel{Status: -1, Output: err.Error()} 67 | var httpError *HttpError 68 | if errors.As(err, &httpError) { 69 | response.StatusCode = httpError.Code 70 | } 71 | _ = c.WriteJSON(response) 72 | } 73 | }() 74 | 75 | procedure := func() error { 76 | // get chatID 77 | if chatID, err = strconv.Atoi(c.Params("id")); err != nil { 78 | return BadRequest("invalid chat_id") 79 | } 80 | 81 | // read body 82 | if _, message, err = c.ReadMessage(); err != nil { 83 | return fmt.Errorf("error receive message: %v", err) 84 | } 85 | 86 | // unmarshal body 87 | var body CreateModel 88 | err = json.Unmarshal(message, &body) 89 | if err != nil { 90 | return fmt.Errorf("error unmarshal text: %v", err) 91 | } 92 | 93 | if body.Request == "" { 94 | return BadRequest("request is empty") 95 | } 96 | //if len([]rune(body.Request)) > 2048 { 97 | // return maxInputExceededError 98 | //} 99 | 100 | // get user id 101 | user, err = LoadUserFromWs(c) 102 | if err != nil { 103 | return Unauthorized() 104 | } 105 | 106 | // check user lock 107 | if _, ok := userLockMap.LoadOrStore(user.ID, UserLockValue{LockTime: time.Now()}); ok { 108 | return userRequestingError 109 | } 110 | defer userLockMap.Delete(user.ID) 111 | 112 | // infer limiter 113 | if !inferLimiter.Allow() { 114 | return unknownError 115 | } 116 | 117 | banned, err = user.CheckUserOffense() 118 | if err != nil { 119 | return err 120 | } 121 | if banned { 122 | return Forbidden(OffenseMessage) 123 | } 124 | 125 | // load chat 126 | err = DB.Take(&chat, chatID).Error 127 | if err != nil { 128 | return err 129 | } 130 | 131 | // permission 132 | if chat.UserID != user.ID { 133 | return Forbidden() 134 | } 135 | 136 | record := Record{ 137 | ChatID: chatID, 138 | Request: body.Request, 139 | } 140 | 141 | // sensitive request check 142 | if sensitive.IsSensitive(record.Request, user) { 143 | record.RequestSensitive = true 144 | record.Response = DefaultResponse 145 | 146 | banned, err = user.AddUserOffense(UserOffensePrompt) 147 | if err != nil { 148 | return err 149 | } 150 | if banned { 151 | err = c.WriteJSON(InferResponseModel{ 152 | Status: -2, // banned 153 | Output: OffenseMessage, 154 | }) 155 | } else { 156 | err = c.WriteJSON(InferResponseModel{ 157 | Status: -2, // sensitive 158 | Output: DefaultResponse, 159 | }) 160 | } 161 | if err != nil { 162 | return fmt.Errorf("write sensitive error: %v", err) 163 | } 164 | } else { 165 | /* infer */ 166 | 167 | // find record prefix to make dialogs, without sensitive content 168 | var oldRecords Records 169 | err = DB.Last(&oldRecords, "chat_id = ? AND request_sensitive = ? AND response_sensitive = ?", chatID, false, false).Error 170 | if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { 171 | return err 172 | } 173 | 174 | // async infer 175 | err = InferAsync( 176 | c, 177 | oldRecords.GetPrefix(), 178 | &record, 179 | oldRecords.ToRecordModel(), 180 | user, 181 | body.Param, 182 | ) 183 | if err != nil && !errors.Is(err, ErrSensitive) { 184 | //if httpError, ok := err.(*HttpError); ok && httpError.MessageType == MaxLength { 185 | // DB.Model(&chat).Update("max_length_exceeded", true) 186 | //} 187 | return err 188 | } 189 | } 190 | 191 | // store into database 192 | err = DB.Transaction(func(tx *gorm.DB) error { 193 | err = tx.Clauses(LockingClause).Take(&chat, chatID).Error 194 | if err != nil { 195 | return err 196 | } 197 | 198 | err = tx.Create(&record).Error 199 | if err != nil { 200 | return err 201 | } 202 | 203 | if chat.Count == 0 { 204 | chat.Name = StripContent(record.Request, config.Config.ChatNameLength) 205 | } 206 | chat.Count += 1 207 | return tx.Save(&chat).Error 208 | }) 209 | if err != nil { 210 | return err 211 | } 212 | 213 | // return a total record structure 214 | err = c.WriteJSON(record) 215 | if err != nil { 216 | return fmt.Errorf("write record error: %v", err) 217 | } 218 | 219 | return nil 220 | } 221 | 222 | err = procedure() 223 | } 224 | 225 | // RegenerateAsync 226 | // @Summary regenerate a record 227 | // @Tags Websocket 228 | // @Router /ws/chats/{chat_id}/regenerate [get] 229 | // @Param chat_id path int true "chat id" 230 | // @Success 201 {object} models.Record 231 | func RegenerateAsync(c *websocket.Conn) { 232 | var ( 233 | chatID int 234 | user *User 235 | err error 236 | banned bool 237 | chat Chat 238 | ) 239 | 240 | defer func() { 241 | if err != nil { 242 | Logger.Error( 243 | "client websocket return with error", 244 | zap.Error(err), 245 | ) 246 | response := InferResponseModel{Status: -1, Output: err.Error()} 247 | if httpError, ok := err.(*HttpError); ok { 248 | response.StatusCode = httpError.Code 249 | } 250 | err = c.WriteJSON(response) 251 | if err != nil { 252 | log.Println("write err error: ", err) 253 | } 254 | } 255 | }() 256 | 257 | procedure := func() error { 258 | // get chatID 259 | if chatID, err = strconv.Atoi(c.Params("id")); err != nil { 260 | return BadRequest("invalid chat_id") 261 | } 262 | 263 | // get user id 264 | user, err = LoadUserFromWs(c) 265 | if err != nil { 266 | return Unauthorized() 267 | } 268 | 269 | // check user lock 270 | if _, ok := userLockMap.LoadOrStore(user.ID, UserLockValue{LockTime: time.Now()}); ok { 271 | return userRequestingError 272 | } 273 | defer userLockMap.Delete(user.ID) 274 | 275 | // infer limiter 276 | if !inferLimiter.Allow() { 277 | return unknownError 278 | } 279 | 280 | banned, err = user.CheckUserOffense() 281 | if err != nil { 282 | return err 283 | } 284 | if banned { 285 | return Forbidden(OffenseMessage) 286 | } 287 | 288 | // load chat 289 | err = DB.Take(&chat, chatID).Error 290 | if err != nil { 291 | return err 292 | } 293 | 294 | // permission 295 | if chat.UserID != user.ID { 296 | return Forbidden() 297 | } 298 | 299 | // get the latest record 300 | var oldRecord Record 301 | err = DB.Last(&oldRecord, "chat_id = ?", chatID).Error 302 | if err != nil { 303 | return err 304 | } 305 | 306 | if !user.IsAdmin || !user.DisableSensitiveCheck { 307 | if oldRecord.RequestSensitive { 308 | banned, err = user.AddUserOffense(UserOffensePrompt) 309 | if err != nil { 310 | return err 311 | } 312 | if banned { 313 | err = c.WriteJSON(InferResponseModel{ 314 | Status: -2, // banned 315 | Output: OffenseMessage, 316 | }) 317 | } else { 318 | err = c.WriteJSON(InferResponseModel{ 319 | Status: -2, // sensitive 320 | Output: DefaultResponse, 321 | }) 322 | } 323 | if err != nil { 324 | return fmt.Errorf("write sensitive error: %v", err) 325 | } 326 | } 327 | } 328 | 329 | record := Record{ 330 | ChatID: chatID, 331 | Request: oldRecord.Request, 332 | } 333 | 334 | /* infer */ 335 | 336 | // find old records to make dialogs, without sensitive content 337 | var oldRecords Records 338 | err = DB.Last(&oldRecords, "chat_id = ? AND request_sensitive = false AND response_sensitive = false AND id < ?", chatID, oldRecord.ID).Error 339 | if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { 340 | return err 341 | } 342 | 343 | // async infer 344 | err = InferAsync( 345 | c, 346 | oldRecords.GetPrefix(), 347 | &record, 348 | oldRecords.ToRecordModel(), 349 | user, 350 | nil, 351 | ) 352 | if err != nil && !errors.Is(err, ErrSensitive) { 353 | // 354 | //if httpError, ok := err.(*HttpError); ok && httpError.MessageType == MaxLength { 355 | // DB.Model(&chat).Update("max_length_exceeded", true) 356 | //} 357 | return err 358 | } 359 | 360 | // store into database 361 | err = DB.Transaction(func(tx *gorm.DB) error { 362 | err = tx.Clauses(LockingClause).Take(&chat, chatID).Error 363 | if err != nil { 364 | return err 365 | } 366 | 367 | err = tx.Delete(&oldRecord).Error 368 | if err != nil { 369 | return err 370 | } 371 | 372 | err = tx.Create(&record).Error 373 | if err != nil { 374 | return err 375 | } 376 | 377 | return tx.Save(&chat).Error 378 | }) 379 | if err != nil { 380 | return err 381 | } 382 | 383 | // return a total record structure 384 | err = c.WriteJSON(record) 385 | if err != nil { 386 | return fmt.Errorf("write record error: %v", err) 387 | } 388 | 389 | return nil 390 | } 391 | 392 | err = procedure() 393 | } 394 | 395 | func interrupt(c *websocket.Conn, interruptChan chan any, connectionClosed *atomic.Bool) { 396 | var message []byte 397 | var err error 398 | defer connectionClosed.Store(true) 399 | for { 400 | if connectionClosed.Load() { 401 | return 402 | } 403 | if _, message, err = c.ReadMessage(); err != nil { 404 | if connectionClosed.Load() { 405 | return 406 | } 407 | Logger.Error("receive from client error", zap.Error(err)) 408 | close(interruptChan) 409 | return 410 | } 411 | 412 | if config.Config.Debug { 413 | log.Printf("receive from client: %v\n", string(message)) 414 | } 415 | 416 | var interruptModel InterruptModel 417 | err = json.Unmarshal(message, &interruptModel) 418 | if err != nil { 419 | Logger.Error("fail to unmarshal interrupt", zap.ByteString("request", message)) 420 | continue 421 | } 422 | 423 | if interruptModel.Interrupt { 424 | close(interruptChan) 425 | return 426 | } 427 | } 428 | } 429 | 430 | // InferWithoutLoginAsync 431 | // @Summary infer without login in websocket 432 | // @Tags Websocket 433 | // @Router /ws/inference [get] 434 | // @Param json body InferenceRequest true "json" 435 | // @Success 200 {object} InferenceResponse 436 | func InferWithoutLoginAsync(c *websocket.Conn) { 437 | var ( 438 | message []byte 439 | err error 440 | record Record 441 | ) 442 | 443 | defer func() { 444 | if err != nil { 445 | Logger.Error( 446 | "client websocket return with error", 447 | zap.Error(err), 448 | ) 449 | response := InferResponseModel{Status: -1, Output: err.Error()} 450 | var httpError *HttpError 451 | if errors.As(err, &httpError) { 452 | response.StatusCode = httpError.Code 453 | } 454 | _ = c.WriteJSON(response) 455 | } 456 | }() 457 | 458 | procedure := func() error { 459 | 460 | // read body 461 | if _, message, err = c.ReadMessage(); err != nil { 462 | return fmt.Errorf("error receive message: %v", err) 463 | } 464 | 465 | // unmarshal body 466 | var body InferenceRequest 467 | err = json.Unmarshal(message, &body) 468 | if err != nil { 469 | return fmt.Errorf("error unmarshal text: %v", err) 470 | } 471 | 472 | if body.Request == "" { 473 | return BadRequest("request is empty") 474 | } 475 | //if len([]rune(body.Request)) > 2048 { 476 | // return maxInputExceededError 477 | //} 478 | 479 | // infer limiter 480 | if !inferLimiter.Allow() { 481 | return unknownError 482 | } 483 | 484 | // sensitive request check 485 | if sensitive.IsSensitive(body.Context, &User{}) { 486 | 487 | err = c.WriteJSON(InferResponseModel{ 488 | Status: -2, // sensitive 489 | Output: DefaultResponse, 490 | }) 491 | if err != nil { 492 | return fmt.Errorf("write sensitive error: %v", err) 493 | } 494 | } else { 495 | /* infer */ 496 | 497 | record.Request = body.Request 498 | err = InferAsync( 499 | c, 500 | body.Context, 501 | &record, 502 | body.Records, 503 | &User{PluginConfig: body.PluginConfig, ModelID: body.ModelID}, 504 | body.Param, 505 | ) 506 | if err != nil { 507 | return err 508 | } 509 | } 510 | 511 | // store into database 512 | directRecord := DirectRecord{ 513 | Duration: record.Duration, 514 | Context: record.Prefix, 515 | Request: record.Request, 516 | Response: record.Response, 517 | ExtraData: record.ExtraData, 518 | } 519 | _ = DB.Create(&directRecord).Error 520 | 521 | // return response 522 | _ = c.WriteJSON(InferenceResponse{ 523 | Response: record.Response, 524 | Context: record.Prefix, 525 | ExtraData: record.ExtraData, 526 | }) 527 | 528 | return nil 529 | } 530 | 531 | err = procedure() 532 | } 533 | -------------------------------------------------------------------------------- /apis/record/apis.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import ( 4 | "errors" 5 | "time" 6 | 7 | "github.com/gofiber/fiber/v2" 8 | "golang.org/x/exp/slices" 9 | "gorm.io/gorm" 10 | 11 | "MOSS_backend/config" 12 | . "MOSS_backend/models" 13 | . "MOSS_backend/utils" 14 | "MOSS_backend/utils/sensitive" 15 | ) 16 | 17 | // ListRecords 18 | // @Summary list records of a chat 19 | // @Tags record 20 | // @Router /chats/{chat_id}/records [get] 21 | // @Param chat_id path int true "chat id" 22 | // @Success 200 {array} models.Record 23 | func ListRecords(c *fiber.Ctx) error { 24 | chatID, err := c.ParamsInt("id") 25 | if err != nil { 26 | return err 27 | } 28 | 29 | userID, err := GetUserID(c) 30 | if err != nil { 31 | return err 32 | } 33 | 34 | var chat Chat 35 | err = DB.Take(&chat, chatID).Error 36 | if err != nil { 37 | return err 38 | } 39 | 40 | if userID != chat.UserID { 41 | return Forbidden() 42 | } 43 | 44 | var records = Records{} 45 | err = DB.Find(&records, "chat_id = ?", chatID).Error 46 | if err != nil { 47 | return err 48 | } 49 | 50 | return Serialize(c, records) 51 | } 52 | 53 | // AddRecord 54 | // @Summary add a record 55 | // @Tags record 56 | // @Router /chats/{chat_id}/records [post] 57 | // @Param chat_id path int true "chat id" 58 | // @Param json body CreateModel true "json" 59 | // @Success 201 {object} models.Record 60 | func AddRecord(c *fiber.Ctx) error { 61 | chatID, err := c.ParamsInt("id") 62 | if err != nil { 63 | return err 64 | } 65 | 66 | // validate body 67 | var body CreateModel 68 | err = ValidateBody(c, &body) 69 | if err != nil { 70 | return err 71 | } 72 | 73 | if body.Request == "" { 74 | return BadRequest("request is empty") 75 | } 76 | //if len([]rune(body.Request)) > 2048 { 77 | // return maxInputExceededError 78 | //} 79 | 80 | user, err := LoadUser(c) 81 | if err != nil { 82 | return err 83 | } 84 | 85 | // check user lock 86 | if _, ok := userLockMap.LoadOrStore(user.ID, UserLockValue{LockTime: time.Now()}); ok { 87 | return userRequestingError 88 | } 89 | defer userLockMap.Delete(user.ID) 90 | 91 | // infer limiter 92 | if !inferLimiter.Allow() { 93 | return unknownError 94 | } 95 | 96 | banned, err := user.CheckUserOffense() 97 | if err != nil { 98 | return err 99 | } 100 | if banned { 101 | return Forbidden(OffenseMessage) 102 | } 103 | 104 | var chat Chat 105 | err = DB.Take(&chat, chatID).Error 106 | if err != nil { 107 | return err 108 | } // not exists 109 | 110 | // permission 111 | if chat.UserID != user.ID { 112 | return Forbidden() 113 | } 114 | 115 | record := Record{ 116 | ChatID: chatID, 117 | Request: body.Request, 118 | } 119 | 120 | // sensitive request check 121 | if sensitive.IsSensitive(record.Request, user) { 122 | record.RequestSensitive = true 123 | record.Response = DefaultResponse 124 | 125 | banned, err = user.AddUserOffense(UserOffensePrompt) 126 | if err != nil { 127 | return err 128 | } 129 | if banned { 130 | return Forbidden(OffenseMessage) 131 | } 132 | } else { 133 | /* infer */ 134 | 135 | // find old records to make dialogs, without sensitive content 136 | var oldRecords Records 137 | err = DB.Find(&oldRecords, "chat_id = ? AND request_sensitive = ? AND response_sensitive = ?", chatID, false, false).Error 138 | if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { 139 | return err 140 | } 141 | 142 | // infer request 143 | err = Infer( 144 | &record, 145 | oldRecords.GetPrefix(), 146 | oldRecords.ToRecordModel(), 147 | user, 148 | body.Param, 149 | ) 150 | if err != nil { 151 | //if errors.Is(err, maxLengthExceededError) { 152 | // chat.MaxLengthExceeded = true 153 | // DB.Save(&chat) 154 | //} 155 | return err 156 | } 157 | 158 | if sensitive.IsSensitive(record.Response, user) { 159 | record.ResponseSensitive = true 160 | 161 | banned, err = user.AddUserOffense(UserOffenseMoss) 162 | if err != nil { 163 | return err 164 | } 165 | if banned { 166 | return Forbidden(OffenseMessage) 167 | } 168 | } 169 | } 170 | 171 | err = DB.Transaction(func(tx *gorm.DB) error { 172 | err = tx.Clauses(LockingClause).Take(&chat, chatID).Error 173 | if err != nil { 174 | return err 175 | } 176 | 177 | err = tx.Create(&record).Error 178 | if err != nil { 179 | return err 180 | } 181 | 182 | if chat.Count == 0 { 183 | chat.Name = StripContent(record.Request, config.Config.ChatNameLength) 184 | } 185 | chat.Count += 1 186 | return tx.Save(&chat).Error 187 | }) 188 | if err != nil { 189 | return err 190 | } 191 | 192 | return Serialize(c.Status(201), &record) 193 | } 194 | 195 | // RetryRecord 196 | // @Summary regenerate the last record of a chat 197 | // @Tags record 198 | // @Router /chats/{chat_id}/regenerate [put] 199 | // @Param chat_id path int true "chat id" 200 | // @Success 200 {object} models.Record 201 | func RetryRecord(c *fiber.Ctx) error { 202 | chatID, err := c.ParamsInt("id") 203 | if err != nil { 204 | return err 205 | } 206 | 207 | user, err := LoadUser(c) 208 | if err != nil { 209 | return err 210 | } 211 | 212 | var chat Chat 213 | err = DB.Take(&chat, chatID).Error 214 | if err != nil { 215 | return err 216 | } 217 | 218 | // check user lock 219 | if _, ok := userLockMap.LoadOrStore(user.ID, UserLockValue{LockTime: time.Now()}); ok { 220 | return userRequestingError 221 | } 222 | defer userLockMap.Delete(user.ID) 223 | 224 | // infer limiter 225 | if !inferLimiter.Allow() { 226 | return unknownError 227 | } 228 | 229 | banned, err := user.CheckUserOffense() 230 | if err != nil { 231 | return err 232 | } 233 | if banned { 234 | return Forbidden(OffenseMessage) 235 | } 236 | 237 | // permission 238 | if chat.UserID != user.ID { 239 | return Forbidden() 240 | } 241 | 242 | // get the latest record 243 | var oldRecord Record 244 | err = DB.Last(&oldRecord, "chat_id = ?", chat.ID).Error 245 | if err != nil { 246 | return err 247 | } 248 | 249 | if !user.IsAdmin || !user.DisableSensitiveCheck { 250 | if oldRecord.RequestSensitive { 251 | banned, err = user.AddUserOffense(UserOffensePrompt) 252 | if err != nil { 253 | return err 254 | } 255 | if banned { 256 | return Forbidden(OffenseMessage) 257 | } 258 | 259 | // old record request is sensitive 260 | return Serialize(c, &oldRecord) 261 | } 262 | } 263 | 264 | record := Record{ 265 | ChatID: chatID, 266 | Request: oldRecord.Request, 267 | } 268 | 269 | /* infer */ 270 | 271 | // find ole records to make dialogs, without sensitive content 272 | var oldRecords Records 273 | err = DB.Find(&oldRecords, "chat_id = ? AND request_sensitive = false AND response_sensitive = false AND id < ?", chatID, oldRecord.ID).Error 274 | if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { 275 | return err 276 | } 277 | 278 | // infer request 279 | err = Infer( 280 | &record, 281 | oldRecords.GetPrefix(), 282 | oldRecords.ToRecordModel(), 283 | user, 284 | nil, 285 | ) 286 | if err != nil { 287 | //if errors.Is(err, maxLengthExceededError) { 288 | // chat.MaxLengthExceeded = true 289 | // DB.Save(&chat) 290 | //} 291 | return err 292 | } 293 | 294 | if sensitive.IsSensitive(record.Response, user) { 295 | record.ResponseSensitive = true 296 | 297 | banned, err = user.AddUserOffense(UserOffenseMoss) 298 | if err != nil { 299 | return err 300 | } 301 | if banned { 302 | return Forbidden(OffenseMessage) 303 | } 304 | } 305 | 306 | err = DB.Transaction(func(tx *gorm.DB) error { 307 | err = tx.Clauses(LockingClause).Take(&chat, chatID).Error 308 | if err != nil { 309 | return err 310 | } 311 | 312 | err = tx.Delete(&oldRecord).Error 313 | if err != nil { 314 | return err 315 | } 316 | 317 | err = tx.Create(&record).Error 318 | if err != nil { 319 | return err 320 | } 321 | 322 | return tx.Save(&chat).Error 323 | }) 324 | if err != nil { 325 | return err 326 | } 327 | 328 | return Serialize(c, &record) 329 | } 330 | 331 | // ModifyRecord 332 | // @Summary modify a record 333 | // @Tags record 334 | // @Router /records/{record_id} [put] 335 | // @Param record_id path int true "record id" 336 | // @Param json body ModifyModel true "json" 337 | // @Success 201 {object} models.Record 338 | func ModifyRecord(c *fiber.Ctx) error { 339 | recordID, err := c.ParamsInt("id") 340 | if err != nil { 341 | return err 342 | } 343 | 344 | var body ModifyModel 345 | err = ValidateBody(c, &body) 346 | if err != nil { 347 | return err 348 | } 349 | 350 | userID, err := GetUserID(c) 351 | if err != nil { 352 | return err 353 | } 354 | 355 | if body.Feedback == nil && body.Like == nil { 356 | return BadRequest() 357 | } 358 | 359 | var record Record 360 | err = DB.Transaction(func(tx *gorm.DB) error { 361 | var chat Chat 362 | err = tx.Clauses(LockingClause).Take(&record, recordID).Error 363 | if err != nil { 364 | return err 365 | } 366 | 367 | err = tx.Take(&chat, record.ChatID).Error 368 | if err != nil { 369 | return err 370 | } 371 | 372 | if chat.UserID != userID { 373 | return Forbidden() 374 | } 375 | 376 | if body.Feedback != nil { 377 | record.Feedback = *body.Feedback 378 | } 379 | 380 | if body.Like != nil { 381 | record.LikeData = *body.Like 382 | } 383 | 384 | return tx.Model(&record).Select("Feedback", "LikeData").Updates(&record).Error 385 | }) 386 | 387 | if err != nil { 388 | return err 389 | } 390 | 391 | return Serialize(c, &record) 392 | } 393 | 394 | // InferWithoutLogin 395 | // @Summary infer without login 396 | // @Tags Inference 397 | // @Router /inference [post] 398 | // @Param json body InferenceRequest true "json" 399 | // @Success 200 {object} InferenceResponse 400 | func InferWithoutLogin(c *fiber.Ctx) error { 401 | var body InferenceRequest 402 | err := ValidateBody(c, &body) 403 | if err != nil { 404 | return err 405 | } 406 | 407 | if body.Request == "" { 408 | return BadRequest("request is empty") 409 | } 410 | //if len([]rune(body.Request)) > 2048 { 411 | // return maxInputExceededError 412 | //} 413 | 414 | // infer limiter 415 | if !inferLimiter.Allow() { 416 | return unknownError 417 | } 418 | 419 | consumerUsername := c.Get("X-Consumer-Username") 420 | passSensitiveCheck := slices.Contains(config.Config.PassSensitiveCheckUsername, consumerUsername) 421 | 422 | if !passSensitiveCheck && sensitive.IsSensitive(body.Context, &User{}) { 423 | return BadRequest(DefaultResponse).WithMessageType(Sensitive) 424 | } 425 | 426 | record := Record{Request: body.Request} 427 | 428 | err = Infer( 429 | &record, 430 | body.Context, 431 | body.Records, 432 | &User{PluginConfig: body.PluginConfig, ModelID: body.ModelID}, 433 | body.Param, 434 | ) 435 | if err != nil { 436 | return err 437 | } 438 | 439 | if !passSensitiveCheck && sensitive.IsSensitive(record.Response, &User{}) { 440 | return BadRequest(DefaultResponse).WithMessageType(Sensitive) 441 | } 442 | 443 | directRecord := DirectRecord{ 444 | Duration: record.Duration, 445 | ConsumerUsername: consumerUsername, 446 | Context: record.Prefix, 447 | Request: record.Request, 448 | Response: record.Response, 449 | ExtraData: record.ExtraData, 450 | } 451 | 452 | _ = DB.Create(&directRecord).Error 453 | 454 | return c.JSON(InferenceResponse{ 455 | Response: record.Response, 456 | Context: record.Prefix, 457 | ExtraData: record.ExtraData, 458 | }) 459 | } 460 | -------------------------------------------------------------------------------- /apis/record/limiter.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import ( 4 | "golang.org/x/time/rate" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | type inferLimiterStruct struct { 10 | Limiter *rate.Limiter 11 | sync.Map // key: timestamp, value: InferPostStats 12 | } 13 | 14 | type InferPostStats struct { 15 | Success bool 16 | time.Time 17 | } 18 | 19 | var inferLimiter = inferLimiterStruct{ 20 | Limiter: rate.NewLimiter(40, 60), 21 | } 22 | 23 | func (i *inferLimiterStruct) Allow() bool { 24 | if i.Limiter != nil && !i.Limiter.Allow() { 25 | return false 26 | } 27 | var success, failure int 28 | i.Range(func(key, value interface{}) bool { 29 | inferValue, ok := value.(InferPostStats) 30 | if !ok { 31 | return false 32 | } 33 | if inferValue.Before(time.Now().Add(-30 * time.Second)) { 34 | i.Delete(key) 35 | return true 36 | } 37 | if inferValue.Success { 38 | success++ 39 | } else { 40 | failure++ 41 | } 42 | return true 43 | }) 44 | if failure > 10 && float64(failure)/float64(success+failure) > 0.5 { 45 | return false 46 | } 47 | return true 48 | } 49 | 50 | func (i *inferLimiterStruct) AddStats(success bool) { 51 | if success { 52 | inferSuccessCounter.Inc() 53 | } else { 54 | inferFailureCounter.Inc() 55 | } 56 | i.Store(time.Now().UnixNano(), InferPostStats{ 57 | Success: success, 58 | Time: time.Now(), 59 | }) 60 | } 61 | -------------------------------------------------------------------------------- /apis/record/observe.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import ( 4 | "MOSS_backend/config" 5 | "github.com/prometheus/client_golang/prometheus" 6 | "github.com/prometheus/client_golang/prometheus/promauto" 7 | ) 8 | 9 | var inferSuccessCounter = promauto.NewCounter(prometheus.CounterOpts{ 10 | Name: prometheus.BuildFQName(config.AppName, "infer", "success"), 11 | }) 12 | 13 | var inferFailureCounter = promauto.NewCounter(prometheus.CounterOpts{ 14 | Name: prometheus.BuildFQName(config.AppName, "infer", "failure"), 15 | }) 16 | 17 | var inferStatusCounter = promauto.NewCounterVec( 18 | prometheus.CounterOpts{ 19 | Name: prometheus.BuildFQName(config.AppName, "infer", "status"), 20 | }, 21 | []string{"status_code"}, 22 | ) 23 | 24 | var inferOnFlightCounter = promauto.NewGauge(prometheus.GaugeOpts{ 25 | Name: prometheus.BuildFQName(config.AppName, "infer", "on_flight"), 26 | }) 27 | 28 | var userInferRequestOnFlight = promauto.NewGauge(prometheus.GaugeOpts{ 29 | Name: prometheus.BuildFQName(config.AppName, "user_infer_request", "on_flight"), 30 | }) 31 | -------------------------------------------------------------------------------- /apis/record/openai.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/gofiber/fiber/v2" 7 | "github.com/google/uuid" 8 | 9 | . "MOSS_backend/models" 10 | . "MOSS_backend/utils" 11 | ) 12 | 13 | // OpenAIListModels 14 | // @Summary List models in OpenAI API protocol 15 | // @Tags openai 16 | // @Router /v1/models [get] 17 | // @Success 200 {object} OpenAIModels 18 | func OpenAIListModels(c *fiber.Ctx) (err error) { 19 | modelConfigs, err := LoadModelConfigs() 20 | if err != nil { 21 | return err 22 | } 23 | 24 | return c.JSON(OpenAIModelsFromModelConfigs(modelConfigs)) 25 | } 26 | 27 | // OpenAIRetrieveModel 28 | // @Summary Retrieve a model in OpenAI API protocol 29 | // @Tags openai 30 | // @Router /v1/models/{name} [get] 31 | // @Success 200 {object} OpenAIModel 32 | func OpenAIRetrieveModel(c *fiber.Ctx) (err error) { 33 | modelName := c.Params("name") 34 | modelConfig, err := LoadModelConfigByName(modelName) 35 | if err != nil { 36 | return err 37 | } 38 | 39 | return c.JSON(OpenAIModelFromModelConfig(modelConfig)) 40 | } 41 | 42 | // OpenAICreateChatCompletion 43 | // @Summary Create a chat completion in OpenAI API protocol 44 | // @Tags openai 45 | // @Router /v1/chat/completions [post] 46 | // @Param json body OpenAIChatCompletionRequest true "json" 47 | // @Success 200 {object} OpenAIChatCompletionResponse 48 | func OpenAICreateChatCompletion(c *fiber.Ctx) (err error) { 49 | var request OpenAIChatCompletionRequest 50 | err = ValidateBody(c, &request) 51 | if err != nil { 52 | return err 53 | } 54 | 55 | modelConfig, err := LoadModelConfigByName(request.Model) 56 | if err != nil { 57 | return err 58 | } 59 | 60 | prefix, requestMessage, err := request.Messages.Build() 61 | if err != nil { 62 | return err 63 | } 64 | 65 | if requestMessage == "" { 66 | return BadRequest("request is empty") 67 | } 68 | //if len([]rune(requestMessage)) > 2048 { 69 | // return maxInputExceededError 70 | //} 71 | 72 | // infer limiter 73 | //if !inferLimiter.Allow() { 74 | // return unknownError 75 | //} 76 | 77 | //consumerUsername := c.Get("X-Consumer-Username") 78 | //passSensitiveCheck := slices.Contains(config.Config.PassSensitiveCheckUsername, consumerUsername) 79 | 80 | //if !passSensitiveCheck && sensitive.IsSensitive(prefix+"\n"+requestMessage, &User{}) { 81 | // return BadRequest(DefaultResponse).WithMessageType(Sensitive) 82 | //} 83 | 84 | recordModels, _, err := request.Messages.BuildRecordModels() 85 | if err != nil { 86 | return err 87 | } 88 | 89 | record := Record{Request: requestMessage} 90 | err = Infer(&record, prefix, recordModels, &User{ 91 | PluginConfig: nil, 92 | ModelID: modelConfig.ID, 93 | IsAdmin: true, 94 | DisableSensitiveCheck: true, 95 | }, nil) 96 | if err != nil { 97 | return err 98 | } 99 | 100 | //if !passSensitiveCheck && sensitive.IsSensitive(record.Response, &User{}) { 101 | // return BadRequest(DefaultResponse).WithMessageType(Sensitive) 102 | //} 103 | 104 | return c.JSON(&OpenAIChatCompletionResponse{ 105 | Id: "chatcmpl-" + uuid.Must(uuid.NewUUID()).String(), 106 | Object: "chat.completion", 107 | Created: time.Now().Unix(), 108 | Model: modelConfig.Description, 109 | SystemFingerprint: "", 110 | Choices: []*OpenAIChatCompletionChoice{ 111 | { 112 | Index: 0, 113 | Message: OpenAIMessages{{ 114 | Role: "assistant", 115 | Content: record.Response, 116 | }}, 117 | Logprobs: nil, 118 | FinishReason: "stop", 119 | }, 120 | }, 121 | Usage: OpenAIChatCompletionUsage{}, 122 | }) 123 | } 124 | -------------------------------------------------------------------------------- /apis/record/routes.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import ( 4 | "github.com/gofiber/fiber/v2" 5 | "github.com/gofiber/websocket/v2" 6 | ) 7 | 8 | func RegisterRoutes(routes fiber.Router) { 9 | // record 10 | routes.Get("/chats/:id/records", ListRecords) 11 | routes.Post("/chats/:id/records", AddRecord) 12 | routes.Get("/ws/chats/:id/records", websocket.New(AddRecordAsync)) 13 | routes.Get("/ws/chats/:id/regenerate", websocket.New(RegenerateAsync)) 14 | routes.Put("/records/:id", ModifyRecord) 15 | 16 | // infer response 17 | routes.Get("/ws/response", websocket.New(ReceiveInferResponse)) 18 | 19 | // infer without login 20 | routes.Post("/inference", InferWithoutLogin) 21 | routes.Get("/ws/inference", websocket.New(InferWithoutLoginAsync)) 22 | 23 | // OpenAI API protocol 24 | routes.Get("/v1/models", OpenAIListModels) 25 | routes.Get("/v1/models/:name", OpenAIRetrieveModel) 26 | routes.Post("/v1/chat/completions", OpenAICreateChatCompletion) 27 | 28 | // yocsef API 29 | routes.Get("/ws/yocsef/inference", websocket.New(InferYocsefAsyncAPI)) 30 | } 31 | -------------------------------------------------------------------------------- /apis/record/schemas.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import ( 4 | "strings" 5 | 6 | . "MOSS_backend/models" 7 | "MOSS_backend/utils" 8 | ) 9 | 10 | type ParamsModel struct { 11 | Param map[string]float64 `json:"param"` 12 | } 13 | 14 | type CreateModel struct { 15 | ParamsModel 16 | Request string `json:"request" validate:"required"` 17 | } 18 | 19 | type InterruptModel struct { 20 | Interrupt bool `json:"interrupt"` 21 | } 22 | 23 | type ModifyModel struct { 24 | Feedback *string `json:"feedback"` 25 | Like *int `json:"like" validate:"omitempty,oneof=1 0 -1"` // 1 like, -1 dislike, 0 reset 26 | } 27 | 28 | type InferenceRequest struct { 29 | Context string `json:"context"` 30 | Request string `json:"request" validate:"min=1"` 31 | Records RecordModels `json:"records" validate:"omitempty,dive"` 32 | PluginConfig map[string]bool `json:"plugin_config"` 33 | ModelID int `json:"model_id"` 34 | ParamsModel 35 | } 36 | 37 | type InferenceResponse struct { 38 | Response string `json:"response"` 39 | Context string `json:"context,omitempty"` 40 | ExtraData any `json:"extra_data,omitempty"` 41 | } 42 | 43 | // OpenAI 44 | 45 | type OpenAIModel struct { 46 | ID string `json:"id"` 47 | Object string `json:"object"` 48 | Created int `json:"created"` 49 | OwnedBy string `json:"owned_by"` 50 | } 51 | 52 | func OpenAIModelFromModelConfig(config *ModelConfig) *OpenAIModel { 53 | return &OpenAIModel{ 54 | ID: config.Description, 55 | Object: "model", 56 | Created: 0, 57 | OwnedBy: "moss", 58 | } 59 | } 60 | 61 | type OpenAIModels struct { 62 | Object string `json:"object"` 63 | Data []*OpenAIModel `json:"data"` 64 | } 65 | 66 | func OpenAIModelsFromModelConfigs(modelConfig ModelConfigs) *OpenAIModels { 67 | data := make([]*OpenAIModel, len(modelConfig)) 68 | for i := range modelConfig { 69 | data[i] = OpenAIModelFromModelConfig(modelConfig[i]) 70 | } 71 | return &OpenAIModels{ 72 | Object: "list", 73 | Data: data, 74 | } 75 | } 76 | 77 | type OpenAIMessage struct { 78 | Role string `json:"role" validate:"required,oneof=system user assistant"` 79 | Content string `json:"content" validate:"required"` 80 | } 81 | 82 | type OpenAIMessages []*OpenAIMessage 83 | 84 | func (messages OpenAIMessages) ValidateSequence() error { 85 | if len(messages) == 0 { 86 | return utils.BadRequest("empty messages") 87 | } 88 | currentRole := messages[0].Role 89 | for i := 1; i < len(messages); i++ { 90 | if messages[i] == nil { 91 | return utils.BadRequest("nil message") 92 | } 93 | if messages[i].Role == currentRole { 94 | return utils.BadRequest("consecutive messages with the same role") 95 | } 96 | if messages[i].Role == "system" || messages[i].Role == "tool" { 97 | return utils.BadRequest("unsupported message role " + messages[i].Role) 98 | } 99 | currentRole = messages[i].Role 100 | } 101 | if currentRole != "user" { 102 | return utils.BadRequest("last message must be user") 103 | } 104 | return nil 105 | } 106 | 107 | func (messages OpenAIMessages) Build() (prefix string, request string, err error) { 108 | err = messages.ValidateSequence() 109 | if err != nil { 110 | return "", "", err 111 | } 112 | var builder strings.Builder 113 | for i, message := range messages { 114 | if message == nil { 115 | err = utils.BadRequest("nil message") 116 | return 117 | } 118 | if i == len(messages)-1 { 119 | request = message.Content 120 | return builder.String(), request, nil 121 | } 122 | if message.Role == "user" { 123 | builder.WriteString("<|Human|>: ") 124 | builder.WriteString(message.Content) 125 | builder.WriteString("\n") 126 | builder.WriteString("<|Inner Thoughts|>: None\n") 127 | builder.WriteString("<|Commands|>: None\n") 128 | builder.WriteString("<|Results|>: None\n") 129 | } else if message.Role == "assistant" { 130 | builder.WriteString("<|MOSS|>: ") 131 | builder.WriteString(message.Content) 132 | builder.WriteString("\n") 133 | } else { 134 | err = utils.BadRequest("invalid message role") 135 | return 136 | } 137 | } 138 | return builder.String(), request, nil 139 | } 140 | 141 | func (messages OpenAIMessages) BuildRecordModels() (models RecordModels, request string, err error) { 142 | err = messages.ValidateSequence() 143 | if err != nil { 144 | return nil, "", err 145 | } 146 | models = make(RecordModels, 0, len(messages)/2) 147 | for i := 0; i < len(messages)-1; i += 2 { 148 | models = append(models, RecordModel{ 149 | Request: messages[i].Content, 150 | Response: messages[i+1].Content, 151 | }) 152 | } 153 | request = messages[len(messages)-1].Content 154 | return models, request, nil 155 | } 156 | 157 | type OpenAIChatCompletionRequest struct { 158 | Messages OpenAIMessages `json:"messages" validate:"required,min=1,dive"` 159 | Model string `json:"model" validate:"required"` 160 | } 161 | 162 | type OpenAIChatCompletionChoice struct { 163 | Index int `json:"index"` 164 | Message OpenAIMessages `json:"message"` 165 | Logprobs interface{} `json:"logprobs"` 166 | FinishReason string `json:"finish_reason"` 167 | } 168 | 169 | type OpenAIChatCompletionUsage struct { 170 | PromptTokens int `json:"prompt_tokens"` 171 | CompletionTokens int `json:"completion_tokens"` 172 | TotalTokens int `json:"total_tokens"` 173 | } 174 | 175 | type OpenAIChatCompletionResponse struct { 176 | Id string `json:"id"` 177 | Object string `json:"object"` 178 | Created int64 `json:"created"` 179 | Model string `json:"model"` 180 | SystemFingerprint string `json:"system_fingerprint"` 181 | Choices []*OpenAIChatCompletionChoice `json:"choices"` 182 | Usage OpenAIChatCompletionUsage `json:"usage"` 183 | } 184 | -------------------------------------------------------------------------------- /apis/record/utils.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import ( 4 | . "MOSS_backend/utils" 5 | "errors" 6 | "regexp" 7 | ) 8 | 9 | // regexps 10 | var ( 11 | endContentRegexp = regexp.MustCompile(`<[es]o\w>`) 12 | mossSpecialTokenRegexp = regexp.MustCompile(``) 13 | innerThoughtsRegexp = regexp.MustCompile(`<\|Inner Thoughts\|>:([\s\S]+?)()`) 14 | commandsRegexp = regexp.MustCompile(`<\|Commands\|>:([\s\S]+?)()`) 15 | resultsRegexp = regexp.MustCompile(`<\|Results\|>:[\s\S]+?`) // not greedy 16 | mossRegexp = regexp.MustCompile(`<\|MOSS\|>:([\s\S]+?)()`) 17 | secondGenerationsFormatRegexp = regexp.MustCompile(`^<\|MOSS\|>:[\s\S]+?$`) 18 | firstGenerationsFormatRegexp = regexp.MustCompile(`^<\|Inner Thoughts\|>:[\s\S]+?\n *?<\|Commands\|>:[\s\S]+?$`) 19 | ) 20 | 21 | //var maxLengthExceededError = BadRequest("The maximum context length is exceeded").WithMessageType(MaxLength) 22 | 23 | // error messages 24 | var ( 25 | userRequestingError = BadRequest("上一次请求还未结束,请稍后再试。User requesting, please wait and try again") 26 | maxInputExceededError = BadRequest("单次输入限长为 2048 字符。Input no more than 2048 characters").WithMessageType(MaxLength) 27 | maxInputExceededFromInferError = BadRequest("单次输入超长,请减少字数并重试。Input max length exceeded, please reduce length and try again").WithMessageType(MaxLength) 28 | unknownError = InternalServerError("未知错误,请刷新或等待一分钟后再试。Unknown error, please refresh or wait a minute and try again") 29 | ErrSensitive = errors.New("sensitive") 30 | interruptError = NoStatus("client interrupt") 31 | ) 32 | -------------------------------------------------------------------------------- /apis/record/yocsef.go: -------------------------------------------------------------------------------- 1 | package record 2 | 3 | import ( 4 | . "MOSS_backend/models" 5 | "MOSS_backend/service" 6 | . "MOSS_backend/utils" 7 | "context" 8 | "errors" 9 | "fmt" 10 | "github.com/gofiber/websocket/v2" 11 | "go.uber.org/zap" 12 | ) 13 | 14 | // InferYocsefAsyncAPI 15 | // @Summary infer without login in websocket 16 | // @Tags Websocket 17 | // @Router /yocsef/inference [get] 18 | // @Param json body InferenceRequest true "json" 19 | // @Success 200 {object} InferenceResponse 20 | func InferYocsefAsyncAPI(c *websocket.Conn) { 21 | var ( 22 | err error 23 | ) 24 | 25 | defer func() { 26 | if err != nil { 27 | Logger.Error( 28 | "client websocket return with error", 29 | zap.Error(err), 30 | ) 31 | response := InferResponseModel{Status: -1, Output: err.Error()} 32 | var httpError *HttpError 33 | if errors.As(err, &httpError) { 34 | response.StatusCode = httpError.Code 35 | } 36 | _ = c.WriteJSON(response) 37 | } 38 | }() 39 | 40 | procedure := func() error { 41 | 42 | // read body 43 | var body InferenceRequest 44 | if err = c.ReadJSON(&body); err != nil { 45 | return fmt.Errorf("error receive message: %v", err) 46 | } 47 | 48 | if body.Request == "" { 49 | return BadRequest("内容不能为空") 50 | } 51 | 52 | //ctx, cancel := context.WithCancelCause(context.Background()) 53 | //defer cancel(errors.New("procedure finished")) 54 | 55 | // listen to interrupt and connection close 56 | //go func() { 57 | // defer cancel(errors.New("client connection closed or interrupt")) 58 | // _, _, innerErr := c.ReadMessage() 59 | // if innerErr != nil { 60 | // return 61 | // } 62 | //}() 63 | 64 | var record *DirectRecord 65 | record, err = service.InferYocsef( 66 | context.Background(), 67 | c, 68 | body.Request, 69 | body.Records, 70 | ) 71 | if err != nil { 72 | return err 73 | } 74 | 75 | DB.Create(&record) 76 | 77 | _ = c.WriteJSON(InferenceResponse{Response: record.Response}) 78 | 79 | return nil 80 | } 81 | 82 | err = procedure() 83 | } 84 | -------------------------------------------------------------------------------- /apis/routes.go: -------------------------------------------------------------------------------- 1 | package apis 2 | 3 | import ( 4 | "github.com/gofiber/fiber/v2" 5 | "github.com/gofiber/swagger" 6 | 7 | "MOSS_backend/apis/account" 8 | "MOSS_backend/apis/chat" 9 | "MOSS_backend/apis/config" 10 | "MOSS_backend/apis/record" 11 | ) 12 | 13 | func RegisterRoutes(app *fiber.App) { 14 | app.Get("/", func(c *fiber.Ctx) error { 15 | return c.Redirect("/api") 16 | }) 17 | // docs 18 | app.Get("/docs", func(c *fiber.Ctx) error { 19 | return c.Redirect("/docs/index.html") 20 | }) 21 | app.Get("/docs/*", swagger.HandlerDefault) 22 | 23 | // meta 24 | routes := app.Group("/api") 25 | routes.Get("/", Index) 26 | 27 | account.RegisterRoutes(routes) 28 | chat.RegisterRoutes(routes) 29 | record.RegisterRoutes(routes) 30 | config.RegisterRoutes(routes) 31 | 32 | } 33 | -------------------------------------------------------------------------------- /config/cache.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "log" 9 | "math/rand" 10 | "time" 11 | 12 | "github.com/redis/go-redis/v9" 13 | ) 14 | 15 | var RedisClient *redis.Client 16 | 17 | func initCache() { 18 | RedisClient = redis.NewClient(&redis.Options{ 19 | Addr: Config.RedisUrl, 20 | }) 21 | pong, err := RedisClient.Ping(context.Background()).Result() 22 | fmt.Println(pong, err) 23 | } 24 | 25 | // GetCache get cache from redis 26 | func GetCache(key string, modelPtr any) error { 27 | data, err := RedisClient.Get(context.Background(), key).Bytes() 28 | if err != nil { 29 | if err != redis.Nil { // err == redis.Nil means key does not exist, logging that is not necessary 30 | log.Printf("error get cache %s err %v", key, err) 31 | } 32 | return err 33 | } 34 | if len(data) == 0 { 35 | log.Printf("empty value of key %v", key) 36 | return errors.New("empty value") 37 | } 38 | 39 | err = json.Unmarshal(data, modelPtr) 40 | if err != nil { 41 | log.Printf("error during unmarshal %s, data:%v ,err %v", key, string(data), err) 42 | } 43 | return err 44 | } 45 | 46 | // SetCache set cache with random duration(+15min) 47 | // the error can be dropped because it has been logged in the function 48 | func SetCache(key string, model any, duration time.Duration) error { 49 | data, err := json.Marshal(model) 50 | if err != nil { 51 | log.Printf("error during marshal %s|%v, err %v", key, model, err) 52 | return err 53 | } 54 | duration = GenRandomDuration(duration) 55 | err = RedisClient.Set(context.Background(), key, data, duration).Err() 56 | if err != nil { 57 | log.Printf("error set cache %s|%v data(string): %v with duration %s, err %v", 58 | key, model, string(data), duration.String(), err) 59 | } 60 | return err 61 | } 62 | 63 | func DeleteCache(key string) error { 64 | return RedisClient.Del(context.Background(), key).Err() 65 | } 66 | 67 | func ClearCache() error { 68 | return RedisClient.FlushAll(context.Background()).Err() 69 | } 70 | 71 | func GenRandomDuration(delay time.Duration) time.Duration { 72 | return delay + time.Duration(rand.Int63n(int64(900*time.Second))) 73 | } 74 | -------------------------------------------------------------------------------- /config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/caarlos0/env/v8" 7 | ) 8 | 9 | const AppName = "moss_backend" 10 | 11 | var Config struct { 12 | Mode string `env:"MODE" envDefault:"dev"` 13 | Debug bool `env:"DEBUG" envDefault:"false"` 14 | Hostname string `env:"HOSTNAME,required"` 15 | DbUrl string `env:"DB_URL,required"` 16 | KongUrl string `env:"KONG_URL,required"` 17 | RedisUrl string `env:"REDIS_URL"` 18 | // sending email config 19 | EmailUrl string `env:"EMAIL_URL,required"` 20 | TencentSecretID string `env:"SECRET_ID,required"` 21 | TencentSecretKey string `env:"SECRET_KEY,required"` 22 | TencentTemplateID uint64 `env:"TEMPLATE_ID,required"` 23 | // sending message config 24 | UniAccessID string `env:"UNI_ACCESS_ID,required"` 25 | UniSignature string `env:"UNI_SIGNATURE" envDefault:"fastnlp"` 26 | UniTemplateID string `env:"UNI_TEMPLATE_ID,required"` 27 | 28 | // InferenceUrl string `env:"INFERENCE_URL,required"` // now save it in db 29 | 30 | // 敏感信息检测 31 | EnableSensitiveCheck bool `env:"ENABLE_SENSITIVE_CHECK" envDefault:"true"` 32 | SensitiveCheckPlatform string `env:"SENSITIVE_CHECK_PLATFORM" envDefault:"ShuMei"` // one of ShuMei or DiTing 33 | 34 | // 谛听平台 35 | DiTingToken string `env:"SENSITIVE_CHECK_TOKEN"` 36 | 37 | // 数美平台 38 | ShuMeiAccessKey string `env:"SHU_MEI_ACCESS_KEY"` 39 | ShuMeiAppID string `env:"SHU_MEI_APP_ID"` 40 | ShuMeiEventID string `env:"SHU_MEI_EVENT_ID"` 41 | ShuMeiType string `env:"SHU_MEI_TYPE"` 42 | 43 | VerificationCodeExpires int `env:"VERIFICATION_CODE_EXPIRES" envDefault:"10"` 44 | ChatNameLength int `env:"CHAT_NAME_LENGTH" envDefault:"30"` 45 | 46 | AccessExpireTime int `env:"ACCESS_EXPIRE_TIME" envDefault:"30"` // 30 minutes 47 | RefreshExpireTime int `env:"REFRESH_EXPIRE_TIME" envDefault:"30"` // 30 days 48 | 49 | CallbackUrl string `env:"CALLBACK_URL,required"` // async callback url 50 | 51 | OpenScreenshot bool `env:"OPEN_SCREENSHOT" envDefault:"true"` 52 | 53 | PassSensitiveCheckUsername []string `env:"PASS_SENSITIVE_CHECK_USERNAME"` 54 | 55 | // tools 56 | EnableTools bool `env:"ENABLE_TOOLS" envDefault:"true"` 57 | ToolsSearchUrl string `env:"TOOLS_SEARCH_URL,required"` 58 | ToolsCalculateUrl string `env:"TOOLS_CALCULATE_URL,required"` 59 | ToolsSolveUrl string `env:"TOOLS_SOLVE_URL,required"` 60 | ToolsDrawUrl string `env:"TOOLS_DRAW_URL,required"` 61 | 62 | // DefaultPluginConfig map[string]bool `env:"DEFAULT_PLUGIN_CONFIG"` 63 | 64 | // InnerThoughtsPostprocess bool `env:"INNER_THOUGHTS_POSTPROCESS" envDefault:"false"` 65 | 66 | DefaultModelID int `env:"DEFAULT_MODEL_ID" envDefault:"1"` 67 | NoNeedInviteCodeEmailSuffix []string `env:"NO_NEED_INVITE_CODE_EMAIL_SUFFIX" envSeparator:"," envDefault:"fudan.edu.cn"` 68 | 69 | // yocsef 70 | YocsefInferenceUrl string `env:"YOCSEF_INFERENCE_URL"` 71 | } 72 | 73 | func InitConfig() { 74 | var err error 75 | if err = env.Parse(&Config); err != nil { 76 | panic(err) 77 | } 78 | fmt.Printf("%+v\n", &Config) 79 | 80 | initCache() 81 | } 82 | -------------------------------------------------------------------------------- /data/data.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import _ "embed" 4 | 5 | //go:embed ip2region.xdb 6 | var Ip2RegionDBFile []byte 7 | 8 | //go:embed meta.json 9 | var MetaData []byte 10 | 11 | //go:embed image.html 12 | var ImageTemplate []byte 13 | -------------------------------------------------------------------------------- /data/image.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | MOSS 9 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | MOSS 的回答均由语言模型生成,可能包含不完整、误导性的,或者错误的信息。(MOSS 版本: 0.0.3) 81 | 82 | 83 | {{ range .Records }} 84 | 85 | 86 | {{- html .Request | replace -}} 87 | 88 | 89 | 90 | 91 | {{- html .Response | replace -}} 92 | 93 | 94 | {{ end }} 95 | 96 | 97 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /data/ip2region.xdb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingYiJun/MOSS_backend/3cb166195ed38340f87585fe93a219f6b77eca35/data/ip2region.xdb -------------------------------------------------------------------------------- /data/meta.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "MOSS backend", 3 | "version": "0.0.1", 4 | "author": "JingYiJun", 5 | "email": "jingyijun@fduhole.com", 6 | "license": "Apache-2.0" 7 | } -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module MOSS_backend 2 | 3 | go 1.22 4 | 5 | require ( 6 | github.com/ansrivas/fiberprometheus/v2 v2.6.1 7 | github.com/apistd/uni-go-sdk v0.0.2 8 | github.com/caarlos0/env/v8 v8.0.0 9 | github.com/chromedp/chromedp v0.9.5 10 | github.com/creasty/defaults v1.7.0 11 | github.com/eko/gocache/lib/v4 v4.1.6 12 | github.com/eko/gocache/store/go_cache/v4 v4.2.1 13 | github.com/eko/gocache/store/redis/v4 v4.2.1 14 | github.com/go-playground/validator/v10 v10.20.0 15 | github.com/gofiber/fiber/v2 v2.52.4 16 | github.com/gofiber/swagger v1.0.0 17 | github.com/gofiber/websocket/v2 v2.2.1 18 | github.com/golang-jwt/jwt/v4 v4.5.0 19 | github.com/google/uuid v1.6.0 20 | github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20240510055607-89e20ab7b6c6 21 | github.com/patrickmn/go-cache v2.1.0+incompatible 22 | github.com/pkg/errors v0.9.1 23 | github.com/prometheus/client_golang v1.19.1 24 | github.com/redis/go-redis/v9 v9.5.1 25 | github.com/robfig/cron/v3 v3.0.1 26 | github.com/sashabaranov/go-openai v1.24.0 27 | github.com/swaggo/swag v1.16.3 28 | github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.927 29 | github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ses v1.0.926 30 | github.com/vmihailenco/msgpack/v5 v5.4.1 31 | go.uber.org/zap v1.27.0 32 | golang.org/x/crypto v0.23.0 33 | golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 34 | golang.org/x/time v0.5.0 35 | gorm.io/driver/mysql v1.5.6 36 | gorm.io/driver/sqlite v1.5.5 37 | gorm.io/gorm v1.25.10 38 | ) 39 | 40 | require ( 41 | filippo.io/edwards25519 v1.1.0 // indirect 42 | github.com/KyleBanks/depth v1.2.1 // indirect 43 | github.com/andybalholm/brotli v1.1.0 // indirect 44 | github.com/beorn7/perks v1.0.1 // indirect 45 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 46 | github.com/chromedp/cdproto v0.0.0-20240519224452-66462be74baa // indirect 47 | github.com/chromedp/sysutil v1.0.0 // indirect 48 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 49 | github.com/fasthttp/websocket v1.5.8 // indirect 50 | github.com/gabriel-vasile/mimetype v1.4.3 // indirect 51 | github.com/go-openapi/jsonpointer v0.21.0 // indirect 52 | github.com/go-openapi/jsonreference v0.21.0 // indirect 53 | github.com/go-openapi/spec v0.21.0 // indirect 54 | github.com/go-openapi/swag v0.23.0 // indirect 55 | github.com/go-playground/locales v0.14.1 // indirect 56 | github.com/go-playground/universal-translator v0.18.1 // indirect 57 | github.com/go-sql-driver/mysql v1.8.1 // indirect 58 | github.com/gobwas/httphead v0.1.0 // indirect 59 | github.com/gobwas/pool v0.2.1 // indirect 60 | github.com/gobwas/ws v1.4.0 // indirect 61 | github.com/gofiber/adaptor/v2 v2.2.1 // indirect 62 | github.com/golang/mock v1.6.0 // indirect 63 | github.com/jinzhu/inflection v1.0.0 // indirect 64 | github.com/jinzhu/now v1.1.5 // indirect 65 | github.com/josharian/intern v1.0.0 // indirect 66 | github.com/klauspost/compress v1.17.8 // indirect 67 | github.com/leodido/go-urn v1.4.0 // indirect 68 | github.com/mailru/easyjson v0.7.7 // indirect 69 | github.com/mattn/go-colorable v0.1.13 // indirect 70 | github.com/mattn/go-isatty v0.0.20 // indirect 71 | github.com/mattn/go-runewidth v0.0.15 // indirect 72 | github.com/mattn/go-sqlite3 v1.14.22 // indirect 73 | github.com/prometheus/client_model v0.6.1 // indirect 74 | github.com/prometheus/common v0.53.0 // indirect 75 | github.com/prometheus/procfs v0.15.0 // indirect 76 | github.com/rivo/uniseg v0.4.7 // indirect 77 | github.com/savsgio/gotils v0.0.0-20240303185622-093b76447511 // indirect 78 | github.com/swaggo/files/v2 v2.0.0 // indirect 79 | github.com/valyala/bytebufferpool v1.0.0 // indirect 80 | github.com/valyala/fasthttp v1.53.0 // indirect 81 | github.com/valyala/tcplisten v1.0.0 // indirect 82 | github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect 83 | go.uber.org/multierr v1.11.0 // indirect 84 | golang.org/x/net v0.25.0 // indirect 85 | golang.org/x/sync v0.7.0 // indirect 86 | golang.org/x/sys v0.20.0 // indirect 87 | golang.org/x/text v0.15.0 // indirect 88 | golang.org/x/tools v0.21.0 // indirect 89 | google.golang.org/protobuf v1.34.1 // indirect 90 | gopkg.in/yaml.v3 v3.0.1 // indirect 91 | ) 92 | -------------------------------------------------------------------------------- /local.conf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | sans-serif 7 | 8 | Main sans-serif font name goes here 9 | Noto Color Emoji 10 | Noto Emoji 11 | 12 | 13 | 14 | 15 | serif 16 | 17 | Main serif font name goes here 18 | Noto Color Emoji 19 | Noto Emoji 20 | 21 | 22 | 23 | 24 | monospace 25 | 26 | Main monospace font name goes here 27 | Noto Color Emoji 28 | Noto Emoji 29 | 30 | 31 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | // @title Moss Backend 2 | // @version 0.0.1 3 | // @description Moss Backend 4 | 5 | // @contact.name Maintainer Chen Ke 6 | // @contact.url https://danxi.fduhole.com/about 7 | // @contact.email dev@fduhole.com 8 | 9 | // @license.name Apache 2.0 10 | // @license.url https://www.apache.org/licenses/LICENSE-2.0.html 11 | 12 | // @host localhost:8000 13 | // @BasePath /api 14 | 15 | package main 16 | 17 | import ( 18 | "MOSS_backend/apis" 19 | "MOSS_backend/apis/record" 20 | "MOSS_backend/config" 21 | _ "MOSS_backend/docs" 22 | "MOSS_backend/middlewares" 23 | "MOSS_backend/models" 24 | "MOSS_backend/utils" 25 | "MOSS_backend/utils/auth" 26 | "MOSS_backend/utils/kong" 27 | "encoding/json" 28 | "github.com/gofiber/fiber/v2" 29 | "github.com/robfig/cron/v3" 30 | "log" 31 | "os" 32 | "os/signal" 33 | "syscall" 34 | ) 35 | 36 | func main() { 37 | config.InitConfig() 38 | models.InitDB() 39 | auth.InitCache() 40 | 41 | // connect to kong 42 | err := kong.Ping() 43 | if err != nil { 44 | panic(err) 45 | } 46 | 47 | app := fiber.New(fiber.Config{ 48 | AppName: config.AppName, 49 | ErrorHandler: utils.MyErrorHandler, 50 | JSONDecoder: json.Unmarshal, 51 | JSONEncoder: json.Marshal, 52 | DisableStartupMessage: true, 53 | }) 54 | middlewares.RegisterMiddlewares(app) 55 | apis.RegisterRoutes(app) 56 | 57 | startTasks() 58 | 59 | go func() { 60 | err = app.Listen("0.0.0.0:8000") 61 | if err != nil { 62 | log.Println(err) 63 | } 64 | }() 65 | 66 | interrupt := make(chan os.Signal, 1) 67 | 68 | // wait for CTRL-C interrupt 69 | signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM) 70 | <-interrupt 71 | 72 | // close app 73 | err = app.Shutdown() 74 | if err != nil { 75 | log.Println(err) 76 | } 77 | 78 | _ = utils.Logger.Sync() 79 | } 80 | 81 | func startTasks() { 82 | c := cron.New() 83 | _, err := c.AddFunc("CRON_TZ=Asia/Shanghai 0 0 * * *", models.ActiveStatusTask) // run every day 00:00 +8:00 84 | if err != nil { 85 | panic(err) 86 | } 87 | go c.Start() 88 | go record.UserLockCheck() 89 | } 90 | -------------------------------------------------------------------------------- /middlewares/init.go: -------------------------------------------------------------------------------- 1 | package middlewares 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/ansrivas/fiberprometheus/v2" 7 | "github.com/gofiber/fiber/v2" 8 | "github.com/gofiber/fiber/v2/middleware/cors" 9 | "github.com/gofiber/fiber/v2/middleware/recover" 10 | "go.uber.org/zap" 11 | 12 | "MOSS_backend/config" 13 | "MOSS_backend/models" 14 | "MOSS_backend/utils" 15 | ) 16 | 17 | func RegisterMiddlewares(app *fiber.App) { 18 | if config.Config.Mode != "bench" { 19 | app.Use(MyLogger) 20 | } 21 | app.Use(cors.New(cors.Config{AllowOrigins: "*"})) 22 | //app.Use(GetUserID) 23 | 24 | // prometheus 25 | prom := fiberprometheus.NewWith(config.AppName, config.AppName, "http") 26 | prom.RegisterAt(app, "/metrics") 27 | app.Use(prom.Middleware) 28 | 29 | app.Use(recover.New(recover.Config{EnableStackTrace: true})) 30 | } 31 | 32 | func GetUserID(c *fiber.Ctx) error { 33 | userID, err := models.GetUserID(c) 34 | if err == nil { 35 | c.Locals("user_id", userID) 36 | } 37 | 38 | return c.Next() 39 | } 40 | 41 | func MyLogger(c *fiber.Ctx) error { 42 | startTime := time.Now() 43 | chainErr := c.Next() 44 | 45 | if chainErr != nil { 46 | if err := c.App().ErrorHandler(c, chainErr); err != nil { 47 | _ = c.SendStatus(fiber.StatusInternalServerError) 48 | } 49 | } 50 | 51 | // not log for prometheus metrics 52 | if c.Path() == "/metrics" { 53 | return nil 54 | } 55 | 56 | latency := time.Since(startTime).Milliseconds() 57 | userID, ok := c.Locals("user_id").(int) 58 | output := []zap.Field{ 59 | zap.Int("status_code", c.Response().StatusCode()), 60 | zap.String("method", c.Method()), 61 | zap.String("origin_url", c.OriginalURL()), 62 | zap.String("remote_ip", utils.GetRealIP(c)), 63 | zap.Int64("latency", latency), 64 | } 65 | if ok { 66 | output = append(output, zap.Int("user_id", userID)) 67 | } 68 | if chainErr != nil { 69 | output = append(output, zap.Error(chainErr)) 70 | } 71 | utils.Logger.Info("http log", output...) 72 | return nil 73 | } 74 | -------------------------------------------------------------------------------- /models/active_status.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "log" 5 | "time" 6 | ) 7 | 8 | type ActiveStatus struct { 9 | ID int 10 | CreatedAt time.Time 11 | DAU int 12 | MAU int 13 | } 14 | 15 | func ActiveStatusTask() { 16 | var dau, mau int64 17 | err := DB.Model(&User{}). 18 | Where("last_login between ? and ?", time.Now().Add(-24*time.Hour), time.Now()). 19 | Count(&dau).Error 20 | if err != nil { 21 | log.Println("load dau err") 22 | } 23 | err = DB.Model(&User{}). 24 | Where("last_login between ? and ?", time.Now().AddDate(0, -1, 0), time.Now()). 25 | Count(&mau).Error 26 | if err != nil { 27 | log.Println("load mau err") 28 | } 29 | 30 | status := ActiveStatus{ 31 | DAU: int(dau), 32 | MAU: int(mau), 33 | } 34 | err = DB.Create(&status).Error 35 | if err != nil { 36 | log.Println("save status err") 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /models/base.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import "gorm.io/gorm/clause" 4 | 5 | type Map = map[string]any 6 | 7 | var LockingClause = clause.Locking{Strength: "UPDATE"} 8 | -------------------------------------------------------------------------------- /models/chat.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/gofiber/fiber/v2" 7 | "github.com/sashabaranov/go-openai" 8 | "gorm.io/gorm" 9 | ) 10 | 11 | type Chat struct { 12 | ID int `json:"id"` 13 | CreatedAt time.Time `json:"created_at"` 14 | UpdatedAt time.Time `json:"updated_at"` 15 | DeletedAt gorm.DeletedAt `json:"-" gorm:"index:idx_chat_user_deleted,priority:2"` 16 | UserID int `json:"user_id" gorm:"index:idx_chat_user_deleted,priority:1"` 17 | Name string `json:"name"` 18 | Count int `json:"count"` // Record 条数 19 | Records Records `json:"records,omitempty"` 20 | MaxLengthExceeded bool `json:"max_length_exceeded"` 21 | } 22 | 23 | type Chats []Chat 24 | 25 | type Record struct { 26 | ID int `json:"id"` 27 | CreatedAt time.Time `json:"created_at"` 28 | DeletedAt gorm.DeletedAt `json:"-" gorm:"index:idx_record_chat_deleted,priority:2"` 29 | Duration float64 `json:"duration"` // 处理时间,单位 s 30 | ChatID int `json:"chat_id" gorm:"index:idx_record_chat_deleted,priority:1"` 31 | Request string `json:"request"` 32 | Response string `json:"response"` 33 | Prefix string `json:"-"` 34 | RawContent string `json:"raw_content"` 35 | ExtraData any `json:"-" gorm:"serializer:json"` //`json:"extra_data" gorm:"serializer:json"` 36 | ProcessedExtraData any `json:"processed_extra_data" gorm:"serializer:json"` 37 | LikeData int `json:"like_data"` // 1 like, -1 dislike 38 | Feedback string `json:"feedback"` 39 | RequestSensitive bool `json:"request_sensitive"` 40 | ResponseSensitive bool `json:"response_sensitive"` 41 | InnerThoughts string `json:"inner_thoughts"` 42 | } 43 | 44 | type Records []Record 45 | 46 | func (record *Record) Preprocess(_ *fiber.Ctx) error { 47 | if record.ResponseSensitive { 48 | record.Response = DefaultResponse 49 | } 50 | return nil 51 | } 52 | 53 | func (records Records) Preprocess(c *fiber.Ctx) error { 54 | for i := range records { 55 | _ = records[i].Preprocess(c) 56 | } 57 | return nil 58 | } 59 | 60 | const DefaultResponse = `Sorry, I have nothing to say. Try another topic. I will block your account if we continue this topic :)` 61 | 62 | func (records Records) ToRecordModel() (recordModel []RecordModel) { 63 | for _, record := range records { 64 | recordModel = append(recordModel, RecordModel{ 65 | Request: record.Request, 66 | Response: record.Response, 67 | }) 68 | } 69 | return 70 | } 71 | 72 | func (records Records) ToOpenAIMessages() (messages []openai.ChatCompletionMessage) { 73 | for _, record := range records { 74 | messages = append(messages, 75 | openai.ChatCompletionMessage{ 76 | Role: "user", 77 | Content: record.Request, 78 | }, 79 | openai.ChatCompletionMessage{ 80 | Role: "assistant", 81 | Content: record.Response, 82 | }) 83 | } 84 | return 85 | } 86 | 87 | func (records Records) GetPrefix() string { 88 | if len(records) == 0 { 89 | return "" 90 | } 91 | return records[len(records)-1].Prefix 92 | } 93 | 94 | type RecordModel struct { 95 | Request string `json:"request" validate:"required"` 96 | Response string `json:"response" validate:"required"` 97 | } 98 | 99 | type RecordModels []RecordModel 100 | 101 | func (recordModels RecordModels) ToOpenAIMessages() (messages []openai.ChatCompletionMessage) { 102 | for _, record := range recordModels { 103 | messages = append(messages, 104 | openai.ChatCompletionMessage{ 105 | Role: "user", 106 | Content: record.Request, 107 | }, 108 | openai.ChatCompletionMessage{ 109 | Role: "assistant", 110 | Content: record.Response, 111 | }) 112 | } 113 | return 114 | } 115 | 116 | type Param struct { 117 | ID int 118 | Name string 119 | Value float64 120 | } 121 | 122 | func LoadParamToMap(m map[string]any) error { 123 | if DB == nil { 124 | return nil 125 | } 126 | var params []Param 127 | err := DB.Find(¶ms).Error 128 | if err != nil { 129 | return err 130 | } 131 | for _, param := range params { 132 | m[param.Name] = param.Value 133 | } 134 | return nil 135 | } 136 | 137 | type DirectRecord struct { 138 | ID int 139 | CreatedAt time.Time 140 | Duration float64 141 | ConsumerUsername string 142 | Context string 143 | Request string 144 | Response string 145 | ExtraData any `json:"extra_data" gorm:"serializer:json"` 146 | } 147 | -------------------------------------------------------------------------------- /models/config.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "time" 5 | 6 | "go.uber.org/zap" 7 | 8 | "MOSS_backend/config" 9 | "MOSS_backend/utils" 10 | ) 11 | 12 | type APIType string 13 | 14 | const ( 15 | APITypeOpenAI APIType = "openai" 16 | ) 17 | 18 | type ModelConfig struct { 19 | ID int `json:"id"` 20 | InnerThoughtsPostprocess bool `json:"inner_thoughts_postprocess" default:"false"` 21 | Description string `json:"description"` 22 | DefaultPluginConfig map[string]bool `json:"default_plugin_config" gorm:"serializer:json"` 23 | Url string `json:"url"` 24 | CallbackUrl string `json:"callback_url"` 25 | APIType APIType `json:"api_type"` 26 | OpenAIModelName string `json:"openai_model_name"` 27 | OpenAISystemPrompt string `json:"openai_system_prompt"` 28 | EnableSensitiveCheck bool `json:"enable_sensitive_check"` 29 | EndDelimiter string `json:"end_delimiter"` 30 | } 31 | 32 | type ModelConfigs = []*ModelConfig 33 | 34 | func (cfg *ModelConfig) TableName() string { 35 | return "language_model_config" 36 | } 37 | 38 | type Config struct { 39 | ID int `json:"id"` 40 | InviteRequired bool `json:"invite_required"` 41 | OffenseCheck bool `json:"offense_check"` 42 | Notice string `json:"notice"` 43 | ModelConfig []ModelConfig `json:"model_config" gorm:"-:all"` 44 | } 45 | 46 | const configCacheName = "moss_backend_config" 47 | const configCacheExpire = 24 * time.Hour 48 | 49 | func LoadConfig(configObjectPtr *Config) error { 50 | if config.GetCache(configCacheName, configObjectPtr) != nil { 51 | if err := DB.First(configObjectPtr).Error; err != nil { 52 | return err 53 | } 54 | if err := DB.Find(&(configObjectPtr.ModelConfig)).Error; err != nil { 55 | return err 56 | } 57 | _ = config.SetCache(configCacheName, *configObjectPtr, configCacheExpire) 58 | } 59 | return nil 60 | } 61 | 62 | func LoadModelConfigs() (ModelConfigs, error) { 63 | var modelConfigs ModelConfigs 64 | if err := DB.Find(&modelConfigs).Error; err != nil { 65 | return nil, err 66 | } 67 | return modelConfigs, nil 68 | } 69 | 70 | func LoadModelConfigByName(name string) (*ModelConfig, error) { 71 | var modelConfig ModelConfig 72 | if err := DB.Where("description = ?", name).First(&modelConfig).Error; err != nil { 73 | return nil, err 74 | } 75 | return &modelConfig, nil 76 | } 77 | 78 | func LoadModelConfigByID(id int) (*ModelConfig, error) { 79 | var modelConfig ModelConfig 80 | if err := DB.Where("id = ?", id).First(&modelConfig).Error; err != nil { 81 | return nil, err 82 | } 83 | return &modelConfig, nil 84 | } 85 | 86 | func UpdateConfig(configObjectPtr *Config) error { 87 | err := DB.Model(&Config{ID: 1}).Updates(configObjectPtr).Error 88 | if err != nil { 89 | utils.Logger.Error("failed to update config", zap.Error(err)) 90 | return err 91 | } 92 | for i := range configObjectPtr.ModelConfig { 93 | err = DB.Model(&configObjectPtr.ModelConfig).Updates(&configObjectPtr.ModelConfig[i]).Error 94 | if err != nil { 95 | utils.Logger.Error("failed to update model config", zap.Error(err)) 96 | return err 97 | } 98 | } 99 | _ = config.SetCache(configCacheName, *configObjectPtr, configCacheExpire) 100 | return nil 101 | } 102 | 103 | func GetPluginConfig(modelID int) (map[string]bool, error) { 104 | var configObject Config 105 | if err := LoadConfig(&configObject); err != nil { 106 | return nil, err 107 | } 108 | for _, modelConfig := range configObject.ModelConfig { 109 | if modelConfig.ID == modelID { 110 | return modelConfig.DefaultPluginConfig, nil 111 | } 112 | } 113 | // if not found, return default config of first model 114 | return configObject.ModelConfig[0].DefaultPluginConfig, nil 115 | } 116 | -------------------------------------------------------------------------------- /models/email_blacklist.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "golang.org/x/exp/slices" 5 | "strings" 6 | ) 7 | 8 | type EmailBlacklist struct { 9 | ID int 10 | EmailDomain string 11 | } 12 | 13 | func IsEmailInBlacklist(email string) bool { 14 | var blacklist []string 15 | DB.Model(&EmailBlacklist{}).Select("email_domain").Scan(&blacklist) 16 | parts := strings.Split(email, "@") 17 | return slices.Contains(blacklist, parts[1]) 18 | } 19 | -------------------------------------------------------------------------------- /models/init.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "MOSS_backend/config" 5 | "errors" 6 | "gorm.io/driver/mysql" 7 | "gorm.io/driver/sqlite" 8 | "gorm.io/gorm" 9 | "gorm.io/gorm/logger" 10 | "gorm.io/gorm/schema" 11 | "log" 12 | "os" 13 | "time" 14 | ) 15 | 16 | var DB *gorm.DB 17 | 18 | var gormConfig = &gorm.Config{ 19 | NamingStrategy: schema.NamingStrategy{ 20 | SingularTable: true, // use singular table name, table for `User` would be `user` with this option enabled 21 | }, 22 | Logger: logger.New( 23 | log.Default(), 24 | logger.Config{ 25 | SlowThreshold: time.Second, // 慢 SQL 阈值 26 | LogLevel: logger.Error, // 日志级别 27 | IgnoreRecordNotFoundError: true, // 忽略ErrRecordNotFound(记录未找到)错误 28 | Colorful: false, // 禁用彩色打印 29 | }, 30 | ), 31 | } 32 | 33 | func InitDB() { 34 | mysqlDB := func() (*gorm.DB, error) { 35 | return gorm.Open(mysql.Open(config.Config.DbUrl), gormConfig) 36 | } 37 | sqliteDB := func() (*gorm.DB, error) { 38 | err := os.MkdirAll("data", 0755) 39 | if err != nil && !os.IsExist(err) { 40 | panic(err) 41 | } 42 | return gorm.Open(sqlite.Open("data/sqlite.db"), gormConfig) 43 | } 44 | memoryDB := func() (*gorm.DB, error) { 45 | return gorm.Open(sqlite.Open("file::memory:?cache=shared"), gormConfig) 46 | } 47 | 48 | var err error 49 | 50 | // connect to database with different mode 51 | switch config.Config.Mode { 52 | case "production": 53 | DB, err = mysqlDB() 54 | case "dev": 55 | if config.Config.DbUrl == "" { 56 | DB, err = sqliteDB() 57 | } else { 58 | DB, err = mysqlDB() 59 | } 60 | case "test": 61 | DB, err = memoryDB() 62 | case "bench": 63 | if config.Config.DbUrl == "" { 64 | DB, err = memoryDB() 65 | } else { 66 | DB, err = mysqlDB() 67 | } 68 | default: 69 | panic("unsupported mode") 70 | } 71 | 72 | if err != nil { 73 | panic(err) 74 | } 75 | 76 | if config.Config.Mode == "dev" || config.Config.Mode == "test" { 77 | DB = DB.Debug() 78 | } 79 | 80 | // migrate database 81 | err = DB.AutoMigrate( 82 | User{}, 83 | Chat{}, 84 | Record{}, 85 | ActiveStatus{}, 86 | Config{}, 87 | ModelConfig{}, 88 | InviteCode{}, 89 | Param{}, 90 | EmailBlacklist{}, 91 | DirectRecord{}, 92 | UserOffense{}, 93 | ) 94 | if err != nil { 95 | panic(err) 96 | } 97 | 98 | var configObject Config 99 | err = DB.First(&configObject).Error 100 | if errors.Is(err, gorm.ErrRecordNotFound) { 101 | DB.Create(&configObject) 102 | } 103 | var configModelObject ModelConfig 104 | err = DB.First(&configModelObject).Error 105 | if errors.Is(err, gorm.ErrRecordNotFound) { 106 | DB.Create(&configModelObject) 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /models/invite_code.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type InviteCode struct { 4 | ID int `gorm:"primaryKey"` 5 | Code string `gorm:"unique,size:32"` 6 | IsSend bool 7 | IsActivated bool 8 | } 9 | -------------------------------------------------------------------------------- /models/user.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/json" 6 | "errors" 7 | "strconv" 8 | "strings" 9 | "time" 10 | 11 | "MOSS_backend/config" 12 | "MOSS_backend/utils" 13 | 14 | "github.com/gofiber/fiber/v2" 15 | "github.com/gofiber/websocket/v2" 16 | "go.uber.org/zap" 17 | "golang.org/x/exp/slices" 18 | "gorm.io/gorm" 19 | ) 20 | 21 | type User struct { 22 | ID int `json:"id" gorm:"primaryKey"` 23 | JoinedTime time.Time `json:"joined_time" gorm:"autoCreateTime"` 24 | LastLogin time.Time `json:"last_login" gorm:"autoUpdateTime"` 25 | DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` 26 | Nickname string `json:"nickname" gorm:"size:128;default:'user'"` 27 | Email string `json:"email" gorm:"size:128;index:,length:5"` 28 | Phone string `json:"phone" gorm:"size:128;index:,length:5"` 29 | Password string `json:"-" gorm:"size:128"` 30 | RegisterIP string `json:"-" gorm:"size:32"` 31 | LastLoginIP string `json:"-" gorm:"size:32"` 32 | LoginIP []string `json:"-" gorm:"serializer:json"` 33 | Chats Chats `json:"chats,omitempty"` 34 | ShareConsent bool `json:"share_consent" gorm:"default:true"` 35 | InviteCode string `json:"-" gorm:"size:32"` 36 | IsAdmin bool `json:"is_admin"` 37 | DisableSensitiveCheck bool `json:"disable_sensitive_check"` 38 | Banned bool `json:"banned"` 39 | ModelID int `json:"model_id" default:"1" gorm:"default:1"` 40 | PluginConfig map[string]bool `json:"plugin_config" gorm:"serializer:json"` 41 | } 42 | 43 | func GetUserCacheKey(userID int) string { 44 | return "moss_user:" + strconv.Itoa(userID) 45 | } 46 | 47 | const UserCacheExpire = 48 * time.Hour 48 | 49 | func GetUserID(c *fiber.Ctx) (int, error) { 50 | if config.Config.Mode == "dev" || config.Config.Mode == "test" { 51 | return 1, nil 52 | } 53 | 54 | id, err := strconv.Atoi(c.Get("X-Consumer-Username")) 55 | if err != nil { 56 | return 0, utils.Unauthorized("Unauthorized") 57 | } 58 | 59 | return id, nil 60 | } 61 | 62 | // LoadUserByIDFromCache return value `err` is directly from DB.Take() 63 | func LoadUserByIDFromCache(userID int, userPtr *User) error { 64 | cacheKey := GetUserCacheKey(userID) 65 | if config.GetCache(cacheKey, userPtr) != nil { 66 | err := DB.Take(userPtr, userID).Error 67 | if err != nil { 68 | return err 69 | } 70 | // err has been printed in SetCache 71 | _ = config.SetCache(cacheKey, *userPtr, UserCacheExpire) 72 | } 73 | return nil 74 | } 75 | 76 | func DeleteUserCacheByID(userID int) { 77 | cacheKey := GetUserCacheKey(userID) 78 | err := config.DeleteCache(cacheKey) 79 | if err != nil { 80 | utils.Logger.Error("err in DeleteUserCacheByID: ", zap.Error(err)) 81 | } 82 | } 83 | 84 | func LoadUserByID(userID int) (*User, error) { 85 | var user User 86 | err := LoadUserByIDFromCache(userID, &user) 87 | if err != nil { // something wrong in DB.Take() in LoadUserByIDFromCache() 88 | DeleteUserCacheByID(userID) 89 | return nil, err 90 | } 91 | updated := false 92 | 93 | if user.ModelID == 0 { 94 | user.ModelID = config.Config.DefaultModelID 95 | updated = true 96 | } else { 97 | var modelConfig ModelConfig 98 | err = DB.Take(&modelConfig, user.ModelID).Error 99 | if err != nil { 100 | user.ModelID = config.Config.DefaultModelID 101 | updated = true 102 | } 103 | } 104 | 105 | var defaultPluginConfig map[string]bool 106 | defaultPluginConfig, err = GetPluginConfig(user.ModelID) 107 | 108 | if user.PluginConfig == nil { 109 | user.PluginConfig = make(map[string]bool) 110 | for key := range defaultPluginConfig { 111 | user.PluginConfig[key] = false 112 | } 113 | updated = true 114 | } else { // add new key 115 | for key := range defaultPluginConfig { 116 | if _, ok := user.PluginConfig[key]; !ok { 117 | user.PluginConfig[key] = false 118 | updated = true 119 | } 120 | } 121 | 122 | // delete not used key 123 | for key := range user.PluginConfig { 124 | if _, ok := defaultPluginConfig[key]; !ok { 125 | delete(user.PluginConfig, key) 126 | updated = true 127 | } 128 | } 129 | } 130 | 131 | if updated { 132 | DB.Model(&user).Select("ModelID", "PluginConfig").Updates(&user) 133 | err = config.SetCache(GetUserCacheKey(userID), user, UserCacheExpire) 134 | } 135 | return &user, err 136 | } 137 | 138 | func LoadUser(c *fiber.Ctx) (*User, error) { 139 | userID, err := GetUserID(c) 140 | if err != nil { 141 | return nil, err 142 | } 143 | return LoadUserByID(userID) 144 | } 145 | 146 | func GetUserIDFromWs(c *websocket.Conn) (int, error) { 147 | // get cookie named access or query jwt 148 | token := c.Query("jwt") 149 | if token == "" { 150 | token = c.Cookies("access") 151 | if token == "" { 152 | return 0, utils.Unauthorized() 153 | } 154 | } 155 | // get data 156 | data, err := parseJWT(token, false) 157 | if err != nil { 158 | return 0, err 159 | } 160 | id, ok := data["id"] // get id 161 | if !ok { 162 | return 0, utils.Unauthorized() 163 | } 164 | return int(id.(float64)), nil 165 | } 166 | 167 | func LoadUserFromWs(c *websocket.Conn) (*User, error) { 168 | userID, err := GetUserIDFromWs(c) 169 | if err != nil { 170 | return nil, err 171 | } 172 | return LoadUserByID(userID) 173 | } 174 | 175 | // parseJWT extracts and parse token 176 | func parseJWT(token string, bearer bool) (Map, error) { 177 | if len(token) < 7 { 178 | return nil, errors.New("bearer token required") 179 | } 180 | 181 | if bearer { 182 | token = token[7:] 183 | } 184 | 185 | payloads := strings.SplitN(token[7:], ".", 3) // extract "Bearer " 186 | if len(payloads) < 3 { 187 | return nil, errors.New("jwt token required") 188 | } 189 | 190 | // jwt encoding ignores padding, so RawStdEncoding should be used instead of StdEncoding 191 | payloadBytes, err := base64.RawURLEncoding.DecodeString(payloads[1]) // the middle one is payload 192 | if err != nil { 193 | return nil, err 194 | } 195 | 196 | var value Map 197 | err = json.Unmarshal(payloadBytes, &value) 198 | return value, err 199 | } 200 | 201 | func GetUserByRefreshToken(c *fiber.Ctx) (*User, error) { 202 | // get id 203 | userID, err := GetUserID(c) 204 | if err != nil { 205 | return nil, err 206 | } 207 | 208 | tokenString := c.Get("Authorization") 209 | if tokenString == "" { // token can be in either header or cookie 210 | tokenString = c.Cookies("refresh") 211 | } 212 | 213 | payload, err := parseJWT(tokenString, true) 214 | if err != nil { 215 | return nil, err 216 | } 217 | 218 | if tokenType, ok := payload["type"]; !ok || tokenType != "refresh" { 219 | return nil, utils.Unauthorized("refresh token invalid") 220 | } 221 | 222 | var user User 223 | err = LoadUserByIDFromCache(userID, &user) 224 | return &user, err 225 | } 226 | 227 | func (user *User) UpdateIP(ip string) { 228 | user.LastLoginIP = ip 229 | if !slices.Contains(user.LoginIP, ip) { 230 | user.LoginIP = append(user.LoginIP, ip) 231 | } 232 | } 233 | -------------------------------------------------------------------------------- /models/user_offense.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import "time" 4 | 5 | type UserOffense struct { 6 | ID int 7 | CreatedAt time.Time 8 | UserID int 9 | Type UserOffenseType 10 | } 11 | 12 | const OffenseMessage = `您因为多次违规,账号被锁定,如有意见请发送邮件至 txsun19@fudan.edu.cn` 13 | 14 | type UserOffenseType = int 15 | 16 | const ( 17 | UserOffensePrompt UserOffenseType = iota + 1 18 | UserOffenseMoss 19 | ) 20 | 21 | func (user *User) AddUserOffense(offenseType UserOffenseType) (bool, error) { 22 | var offense = UserOffense{ 23 | UserID: user.ID, 24 | Type: offenseType, 25 | } 26 | err := DB.Create(&offense).Error 27 | if err != nil { 28 | return false, err 29 | } 30 | return user.CheckUserOffense() 31 | } 32 | 33 | func (user *User) CheckUserOffense() (bool, error) { 34 | var ( 35 | count int64 36 | err error 37 | ) 38 | 39 | var configObject Config 40 | err = LoadConfig(&configObject) 41 | if err != nil { 42 | return false, err 43 | } 44 | if !configObject.OffenseCheck { 45 | return false, nil 46 | } 47 | if user.Banned { 48 | return true, nil 49 | } 50 | 51 | err = DB.Model(&UserOffense{}). 52 | Where("created_at between ? and ? and type = ? and user_id = ?", 53 | time.Now().Add(-5*time.Minute), 54 | time.Now(), 55 | UserOffensePrompt, 56 | user.ID). 57 | Count(&count).Error 58 | if err != nil { 59 | return false, err 60 | } 61 | if count >= 3 { 62 | user.Banned = true 63 | err = DB.Model(&user).Select("Banned").Updates(user).Error 64 | if err != nil { 65 | return false, err 66 | } 67 | return true, err 68 | } 69 | 70 | err = DB.Model(&UserOffense{}). 71 | Where("created_at between ? and ? and type = ? and user_id = ?", 72 | time.Now().Add(-5*time.Minute), 73 | time.Now(), 74 | UserOffenseMoss, 75 | user.ID). 76 | Count(&count).Error 77 | if err != nil { 78 | return false, err 79 | } 80 | if count >= 10 { 81 | user.Banned = true 82 | err = DB.Model(&user).Select("Banned").Updates(user).Error 83 | if err != nil { 84 | return false, err 85 | } 86 | return true, err 87 | } 88 | return false, err 89 | } 90 | -------------------------------------------------------------------------------- /service/yocsef.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "MOSS_backend/config" 5 | "MOSS_backend/models" 6 | "MOSS_backend/utils" 7 | "bufio" 8 | "bytes" 9 | "context" 10 | "encoding/json" 11 | "errors" 12 | "io" 13 | "net/http" 14 | "strings" 15 | ) 16 | 17 | type InferYocsefRequest struct { 18 | Question string `json:"question,omitempty"` 19 | ChatHistory [][]string `json:"chat_history,omitempty"` 20 | } 21 | 22 | var yocsefHttpClient = &http.Client{} 23 | 24 | func InferYocsef( 25 | ctx context.Context, 26 | w utils.JSONWriter, 27 | prompt string, 28 | records models.RecordModels, 29 | ) ( 30 | model *models.DirectRecord, 31 | err error, 32 | ) { 33 | if config.Config.YocsefInferenceUrl == "" { 34 | return nil, errors.New("yocsef 推理模型暂不可用") 35 | } 36 | 37 | var chatHistory = make([][]string, len(records)) 38 | for i, record := range records { 39 | chatHistory[i] = []string{record.Request, record.Response} 40 | } 41 | 42 | var request = map[string]any{ 43 | "input": map[string]any{ 44 | "question": prompt, 45 | "chat_history": chatHistory, 46 | }, 47 | } 48 | requestData, err := json.Marshal(request) 49 | if err != nil { 50 | return 51 | } 52 | 53 | // server send event 54 | req, err := http.NewRequest("POST", config.Config.YocsefInferenceUrl, bytes.NewBuffer(requestData)) 55 | if err != nil { 56 | return 57 | } 58 | 59 | req.Header.Set("Content-Type", "application/json") 60 | req.Header.Set("Accept", "text/event-stream") 61 | req.Header.Set("Cache-Control", "no-cache") 62 | req.Header.Set("Connection", "keep-alive") 63 | 64 | res, err := yocsefHttpClient.Do(req) 65 | if err != nil { 66 | return 67 | } 68 | defer func(Body io.ReadCloser) { 69 | _ = Body.Close() 70 | }(res.Body) 71 | 72 | if res.StatusCode != http.StatusOK { 73 | return nil, errors.New("yocsef 推理模型暂不可用") 74 | } 75 | 76 | var reader = bufio.NewReader(res.Body) 77 | var resultBuilder strings.Builder 78 | var nowOutput string 79 | var detectedOutput string 80 | 81 | for { 82 | var line []byte 83 | line, err = reader.ReadBytes('\n') 84 | if errors.Is(err, io.EOF) { 85 | break 86 | } 87 | if err != nil { 88 | return nil, err 89 | } 90 | 91 | line = bytes.Trim(line, " \n\r") 92 | if strings.HasPrefix(string(line), "event") { 93 | continue 94 | } 95 | if strings.HasPrefix(string(line), "data") { 96 | line = bytes.TrimPrefix(line, []byte("data:")) 97 | } 98 | line = bytes.Trim(line, " \n\r") 99 | if len(line) == 0 { 100 | continue 101 | } 102 | 103 | if ctx.Err() != nil { 104 | return nil, ctx.Err() 105 | } 106 | 107 | var response map[string]any 108 | err = json.Unmarshal(line, &response) 109 | if err != nil { 110 | continue 111 | } 112 | 113 | var ok bool 114 | nowOutput, ok = response["content"].(string) 115 | if !ok { 116 | continue 117 | } 118 | resultBuilder.WriteString(nowOutput) 119 | nowOutput = resultBuilder.String() 120 | 121 | var endDelimiter = "<|im_end|>" 122 | if strings.Contains(nowOutput, endDelimiter) { 123 | nowOutput = strings.Split(nowOutput, endDelimiter)[0] 124 | break 125 | } 126 | 127 | before, _, found := utils.CutLastAny(nowOutput, ",.?!\n,。?!") 128 | if !found || before == detectedOutput { 129 | continue 130 | } 131 | detectedOutput = before 132 | 133 | _ = w.WriteJSON(InferResponseModel{ 134 | Status: 1, 135 | Output: nowOutput, 136 | Stage: "MOSS", 137 | }) 138 | } 139 | 140 | if nowOutput == "" { 141 | return nil, errors.New("yocsef 推理模型暂不可用") 142 | } 143 | 144 | if ctx.Err() != nil { 145 | return nil, ctx.Err() 146 | } 147 | if nowOutput != detectedOutput { 148 | _ = w.WriteJSON(InferResponseModel{ 149 | Status: 1, 150 | Output: nowOutput, 151 | Stage: "MOSS", 152 | }) 153 | } 154 | 155 | _ = w.WriteJSON(InferResponseModel{ 156 | Status: 0, 157 | Output: nowOutput, 158 | Stage: "MOSS", 159 | }) 160 | 161 | var record = models.DirectRecord{Request: prompt, Response: nowOutput} 162 | return &record, nil 163 | } 164 | 165 | type InferResponseModel struct { 166 | Status int `json:"status"` // 1 for output, 0 for end, -1 for error, -2 for sensitive 167 | StatusCode int `json:"status_code,omitempty"` 168 | Output string `json:"output,omitempty"` 169 | Stage string `json:"stage,omitempty"` 170 | } 171 | -------------------------------------------------------------------------------- /utils/auth/identifier.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "bytes" 5 | "crypto/rand" 6 | "crypto/sha256" 7 | "encoding/base64" 8 | "fmt" 9 | "golang.org/x/crypto/pbkdf2" 10 | "hash" 11 | "math/big" 12 | "strconv" 13 | "strings" 14 | ) 15 | 16 | func passwordHash(bytePassword, salt []byte, iterations, KeyLen int, hash func() hash.Hash) string { 17 | return base64.StdEncoding.EncodeToString(pbkdf2.Key(bytePassword, salt, iterations, KeyLen, hash)) 18 | } 19 | 20 | func saltGenerator(stringLen int) ([]byte, error) { 21 | const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" 22 | charsLength := len(chars) 23 | var builder bytes.Buffer 24 | for i := 0; i < stringLen; i++ { 25 | choiceIndex, err := rand.Int(rand.Reader, big.NewInt(int64(charsLength))) 26 | if err != nil { 27 | return nil, err 28 | } 29 | err = builder.WriteByte(chars[choiceIndex.Int64()]) 30 | if err != nil { 31 | return nil, err 32 | } 33 | } 34 | return builder.Bytes(), nil 35 | } 36 | 37 | func MakePassword(rawPassword string) (string, error) { 38 | salt, err := saltGenerator(12) 39 | if err != nil { 40 | return "", err 41 | } 42 | algorithm := "sha256" 43 | iterations := 216000 44 | hashBase64 := passwordHash([]byte(rawPassword), salt, iterations, 32, sha256.New) 45 | 46 | return fmt.Sprintf("pbkdf2_%v$%v$%v$%v", algorithm, iterations, string(salt), hashBase64), nil 47 | } 48 | 49 | func CheckPassword(rawPassword, encryptPassword string) (bool, error) { 50 | splitEncryptedPassword := strings.Split(encryptPassword, "$") 51 | if len(splitEncryptedPassword) != 4 { 52 | return false, fmt.Errorf("parse encryptPassword error: %v", encryptPassword) 53 | } 54 | algorithm := splitEncryptedPassword[0] 55 | splitAlgorithm := strings.Split(algorithm, "_") 56 | if len(splitAlgorithm) != 2 { 57 | return false, fmt.Errorf("parse encryptPassword algorithm error: %v", encryptPassword) 58 | } 59 | 60 | var hashOutputSize int 61 | var hashFactory func() hash.Hash 62 | if splitAlgorithm[1] == "sha256" { 63 | hashOutputSize = 32 64 | hashFactory = sha256.New 65 | } else { 66 | return false, fmt.Errorf("invalid sum algorithm: %v", splitAlgorithm[1]) 67 | } 68 | 69 | iterations, err := strconv.Atoi(splitEncryptedPassword[1]) 70 | if err != nil { 71 | return false, err 72 | } 73 | 74 | salt := splitEncryptedPassword[2] 75 | 76 | hashBase64 := passwordHash([]byte(rawPassword), []byte(salt), iterations, hashOutputSize, hashFactory) 77 | 78 | return hashBase64 == splitEncryptedPassword[3], nil 79 | } 80 | -------------------------------------------------------------------------------- /utils/auth/verification.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "MOSS_backend/config" 5 | "context" 6 | "crypto/rand" 7 | "fmt" 8 | "github.com/eko/gocache/lib/v4/cache" 9 | gocacheStore "github.com/eko/gocache/store/go_cache/v4" 10 | redisStore "github.com/eko/gocache/store/redis/v4" 11 | gocache "github.com/patrickmn/go-cache" 12 | "github.com/redis/go-redis/v9" 13 | "math/big" 14 | "time" 15 | ) 16 | 17 | var verificationCodeCache *cache.Cache[string] 18 | 19 | func InitCache() { 20 | if config.Config.RedisUrl != "" { 21 | verificationCodeCache = cache.New[string]( 22 | redisStore.NewRedis( 23 | redis.NewClient( 24 | &redis.Options{ 25 | Addr: config.Config.RedisUrl, 26 | }, 27 | ), 28 | ), 29 | ) 30 | fmt.Println("using redis") 31 | } else { 32 | verificationCodeCache = cache.New[string]( 33 | gocacheStore.NewGoCache( 34 | gocache.New( 35 | time.Duration(config.Config.VerificationCodeExpires)*time.Minute, 36 | 20*time.Minute), 37 | ), 38 | ) 39 | fmt.Println("using gocache") 40 | } 41 | } 42 | 43 | // SetVerificationCode 缓存中设置验证码,key = {scope}-{info} 44 | func SetVerificationCode(info, scope string) (string, error) { 45 | codeInt, err := rand.Int(rand.Reader, big.NewInt(1000000)) 46 | if err != nil { 47 | return "", err 48 | } 49 | code := fmt.Sprintf("%06d", codeInt.Uint64()) 50 | 51 | return code, verificationCodeCache.Set( 52 | context.Background(), 53 | fmt.Sprintf("%v-%v", scope, info), 54 | code, 55 | ) 56 | } 57 | 58 | // CheckVerificationCode 检查验证码 59 | func CheckVerificationCode(info, scope, code string) bool { 60 | storedCode, err := verificationCodeCache.Get( 61 | context.Background(), 62 | fmt.Sprintf("%v-%v", scope, info), 63 | ) 64 | return err == nil && storedCode == code 65 | } 66 | 67 | func DeleteVerificationCode(info, scope string) error { 68 | return verificationCodeCache.Delete( 69 | context.Background(), 70 | fmt.Sprintf("%v-%v", scope, info), 71 | ) 72 | } 73 | -------------------------------------------------------------------------------- /utils/errors.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "errors" 5 | "github.com/gofiber/fiber/v2" 6 | "gorm.io/gorm" 7 | ) 8 | 9 | type MessageResponse struct { 10 | Message string `json:"message"` 11 | Data any `json:"data,omitempty"` 12 | } 13 | 14 | type HttpError struct { 15 | Code int `json:"code,omitempty"` 16 | Message string `json:"message,omitempty"` 17 | MessageType MessageType `json:"message_type,omitempty"` 18 | Detail *ErrorDetail `json:"detail,omitempty"` 19 | } 20 | 21 | func (e *HttpError) Error() string { 22 | return e.Message 23 | } 24 | 25 | func (e *HttpError) WithMessageType(messageType MessageType) *HttpError { 26 | e.MessageType = messageType 27 | return e 28 | } 29 | 30 | type MessageType = string 31 | 32 | const ( 33 | MaxLength MessageType = "max_length" 34 | Sensitive = "sensitive" 35 | ) 36 | 37 | func NoStatus(message string) *HttpError { 38 | return &HttpError{ 39 | Code: 0, 40 | Message: message, 41 | } 42 | } 43 | 44 | func BadRequest(messages ...string) *HttpError { 45 | message := "Bad Request" 46 | if len(messages) > 0 { 47 | message = messages[0] 48 | } 49 | return &HttpError{ 50 | Code: 400, 51 | Message: message, 52 | } 53 | } 54 | 55 | func Unauthorized(messages ...string) *HttpError { 56 | message := "Invalid JWT Token" 57 | if len(messages) > 0 { 58 | message = messages[0] 59 | } 60 | return &HttpError{ 61 | Code: 401, 62 | Message: message, 63 | } 64 | } 65 | 66 | func Forbidden(messages ...string) *HttpError { 67 | message := "您没有权限进行此操作" 68 | if len(messages) > 0 { 69 | message = messages[0] 70 | } 71 | return &HttpError{ 72 | Code: 403, 73 | Message: message, 74 | } 75 | } 76 | 77 | func NotFound(messages ...string) *HttpError { 78 | message := "Not Found" 79 | if len(messages) > 0 { 80 | message = messages[0] 81 | } 82 | return &HttpError{ 83 | Code: 404, 84 | Message: message, 85 | } 86 | } 87 | 88 | func InternalServerError(messages ...string) *HttpError { 89 | message := "Unknown Error" 90 | if len(messages) > 0 { 91 | message = messages[0] 92 | } 93 | return &HttpError{ 94 | Code: 500, 95 | Message: message, 96 | } 97 | } 98 | 99 | func MyErrorHandler(ctx *fiber.Ctx, err error) error { 100 | if err == nil { 101 | return nil 102 | } 103 | 104 | httpError := HttpError{ 105 | Code: 500, 106 | Message: err.Error(), 107 | } 108 | 109 | if errors.Is(err, gorm.ErrRecordNotFound) { 110 | httpError.Code = 404 111 | } else { 112 | switch e := err.(type) { 113 | case *HttpError: 114 | httpError = *e 115 | case *fiber.Error: 116 | httpError.Code = e.Code 117 | case *ErrorDetail: 118 | httpError.Code = 400 119 | httpError.Detail = e 120 | case fiber.MultiError: 121 | httpError.Code = 400 122 | httpError.Message = "" 123 | for _, err = range e { 124 | httpError.Message += err.Error() + "\n" 125 | } 126 | } 127 | } 128 | 129 | return ctx.Status(httpError.Code).JSON(&httpError) 130 | } 131 | 132 | type ErrCollection struct { 133 | ErrVerificationCodeInvalid error 134 | ErrNeedInviteCode error 135 | ErrInviteCodeInvalid error 136 | ErrRegistered error 137 | ErrEmailRegistered error 138 | ErrEmailNotRegistered error 139 | ErrEmailCannotModify error 140 | ErrEmailCannotReset error 141 | ErrPhoneRegistered error 142 | ErrPhoneNotRegistered error 143 | ErrPhoneCannotModify error 144 | ErrPhoneCannotReset error 145 | ErrPasswordIncorrect error 146 | ErrEmailInBlacklist error 147 | } 148 | 149 | var ErrCollectionCN = ErrCollection{ 150 | ErrVerificationCodeInvalid: BadRequest("验证码错误"), 151 | ErrNeedInviteCode: BadRequest("需要邀请码"), 152 | ErrInviteCodeInvalid: BadRequest("邀请码错误"), 153 | ErrRegistered: BadRequest("您已注册,如果忘记密码,请使用忘记密码功能找回"), 154 | ErrEmailRegistered: BadRequest("该邮箱已被注册"), 155 | ErrEmailNotRegistered: BadRequest("该邮箱未注册"), 156 | ErrEmailCannotModify: BadRequest("未登录状态,禁止修改邮箱"), 157 | ErrEmailCannotReset: BadRequest("登录状态无法重置密码,请退出登录然后重试"), 158 | ErrPhoneRegistered: BadRequest("该手机号已被注册"), 159 | ErrPhoneNotRegistered: BadRequest("该手机号未注册"), 160 | ErrPhoneCannotModify: BadRequest("未登录状态,禁止修改手机号"), 161 | ErrPhoneCannotReset: BadRequest("登录状态无法重置密码,请退出登录然后重试"), 162 | ErrPasswordIncorrect: Unauthorized("密码错误"), 163 | ErrEmailInBlacklist: BadRequest("该邮箱已被禁用"), 164 | } 165 | 166 | var ErrCollectionGlobal = ErrCollection{ 167 | ErrVerificationCodeInvalid: BadRequest("invalid verification code"), 168 | ErrNeedInviteCode: BadRequest("invitation code needed"), 169 | ErrInviteCodeInvalid: BadRequest("invalid invitation code"), 170 | ErrRegistered: BadRequest("You have registered, if you forget your password, please use reset password function to retrieve"), 171 | ErrEmailRegistered: BadRequest("email address registered"), 172 | ErrEmailNotRegistered: BadRequest("email address not registered"), 173 | ErrEmailCannotModify: BadRequest("cannot modify email address when not login"), 174 | ErrEmailCannotReset: BadRequest("cannot reset password when login, please logout and retry"), 175 | ErrPhoneRegistered: BadRequest("phone number registered"), 176 | ErrPhoneNotRegistered: BadRequest("phone number not registered"), 177 | ErrPhoneCannotModify: BadRequest("cannot modify phone number when not login"), 178 | ErrPhoneCannotReset: BadRequest("cannot reset password when login, please logout and retry"), 179 | ErrPasswordIncorrect: Unauthorized("password incorrect"), 180 | ErrEmailInBlacklist: BadRequest("banned email domain"), 181 | } 182 | 183 | type MessageCollection struct { 184 | MessageLoginSuccess string 185 | MessageRegisterSuccess string 186 | MessageLogoutSuccess string 187 | MessageResetPasswordSuccess string 188 | MessageVerificationEmailSend string 189 | MessageVerificationPhoneSend string 190 | } 191 | 192 | var MessageCollectionCN = MessageCollection{ 193 | MessageLoginSuccess: "登录成功", 194 | MessageRegisterSuccess: "注册成功", 195 | MessageLogoutSuccess: "登出成功", 196 | MessageResetPasswordSuccess: "重置密码成功", 197 | MessageVerificationEmailSend: "验证邮件已发送,请查收\n如未收到,请检查邮件地址是否正确,检查垃圾箱,或重试", 198 | MessageVerificationPhoneSend: "验证短信已发送,请查收\n如未收到,请检查手机号是否正确,检查垃圾箱,或重试", 199 | } 200 | 201 | var MessageCollectionGlobal = MessageCollection{ 202 | MessageLoginSuccess: "Login successful", 203 | MessageRegisterSuccess: "register successful", 204 | MessageLogoutSuccess: "logout successful", 205 | MessageResetPasswordSuccess: "reset password successful", 206 | MessageVerificationEmailSend: "The verification email has been sent, please check\nIf not, please check if the email address is correct, check the spam box, or try again", 207 | MessageVerificationPhoneSend: "The verification message has been sent, please check\nIf not, please check if the phone number is correct, check the spam box, or try again", 208 | } 209 | 210 | func GetInfoByIP(ip string) (*ErrCollection, *MessageCollection) { 211 | if ok, _ := IsInChina(ip); ok { 212 | return &ErrCollectionCN, &MessageCollectionCN 213 | } else { 214 | return &ErrCollectionGlobal, &MessageCollectionGlobal 215 | } 216 | } 217 | -------------------------------------------------------------------------------- /utils/kong/jwt.go: -------------------------------------------------------------------------------- 1 | package kong 2 | 3 | import ( 4 | "MOSS_backend/config" 5 | "MOSS_backend/models" 6 | "github.com/golang-jwt/jwt/v4" 7 | "time" 8 | ) 9 | 10 | func CreateToken(user *models.User) (accessToken, refreshToken string, err error) { 11 | 12 | jwtCredential, err := GetJwtCredential(user.ID) 13 | if err != nil { 14 | return "", "", err 15 | } 16 | claim := jwt.MapClaims{ 17 | "uid": user.ID, 18 | "iss": jwtCredential.Key, 19 | "iat": time.Now().Unix(), 20 | "id": user.ID, 21 | "nickname": user.Nickname, 22 | "joined_time": user.JoinedTime.Format(time.RFC3339), 23 | } 24 | 25 | // access payload 26 | claim["type"] = "access" 27 | claim["exp"] = time.Now().Add(time.Duration(config.Config.AccessExpireTime) * time.Minute).Unix() // 30 minutes 28 | accessToken, err = jwt.NewWithClaims(jwt.SigningMethodHS256, claim).SignedString([]byte(jwtCredential.Secret)) 29 | if err != nil { 30 | return "", "", err 31 | } 32 | 33 | // refresh payload 34 | claim["type"] = "refresh" 35 | claim["exp"] = time.Now().Add(time.Duration(config.Config.RefreshExpireTime) * 24 * time.Hour).Unix() // 30 days 36 | refreshToken, err = jwt.NewWithClaims(jwt.SigningMethodHS256, claim).SignedString([]byte(jwtCredential.Secret)) 37 | if err != nil { 38 | return "", "", err 39 | } 40 | 41 | return 42 | } 43 | -------------------------------------------------------------------------------- /utils/kong/kong.go: -------------------------------------------------------------------------------- 1 | package kong 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "strconv" 10 | 11 | "github.com/gofiber/fiber/v2" 12 | "github.com/pkg/errors" 13 | 14 | "MOSS_backend/config" 15 | ) 16 | 17 | type JwtCredential struct { 18 | ID string `json:"id"` 19 | Secret string `json:"secret"` 20 | Key string `json:"key"` 21 | Algorithm string `json:"algorithm"` 22 | } 23 | 24 | type JwtCredentials struct { 25 | Next string `json:"next"` 26 | Data []*JwtCredential `json:"data"` 27 | } 28 | 29 | var kongClient = &http.Client{} 30 | 31 | func kongRequestDo(Method, URI string, body io.Reader, contentType string) (int, []byte, error) { 32 | req, err := http.NewRequest( 33 | Method, 34 | fmt.Sprintf("%v%v", config.Config.KongUrl, URI), 35 | body, 36 | ) 37 | if err != nil { 38 | return 500, nil, err 39 | } 40 | if contentType != "" { 41 | req.Header.Set("Content-Type", contentType) 42 | } 43 | rsp, err := kongClient.Do(req) 44 | defer func() { 45 | _ = rsp.Body.Close() 46 | }() 47 | if err != nil { 48 | return 500, nil, err 49 | } 50 | data, err := io.ReadAll(rsp.Body) 51 | return rsp.StatusCode, data, err 52 | } 53 | 54 | func Ping() error { 55 | req, err := kongClient.Get(config.Config.KongUrl) 56 | if err != nil { 57 | return err 58 | } 59 | 60 | if req.StatusCode != 200 { 61 | return fmt.Errorf("error connect to kong[%s]: %v", config.Config.KongUrl, err) 62 | } else { 63 | fmt.Println("ping kong success") 64 | } 65 | return req.Body.Close() 66 | } 67 | 68 | func CreateUser(userID int) error { 69 | reqBodyObject := map[string]any{ 70 | "username": strconv.Itoa(userID), 71 | } 72 | reqData, err := json.Marshal(reqBodyObject) 73 | if err != nil { 74 | return err 75 | } 76 | statusCode, body, err := kongRequestDo( 77 | http.MethodPut, 78 | fmt.Sprintf("/consumers/%d", userID), 79 | bytes.NewReader(reqData), 80 | fiber.MIMEApplicationJSON, 81 | ) 82 | if err != nil { 83 | return err 84 | } 85 | if !(statusCode == 200 || statusCode == 201) { 86 | return fmt.Errorf("create user %v in kong error: %v", userID, string(body)) 87 | } 88 | return nil 89 | } 90 | 91 | func CreateJwtCredential(userID int) (*JwtCredential, error) { 92 | statusCode, body, err := kongRequestDo( 93 | http.MethodPost, 94 | fmt.Sprintf("/consumers/%d/jwt", userID), 95 | nil, 96 | fiber.MIMEApplicationForm, 97 | ) 98 | if err != nil { 99 | return nil, err 100 | } 101 | if statusCode == 404 { 102 | err = CreateUser(userID) 103 | return CreateJwtCredential(userID) 104 | } else if statusCode != 201 { 105 | return nil, fmt.Errorf("create user %v jwt credential error: %v", userID, string(body)) 106 | } 107 | jwtCredential := new(JwtCredential) 108 | err = json.Unmarshal(body, &jwtCredential) 109 | if err != nil { 110 | return nil, err 111 | } 112 | return jwtCredential, nil 113 | } 114 | 115 | func ListJwtCredentials(userID int) ([]*JwtCredential, error) { 116 | statusCode, body, err := kongRequestDo( 117 | http.MethodGet, 118 | fmt.Sprintf("/consumers/%d/jwt", userID), 119 | nil, 120 | "", 121 | ) 122 | if err != nil { 123 | return nil, err 124 | } 125 | if statusCode != 200 { 126 | 127 | return nil, fmt.Errorf("list credential error: %v", string(body)) 128 | } 129 | 130 | var jwtCredentials JwtCredentials 131 | err = json.Unmarshal(body, &jwtCredentials) 132 | if err != nil { 133 | return nil, err 134 | } 135 | 136 | return jwtCredentials.Data, nil 137 | } 138 | 139 | func GetJwtCredential(userID int) (*JwtCredential, error) { 140 | jwtCredentials, _ := ListJwtCredentials(userID) 141 | //if err != nil { 142 | // return nil, err 143 | //} 144 | if len(jwtCredentials) == 0 { 145 | return CreateJwtCredential(userID) 146 | } else { 147 | return jwtCredentials[0], nil 148 | } 149 | } 150 | 151 | func DeleteJwtCredential(userID int) error { 152 | deleteAJwtCredential := func(jwtID string) error { 153 | statusCode, _, err := kongRequestDo( 154 | http.MethodDelete, 155 | fmt.Sprintf("/consumers/%d/jwt/%v", userID, jwtID), 156 | nil, 157 | "", 158 | ) 159 | if err != nil { 160 | return err 161 | } 162 | if statusCode != 204 { 163 | return fmt.Errorf("delete user %v jwt credential %v error", userID, jwtID) 164 | } 165 | return nil 166 | } 167 | 168 | var err error 169 | jwtCredentials, err := ListJwtCredentials(userID) 170 | if err != nil { 171 | return err 172 | } 173 | for i := range jwtCredentials { 174 | innerErr := deleteAJwtCredential(jwtCredentials[i].ID) 175 | if innerErr != nil { 176 | if err == nil { 177 | err = innerErr 178 | } else { 179 | err = errors.Wrap(innerErr, err.Error()) 180 | } 181 | } 182 | } 183 | return err 184 | } 185 | -------------------------------------------------------------------------------- /utils/logger.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "go.uber.org/zap" 5 | "go.uber.org/zap/zapcore" 6 | ) 7 | 8 | var Logger, _ = zap.Config{ 9 | Level: zap.NewAtomicLevelAt(zapcore.InfoLevel), 10 | Development: false, 11 | Encoding: "json", 12 | EncoderConfig: zapcore.EncoderConfig{ 13 | TimeKey: "time", 14 | LevelKey: "level", 15 | NameKey: "logger", 16 | MessageKey: "msg", 17 | EncodeLevel: zapcore.LowercaseLevelEncoder, 18 | EncodeTime: zapcore.RFC3339TimeEncoder, 19 | EncodeDuration: zapcore.SecondsDurationEncoder, 20 | EncodeName: zapcore.FullNameEncoder, 21 | }, 22 | OutputPaths: []string{"stdout"}, 23 | ErrorOutputPaths: []string{"stderr"}, 24 | }.Build() 25 | -------------------------------------------------------------------------------- /utils/region.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "MOSS_backend/data" 5 | "github.com/lionsoul2014/ip2region/binding/golang/xdb" 6 | "strings" 7 | ) 8 | 9 | var searcher *xdb.Searcher 10 | 11 | func init() { 12 | var err error 13 | searcher, err = xdb.NewWithBuffer(data.Ip2RegionDBFile) 14 | if err != nil { 15 | panic(err) 16 | } 17 | } 18 | 19 | func IsInChina(ip string) (bool, error) { 20 | region, err := searcher.SearchByStr(ip) 21 | if err != nil { 22 | return false, err 23 | } 24 | regionTable := strings.Split(region, "|") 25 | return regionTable[0] == "中国", nil 26 | } 27 | -------------------------------------------------------------------------------- /utils/region_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestRegion(t *testing.T) { 9 | const ip = "4.4.4.4" 10 | region, err := searcher.SearchByStr(ip) 11 | if err != nil { 12 | t.Fatal(err) 13 | } 14 | fmt.Println(region) 15 | 16 | fmt.Println(IsInChina("4.4.4.4")) 17 | } 18 | -------------------------------------------------------------------------------- /utils/sender.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "MOSS_backend/config" 5 | "fmt" 6 | unisms "github.com/apistd/uni-go-sdk/sms" 7 | "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" 8 | "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" 9 | "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/regions" 10 | ses "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ses/v20201002" 11 | "go.uber.org/zap" 12 | ) 13 | 14 | func SendCodeEmail(code, receiver string) error { 15 | credential := common.NewCredential( 16 | config.Config.TencentSecretID, 17 | config.Config.TencentSecretKey, 18 | ) 19 | // 实例化一个client选项,可选的,没有特殊需求可以跳过 20 | cpf := profile.NewClientProfile() 21 | cpf.HttpProfile.Endpoint = "ses.tencentcloudapi.com" 22 | // 实例化要请求产品的client对象,clientProfile是可选的 23 | client, err := ses.NewClient(credential, regions.HongKong, cpf) 24 | if err != nil { 25 | return err 26 | } 27 | 28 | // 实例化一个请求对象,每个接口都会对应一个request对象 29 | request := ses.NewSendEmailRequest() 30 | 31 | request.FromEmailAddress = common.StringPtr(config.Config.EmailUrl) 32 | request.Destination = common.StringPtrs([]string{receiver}) 33 | request.Template = &ses.Template{ 34 | TemplateID: common.Uint64Ptr(config.Config.TencentTemplateID), 35 | TemplateData: common.StringPtr(fmt.Sprintf("{\"code\": \"%s\"}", code)), 36 | } 37 | request.Subject = common.StringPtr("[MOSS] Verification Code") 38 | request.TriggerType = common.Uint64Ptr(1) 39 | 40 | // 返回的resp是一个SendEmailResponse的实例,与请求对象对应 41 | resp, err := client.SendEmail(request) 42 | if err != nil { 43 | return err 44 | } 45 | Logger.Info("SendEmailResponse", zap.String("Response", resp.ToJsonString())) 46 | return err 47 | } 48 | 49 | func SendCodeMessage(code, phone string) error { 50 | // 初始化 51 | client := unisms.NewClient(config.Config.UniAccessID) // 若使用简易验签模式仅传入第一个参数即可 52 | 53 | // 构建信息 54 | message := unisms.BuildMessage() 55 | message.SetTo(phone) 56 | message.SetSignature(config.Config.UniSignature) 57 | message.SetTemplateId(config.Config.UniTemplateID) 58 | message.SetTemplateData(map[string]string{"code": code}) // 设置自定义参数 (变量短信) 59 | 60 | // 发送短信 61 | _, err := client.Send(message) 62 | return err 63 | } 64 | -------------------------------------------------------------------------------- /utils/sensitive/diting/main.go: -------------------------------------------------------------------------------- 1 | package diting 2 | 3 | import ( 4 | "MOSS_backend/config" 5 | "bytes" 6 | "encoding/json" 7 | "github.com/google/uuid" 8 | "io" 9 | "log" 10 | "net/http" 11 | ) 12 | 13 | var sensitiveClient http.Client 14 | 15 | const sensitiveCheckUrl = `https://gtf.ai.xingzheai.cn/v2.0/game_chat_ban/detect_text` 16 | 17 | type SensitiveRequest struct { 18 | DataID string `json:"data_id"` 19 | Context string `json:"context"` 20 | ContextType string `json:"context_type"` 21 | Token string `json:"token"` 22 | } 23 | 24 | type SensitiveResponse struct { 25 | Code int `json:"code"` 26 | Msg string `json:"msg"` 27 | DataID string `json:"data_id"` 28 | Data struct { 29 | Suggestion string `json:"suggestion"` 30 | Label string `json:"label"` 31 | } `json:"data,omitempty"` 32 | } 33 | 34 | func IsSensitive(context string) bool { 35 | data, err := json.Marshal(SensitiveRequest{ 36 | DataID: uuid.NewString(), 37 | Context: context, 38 | ContextType: "chat", 39 | Token: config.Config.DiTingToken, 40 | }) 41 | if err != nil { 42 | log.Println("marshal data err") 43 | return false 44 | } 45 | rsp, err := sensitiveClient.Post( 46 | sensitiveCheckUrl, 47 | "application/json", 48 | bytes.NewBuffer(data), 49 | ) 50 | if err != nil { 51 | log.Println("sending detect request error") 52 | return false 53 | } 54 | defer func() { 55 | _ = rsp.Body.Close() 56 | }() 57 | 58 | if rsp.StatusCode != 200 { 59 | log.Printf("detect request status code: %d\n", rsp.StatusCode) 60 | return false 61 | } 62 | 63 | var response SensitiveResponse 64 | responseData, err := io.ReadAll(rsp.Body) 65 | if err != nil { 66 | log.Println("response read error") 67 | return false 68 | } 69 | err = json.Unmarshal(responseData, &response) 70 | if err != nil { 71 | log.Println("response decode error") 72 | return false 73 | } 74 | 75 | if response.Code == -1 { 76 | log.Println("detect error") 77 | if response.Msg == "recharge" { 78 | log.Println("recharge sensitive detect platform") 79 | } 80 | } else { 81 | if response.Data.Suggestion == "pass" { 82 | return false 83 | } else { 84 | return true 85 | } 86 | } 87 | return false 88 | } 89 | -------------------------------------------------------------------------------- /utils/sensitive/sensitive.go: -------------------------------------------------------------------------------- 1 | package sensitive 2 | 3 | import ( 4 | "MOSS_backend/config" 5 | "MOSS_backend/models" 6 | "MOSS_backend/utils/sensitive/diting" 7 | "MOSS_backend/utils/sensitive/shumei" 8 | ) 9 | 10 | func IsSensitive(content string, user *models.User) bool { 11 | if content == "" { 12 | return false 13 | } 14 | if !config.Config.EnableSensitiveCheck { 15 | return false 16 | } 17 | if user.IsAdmin && user.DisableSensitiveCheck { 18 | return false 19 | } 20 | if config.Config.SensitiveCheckPlatform == "ShuMei" { 21 | return shumei.IsSensitive(content) 22 | } else { 23 | return diting.IsSensitive(content) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /utils/sensitive/shumei/main.go: -------------------------------------------------------------------------------- 1 | package shumei 2 | 3 | import ( 4 | "MOSS_backend/config" 5 | "MOSS_backend/utils" 6 | "bytes" 7 | "encoding/json" 8 | "github.com/google/uuid" 9 | "go.uber.org/zap" 10 | "io" 11 | "net/http" 12 | "time" 13 | ) 14 | 15 | const url = `http://api-text-bj.fengkongcloud.com/text/v4` 16 | 17 | var client = http.Client{Timeout: 1 * time.Second} 18 | 19 | type Request struct { 20 | AccessKey string `json:"accessKey"` 21 | AppId string `json:"appId"` 22 | EventId string `json:"eventId"` 23 | Type string `json:"type"` 24 | Data RequestData `json:"data"` 25 | } 26 | 27 | type RequestData struct { 28 | Text string `json:"text"` 29 | TokenId string `json:"tokenId"` 30 | } 31 | 32 | type Response struct { 33 | Code int `json:"code"` 34 | Message string `json:"message"` 35 | RequestId string `json:"requestId"` 36 | RiskLevel string `json:"riskLevel"` 37 | } 38 | 39 | func IsSensitive(content string) bool { 40 | data, _ := json.Marshal(Request{ 41 | AccessKey: config.Config.ShuMeiAccessKey, 42 | AppId: config.Config.ShuMeiAppID, 43 | EventId: config.Config.ShuMeiEventID, 44 | Type: config.Config.ShuMeiType, 45 | Data: RequestData{ 46 | Text: content, 47 | TokenId: uuid.NewString(), 48 | }, 49 | }) 50 | 51 | // timer 52 | startTime := time.Now() 53 | defer func() { 54 | utils.Logger.Info( 55 | "shumei check", 56 | zap.Int64("duration", time.Since(startTime).Milliseconds()), 57 | ) 58 | }() 59 | 60 | rsp, err := client.Post(url, "application/json", bytes.NewBuffer(data)) 61 | if err != nil { 62 | utils.Logger.Error("shu mei: post error", 63 | zap.Error(err), 64 | ) 65 | return false 66 | } 67 | 68 | defer func() { 69 | _ = rsp.Body.Close() 70 | }() 71 | 72 | data, err = io.ReadAll(rsp.Body) 73 | if err != nil { 74 | utils.Logger.Error("shu mei: read body error", 75 | zap.Error(err), 76 | ) 77 | return false 78 | } 79 | 80 | if rsp.StatusCode != 200 { 81 | utils.Logger.Error("shu mei: platform error", 82 | zap.Int("status code", rsp.StatusCode), 83 | ) 84 | return false 85 | } 86 | 87 | var response Response 88 | err = json.Unmarshal(data, &response) 89 | if err != nil { 90 | utils.Logger.Error("shu mei: response decode error", 91 | zap.String("response", string(data)), 92 | zap.Error(err), 93 | ) 94 | return false 95 | } 96 | 97 | if response.Code != 1100 { 98 | utils.Logger.Warn("shu mei: check error", 99 | zap.String("message", response.Message), 100 | ) 101 | return false 102 | } else { 103 | if response.RiskLevel == "PASS" { 104 | return false 105 | } else { 106 | return true 107 | } 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /utils/tools/calculate.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "MOSS_backend/config" 5 | "MOSS_backend/utils" 6 | "bytes" 7 | "encoding/json" 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "strconv" 12 | "time" 13 | 14 | "go.uber.org/zap" 15 | ) 16 | 17 | type keyNotExistError struct { 18 | Results Map 19 | } 20 | 21 | func (e keyNotExistError) Error() string { 22 | return fmt.Sprintf("`result` in results does not exist. results: %v", e.Results) 23 | } 24 | 25 | type resultNotStringError struct { 26 | Results Map 27 | } 28 | 29 | func (e resultNotStringError) Error() string { 30 | return fmt.Sprintf("`result` in results is not a string type. results: %v", e.Results) 31 | } 32 | 33 | type calculateTask struct { 34 | taskModel 35 | results Map 36 | resultString string 37 | } 38 | 39 | var _ task = (*calculateTask)(nil) 40 | 41 | var calculateHttpClient = http.Client{Timeout: 20 * time.Second} 42 | 43 | func (t *calculateTask) postprocess() *ResultModel { 44 | if t.err != nil { 45 | return NoneResultModel 46 | } 47 | return &ResultModel{ 48 | Result: t.resultString, 49 | ExtraData: &ExtraDataModel{ 50 | Type: "calculate", 51 | Request: t.args, 52 | Data: t.results, 53 | }, 54 | ProcessedExtraData: &ExtraDataModel{ 55 | Type: t.action, 56 | Request: t.args, 57 | Data: t.resultString, 58 | }, 59 | } 60 | } 61 | 62 | func (t *calculateTask) request() { 63 | data, _ := json.Marshal(map[string]any{"text": t.args}) 64 | res, err := calculateHttpClient.Post(config.Config.ToolsCalculateUrl, "application/json", bytes.NewBuffer(data)) 65 | if err != nil { 66 | utils.Logger.Error("post calculate(tools) error: ", zap.Error(err)) 67 | t.err = ErrGeneric 68 | return 69 | } 70 | 71 | if res.StatusCode != 200 { 72 | utils.Logger.Error("post calculate(tools) status code error: " + strconv.Itoa(res.StatusCode)) 73 | t.err = ErrGeneric 74 | return 75 | } 76 | 77 | responseData, err := io.ReadAll(res.Body) 78 | if err != nil { 79 | utils.Logger.Error("post calculate(tools) response read error: ", zap.Error(err)) 80 | t.err = ErrGeneric 81 | return 82 | } 83 | 84 | var results map[string]any 85 | err = json.Unmarshal(responseData, &results) 86 | if err != nil { 87 | utils.Logger.Error("post calculate(tools) response unmarshal error: ", zap.Error(err)) 88 | t.err = ErrGeneric 89 | return 90 | } 91 | calculateResult, exist := results["result"] 92 | if !exist { 93 | utils.Logger.Error("post calculate(tools) response format error: ", zap.Error(keyNotExistError{Results: results})) 94 | t.err = ErrGeneric 95 | return 96 | } 97 | resultsString, ok := calculateResult.(string) 98 | if !ok { 99 | utils.Logger.Error("post calculate(tools) response format error: ", zap.Error(resultNotStringError{Results: results})) 100 | t.err = ErrGeneric 101 | return 102 | } 103 | if _, err := strconv.ParseFloat(resultsString, 32); err != nil { 104 | utils.Logger.Error("post calculate(tools) response not number error: ", zap.Error(err)) 105 | t.err = ErrGeneric 106 | return 107 | } 108 | 109 | t.results = results 110 | t.resultString = resultsString 111 | } 112 | -------------------------------------------------------------------------------- /utils/tools/draw.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "MOSS_backend/config" 5 | "MOSS_backend/utils" 6 | "bytes" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "os" 11 | "strconv" 12 | "time" 13 | 14 | "github.com/google/uuid" 15 | "github.com/vmihailenco/msgpack/v5" 16 | "go.uber.org/zap" 17 | ) 18 | 19 | //func main() { 20 | // prompt := flag.String("p", "a photo of an astronaut riding a horse on Mars", "prompt") 21 | // host := flag.String("host", "0.0.0.0", "remote server host ip") 22 | // port := flag.Int("port", 443, "service port") 23 | // flag.Parse() 24 | // 25 | // client := &http.Client{} 26 | // reqBody := msgpack.MustMarshal(*prompt) 27 | // resp, err := client.Post( 28 | // fmt.Sprintf("http://%s:%d", *host, *port), 29 | // "application/x-msgpack", 30 | // bytes.NewBuffer(reqBody), 31 | // ) 32 | // if err != nil { 33 | // fmt.Printf("ERROR: %v\n", err) 34 | // return 35 | // } 36 | // defer resp.Body.Close() 37 | // 38 | // if resp.StatusCode == http.StatusOK { 39 | // data := make([]byte, resp.ContentLength) 40 | // if _, err := resp.Body.Read(data); err != nil { 41 | // fmt.Printf("ERROR: %v\n", err) 42 | // return 43 | // } 44 | // fmt.Println(base64.StdEncoding.EncodeToString(data)) 45 | // } else { 46 | // fmt.Printf("ERROR: <%d> %s\n", resp.StatusCode, resp.Status) 47 | // } 48 | //} 49 | 50 | type drawTask struct { 51 | taskModel 52 | results []byte 53 | url string 54 | } 55 | 56 | var _ task = (*drawTask)(nil) 57 | 58 | func (t *drawTask) request() { 59 | reqBody, err := msgpack.Marshal(t.args) 60 | if err != nil { 61 | utils.Logger.Error("post draw(tools) prompt cannot marshal error: ", zap.Error(err)) 62 | t.err = ErrGeneric 63 | return 64 | } 65 | res, err := drawHttpClient.Post(config.Config.ToolsDrawUrl, "application/x-msgpack", bytes.NewBuffer(reqBody)) 66 | if err != nil { 67 | utils.Logger.Error("post draw(tools) error: ", zap.Error(err)) 68 | t.err = ErrGeneric 69 | return 70 | } 71 | 72 | if res.StatusCode != 200 { 73 | utils.Logger.Error("post draw(tools) status code error: " + strconv.Itoa(res.StatusCode)) 74 | t.err = ErrGeneric 75 | return 76 | } 77 | data, err := io.ReadAll(res.Body) 78 | if err != nil { 79 | utils.Logger.Error("post draw(tools) response body data cannot read error: ", zap.Error(err)) 80 | t.err = ErrGeneric 81 | return 82 | } 83 | var resultsByte []byte 84 | if err = msgpack.Unmarshal(data, &resultsByte); err != nil { 85 | utils.Logger.Error("post draw(tools) response body data cannot Unmarshal error: ", zap.Error(err)) 86 | t.err = ErrGeneric 87 | return 88 | } 89 | 90 | // save 91 | t.results = resultsByte 92 | } 93 | 94 | func (t *drawTask) postprocess() *ResultModel { 95 | if t.err != nil { 96 | return NoneResultModel 97 | } 98 | // save to file 99 | filename := uuid.NewString() + ".jpg" 100 | err := os.WriteFile(fmt.Sprintf("./draw/%s", filename), t.results, 0644) 101 | if err != nil { 102 | utils.Logger.Error("post draw(tools) response body data cannot save to file error: ", zap.Error(err)) 103 | return NoneResultModel 104 | } 105 | 106 | t.url = fmt.Sprintf("https://%s/api/draw/%s", config.Config.Hostname, filename) 107 | 108 | return &ResultModel{ 109 | Result: "a picture of the given prompt has been finished", 110 | ExtraData: &ExtraDataModel{ 111 | Type: "draw", 112 | Request: t.args, 113 | // sending resultsByte using json means `automatically encoding with BASE64` 114 | Data: t.results, 115 | }, 116 | ProcessedExtraData: &ExtraDataModel{ 117 | Type: t.action, 118 | Request: t.args, 119 | Data: t.url, 120 | }, 121 | } 122 | } 123 | 124 | var drawHttpClient = http.Client{Timeout: 20 * time.Second} 125 | -------------------------------------------------------------------------------- /utils/tools/main.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "MOSS_backend/config" 5 | "MOSS_backend/utils" 6 | "errors" 7 | "fmt" 8 | "regexp" 9 | "sort" 10 | "strings" 11 | "sync" 12 | 13 | "github.com/gofiber/websocket/v2" 14 | "go.uber.org/zap" 15 | ) 16 | 17 | type Map = map[string]any 18 | type CommandStatusModel struct { 19 | Status int `json:"status"` 20 | ID int `json:"id"` 21 | Args string `json:"output"` 22 | Type string `json:"type"` 23 | Stage string `json:"stage"` 24 | } 25 | 26 | const maxCommandNumber = 4 27 | 28 | var commandsFormatRegexp = regexp.MustCompile(`^\w+\("([\s\S]+?)"\)(, *?\w+\("([\s\S]+?)"\))*$`) 29 | var commandSplitRegexp = regexp.MustCompile(`(\w+)\("([\s\S]+?)"\)`) 30 | var commandOrder = map[string]int{"Search": 1, "Calculate": 2, "Solve": 3, "Text2Image": 4} 31 | var Command2Description = map[string]string{"Search": "Web search", "Calculate": "Calculator", "Solve": "Equation solver", "Text2Image": "Text-to-image"} 32 | var ErrInvalidCommandFormat = errors.New("commands format error") 33 | var ErrCommandIsNotNone = errors.New("command is not none") 34 | 35 | func Execute(c *websocket.Conn, rawCommand string, pluginConfig map[string]bool) (*ResultTotalModel, string, error) { 36 | if rawCommand == "None" || rawCommand == "none" { 37 | return NoneResultTotalModel, "None", ErrCommandIsNotNone 38 | } 39 | if !config.Config.EnableTools { 40 | return NoneResultTotalModel, "None", ErrCommandIsNotNone 41 | } 42 | if !commandsFormatRegexp.MatchString(rawCommand) { 43 | return NoneResultTotalModel, "None", ErrInvalidCommandFormat 44 | } 45 | // commands is like: [[Search("A"), Search, A,] [Solve("B"), Solve, B] [Search("C"), Search, C]] 46 | commands := commandSplitRegexp.FindAllStringSubmatch(rawCommand, -1) 47 | 48 | commands, newCommandString, err := filterCommand(commands, pluginConfig) 49 | if err != nil { 50 | return NoneResultTotalModel, "None", err 51 | } 52 | 53 | // sort, search should be at first 54 | sort.Slice(commands, func(i, j int) bool { 55 | return commandOrder[commands[i][1]] < commandOrder[commands[j][1]] 56 | }) 57 | // commands now like: [[Search("A"), Search, A,] [Search("C"), Search, C] [Solve("B"), Solve, B]] 58 | 59 | var s = &scheduler{ 60 | tasks: make([]task, 0, len(commands)), 61 | // the index of `the search results in <|results|>` starts with 1 62 | searchResultsIndex: 1, 63 | } 64 | 65 | var resultTotal = &ResultTotalModel{ 66 | ExtraData: make([]*ExtraDataModel, 0, len(commands)), 67 | ProcessedExtraData: make([]*ExtraDataModel, 0, len(commands)), 68 | } 69 | 70 | // generate tasks 71 | for i := range commands { 72 | if i >= maxCommandNumber { 73 | break 74 | } 75 | sendCommandStatus(c, i, commands[i][1], commands[i][2], "start") 76 | t := s.NewTask(commands[i][1], commands[i][2]) 77 | if t != nil { 78 | s.tasks = append(s.tasks, t) 79 | } 80 | } 81 | 82 | // request tools concurrently 83 | var wg sync.WaitGroup 84 | for _, t := range s.tasks { 85 | wg.Add(1) 86 | go func(t task) { 87 | defer wg.Done() 88 | t.request() 89 | }(t) 90 | } 91 | wg.Wait() 92 | 93 | // postprocess 94 | var resultsBuilder strings.Builder 95 | for i, t := range s.tasks { 96 | results := t.postprocess() 97 | 98 | resultsBuilder.WriteString(t.name()) 99 | if t.getAction() == "Calculate" { 100 | resultsBuilder.WriteString(" => ") 101 | } else { 102 | resultsBuilder.WriteString(" =>\n") 103 | } 104 | resultsBuilder.WriteString(results.Result) 105 | resultsBuilder.WriteString("\n") 106 | if results.ExtraData != nil { 107 | resultTotal.ExtraData = append(resultTotal.ExtraData, results.ExtraData) 108 | } 109 | if results.ProcessedExtraData != nil { 110 | resultTotal.ProcessedExtraData = append(resultTotal.ProcessedExtraData, results.ProcessedExtraData) 111 | } 112 | sendCommandStatus(c, i, commands[i][1], commands[i][2], "done") 113 | } 114 | 115 | if resultsBuilder.String() == "" { 116 | return NoneResultTotalModel, "None", nil 117 | } 118 | 119 | resultTotal.Result = resultsBuilder.String() 120 | return resultTotal, newCommandString, nil 121 | } 122 | 123 | func (s *scheduler) NewTask(action string, args string) task { 124 | if config.Config.Debug { 125 | fmt.Println(action + args) 126 | } 127 | t := taskModel{ 128 | s: s, 129 | action: action, 130 | args: args, 131 | err: nil, 132 | } 133 | switch action { 134 | case "Search": 135 | return &searchTask{taskModel: t} 136 | case "Calculate": 137 | return &calculateTask{taskModel: t} 138 | case "Solve": 139 | return &solveTask{taskModel: t} 140 | case "Text2Image": 141 | return &drawTask{taskModel: t} 142 | default: 143 | return nil 144 | } 145 | } 146 | 147 | // sendCommandStatus 148 | // a filter. only inform frontend well-formed commands 149 | func sendCommandStatus(c *websocket.Conn, id int, action, args, StatusString string) { 150 | if c == nil { 151 | //utils.Logger.Info("no ws connection") 152 | return 153 | } 154 | if err := c.WriteJSON(CommandStatusModel{ 155 | Status: 3, // 3 means `send command status` 156 | ID: id + 1, // id start with 1 157 | Type: action, 158 | Args: args, 159 | Stage: StatusString, // start or done 160 | }); err != nil { 161 | utils.Logger.Error("fail to send command status", zap.Error(err)) 162 | } 163 | } 164 | 165 | func filterCommand(commands [][]string, pluginConfig map[string]bool) ([][]string, string, error) { 166 | var newCommandBuilder strings.Builder 167 | var validCommands = make([][]string, 0, len(commands)) 168 | for i := range commands { 169 | if description, ok := Command2Description[commands[i][1]]; !ok { 170 | continue 171 | } else { 172 | if v, ok := pluginConfig[description]; !ok || !v { 173 | continue 174 | } 175 | } 176 | validCommands = append(validCommands, commands[i]) 177 | if i > 0 { 178 | newCommandBuilder.WriteString(", ") 179 | } 180 | newCommandBuilder.WriteString(commands[i][0]) 181 | } 182 | if len(validCommands) == 0 { 183 | return nil, "None", ErrCommandIsNotNone 184 | } 185 | return validCommands, newCommandBuilder.String(), nil 186 | } 187 | 188 | //func executeOnce(action string, args string, searchResultIndex *int) (string, map[string]any) { 189 | // if config.Config.Debug { 190 | // fmt.Println(action + args) 191 | // } 192 | // switch action { 193 | // case "Search": 194 | // results, extraData := search(args) 195 | // searchResult := searchResultsFormatter(results, searchResultIndex) 196 | // return searchResult, extraData 197 | // case "Calculate": 198 | // return calculate(args) 199 | // case "Solve": 200 | // return solve(args) 201 | // case "Draw": 202 | // return draw(args) 203 | // default: 204 | // return "None", nil 205 | // } 206 | //} 207 | 208 | //func cutCommand(command string) (string, string) { 209 | // before, after, found := strings.Cut(command, "(") 210 | // if found { 211 | // return before, strings.Trim(after, "\")") 212 | // } else { 213 | // return command, "" 214 | // } 215 | //} 216 | -------------------------------------------------------------------------------- /utils/tools/schema.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | type ResultModel struct { 9 | Result string 10 | ExtraData *ExtraDataModel 11 | ProcessedExtraData *ExtraDataModel 12 | } 13 | 14 | type ResultTotalModel struct { 15 | Result string `json:"-"` //`json:"result"` 16 | ExtraData []*ExtraDataModel `json:"-"` //`json:"extra_data"` 17 | ProcessedExtraData []*ExtraDataModel `json:"processed_extra_data"` 18 | } 19 | 20 | var NoneResultModel = &ResultModel{Result: "None"} 21 | 22 | var NoneResultTotalModel = &ResultTotalModel{Result: "None"} 23 | 24 | type ExtraDataModel struct { 25 | Type string `json:"type"` 26 | Request string `json:"request"` 27 | Data any `json:"data"` 28 | } 29 | 30 | type task interface { 31 | getAction() string 32 | name() string 33 | request() 34 | postprocess() *ResultModel 35 | } 36 | 37 | type taskModel struct { 38 | s *scheduler 39 | action string 40 | args string 41 | err error 42 | } 43 | 44 | func (t *taskModel) name() string { 45 | return fmt.Sprintf("%s(\"%s\")", t.action, t.args) 46 | } 47 | 48 | func (t *taskModel) getAction() string { 49 | return t.action 50 | } 51 | 52 | type scheduler struct { 53 | tasks []task 54 | searchResultsIndex int 55 | } 56 | 57 | var ErrGeneric = errors.New("default error") 58 | -------------------------------------------------------------------------------- /utils/tools/search.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "MOSS_backend/config" 5 | "MOSS_backend/utils" 6 | "bytes" 7 | "encoding/json" 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "strconv" 12 | "strings" 13 | "time" 14 | 15 | "go.uber.org/zap" 16 | ) 17 | 18 | /* 19 | def clean(tmp_answer): 20 | tmp_answer = tmp_answer.replace('\n',' ') 21 | tmp_answer = tmp_answer.__repr__() 22 | return tmp_answer 23 | 24 | def convert(res): 25 | tmp_sample = [] 26 | id = 0 27 | 28 | try: 29 | line_dict = eval(res) 30 | line_dict = eval(line_dict) 31 | except: 32 | # tmp_sample.append('Error Responses.') 33 | pass 34 | if 'url' in line_dict: 35 | tmp_answer = 'No Results.' 36 | if 'snippet' in line_dict['summ']: 37 | tmp_answer = line_dict['summ']['snippet'].__repr__() 38 | # tmp_answer = clean(tmp_answer) 39 | elif 'title' in line_dict['summ']: 40 | tmp_answer = line_dict['summ']['title'] + ': ' + line_dict['summ']['answer'].__repr__() 41 | # tmp_answer = clean(tmp_answer) 42 | else: 43 | print ("decode error:)") 44 | exit(0) 45 | tmp_sample.append('<|{}|>: {}'.format(id, tmp_answer)) 46 | id += 1 47 | elif '0' in line_dict: 48 | item_num = 1 49 | for key in line_dict: 50 | if item_num <= 3: 51 | item_num += 1 52 | else: 53 | break 54 | tmp_answer = line_dict[key]['summ'] 55 | tmp_answer = clean(tmp_answer)[:400] 56 | tmp_sample.append('<|{}|>: {}'.format(id, tmp_answer)) 57 | id += 1 58 | return tmp_sample 59 | */ 60 | 61 | type searchTask struct { 62 | taskModel 63 | results map[string]any 64 | processedResults map[string]PrettySearch 65 | } 66 | 67 | var _ task = (*searchTask)(nil) 68 | 69 | type PrettySearch struct { 70 | Url string `json:"url"` 71 | Title string `json:"title"` 72 | } 73 | 74 | var searchHttpClient = http.Client{Timeout: 20 * time.Second} 75 | 76 | func clean(tmpAnswer string) string { 77 | tmpAnswer = strings.ReplaceAll(tmpAnswer, "\n", " ") 78 | tmpAnswer = strconv.Quote(tmpAnswer) 79 | return tmpAnswer 80 | } 81 | 82 | func (t *searchTask) postprocess() (r *ResultModel) { 83 | if t.results == nil || t.err != nil { 84 | return NoneResultModel 85 | } 86 | 87 | var ( 88 | dict = t.results 89 | tmpSample = make([]string, 0, 3) 90 | processedResult = make(map[int]PrettySearch) 91 | id = 0 // counter 92 | title, url string 93 | ) 94 | 95 | defer func() { 96 | if something := recover(); something != nil { 97 | utils.Logger.Error("search postprocess panic", zap.Any("something", something)) 98 | r = NoneResultModel 99 | } 100 | }() 101 | 102 | if u, exists := t.results["url"]; exists { 103 | url = u.(string) 104 | tmpAnswer := "No Results." 105 | if summ, ok := dict["summ"]; ok { 106 | // in summary, there are two types of response 107 | if snippet, exists := summ.(Map)["snippet"]; exists { 108 | tmpAnswer = strconv.Quote(fmt.Sprintf("%v", snippet)) 109 | if titleValue, exists := summ.(Map)["title"]; exists { 110 | title = titleValue.(string) 111 | } 112 | } else if titleValue, exists := summ.(Map)["title"]; exists { 113 | title = titleValue.(string) 114 | answer := summ.(Map)["answer"].(string) 115 | tmpAnswer = fmt.Sprintf("%v: %v", title, strconv.Quote(fmt.Sprintf("%v", answer))) 116 | } else { 117 | utils.Logger.Error("search response decode error") 118 | return NoneResultModel 119 | } 120 | } else { 121 | utils.Logger.Error("search response decode error") 122 | return NoneResultModel 123 | } 124 | 125 | tmpSample = append(tmpSample, fmt.Sprintf("<|%d|>: %s", t.s.searchResultsIndex, tmpAnswer)) 126 | 127 | // save to processedResult 128 | processedResult[t.s.searchResultsIndex] = PrettySearch{ 129 | Url: url, 130 | Title: title, 131 | } 132 | t.s.searchResultsIndex += 1 133 | } else if _, exists := dict["0"]; exists { 134 | for _, value := range dict { 135 | // get title, url and answer 136 | if titleValue, exists := value.(Map)["title"]; exists { 137 | title = titleValue.(string) 138 | } 139 | if urlValue, exists := value.(Map)["url"]; exists { 140 | url = urlValue.(string) 141 | } 142 | tmpAnswer := value.(Map)["summ"].(string) 143 | tmpAnswerRune := []rune(clean(tmpAnswer)) 144 | tmpAnswerRune = tmpAnswerRune[:min(len(tmpAnswerRune), 400)] 145 | tmpAnswer = string(tmpAnswerRune) 146 | tmpSample = append(tmpSample, fmt.Sprintf("<|%d|>: %s", t.s.searchResultsIndex, tmpAnswer)) 147 | 148 | // save to processedResult 149 | processedResult[t.s.searchResultsIndex] = PrettySearch{ 150 | Url: url, 151 | Title: title, 152 | } 153 | t.s.searchResultsIndex += 1 154 | 155 | // to next or break 156 | if id < 3 { // topk 157 | id++ 158 | } else { 159 | break 160 | } 161 | } 162 | } 163 | return &ResultModel{ 164 | Result: strings.Join(tmpSample, "\n"), 165 | ExtraData: &ExtraDataModel{ 166 | Type: "search", 167 | Request: t.args, 168 | Data: t.results, 169 | }, 170 | ProcessedExtraData: &ExtraDataModel{ 171 | Type: t.action, 172 | Request: t.args, 173 | Data: processedResult, 174 | }, 175 | } 176 | } 177 | 178 | func (t *searchTask) request() { 179 | data, _ := json.Marshal(map[string]any{"query": t.args, "topk": "3"}) 180 | res, err := searchHttpClient.Post(config.Config.ToolsSearchUrl, "application/json", bytes.NewBuffer(data)) 181 | if err != nil { 182 | utils.Logger.Error("post search error: ", zap.Error(err)) 183 | t.err = ErrGeneric 184 | return 185 | } 186 | 187 | if res.StatusCode != 200 { 188 | utils.Logger.Error("post search status code error: " + strconv.Itoa(res.StatusCode)) 189 | t.err = ErrGeneric 190 | return 191 | } 192 | 193 | responseData, err := io.ReadAll(res.Body) 194 | if err != nil { 195 | utils.Logger.Error("post search response read error: ", zap.Error(err)) 196 | t.err = ErrGeneric 197 | return 198 | } 199 | // result processing 200 | var results Map 201 | err = json.Unmarshal(responseData, &results) 202 | if err != nil { 203 | utils.Logger.Error("post search response unmarshal error: ", zap.Error(err)) 204 | t.err = ErrGeneric 205 | return 206 | } 207 | 208 | t.results = results 209 | } 210 | -------------------------------------------------------------------------------- /utils/tools/solve.go: -------------------------------------------------------------------------------- 1 | package tools 2 | 3 | import ( 4 | "MOSS_backend/config" 5 | "MOSS_backend/utils" 6 | "bytes" 7 | "encoding/json" 8 | "io" 9 | "net/http" 10 | "strconv" 11 | "time" 12 | 13 | "go.uber.org/zap" 14 | ) 15 | 16 | type solveTask struct { 17 | taskModel 18 | results Map 19 | resultString string 20 | } 21 | 22 | var _ task = (*solveTask)(nil) 23 | 24 | func (t *solveTask) postprocess() *ResultModel { 25 | if t.err != nil { 26 | return NoneResultModel 27 | } 28 | return &ResultModel{ 29 | Result: t.resultString, 30 | ExtraData: &ExtraDataModel{ 31 | Type: "solve", 32 | Request: t.args, 33 | Data: t.results, 34 | }, 35 | ProcessedExtraData: &ExtraDataModel{ 36 | Type: t.action, 37 | Request: t.args, 38 | Data: t.resultString, 39 | }, 40 | } 41 | } 42 | 43 | var solveHttpClient = http.Client{Timeout: 20 * time.Second} 44 | 45 | func (t *solveTask) request() { 46 | data, _ := json.Marshal(map[string]any{"text": t.args}) 47 | res, err := solveHttpClient.Post(config.Config.ToolsSolveUrl, "application/json", bytes.NewBuffer(data)) 48 | if err != nil { 49 | utils.Logger.Error("post solve(tools) error: ", zap.Error(err)) 50 | t.err = ErrGeneric 51 | return 52 | } 53 | 54 | if res.StatusCode != 200 { 55 | utils.Logger.Error("post solve(tools) status code error: " + strconv.Itoa(res.StatusCode)) 56 | t.err = ErrGeneric 57 | return 58 | } 59 | 60 | responseData, err := io.ReadAll(res.Body) 61 | if err != nil { 62 | utils.Logger.Error("post solve(tools) response read error: ", zap.Error(err)) 63 | t.err = ErrGeneric 64 | return 65 | } 66 | 67 | var results map[string]any 68 | err = json.Unmarshal(responseData, &results) 69 | if err != nil { 70 | utils.Logger.Error("post solve(tools) response unmarshal error: ", zap.Error(err)) 71 | t.err = ErrGeneric 72 | return 73 | } 74 | 75 | solveResult, exist := results["result"] 76 | if !exist { 77 | utils.Logger.Error("post solve(tools) response format error: ", zap.Error(keyNotExistError{Results: results})) 78 | t.err = ErrGeneric 79 | return 80 | } 81 | resultsString, ok := solveResult.(string) 82 | if !ok { 83 | utils.Logger.Error("post solve(tools) response format error: ", zap.Error(resultNotStringError{Results: results})) 84 | t.err = ErrGeneric 85 | return 86 | } 87 | if resultsString == `[ERROR]` || resultsString == "" { 88 | utils.Logger.Warn("post solve(tools) request no solution") 89 | t.err = ErrGeneric 90 | return 91 | } 92 | 93 | t.results = results 94 | t.resultString = resultsString 95 | } 96 | 97 | //func solve(request string) (string, map[string]any) { 98 | // data, _ := json.Marshal(map[string]any{"text": request}) 99 | // res, err := solveHttpClient.Post(config.Config.ToolsSolveUrl, "application/json", bytes.NewBuffer(data)) 100 | // if err != nil { 101 | // utils.Logger.Error("post solve(tools) error: ", zap.Error(err)) 102 | // return "None", nil 103 | // } 104 | // 105 | // if res.StatusCode != 200 { 106 | // utils.Logger.Error("post solve(tools) status code error: " + strconv.Itoa(res.StatusCode)) 107 | // return "None", nil 108 | // } 109 | // 110 | // var results map[string]any 111 | // responseData, err := io.ReadAll(res.Body) 112 | // if err != nil { 113 | // utils.Logger.Error("post solve(tools) response read error: ", zap.Error(err)) 114 | // return "None", nil 115 | // } 116 | // err = json.Unmarshal(responseData, &results) 117 | // if err != nil { 118 | // utils.Logger.Error("post solve(tools) response unmarshal error: ", zap.Error(err)) 119 | // return "None", nil 120 | // } 121 | // solveResult, exist := results["result"] 122 | // if !exist { 123 | // utils.Logger.Error("post solve(tools) response format error: ", zap.Error(keyNotExistError{Results: results})) 124 | // return "None", nil 125 | // } 126 | // resultsString, ok := solveResult.(string) 127 | // if !ok { 128 | // utils.Logger.Error("post solve(tools) response format error: ", zap.Error(resultNotStringError{Results: results})) 129 | // return "None", nil 130 | // } 131 | // if resultsString == `[ERROR]` { 132 | // utils.Logger.Warn("post solve(tools) request no solution") 133 | // return "None", nil 134 | // } 135 | // return resultsString, map[string]any{"type": "solve", "data": results, "request": request} 136 | //} 137 | -------------------------------------------------------------------------------- /utils/utils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "github.com/gofiber/fiber/v2" 5 | ) 6 | 7 | type CanPreprocess interface { 8 | Preprocess(c *fiber.Ctx) error 9 | } 10 | 11 | func Serialize(c *fiber.Ctx, obj CanPreprocess) error { 12 | err := obj.Preprocess(c) 13 | if err != nil { 14 | return err 15 | } 16 | return c.JSON(obj) 17 | } 18 | 19 | func GetRealIP(c *fiber.Ctx) string { 20 | IPs := c.IPs() 21 | if len(IPs) > 0 { 22 | return IPs[0] 23 | } else { 24 | return c.Get("X-Real-Ip", c.IP()) 25 | } 26 | } 27 | 28 | func StripContent(content string, length int) string { 29 | return string([]rune(content)[:min(len([]rune(content)), length)]) 30 | } 31 | 32 | func CutLastAny(s string, chars string) (before, after string, found bool) { 33 | sourceRunes := []rune(s) 34 | charRunes := []rune(chars) 35 | maxIndex := -1 36 | for _, char := range charRunes { 37 | index := -1 38 | for i, sourceRune := range sourceRunes { 39 | if char == sourceRune { 40 | index = i 41 | } 42 | } 43 | if index > 0 { 44 | maxIndex = max(maxIndex, index) 45 | } 46 | } 47 | if maxIndex == -1 { 48 | return s, "", false 49 | } else { 50 | return string(sourceRunes[:maxIndex+1]), string(sourceRunes[maxIndex+1:]), true 51 | } 52 | } 53 | 54 | type JSONReader interface { 55 | ReadJson(any) error 56 | } 57 | 58 | type JSONWriter interface { 59 | WriteJSON(any) error 60 | } 61 | 62 | type JsonReaderWriter interface { 63 | JSONReader 64 | JSONWriter 65 | } 66 | -------------------------------------------------------------------------------- /utils/validate.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "encoding/json" 5 | "github.com/creasty/defaults" 6 | "github.com/go-playground/validator/v10" 7 | "github.com/gofiber/fiber/v2" 8 | "reflect" 9 | "strings" 10 | ) 11 | 12 | type ErrorDetailElement struct { 13 | validator.FieldError 14 | Field string `json:"field"` 15 | Tag string `json:"tag"` 16 | Value string `json:"value"` 17 | } 18 | 19 | type ErrorDetail []*ErrorDetailElement 20 | 21 | func (e ErrorDetail) Error() string { 22 | var builder strings.Builder 23 | builder.WriteString("Validation Error: ") 24 | for _, err := range e { 25 | builder.WriteString("invalid " + err.Field) 26 | builder.WriteString("\n") 27 | } 28 | return builder.String() 29 | } 30 | 31 | var validate = validator.New() 32 | 33 | func init() { 34 | validate.RegisterTagNameFunc(func(fld reflect.StructField) string { 35 | name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0] 36 | 37 | if name == "-" { 38 | return "" 39 | } 40 | 41 | return name 42 | }) 43 | } 44 | 45 | func Validate(model any) error { 46 | errors := validate.Struct(model) 47 | if errors != nil { 48 | var errorDetail ErrorDetail 49 | for _, err := range errors.(validator.ValidationErrors) { 50 | detail := ErrorDetailElement{ 51 | FieldError: err, 52 | Field: err.Field(), 53 | Tag: err.Tag(), 54 | Value: err.Param(), 55 | } 56 | errorDetail = append(errorDetail, &detail) 57 | } 58 | return &errorDetail 59 | } 60 | return nil 61 | } 62 | 63 | func ValidateQuery(c *fiber.Ctx, model any) error { 64 | err := c.QueryParser(model) 65 | if err != nil { 66 | return err 67 | } 68 | err = defaults.Set(model) 69 | if err != nil { 70 | return err 71 | } 72 | return Validate(model) 73 | } 74 | 75 | // ValidateBody supports json only 76 | func ValidateBody(c *fiber.Ctx, model any) error { 77 | body := c.Body() 78 | if len(body) == 0 { 79 | body = []byte("{}") 80 | } 81 | err := json.Unmarshal(body, model) 82 | if err != nil { 83 | return err 84 | } 85 | err = defaults.Set(model) 86 | if err != nil { 87 | return err 88 | } 89 | return Validate(model) 90 | } 91 | --------------------------------------------------------------------------------