├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── dev.yaml │ └── main.yaml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── Dockerfile ├── LICENSE ├── README.md ├── apis ├── default.go ├── division │ ├── apis.go │ ├── routes.go │ ├── schemas.go │ └── utils.go ├── favourite │ ├── api.go │ ├── routes.go │ └── schemas.go ├── floor │ ├── apis.go │ ├── routes.go │ ├── schemas.go │ ├── search.go │ └── utils.go ├── hole │ ├── apis.go │ ├── purge_hole.go │ ├── routes.go │ ├── schemas.go │ └── update_views.go ├── message │ ├── apis.go │ ├── purge.go │ ├── routes.go │ └── schemas.go ├── penalty │ └── api.go ├── report │ ├── apis.go │ ├── routes.go │ └── schemas.go ├── routes.go ├── subscription │ ├── api.go │ ├── routes.go │ └── schemas.go ├── tag │ ├── apis.go │ ├── routes.go │ └── schemas.go └── user │ ├── apis.go │ └── schemas.go ├── benchmarks ├── floor_test.go ├── hole_test.go ├── init.go └── utils.go ├── bootstrap └── init.go ├── config └── config.go ├── data ├── data.go ├── meta.json └── names.json ├── docs ├── docs.go ├── swagger.json └── swagger.yaml ├── go.mod ├── go.sum ├── main.go ├── models ├── admin_log.go ├── anonyname.go ├── base.go ├── division.go ├── elastic.go ├── favorite_group.go ├── floor.go ├── floor_history.go ├── floor_like.go ├── floor_mention.go ├── hole.go ├── hole_tags.go ├── init.go ├── message.go ├── notification.go ├── punishment.go ├── report.go ├── report_punishment.go ├── tag.go ├── url_hostname_whitelist.go ├── user.go ├── user_favorite.go ├── user_subscription.go └── user_test.go ├── tests ├── default.go ├── default_test.go ├── division_test.go ├── favorite_test.go ├── floor_test.go ├── hole_test.go ├── init.go ├── report_test.go ├── tag_test.go └── utils.go └── utils ├── bot.go ├── cache.go ├── errors.go ├── log.go ├── model.go ├── name.go ├── sensitive ├── api.go ├── utils.go └── utils_test.go ├── utils.go └── utils_test.go /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug报告 3 | about: 说明你遇到的Bug。 4 | title: '[BUG]' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **描述 Bug** 11 | 简要描述 Bug 是什么。 12 | 如果你认为标题已经说得很清楚了,可以删除这一项。 13 | 14 | **复现步骤** 15 | 复现该 Bug 的步骤: -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 请求新功能 3 | about: 请求后端加入新的功能! 4 | title: '[Feature Request]' 5 | labels: feature request 6 | assignees: '' 7 | 8 | --- 9 | 10 | **你的功能需求和某个bug有关吗?** 11 | `[是/否]` 12 | 13 | **你想要什么样的功能?** 14 | 简要、清晰地描述你请求增加的功能。 15 | -------------------------------------------------------------------------------- /.github/workflows/dev.yaml: -------------------------------------------------------------------------------- 1 | name: Dev Build 2 | on: 3 | push: 4 | branches: [ dev ] 5 | 6 | env: 7 | APP_NAME: treehole_next 8 | 9 | jobs: 10 | docker: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@master 15 | 16 | # - name: Set up Go 17 | # uses: actions/setup-go@master 18 | # with: 19 | # go-version: 1.21.1 20 | # 21 | # - run: go version 22 | # 23 | # - name: Automated Testing 24 | # env: 25 | # MODE: test 26 | # run: go test -v -count=1 -json -tags release ./tests/... 27 | 28 | - name: Set up QEMU 29 | uses: docker/setup-qemu-action@master 30 | 31 | - name: Set up Docker Buildx 32 | uses: docker/setup-buildx-action@master 33 | 34 | - name: Login to DockerHub 35 | uses: docker/login-action@master 36 | with: 37 | username: ${{ secrets.DOCKERHUB_USERNAME }} 38 | password: ${{ secrets.DOCKERHUB_TOKEN }} 39 | 40 | - name: Build and push 41 | id: docker_build 42 | uses: docker/build-push-action@master 43 | with: 44 | push: true 45 | tags: | 46 | ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.APP_NAME }}:latest 47 | ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.APP_NAME }}:dev 48 | 49 | -------------------------------------------------------------------------------- /.github/workflows/main.yaml: -------------------------------------------------------------------------------- 1 | name: Production Build 2 | on: 3 | push: 4 | branches: [ main ] 5 | 6 | env: 7 | APP_NAME: treehole_next 8 | 9 | jobs: 10 | docker: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@master 15 | 16 | # - name: Set up Go 17 | # uses: actions/setup-go@master 18 | # with: 19 | # go-version: 1.21.1 20 | # 21 | # - run: go version 22 | # - name: Automated Testing 23 | # env: 24 | # MODE: test 25 | # run: go test -v -count=1 -json -tags release ./tests/... 26 | 27 | - name: Set up QEMU 28 | uses: docker/setup-qemu-action@master 29 | 30 | - name: Set up Docker Buildx 31 | uses: docker/setup-buildx-action@master 32 | 33 | - name: Login to DockerHub 34 | uses: docker/login-action@master 35 | with: 36 | username: ${{ secrets.DOCKERHUB_USERNAME }} 37 | password: ${{ secrets.DOCKERHUB_TOKEN }} 38 | 39 | - name: Build and push 40 | id: docker_build 41 | uses: docker/build-push-action@master 42 | with: 43 | push: true 44 | tags: | 45 | ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.APP_NAME }}:latest 46 | ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.APP_NAME }}:master 47 | 48 | 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | *.env 4 | *.exe 5 | *.db 6 | *.ps1 7 | *.log 8 | *.out 9 | *.http 10 | .DS_Store -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | dev@danta.tech. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.22-alpine as builder 2 | 3 | WORKDIR /app 4 | 5 | COPY go.mod go.sum ./ 6 | RUN apk add --no-cache --virtual .build-deps \ 7 | ca-certificates \ 8 | tzdata \ 9 | gcc \ 10 | g++ && \ 11 | go mod download 12 | 13 | COPY . . 14 | 15 | RUN go build -ldflags "-s -w" -o treehole 16 | 17 | FROM alpine 18 | 19 | WORKDIR /app 20 | 21 | COPY --from=builder /app/treehole /app/ 22 | COPY --from=builder /usr/share/zoneinfo /usr/share/zoneinfo 23 | COPY data data 24 | 25 | ENV TZ=Asia/Shanghai 26 | ENV MODE=production 27 | 28 | EXPOSE 8000 29 | 30 | ENTRYPOINT ["./treehole"] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Open Tree Hole Next 2 | 3 | Next Generation of OpenTreeHole Written In Golang 4 | 5 | An Anonymous BBS 6 | 7 | ## Features 8 | To be done 9 | 10 | ## Usage 11 | 12 | ### build and run 13 | 14 | Before building, you should have go (>=1.22) installed on your machine. If not, please visit [golang.org](https://golang.org) to download the latest version. 15 | 16 | ```shell 17 | git clone https://github.com/OpenTreeHole/treehole_next.git 18 | cd treehole_next 19 | # install swag and generate docs 20 | go install github.com/swaggo/swag/cmd/swag@latest 21 | swag init --parseDependency --parseDepth 1 # to generate the latest docs, this should be run before compiling 22 | # build for debug 23 | go build -o treehole.exe 24 | # build for release 25 | go build -tags "release" -ldflags "-s -w" -o treehole.exe 26 | # run 27 | ./treehole.exe 28 | ``` 29 | 30 | Note: You should check [config file](config/config.go) to see if required environment variables are set. 31 | 32 | ### test 33 | 34 | ```shell 35 | export MODE=test 36 | go test -v ./tests/... 37 | ``` 38 | 39 | ### benchmark 40 | 41 | ```shell 42 | export MODE=bench 43 | go test -v -benchmem -cpuprofile=cpu.out -benchtime=1s ./benchmarks/... -bench . 44 | ``` 45 | For documentation, please open http://localhost:8000/docs after running app 46 | ## Badge 47 | 48 | [//]: # ([![build](https://github.com/OpenTreeHole/treehole_next/actions/workflows/master.yaml/badge.svg)](https://github.com/OpenTreeHole/treehole_next/actions/workflows/master.yaml)) 49 | [//]: # ([![dev build](https://github.com/OpenTreeHole/treehole_next/actions/workflows/dev.yaml/badge.svg)](https://github.com/OpenTreeHole/treehole_next/actions/workflows/dev.yaml)) 50 | 51 | [![stars](https://img.shields.io/github/stars/OpenTreeHole/treehole_next)](https://github.com/OpenTreeHole/treehole_next/stargazers) 52 | [![issues](https://img.shields.io/github/issues/OpenTreeHole/treehole_next)](https://github.com/OpenTreeHole/treehole_next/issues) 53 | [![pull requests](https://img.shields.io/github/issues-pr/OpenTreeHole/treehole_next)](https://github.com/OpenTreeHole/treehole_next/pulls) 54 | 55 | [![standard-readme compliant](https://img.shields.io/badge/readme%20style-standard-brightgreen.svg?style=flat-square)](https://github.com/RichardLitt/standard-readme) 56 | 57 | ### Powered by 58 | 59 | ![Go](https://img.shields.io/badge/go-%2300ADD8.svg?style=for-the-badge&logo=go&logoColor=white) 60 | ![Swagger](https://img.shields.io/badge/-Swagger-%23Clojure?style=for-the-badge&logo=swagger&logoColor=white) 61 | 62 | ## Contributing 63 | 64 | Feel free to dive in! [Open an issue](https://github.com/OpenTreeHole/treehole_next/issues/new) or [Submit PRs](https://github.com/OpenTreeHole/treehole_next/compare). 65 | 66 | We are now in rapid development, any contribution would be of great help. 67 | For the developing roadmap, please visit [this issue](https://github.com/OpenTreeHole/treehole_next/issues/1). 68 | 69 | ### Contributors 70 | 71 | This project exists thanks to all the people who contribute. 72 | 73 | 74 | contributors 75 | 76 | 77 | ## Licence 78 | 79 | [![license](https://img.shields.io/github/license/OpenTreeHole/treehole_next)](https://github.com/OpenTreeHole/treehole_next/blob/master/LICENSE) 80 | © OpenTreeHole 81 | -------------------------------------------------------------------------------- /apis/default.go: -------------------------------------------------------------------------------- 1 | package apis 2 | 3 | import ( 4 | "treehole_next/data" 5 | 6 | "github.com/gofiber/fiber/v2" 7 | ) 8 | 9 | // Index 10 | // 11 | // @Produce application/json 12 | // @Success 200 {object} models.MessageModel 13 | // @Router / [get] 14 | func Index(c *fiber.Ctx) error { 15 | return c.Send(data.MetaFile) 16 | } 17 | -------------------------------------------------------------------------------- /apis/division/apis.go: -------------------------------------------------------------------------------- 1 | package division 2 | 3 | import ( 4 | "strconv" 5 | 6 | "github.com/goccy/go-json" 7 | "gorm.io/gorm" 8 | "gorm.io/gorm/clause" 9 | 10 | "github.com/opentreehole/go-common" 11 | 12 | . "treehole_next/models" 13 | . "treehole_next/utils" 14 | 15 | "github.com/gofiber/fiber/v2" 16 | ) 17 | 18 | // AddDivision 19 | // 20 | // @Summary Add A Division 21 | // @Tags Division 22 | // @Accept application/json 23 | // @Produce application/json 24 | // @Router /divisions [post] 25 | // @Param json body CreateModel true "json" 26 | // @Success 201 {object} models.Division 27 | // @Success 200 {object} models.Division 28 | func AddDivision(c *fiber.Ctx) error { 29 | // validate body 30 | var body CreateModel 31 | err := common.ValidateBody(c, &body) 32 | if err != nil { 33 | return err 34 | } 35 | 36 | // get user 37 | user, err := GetCurrLoginUser(c) 38 | if err != nil { 39 | return err 40 | } 41 | 42 | // permission check 43 | if !user.IsAdmin { 44 | return common.Forbidden() 45 | } 46 | 47 | // bind division 48 | division := Division{ 49 | Name: body.Name, 50 | Description: body.Description, 51 | } 52 | result := DB.FirstOrCreate(&division, Division{Name: body.Name}) 53 | if result.RowsAffected == 0 { 54 | c.Status(200) 55 | } else { 56 | c.Status(201) 57 | } 58 | return Serialize(c, &division) 59 | } 60 | 61 | // ListDivisions 62 | // 63 | // @Summary List All Divisions 64 | // @Tags Division 65 | // @Produce application/json 66 | // @Router /divisions [get] 67 | // @Success 200 {array} models.Division 68 | func ListDivisions(c *fiber.Ctx) error { 69 | var divisions Divisions 70 | if GetCache("divisions", &divisions) { 71 | return c.JSON(divisions) 72 | } 73 | err := DB.Find(&divisions, "hidden = false").Error 74 | if err != nil { 75 | return err 76 | } 77 | return Serialize(c, divisions) 78 | } 79 | 80 | // GetDivision 81 | // 82 | // @Summary Get Division 83 | // @Tags Division 84 | // @Produce application/json 85 | // @Router /divisions/{id} [get] 86 | // @Param id path int true "id" 87 | // @Success 200 {object} models.Division 88 | // @Failure 404 {object} MessageModel 89 | func GetDivision(c *fiber.Ctx) error { 90 | id, err := c.ParamsInt("id") 91 | if err != nil { 92 | return err 93 | } 94 | var division Division 95 | result := DB.Where("hidden = false").First(&division, id) 96 | if result.Error != nil { 97 | return result.Error 98 | } 99 | return Serialize(c, &division) 100 | } 101 | 102 | // ModifyDivision 103 | // 104 | // @Summary Modify A Division 105 | // @Tags Division 106 | // @Produce json 107 | // @Router /divisions/{id} [put] 108 | // @Router /divisions/{id}/_webvpn [patch] 109 | // @Param id path int true "id" 110 | // @Param json body ModifyDivisionModel true "json" 111 | // @Success 200 {object} models.Division 112 | // @Failure 404 {object} MessageModel 113 | func ModifyDivision(c *fiber.Ctx) error { 114 | // validate body 115 | var body ModifyDivisionModel 116 | err := common.ValidateBody(c, &body) 117 | if err != nil { 118 | return err 119 | } 120 | 121 | id, err := c.ParamsInt("id") 122 | if err != nil { 123 | return err 124 | } 125 | 126 | // get user 127 | user, err := GetCurrLoginUser(c) 128 | if err != nil { 129 | return err 130 | } 131 | 132 | // permission check 133 | if !user.IsAdmin { 134 | return common.Forbidden() 135 | } 136 | 137 | var division Division 138 | err = DB.Transaction(func(tx *gorm.DB) error { 139 | err = tx.Clauses(clause.Locking{Strength: "UPDATE"}).First(&division, id).Error 140 | if err != nil { 141 | return err 142 | } 143 | 144 | modifyData := make(map[string]any) 145 | if body.Name != nil { 146 | modifyData["name"] = *body.Name 147 | } 148 | if body.Description != nil { 149 | modifyData["description"] = *body.Description 150 | } 151 | if body.Pinned != nil { 152 | data, _ := json.Marshal(body.Pinned) 153 | modifyData["pinned"] = string(data) 154 | } 155 | 156 | if len(modifyData) == 0 { 157 | return common.BadRequest("No data to modify.") 158 | } 159 | 160 | return tx.Model(&division).Updates(modifyData).Error 161 | }) 162 | if err != nil { 163 | return err 164 | } 165 | 166 | var newDivision Division 167 | err = DB.First(&newDivision, id).Error 168 | if err != nil { 169 | return err 170 | } 171 | 172 | MyLog("Division", "Modify", division.ID, user.ID, RoleAdmin) 173 | 174 | CreateAdminLog(DB, AdminLogTypeDivision, user.ID, map[string]any{ 175 | "division_id": division.ID, 176 | "before": division, 177 | "after": newDivision, 178 | }) 179 | 180 | // refresh cache. here should not use `go refreshCache` 181 | err = refreshCache(c) 182 | if err != nil { 183 | return err 184 | } 185 | 186 | return Serialize(c, &newDivision) 187 | } 188 | 189 | // DeleteDivision 190 | // 191 | // @Summary Delete A Division 192 | // @Description Delete a division and move all of its holes to another given division 193 | // @Tags Division 194 | // @Produce application/json 195 | // @Router /divisions/{id} [delete] 196 | // @Param id path int true "id" 197 | // @Param json body DeleteModel true "json" 198 | // @Success 204 199 | // @Failure 404 {object} MessageModel 200 | func DeleteDivision(c *fiber.Ctx) error { 201 | // validate body 202 | var body DeleteModel 203 | err := common.ValidateBody(c, &body) 204 | if err != nil { 205 | return err 206 | } 207 | id, err := c.ParamsInt("id") 208 | if err != nil { 209 | return err 210 | } 211 | 212 | // get user 213 | user, err := GetCurrLoginUser(c) 214 | if err != nil { 215 | return err 216 | } 217 | if !user.IsAdmin { 218 | return common.Forbidden() 219 | } 220 | 221 | if id == body.To { 222 | return common.BadRequest("The deleted division can't be the same as to.") 223 | } 224 | err = DB.Exec("UPDATE hole SET division_id = ? WHERE division_id = ?", body.To, id).Error 225 | if err != nil { 226 | return err 227 | } 228 | err = DB.Delete(&Division{ID: id}).Error 229 | if err != nil { 230 | return err 231 | } 232 | 233 | // log 234 | //if err != nil { 235 | // return err 236 | //} 237 | MyLog("Division", "Delete", id, user.ID, RoleAdmin, "To: ", strconv.Itoa(body.To)) 238 | 239 | err = refreshCache(c) 240 | if err != nil { 241 | return err 242 | } 243 | 244 | return c.Status(204).JSON(nil) 245 | } 246 | -------------------------------------------------------------------------------- /apis/division/routes.go: -------------------------------------------------------------------------------- 1 | package division 2 | 3 | import "github.com/gofiber/fiber/v2" 4 | 5 | func RegisterRoutes(app fiber.Router) { 6 | app.Post("/divisions", AddDivision) 7 | app.Get("/divisions", ListDivisions) 8 | app.Get("/divisions/:id", GetDivision) 9 | app.Put("/divisions/:id", ModifyDivision) 10 | app.Patch("/divisions/:id/_webvpn", ModifyDivision) 11 | app.Delete("/divisions/:id", DeleteDivision) 12 | } 13 | -------------------------------------------------------------------------------- /apis/division/schemas.go: -------------------------------------------------------------------------------- 1 | package division 2 | 3 | type DeleteModel struct { 4 | // Admin only 5 | // ID of the target division that all the deleted division's holes will be moved to 6 | To int `json:"to" default:"1"` 7 | } 8 | 9 | type CreateModel struct { 10 | Name string `json:"name"` 11 | Description string `json:"description"` 12 | } 13 | 14 | type ModifyDivisionModel struct { 15 | Name *string `json:"name"` 16 | Description *string `json:"description"` 17 | Pinned []int `json:"pinned"` 18 | } 19 | -------------------------------------------------------------------------------- /apis/division/utils.go: -------------------------------------------------------------------------------- 1 | package division 2 | 3 | import ( 4 | "github.com/gofiber/fiber/v2" 5 | "github.com/rs/zerolog/log" 6 | 7 | . "treehole_next/models" 8 | ) 9 | 10 | func refreshCache(c *fiber.Ctx) error { 11 | 12 | var divisions Divisions 13 | err := DB.Find(&divisions).Error 14 | if err != nil { 15 | return err 16 | } 17 | 18 | err = divisions.Preprocess(c) 19 | if err != nil { 20 | log.Err(err).Msg("error refreshing cache") 21 | return err 22 | } 23 | 24 | return nil 25 | } 26 | -------------------------------------------------------------------------------- /apis/favourite/routes.go: -------------------------------------------------------------------------------- 1 | package favourite 2 | 3 | import "github.com/gofiber/fiber/v2" 4 | 5 | func RegisterRoutes(app fiber.Router) { 6 | app.Get("/user/favorites", ListFavorites) 7 | app.Post("/user/favorites", AddFavorite) 8 | app.Put("/user/favorites", ModifyFavorite) 9 | app.Patch("/user/favorites/_webvpn", ModifyFavorite) 10 | app.Delete("/user/favorites", DeleteFavorite) 11 | app.Get("/user/favorite_groups", ListFavoriteGroups) 12 | app.Post("/user/favorite_groups", AddFavoriteGroup) 13 | app.Put("/user/favorite_groups", ModifyFavoriteGroup) 14 | app.Patch("/user/favorite_groups/_webvpn", ModifyFavoriteGroup) 15 | app.Delete("/user/favorite_groups", DeleteFavoriteGroup) 16 | app.Put("/user/favorites/move", MoveFavorite) 17 | } 18 | -------------------------------------------------------------------------------- /apis/favourite/schemas.go: -------------------------------------------------------------------------------- 1 | package favourite 2 | 3 | type Response struct { 4 | Message string `json:"message"` 5 | Data []int `json:"data"` 6 | } 7 | 8 | type ListFavoriteModel struct { 9 | Order string `json:"order" query:"order" validate:"omitempty,oneof=id time_created hole_time_updated" default:"time_created"` 10 | Plain bool `json:"plain" default:"false" query:"plain"` 11 | FavoriteGroupID *int `json:"favorite_group_id" query:"favorite_group_id"` 12 | } 13 | 14 | type AddModel struct { 15 | HoleID int `json:"hole_id"` 16 | FavoriteGroupID int `json:"favorite_group_id" default:"0"` 17 | } 18 | 19 | type ModifyModel struct { 20 | HoleIDs []int `json:"hole_ids"` 21 | FavoriteGroupID int `json:"favorite_group_id" default:"0"` 22 | } 23 | 24 | type DeleteModel struct { 25 | HoleID int `json:"hole_id"` 26 | FavoriteGroupID int `json:"favorite_group_id" default:"0"` 27 | } 28 | 29 | type AddFavoriteGroupModel struct { 30 | Name string `json:"name" validate:"required,max=64"` 31 | } 32 | 33 | type ModifyFavoriteGroupModel struct { 34 | Name string `json:"name" validate:"required,max=64"` 35 | FavoriteGroupID *int `json:"favorite_group_id" validate:"required"` 36 | } 37 | 38 | type DeleteFavoriteGroupModel struct { 39 | FavoriteGroupID *int `json:"favorite_group_id" validate:"required"` 40 | } 41 | 42 | type MoveModel struct { 43 | HoleIDs []int `json:"hole_ids"` 44 | FromFavoriteGroupID *int `json:"from_favorite_group_id" default:"0" validate:"required"` 45 | ToFavoriteGroupID *int `json:"to_favorite_group_id" validate:"required"` 46 | } 47 | 48 | type ListFavoriteGroupModel struct { 49 | Order string `json:"order" query:"order" validate:"omitempty,oneof=id time_created time_updated" default:"time_created"` 50 | Plain bool `json:"plain" default:"false" query:"plain"` 51 | } 52 | -------------------------------------------------------------------------------- /apis/floor/routes.go: -------------------------------------------------------------------------------- 1 | package floor 2 | 3 | import ( 4 | "github.com/gofiber/fiber/v2" 5 | 6 | "treehole_next/utils" 7 | ) 8 | 9 | func RegisterRoutes(app fiber.Router) { 10 | app.Post("/floors/search", SearchFloors) 11 | app.Get("/floors/search", SearchFloors) 12 | 13 | app.Get("/holes/:id/floors", ListFloorsInAHole) 14 | app.Get("/floors", ListFloorsOld) 15 | app.Get("/floors/:id", GetFloor) 16 | app.Post("/holes/:id/floors", utils.MiddlewareHasAnsweredQuestions, CreateFloor) 17 | app.Post("/floors", utils.MiddlewareHasAnsweredQuestions, CreateFloorOld) 18 | app.Put("/floors/:id", ModifyFloor) 19 | app.Patch("/floors/:id/_webvpn", ModifyFloor) 20 | app.Post("/floors/:id/like/:like", ModifyFloorLike) 21 | app.Delete("/floors/:id", DeleteFloor) 22 | 23 | app.Get("/users/me/floors", ListReplyFloors) 24 | 25 | app.Get("/floors/:id/history", GetFloorHistory) 26 | app.Post("/floors/:id/restore/:floor_history_id", RestoreFloor) 27 | 28 | app.Post("/config/search", SearchConfig) 29 | app.Get("/floors/:id/punishment", GetPunishmentHistory) 30 | app.Get("/floors/:id/user_silence", GetUserSilence) 31 | 32 | app.Get("/floors/_sensitive", ListSensitiveFloors) 33 | app.Put("/floors/:id/_sensitive", ModifyFloorSensitive) 34 | app.Patch("/floors/:id/_sensitive/_webvpn", ModifyFloorSensitive) 35 | } 36 | -------------------------------------------------------------------------------- /apis/floor/schemas.go: -------------------------------------------------------------------------------- 1 | package floor 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/opentreehole/go-common" 7 | 8 | "treehole_next/models" 9 | ) 10 | 11 | type ListModel struct { 12 | Size int `json:"size" query:"size" default:"30" validate:"min=0,max=50"` // length of object array 13 | Offset int `json:"offset" query:"offset" default:"0" validate:"min=0"` // offset of object array 14 | Sort string `json:"sort" query:"sort" default:"asc" validate:"oneof=asc desc"` // Sort order 15 | OrderBy string `json:"order_by" query:"order_by" default:"id" validate:"oneof=id like"` // SQL ORDER BY field 16 | } 17 | 18 | type ListOldModel struct { 19 | HoleID int `query:"hole_id" json:"hole_id"` 20 | Size int `query:"length" json:"length" validate:"min=0,max=50" ` 21 | Offset int `query:"start_floor" json:"start_floor"` 22 | Search string `query:"s" json:"s"` 23 | } 24 | 25 | type CreateModel struct { 26 | Content string `json:"content" validate:"required"` 27 | // Admin and Operator only 28 | SpecialTag string `json:"special_tag" validate:"omitempty,max=16"` 29 | // id of the floor to which replied 30 | ReplyTo int `json:"reply_to" validate:"min=0"` 31 | } 32 | 33 | type CreateOldModel struct { 34 | HoleID int `json:"hole_id" validate:"min=1"` 35 | CreateModel 36 | } 37 | 38 | type CreateOldResponse struct { 39 | Data models.Floor `json:"data"` 40 | Message string `json:"message"` 41 | } 42 | 43 | type ModifyModel struct { 44 | // Owner or admin, the original content should be moved to floor_history 45 | Content *string `json:"content" validate:"omitempty"` 46 | // Admin and Operator only 47 | SpecialTag *string `json:"special_tag" validate:"omitempty,max=16"` 48 | // All user, deprecated, "add" is like, "cancel" is reset 49 | Like *string `json:"like" validate:"omitempty,oneof=add cancel"` 50 | // 仅管理员,留空则重置,高优先级 51 | Fold *string `json:"fold_v2" validate:"omitempty,max=64"` 52 | // 仅管理员,留空则重置,低优先级 53 | FoldFrontend []string `json:"fold" validate:"omitempty"` 54 | } 55 | 56 | func (body ModifyModel) DoNothing() bool { 57 | return body.Content == nil && body.SpecialTag == nil && body.Like == nil && body.Fold == nil && body.FoldFrontend == nil 58 | } 59 | 60 | func (body ModifyModel) CheckPermission(user *models.User, floor *models.Floor, hole *models.Hole) error { 61 | if body.Content != nil { 62 | if !user.IsAdmin { 63 | if user.ID != floor.UserID { 64 | return common.Forbidden("这不是您的楼层,您没有权限修改") 65 | } else { 66 | if user.BanDivision[hole.DivisionID] != nil { 67 | return common.Forbidden(user.BanDivisionMessage(hole.DivisionID)) 68 | } else if hole.Locked { 69 | return common.Forbidden("此洞已被锁定,您无法修改") 70 | } else if floor.Deleted { 71 | return common.Forbidden("此洞已被删除,您无法修改") 72 | } 73 | } 74 | } else { 75 | if user.BanDivision[hole.DivisionID] != nil { 76 | return common.Forbidden(user.BanDivisionMessage(hole.DivisionID)) 77 | } 78 | } 79 | } 80 | if (body.Fold != nil || body.FoldFrontend != nil) && !user.IsAdmin { 81 | return common.Forbidden("非管理员禁止折叠") 82 | } 83 | if body.SpecialTag != nil && !user.IsAdmin { 84 | return common.Forbidden("非管理员禁止修改特殊标签") 85 | } 86 | return nil 87 | } 88 | 89 | type DeleteModel struct { 90 | Reason string `json:"delete_reason" validate:"max=32"` 91 | } 92 | 93 | type RestoreModel struct { 94 | Reason string `json:"restore_reason" validate:"required,max=32"` 95 | } 96 | 97 | type SearchConfigModel struct { 98 | Open bool `json:"open"` 99 | } 100 | 101 | type SensitiveFloorRequest struct { 102 | Size int `json:"size" query:"size" default:"10" validate:"max=10"` 103 | Offset common.CustomTime `json:"offset" query:"offset" swaggertype:"string"` 104 | OrderBy string `json:"order_by" query:"order_by" default:"time_created" validate:"oneof=time_created time_updated"` 105 | Open bool `json:"open" query:"open"` 106 | All bool `json:"all" query:"all"` 107 | } 108 | 109 | type SensitiveFloorResponse struct { 110 | ID int `json:"id"` 111 | CreatedAt time.Time `json:"time_created"` 112 | UpdatedAt time.Time `json:"time_updated"` 113 | Content string `json:"content"` 114 | Modified int `json:"modified"` 115 | IsActualSensitive *bool `json:"is_actual_sensitive"` 116 | HoleID int `json:"hole_id"` 117 | Deleted bool `json:"deleted"` 118 | SensitiveDetail string `json:"sensitive_detail,omitempty"` 119 | } 120 | 121 | func (s *SensitiveFloorResponse) FromModel(floor *models.Floor) *SensitiveFloorResponse { 122 | s.ID = floor.ID 123 | s.CreatedAt = floor.CreatedAt 124 | s.UpdatedAt = floor.UpdatedAt 125 | s.Content = floor.Content 126 | s.Modified = floor.Modified 127 | s.IsActualSensitive = floor.IsActualSensitive 128 | s.HoleID = floor.HoleID 129 | s.Deleted = floor.Deleted 130 | s.SensitiveDetail = floor.SensitiveDetail 131 | return s 132 | } 133 | 134 | type ModifySensitiveFloorRequest struct { 135 | IsActualSensitive bool `json:"is_actual_sensitive"` 136 | } 137 | 138 | type BanDivision map[int]*time.Time 139 | -------------------------------------------------------------------------------- /apis/floor/search.go: -------------------------------------------------------------------------------- 1 | package floor 2 | 3 | import ( 4 | "github.com/gofiber/fiber/v2" 5 | "github.com/opentreehole/go-common" 6 | 7 | . "treehole_next/config" 8 | . "treehole_next/models" 9 | . "treehole_next/utils" 10 | ) 11 | 12 | // SearchQuery is the query struct for searching floors 13 | type SearchQuery struct { 14 | Search string `json:"search" query:"search" validate:"required"` 15 | Size int `json:"size" query:"size" validate:"min=0" default:"10"` 16 | Offset int `json:"offset" query:"offset" validate:"min=0" default:"0"` 17 | 18 | // Accurate is used to determine whether to use accurate search 19 | Accurate bool `json:"accurate" query:"accurate" default:"false"` 20 | 21 | // StartTime and EndTime are used to filter floors by time 22 | // Both are Unix timestamps, and are optional 23 | StartTime *int64 `json:"start_time" query:"start_time"` 24 | EndTime *int64 `json:"end_time" query:"end_time"` 25 | } 26 | 27 | // SearchFloors 28 | // 29 | // @Summary SearchFloors In ElasticSearch 30 | // @Tags Search 31 | // @Produce application/json 32 | // @Router /floors/search [get] 33 | // @Router /floors/search [post] 34 | // @Param object query SearchQuery true "search_query" 35 | // @Success 200 {array} models.Floor 36 | func SearchFloors(c *fiber.Ctx) error { 37 | var query SearchQuery 38 | err := common.ValidateQuery(c, &query) 39 | if err != nil { 40 | return err 41 | } 42 | 43 | floors, err := Search(c, query.Search, query.Size, query.Offset, query.Accurate, query.StartTime, query.EndTime) 44 | if err != nil { 45 | return err 46 | } 47 | 48 | return Serialize(c, floors) 49 | } 50 | 51 | // SearchConfig 52 | // 53 | // @Summary change search config 54 | // @Tags Search 55 | // @Produce application/json 56 | // @Router /config/search [post] 57 | // @Param json body SearchConfigModel true "json" 58 | // @Success 200 {object} Map 59 | func SearchConfig(c *fiber.Ctx) error { 60 | var body SearchConfigModel 61 | err := c.BodyParser(&body) 62 | if err != nil { 63 | return err 64 | } 65 | user, err := GetCurrLoginUser(c) 66 | if err != nil { 67 | return err 68 | } 69 | if !user.IsAdmin { 70 | return common.Forbidden() 71 | } 72 | if DynamicConfig.OpenSearch.Load() == body.Open { 73 | return c.Status(200).JSON(Map{"message": "已经被修改"}) 74 | } else { 75 | DynamicConfig.OpenSearch.Store(body.Open) 76 | return c.Status(201).JSON(Map{"message": "修改成功"}) 77 | } 78 | } 79 | 80 | func SearchFloorsOld(c *fiber.Ctx, query *ListOldModel) error { 81 | if !DynamicConfig.OpenSearch.Load() { 82 | return common.Forbidden("茶楼流量激增,搜索功能暂缓开放") 83 | } 84 | 85 | floors, err := Search(c, query.Search, query.Size, query.Offset, false, nil, nil) 86 | if err != nil { 87 | return err 88 | } 89 | 90 | return Serialize(c, &floors) 91 | } 92 | -------------------------------------------------------------------------------- /apis/floor/utils.go: -------------------------------------------------------------------------------- 1 | package floor 2 | 3 | import "fmt" 4 | 5 | func generateDeleteReason(reason string, isOwner bool) string { 6 | if reason == "" { 7 | if isOwner { 8 | return "该内容被作者删除" 9 | } 10 | reason = "违反社区规范" 11 | } 12 | return fmt.Sprintf("该内容因%s被删除", reason) 13 | } 14 | -------------------------------------------------------------------------------- /apis/hole/purge_hole.go: -------------------------------------------------------------------------------- 1 | package hole 2 | 3 | import ( 4 | "context" 5 | "github.com/rs/zerolog/log" 6 | "gorm.io/gorm" 7 | "gorm.io/gorm/clause" 8 | "time" 9 | 10 | "treehole_next/config" 11 | . "treehole_next/models" 12 | ) 13 | 14 | func purgeHole() (err error) { 15 | const REASON = "purge_hole" 16 | const DELETE_CONTENT = "该内容已被删除" 17 | 18 | return DB.Transaction(func(tx *gorm.DB) (err error) { 19 | 20 | // load holeIDs, lock for update 21 | var holeIDs []int 22 | err = tx.Model(&Hole{}). 23 | Clauses(clause.Locking{Strength: "UPDATE"}). 24 | Where("no_purge = ?", false). 25 | Where("division_id IN ?", config.Config.HolePurgeDivisions). 26 | Where("updated_at < ?", 27 | time.Now().AddDate(0, 0, -config.Config.HolePurgeDays), 28 | ).Pluck("id", &holeIDs).Error 29 | if err != nil { 30 | return err 31 | } 32 | 33 | if len(holeIDs) == 0 { 34 | return nil 35 | } 36 | 37 | /* delete all floors in hole of holeIOs */ 38 | 39 | // get floors, lock for update 40 | var floors []Floor 41 | err = tx. 42 | Clauses(clause.Locking{Strength: "UPDATE"}). 43 | Where("hole_id IN ?", holeIDs). 44 | Find(&floors).Error 45 | if err != nil { 46 | return err 47 | } 48 | if len(floors) == 0 { 49 | return nil 50 | } 51 | 52 | // generate floorHistory 53 | var floorHistorySlice = make([]FloorHistory, 0, len(floors)) 54 | for i := range floors { 55 | floorHistorySlice = append(floorHistorySlice, FloorHistory{ 56 | Content: floors[i].Content, 57 | Reason: REASON, 58 | FloorID: floors[i].ID, 59 | UserID: 1, 60 | SensitiveDetail: floors[i].SensitiveDetail, 61 | IsActualSensitive: floors[i].IsActualSensitive, 62 | IsSensitive: floors[i].IsSensitive, 63 | }) 64 | } 65 | err = tx.Create(&floorHistorySlice).Error 66 | if err != nil { 67 | return err 68 | } 69 | 70 | // delete floors 71 | var floorIDs = make([]int, 0, len(floors)) 72 | for i := range floors { 73 | floorIDs = append(floorIDs, floors[i].ID) 74 | } 75 | 76 | err = tx.Model(&Floor{}). 77 | Where("id IN ?", floorIDs). 78 | Updates(map[string]any{ 79 | "deleted": true, 80 | "content": DELETE_CONTENT, 81 | }).Error 82 | if err != nil { 83 | return err 84 | } 85 | 86 | /* delete all holes in holeIDs */ 87 | err = tx. 88 | Where("id IN ?", holeIDs). 89 | Delete(&Hole{}).Error 90 | if err != nil { 91 | return err 92 | } 93 | 94 | // delete floor in search engine 95 | go BulkDelete(floorIDs) 96 | 97 | // log 98 | log.Info(). 99 | Ints("hole_ids", holeIDs). 100 | Ints("floor_ids", floorIDs). 101 | Msg("purge hole") 102 | 103 | return nil 104 | }) 105 | } 106 | 107 | func PurgeHole(ctx context.Context) { 108 | ticker := time.NewTicker(time.Minute * 10) 109 | defer ticker.Stop() 110 | for { 111 | select { 112 | case <-ticker.C: 113 | err := purgeHole() 114 | if err != nil { 115 | log.Err(err).Msg("error purge hole") 116 | } 117 | case <-ctx.Done(): 118 | return 119 | } 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /apis/hole/routes.go: -------------------------------------------------------------------------------- 1 | package hole 2 | 3 | import ( 4 | "github.com/gofiber/fiber/v2" 5 | 6 | "treehole_next/utils" 7 | ) 8 | 9 | func RegisterRoutes(app fiber.Router) { 10 | app.Get("/divisions/:id/holes", ListHolesByDivision) 11 | app.Get("/tags/:name/holes", ListHolesByTag) 12 | app.Get("/users/me/holes", ListHolesByMe) 13 | app.Get("/holes/:id", GetHole) 14 | app.Get("/holes", ListHoles) 15 | app.Get("/holes/_good", ListGoodHoles) 16 | app.Post("/divisions/:id/holes", utils.MiddlewareHasAnsweredQuestions, CreateHole) 17 | app.Post("/holes", utils.MiddlewareHasAnsweredQuestions, CreateHoleOld) 18 | app.Patch("/holes/:id/_webvpn", ModifyHole) 19 | app.Patch("/holes/:id", PatchHole) 20 | app.Put("/holes/:id", ModifyHole) 21 | app.Delete("/holes/:id", HideHole) 22 | app.Delete("/holes/:id/_force", DeleteHole) 23 | } 24 | -------------------------------------------------------------------------------- /apis/hole/schemas.go: -------------------------------------------------------------------------------- 1 | package hole 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/opentreehole/go-common" 7 | 8 | "treehole_next/apis/tag" 9 | "treehole_next/models" 10 | ) 11 | 12 | type QueryTime struct { 13 | Size int `json:"size" query:"size" default:"10" validate:"max=10"` 14 | // updated time < offset (default is now) 15 | Offset common.CustomTime `json:"offset" query:"offset" swaggertype:"string"` 16 | Order string `json:"order" query:"order"` 17 | } 18 | 19 | func (q *QueryTime) SetDefaults() { 20 | if q.Offset.IsZero() { 21 | q.Offset = common.CustomTime{Time: time.Now()} 22 | } 23 | } 24 | 25 | type ListOldModel struct { 26 | Offset0 common.CustomTime `json:"start_time" query:"start_time" swaggertype:"string"` 27 | Offset common.CustomTime `json:"offset" query:"offset" swaggertype:"string"` 28 | Size0 int `json:"length" query:"length" default:"10" validate:"max=10"` 29 | Size int `json:"size" query:"size" default:"10" validate:"max=10" ` 30 | Tag string `json:"tag" query:"tag"` 31 | Tags []string `json:"tags" query:"tags"` 32 | DivisionID int `json:"division_id" query:"division_id"` 33 | Order string `json:"order" query:"order"` 34 | CreatedStart *common.CustomTime `json:"created_start" query:"created_start" swaggertype:"string"` 35 | CreatedEnd *common.CustomTime `json:"created_end" query:"created_end" swaggertype:"string"` 36 | } 37 | 38 | func (q *ListOldModel) SetDefaults() { 39 | if q.Size == 0 { 40 | q.Size = q.Size0 41 | } 42 | if q.Offset.IsZero() { 43 | if q.Offset0.IsZero() { 44 | q.Offset = common.CustomTime{Time: time.Now()} 45 | } else { 46 | q.Offset = q.Offset0 47 | } 48 | } 49 | if q.CreatedStart == nil { 50 | q.CreatedStart = &common.CustomTime{Time: time.Time{}} // 默认值为零时间 51 | } 52 | if q.CreatedEnd == nil { 53 | q.CreatedEnd = &common.CustomTime{Time: time.Now()} 54 | } 55 | } 56 | 57 | type TagCreateModelSlice struct { 58 | Tags []tag.CreateModel `json:"tags" validate:"omitempty,min=1,max=10,dive"` // All users 59 | } 60 | 61 | func (tagCreateModelSlice TagCreateModelSlice) ToName() []string { 62 | tags := make([]string, 0, len(tagCreateModelSlice.Tags)) 63 | for _, tagCreateModel := range tagCreateModelSlice.Tags { 64 | tags = append(tags, tagCreateModel.Name) 65 | } 66 | return tags 67 | } 68 | 69 | type CreateModel struct { 70 | Content string `json:"content" validate:"required"` 71 | TagCreateModelSlice 72 | // Admin and Operator only 73 | SpecialTag string `json:"special_tag" validate:"max=16"` 74 | } 75 | 76 | type CreateOldModel struct { 77 | CreateModel 78 | DivisionID int `json:"division_id" validate:"omitempty,min=1" default:"1"` 79 | } 80 | 81 | type CreateOldResponse struct { 82 | Data models.Hole `json:"data"` 83 | Message string `json:"message"` 84 | } 85 | 86 | type ModifyModel struct { 87 | TagCreateModelSlice 88 | DivisionID *int `json:"division_id" validate:"omitempty,min=1"` // Admin and owner only 89 | Hidden *bool `json:"hidden"` // Admin only 90 | Unhidden *bool `json:"unhidden"` // admin only 91 | Lock *bool `json:"lock"` // admin only 92 | } 93 | 94 | func (body ModifyModel) CheckPermission(user *models.User, hole *models.Hole) error { 95 | if body.DivisionID != nil && !user.IsAdmin { 96 | return common.Forbidden("非管理员禁止修改分区") 97 | } 98 | if body.Hidden != nil && !user.IsAdmin { 99 | return common.Forbidden("非管理员禁止隐藏帖子") 100 | } 101 | if body.Unhidden != nil && !user.IsAdmin { 102 | return common.BadRequest("非管理员禁止取消隐藏") 103 | } 104 | if body.Tags != nil && !(user.IsAdmin) { 105 | return common.Forbidden() 106 | } 107 | if body.Tags != nil && len(body.Tags) == 0 { 108 | return common.BadRequest("tags 不能为空") 109 | } 110 | if body.Lock != nil && !user.IsAdmin { 111 | return common.Forbidden("非管理员禁止锁定帖子") 112 | } 113 | return nil 114 | } 115 | 116 | func (body ModifyModel) DoNothing() bool { 117 | return body.Hidden == nil && body.Unhidden == nil && body.Tags == nil && body.DivisionID == nil && body.Lock == nil 118 | } 119 | -------------------------------------------------------------------------------- /apis/hole/update_views.go: -------------------------------------------------------------------------------- 1 | package hole 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strconv" 7 | "strings" 8 | "time" 9 | 10 | "github.com/rs/zerolog/log" 11 | 12 | . "treehole_next/models" 13 | ) 14 | 15 | var holeViewsChan = make(chan int, 1000) 16 | var holeViews = map[int]int{} 17 | 18 | func updateHoleViews() { 19 | /* 20 | UPDATE table 21 | SET field = CASE id 22 | WHEN 1 THEN 'value' 23 | WHEN 2 THEN 'value' 24 | WHEN 3 THEN 'value' 25 | END 26 | WHERE id IN (1,2,3) 27 | */ 28 | length := len(holeViews) 29 | if length == 0 { 30 | return 31 | } 32 | keys := make([]string, 0, length) 33 | 34 | var builder strings.Builder 35 | builder.WriteString("UPDATE hole SET view = CASE id ") 36 | 37 | for holeID, views := range holeViews { 38 | builder.WriteString(fmt.Sprintf("WHEN %d THEN view + %d ", holeID, views)) 39 | keys = append(keys, strconv.Itoa(holeID)) 40 | delete(holeViews, holeID) 41 | } 42 | builder.WriteString("END WHERE id IN (") 43 | builder.WriteString(strings.Join(keys, ",")) 44 | builder.WriteString(")") 45 | 46 | result := DB.Exec(builder.String()) 47 | if result.Error != nil { 48 | log.Err(result.Error).Msg("update hole views failed") 49 | } else { 50 | log.Info().Strs("updated", keys).Msg("update hole views success") 51 | } 52 | } 53 | 54 | func UpdateHoleViews(ctx context.Context) { 55 | 56 | ticker := time.NewTicker(time.Second * 60) 57 | defer ticker.Stop() 58 | for { 59 | select { 60 | case <-ticker.C: 61 | updateHoleViews() 62 | case holeID := <-holeViewsChan: 63 | holeViews[holeID]++ 64 | case <-ctx.Done(): 65 | updateHoleViews() 66 | log.Info().Msg("task UpdateHoleViews stopped...") 67 | return 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /apis/message/apis.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "github.com/opentreehole/go-common" 5 | 6 | . "treehole_next/models" 7 | . "treehole_next/utils" 8 | 9 | "github.com/gofiber/fiber/v2" 10 | ) 11 | 12 | // ListMessages 13 | // @Summary List Messages of a User 14 | // @Tags Message 15 | // @Produce application/json 16 | // @Router /messages [get] 17 | // @Success 200 {array} Message 18 | // @Param object query ListModel false "query" 19 | func ListMessages(c *fiber.Ctx) error { 20 | var query ListModel 21 | err := common.ValidateQuery(c, &query) 22 | if err != nil { 23 | return err 24 | } 25 | 26 | userID, err := common.GetUserID(c) 27 | if err != nil { 28 | return err 29 | } 30 | 31 | messages := Messages{} 32 | 33 | if query.NotRead { 34 | DB.Raw(` 35 | SELECT message.*,message_user.has_read FROM message 36 | INNER JOIN message_user 37 | WHERE message.id = message_user.message_id and message_user.user_id = ? and message_user.has_read = false 38 | ORDER BY updated_at DESC`, 39 | userID, 40 | ).Scan(&messages) 41 | } else { 42 | DB.Raw(` 43 | SELECT message.*,message_user.has_read FROM message 44 | INNER JOIN message_user 45 | WHERE message.id = message_user.message_id and message_user.user_id = ? 46 | ORDER BY updated_at DESC`, 47 | userID, 48 | ).Scan(&messages) 49 | } 50 | 51 | return Serialize(c, &messages) 52 | } 53 | 54 | // SendMail 55 | // @Summary Send a Mail 56 | // @Description Send to multiple recipients and save to db, admin only. 57 | // @Tags Message 58 | // @Produce application/json 59 | // @Param json body CreateModel true "json" 60 | // @Router /messages [post] 61 | // @Success 201 {object} Message 62 | func SendMail(c *fiber.Ctx) error { 63 | var body CreateModel 64 | err := common.ValidateBody(c, &body) 65 | if err != nil { 66 | return err 67 | } 68 | 69 | // get user 70 | user, err := GetCurrLoginUser(c) 71 | if err != nil { 72 | return err 73 | } 74 | 75 | // permission 76 | if !user.IsAdmin { 77 | return common.Forbidden() 78 | } 79 | 80 | // construct mail 81 | mail := Notification{ 82 | Description: body.Description, 83 | Recipients: body.Recipients, 84 | Data: Map{}, 85 | Title: "您有一封站内信", 86 | Type: MessageTypeMail, 87 | URL: "/api/messages", 88 | } 89 | 90 | // send 91 | message, err := mail.Send() 92 | if err != nil { 93 | return err 94 | } 95 | 96 | CreateAdminLog(DB, AdminLogTypeMessage, user.ID, body) 97 | 98 | return Serialize(c.Status(201), &message) 99 | } 100 | 101 | // ClearMessages 102 | // @Summary Clear Messages of a User 103 | // @Tags Message 104 | // @Produce application/json 105 | // @Router /messages/clear [post] 106 | // @Success 204 107 | func ClearMessages(c *fiber.Ctx) error { 108 | userID, err := common.GetUserID(c) 109 | if err != nil { 110 | return err 111 | } 112 | 113 | result := DB.Exec( 114 | "UPDATE message_user SET has_read = true WHERE user_id = ?", 115 | userID, 116 | ) 117 | if result.Error != nil { 118 | return result.Error 119 | } 120 | return c.Status(204).JSON(nil) 121 | } 122 | 123 | // ClearMessagesDeprecated 124 | // @Summary Clear Messages Deprecated 125 | // @Tags Message 126 | // @Produce application/json 127 | // @Router /messages [put] 128 | // @Router /messages/_webvpn [patch] 129 | // @Success 204 130 | func ClearMessagesDeprecated(c *fiber.Ctx) error { 131 | return ClearMessages(c) 132 | } 133 | 134 | // DeleteMessage 135 | // @Summary Delete a message of a user 136 | // @Tags Message 137 | // @Produce application/json 138 | // @Router /messages/{id} [delete] 139 | // @Param id path int true "message id" 140 | // @Success 204 141 | func DeleteMessage(c *fiber.Ctx) error { 142 | userID, err := common.GetUserID(c) 143 | if err != nil { 144 | return err 145 | } 146 | 147 | id, _ := c.ParamsInt("id") 148 | result := DB.Exec( 149 | "UPDATE message_user SET has_read = true WHERE user_id = ? AND message_id = ?", 150 | userID, id, 151 | ) 152 | if result.Error != nil { 153 | return result.Error 154 | } 155 | return c.Status(204).JSON(nil) 156 | } 157 | -------------------------------------------------------------------------------- /apis/message/purge.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/rs/zerolog/log" 7 | 8 | "treehole_next/config" 9 | . "treehole_next/models" 10 | ) 11 | 12 | func purgeMessage() error { 13 | return DB.Exec( 14 | "DELETE FROM message WHERE created_at < ?", 15 | time.Now().Add(-time.Hour*24*time.Duration(config.Config.MessagePurgeDays)), 16 | ).Error 17 | } 18 | 19 | func PurgeMessage() { 20 | ticker := time.NewTicker(time.Hour * 24) 21 | defer ticker.Stop() 22 | for range ticker.C { 23 | err := purgeMessage() 24 | if err != nil { 25 | log.Err(err).Msg("error purge message") 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /apis/message/routes.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | import "github.com/gofiber/fiber/v2" 4 | 5 | func RegisterRoutes(app fiber.Router) { 6 | app.Post("/messages", SendMail) 7 | app.Get("/messages", ListMessages) 8 | app.Post("/messages/clear", ClearMessages) 9 | app.Put("/messages", ClearMessagesDeprecated) 10 | app.Patch("/messages/_webvpn", ClearMessagesDeprecated) 11 | app.Delete("/messages/:id", DeleteMessage) 12 | } 13 | -------------------------------------------------------------------------------- /apis/message/schemas.go: -------------------------------------------------------------------------------- 1 | package message 2 | 3 | type CreateModel struct { 4 | // MessageTypeMail 5 | Description string `json:"description"` 6 | Recipients []int `json:"recipients" validate:"required"` 7 | } 8 | 9 | type ListModel struct { 10 | NotRead bool `json:"not_read" default:"false" query:"not_read"` 11 | } 12 | -------------------------------------------------------------------------------- /apis/penalty/api.go: -------------------------------------------------------------------------------- 1 | // Package penalty is deprecated! Please use APIs in auth. 2 | package penalty 3 | 4 | import ( 5 | "fmt" 6 | "time" 7 | 8 | "treehole_next/config" 9 | . "treehole_next/models" 10 | "treehole_next/utils" 11 | 12 | "github.com/opentreehole/go-common" 13 | "gorm.io/gorm" 14 | "gorm.io/gorm/clause" 15 | 16 | "github.com/gofiber/fiber/v2" 17 | ) 18 | 19 | type PostBody struct { 20 | PenaltyLevel *int `json:"penalty_level" validate:"omitempty"` // low priority, deprecated 21 | Days *int `json:"days" validate:"omitempty,min=1"` // high priority 22 | Divisions []int `json:"divisions" validate:"omitempty,min=1"` // high priority 23 | Reason string `json:"reason"` // optional 24 | } 25 | 26 | type ForeverPostBody struct { 27 | Reason string `json:"reason"` // optional 28 | } 29 | 30 | // BanUser 31 | // 32 | // @Summary Ban publisher of a floor 33 | // @Tags Penalty 34 | // @Produce json 35 | // @Router /penalty/{floor_id} [post] 36 | // @Param json body PostBody true "json" 37 | // @Success 201 {object} User 38 | func BanUser(c *fiber.Ctx) error { 39 | // validate body 40 | var body PostBody 41 | err := common.ValidateBody(c, &body) 42 | if err != nil { 43 | return err 44 | } 45 | 46 | floorID, err := c.ParamsInt("id") 47 | if err != nil { 48 | return err 49 | } 50 | 51 | // get user 52 | user, err := GetCurrLoginUser(c) 53 | if err != nil { 54 | return err 55 | } 56 | 57 | // permission 58 | if !user.IsAdmin { 59 | return common.Forbidden() 60 | } 61 | 62 | var floor Floor 63 | err = DB.Take(&floor, floorID).Error 64 | if err != nil { 65 | return err 66 | } 67 | 68 | var hole Hole 69 | err = DB.Take(&hole, floor.HoleID).Error 70 | if err != nil { 71 | return err 72 | } 73 | 74 | var days int 75 | if body.Days != nil { 76 | days = *body.Days 77 | if days <= 0 { 78 | days = 1 79 | } 80 | } else if body.PenaltyLevel != nil { 81 | switch *body.PenaltyLevel { 82 | case 1: 83 | days = 1 84 | case 2: 85 | days = 5 86 | case 3: 87 | days = 999 88 | default: 89 | days = 1 90 | } 91 | } 92 | 93 | duration := time.Duration(days) * 24 * time.Hour 94 | 95 | punishment := Punishment{ 96 | UserID: floor.UserID, 97 | MadeBy: user.ID, 98 | FloorID: &floor.ID, 99 | DivisionID: hole.DivisionID, 100 | Duration: &duration, 101 | Day: days, 102 | Reason: body.Reason, 103 | } 104 | user, err = punishment.Create() 105 | if err != nil { 106 | return err 107 | } 108 | 109 | // construct message for user 110 | message := Notification{ 111 | Data: floor, 112 | Recipients: []int{floor.UserID}, 113 | Description: fmt.Sprintf( 114 | "您因为违反社区公约被禁言。时间:%d天,原因:%s\n如有异议,请联系admin@danta.tech。", 115 | days, 116 | body.Reason, 117 | ), 118 | Title: "处罚通知", 119 | Type: MessageTypePermission, 120 | URL: fmt.Sprintf("/api/floors/%d", floor.ID), 121 | } 122 | 123 | // send 124 | _, err = message.Send() 125 | if err != nil { 126 | return err 127 | } 128 | 129 | return c.JSON(user) 130 | } 131 | 132 | // BanUserForever 133 | // 134 | // @Summary Ban publisher of a floor forever 135 | // @Tags Penalty 136 | // @Produce json 137 | // @Router /penalty/{floor_id}/_forever [post] 138 | // @Param json body ForeverPostBody true "json" 139 | // @Success 201 {object} User 140 | func BanUserForever(c *fiber.Ctx) error { 141 | // validate body 142 | var body ForeverPostBody 143 | err := common.ValidateBody(c, &body) 144 | if err != nil { 145 | return err 146 | } 147 | 148 | floorID, err := c.ParamsInt("id") 149 | if err != nil { 150 | return err 151 | } 152 | 153 | // get user 154 | user, err := GetCurrLoginUser(c) 155 | if err != nil { 156 | return err 157 | } 158 | 159 | // permission 160 | if !user.IsAdmin { 161 | return common.Forbidden() 162 | } 163 | 164 | var floor Floor 165 | err = DB.Take(&floor, floorID).Error 166 | if err != nil { 167 | return err 168 | } 169 | 170 | // var hole Hole 171 | // err = DB.Take(&hole, floor.HoleID).Error 172 | // if err != nil { 173 | // return err 174 | // } 175 | 176 | days := 3650 177 | duration := time.Duration(days) * 24 * time.Hour 178 | 179 | var punishments Punishments 180 | var punishment *Punishment 181 | var divisionIDs []int 182 | madeBy := user.ID 183 | user = &User{ 184 | ID: floor.UserID, 185 | } 186 | err = DB.Transaction(func(tx *gorm.DB) (err error) { 187 | err = tx.Clauses(clause.Locking{Strength: "UPDATE"}).Take(&user).Error 188 | if err != nil { 189 | return err 190 | } 191 | 192 | err = tx.Clauses(clause.Locking{Strength: "UPDATE"}).Model(&Division{}).Select("ID").Scan(&divisionIDs).Error 193 | if err != nil { 194 | return err 195 | } 196 | 197 | ExcludeBanForeverDivisionIds := config.Config.ExcludeBanForeverDivisionIds 198 | 199 | divisionIDs = utils.Difference(divisionIDs, ExcludeBanForeverDivisionIds) 200 | 201 | for _, divisionID := range divisionIDs { 202 | punishment = &Punishment{ 203 | UserID: floor.UserID, 204 | MadeBy: madeBy, 205 | FloorID: nil, 206 | DivisionID: divisionID, 207 | Duration: &duration, 208 | Day: days, 209 | Reason: body.Reason, 210 | StartTime: time.Now(), 211 | EndTime: time.Now().Add(duration), 212 | } 213 | 214 | if user.BanDivision[divisionID] == nil { 215 | user.BanDivision[divisionID] = &punishment.EndTime 216 | } else { 217 | user.BanDivision[divisionID].Add(*punishment.Duration) 218 | } 219 | 220 | punishments = append(punishments, punishment) 221 | } 222 | user.OffenceCount += len(divisionIDs) 223 | 224 | err = tx.Create(&punishments).Error 225 | if err != nil { 226 | return err 227 | } 228 | 229 | err = tx.Select("BanDivision", "OffenceCount").Save(&user).Error 230 | if err != nil { 231 | return err 232 | } 233 | 234 | return nil 235 | }) 236 | if err != nil { 237 | return err 238 | } 239 | 240 | // construct message for user 241 | message := Notification{ 242 | Data: floor, 243 | Recipients: []int{floor.UserID}, 244 | Description: fmt.Sprintf( 245 | "您因为违反社区公约被禁言。时间:%d天,原因:%s\n如有异议,请联系admin@danta.tech。", 246 | days, 247 | body.Reason, 248 | ), 249 | Title: "处罚通知", 250 | Type: MessageTypePermission, 251 | URL: fmt.Sprintf("/api/floors/%d", floor.ID), 252 | } 253 | 254 | // send 255 | _, err = message.Send() 256 | if err != nil { 257 | return err 258 | } 259 | 260 | return c.JSON(user) 261 | } 262 | 263 | // ListMyPunishments godoc 264 | // @Summary List my punishments 265 | // @Tags Penalty 266 | // @Produce json 267 | // @Router /users/me/punishments [get] 268 | // @Success 200 {array} Punishment 269 | func ListMyPunishments(c *fiber.Ctx) error { 270 | userID, err := common.GetUserID(c) 271 | if err != nil { 272 | return err 273 | } 274 | 275 | punishments, err := listPunishmentsByUserID(userID) 276 | if err != nil { 277 | return err 278 | } 279 | 280 | return c.JSON(punishments) 281 | } 282 | 283 | // ListPunishmentsByUserID godoc 284 | // @Summary List punishments by user id 285 | // @Tags Penalty 286 | // @Produce json 287 | // @Router /users/{id}/punishments [get] 288 | // @Param id path int true "User ID" 289 | // @Success 200 {array} Punishment 290 | func ListPunishmentsByUserID(c *fiber.Ctx) error { 291 | userID, err := c.ParamsInt("id") 292 | if err != nil { 293 | return err 294 | } 295 | 296 | currentUser, err := GetCurrLoginUser(c) 297 | if err != nil { 298 | return err 299 | } 300 | if !currentUser.IsAdmin && currentUser.ID != userID { 301 | return common.Forbidden() 302 | } 303 | 304 | punishments, err := listPunishmentsByUserID(userID) 305 | if err != nil { 306 | return err 307 | } 308 | 309 | return c.JSON(punishments) 310 | } 311 | 312 | func listPunishmentsByUserID(userID int) ([]Punishment, error) { 313 | var punishments []Punishment 314 | err := DB.Where("user_id = ?", userID).Preload("Floor").Find(&punishments).Error 315 | if err != nil { 316 | return nil, err 317 | } 318 | 319 | // remove made_by 320 | for i := range punishments { 321 | punishments[i].MadeBy = 0 322 | } 323 | 324 | return punishments, nil 325 | } 326 | 327 | func RegisterRoutes(app fiber.Router) { 328 | app.Post("/penalty/:id/_forever", BanUserForever) 329 | app.Post("/penalty/:id", BanUser) 330 | app.Get("/users/me/punishments", ListMyPunishments) 331 | app.Get("/users/:id/punishments", ListPunishmentsByUserID) 332 | } 333 | -------------------------------------------------------------------------------- /apis/report/apis.go: -------------------------------------------------------------------------------- 1 | package report 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | . "treehole_next/models" 7 | . "treehole_next/utils" 8 | 9 | "github.com/opentreehole/go-common" 10 | "github.com/rs/zerolog/log" 11 | 12 | "github.com/gofiber/fiber/v2" 13 | "gorm.io/gorm" 14 | ) 15 | 16 | // GetReport 17 | // 18 | // @Summary Get A Report 19 | // @Tags Report 20 | // @Produce application/json 21 | // @Router /reports/{id} [get] 22 | // @Param id path int true "id" 23 | // @Success 200 {object} Report 24 | // @Failure 404 {object} MessageModel 25 | func GetReport(c *fiber.Ctx) error { 26 | // validate query 27 | reportID, err := c.ParamsInt("id") 28 | if err != nil { 29 | return err 30 | } 31 | 32 | // find report 33 | var report Report 34 | result := LoadReportFloor(DB).First(&report, reportID) 35 | if result.Error != nil { 36 | return result.Error 37 | } 38 | return Serialize(c, &report) 39 | } 40 | 41 | // ListReports 42 | // 43 | // @Summary List All Reports 44 | // @Tags Report 45 | // @Produce application/json 46 | // @Router /reports [get] 47 | // @Param object query ListModel false "query" 48 | // @Success 200 {array} Report 49 | // @Failure 404 {object} MessageModel 50 | func ListReports(c *fiber.Ctx) error { 51 | // validate query 52 | var query ListModel 53 | err := common.ValidateQuery(c, &query) 54 | if err != nil { 55 | return err 56 | } 57 | 58 | // find reports 59 | var reports Reports 60 | 61 | querySet := LoadReportFloor(query.BaseQuery()) 62 | 63 | var result *gorm.DB 64 | switch query.Range { 65 | case RangeNotDealt: 66 | result = querySet.Find(&reports, "dealt = ?", false) 67 | case RangeDealt: 68 | result = querySet.Find(&reports, "dealt = ?", true) 69 | case RangeAll: 70 | result = querySet.Find(&reports) 71 | } 72 | if result.Error != nil { 73 | return result.Error 74 | } 75 | return Serialize(c, &reports) 76 | } 77 | 78 | // AddReport 79 | // 80 | // @Summary Add a report 81 | // @Description Add a report and send notification to admins 82 | // @Tags Report 83 | // @Produce application/json 84 | // @Router /reports [post] 85 | // @Param json body AddModel true "json" 86 | // @Success 204 87 | // 88 | // @Failure 400 {object} common.HttpError 89 | func AddReport(c *fiber.Ctx) error { 90 | // validate body 91 | var body AddModel 92 | err := common.ValidateBody(c, &body) 93 | if err != nil { 94 | return err 95 | } 96 | 97 | user, err := GetCurrLoginUser(c) 98 | if err != nil { 99 | return err 100 | } 101 | 102 | // permission 103 | if user.BanReport != nil { 104 | return common.Forbidden(user.BanReportMessage()) 105 | } 106 | 107 | // add report 108 | report := Report{ 109 | FloorID: body.FloorID, 110 | Reason: body.Reason, 111 | Dealt: false, 112 | } 113 | err = report.Create(c) 114 | if err != nil { 115 | return err 116 | } 117 | 118 | // Send Notification 119 | err = report.SendCreate(DB) 120 | if err != nil { 121 | log.Err(err).Str("model", "Notification").Msg("SendCreate failed: ") 122 | // return err // only for test 123 | } 124 | 125 | return c.Status(204).JSON(nil) 126 | } 127 | 128 | // DeleteReport 129 | // 130 | // @Summary Deal a report 131 | // @Description Mark a report as "dealt" and send notification to reporter 132 | // @Tags Report 133 | // @Produce application/json 134 | // @Router /reports/{id} [delete] 135 | // @Param id path int true "id" 136 | // @Param json body DeleteModel true "json" 137 | // @Success 200 {object} Report 138 | // @Failure 400 {object} common.HttpError 139 | func DeleteReport(c *fiber.Ctx) error { 140 | // validate query 141 | reportID, err := c.ParamsInt("id") 142 | if err != nil { 143 | return err 144 | } 145 | 146 | // validate body 147 | var body DeleteModel 148 | err = common.ValidateBody(c, &body) 149 | if err != nil { 150 | return err 151 | } 152 | 153 | // get user id 154 | userID, err := common.GetUserID(c) 155 | if err != nil { 156 | return err 157 | } 158 | 159 | // modify report 160 | var report Report 161 | result := LoadReportFloor(DB).First(&report, reportID) 162 | if result.Error != nil { 163 | return result.Error 164 | } 165 | report.Dealt = true 166 | report.DealtBy = userID 167 | report.Result = body.Result 168 | DB.Omit("Floor").Save(&report) 169 | 170 | MyLog("Report", "Delete", reportID, userID, RoleAdmin) 171 | CreateAdminLog(DB, AdminLogTypeDeleteReport, userID, report) 172 | 173 | // Send Notification 174 | err = report.SendModify(DB) 175 | if err != nil { 176 | log.Err(err).Str("model", "Notification").Msg("SendModify failed") 177 | // return err // only for test 178 | } 179 | 180 | return Serialize(c, &report) 181 | } 182 | 183 | type banBody struct { 184 | Days *int `json:"days" validate:"omitempty,min=1"` 185 | Reason string `json:"reason"` // optional 186 | } 187 | 188 | // BanReporter 189 | // 190 | // @Summary Ban reporter of a report 191 | // @Tags Report 192 | // @Produce json 193 | // @Router /reports/ban/{id} [post] 194 | // @Param json body banBody true "json" 195 | // @Success 201 {object} User 196 | func BanReporter(c *fiber.Ctx) error { 197 | // validate body 198 | var body banBody 199 | err := common.ValidateBody(c, &body) 200 | if err != nil { 201 | return err 202 | } 203 | 204 | reportID, err := c.ParamsInt("id") 205 | if err != nil { 206 | return err 207 | } 208 | 209 | // get user 210 | user, err := GetCurrLoginUser(c) 211 | if err != nil { 212 | return err 213 | } 214 | 215 | // permission 216 | if !user.IsAdmin { 217 | return common.Forbidden() 218 | } 219 | 220 | var report Report 221 | err = DB.Take(&report, reportID).Error 222 | if err != nil { 223 | return err 224 | } 225 | 226 | var days int 227 | if body.Days != nil { 228 | days = *body.Days 229 | if days <= 0 { 230 | days = 1 231 | } 232 | } else { 233 | days = 1 234 | } 235 | 236 | duration := time.Duration(days) * 24 * time.Hour 237 | 238 | reportPunishment := ReportPunishment{ 239 | UserID: report.UserID, 240 | MadeBy: user.ID, 241 | ReportId: report.ID, 242 | Duration: &duration, 243 | Reason: body.Reason, 244 | } 245 | user, err = reportPunishment.Create() 246 | if err != nil { 247 | return err 248 | } 249 | 250 | // construct message for user 251 | message := Notification{ 252 | Data: report, 253 | Recipients: []int{report.UserID}, 254 | Description: fmt.Sprintf( 255 | "您因违反社区公约被禁止举报。时间:%d天,原因:%s\n如有异议,请联系admin@danta.tech。", 256 | days, 257 | body.Reason, 258 | ), 259 | Title: "处罚通知", 260 | Type: MessageTypePermission, 261 | URL: fmt.Sprintf("/api/reports/%d", report.ID), 262 | } 263 | 264 | // send 265 | _, err = message.Send() 266 | if err != nil { 267 | return err 268 | } 269 | 270 | return c.JSON(user) 271 | } 272 | -------------------------------------------------------------------------------- /apis/report/routes.go: -------------------------------------------------------------------------------- 1 | package report 2 | 3 | import "github.com/gofiber/fiber/v2" 4 | 5 | func RegisterRoutes(app fiber.Router) { 6 | app.Get("/reports/:id", GetReport) 7 | app.Get("/reports", ListReports) 8 | app.Post("/reports", AddReport) 9 | app.Delete("/reports/:id", DeleteReport) 10 | 11 | app.Post("/reports/ban/:id", BanReporter) 12 | } 13 | -------------------------------------------------------------------------------- /apis/report/schemas.go: -------------------------------------------------------------------------------- 1 | package report 2 | 3 | import ( 4 | "fmt" 5 | 6 | "gorm.io/gorm" 7 | 8 | . "treehole_next/models" 9 | ) 10 | 11 | type Range int 12 | 13 | const ( 14 | RangeNotDealt Range = iota 15 | RangeDealt 16 | RangeAll 17 | ) 18 | 19 | type ListModel struct { 20 | Size int `query:"size" default:"30" validate:"min=0,max=50"` 21 | Offset int `query:"offset" default:"0" validate:"min=0"` 22 | OrderBy string `query:"order_by" default:"id"` 23 | // Sort order, default is desc 24 | Sort string `json:"sort" query:"sort" default:"desc" validate:"oneof=asc desc"` 25 | // Range, 0: not dealt, 1: dealt, 2: all 26 | Range Range `json:"range"` 27 | } 28 | 29 | func (q *ListModel) BaseQuery() *gorm.DB { 30 | return DB. 31 | Limit(q.Size). 32 | Offset(q.Offset). 33 | Order(fmt.Sprintf("`report`.`%s` %s", q.OrderBy, q.Sort)) 34 | } 35 | 36 | type AddModel struct { 37 | FloorID int `json:"floor_id" validate:"required"` 38 | Reason string `json:"reason" validate:"required,max=128"` 39 | } 40 | 41 | type DeleteModel struct { 42 | // The deal result, send it to reporter 43 | Result string `json:"result" validate:"required,max=128"` 44 | } 45 | -------------------------------------------------------------------------------- /apis/routes.go: -------------------------------------------------------------------------------- 1 | package apis 2 | 3 | import ( 4 | "github.com/opentreehole/go-common" 5 | 6 | "treehole_next/apis/division" 7 | "treehole_next/apis/favourite" 8 | "treehole_next/apis/floor" 9 | "treehole_next/apis/hole" 10 | "treehole_next/apis/message" 11 | "treehole_next/apis/penalty" 12 | "treehole_next/apis/report" 13 | "treehole_next/apis/subscription" 14 | "treehole_next/apis/tag" 15 | "treehole_next/apis/user" 16 | "treehole_next/config" 17 | _ "treehole_next/docs" 18 | "treehole_next/models" 19 | 20 | "github.com/gofiber/fiber/v2" 21 | fiberSwagger "github.com/swaggo/fiber-swagger" 22 | ) 23 | 24 | func registerRoutes(app *fiber.App) { 25 | app.Get("/", func(c *fiber.Ctx) error { 26 | return c.Redirect("/api") 27 | }) 28 | app.Get("/docs", func(c *fiber.Ctx) error { 29 | return c.Redirect("/docs/index.html") 30 | }) 31 | app.Get("/docs/*", fiberSwagger.WrapHandler) 32 | } 33 | 34 | func RegisterRoutes(app *fiber.App) { 35 | registerRoutes(app) 36 | 37 | group := app.Group("/api") 38 | group.Get("/", Index) 39 | group.Use(MiddlewareGetUser) 40 | division.RegisterRoutes(group) 41 | tag.RegisterRoutes(group) 42 | hole.RegisterRoutes(group) 43 | floor.RegisterRoutes(group) 44 | report.RegisterRoutes(group) 45 | favourite.RegisterRoutes(group) 46 | subscription.RegisterRoutes(group) 47 | penalty.RegisterRoutes(group) 48 | user.RegisterRoutes(group) 49 | message.RegisterRoutes(group) 50 | } 51 | 52 | func MiddlewareGetUser(c *fiber.Ctx) error { 53 | userObject, err := models.GetCurrLoginUser(c) 54 | if err != nil { 55 | return err 56 | } 57 | c.Locals("user", userObject) 58 | if config.Config.AdminOnly { 59 | if !userObject.IsAdmin { 60 | return common.Forbidden() 61 | } 62 | } 63 | return c.Next() 64 | } 65 | -------------------------------------------------------------------------------- /apis/subscription/api.go: -------------------------------------------------------------------------------- 1 | package subscription 2 | 3 | import ( 4 | "github.com/gofiber/fiber/v2" 5 | "github.com/opentreehole/go-common" 6 | "gorm.io/gorm" 7 | "gorm.io/plugin/dbresolver" 8 | 9 | . "treehole_next/models" 10 | . "treehole_next/utils" 11 | ) 12 | 13 | // ListSubscriptions 14 | // 15 | // @Summary List User's Subscriptions 16 | // @Tags Subscription 17 | // @Produce application/json 18 | // @Router /users/subscriptions [get] 19 | // @Param object query ListModel false "query" 20 | // @Success 200 {object} models.Map 21 | // @Success 200 {array} models.Hole 22 | func ListSubscriptions(c *fiber.Ctx) error { 23 | // get userID 24 | userID, err := common.GetUserID(c) 25 | if err != nil { 26 | return err 27 | } 28 | 29 | var query ListModel 30 | err = common.ValidateQuery(c, &query) 31 | if err != nil { 32 | return err 33 | } 34 | 35 | if query.Plain { 36 | data, err := UserGetSubscriptionData(DB, userID) 37 | if err != nil { 38 | return err 39 | } 40 | return c.JSON(Map{"data": data}) 41 | } else { 42 | holes := make(Holes, 0) 43 | err := DB. 44 | Joins("JOIN user_subscription ON user_subscription.hole_id = hole.id AND user_subscription.user_id = ?", userID). 45 | Order("user_subscription.created_at desc").Find(&holes).Error 46 | if err != nil { 47 | return err 48 | } 49 | return Serialize(c, &holes) 50 | } 51 | } 52 | 53 | // AddSubscription 54 | // 55 | // @Summary Add A Subscription 56 | // @Tags Subscription 57 | // @Accept application/json 58 | // @Produce application/json 59 | // @Router /users/subscriptions [post] 60 | // @Param json body AddModel true "json" 61 | // @Success 201 {object} Response 62 | func AddSubscription(c *fiber.Ctx) error { 63 | // validate body 64 | var body AddModel 65 | err := common.ValidateBody(c, &body) 66 | if err != nil { 67 | return err 68 | } 69 | 70 | // get userID 71 | userID, err := common.GetUserID(c) 72 | if err != nil { 73 | return err 74 | } 75 | 76 | var data []int 77 | 78 | err = DB.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error { 79 | // add favorites 80 | err = AddUserSubscription(tx, userID, body.HoleID) 81 | if err != nil { 82 | return err 83 | } 84 | 85 | // create response 86 | data, err = UserGetSubscriptionData(tx, userID) 87 | return err 88 | }) 89 | if err != nil { 90 | return err 91 | } 92 | 93 | return c.Status(201).JSON(&Response{ 94 | Message: "关注成功", 95 | Data: data, 96 | }) 97 | } 98 | 99 | // DeleteSubscription 100 | // 101 | // @Summary Delete A Subscription 102 | // @Tags Subscription 103 | // @Produce application/json 104 | // @Router /users/subscription [delete] 105 | // @Param json body DeleteModel true "json" 106 | // @Success 200 {object} Response 107 | // @Failure 404 {object} Response 108 | func DeleteSubscription(c *fiber.Ctx) error { 109 | // validate body 110 | var body DeleteModel 111 | err := common.ValidateBody(c, &body) 112 | if err != nil { 113 | return err 114 | } 115 | 116 | // get userID 117 | userID, err := common.GetUserID(c) 118 | if err != nil { 119 | return err 120 | } 121 | 122 | // delete subscriptions 123 | err = DB.Delete(UserSubscription{UserID: userID, HoleID: body.HoleID}).Error 124 | if err != nil { 125 | return err 126 | } 127 | 128 | // create response 129 | data, err := UserGetSubscriptionData(DB, userID) 130 | if err != nil { 131 | return err 132 | } 133 | 134 | return c.JSON(&Response{ 135 | Message: "删除成功", 136 | Data: data, 137 | }) 138 | } 139 | -------------------------------------------------------------------------------- /apis/subscription/routes.go: -------------------------------------------------------------------------------- 1 | package subscription 2 | 3 | import "github.com/gofiber/fiber/v2" 4 | 5 | func RegisterRoutes(app fiber.Router) { 6 | app.Get("/users/subscriptions", ListSubscriptions) 7 | app.Post("/users/subscriptions", AddSubscription) 8 | app.Delete("/users/subscriptions", DeleteSubscription) 9 | app.Delete("/users/subscription", DeleteSubscription) 10 | } 11 | -------------------------------------------------------------------------------- /apis/subscription/schemas.go: -------------------------------------------------------------------------------- 1 | package subscription 2 | 3 | type Response struct { 4 | Message string `json:"message"` 5 | Data []int `json:"data"` 6 | } 7 | 8 | type ListModel struct { 9 | Plain bool `json:"plain" default:"false" query:"plain"` 10 | } 11 | 12 | type AddModel struct { 13 | HoleID int `json:"hole_id"` 14 | } 15 | 16 | type DeleteModel struct { 17 | HoleID int `json:"hole_id"` 18 | } 19 | -------------------------------------------------------------------------------- /apis/tag/apis.go: -------------------------------------------------------------------------------- 1 | package tag 2 | 3 | import ( 4 | "strings" 5 | "time" 6 | "treehole_next/utils/sensitive" 7 | 8 | "github.com/opentreehole/go-common" 9 | "gorm.io/plugin/dbresolver" 10 | 11 | . "treehole_next/models" 12 | . "treehole_next/utils" 13 | 14 | "github.com/gofiber/fiber/v2" 15 | "gorm.io/gorm" 16 | ) 17 | 18 | // ListTags 19 | // 20 | // @Summary List All Tags 21 | // @Tags Tag 22 | // @Produce application/json 23 | // @Param object query SearchModel false "query" 24 | // @Router /tags [get] 25 | // @Success 200 {array} Tag 26 | func ListTags(c *fiber.Ctx) error { 27 | var query SearchModel 28 | err := common.ValidateQuery(c, &query) 29 | if err != nil { 30 | return err 31 | } 32 | 33 | tags := make(Tags, 0, 10) 34 | if query.Search == "" { 35 | if GetCache("tags", &tags) { 36 | return c.JSON(&tags) 37 | } else { 38 | err = DB.Order("temperature DESC").Find(&tags).Error 39 | if err != nil { 40 | return err 41 | } 42 | go UpdateTagCache(tags) 43 | return Serialize(c, &tags) 44 | } 45 | } 46 | err = DB.Where("name LIKE ?", "%"+query.Search+"%"). 47 | Order("temperature DESC").Find(&tags).Error 48 | if err != nil { 49 | return err 50 | } 51 | return Serialize(c, &tags) 52 | } 53 | 54 | // GetTag 55 | // 56 | // @Summary Get A Tag 57 | // @Tags Tag 58 | // @Produce application/json 59 | // @Router /tags/{id} [get] 60 | // @Param id path int true "id" 61 | // @Success 200 {object} Tag 62 | // @Failure 404 {object} MessageModel 63 | func GetTag(c *fiber.Ctx) error { 64 | id, _ := c.ParamsInt("id") 65 | var tag Tag 66 | tag.ID = id 67 | result := DB.First(&tag) 68 | if result.Error != nil { 69 | return result.Error 70 | } 71 | return Serialize(c, &tag) 72 | } 73 | 74 | // CreateTag 75 | // 76 | // @Summary Create A Tag 77 | // @Tags Tag 78 | // @Produce application/json 79 | // @Router /tags [post] 80 | // @Param json body CreateModel true "json" 81 | // @Success 200 {object} Tag 82 | // @Success 201 {object} Tag 83 | func CreateTag(c *fiber.Ctx) error { 84 | // validate body 85 | var tag Tag 86 | var body CreateModel 87 | err := common.ValidateBody(c, &body) 88 | if err != nil { 89 | return err 90 | } 91 | 92 | // check tag prefix 93 | user, err := GetCurrLoginUser(c) 94 | if err != nil { 95 | return err 96 | } 97 | if !user.IsAdmin { 98 | if len(tag.Name) > 15 && len([]rune(tag.Name)) > 10 { 99 | return common.BadRequest("标签长度不能超过 10 个字符") 100 | } 101 | if strings.HasPrefix(body.Name, "#") { 102 | return common.BadRequest("只有管理员才能创建 # 开头的 tag") 103 | } 104 | if strings.HasPrefix(body.Name, "@") { 105 | return common.BadRequest("只有管理员才能创建 @ 开头的 tag") 106 | } 107 | if strings.HasPrefix(tag.Name, "*") { 108 | return common.BadRequest("只有管理员才能创建 * 开头的 tag") 109 | } 110 | } 111 | 112 | sensitiveResp, err := sensitive.CheckSensitive(sensitive.ParamsForCheck{ 113 | Content: body.Name, 114 | Id: time.Now().UnixNano(), 115 | TypeName: sensitive.TypeTag, 116 | }) 117 | if err != nil { 118 | return err 119 | } 120 | tag.IsSensitive = !sensitiveResp.Pass 121 | 122 | // bind and create tag 123 | body.Name = strings.TrimSpace(body.Name) 124 | tag.Name = body.Name 125 | result := DB.Where("name = ?", body.Name).FirstOrCreate(&tag) 126 | 127 | if result.RowsAffected == 0 { 128 | c.Status(200) 129 | } else { 130 | c.Status(201) 131 | } 132 | return Serialize(c, &tag) 133 | } 134 | 135 | // ModifyTag 136 | // 137 | // @Summary Modify A Tag, admin only 138 | // @Tags Tag 139 | // @Produce application/json 140 | // @Router /tags/{id} [put] 141 | // @Router /tags/{id}/_webvpn [patch] 142 | // @Param id path int true "id" 143 | // @Param json body ModifyModel true "json" 144 | // @Success 200 {object} Tag 145 | // @Failure 404 {object} MessageModel 146 | func ModifyTag(c *fiber.Ctx) error { 147 | // admin 148 | user, err := GetCurrLoginUser(c) 149 | if err != nil { 150 | return err 151 | } 152 | if !user.IsAdmin { 153 | return common.Forbidden() 154 | } 155 | 156 | // validate body 157 | var body ModifyModel 158 | err = common.ValidateBody(c, &body) 159 | if err != nil { 160 | return err 161 | } 162 | id, err := c.ParamsInt("id") 163 | if err != nil { 164 | return err 165 | } 166 | 167 | // modify tag 168 | var tag Tag 169 | DB.Find(&tag, id) 170 | tag.Name = strings.TrimSpace(body.Name) 171 | tag.Temperature = body.Temperature 172 | 173 | sensitiveResp, err := sensitive.CheckSensitive(sensitive.ParamsForCheck{ 174 | Content: body.Name, 175 | Id: time.Now().UnixNano(), 176 | TypeName: sensitive.TypeTag, 177 | }) 178 | if err != nil { 179 | return err 180 | } 181 | tag.IsSensitive = !sensitiveResp.Pass 182 | 183 | DB.Save(&tag) 184 | 185 | // log 186 | userID, err := common.GetUserID(c) 187 | if err != nil { 188 | return err 189 | } 190 | MyLog("Tag", "Modify", tag.ID, userID, RoleAdmin) 191 | CreateAdminLog(DB, AdminLogTypeTag, userID, struct { 192 | TagID int `json:"tag_id"` 193 | Body ModifyModel `json:"body"` 194 | }{ 195 | TagID: tag.ID, 196 | Body: body, 197 | }) 198 | 199 | return Serialize(c, &tag) 200 | } 201 | 202 | // DeleteTag 203 | // 204 | // @Summary Delete A Tag 205 | // @Description Delete a tag and link all of its holes to another given tag 206 | // @Tags Tag 207 | // @Produce application/json 208 | // @Router /tags/{id} [delete] 209 | // @Param id path int true "id" 210 | // @Param json body DeleteModel true "json" 211 | // @Success 200 {object} Tag 212 | // @Failure 404 {object} MessageModel 213 | func DeleteTag(c *fiber.Ctx) error { 214 | // admin 215 | user, err := GetCurrLoginUser(c) 216 | if err != nil { 217 | return err 218 | } 219 | if !user.IsAdmin { 220 | return common.Forbidden() 221 | } 222 | 223 | // validate body 224 | var body DeleteModel 225 | err = common.ValidateBody(c, &body) 226 | if err != nil { 227 | return err 228 | } 229 | 230 | id, err := c.ParamsInt("id") 231 | if err != nil { 232 | return err 233 | } 234 | 235 | var tag Tag 236 | result := DB.First(&tag, id) 237 | if result.Error != nil { 238 | return result.Error 239 | } 240 | 241 | var newTag Tag 242 | result = DB.Where("name = ?", body.To).First(&newTag) 243 | if result.Error != nil { 244 | return result.Error 245 | } 246 | 247 | newTag.Temperature += tag.Temperature 248 | 249 | err = DB.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error { 250 | result = tx.Exec(` 251 | DELETE FROM hole_tags WHERE tag_id = ? AND hole_id IN 252 | (SELECT a.hole_id FROM 253 | (SELECT hole_id FROM hole_tags WHERE tag_id = ?)a 254 | )`, id, newTag.ID) 255 | if result.Error != nil { 256 | return result.Error 257 | } 258 | 259 | result = tx.Exec(`UPDATE hole_tags SET tag_id = ? WHERE tag_id = ?`, newTag.ID, id) 260 | if result.Error != nil { 261 | return result.Error 262 | } 263 | 264 | result = tx.Updates(&newTag) 265 | if result.Error != nil { 266 | return result.Error 267 | } 268 | 269 | result = tx.Delete(&tag) 270 | if result.Error != nil { 271 | return result.Error 272 | } 273 | 274 | return nil 275 | }) 276 | if err != nil { 277 | return err 278 | } 279 | 280 | // log 281 | userID, err := common.GetUserID(c) 282 | if err != nil { 283 | return err 284 | } 285 | MyLog("Tag", "Delete", id, userID, RoleAdmin) 286 | return Serialize(c, &newTag) 287 | } 288 | -------------------------------------------------------------------------------- /apis/tag/routes.go: -------------------------------------------------------------------------------- 1 | package tag 2 | 3 | import "github.com/gofiber/fiber/v2" 4 | 5 | func RegisterRoutes(app fiber.Router) { 6 | app.Get("/tags", ListTags) 7 | app.Get("/tags/:id", GetTag) 8 | app.Post("/tags", CreateTag) 9 | app.Put("/tags/:id", ModifyTag) 10 | app.Patch("/tags/:id/_webvpn", ModifyTag) 11 | app.Delete("/tags/:id", DeleteTag) 12 | } 13 | -------------------------------------------------------------------------------- /apis/tag/schemas.go: -------------------------------------------------------------------------------- 1 | package tag 2 | 3 | type CreateModel struct { 4 | Name string `json:"name,omitempty" validate:"max=20"` // Admin only 5 | } 6 | 7 | type ModifyModel struct { 8 | CreateModel 9 | Temperature int `json:"temperature,omitempty"` // Admin only 10 | } 11 | 12 | type DeleteModel struct { 13 | // Admin only 14 | // Name of the target tag that all the deleted tag's holes will be connected to 15 | To string `json:"to,omitempty"` 16 | } 17 | 18 | type SearchModel struct { 19 | Search string `json:"s" query:"s" validate:"max=32"` // search tag by name 20 | } 21 | -------------------------------------------------------------------------------- /apis/user/apis.go: -------------------------------------------------------------------------------- 1 | package user 2 | 3 | import ( 4 | "github.com/gofiber/fiber/v2" 5 | "github.com/opentreehole/go-common" 6 | "gorm.io/gorm/clause" 7 | 8 | . "treehole_next/models" 9 | ) 10 | 11 | func RegisterRoutes(app fiber.Router) { 12 | app.Get("/users/me", GetCurrentUser) 13 | app.Get("/users/:id", GetUserByID) 14 | app.Put("/users/:id", ModifyUser) 15 | app.Patch("/users/:id/_webvpn", ModifyUser) 16 | app.Put("/users/me", ModifyCurrentUser) 17 | app.Patch("/users/me/_webvpn", ModifyCurrentUser) 18 | } 19 | 20 | // GetCurrentUser 21 | // 22 | // @Summary get current user 23 | // @Tags user 24 | // @Deprecated 25 | // @Produce json 26 | // @Router /users/me [get] 27 | // @Success 200 {object} User 28 | func GetCurrentUser(c *fiber.Ctx) error { 29 | user, err := GetCurrLoginUser(c) 30 | if err != nil { 31 | return err 32 | } 33 | return c.JSON(&user) 34 | } 35 | 36 | // GetUserByID 37 | // 38 | // @Summary get user by id, owner or admin 39 | // @Tags user 40 | // @Produce json 41 | // @Router /users/{user_id} [get] 42 | // @Success 200 {object} User 43 | func GetUserByID(c *fiber.Ctx) error { 44 | userID, err := c.ParamsInt("id") 45 | if err != nil { 46 | return err 47 | } 48 | 49 | user, err := GetCurrLoginUser(c) 50 | if err != nil { 51 | return err 52 | } 53 | 54 | if !user.IsAdmin || user.ID == userID { 55 | return common.Forbidden() 56 | } 57 | 58 | var getUser User 59 | err = getUser.LoadUserByID(userID) 60 | if err != nil { 61 | return err 62 | } 63 | 64 | return c.JSON(&getUser) 65 | } 66 | 67 | // ModifyUser 68 | // 69 | // @Summary modify user profiles 70 | // @Tags User 71 | // @Produce json 72 | // @Router /users/{user_id} [put] 73 | // @Router /users/{user_id}/_webvpn [patch] 74 | // @Param user_id path int true "user id" 75 | // @Param json body ModifyModel true "modify user" 76 | // @Success 200 {object} User 77 | func ModifyUser(c *fiber.Ctx) error { 78 | userID, err := c.ParamsInt("id") 79 | if err != nil { 80 | return err 81 | } 82 | 83 | user, err := GetCurrLoginUser(c) 84 | if err != nil { 85 | return err 86 | } 87 | 88 | if !user.IsAdmin && user.ID != userID { 89 | return common.Forbidden() 90 | } 91 | 92 | var body ModifyModel 93 | err = common.ValidateBody(c, &body) 94 | if err != nil { 95 | return err 96 | } 97 | 98 | // cannot get field "has_answered_questions" when admin changes other user's config 99 | if user.ID != userID { 100 | user = &User{ 101 | ID: userID, 102 | } 103 | err = DB.Take(user).Error 104 | if err != nil { 105 | return err 106 | } 107 | } 108 | 109 | err = modifyUser(c, user, body) 110 | if err != nil { 111 | return err 112 | } 113 | 114 | return c.JSON(user) 115 | } 116 | 117 | // ModifyCurrentUser 118 | // 119 | // @Summary modify current user profiles 120 | // @Tags User 121 | // @Produce json 122 | // @Router /users/me [put] 123 | // @Router /users/me/_webvpn [patch] 124 | // @Param user_id path int true "user id" 125 | // @Param json body ModifyModel true "modify user" 126 | // @Success 200 {object} User 127 | func ModifyCurrentUser(c *fiber.Ctx) error { 128 | user, err := GetCurrLoginUser(c) 129 | if err != nil { 130 | return err 131 | } 132 | 133 | var body ModifyModel 134 | err = common.ValidateBody(c, &body) 135 | if err != nil { 136 | return err 137 | } 138 | 139 | err = modifyUser(c, user, body) 140 | if err != nil { 141 | return err 142 | } 143 | 144 | return c.JSON(&user) 145 | } 146 | 147 | func modifyUser(_ *fiber.Ctx, user *User, body ModifyModel) error { 148 | var newUser User 149 | err := DB.Select("config").First(&newUser, user.ID).Error 150 | if err != nil { 151 | return err 152 | } 153 | 154 | if body.Config != nil { 155 | if body.Config.Notify != nil { 156 | newUser.Config.Notify = body.Config.Notify 157 | } 158 | if body.Config.ShowFolded != nil { 159 | newUser.Config.ShowFolded = *body.Config.ShowFolded 160 | } 161 | } 162 | 163 | err = DB.Model(&user).Omit(clause.Associations).Select("Config").UpdateColumns(&newUser).Error 164 | if err != nil { 165 | return err 166 | } 167 | 168 | user.Config = newUser.Config 169 | return nil 170 | } 171 | -------------------------------------------------------------------------------- /apis/user/schemas.go: -------------------------------------------------------------------------------- 1 | package user 2 | 3 | type ModifyModel struct { 4 | Nickname *string `json:"nickname" validate:"omitempty,min=1"` 5 | Config *UserConfigModel `json:"config"` 6 | } 7 | 8 | type UserConfigModel struct { 9 | Notify []string `json:"notify"` 10 | ShowFolded *string `json:"show_folded"` 11 | } 12 | -------------------------------------------------------------------------------- /benchmarks/floor_test.go: -------------------------------------------------------------------------------- 1 | package benchmarks 2 | 3 | import ( 4 | "math/rand" 5 | "strconv" 6 | "testing" 7 | 8 | . "treehole_next/models" 9 | ) 10 | 11 | func BenchmarkListFloorsInAHole(b *testing.B) { 12 | for i := 0; i < b.N; i++ { 13 | b.StopTimer() 14 | route := "/api/holes/" + strconv.Itoa(rand.Intn(HOLE_MAX)+1) + "/floors/" 15 | b.StartTimer() 16 | 17 | benchmarkCommon(b, "get", route, REQUEST_BODY) 18 | } 19 | } 20 | 21 | func BenchmarkGetFloor(b *testing.B) { 22 | for i := 0; i < b.N; i++ { 23 | b.StopTimer() 24 | route := "/api/floors/" + strconv.Itoa(rand.Intn(FLOOR_MAX)+1) + "/" 25 | b.StartTimer() 26 | 27 | benchmarkCommon(b, "get", route, REQUEST_BODY) 28 | } 29 | } 30 | 31 | func BenchmarkCreateFloor(b *testing.B) { 32 | for i := 0; i < b.N; i++ { 33 | b.StopTimer() 34 | route := "/api/holes/" + strconv.Itoa(rand.Intn(HOLE_MAX)+1) + "/floors/" 35 | data := Map{ 36 | "content": strconv.Itoa(rand.Int()), 37 | "reply_to": 0, 38 | } 39 | b.StartTimer() 40 | 41 | benchmarkCommon(b, "post", route, REQUEST_BODY, data) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /benchmarks/hole_test.go: -------------------------------------------------------------------------------- 1 | package benchmarks 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "strconv" 7 | "testing" 8 | 9 | . "treehole_next/models" 10 | _ "treehole_next/tests" 11 | ) 12 | 13 | func BenchmarkListHoles(b *testing.B) { 14 | for i := 0; i < b.N; i++ { 15 | // prepare 16 | b.StopTimer() 17 | route := "/api/divisions/" + strconv.Itoa(rand.Intn(DIVISION_MAX)+1) + "/holes/" 18 | b.StartTimer() 19 | 20 | benchmarkCommon(b, "get", route, REQUEST_BODY) 21 | } 22 | } 23 | 24 | func BenchmarkCreateHoles(b *testing.B) { 25 | for i := 0; i < b.N; i++ { 26 | // prepare 27 | b.StopTimer() 28 | route := "/api/divisions/" + strconv.Itoa(rand.Intn(DIVISION_MAX)+1) + "/holes/" 29 | data := Map{ 30 | "content": fmt.Sprintf("%v", rand.Uint64()), 31 | "tag": []Map{ 32 | {"name": "123"}, 33 | {"name": "456"}, 34 | }, 35 | } 36 | b.StartTimer() 37 | 38 | benchmarkCommon(b, "post", route, REQUEST_BODY, data) 39 | } 40 | } 41 | 42 | func BenchmarkGetHole(b *testing.B) { 43 | for i := 0; i < b.N; i++ { 44 | // prepare 45 | b.StopTimer() 46 | holeID := rand.Intn(HOLE_MAX) + 1 47 | url := "/api/holes/" + strconv.Itoa(holeID) + "/" 48 | b.StartTimer() 49 | 50 | benchmarkCommon(b, "get", url, REQUEST_BODY) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /benchmarks/init.go: -------------------------------------------------------------------------------- 1 | package benchmarks 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "strings" 7 | 8 | "github.com/rs/zerolog/log" 9 | "gorm.io/gorm/logger" 10 | 11 | . "treehole_next/models" 12 | "treehole_next/utils" 13 | ) 14 | 15 | const ( 16 | DIVISION_MAX = 10 17 | TAG_MAX = 100 18 | HOLE_MAX = 100 19 | FLOOR_MAX = 1000 20 | ) 21 | 22 | func init() { 23 | DB.Logger = logger.Default.LogMode(logger.Silent) 24 | 25 | divisions := make(Divisions, 0, DIVISION_MAX) 26 | tags := make(Tags, 0, TAG_MAX) 27 | holes := make(Holes, 0, HOLE_MAX) 28 | floors := make(Floors, 0, FLOOR_MAX) 29 | 30 | for i := 0; i < DIVISION_MAX; i++ { 31 | divisions = append(divisions, &Division{ 32 | ID: i + 1, 33 | Name: strings.Repeat("d", i+1), 34 | Description: strings.Repeat("dd", i+1), 35 | }) 36 | } 37 | 38 | for i := 0; i < TAG_MAX; i++ { 39 | content := fmt.Sprintf("%v", rand.Uint64()) 40 | tags = append(tags, &Tag{ 41 | ID: i + 1, 42 | Name: content, 43 | }) 44 | } 45 | 46 | for i := 0; i < HOLE_MAX; i++ { 47 | generateTag := func() Tags { 48 | nowTags := make(Tags, rand.Intn(10)) 49 | for i := range nowTags { 50 | nowTags[i] = tags[rand.Intn(TAG_MAX)] 51 | } 52 | return nowTags 53 | } 54 | holes = append(holes, &Hole{ 55 | ID: i + 1, 56 | UserID: 1, 57 | DivisionID: rand.Intn(DIVISION_MAX) + 1, 58 | Tags: generateTag(), 59 | }) 60 | } 61 | 62 | for i := 0; i < FLOOR_MAX; i++ { 63 | content := fmt.Sprintf("%v", rand.Uint64()) 64 | generateMention := func() Floors { 65 | floorMentions := make(Floors, 0, rand.Intn(10)) 66 | for j := range floorMentions { 67 | floorMentions[j] = &Floor{ID: rand.Intn(FLOOR_MAX) + 1} 68 | } 69 | return floorMentions 70 | } 71 | floors = append(floors, &Floor{ 72 | ID: i + 1, 73 | Content: strings.Repeat(content, rand.Intn(2)), 74 | Anonyname: utils.GenerateName([]string{}), 75 | HoleID: rand.Intn(HOLE_MAX) + 1, 76 | Mention: generateMention(), 77 | }) 78 | holes[floors[i].HoleID-1].Reply += 1 79 | } 80 | 81 | var err error 82 | err = DB.Create(divisions).Error 83 | if err != nil { 84 | log.Fatal().Err(err).Send() 85 | } 86 | err = DB.Create(tags).Error 87 | if err != nil { 88 | log.Fatal().Err(err).Send() 89 | } 90 | err = DB.Create(holes).Error 91 | if err != nil { 92 | log.Fatal().Err(err).Send() 93 | } 94 | err = DB.Create(floors).Error 95 | if err != nil { 96 | log.Fatal().Err(err).Send() 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /benchmarks/utils.go: -------------------------------------------------------------------------------- 1 | package benchmarks 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "net/http" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/goccy/go-json" 11 | "github.com/hetiansu5/urlquery" 12 | "github.com/stretchr/testify/assert" 13 | 14 | "treehole_next/bootstrap" 15 | . "treehole_next/models" 16 | ) 17 | 18 | var App, _ = bootstrap.Init() 19 | 20 | var _ Map 21 | 22 | const ( 23 | REQUEST_BODY = iota 24 | REQUEST_QUERY 25 | ) 26 | 27 | func benchmarkCommon(b *testing.B, method string, route string, requestType int, data ...Map) []byte { 28 | var requestData []byte 29 | var err error 30 | var req *http.Request 31 | 32 | b.StopTimer() 33 | switch requestType { 34 | case REQUEST_BODY: 35 | if len(data) > 0 && data[0] != nil { // data[0] is request data 36 | requestData, err = json.Marshal(data[0]) 37 | assert.Nilf(b, err, "encode request body") 38 | } 39 | req, err = http.NewRequest( 40 | strings.ToUpper(method), 41 | route, 42 | bytes.NewBuffer(requestData), 43 | ) 44 | case REQUEST_QUERY: 45 | req, err = http.NewRequest( 46 | strings.ToUpper(method), 47 | route, 48 | nil, 49 | ) 50 | if len(data) > 0 && data[0] != nil { // data[0] is query data 51 | queryData, err := urlquery.Marshal(data[0]) 52 | req.URL.RawQuery = string(queryData) 53 | assert.Nilf(b, err, "encode request body") 54 | } 55 | } 56 | 57 | req.Header.Add("Content-Type", "application/json") 58 | assert.Nilf(b, err, "constructs http request") 59 | 60 | b.StartTimer() 61 | res, err := App.Test(req, -1) 62 | b.StopTimer() 63 | assert.Nilf(b, err, "perform request") 64 | 65 | responseBody, err := io.ReadAll(res.Body) 66 | assert.Nilf(b, err, "decode response") 67 | 68 | if res.StatusCode != 200 && res.StatusCode != 201 { 69 | assert.Fail(b, string(responseBody)) 70 | } 71 | return responseBody 72 | } 73 | -------------------------------------------------------------------------------- /bootstrap/init.go: -------------------------------------------------------------------------------- 1 | package bootstrap 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/gofiber/fiber/v2" 7 | "github.com/opentreehole/go-common" 8 | 9 | "treehole_next/apis" 10 | "treehole_next/apis/hole" 11 | "treehole_next/apis/message" 12 | "treehole_next/config" 13 | "treehole_next/models" 14 | "treehole_next/utils" 15 | "treehole_next/utils/sensitive" 16 | 17 | "github.com/goccy/go-json" 18 | "github.com/gofiber/fiber/v2/middleware/pprof" 19 | "github.com/gofiber/fiber/v2/middleware/recover" 20 | ) 21 | 22 | func Init() (*fiber.App, context.CancelFunc) { 23 | config.InitConfig() 24 | utils.InitCache() 25 | sensitive.InitSensitiveLabelMap() 26 | models.Init() 27 | models.InitDB() 28 | models.InitAdminList() 29 | 30 | app := fiber.New(fiber.Config{ 31 | ErrorHandler: common.ErrorHandler, 32 | JSONEncoder: json.Marshal, 33 | JSONDecoder: json.Unmarshal, 34 | DisableStartupMessage: true, 35 | }) 36 | registerMiddlewares(app) 37 | apis.RegisterRoutes(app) 38 | 39 | return app, startTasks() 40 | } 41 | 42 | func registerMiddlewares(app *fiber.App) { 43 | app.Use(recover.New(recover.Config{EnableStackTrace: true})) 44 | app.Use(common.MiddlewareGetUserID) 45 | if config.Config.Mode != "bench" { 46 | app.Use(common.MiddlewareCustomLogger) 47 | } 48 | app.Use(pprof.New()) 49 | } 50 | 51 | func startTasks() context.CancelFunc { 52 | ctx, cancel := context.WithCancel(context.Background()) 53 | go hole.UpdateHoleViews(ctx) 54 | go hole.PurgeHole(ctx) 55 | go message.PurgeMessage() 56 | // go models.UpdateAdminList(ctx) 57 | go sensitive.UpdateSensitiveLabelMap(ctx) 58 | return cancel 59 | } 60 | -------------------------------------------------------------------------------- /config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/caarlos0/env/v9" 5 | "net/url" 6 | "sync/atomic" 7 | 8 | "github.com/rs/zerolog/log" 9 | ) 10 | 11 | var Config struct { 12 | Mode string `env:"MODE" envDefault:"dev"` 13 | TZ string `env:"TZ" envDefault:"Asia/Shanghai"` 14 | Size int `env:"SIZE" envDefault:"30"` 15 | MaxSize int `env:"MAX_SIZE" envDefault:"50"` 16 | TagSize int `env:"TAG_SIZE" envDefault:"5"` 17 | HoleFloorSize int `env:"HOLE_FLOOR_SIZE" envDefault:"10"` 18 | Debug bool `env:"DEBUG" envDefault:"false"` 19 | // example: user:pass@tcp(127.0.0.1:3306)/dbname?parseTime=true&loc=Asia%2fShanghai 20 | // set time_zone in url, otherwise UTC 21 | // for more detail, see https://github.com/go-sql-driver/mysql#dsn-data-source-name 22 | DbURL string `env:"DB_URL"` 23 | // example: MYSQL_REPLICA_URL="db1_dsn,db2_dsn", use ',' as separator 24 | // should also set time_zone in url 25 | MysqlReplicaURLs []string `env:"MYSQL_REPLICA_URL"` 26 | RedisURL string `env:"REDIS_URL"` // redis:6379 27 | NotificationUrl string `env:"NOTIFICATION_URL"` 28 | MessagePurgeDays int `envDefault:"7" env:"MESSAGE_PURGE_DAYS"` 29 | AuthUrl string `env:"AUTH_URL"` 30 | ElasticsearchUrl string `env:"ELASTICSEARCH_URL"` 31 | OpenSearch bool `env:"OPEN_SEARCH" envDefault:"true"` 32 | OpenFuzzName bool `env:"OPEN_FUZZ_NAME" envDefault:"false"` 33 | UserAllShowHidden bool `env:"USER_ALL_HIDDEN" envDefault:"false"` 34 | AdminOnly bool `env:"ADMIN_ONLY" envDefault:"false"` 35 | HolePurgeDivisions []int `env:"HOLE_PURGE_DIVISIONS" envDefault:"2"` 36 | HolePurgeDays int `env:"HOLE_PURGE_DAYS" envDefault:"30"` 37 | OpenSensitiveCheck bool `env:"OPEN_SENSITIVE_CHECK" envDefault:"true"` 38 | 39 | YiDunBusinessIdText string `env:"YI_DUN_BUSINESS_ID_TEXT" envDefault:""` 40 | YiDunBusinessIdImage string `env:"YI_DUN_BUSINESS_ID_IMAGE" envDefault:""` 41 | YiDunSecretId string `env:"YI_DUN_SECRET_ID" envDefault:""` 42 | YiDunSecretKey string `env:"YI_DUN_SECRET_KEY" envDefault:""` 43 | YiDunAccessKeyId string `env:"YI_DUN_ACCESS_KEY_ID" envDefault:""` 44 | YiDunAccessKeySecret string `env:"YI_DUN_ACCESS_KEY_SECRET" envDefault:""` 45 | ValidImageUrl []string `env:"VALID_IMAGE_URL"` 46 | UrlHostnameWhitelist []string `env:"URL_HOSTNAME_WHITELIST"` 47 | ExternalImageHost string `env:"EXTERNAL_IMAGE_HOSTNAME" envDefault:""` 48 | NotifiableAdminIds []int `env:"NOTIFIABLE_ADMIN_IDS"` 49 | ExcludeBanForeverDivisionIds []int `env:"EXCLUDE_BAN_FOREVER_DIVISION_IDS"` 50 | ProxyUrl *url.URL `env:"PROXY_URL"` 51 | QQBotPhysicsGroupID *int64 `env:"PHYSICS_GROUP_ID"` 52 | QQBotCodingGroupID *int64 `env:"CODING_GROUP_ID"` 53 | QQBotUserID *int64 `env:"USER_ID"` 54 | QQBotUrl *string `env:"QQ_BOT_URL"` 55 | FeishuBotUrl *string `env:"FEISHU_BOT_URL"` 56 | AdminOnlyTagIds []int `env:"ADMIN_ONLY_TAG_IDS"` 57 | } 58 | 59 | var DynamicConfig struct { 60 | OpenSearch atomic.Bool 61 | } 62 | 63 | func InitConfig() { // load config from environment variables 64 | if err := env.Parse(&Config); err != nil { 65 | log.Fatal().Err(err).Send() 66 | } 67 | log.Info().Any("config", Config).Msg("init config") 68 | DynamicConfig.OpenSearch.Store(Config.OpenSearch) 69 | } 70 | -------------------------------------------------------------------------------- /data/data.go: -------------------------------------------------------------------------------- 1 | package data 2 | 3 | import ( 4 | _ "embed" 5 | "os" 6 | 7 | "github.com/goccy/go-json" 8 | "github.com/rs/zerolog/log" 9 | ) 10 | 11 | //go:embed names.json 12 | var NamesFile []byte 13 | 14 | //go:embed meta.json 15 | var MetaFile []byte 16 | 17 | var NamesMapping map[string]string 18 | 19 | func init() { 20 | err := initNamesMapping() 21 | if err != nil { 22 | log.Err(err).Msg("could not init names mapping") 23 | } 24 | } 25 | 26 | func initNamesMapping() error { 27 | NamesMappingData, err := os.ReadFile(`data/names_mapping.json`) 28 | if err != nil { 29 | return err 30 | } 31 | 32 | return json.Unmarshal(NamesMappingData, &NamesMapping) 33 | } 34 | -------------------------------------------------------------------------------- /data/meta.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Open Tree Hole", 3 | "description": "Next Generation of OpenTreeHole Implemented In Go ---- An Anonymous BBS", 4 | "version": "2.0.0", 5 | "homepage": "https://github.com/opentreehole", 6 | "repository": "https://github.com/OpenTreeHole/treehole_next", 7 | "author": "hasbai", 8 | "maintainer": "jingyijun", 9 | "email": "dev@danta.tech", 10 | "license": "Apache-2.0" 11 | } -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module treehole_next 2 | 3 | go 1.22 4 | 5 | require ( 6 | github.com/caarlos0/env/v9 v9.0.0 7 | github.com/eko/gocache/lib/v4 v4.1.6 8 | github.com/eko/gocache/store/go_cache/v4 v4.2.2 9 | github.com/eko/gocache/store/redis/v4 v4.2.2 10 | github.com/elastic/go-elasticsearch/v8 v8.14.0 11 | github.com/goccy/go-json v0.10.3 12 | github.com/gofiber/fiber/v2 v2.52.5 13 | github.com/hetiansu5/urlquery v1.2.7 14 | github.com/opentreehole/go-common v0.1.7 15 | github.com/patrickmn/go-cache v2.1.0+incompatible 16 | github.com/redis/go-redis/v9 v9.6.1 17 | github.com/rs/zerolog v1.33.0 18 | github.com/stretchr/testify v1.9.0 19 | github.com/swaggo/fiber-swagger v1.3.0 20 | github.com/swaggo/swag v1.16.3 21 | github.com/yidun/yidun-golang-sdk v1.0.14 22 | golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 23 | gorm.io/driver/mysql v1.5.7 24 | gorm.io/driver/sqlite v1.5.6 25 | gorm.io/gorm v1.25.11 26 | gorm.io/plugin/dbresolver v1.5.2 27 | mvdan.cc/xurls/v2 v2.5.0 28 | ) 29 | 30 | require ( 31 | filippo.io/edwards25519 v1.1.0 // indirect 32 | github.com/KyleBanks/depth v1.2.1 // indirect 33 | github.com/andybalholm/brotli v1.1.0 // indirect 34 | github.com/beorn7/perks v1.0.1 // indirect 35 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 36 | github.com/creasty/defaults v1.7.0 // indirect 37 | github.com/davecgh/go-spew v1.1.1 // indirect 38 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 39 | github.com/elastic/elastic-transport-go/v8 v8.6.0 // indirect 40 | github.com/gabriel-vasile/mimetype v1.4.5 // indirect 41 | github.com/go-logr/logr v1.4.2 // indirect 42 | github.com/go-logr/stdr v1.2.2 // indirect 43 | github.com/go-openapi/jsonpointer v0.21.0 // indirect 44 | github.com/go-openapi/jsonreference v0.21.0 // indirect 45 | github.com/go-openapi/spec v0.21.0 // indirect 46 | github.com/go-openapi/swag v0.23.0 // indirect 47 | github.com/go-playground/locales v0.14.1 // indirect 48 | github.com/go-playground/universal-translator v0.18.1 // indirect 49 | github.com/go-playground/validator/v10 v10.22.0 // indirect 50 | github.com/go-sql-driver/mysql v1.8.1 // indirect 51 | github.com/golang/mock v1.6.0 // indirect 52 | github.com/google/uuid v1.6.0 // indirect 53 | github.com/jinzhu/inflection v1.0.0 // indirect 54 | github.com/jinzhu/now v1.1.5 // indirect 55 | github.com/josharian/intern v1.0.0 // indirect 56 | github.com/klauspost/compress v1.17.9 // indirect 57 | github.com/leodido/go-urn v1.4.0 // indirect 58 | github.com/mailru/easyjson v0.7.7 // indirect 59 | github.com/mattn/go-colorable v0.1.13 // indirect 60 | github.com/mattn/go-isatty v0.0.20 // indirect 61 | github.com/mattn/go-runewidth v0.0.16 // indirect 62 | github.com/mattn/go-sqlite3 v1.14.22 // indirect 63 | github.com/pmezard/go-difflib v1.0.0 // indirect 64 | github.com/prometheus/client_golang v1.19.0 // indirect 65 | github.com/prometheus/client_model v0.6.0 // indirect 66 | github.com/prometheus/common v0.51.1 // indirect 67 | github.com/prometheus/procfs v0.13.0 // indirect 68 | github.com/rivo/uniseg v0.4.7 // indirect 69 | github.com/swaggo/files v1.0.1 // indirect 70 | github.com/tjfoc/gmsm v1.4.1 // indirect 71 | github.com/valyala/bytebufferpool v1.0.0 // indirect 72 | github.com/valyala/fasthttp v1.55.0 // indirect 73 | github.com/valyala/tcplisten v1.0.0 // indirect 74 | go.opentelemetry.io/otel v1.28.0 // indirect 75 | go.opentelemetry.io/otel/metric v1.28.0 // indirect 76 | go.opentelemetry.io/otel/trace v1.28.0 // indirect 77 | golang.org/x/crypto v0.26.0 // indirect 78 | golang.org/x/net v0.28.0 // indirect 79 | golang.org/x/sync v0.8.0 // indirect 80 | golang.org/x/sys v0.23.0 // indirect 81 | golang.org/x/text v0.17.0 // indirect 82 | golang.org/x/tools v0.24.0 // indirect 83 | google.golang.org/protobuf v1.33.0 // indirect 84 | gopkg.in/yaml.v3 v3.0.1 // indirect 85 | ) 86 | 87 | replace github.com/yidun/yidun-golang-sdk => github.com/jingyijun/yidun-golang-sdk v1.0.13-0.20240709102803-aaae270a5671 88 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | "os/signal" 6 | "syscall" 7 | 8 | "github.com/rs/zerolog/log" 9 | 10 | "treehole_next/bootstrap" 11 | ) 12 | 13 | // @title Open Tree Hole 14 | // @version 2.1.0 15 | // @description An Anonymous BBS \n Note: PUT methods are used to PARTLY update, and we don't use PATCH method. 16 | 17 | // @contact.name Maintainer Ke Chen 18 | // @contact.email dev@danta.tech 19 | 20 | // @license.name Apache 2.0 21 | // @license.url https://www.apache.org/licenses/LICENSE-2.0.html 22 | 23 | // @host 24 | // @BasePath /api 25 | 26 | func main() { 27 | app, cancel := bootstrap.Init() 28 | go func() { 29 | err := app.Listen("0.0.0.0:8000") 30 | if err != nil { 31 | log.Fatal().Err(err).Msg("app listen failed") 32 | } 33 | }() 34 | 35 | interrupt := make(chan os.Signal, 1) 36 | 37 | // wait for CTRL-C interrupt 38 | signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM) 39 | <-interrupt 40 | 41 | // close app 42 | err := app.Shutdown() 43 | if err != nil { 44 | log.Err(err).Msg("error shutdown app") 45 | } 46 | // stop tasks 47 | cancel() 48 | } 49 | -------------------------------------------------------------------------------- /models/admin_log.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "github.com/rs/zerolog/log" 5 | "gorm.io/gorm" 6 | "time" 7 | ) 8 | 9 | type AdminLog struct { 10 | ID int `gorm:"primaryKey"` 11 | CreatedAt time.Time 12 | Type AdminLogType `gorm:"size:16;not null"` 13 | UserID int `gorm:"not null"` 14 | Data any `gorm:"serializer:json"` 15 | } 16 | 17 | type AdminLogType string 18 | 19 | const ( 20 | AdminLogTypeHole AdminLogType = "edit_hole" 21 | AdminLogTypeHideHole AdminLogType = "hide_hole" 22 | AdminLogTypeTag AdminLogType = "edit_tag" 23 | AdminLogTypeDivision AdminLogType = "edit_division" 24 | AdminLogTypeMessage AdminLogType = "send_message" 25 | AdminLogTypeDeleteReport AdminLogType = "delete_report" 26 | AdminLogTypeChangeSensitive AdminLogType = "change_sensitive" 27 | ) 28 | 29 | // CreateAdminLog 30 | // save admin edit log for audit purpose 31 | func CreateAdminLog(tx *gorm.DB, logType AdminLogType, userID int, data any) { 32 | adminLog := AdminLog{ 33 | Type: logType, 34 | UserID: userID, 35 | Data: data, 36 | } 37 | err := tx.Create(&adminLog).Error // omit error 38 | if err != nil { 39 | log.Error().Err(err).Msg("failed to create admin log") 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /models/anonyname.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "errors" 5 | 6 | "gorm.io/gorm" 7 | "gorm.io/gorm/clause" 8 | 9 | "treehole_next/utils" 10 | ) 11 | 12 | type AnonynameMapping struct { 13 | HoleID int `json:"hole_id" gorm:"primaryKey"` 14 | UserID int `json:"user_id" gorm:"primaryKey"` 15 | Anonyname string `json:"anonyname" gorm:"size:32"` 16 | } 17 | 18 | func NewAnonyname(tx *gorm.DB, holeID, userID int) (string, error) { 19 | name := utils.NewRandName() 20 | return name, tx.Create(&AnonynameMapping{ 21 | HoleID: holeID, 22 | UserID: userID, 23 | Anonyname: name, 24 | }).Error 25 | } 26 | 27 | func FindOrGenerateAnonyname(tx *gorm.DB, holeID, userID int) (string, error) { 28 | var anonyname string 29 | err := tx. 30 | Model(&AnonynameMapping{}). 31 | Select("anonyname"). 32 | Where("hole_id = ?", holeID). 33 | Where("user_id = ?", userID). 34 | Take(&anonyname).Error 35 | 36 | if err != nil { 37 | if errors.Is(err, gorm.ErrRecordNotFound) { 38 | var names []string 39 | err = tx. 40 | Clauses(clause.Locking{Strength: "UPDATE"}). 41 | Model(&AnonynameMapping{}). 42 | Select("anonyname"). 43 | Where("hole_id = ?", holeID). 44 | Order("anonyname"). 45 | Scan(&names).Error 46 | if err != nil { 47 | return "", err 48 | } 49 | 50 | anonyname = utils.GenerateName(names) 51 | err = tx.Create(&AnonynameMapping{ 52 | HoleID: holeID, 53 | UserID: userID, 54 | Anonyname: anonyname, 55 | }).Error 56 | if err != nil { 57 | return anonyname, err 58 | } 59 | } else { 60 | return "", err 61 | } 62 | } 63 | return anonyname, nil 64 | } 65 | -------------------------------------------------------------------------------- /models/base.go: -------------------------------------------------------------------------------- 1 | // Package models contains database models 2 | package models 3 | 4 | type Map = map[string]interface{} 5 | 6 | type Models interface { 7 | Division | Hole | Floor | Tag | User | Report | Message | 8 | Divisions | Holes | Floors | Tags | Users | Reports | Messages 9 | } 10 | 11 | type MessageModel struct { 12 | Message string `json:"message"` 13 | } 14 | -------------------------------------------------------------------------------- /models/division.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "time" 5 | 6 | "treehole_next/utils" 7 | 8 | "github.com/gofiber/fiber/v2" 9 | "gorm.io/gorm" 10 | ) 11 | 12 | type Division struct { 13 | /// saved fields 14 | ID int `json:"id" gorm:"primaryKey"` 15 | CreatedAt time.Time `json:"time_created" gorm:"not null"` 16 | UpdatedAt time.Time `json:"time_updated" gorm:"not null"` 17 | 18 | /// base info 19 | Name string `json:"name" gorm:"unique;size:10"` 20 | Description string `json:"description" gorm:"size:64"` 21 | Hidden bool `json:"hidden" gorm:"not null;default:false"` 22 | 23 | // pinned holes in given order 24 | Pinned []int `json:"-" gorm:"serializer:json;size:100;not null;default:\"[]\""` 25 | 26 | /// association fields, should add foreign key 27 | 28 | // return pinned hole to frontend 29 | Holes Holes `json:"pinned"` 30 | 31 | /// generated field 32 | DivisionID int `json:"division_id" gorm:"-:all"` 33 | } 34 | 35 | func (division *Division) GetID() int { 36 | return division.ID 37 | } 38 | 39 | type Divisions []*Division 40 | 41 | func (divisions Divisions) Preprocess(c *fiber.Ctx) error { 42 | for _, division := range divisions { 43 | err := division.Preprocess(c) 44 | if err != nil { 45 | return err 46 | } 47 | } 48 | return utils.SetCache("divisions", divisions, 0) 49 | } 50 | 51 | func (division *Division) Preprocess(c *fiber.Ctx) error { 52 | var pinned = division.Pinned 53 | division.Holes = make(Holes, 0, 10) 54 | if len(pinned) == 0 { 55 | return nil 56 | } 57 | DB.Find(&division.Holes, pinned) 58 | if len(division.Holes) == 0 { 59 | return nil 60 | } 61 | division.Holes = utils.OrderInGivenOrder(division.Holes, pinned) 62 | // division.Holes = division.Holes.RemoveIf(func(hole *Hole) bool { 63 | // return hole.Hidden 64 | // }) 65 | return division.Holes.Preprocess(c) 66 | } 67 | 68 | func (division *Division) AfterFind(_ *gorm.DB) (err error) { 69 | division.DivisionID = division.ID 70 | return nil 71 | } 72 | 73 | func (division *Division) AfterCreate(_ *gorm.DB) (err error) { 74 | division.DivisionID = division.ID 75 | return nil 76 | } 77 | -------------------------------------------------------------------------------- /models/favorite_group.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "errors" 5 | "github.com/opentreehole/go-common" 6 | "gorm.io/gorm" 7 | "gorm.io/plugin/dbresolver" 8 | "time" 9 | ) 10 | 11 | type FavoriteGroup struct { 12 | FavoriteGroupID int `json:"favorite_group_id" gorm:"primaryKey"` 13 | UserID int `json:"user_id" gorm:"primaryKey"` 14 | Name string `json:"name" gorm:"not null;size:64" default:"默认"` 15 | CreatedAt time.Time `json:"time_created"` 16 | UpdatedAt time.Time `json:"time_updated"` 17 | Deleted bool `json:"deleted" gorm:"default:false"` 18 | Count int `json:"count" gorm:"default:0"` 19 | } 20 | 21 | const MaxGroupPerUser = 10 22 | 23 | type FavoriteGroups []FavoriteGroup 24 | 25 | func (FavoriteGroup) TableName() string { 26 | return "favorite_groups" 27 | } 28 | 29 | // make sure use this function in a transaction 30 | func UserGetFavoriteGroups(tx *gorm.DB, userID int, order *string) (favoriteGroups FavoriteGroups, err error) { 31 | err = CheckDefaultFavoriteGroup(tx, userID) 32 | if err != nil { 33 | return 34 | } 35 | 36 | if order == nil { 37 | err = tx.Where("user_id = ? and deleted = false", userID).Find(&favoriteGroups).Error 38 | } else { 39 | err = tx.Where("user_id = ? and deleted = false", userID).Order(*order).Find(&favoriteGroups).Error 40 | } 41 | return 42 | } 43 | 44 | func DeleteUserFavoriteGroup(tx *gorm.DB, userID int, groupID int) (err error) { 45 | if groupID == 0 { 46 | return common.Forbidden("默认收藏夹不可删除") 47 | } 48 | err = tx.Model(&UserFavorite{}).Where("user_id = ? AND favorite_group_id = ?", userID, groupID).Take(&UserFavorite{}).Error 49 | if err != nil { 50 | if !errors.Is(err, gorm.ErrRecordNotFound) { 51 | return err 52 | } 53 | } else { 54 | return common.Forbidden("收藏夹中存在收藏内容,请先移除") 55 | } 56 | 57 | result := tx.Clauses(dbresolver.Write).Where("user_id = ? AND favorite_group_id = ?", userID, groupID).Updates(FavoriteGroup{Deleted: true}) 58 | if result.Error != nil { 59 | return err 60 | } 61 | if result.RowsAffected == 0 { 62 | return common.NotFound("收藏夹不存在") 63 | } 64 | err = tx.Model(&UserFavorite{}).Where("user_id = ? AND favorite_group_id = ?", userID, groupID).Delete(&UserFavorite{}).Error 65 | if err != nil { 66 | return err 67 | } 68 | return tx.Model(&User{}).Where("id = ?", userID).Update("favorite_group_count", gorm.Expr("favorite_group_count - 1")).Error 69 | } 70 | 71 | func CheckDefaultFavoriteGroup(tx *gorm.DB, userID int) (err error) { 72 | return tx.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error { 73 | err = tx.Model(&FavoriteGroup{}).Where("user_id = ? AND favorite_group_id = 0", userID).Take(&FavoriteGroup{}).Error 74 | if err != nil { 75 | if !errors.Is(err, gorm.ErrRecordNotFound) { 76 | return err 77 | } 78 | 79 | // insert default favorite group if not exists 80 | err = tx.Create(&FavoriteGroup{ 81 | UserID: userID, 82 | Name: "默认收藏夹", 83 | FavoriteGroupID: 0, 84 | CreatedAt: time.Now(), 85 | }).Error 86 | if err != nil { 87 | return err 88 | } 89 | return tx.Model(&User{}).Where("id = ?", userID).Update("favorite_group_count", gorm.Expr("favorite_group_count + 1")).Error 90 | } 91 | 92 | // default favorite group exists 93 | return nil 94 | }) 95 | 96 | } 97 | 98 | func AddUserFavoriteGroup(tx *gorm.DB, userID int, name string) (err error) { 99 | return tx.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error { 100 | var groupID int 101 | err = tx.Model(&FavoriteGroup{}).Select("IFNULL(MAX(favorite_group_id), 0) AS max_id").Where("user_id = ? and deleted = false", userID). 102 | Take(&groupID).Error 103 | groupID++ 104 | if err != nil { 105 | return err 106 | } 107 | if groupID >= MaxGroupPerUser { 108 | err = tx.Model(&FavoriteGroup{}).Where("user_id = ? and deleted = true", userID).Order("favorite_group_id").Limit(1).Take(&groupID).Error 109 | } 110 | if errors.Is(err, gorm.ErrRecordNotFound) { 111 | return common.Forbidden("收藏夹数量已达上限") 112 | } 113 | if err != nil { 114 | return err 115 | } 116 | 117 | err = tx.Create(&FavoriteGroup{ 118 | UserID: userID, 119 | Name: name, 120 | FavoriteGroupID: groupID, 121 | CreatedAt: time.Now(), 122 | }).Error 123 | if err != nil { 124 | return err 125 | } 126 | return tx.Model(&User{}).Where("id = ?", userID).Update("favorite_group_count", gorm.Expr("favorite_group_count + 1")).Error 127 | }) 128 | } 129 | 130 | func ModifyUserFavoriteGroup(tx *gorm.DB, userID int, groupID int, name string) (err error) { 131 | return tx.Clauses(dbresolver.Write).Where("user_id = ? AND favorite_group_id = ?", userID, groupID). 132 | Updates(FavoriteGroup{Name: name, UpdatedAt: time.Now()}).Error 133 | } 134 | -------------------------------------------------------------------------------- /models/floor_history.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import "time" 4 | 5 | type FloorHistory struct { 6 | /// base info 7 | ID int `json:"id" gorm:"primaryKey"` 8 | CreatedAt time.Time `json:"time_created"` 9 | UpdatedAt time.Time `json:"time_updated"` 10 | Content string `json:"content" gorm:"size:15000"` 11 | Reason string `json:"reason"` 12 | FloorID int `json:"floor_id"` 13 | // auto sensitive check 14 | IsSensitive bool `json:"is_sensitive"` 15 | 16 | // manual sensitive check 17 | IsActualSensitive *bool `json:"is_actual_sensitive"` 18 | 19 | SensitiveDetail string `json:"sensitive_detail,omitempty"` 20 | // The one who modified the floor 21 | UserID int `json:"user_id"` 22 | } 23 | 24 | type FloorHistorySlice []*FloorHistory 25 | -------------------------------------------------------------------------------- /models/floor_like.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type FloorLike struct { 4 | FloorID int `json:"floor_id" gorm:"primaryKey"` 5 | UserID int `json:"user_id" gorm:"primaryKey"` 6 | LikeData int8 `json:"like_data"` 7 | } 8 | -------------------------------------------------------------------------------- /models/floor_mention.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "regexp" 5 | 6 | "gorm.io/gorm" 7 | 8 | "treehole_next/utils" 9 | ) 10 | 11 | type FloorMention struct { 12 | FloorID int `json:"floor_id" gorm:"primaryKey"` 13 | MentionID int `json:"mention_id" gorm:"primaryKey"` 14 | } 15 | 16 | func (FloorMention) TableName() string { 17 | return "floor_mention" 18 | } 19 | 20 | var reHole = regexp.MustCompile(`[^#]#(\d+)`) 21 | var reFloor = regexp.MustCompile(`##(\d+)`) 22 | 23 | func parseMentionIDs(content string) (holeIDs []int, floorIDs []int, err error) { 24 | // todo: parse replyTo 25 | 26 | // find mentioned holeIDs 27 | holeIDsText := reHole.FindAllStringSubmatch(" "+content, -1) 28 | holeIDs, err = utils.RegText2IntArray(holeIDsText) 29 | if err != nil { 30 | return nil, nil, err 31 | } 32 | 33 | // find mentioned floorIDs 34 | floorIDsText := reFloor.FindAllStringSubmatch(" "+content, -1) 35 | floorIDs, err = utils.RegText2IntArray(floorIDsText) 36 | return holeIDs, floorIDs, err 37 | } 38 | 39 | func LoadFloorMentions(tx *gorm.DB, content string) (Floors, error) { 40 | holeIDs, floorIDs, err := parseMentionIDs(content) 41 | if err != nil { 42 | return nil, err 43 | } 44 | 45 | queryGetHoleFloors := tx.Model(&Floor{}).Where("hole_id in ? and ranking = 0", holeIDs) 46 | queryGetFloors := tx.Model(&Floor{}).Where("id in ?", floorIDs) 47 | mentionFloors := Floors{} 48 | if len(holeIDs) > 0 && len(floorIDs) > 0 { 49 | err = tx.Raw(`? UNION ?`, queryGetHoleFloors, queryGetFloors).Scan(&mentionFloors).Error 50 | } else if len(holeIDs) > 0 { 51 | err = queryGetHoleFloors.Scan(&mentionFloors).Error 52 | } else if len(floorIDs) > 0 { 53 | err = queryGetFloors.Scan(&mentionFloors).Error 54 | } 55 | return mentionFloors, err 56 | } 57 | -------------------------------------------------------------------------------- /models/hole_tags.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type HoleTag struct { 4 | HoleID int `json:"hole_id" gorm:"index"` 5 | TagID int `json:"tag_id" gorm:"index"` 6 | } 7 | 8 | func (HoleTag) TableName() string { 9 | return "hole_tags" 10 | } 11 | 12 | type HoleTags []*HoleTag 13 | -------------------------------------------------------------------------------- /models/init.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "os" 5 | "time" 6 | 7 | "github.com/rs/zerolog/log" 8 | 9 | "treehole_next/config" 10 | 11 | "gorm.io/gorm/logger" 12 | "gorm.io/plugin/dbresolver" 13 | 14 | "gorm.io/driver/mysql" 15 | "gorm.io/driver/sqlite" 16 | "gorm.io/gorm" 17 | "gorm.io/gorm/schema" 18 | ) 19 | 20 | var DB *gorm.DB 21 | 22 | var gormConfig = &gorm.Config{ 23 | NamingStrategy: schema.NamingStrategy{ 24 | SingularTable: true, // use singular table name, table for `User` would be `user` with this option enabled 25 | }, 26 | Logger: logger.New( 27 | &log.Logger, 28 | logger.Config{ 29 | SlowThreshold: time.Second, // 慢 SQL 阈值 30 | LogLevel: logger.Error, // 日志级别 31 | IgnoreRecordNotFoundError: true, // 忽略ErrRecordNotFound(记录未找到)错误 32 | Colorful: false, // 禁用彩色打印 33 | }, 34 | ), 35 | } 36 | 37 | // Read/Write Splitting 38 | func mysqlDB() *gorm.DB { 39 | // set source databases 40 | source := mysql.Open(config.Config.DbURL) 41 | db, err := gorm.Open(source, gormConfig) 42 | if err != nil { 43 | log.Fatal().Err(err).Send() 44 | } 45 | 46 | // set replica databases 47 | var replicas []gorm.Dialector 48 | for _, url := range config.Config.MysqlReplicaURLs { 49 | replicas = append(replicas, mysql.Open(url)) 50 | } 51 | err = db.Use(dbresolver.Register(dbresolver.Config{ 52 | Sources: []gorm.Dialector{source}, 53 | Replicas: replicas, 54 | Policy: dbresolver.RandomPolicy{}, 55 | })) 56 | if err != nil { 57 | log.Fatal().Err(err).Send() 58 | } 59 | return db 60 | } 61 | 62 | func sqliteDB() *gorm.DB { 63 | err := os.MkdirAll("data", 0750) 64 | if err != nil { 65 | log.Fatal().Err(err).Send() 66 | } 67 | db, err := gorm.Open(sqlite.Open("data/sqlite.db"), gormConfig) 68 | if err != nil { 69 | log.Fatal().Err(err).Send() 70 | } 71 | // https://github.com/go-gorm/gorm/issues/3709 72 | phyDB, err := db.DB() 73 | if err != nil { 74 | log.Fatal().Err(err).Send() 75 | } 76 | phyDB.SetMaxOpenConns(1) 77 | return db 78 | } 79 | 80 | func memoryDB() *gorm.DB { 81 | db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), gormConfig) 82 | if err != nil { 83 | log.Fatal().Err(err).Send() 84 | } 85 | // https://github.com/go-gorm/gorm/issues/3709 86 | phyDB, err := db.DB() 87 | if err != nil { 88 | log.Fatal().Err(err).Send() 89 | } 90 | phyDB.SetMaxOpenConns(1) 91 | return db 92 | } 93 | 94 | func InitDB() { 95 | var err error 96 | switch config.Config.Mode { 97 | case "production": 98 | DB = mysqlDB() 99 | case "test": 100 | fallthrough 101 | case "bench": 102 | DB = memoryDB() 103 | case "dev": 104 | if config.Config.DbURL == "" { 105 | DB = sqliteDB() 106 | } else { 107 | DB = mysqlDB() 108 | } 109 | default: 110 | log.Fatal().Msg("unknown mode") 111 | } 112 | 113 | switch config.Config.Mode { 114 | case "test": 115 | fallthrough 116 | case "dev": 117 | DB = DB.Debug() 118 | } 119 | 120 | err = DB.SetupJoinTable(&User{}, "UserLikedFloors", &FloorLike{}) 121 | if err != nil { 122 | log.Fatal().Err(err).Send() 123 | } 124 | 125 | err = DB.SetupJoinTable(&Hole{}, "Mapping", &AnonynameMapping{}) 126 | if err != nil { 127 | log.Fatal().Err(err).Send() 128 | } 129 | 130 | err = DB.SetupJoinTable(&Message{}, "Users", &MessageUser{}) 131 | if err != nil { 132 | log.Fatal().Err(err).Send() 133 | } 134 | 135 | err = DB.SetupJoinTable(&User{}, "UserSubscription", &UserSubscription{}) 136 | if err != nil { 137 | log.Fatal().Err(err).Send() 138 | } 139 | 140 | // models must be registered here to migrate into the database 141 | err = DB.AutoMigrate( 142 | &Division{}, 143 | &Tag{}, 144 | &User{}, 145 | &Floor{}, 146 | &Hole{}, 147 | &Report{}, 148 | &Punishment{}, 149 | &ReportPunishment{}, 150 | &Message{}, 151 | &FloorHistory{}, 152 | &AdminLog{}, 153 | &UserFavorite{}, 154 | &FavoriteGroup{}, 155 | &UrlHostnameWhitelist{}, 156 | ) 157 | if err != nil { 158 | log.Fatal().Err(err).Send() 159 | } 160 | 161 | err = DB.Model(&UrlHostnameWhitelist{}).Pluck("hostname", &config.Config.UrlHostnameWhitelist).Error 162 | if err != nil { 163 | log.Fatal().Err(err).Send() 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /models/message.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | // Should be same as message in notification project 4 | 5 | import ( 6 | "time" 7 | 8 | "github.com/gofiber/fiber/v2" 9 | "gorm.io/gorm" 10 | ) 11 | 12 | type Messages []Message 13 | 14 | type Message struct { 15 | ID int `gorm:"primaryKey" json:"id"` 16 | CreatedAt time.Time `json:"time_created"` 17 | UpdatedAt time.Time `json:"time_updated"` 18 | Title string `json:"message" gorm:"size:1024;not null"` 19 | Description string `json:"description" gorm:"size:65536;not null"` 20 | Data any `json:"data" gorm:"serializer:json" ` 21 | Type MessageType `json:"code" gorm:"size:16;not null"` 22 | URL string `json:"url" gorm:"size:64;default:'';not null"` 23 | Recipients []int `json:"-" gorm:"-:all" ` 24 | MessageID int `json:"message_id" gorm:"-:all"` // 兼容旧版 id 25 | HasRead bool `json:"has_read" gorm:"default:false"` // 兼容旧版, 永远为false,以MessageUser的HasRead为准 26 | Users Users `json:"-" gorm:"many2many:message_user;constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"` 27 | } 28 | 29 | type MessageUser struct { 30 | MessageID int `json:"message_id" gorm:"primaryKey"` 31 | UserID int `json:"user_id" gorm:"primaryKey"` 32 | HasRead bool `json:"has_read" gorm:"default:false"` // 兼容旧版 33 | } 34 | 35 | type MessageType string 36 | 37 | const ( 38 | MessageTypeFavorite MessageType = "favorite" 39 | MessageTypeReply MessageType = "reply" 40 | MessageTypeMention MessageType = "mention" 41 | MessageTypeModify MessageType = "modify" // including fold and delete 42 | MessageTypePermission MessageType = "permission" 43 | MessageTypeReport MessageType = "report" 44 | MessageTypeReportDealt MessageType = "report_dealt" 45 | MessageTypeMail MessageType = "mail" 46 | MessageTypeSensitive MessageType = "sensitive" 47 | ) 48 | 49 | func (messages Messages) Preprocess(c *fiber.Ctx) error { 50 | for i := 0; i < len(messages); i++ { 51 | err := messages[i].Preprocess(c) 52 | if err != nil { 53 | return err 54 | } 55 | } 56 | return nil 57 | } 58 | 59 | func (message *Message) Preprocess(_ *fiber.Ctx) error { 60 | message.MessageID = message.ID 61 | return nil 62 | } 63 | 64 | func (message *Message) AfterCreate(tx *gorm.DB) (err error) { 65 | mapping := make([]MessageUser, len(message.Recipients)) 66 | for i, userID := range message.Recipients { 67 | mapping[i] = MessageUser{ 68 | MessageID: message.ID, 69 | UserID: userID, 70 | } 71 | } 72 | return tx.Create(&mapping).Error 73 | } 74 | -------------------------------------------------------------------------------- /models/notification.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "math/rand" 10 | "net/http" 11 | "regexp" 12 | "strings" 13 | "sync" 14 | "time" 15 | 16 | "github.com/rs/zerolog/log" 17 | 18 | "treehole_next/config" 19 | "treehole_next/utils" 20 | 21 | "golang.org/x/exp/slices" 22 | "gorm.io/gorm/clause" 23 | 24 | "github.com/goccy/go-json" 25 | ) 26 | 27 | const ( 28 | timeout = time.Second * 10 29 | ) 30 | 31 | var client = http.Client{Timeout: timeout} 32 | 33 | type Notifications []Notification 34 | 35 | type Notification struct { 36 | // Should be same as CrateModel in notification project 37 | Title string `json:"message"` 38 | Description string `json:"description"` 39 | Data any `json:"data"` 40 | Type MessageType `json:"code"` 41 | URL string `json:"url"` 42 | Recipients []int `json:"recipients"` 43 | } 44 | 45 | func readRespNotification(body io.ReadCloser) Notification { 46 | defer func(body io.ReadCloser) { 47 | err := body.Close() 48 | if err != nil { 49 | log.Err(err).Str("model", "Notification").Msg("error close body") 50 | } 51 | }(body) 52 | 53 | data, err := io.ReadAll(body) 54 | if err != nil { 55 | log.Err(err).Str("model", "Notification").Msg("error read body") 56 | return Notification{} 57 | } 58 | var response Notification 59 | err = json.Unmarshal(data, &response) 60 | if err != nil { 61 | log.Err(err).Str("model", "Notification").Msg("error unmarshal body") 62 | return Notification{} 63 | } 64 | return response 65 | } 66 | 67 | func (messages Notifications) Merge(newNotification Notification) Notifications { 68 | if len(newNotification.Recipients) == 0 { 69 | return messages 70 | } 71 | 72 | newMerge := newNotification.Recipients 73 | for _, message := range messages { 74 | old := message.Recipients 75 | for _, r1 := range old { 76 | for id, r2 := range newMerge { 77 | if r1 == r2 { 78 | newMerge = append(newMerge[:id], newMerge[id+1:]...) 79 | break 80 | } 81 | } 82 | } 83 | if len(newMerge) == 0 { 84 | return messages 85 | } 86 | } 87 | 88 | newNotification.Recipients = newMerge 89 | return append(messages, newNotification) 90 | } 91 | 92 | func (messages Notifications) Send() error { 93 | if messages == nil { 94 | return nil 95 | } 96 | 97 | for _, message := range messages { 98 | _, err := message.Send() 99 | if err != nil { 100 | return err 101 | } 102 | } 103 | return nil 104 | } 105 | 106 | // check user.config.Notify contain message.Type 107 | func (message *Notification) checkConfig() { 108 | // generate new recipients 109 | var newRecipient []int 110 | 111 | // find users 112 | var users []User 113 | result := DB.Find(&users, "id in ?", message.Recipients) 114 | if result.Error != nil { 115 | message.Recipients = newRecipient 116 | return 117 | } 118 | 119 | // filter recipients 120 | for _, user := range users { 121 | if slices.Contains(defaultUserConfig.Notify, string(message.Type)) && !slices.Contains(user.Config.Notify, string(message.Type)) { 122 | continue 123 | } 124 | newRecipient = append(newRecipient, user.ID) 125 | } 126 | message.Recipients = newRecipient 127 | } 128 | 129 | func (message Notification) Send() (Message, error) { 130 | // only for test 131 | // message["recipients"] = []int{1} 132 | 133 | var err error 134 | 135 | message.checkConfig() 136 | // return if no recipient 137 | if len(message.Recipients) == 0 { 138 | return Message{}, nil 139 | } 140 | 141 | // save to database first 142 | body := Message{ 143 | Type: message.Type, 144 | Title: message.Title, 145 | Description: message.Description, 146 | Data: message.Data, 147 | URL: message.URL, 148 | Recipients: message.Recipients, 149 | } 150 | err = DB.Omit(clause.Associations).Create(&body).Error 151 | if err != nil { 152 | log.Err(err).Str("model", "Notification").Msg("message save failed: " + err.Error()) 153 | return Message{}, err 154 | } 155 | if config.Config.NotificationUrl == "" { 156 | return Message{}, nil 157 | } 158 | message.Title = utils.StripContent(message.Title, 32) //varchar(32) 159 | message.Description = utils.StripContent(cleanNotificationDescription(message.Description), 64) //varchar(64) 160 | body.Title = message.Title 161 | body.Description = message.Description 162 | 163 | // construct form 164 | form, err := json.Marshal(message) 165 | if err != nil { 166 | log.Err(err).Str("model", "Notification").Msg("error encoding notification") 167 | return Message{}, err 168 | } 169 | 170 | // construct http request 171 | req, err := http.NewRequest( 172 | "POST", 173 | fmt.Sprintf("%s/messages", config.Config.NotificationUrl), 174 | bytes.NewBuffer(form), 175 | ) 176 | if err != nil { 177 | log.Err(err).Str("model", "Notification").Msg("error making request") 178 | return Message{}, err 179 | } 180 | req.Header.Add("Content-Type", "application/json") 181 | 182 | // bench and simulation 183 | if config.Config.Mode == "bench" { 184 | time.Sleep(time.Millisecond) 185 | return Message{}, nil 186 | } 187 | 188 | // get response 189 | resp, err := client.Do(req) 190 | if err != nil { 191 | log.Err(err).Str("model", "Notification").Msg("error sending notification") 192 | return Message{}, err 193 | } 194 | 195 | response := readRespNotification(resp.Body) 196 | if resp.StatusCode != 201 { 197 | log.Error().Str("model", "Notification").Any("response", response).Msg("notification response failed") 198 | return Message{}, errors.New(fmt.Sprint(response)) 199 | } 200 | 201 | return body, nil 202 | } 203 | 204 | var adminList struct { 205 | sync.RWMutex 206 | data []int 207 | } 208 | 209 | func InitAdminList() { 210 | // skip when bench 211 | if config.Config.Mode == "bench" || config.Config.AuthUrl == "" { 212 | return 213 | } 214 | 215 | // // http request 216 | // res, err := http.Get(config.Config.AuthUrl + "/users/admin") 217 | 218 | // // handle err 219 | // if err != nil { 220 | // log.Err(err).Str("model", "get admin").Msg("error sending auth server") 221 | // return 222 | // } 223 | 224 | // defer func() { 225 | // _ = res.Body.Close() 226 | // }() 227 | 228 | // if res.StatusCode != 200 { 229 | // log.Error().Str("model", "get admin").Msg("auth server response failed" + res.Status) 230 | // return 231 | // } 232 | 233 | // data, err := io.ReadAll(res.Body) 234 | // if err != nil { 235 | // log.Err(err).Str("model", "get admin").Msg("error reading auth server response") 236 | // return 237 | // } 238 | 239 | adminList.Lock() 240 | defer adminList.Unlock() 241 | 242 | // err = json.Unmarshal(data, &adminList.data) 243 | // if err != nil { 244 | // log.Err(err).Str("model", "get admin").Msg("error unmarshal auth server response") 245 | // return 246 | // } 247 | adminList.data = config.Config.NotifiableAdminIds 248 | 249 | // shuffle ids 250 | for i := range adminList.data { 251 | j := rand.Intn(i + 1) 252 | adminList.data[i], adminList.data[j] = adminList.data[j], adminList.data[i] 253 | } 254 | } 255 | 256 | func UpdateAdminList(ctx context.Context) { 257 | ticker := time.NewTicker(time.Minute) 258 | for { 259 | select { 260 | case <-ctx.Done(): 261 | return 262 | case <-ticker.C: 263 | InitAdminList() 264 | } 265 | } 266 | } 267 | 268 | var ( 269 | reMention = regexp.MustCompile(`#{1,2}\d+`) 270 | reFormula = regexp.MustCompile(`(?s)\${1,2}.*?\${1,2}`) 271 | reSticker = regexp.MustCompile(`!\[\]\(dx_\S+?\)`) 272 | reImage = regexp.MustCompile(`!\[.*?\]\(.*?\)`) 273 | ) 274 | 275 | func cleanNotificationDescription(content string) string { 276 | newContent := reMention.ReplaceAllString(content, "") 277 | newContent = reFormula.ReplaceAllString(newContent, "[公式]") 278 | newContent = reSticker.ReplaceAllString(newContent, "[表情]") 279 | newContent = reImage.ReplaceAllString(newContent, "[图片]") 280 | newContent = strings.ReplaceAll(newContent, "\n", "") 281 | if newContent == "" { 282 | return content 283 | } 284 | return newContent 285 | } 286 | -------------------------------------------------------------------------------- /models/punishment.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "errors" 5 | "time" 6 | 7 | "gorm.io/gorm" 8 | "gorm.io/gorm/clause" 9 | "gorm.io/plugin/dbresolver" 10 | ) 11 | 12 | // Punishment 13 | // a record of user punishment 14 | // when a record created, it can't be modified if other admins punish this user on the same floor 15 | // whether a user is banned to post on one division based on the latest / max(id) record 16 | // if admin want to modify punishment duration, manually modify the latest record of this user in database 17 | // admin can be granted update privilege on SQL view of this table 18 | type Punishment struct { 19 | ID int `json:"id" gorm:"primaryKey"` 20 | 21 | // time when this punishment creates 22 | CreatedAt time.Time `json:"created_at"` 23 | 24 | UpdatedAt time.Time `json:"updated_at"` 25 | 26 | // time when this punishment revoked 27 | DeletedAt gorm.DeletedAt `json:"-"` 28 | 29 | // start from end_time of previous punishment (punishment accumulation of different floors) 30 | // if no previous punishment or previous punishment end time less than time.Now() (synced), set start time time.Now() 31 | StartTime time.Time `json:"start_time" gorm:"not null"` 32 | 33 | // end_time of this punishment 34 | EndTime time.Time `json:"end_time" gorm:"not null"` 35 | 36 | Duration *time.Duration `json:"duration" swaggertype:"integer"` 37 | 38 | Day int `json:"day"` 39 | 40 | // user punished 41 | UserID int `json:"user_id" gorm:"not null;index"` 42 | 43 | // admin user_id who made this punish 44 | MadeBy int `json:"made_by,omitempty"` 45 | 46 | // punished because of this floor 47 | FloorID *int `json:"floor_id" gorm:"uniqueIndex"` 48 | 49 | Floor *Floor `json:"floor,omitempty" gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"` // foreign key 50 | 51 | DivisionID int `json:"division_id" gorm:"not null"` 52 | 53 | Division *Division `json:"division,omitempty"` // foreign key 54 | 55 | // reason 56 | Reason string `json:"reason" gorm:"size:128"` 57 | } 58 | 59 | type Punishments []*Punishment 60 | 61 | func (punishment *Punishment) Create() (*User, error) { 62 | var user User 63 | 64 | err := DB.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error { 65 | err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Take(&user, punishment.UserID).Error 66 | if err != nil { 67 | return err 68 | } 69 | 70 | var previousPunishment Punishment 71 | err = tx.Where("user_id = ? and floor_id = ?", user.ID, punishment.FloorID).Take(&previousPunishment).Error 72 | if err == nil { 73 | // return common.Forbidden("该用户已被禁言") 74 | 75 | // same as before, do nothing 76 | if previousPunishment.Duration == punishment.Duration && previousPunishment.Day == punishment.Day { 77 | return nil 78 | } 79 | 80 | // different duration, revoke previous punishment 81 | diffDuration := time.Duration(punishment.Day-previousPunishment.Day) * 24 * time.Hour 82 | 83 | previousPunishment.Duration = punishment.Duration 84 | previousPunishment.Day = punishment.Day 85 | previousPunishment.EndTime = previousPunishment.StartTime.Add(*punishment.Duration) 86 | previousPunishment.Reason = punishment.Reason 87 | previousPunishment.MadeBy = punishment.MadeBy 88 | // conflict with previous punishment if not equal 89 | // ignore it as it's rare 90 | previousPunishment.DivisionID = punishment.DivisionID 91 | 92 | if user.BanDivision[punishment.DivisionID] == nil { 93 | user.BanDivision[punishment.DivisionID] = &previousPunishment.EndTime 94 | } else { 95 | *user.BanDivision[punishment.DivisionID] = user.BanDivision[punishment.DivisionID].Add(diffDuration) 96 | } 97 | 98 | err = tx.Updates(&previousPunishment).Error 99 | if err != nil { 100 | return err 101 | } 102 | } else if !errors.Is(err, gorm.ErrRecordNotFound) { 103 | return err 104 | } else { 105 | user.OffenceCount += 1 106 | punishment.StartTime = time.Now() 107 | punishment.EndTime = punishment.StartTime.Add(*punishment.Duration) 108 | 109 | if user.BanDivision[punishment.DivisionID] == nil { 110 | user.BanDivision[punishment.DivisionID] = &punishment.EndTime 111 | } else { 112 | *user.BanDivision[punishment.DivisionID] = user.BanDivision[punishment.DivisionID].Add(*punishment.Duration) 113 | } 114 | 115 | err = tx.Create(&punishment).Error 116 | if err != nil { 117 | return err 118 | } 119 | } 120 | 121 | err = tx.Select("BanDivision", "OffenceCount").Save(&user).Error 122 | if err != nil { 123 | return err 124 | } 125 | 126 | return nil 127 | }) 128 | return &user, err 129 | } 130 | -------------------------------------------------------------------------------- /models/report.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "sync/atomic" 7 | "time" 8 | 9 | "github.com/rs/zerolog/log" 10 | 11 | "github.com/gofiber/fiber/v2" 12 | "github.com/opentreehole/go-common" 13 | "gorm.io/gorm" 14 | ) 15 | 16 | type Report struct { 17 | ID int `json:"id" gorm:"primaryKey"` 18 | CreatedAt time.Time `json:"time_created"` 19 | UpdatedAt time.Time `json:"time_updated"` 20 | ReportID int `json:"report_id" gorm:"-:all"` 21 | FloorID int `json:"floor_id"` 22 | HoleID int `json:"hole_id" gorm:"-:all"` 23 | Floor *Floor `json:"floor" gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"` 24 | UserID int `json:"-"` // the reporter's id, should keep a secret 25 | Reason string `json:"reason" gorm:"size:128"` 26 | Dealt bool `json:"dealt"` // the report has been dealt 27 | // who dealt the report 28 | DealtBy int `json:"dealt_by" gorm:"index"` 29 | Result string `json:"result" gorm:"size:128"` // deal result 30 | } 31 | 32 | func (report *Report) GetID() int { 33 | return report.ID 34 | } 35 | 36 | type Reports []*Report 37 | 38 | func (report *Report) Preprocess(c *fiber.Ctx) (err error) { 39 | err = report.Floor.SetDefaults(c) 40 | if err != nil { 41 | return err 42 | } 43 | for i := range report.Floor.Mention { 44 | err = report.Floor.Mention[i].SetDefaults(c) 45 | if err != nil { 46 | return err 47 | } 48 | } 49 | report.HoleID = report.Floor.HoleID 50 | return nil 51 | } 52 | 53 | func (reports Reports) Preprocess(c *fiber.Ctx) error { 54 | for i := 0; i < len(reports); i++ { 55 | _ = reports[i].Preprocess(c) 56 | } 57 | return nil 58 | } 59 | 60 | func LoadReportFloor(tx *gorm.DB) *gorm.DB { 61 | return tx.Preload("Floor.Mention").Preload("Floor") 62 | } 63 | 64 | func (report *Report) Create(c *fiber.Ctx, db ...*gorm.DB) error { 65 | var tx *gorm.DB 66 | if len(db) > 0 { 67 | tx = db[0] 68 | } else { 69 | tx = DB 70 | } 71 | userID, err := common.GetUserID(c) 72 | if err != nil { 73 | return err 74 | } 75 | 76 | existingReport := Report{} 77 | err = tx.Where("user_id = ? AND floor_id = ?", userID, report.FloorID).First(&existingReport).Error 78 | if err != nil { 79 | if !errors.Is(err, gorm.ErrRecordNotFound) { 80 | return err 81 | } 82 | } 83 | 84 | if errors.Is(err, gorm.ErrRecordNotFound) { 85 | report.UserID = userID 86 | err = tx.Create(&report).Error 87 | if err != nil { 88 | return err 89 | } 90 | 91 | report.ReportID = report.ID 92 | 93 | err = tx.Model(report).Association("Floor").Find(&report.Floor) 94 | if err != nil { 95 | return err 96 | } 97 | 98 | err = report.Preprocess(c) 99 | if err != nil { 100 | return err 101 | } 102 | } else { 103 | existingReport.Reason = existingReport.Reason + "\n" + report.Reason 104 | err = tx.Model(&existingReport).Updates(map[string]any{ 105 | "reason": existingReport.Reason, 106 | "dealt": false, 107 | }).Error // update reason and load floor in AfterUpdate hook 108 | if err != nil { 109 | return err 110 | } 111 | report.Floor = existingReport.Floor 112 | } 113 | 114 | return nil 115 | } 116 | 117 | func (report *Report) AfterFind(_ *gorm.DB) (err error) { 118 | report.ReportID = report.ID 119 | 120 | return nil 121 | } 122 | 123 | func (report *Report) AfterUpdate(tx *gorm.DB) (err error) { 124 | err = tx.Model(report).Association("Floor").Find(&report.Floor) 125 | if err != nil { 126 | return err 127 | } 128 | 129 | //err = report.Preprocess(nil) 130 | //if err != nil { 131 | // return err 132 | //} 133 | 134 | return nil 135 | } 136 | 137 | var adminCounter = new(int32) 138 | 139 | func (report *Report) SendCreate(_ *gorm.DB) error { 140 | // adminList.RLock() 141 | // defer adminList.RUnlock() 142 | if len(adminList.data) == 0 { 143 | return nil 144 | } 145 | 146 | // get counter 147 | currentCounter := atomic.AddInt32(adminCounter, 1) 148 | result := atomic.CompareAndSwapInt32(adminCounter, int32(len(adminList.data)), 0) 149 | if result { 150 | log.Info().Str("model", "get admin").Msg("adminCounter Reset") 151 | } 152 | userIDs := []int{adminList.data[currentCounter-1]} 153 | 154 | // construct message 155 | message := Notification{ 156 | Data: report, 157 | Recipients: userIDs, 158 | Description: fmt.Sprintf( 159 | "理由:%s,内容:%s", 160 | report.Reason, 161 | report.Floor.Content, 162 | ), 163 | Title: "您有举报需要处理", 164 | Type: MessageTypeReport, 165 | URL: fmt.Sprintf("/api/reports/%d", report.ID), 166 | } 167 | 168 | // send 169 | _, err := message.Send() 170 | return err 171 | } 172 | 173 | func (report *Report) SendModify(_ *gorm.DB) error { 174 | // get recipients 175 | userIDs := []int{report.UserID} 176 | 177 | // construct message 178 | message := Notification{ 179 | Data: report, 180 | Recipients: userIDs, 181 | Description: fmt.Sprintf( 182 | "处理结果:%s\n感谢您为维护社区秩序所做的贡献。", 183 | report.Result, 184 | ), 185 | Title: "您的举报已得到处理", 186 | Type: MessageTypeReportDealt, 187 | URL: fmt.Sprintf("/api/reports/%d", report.ID), 188 | } 189 | 190 | // send 191 | _, err := message.Send() 192 | return err 193 | } 194 | -------------------------------------------------------------------------------- /models/report_punishment.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "errors" 5 | "github.com/opentreehole/go-common" 6 | "gorm.io/gorm" 7 | "gorm.io/gorm/clause" 8 | "gorm.io/plugin/dbresolver" 9 | "time" 10 | ) 11 | 12 | type ReportPunishment struct { 13 | ID int `json:"id" gorm:"primaryKey"` 14 | 15 | // time when this report punishment creates 16 | CreatedAt time.Time `json:"created_at"` 17 | 18 | // time when this report punishment revoked 19 | DeletedAt gorm.DeletedAt `json:"deleted_at"` 20 | 21 | // start from end_time of previous punishment (punishment accumulation of different floors) 22 | // if no previous punishment or previous punishment end time less than time.Now() (synced), set start time time.Now() 23 | StartTime time.Time `json:"start_time" gorm:"not null"` 24 | 25 | // end_time of this report punishment 26 | EndTime time.Time `json:"end_time" gorm:"not null"` 27 | 28 | Duration *time.Duration `json:"duration"` 29 | 30 | // user punished 31 | UserID int `json:"user_id" gorm:"not null;index"` 32 | 33 | // admin user_id who made this punish 34 | MadeBy int `json:"made_by,omitempty"` 35 | 36 | // punished because of this report 37 | ReportId int `json:"report_id" gorm:"uniqueIndex"` 38 | 39 | Report *Report `json:"report,omitempty"` // foreign key 40 | 41 | // reason 42 | Reason string `json:"reason" gorm:"size:128"` 43 | } 44 | 45 | type ReportPunishments []*ReportPunishment 46 | 47 | func (reportPunishment *ReportPunishment) Create() (*User, error) { 48 | var user User 49 | 50 | err := DB.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error { 51 | err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Take(&user, reportPunishment.UserID).Error 52 | if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { 53 | return err 54 | } 55 | 56 | // if the user has been banned from this report 57 | var punishmentRecord ReportPunishment 58 | err = tx.Where("user_id = ? and report_id = ?", user.ID, reportPunishment.ReportId).Take(&punishmentRecord).Error 59 | if err == nil { 60 | return common.Forbidden("该用户已被限制使用举报功能") 61 | } else if !errors.Is(err, gorm.ErrRecordNotFound) { 62 | return err 63 | } 64 | 65 | var lastPunishment ReportPunishment 66 | err = tx.Where("user_id = ?", user.ID).Last(&lastPunishment).Error 67 | if err == nil { 68 | if lastPunishment.EndTime.Before(time.Now()) { 69 | reportPunishment.StartTime = time.Now() 70 | } else { 71 | reportPunishment.StartTime = lastPunishment.EndTime 72 | } 73 | } else if errors.Is(err, gorm.ErrRecordNotFound) { 74 | reportPunishment.StartTime = time.Now() 75 | } else { 76 | return err 77 | } 78 | 79 | reportPunishment.EndTime = reportPunishment.StartTime.Add(*reportPunishment.Duration) 80 | 81 | user.BanReport = &reportPunishment.EndTime 82 | user.BanReportCount += 1 83 | 84 | err = tx.Create(&reportPunishment).Error 85 | if err != nil { 86 | return err 87 | } 88 | 89 | err = tx.Select("BanReport", "BanReportCount").Save(&user).Error 90 | if err != nil { 91 | return err 92 | } 93 | 94 | return nil 95 | }) 96 | return &user, err 97 | } 98 | -------------------------------------------------------------------------------- /models/tag.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "sync" 7 | "time" 8 | "treehole_next/config" 9 | "treehole_next/utils/sensitive" 10 | 11 | "github.com/gofiber/fiber/v2" 12 | 13 | "github.com/opentreehole/go-common" 14 | "github.com/rs/zerolog/log" 15 | "golang.org/x/exp/slices" 16 | "gorm.io/gorm" 17 | "gorm.io/gorm/clause" 18 | 19 | "treehole_next/utils" 20 | ) 21 | 22 | type Tag struct { 23 | /// saved fields 24 | ID int `json:"id" gorm:"primaryKey"` 25 | CreatedAt time.Time `json:"-" gorm:"not null"` 26 | UpdatedAt time.Time `json:"-" gorm:"not null"` 27 | 28 | /// base info 29 | Name string `json:"name" gorm:"not null;unique;size:32"` 30 | Temperature int `json:"temperature" gorm:"not null;default:0"` 31 | 32 | IsZZMG bool `json:"-" gorm:"not null;default:false"` 33 | 34 | /// association info, should add foreign key 35 | Holes Holes `json:"-" gorm:"many2many:hole_tags;constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"` 36 | // auto sensitive check 37 | IsSensitive bool `json:"-" gorm:"index:idx_tag_actual_sensitive,priority:1"` 38 | 39 | // manual sensitive check 40 | IsActualSensitive *bool `json:"-" gorm:"index:idx_tag_actual_sensitive,priority:2"` 41 | /// generated field 42 | TagID int `json:"tag_id" gorm:"-:all"` 43 | 44 | Nsfw bool `json:"nsfw" gorm:"not null;default:false;index"` 45 | } 46 | 47 | type Tags []*Tag 48 | 49 | func (tag *Tag) GetID() int { 50 | return tag.ID 51 | } 52 | 53 | func (tag *Tag) AfterFind(_ *gorm.DB) (err error) { 54 | tag.TagID = tag.ID 55 | return nil 56 | } 57 | 58 | func (tag *Tag) BeforeCreate(_ *gorm.DB) (err error) { 59 | if len(tag.Name) > 0 && tag.Name[0] == '*' { 60 | tag.Nsfw = true 61 | } 62 | return nil 63 | } 64 | 65 | func (tag *Tag) AfterCreate(_ *gorm.DB) (err error) { 66 | tag.TagID = tag.ID 67 | return nil 68 | } 69 | 70 | func FindOrCreateTags(tx *gorm.DB, user *User, names []string) (Tags, error) { 71 | tags := make(Tags, 0) 72 | for i, name := range names { 73 | names[i] = strings.TrimSpace(name) 74 | } 75 | err := tx.Where("name in ?", names).Find(&tags).Error 76 | if err != nil { 77 | return nil, err 78 | } 79 | 80 | existTagNames := make([]string, 0) 81 | for _, tag := range tags { 82 | existTagNames = append(existTagNames, tag.Name) 83 | if !user.IsAdmin { 84 | if slices.ContainsFunc(config.Config.AdminOnlyTagIds, func(i int) bool { 85 | return i == tag.ID 86 | }) { 87 | return nil, common.Forbidden(fmt.Sprintf("标签 %s 为管理员专用标签", tag.Name)) 88 | } 89 | } 90 | } 91 | 92 | newTags := make(Tags, 0) 93 | for _, name := range names { 94 | name = strings.TrimSpace(name) 95 | if !slices.ContainsFunc(existTagNames, func(s string) bool { 96 | return strings.EqualFold(s, name) 97 | }) { 98 | newTags = append(newTags, &Tag{Name: name}) 99 | } 100 | } 101 | 102 | if len(newTags) == 0 { 103 | return tags, nil 104 | } 105 | for _, tag := range newTags { 106 | if !user.IsAdmin { 107 | if len(tag.Name) > 15 && len([]rune(tag.Name)) > 10 { 108 | return nil, common.BadRequest("标签长度不能超过 10 个字符") 109 | } 110 | if strings.HasPrefix(tag.Name, "#") { 111 | return nil, common.BadRequest("只有管理员才能创建 # 开头的 tag") 112 | } 113 | if strings.HasPrefix(tag.Name, "@") { 114 | return nil, common.BadRequest("只有管理员才能创建 @ 开头的 tag") 115 | } 116 | if strings.HasPrefix(tag.Name, "*") { 117 | return nil, common.BadRequest("只有管理员才能创建 * 开头的 tag") 118 | } 119 | } 120 | } 121 | 122 | var wg sync.WaitGroup 123 | for _, tag := range newTags { 124 | wg.Add(1) 125 | go func(tag *Tag) { 126 | sensitiveResp, err := sensitive.CheckSensitive(sensitive.ParamsForCheck{ 127 | Content: tag.Name, 128 | Id: time.Now().UnixNano(), 129 | TypeName: sensitive.TypeTag, 130 | }) 131 | if err != nil { 132 | return 133 | } 134 | tag.IsSensitive = !sensitiveResp.Pass 135 | wg.Done() 136 | }(tag) 137 | } 138 | wg.Wait() 139 | 140 | err = tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&newTags).Error 141 | 142 | go UpdateTagCache(nil) 143 | 144 | return append(tags, newTags...), err 145 | } 146 | 147 | func UpdateTagCache(tags Tags) { 148 | var err error 149 | if len(tags) == 0 { 150 | err := DB.Order("temperature desc").Find(&tags).Error 151 | if err != nil { 152 | log.Printf("update tag cache error: %s", err) 153 | } 154 | } 155 | err = utils.SetCache("tags", tags, 10*time.Minute) 156 | if err != nil { 157 | log.Printf("update tag cache error: %s", err) 158 | } 159 | } 160 | 161 | func (tag *Tag) Preprocess(c *fiber.Ctx) error { 162 | return Tags{tag}.Preprocess(c) 163 | } 164 | 165 | func (tags Tags) Preprocess(_ *fiber.Ctx) error { 166 | tagIDs := make([]int, len(tags)) 167 | IdTagMapping := make(map[int]*Tag) 168 | for i, tag := range tags { 169 | if tags[i].Sensitive() { 170 | tags[i].Name = "" 171 | } 172 | tagIDs[i] = tag.ID 173 | IdTagMapping[tag.ID] = tags[i] 174 | } 175 | return nil 176 | } 177 | 178 | func (tag *Tag) Sensitive() bool { 179 | if tag == nil { 180 | return false 181 | } 182 | if tag.IsActualSensitive != nil { 183 | return *tag.IsActualSensitive 184 | } 185 | return tag.IsSensitive 186 | } 187 | -------------------------------------------------------------------------------- /models/url_hostname_whitelist.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | type UrlHostnameWhitelist struct { 4 | ID int `json:"id" gorm:"primaryKey"` 5 | Hostname string `json:"hostname" gorm:"size:255;not null"` 6 | } 7 | -------------------------------------------------------------------------------- /models/user.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "time" 7 | 8 | "golang.org/x/exp/slices" 9 | 10 | "treehole_next/config" 11 | 12 | "github.com/gofiber/fiber/v2" 13 | "github.com/opentreehole/go-common" 14 | "github.com/rs/zerolog/log" 15 | "gorm.io/gorm" 16 | ) 17 | 18 | type User struct { 19 | /// base info 20 | ID int `json:"id" gorm:"primaryKey"` 21 | 22 | Config UserConfig `json:"config" gorm:"serializer:json;not null;default:\"{}\""` 23 | 24 | BanDivision map[int]*time.Time `json:"-" gorm:"serializer:json;not null;default:\"{}\""` 25 | 26 | OffenceCount int `json:"-" gorm:"not null;default:0"` 27 | 28 | BanReport *time.Time `json:"-" gorm:"serializer:json"` 29 | 30 | BanReportCount int `json:"-" gorm:"not null;default:0"` 31 | 32 | DefaultSpecialTag string `json:"default_special_tag" gorm:"size:32"` 33 | 34 | SpecialTags []string `json:"special_tags" gorm:"serializer:json;not null;default:\"[]\""` 35 | 36 | FavoriteGroupCount int `json:"favorite_group_count" gorm:"not null;default:0"` 37 | 38 | /// association fields, should add foreign key 39 | 40 | // holes owned by the user 41 | UserHoles Holes `json:"-" gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"` 42 | 43 | // floors owned by the user 44 | UserFloors Floors `json:"-" gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"` 45 | 46 | // reports made by the user; a user has many report 47 | UserReports Reports `json:"-"` 48 | 49 | // floors liked by the user 50 | UserLikedFloors Floors `json:"-" gorm:"many2many:floor_like;constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"` 51 | 52 | // floor history made by the user 53 | UserFloorHistory FloorHistorySlice `json:"-" gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"` 54 | 55 | // user punishments on division 56 | UserPunishments Punishments `json:"-"` 57 | 58 | // punishments made by this user 59 | UserMakePunishments Punishments `json:"-" gorm:"foreignKey:MadeBy"` 60 | 61 | // user punishments on report 62 | UserReportPunishments ReportPunishments `json:"-"` 63 | 64 | // report punishments made by this user 65 | UserMakeReportPunishments ReportPunishments `json:"-" gorm:"foreignKey:MadeBy"` 66 | 67 | UserSubscription Holes `json:"-" gorm:"many2many:user_subscription;constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"` 68 | 69 | /// dynamically generated field 70 | 71 | UserID int `json:"user_id" gorm:"-:all"` 72 | 73 | Permission struct { 74 | // 管理员权限到期时间 75 | Admin time.Time `json:"admin"` 76 | // key: division_id value: 对应分区禁言解除时间 77 | Silent map[int]*time.Time `json:"silent"` 78 | OffenseCount int `json:"offense_count"` 79 | } `json:"permission" gorm:"-:all"` 80 | 81 | // get from jwt 82 | IsAdmin bool `json:"is_admin" gorm:"-:all"` 83 | JoinedTime time.Time `json:"joined_time" gorm:"-:all"` 84 | Nickname string `json:"nickname" gorm:"-:all"` 85 | HasAnsweredQuestions bool `json:"has_answered_questions" gorm:"-:all"` 86 | } 87 | 88 | type Users []*User 89 | 90 | type UserConfig struct { 91 | // used when notify 92 | Notify []string `json:"notify"` 93 | 94 | // 对折叠内容的处理 95 | // fold 折叠, hide 隐藏, show 展示 96 | ShowFolded string `json:"show_folded"` 97 | } 98 | 99 | var defaultUserConfig = UserConfig{ 100 | Notify: []string{"mention", "favorite", "report"}, 101 | ShowFolded: "hide", 102 | } 103 | 104 | var showFoldedOptions = []string{"hide", "fold", "show"} 105 | 106 | func (user *User) GetID() int { 107 | return user.ID 108 | } 109 | 110 | func (user *User) AfterCreate(_ *gorm.DB) error { 111 | user.UserID = user.ID 112 | return nil 113 | } 114 | 115 | func (user *User) AfterFind(_ *gorm.DB) error { 116 | user.UserID = user.ID 117 | return nil 118 | } 119 | 120 | var ( 121 | maxTime time.Time 122 | minTime time.Time 123 | ) 124 | 125 | func init() { 126 | var err error 127 | maxTime, err = time.Parse(time.RFC3339, "9999-01-01T00:00:00+00:00") 128 | if err != nil { 129 | log.Fatal().Err(err).Send() 130 | } 131 | minTime = time.Unix(0, 0) 132 | } 133 | 134 | // GetCurrLoginUser get current login user 135 | // In dev or test mode, return a default admin user 136 | func GetCurrLoginUser(c *fiber.Ctx) (*User, error) { 137 | user := &User{ 138 | BanDivision: make(map[int]*time.Time), 139 | } 140 | if config.Config.Mode == "dev" || config.Config.Mode == "test" { 141 | user.ID = 1 142 | user.IsAdmin = true 143 | user.HasAnsweredQuestions = true 144 | return user, nil 145 | } 146 | 147 | if c.Locals("user") != nil { 148 | return c.Locals("user").(*User), nil 149 | } 150 | 151 | // get id 152 | userID, err := common.GetUserID(c) 153 | if err != nil { 154 | return nil, err 155 | } 156 | 157 | // parse JWT 158 | err = common.ParseJWTToken(common.GetJWTToken(c), user) 159 | if err != nil { 160 | return nil, err 161 | } 162 | 163 | // load user from database in transaction 164 | err = user.LoadUserByID(userID) 165 | 166 | if user.IsAdmin { 167 | user.Permission.Admin = maxTime 168 | } else { 169 | user.Permission.Admin = minTime 170 | } 171 | user.Permission.Silent = user.BanDivision 172 | user.Permission.OffenseCount = user.OffenceCount 173 | 174 | if config.Config.UserAllShowHidden { 175 | user.Config.ShowFolded = "hide" 176 | } 177 | 178 | // save user in c.Locals 179 | c.Locals("user", user) 180 | 181 | return user, err 182 | } 183 | 184 | func (user *User) LoadUserByID(userID int) error { 185 | return DB.Transaction(func(tx *gorm.DB) error { 186 | err := tx.Take(&user, userID).Error 187 | if err != nil { 188 | if errors.Is(err, gorm.ErrRecordNotFound) { 189 | // insert user if not found 190 | user.ID = userID 191 | user.Config = defaultUserConfig 192 | err = tx.Create(&user).Error 193 | if err != nil { 194 | return err 195 | } 196 | } else { 197 | return err 198 | } 199 | } 200 | 201 | err = CheckDefaultFavoriteGroup(tx, userID) 202 | if err != nil { 203 | return err 204 | } 205 | 206 | // check latest permission 207 | modified := false 208 | for divisionID := range user.BanDivision { 209 | endTime := user.BanDivision[divisionID] 210 | if endTime != nil && endTime.Before(time.Now()) { 211 | delete(user.BanDivision, divisionID) 212 | modified = true 213 | } 214 | } 215 | 216 | // check config 217 | if !slices.Contains(showFoldedOptions, user.Config.ShowFolded) { 218 | user.Config.ShowFolded = defaultUserConfig.ShowFolded 219 | modified = true 220 | } 221 | 222 | if user.Config.Notify == nil { 223 | user.Config.Notify = defaultUserConfig.Notify 224 | modified = true 225 | } 226 | 227 | if modified { 228 | err = tx.Select("BanDivision", "Config").Save(&user).Error 229 | if err != nil { 230 | return err 231 | } 232 | } 233 | 234 | return nil 235 | }) 236 | } 237 | 238 | func (user *User) BanDivisionMessage(divisionID int) string { 239 | if user.BanDivision[divisionID] == nil { 240 | return fmt.Sprintf("您在此板块已被禁言") 241 | } else { 242 | return fmt.Sprintf( 243 | "您在此板块已被禁言,解封时间:%s", 244 | user.BanDivision[divisionID].Format("2006-01-02 15:04:05")) 245 | } 246 | } 247 | 248 | func (user *User) BanReportMessage() string { 249 | if user.BanReport == nil { 250 | return fmt.Sprintf("您已被限制使用举报功能") 251 | } else { 252 | return fmt.Sprintf( 253 | "您已被限制使用举报功能,解封时间:%s", 254 | user.BanReport.Format("2006-01-02 15:04:05")) 255 | } 256 | } 257 | -------------------------------------------------------------------------------- /models/user_favorite.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "github.com/opentreehole/go-common" 5 | "time" 6 | 7 | "gorm.io/gorm" 8 | "gorm.io/gorm/clause" 9 | "gorm.io/plugin/dbresolver" 10 | 11 | "treehole_next/utils" 12 | ) 13 | 14 | type UserFavorite struct { 15 | UserID int `json:"user_id" gorm:"primaryKey"` 16 | FavoriteGroupID int `json:"favorite_group_id" gorm:"primaryKey"` 17 | HoleID int `json:"hole_id" gorm:"primaryKey"` 18 | CreatedAt time.Time `json:"time_created"` 19 | } 20 | 21 | type UserFavorites []UserFavorite 22 | 23 | func (UserFavorite) TableName() string { 24 | return "user_favorites" 25 | } 26 | 27 | func IsFavoriteGroupExist(tx *gorm.DB, userID int, favoriteGroupID int) bool { 28 | var num int64 29 | tx.Model(&FavoriteGroup{}).Where("user_id = ? AND favorite_group_id = ? AND deleted = false", userID, favoriteGroupID).Count(&num) 30 | return num > 0 31 | } 32 | 33 | // ModifyUserFavorite only take effect in the same favorite_group 34 | func ModifyUserFavorite(tx *gorm.DB, userID int, holeIDs []int, favoriteGroupID int) error { 35 | if len(holeIDs) == 0 { 36 | return nil 37 | } 38 | if !IsFavoriteGroupExist(tx, userID, favoriteGroupID) { 39 | return common.NotFound("收藏夹不存在") 40 | } 41 | if !IsHolesExist(tx, holeIDs) { 42 | return common.Forbidden("帖子不存在") 43 | } 44 | return tx.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error { 45 | var oldHoleIDs []int 46 | err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). 47 | Model(&UserFavorite{}).Where("user_id = ? AND favorite_group_id = ?", userID, favoriteGroupID). 48 | Pluck("hole_id", &oldHoleIDs).Error 49 | if err != nil { 50 | return err 51 | } 52 | 53 | // remove user_favorite that not in holeIDs 54 | var removingHoleIDMapping = make(map[int]bool) 55 | for _, holeID := range oldHoleIDs { 56 | removingHoleIDMapping[holeID] = true 57 | } 58 | for _, holeID := range holeIDs { 59 | if removingHoleIDMapping[holeID] { 60 | delete(removingHoleIDMapping, holeID) 61 | } 62 | } 63 | removingHoleIDs := utils.Keys(removingHoleIDMapping) 64 | if len(removingHoleIDs) > 0 { 65 | deleteUserFavorite := make(UserFavorites, 0) 66 | for _, holeID := range removingHoleIDs { 67 | deleteUserFavorite = append(deleteUserFavorite, UserFavorite{UserID: userID, HoleID: holeID, FavoriteGroupID: favoriteGroupID}) 68 | } 69 | err = tx.Delete(&deleteUserFavorite).Error 70 | if err != nil { 71 | return err 72 | } 73 | } 74 | 75 | // insert user_favorite that not in oldHoleIDs 76 | var newHoleIDMapping = make(map[int]bool) 77 | for _, holeID := range holeIDs { 78 | newHoleIDMapping[holeID] = true 79 | } 80 | for _, holeID := range oldHoleIDs { 81 | if newHoleIDMapping[holeID] { 82 | delete(newHoleIDMapping, holeID) 83 | } 84 | } 85 | newHoleIDs := utils.Keys(newHoleIDMapping) 86 | if len(newHoleIDs) > 0 { 87 | insertUserFavorite := make(UserFavorites, 0) 88 | for _, holeID := range newHoleIDs { 89 | insertUserFavorite = append(insertUserFavorite, UserFavorite{UserID: userID, HoleID: holeID, FavoriteGroupID: favoriteGroupID}) 90 | } 91 | err = tx.Create(&insertUserFavorite).Error 92 | if err != nil { 93 | return err 94 | } 95 | } 96 | return tx.Model(&FavoriteGroup{}).Where("user_id = ? AND favorite_group_id = ?", userID, favoriteGroupID).Update("count", len(holeIDs)).Error 97 | }) 98 | } 99 | 100 | func AddUserFavorite(tx *gorm.DB, userID int, holeID int, favoriteGroupID int) error { 101 | if !IsFavoriteGroupExist(tx, userID, favoriteGroupID) { 102 | return common.NotFound("收藏夹不存在") 103 | } 104 | if !IsHolesExist(tx, []int{holeID}) { 105 | return common.NotFound("帖子不存在") 106 | } 107 | var err = tx.Clauses(clause.OnConflict{ 108 | DoUpdates: clause.Assignments(Map{"created_at": time.Now()}), 109 | }).Create(&UserFavorite{ 110 | UserID: userID, 111 | HoleID: holeID, 112 | FavoriteGroupID: favoriteGroupID, 113 | }).Error 114 | if err != nil { 115 | return err 116 | } 117 | return tx.Clauses(dbresolver.Write).Model(&FavoriteGroup{}). 118 | Where("user_id = ? AND favorite_group_id = ?", userID, favoriteGroupID).Update("count", gorm.Expr("count + 1")).Error 119 | } 120 | 121 | // UserGetFavoriteData get all favorite data of a user 122 | func UserGetFavoriteData(tx *gorm.DB, userID int) ([]int, error) { 123 | data := make([]int, 0, 10) 124 | err := tx.Clauses(dbresolver.Write).Model(&UserFavorite{}).Where("user_id = ?", userID).Distinct(). 125 | Pluck("hole_id", &data).Error 126 | return data, err 127 | } 128 | 129 | // UserGetFavoriteDataByFavoriteGroup get favorite data in specific favorite group 130 | func UserGetFavoriteDataByFavoriteGroup(tx *gorm.DB, userID int, favoriteGroupID int) ([]int, error) { 131 | if !IsFavoriteGroupExist(tx, userID, favoriteGroupID) { 132 | return nil, common.NotFound("收藏夹不存在") 133 | } 134 | data := make([]int, 0, 10) 135 | err := tx.Clauses(dbresolver.Write).Model(&UserFavorite{}). 136 | Where("user_id = ? AND favorite_group_id = ?", userID, favoriteGroupID).Pluck("hole_id", &data).Error 137 | return data, err 138 | } 139 | 140 | // DeleteUserFavorite delete user favorite 141 | // if user favorite hole only once, delete the hole 142 | // otherwise, delete the favorite in the specific favorite group 143 | func DeleteUserFavorite(tx *gorm.DB, userID int, holeID int, favoriteGroupID int) error { 144 | if !IsFavoriteGroupExist(tx, userID, favoriteGroupID) { 145 | return common.NotFound("收藏夹不存在") 146 | } 147 | if !IsHolesExist(tx, []int{holeID}) { 148 | return common.NotFound("帖子不存在") 149 | } 150 | return tx.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error { 151 | err := tx.Delete(&UserFavorite{UserID: userID, HoleID: holeID, FavoriteGroupID: favoriteGroupID}).Error 152 | if err != nil { 153 | return err 154 | } 155 | return tx.Clauses(dbresolver.Write).Model(&FavoriteGroup{}).Where("user_id = ? AND favorite_group_id = ?", userID, favoriteGroupID).Update("count", gorm.Expr("count - 1")).Error 156 | }) 157 | } 158 | 159 | // MoveUserFavorite move holes that are really in the fromFavoriteGroup 160 | func MoveUserFavorite(tx *gorm.DB, userID int, holeIDs []int, fromFavoriteGroupID int, toFavoriteGroupID int) error { 161 | if fromFavoriteGroupID == toFavoriteGroupID { 162 | return nil 163 | } 164 | if len(holeIDs) == 0 { 165 | return nil 166 | } 167 | if !IsFavoriteGroupExist(tx, userID, fromFavoriteGroupID) || !IsFavoriteGroupExist(tx, userID, toFavoriteGroupID) { 168 | return common.NotFound("收藏夹不存在") 169 | } 170 | if !IsHolesExist(tx, holeIDs) { 171 | return common.Forbidden("帖子不存在") 172 | } 173 | return tx.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error { 174 | var oldHoleIDs []int 175 | err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). 176 | Model(&UserFavorite{}).Where("user_id = ? AND favorite_group_id = ?", userID, fromFavoriteGroupID). 177 | Pluck("hole_id", &oldHoleIDs).Error 178 | if err != nil { 179 | return err 180 | } 181 | 182 | // move user_favorite that in holeIDs 183 | var removingHoleIDMapping = make(map[int]bool) 184 | var removingHoleIDs []int 185 | for _, holeID := range oldHoleIDs { 186 | removingHoleIDMapping[holeID] = true 187 | } 188 | for _, holeID := range holeIDs { 189 | if removingHoleIDMapping[holeID] { 190 | removingHoleIDs = append(removingHoleIDs, holeID) 191 | } 192 | } 193 | if len(removingHoleIDs) > 0 { 194 | err = tx.Table("user_favorites"). 195 | Where("user_id = ? AND favorite_group_id = ? AND hole_id IN ?", userID, fromFavoriteGroupID, removingHoleIDs). 196 | Updates(map[string]interface{}{"favorite_group_id": toFavoriteGroupID}).Error 197 | if err != nil { 198 | return err 199 | } 200 | } 201 | err = tx.Model(&FavoriteGroup{}).Where("user_id = ? AND favorite_group_id = ?", userID, fromFavoriteGroupID).Update("count", gorm.Expr("count - ?", len(removingHoleIDs))).Error 202 | if err != nil { 203 | return err 204 | } 205 | return tx.Model(&FavoriteGroup{}).Where("user_id = ? AND favorite_group_id = ?", userID, toFavoriteGroupID).Update("count", gorm.Expr("count + ?", len(removingHoleIDs))).Error 206 | }) 207 | } 208 | -------------------------------------------------------------------------------- /models/user_subscription.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "time" 5 | 6 | "gorm.io/gorm" 7 | "gorm.io/gorm/clause" 8 | "gorm.io/plugin/dbresolver" 9 | ) 10 | 11 | type UserSubscription struct { 12 | UserID int `json:"user_id" gorm:"primaryKey"` 13 | HoleID int `json:"hole_id" gorm:"primaryKey"` 14 | CreatedAt time.Time `json:"time_created"` 15 | } 16 | 17 | type UserSubscriptions []UserSubscription 18 | 19 | func (UserSubscription) TableName() string { 20 | return "user_subscription" 21 | } 22 | 23 | func UserGetSubscriptionData(tx *gorm.DB, userID int) ([]int, error) { 24 | data := make([]int, 0, 10) 25 | err := tx.Clauses(dbresolver.Write).Raw("SELECT hole_id FROM user_subscription WHERE user_id = ? ORDER BY created_at", userID).Scan(&data).Error 26 | return data, err 27 | } 28 | 29 | func AddUserSubscription(tx *gorm.DB, userID int, holeID int) error { 30 | return tx.Clauses(clause.OnConflict{ 31 | DoUpdates: clause.Assignments(Map{"created_at": time.Now()}), 32 | }).Create(&UserSubscription{ 33 | UserID: userID, 34 | HoleID: holeID}).Error 35 | } 36 | -------------------------------------------------------------------------------- /models/user_test.go: -------------------------------------------------------------------------------- 1 | package models 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/opentreehole/go-common" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestParseJWT(t *testing.T) { 11 | var user User 12 | jwt := "Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ1aWQiOjE2LCJpc3MiOiJEU2lSa2NvWDJZV3dta3VqM3FFdFVxSE1uUnNvMjZQYiIsImlhdCI6MTY2MjUyNzg5OSwiaWQiOjE2LCJpc19hZG1pbiI6ZmFsc2UsIm5pY2tuYW1lIjoidXNlciIsIm9mZmVuc2VfY291bnQiOjAsInJvbGVzIjpbXSwidHlwZSI6ImFjY2VzcyIsImV4cCI6MTY2MjUyOTY5OX0.Ov_8cJay-Ta0jsPYUx1D-XDc_D1WK1iTdjnuEKAelaM" 13 | err := common.ParseJWTToken(jwt, &user) 14 | assert.Nilf(t, err, "ParseJWTToken failed: %v", err) 15 | } 16 | -------------------------------------------------------------------------------- /tests/default.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | var largeInt = 1145141919810 4 | -------------------------------------------------------------------------------- /tests/default_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestIndex(t *testing.T) { 10 | testAPI(t, "get", "/", 302, nil) 11 | testAPI(t, "get", "/api", 200, nil) 12 | web404 := testAPI(t, "get", "/404", 404, nil) 13 | assert.EqualValues(t, "Cannot GET /404", web404["message"]) 14 | } 15 | 16 | func TestDocs(t *testing.T) { 17 | testCommon(t, "get", "/docs", 302) 18 | testCommon(t, "get", "/docs/index.html", 200) 19 | } 20 | -------------------------------------------------------------------------------- /tests/division_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "strconv" 5 | "testing" 6 | 7 | . "treehole_next/models" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestGetDivision(t *testing.T) { 13 | var divisionPinned = []int{0, 2, 3, 1, largeInt} 14 | 15 | var d Division 16 | DB.First(&d, 1) 17 | d.Pinned = divisionPinned 18 | DB.Save(&d) 19 | 20 | var division Division 21 | testAPIModel(t, "get", "/api/divisions/1", 200, &division) 22 | // test pinned order 23 | respPinned := make([]int, 3) 24 | for i, p := range division.Holes { 25 | respPinned[i] = p.ID 26 | } 27 | assert.Equal(t, []int{2, 3, 1}, respPinned) 28 | } 29 | 30 | func TestListDivision(t *testing.T) { 31 | // return all divisions 32 | var length int64 33 | DB.Table("division").Count(&length) 34 | resp := testAPIArray(t, "get", "/api/divisions", 200) 35 | assert.Equal(t, length, int64(len(resp))) 36 | } 37 | 38 | func TestAddDivision(t *testing.T) { 39 | data := Map{"name": "TestAddDivision", "description": "TestAddDivisionDescription"} 40 | testAPI(t, "post", "/api/divisions", 201, data) 41 | 42 | // duplicate post, return 200 and change nothing 43 | data["description"] = "another" 44 | resp := testAPI(t, "post", "/api/divisions", 200, data) 45 | assert.Equal(t, "TestAddDivisionDescription", resp["description"]) 46 | } 47 | 48 | func TestModifyDivision(t *testing.T) { 49 | pinned := []int{3, 2, 5, 1, 4} 50 | data := Map{"name": "modify", "description": "modify", "pinned": pinned} 51 | 52 | var division Division 53 | testAPIModel(t, "put", "/api/divisions/1", 200, &division, data) 54 | 55 | // test modify 56 | assert.Equal(t, "modify", division.Name) 57 | assert.Equal(t, "modify", division.Description) 58 | 59 | // test pinned order 60 | respPinned := make([]int, 5) 61 | for i, d := range division.Holes { 62 | respPinned[i] = d.ID 63 | } 64 | assert.Equal(t, pinned, respPinned) 65 | } 66 | 67 | func TestDeleteDivision(t *testing.T) { 68 | id := 3 69 | toID := 2 70 | 71 | hole := Hole{DivisionID: id} 72 | DB.Create(&hole) 73 | testAPI(t, "delete", "/api/divisions/"+strconv.Itoa(id), 204, Map{"to": toID}) 74 | testAPI(t, "delete", "/api/divisions/"+strconv.Itoa(id), 204, Map{}) // repeat delete 75 | 76 | // deleted 77 | var d Division 78 | result := DB.First(&d, id) 79 | assert.True(t, result.Error != nil) 80 | 81 | // hole moved 82 | DB.First(&hole, hole.ID) 83 | assert.Equal(t, toID, hole.DivisionID) 84 | 85 | } 86 | 87 | func TestDeleteDivisionDefaultValue(t *testing.T) { 88 | id := 4 89 | toID := 1 90 | 91 | // if create hole here, say database lock, pending enquiry 92 | var hole, getHole Hole 93 | DB.Where("division_id = ?", id).First(&hole) 94 | testAPI(t, "delete", "/api/divisions/"+strconv.Itoa(id), 204, Map{}) 95 | 96 | // hole moved 97 | DB.Take(&getHole, hole.ID) 98 | assert.Equal(t, toID, getHole.DivisionID) 99 | 100 | } 101 | -------------------------------------------------------------------------------- /tests/favorite_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "testing" 5 | 6 | . "treehole_next/models" 7 | 8 | "github.com/stretchr/testify/assert" 9 | "golang.org/x/exp/slices" 10 | ) 11 | 12 | func TestListFavorites(t *testing.T) { 13 | var holes Holes 14 | testAPIModel(t, "get", "/api/user/favorites", 200, &holes) 15 | assert.EqualValues(t, 10, len(holes)) 16 | } 17 | 18 | func TestAddFavorite(t *testing.T) { 19 | data := Map{"hole_id": 11} 20 | testAPI(t, "post", "/api/user/favorites", 201, data) 21 | testAPI(t, "post", "/api/user/favorites", 201, data) // duplicated, refresh updated_at 22 | } 23 | 24 | func TestModifyFavorites(t *testing.T) { 25 | data := Map{"hole_ids": []int{1, 2, 5, 6, 7}} 26 | testAPI(t, "put", "/api/user/favorites", 201, data) 27 | testAPI(t, "put", "/api/user/favorites", 201, data) // duplicated 28 | var userFavorites []UserFavorite 29 | DB.Where("user_id = ?", 1).Find(&userFavorites) 30 | assert.EqualValues(t, 5, len(userFavorites)) 31 | } 32 | 33 | func TestDeleteFavorite(t *testing.T) { 34 | data := Map{"hole_id": 1} 35 | testAPI(t, "delete", "/api/user/favorites", 200, data) 36 | var userFavorites []UserFavorite 37 | DB.Where("user_id = ?", 1).Find(&userFavorites) 38 | assert.EqualValues(t, false, slices.Contains(userFavorites, UserFavorite{UserID: 1, HoleID: 1})) 39 | favouriteLen := len(userFavorites) 40 | 41 | testAPI(t, "delete", "/api/user/favorites", 200, data) // duplicated 42 | DB.Where("user_id = ?", 1).Find(&userFavorites) 43 | assert.EqualValues(t, favouriteLen, len(userFavorites)) 44 | } 45 | -------------------------------------------------------------------------------- /tests/floor_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "strconv" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/goccy/go-json" 9 | 10 | . "treehole_next/config" 11 | . "treehole_next/models" 12 | 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestListFloorsInAHole(t *testing.T) { 17 | var hole Hole 18 | DB.Where("division_id = ?", 7).First(&hole) 19 | var floors Floors 20 | testAPIModel(t, "get", "/api/holes/"+strconv.Itoa(hole.ID)+"/floors", 200, &floors) 21 | assert.EqualValues(t, Config.Size, len(floors)) 22 | if len(floors) != 0 { 23 | assert.EqualValues(t, "1", floors[0].Content) 24 | } 25 | 26 | // size 27 | size := 38 28 | data := Map{"size": size} 29 | testAPIModelWithQuery(t, "get", "/api/holes/"+strconv.Itoa(hole.ID)+"/floors", 200, &floors, data) 30 | assert.EqualValues(t, size, len(floors)) 31 | if len(floors) != 0 { 32 | assert.EqualValues(t, "1", floors[0].Content) 33 | } 34 | 35 | // offset 36 | offset := 7 37 | data = Map{"offset": offset} 38 | testAPIModelWithQuery(t, "get", "/api/holes/"+strconv.Itoa(hole.ID)+"/floors", 200, &floors, data) 39 | assert.EqualValues(t, Config.Size, len(floors)) 40 | if len(floors) != 0 { 41 | assert.EqualValues(t, strings.Repeat("1", offset+1), floors[0].Content) 42 | } 43 | } 44 | 45 | func TestListFloorsOld(t *testing.T) { 46 | var hole Hole 47 | DB.Where("division_id = ?", 7).First(&hole) 48 | data := Map{"hole_id": hole.ID} 49 | var floors Floors 50 | testAPIModelWithQuery(t, "get", "/api/floors", 200, &floors, data) 51 | assert.EqualValues(t, Config.MaxSize, len(floors)) 52 | if len(floors) != 0 { 53 | assert.EqualValues(t, "1", floors[0].Content) 54 | } 55 | } 56 | 57 | func TestGetFloor(t *testing.T) { 58 | var hole Hole 59 | DB.Where("division_id = ?", 7).First(&hole) 60 | var floor Floor 61 | DB.Where("hole_id = ?", hole.ID).First(&floor) 62 | var getFloor Floor 63 | testAPIModel(t, "get", "/api/floors/"+strconv.Itoa(floor.ID), 200, &getFloor) 64 | assert.EqualValues(t, floor.Content, getFloor.Content) 65 | 66 | testAPIModel(t, "get", "/api/floors/"+strconv.Itoa(largeInt), 404, &getFloor) 67 | } 68 | 69 | func TestCreateFloor(t *testing.T) { 70 | var hole Hole 71 | DB.Where("division_id = ?", 7).Offset(1).First(&hole) 72 | content := "123" 73 | data := Map{"content": content} 74 | var getFloor Floor 75 | testAPIModel(t, "post", "/api/holes/"+strconv.Itoa(hole.ID)+"/floors", 201, &getFloor, data) 76 | assert.EqualValues(t, content, getFloor.Content) 77 | 78 | var floors Floors 79 | DB.Where("hole_id = ?", hole.ID).Find(&floors) 80 | assert.EqualValues(t, 2, len(floors)) 81 | 82 | testAPIModel(t, "post", "/api/holes/"+strconv.Itoa(largeInt)+"/floors", 404, &getFloor, data) 83 | } 84 | 85 | func TestCreateFloorOld(t *testing.T) { 86 | var hole Hole 87 | DB.Where("division_id = ?", 7).Offset(2).First(&hole) 88 | content := "1234" 89 | data := Map{"hole_id": hole.ID, "content": content} 90 | type CreateOLdResponse struct { 91 | Data Floor 92 | Message string 93 | } 94 | var getFloor CreateOLdResponse 95 | rsp := testCommon(t, "post", "/api/floors", 201, data) 96 | err := json.Unmarshal(rsp, &getFloor) 97 | assert.Nilf(t, err, "Unmarshal Failed") 98 | assert.EqualValues(t, content, getFloor.Data.Content) 99 | 100 | var floors Floors 101 | DB.Where("hole_id = ?", hole.ID).Find(&floors) 102 | assert.EqualValues(t, 2, len(floors)) 103 | if len(floors) != 0 { 104 | assert.EqualValues(t, content, floors[1].Content) 105 | } 106 | 107 | testCommon(t, "post", "/api/holes/"+strconv.Itoa(123456)+"/floors", 404, data) 108 | } 109 | 110 | func TestModifyFloor(t *testing.T) { 111 | var hole Hole 112 | DB.Where("division_id = ?", 7).Offset(3).First(&hole) 113 | var floor Floor 114 | DB.Where("hole_id = ?", hole.ID).First(&floor) 115 | content := "12341234" 116 | data := Map{"content": content} 117 | var getFloor Floor 118 | 119 | // modify content 120 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 200, data) 121 | 122 | DB.Find(&getFloor, floor.ID) 123 | assert.EqualValues(t, content, getFloor.Content) 124 | 125 | // modify fold 126 | // test 1: fold == ["test"], fold_v2 == "" 127 | data = Map{"fold": []string{"test"}} 128 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 200, data) 129 | DB.Find(&getFloor, floor.ID) 130 | assert.EqualValues(t, data["fold"].([]string)[0], getFloor.Fold) 131 | 132 | // test2: fold == [], fold_v2 == "": expect reset fold 133 | data = Map{"fold": []string{}} 134 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 200, data) 135 | DB.Find(&getFloor, floor.ID) 136 | assert.EqualValues(t, "", getFloor.Fold) 137 | 138 | // test3: fold == [], fold_v2 == "test_test" 139 | data = Map{"fold_v2": "test_test"} 140 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 200, data) 141 | DB.Find(&getFloor, floor.ID) 142 | assert.EqualValues(t, data["fold_v2"], getFloor.Fold) 143 | 144 | // test4: fold == [], fold_v2 == "": expect reset fold 145 | data = Map{"fold": []string{}} 146 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 200, data) 147 | DB.Find(&getFloor, floor.ID) 148 | assert.EqualValues(t, "", getFloor.Fold) 149 | 150 | // test5: fold == ["test"], fold_v2 == "test_test": expect "test_test", fold_v2 has the priority 151 | data = Map{"fold": []string{"test"}, "fold_v2": "test_test"} 152 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 200, data) 153 | DB.Find(&getFloor, floor.ID) 154 | assert.EqualValues(t, "test_test", getFloor.Fold) 155 | 156 | // test6: fold == nil, fold_v2 == "": do nothing; 无效请求 157 | data = Map{} 158 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 400, data) 159 | DB.Find(&getFloor, floor.ID) 160 | assert.EqualValues(t, "test_test", getFloor.Fold) 161 | 162 | // modify like add old 163 | data = Map{"like": "add"} 164 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 200, data) 165 | DB.Find(&getFloor, floor.ID) 166 | assert.EqualValues(t, 1, getFloor.Like) 167 | 168 | // modify like reset old 169 | data = Map{"like": "cancel"} 170 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 200, data) 171 | DB.Find(&getFloor, floor.ID) 172 | assert.EqualValues(t, 0, getFloor.Like) 173 | } 174 | 175 | func TestModifyFloorLike(t *testing.T) { 176 | var hole Hole 177 | DB.Where("division_id = ?", 7).Offset(4).First(&hole) 178 | var floor Floor 179 | DB.Where("hole_id = ?", hole.ID).First(&floor) 180 | 181 | // like 182 | for i := 0; i < 10; i++ { 183 | testAPI(t, "post", "/api/floors/"+strconv.Itoa(floor.ID)+"/like/1", 200) 184 | } 185 | DB.First(&floor, floor.ID) 186 | assert.EqualValues(t, 1, floor.Like) 187 | assert.EqualValues(t, 0, floor.Dislike) 188 | 189 | // dislike 190 | for i := 0; i < 15; i++ { 191 | testAPI(t, "post", "/api/floors/"+strconv.Itoa(floor.ID)+"/like/-1", 200) 192 | } 193 | DB.First(&floor, floor.ID) 194 | assert.EqualValues(t, 0, floor.Like) 195 | assert.EqualValues(t, 1, floor.Dislike) 196 | 197 | // reset 198 | testAPI(t, "post", "/api/floors/"+strconv.Itoa(floor.ID)+"/like/0", 200) 199 | DB.First(&floor, floor.ID) 200 | assert.EqualValues(t, 0, floor.Like) 201 | } 202 | 203 | func TestDeleteFloor(t *testing.T) { 204 | var hole Hole 205 | DB.Where("division_id = ?", 7).Offset(5).First(&hole) 206 | var floor Floor 207 | DB.Where("hole_id = ?", hole.ID).First(&floor) 208 | content := "1234567" 209 | data := Map{"delete_reason": content} 210 | 211 | testAPI(t, "delete", "/api/floors/"+strconv.Itoa(floor.ID), 200, data) 212 | 213 | DB.First(&floor, floor.ID) 214 | assert.EqualValues(t, true, floor.Deleted) 215 | var floorHistory FloorHistory 216 | DB.Where("floor_id = ?", floor.ID).First(&floorHistory) 217 | assert.EqualValues(t, content, floorHistory.Reason) 218 | 219 | // permission 220 | floor = Floor{} 221 | DB.Where("hole_id = ?", hole.ID).Offset(1).First(&floor) 222 | testAPI(t, "delete", "/api/floors/"+strconv.Itoa(floor.ID), 200, data) 223 | } 224 | -------------------------------------------------------------------------------- /tests/hole_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "strconv" 5 | "strings" 6 | "testing" 7 | 8 | . "treehole_next/config" 9 | . "treehole_next/models" 10 | "treehole_next/utils" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestListHoleInADivision(t *testing.T) { 16 | var holes Holes 17 | var ids []int 18 | 19 | DB.Raw("SELECT id FROM hole WHERE division_id = 6 AND hidden = 0 ORDER BY updated_at DESC").Scan(&ids) 20 | 21 | testAPIModel(t, "get", "/api/divisions/6/holes", 200, &holes) 22 | assert.Equal(t, ids[:Config.HoleFloorSize], utils.Models2IDSlice(holes)) 23 | 24 | testAPIModel(t, "get", "/api/divisions/"+strconv.Itoa(largeInt)+"/holes", 200, &holes) // return empty holes 25 | testAPI(t, "get", "/api/divisions/"+strings.Repeat(strconv.Itoa(largeInt), 15)+"/holes", 500) // huge divisionID 26 | } 27 | 28 | func TestListHolesByTag(t *testing.T) { 29 | var tag Tag 30 | DB.Where("name = ?", "114").First(&tag) 31 | var holes Holes 32 | err := DB.Model(&tag).Association("Holes").Find(&holes) 33 | if err != nil { 34 | t.Fatal(err) 35 | } 36 | 37 | var getHoles Holes 38 | testAPIModel(t, "get", "/api/tags/114/holes", 200, &getHoles) 39 | assert.EqualValues(t, len(holes), len(getHoles)) 40 | 41 | // empty holes 42 | testAPIModel(t, "get", "/api/tags/115/holes", 200, &getHoles) 43 | assert.EqualValues(t, Holes{}, getHoles) 44 | } 45 | 46 | func TestCreateHole(t *testing.T) { 47 | content := "abcdef" 48 | data := Map{"content": content, "tags": []Map{{"name": "a"}, {"name": "ab"}, {"name": "abc"}}} 49 | testAPI(t, "post", "/api/divisions/1/holes", 201, data) 50 | data["tags"] = []Map{{"name": "abcd"}, {"name": "ab"}, {"name": "abc"}} // update temperature or create tag 51 | testAPI(t, "post", "/api/divisions/1/holes", 201, data) 52 | 53 | tag := Tag{} 54 | DB.Where("name = ?", "a").First(&tag) 55 | assert.EqualValues(t, 1, tag.Temperature) 56 | tag = Tag{} 57 | DB.Where("name = ?", "abc").First(&tag) 58 | assert.EqualValues(t, 2, tag.Temperature) 59 | assert.EqualValues(t, 2, DB.Model(&tag).Association("Holes").Count()) 60 | 61 | data = Map{"content": content, "tags": []Map{}} 62 | testAPI(t, "post", "/api/divisions/1/holes", 400, data) // at least one tag 63 | 64 | content = strings.Repeat("~", 15001) 65 | data = Map{"content": content, "tags": []Map{{"name": "a"}, {"name": "ab"}, {"name": "abc"}}} 66 | testAPI(t, "post", "/api/divisions/1/holes", 400, data) // data no more than 10000 67 | 68 | tags := make([]Map, 11) 69 | for i := range tags { 70 | tags[i] = Map{"name": strconv.Itoa(i)} 71 | } 72 | data = Map{"content": "123456789", "tags": tags} // at most 10 tags 73 | testAPI(t, "post", "/api/divisions/1/holes", 400, data) 74 | } 75 | 76 | func TestCreateHoleOld(t *testing.T) { 77 | content := "abcdef" 78 | tagName := []Map{{"name": "d"}, {"name": "de"}, {"name": "def"}} 79 | division_id := 1 80 | data := Map{"content": content, "tags": tagName, "division_id": division_id} 81 | testAPI(t, "post", "/api/holes", 201, data) 82 | tagName = []Map{{"name": "abc"}, {"name": "defg"}, {"name": "de"}} 83 | data = Map{"content": content, "tags": tagName, "division_id": division_id} 84 | testAPI(t, "post", "/api/holes", 201, data) 85 | 86 | var holes Holes 87 | var tag Tag 88 | DB.Where("name = ?", "def").First(&tag) 89 | err := DB.Model(&tag).Association("Holes").Find(&holes) 90 | if err != nil { 91 | t.Fatal(err) 92 | } 93 | } 94 | 95 | func TestModifyHole(t *testing.T) { 96 | var tag Tag 97 | DB.Where("name = ?", "111").First(&tag) 98 | var holes Holes 99 | err := DB.Model(&tag).Association("Holes").Find(&holes) 100 | if err != nil { 101 | t.Fatal(err) 102 | } 103 | 104 | tagName := []Map{{"name": "111"}, {"name": "d"}, {"name": "de"}, {"name": "def"}} 105 | division_id := 5 106 | data := Map{"tags": tagName, "division_id": division_id} 107 | testAPI(t, "put", "/api/holes/"+strconv.Itoa(holes[0].ID), 200, data) 108 | 109 | DB.Preload("Tags").Where("id = ?", holes[0].ID).Find(&holes[0]) 110 | 111 | var getTagName []Map 112 | for _, v := range holes[0].Tags { 113 | getTagName = append(getTagName, Map{"name": v.Name}) 114 | } 115 | assert.EqualValues(t, tagName, getTagName) 116 | assert.EqualValues(t, division_id, holes[0].DivisionID) 117 | 118 | // default schemas 119 | testAPI(t, "put", "/api/holes/"+strconv.Itoa(holes[0].ID), 400, Map{}) // bad request if modify nothing 120 | DB.Where("id = ?", holes[0].ID).Find(&holes[0]) 121 | assert.Equal(t, division_id, holes[0].DivisionID) 122 | } 123 | 124 | func TestDeleteHole(t *testing.T) { 125 | var hole Hole 126 | holeID := 10 127 | testAPI(t, "delete", "/api/holes/"+strconv.Itoa(holeID), 204) 128 | testAPI(t, "delete", "/api/holes/"+strconv.Itoa(largeInt), 404) 129 | DB.Where("id = ?", 10).Find(&hole) 130 | assert.Equal(t, true, hole.Hidden) 131 | } 132 | -------------------------------------------------------------------------------- /tests/init.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "strconv" 5 | "strings" 6 | 7 | "github.com/rs/zerolog/log" 8 | 9 | "treehole_next/config" 10 | . "treehole_next/models" 11 | ) 12 | 13 | func init() { 14 | initTestDivision() 15 | initTestHoles() 16 | initTestFloors() 17 | initTestTags() 18 | initTestFavorites() 19 | initTestReports() 20 | 21 | config.Config.OpenSensitiveCheck = false 22 | } 23 | 24 | func initTestDivision() { 25 | divisions := make(Divisions, 10) 26 | for i := range divisions { 27 | divisions[i] = &Division{ 28 | ID: i + 1, 29 | Name: strconv.Itoa(i), 30 | Description: strconv.Itoa(i), 31 | } 32 | } 33 | holes := make(Holes, 10) 34 | for i := range holes { 35 | holes[i] = &Hole{ 36 | DivisionID: 1, 37 | } 38 | } 39 | holes[9].DivisionID = 4 // for TestDeleteDivisionDefaultValue 40 | err := DB.Create(&divisions).Error 41 | if err != nil { 42 | log.Fatal().Err(err).Send() 43 | } 44 | err = DB.Create(&holes).Error 45 | if err != nil { 46 | log.Fatal().Err(err).Send() 47 | } 48 | } 49 | 50 | func initTestHoles() { 51 | holes := make(Holes, 10) 52 | for i := range holes { 53 | holes[i] = &Hole{ 54 | DivisionID: 6, 55 | } 56 | } 57 | tag := Tag{Name: "114", Temperature: 15} 58 | holes[1].Tags = Tags{&tag} 59 | holes[2].Tags = Tags{&tag} // here it will insert twice in latest version gorm 60 | holes[3].Tags = Tags{{Name: "111", Temperature: 23}, {Name: "222", Temperature: 45}} 61 | err := DB.Create(&holes).Error 62 | if err != nil { 63 | log.Fatal().Err(err).Send() 64 | } 65 | tag = Tag{Name: "115"} 66 | err = DB.Create(&tag).Error 67 | if err != nil { 68 | log.Fatal().Err(err).Send() 69 | } 70 | } 71 | 72 | func initTestFloors() { 73 | holes := make(Holes, 10) 74 | for i := range holes { 75 | holes[i] = &Hole{ 76 | DivisionID: 7, 77 | } 78 | } 79 | for i := 1; i <= 50; i++ { 80 | holes[0].Floors = append(holes[0].Floors, &Floor{Content: strings.Repeat("1", i), Ranking: i - 1}) 81 | } 82 | holes[0].Floors[10].Mention = Floors{ 83 | {HoleID: 102}, 84 | {HoleID: 304}, 85 | } 86 | holes[0].Floors[11].Mention = Floors{ 87 | {HoleID: 506}, 88 | {HoleID: 708}, 89 | } 90 | holes[1].Floors = Floors{{Content: "123456789"}} // for TestCreate 91 | holes[2].Floors = Floors{{Content: "123456789"}} // for TestCreate 92 | holes[3].Floors = Floors{{Content: "123456789"}} // for TestModify 93 | holes[4].Floors = Floors{{Content: "123456789"}} // for TestModify like 94 | holes[5].Floors = Floors{{Content: "123456789", UserID: 1}, {Content: "23333", UserID: 5, Ranking: 1}} // for TestDelete 95 | err := DB.Create(&holes).Error 96 | if err != nil { 97 | log.Fatal().Err(err).Send() 98 | } 99 | } 100 | 101 | func initTestTags() { 102 | holes := make(Holes, 5) 103 | tags := make(Tags, 6) 104 | hole_tags := [][]int{ 105 | {0, 1, 2}, 106 | {3}, 107 | {0, 4}, 108 | {1, 0, 2}, 109 | {2, 3, 4}, 110 | {0, 4}, 111 | } // int[tag_id][hole_id] 112 | 113 | for i := range holes { 114 | holes[i] = &Hole{DivisionID: 8} 115 | } 116 | 117 | for i := range tags { 118 | tags[i] = &Tag{Name: strconv.Itoa(i + 1)} 119 | for _, v := range hole_tags[i] { 120 | tags[i].Holes = append(tags[i].Holes, holes[v]) 121 | } 122 | } 123 | 124 | tags[0].Temperature = 5 125 | tags[2].Temperature = 25 126 | tags[5].Temperature = 34 127 | err := DB.Create(&tags).Error 128 | if err != nil { 129 | log.Fatal().Err(err).Send() 130 | } 131 | } 132 | 133 | func initTestFavorites() { 134 | favoriteGroup := FavoriteGroup{Name: "test", UserID: 1, FavoriteGroupID: 0} 135 | err := DB.Create(&favoriteGroup).Error 136 | if err != nil { 137 | log.Fatal().Err(err).Send() 138 | } 139 | userFavorites := make([]UserFavorite, 10) 140 | for i := range userFavorites { 141 | userFavorites[i].HoleID = i + 1 142 | userFavorites[i].UserID = 1 143 | } 144 | err = DB.Create(&userFavorites).Error 145 | if err != nil { 146 | log.Fatal().Err(err).Send() 147 | } 148 | } 149 | 150 | const ( 151 | REPORT_BASE_ID = 1 152 | REPORT_FLOOR_BASE_ID = 1001 153 | ) 154 | 155 | func initTestReports() { 156 | hole := Hole{ID: 1000} 157 | floors := make(Floors, 20) 158 | for i := range floors { 159 | floors[i] = &Floor{ 160 | ID: REPORT_FLOOR_BASE_ID + i, 161 | HoleID: 1000, 162 | Ranking: i, 163 | UserID: 1, 164 | } 165 | } 166 | reports := make([]Report, 10) 167 | for i := range reports { 168 | reports[i].ID = REPORT_BASE_ID + i 169 | reports[i].FloorID = REPORT_FLOOR_BASE_ID + i 170 | reports[i].UserID = 1 171 | if i < 5 { 172 | reports[i].Dealt = true 173 | } 174 | } 175 | 176 | err := DB.Create(&hole).Error 177 | if err != nil { 178 | log.Fatal().Err(err).Send() 179 | } 180 | err = DB.Create(&floors).Error 181 | if err != nil { 182 | log.Fatal().Err(err).Send() 183 | } 184 | err = DB.Create(&reports).Error 185 | if err != nil { 186 | log.Fatal().Err(err).Send() 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /tests/report_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "strconv" 5 | "testing" 6 | 7 | "github.com/rs/zerolog/log" 8 | 9 | . "treehole_next/models" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestGetReport(t *testing.T) { 15 | reportID := REPORT_BASE_ID 16 | var report Report 17 | DB.First(&report, reportID) 18 | log.Info().Any("report", report).Send() 19 | 20 | var getReport Report 21 | testAPIModel(t, "get", "/api/reports/"+strconv.Itoa(reportID), 200, &getReport) 22 | assert.EqualValues(t, report.FloorID, getReport.FloorID) 23 | assert.EqualValues(t, report.FloorID, getReport.Floor.FloorID) 24 | } 25 | 26 | func TestListReport(t *testing.T) { 27 | data := Map{} 28 | 29 | var getReports Reports 30 | testAPIModelWithQuery(t, "get", "/api/reports", 200, &getReports, data) 31 | log.Printf("getReports: %+v\n", getReports) 32 | 33 | data = Map{"range": 1} 34 | testAPIModelWithQuery(t, "get", "/api/reports", 200, &getReports, data) 35 | log.Printf("getReports: %+v\n", getReports) 36 | 37 | data = Map{"range": 2} 38 | testAPIModelWithQuery(t, "get", "/api/reports", 200, &getReports, data) 39 | log.Printf("getReports: %+v\n", getReports) 40 | } 41 | 42 | func TestAddReport(t *testing.T) { 43 | data := Map{"floor_id": REPORT_FLOOR_BASE_ID + 14, "reason": "123456789"} 44 | 45 | testAPI(t, "post", "/api/reports", 204, data) 46 | } 47 | 48 | func TestDeleteReport(t *testing.T) { 49 | reportID := REPORT_BASE_ID + 7 50 | var getReport Report 51 | data := Map{"result": "123456789"} 52 | testAPI(t, "delete", "/api/reports/"+strconv.Itoa(reportID), 200, data) 53 | 54 | DB.First(&getReport, reportID) 55 | assert.EqualValues(t, true, getReport.Dealt) 56 | } 57 | -------------------------------------------------------------------------------- /tests/tag_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "strconv" 5 | "testing" 6 | 7 | . "treehole_next/models" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func init() { 13 | 14 | } 15 | 16 | func TestListTag(t *testing.T) { 17 | var length int64 18 | DB.Table("tag").Count(&length) 19 | resp := testAPIArray(t, "get", "/api/tags", 200) 20 | assert.Equal(t, length, int64(len(resp))) 21 | } 22 | 23 | func TestGetTag(t *testing.T) { 24 | id := 3 25 | 26 | var tag Tag 27 | DB.First(&tag, id) 28 | 29 | var newTag Tag 30 | testAPIModel(t, "get", "/api/tags/"+strconv.Itoa(id), 200, &newTag) 31 | assert.Equalf(t, tag.Name, newTag.Name, "get tag") 32 | } 33 | 34 | func TestCreateTag(t *testing.T) { 35 | data := Map{"name": "name"} 36 | testAPI(t, "post", "/api/tags", 201, data) 37 | 38 | // duplicate post, return 200 and change nothing 39 | testAPI(t, "post", "/api/tags", 200, data) 40 | } 41 | 42 | func TestModifyTag(t *testing.T) { 43 | id := 3 44 | data := Map{"name": "another", "temperature": 34} 45 | 46 | testAPI(t, "put", "/api/tags/"+strconv.Itoa(id), 200, data) 47 | 48 | var tag Tag 49 | DB.Model(&Tag{}).First(&tag, 3) 50 | assert.Equalf(t, "another", tag.Name, "modify tag name") 51 | assert.Equalf(t, 34, tag.Temperature, "modify tag tempeture") 52 | } 53 | 54 | func TestDeleteTag(t *testing.T) { 55 | 56 | // Move holes to existed tag 57 | fromName := "1" 58 | toName := "6" 59 | var id int 60 | DB.Model(Tag{}).Where("name = ?", fromName).Pluck("id", &id) 61 | data := Map{"to": toName} 62 | testAPI(t, "delete", "/api/tags/"+strconv.Itoa(id), 200, data) 63 | var tag Tag 64 | DB.Where("name = ?", toName).First(&tag) 65 | associationHolesLen := DB.Model(&tag).Association("Holes").Count() 66 | assert.EqualValuesf(t, 4, associationHolesLen, "move holes") 67 | assert.EqualValuesf(t, 39, tag.Temperature, "tag Temperature add") 68 | tag = Tag{} 69 | 70 | if result := DB.First(&tag, id); result.Error == nil { 71 | assert.Error(t, result.Error, "delete tags") 72 | } 73 | 74 | // Duplicated delete holes 75 | testAPI(t, "delete", "/api/tags/"+strconv.Itoa(id), 404, data) 76 | 77 | // Move holes to new tag 78 | id = 8 79 | data["to"] = "iii555" 80 | testAPI(t, "delete", "/api/tags/"+strconv.Itoa(id), 404, data) 81 | } 82 | -------------------------------------------------------------------------------- /tests/utils.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "net/http" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/goccy/go-json" 11 | 12 | "treehole_next/bootstrap" 13 | . "treehole_next/models" 14 | 15 | "github.com/hetiansu5/urlquery" 16 | "github.com/stretchr/testify/assert" 17 | ) 18 | 19 | type JsonData interface { 20 | Map | []Map 21 | } 22 | 23 | var App, _ = bootstrap.Init() 24 | 25 | // testCommon tests status code and returns response body in bytes 26 | func testCommon(t *testing.T, method string, route string, statusCode int, data ...Map) []byte { 27 | var requestData []byte 28 | var err error 29 | 30 | if len(data) > 0 && data[0] != nil { // data[0] is request data 31 | requestData, err = json.Marshal(data[0]) 32 | assert.Nilf(t, err, "encode request body") 33 | } 34 | req, err := http.NewRequest( 35 | strings.ToUpper(method), 36 | route, 37 | bytes.NewBuffer(requestData), 38 | ) 39 | req.Header.Add("Content-Type", "application/json") 40 | req.Header.Add("X-Consumer-Username", "1") // for common.GetUserID 41 | assert.Nilf(t, err, "constructs http request") 42 | 43 | res, err := App.Test(req, -1) 44 | assert.Nilf(t, err, "perform request") 45 | assert.Equalf(t, statusCode, res.StatusCode, "status code") 46 | 47 | responseBody, err := io.ReadAll(res.Body) 48 | assert.Nilf(t, err, "decode response") 49 | 50 | return responseBody 51 | } 52 | 53 | // testCommonQuery tests status code and returns response body in bytes 54 | func testCommonQuery(t *testing.T, method string, route string, statusCode int, data ...Map) []byte { 55 | var err error 56 | req, err := http.NewRequest( 57 | strings.ToUpper(method), 58 | route, 59 | nil, 60 | ) 61 | if len(data) > 0 && data[0] != nil { // data[0] is query data 62 | queryData, err := urlquery.Marshal(data[0]) 63 | req.URL.RawQuery = string(queryData) 64 | assert.Nilf(t, err, "encode request body") 65 | } 66 | 67 | req.Header.Add("Content-Type", "application/json") 68 | req.Header.Add("X-Consumer-Username", "1") // for common.GetUserID 69 | assert.Nilf(t, err, "constructs http request") 70 | 71 | res, err := App.Test(req, -1) 72 | assert.Nilf(t, err, "perform request") 73 | assert.Equalf(t, statusCode, res.StatusCode, "status code") 74 | 75 | responseBody, err := io.ReadAll(res.Body) 76 | assert.Nilf(t, err, "decode response") 77 | 78 | return responseBody 79 | } 80 | 81 | // testAPIGeneric inherits testCommon, decodes response body to json, tests whether it's expected 82 | func testAPIGeneric[T JsonData](t *testing.T, method string, route string, statusCode int, data ...Map) T { 83 | responseBody := testCommon(t, method, route, statusCode, data...) 84 | 85 | if statusCode == 204 || statusCode == 302 { // no content and redirect 86 | return nil 87 | } 88 | var responseData T 89 | err := json.Unmarshal(responseBody, &responseData) 90 | assert.Nilf(t, err, "decode response") 91 | 92 | if len(data) > 1 { // data[1] is response data 93 | assert.EqualValuesf(t, data[1], responseData, "response data") 94 | } 95 | 96 | return responseData 97 | } 98 | 99 | // testAPI returns a Map 100 | func testAPI(t *testing.T, method string, route string, statusCode int, data ...Map) Map { 101 | return testAPIGeneric[Map](t, method, route, statusCode, data...) 102 | } 103 | 104 | // testAPIArray returns []Map 105 | func testAPIArray(t *testing.T, method string, route string, statusCode int, data ...Map) []Map { 106 | return testAPIGeneric[[]Map](t, method, route, statusCode, data...) 107 | } 108 | 109 | func testAPIModel[T Models](t *testing.T, method string, route string, statusCode int, obj *T, data ...Map) { 110 | responseBytes := testCommon(t, method, route, statusCode, data...) 111 | err := json.Unmarshal(responseBytes, obj) 112 | assert.Nilf(t, err, "unmarshal response") 113 | } 114 | 115 | func testAPIModelWithQuery[T Models](t *testing.T, method string, route string, statusCode int, obj *T, data ...Map) { 116 | responseBytes := testCommonQuery(t, method, route, statusCode, data...) 117 | err := json.Unmarshal(responseBytes, obj) 118 | assert.Nilf(t, err, "unmarshal response") 119 | } 120 | -------------------------------------------------------------------------------- /utils/bot.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "treehole_next/config" 10 | ) 11 | 12 | type BotMessageType string 13 | 14 | const ( 15 | MessageTypeGroup BotMessageType = "group" 16 | MessageTypePrivate BotMessageType = "private" 17 | ) 18 | 19 | type BotMessage struct { 20 | MessageType BotMessageType `json:"message_type"` 21 | GroupID *int64 `json:"group_id"` 22 | UserID *int64 `json:"user_id"` 23 | Message string `json:"message"` 24 | AutoEscape bool `json:"auto_escape default:false"` 25 | } 26 | 27 | type FeishuMessage struct { 28 | MsgType string `json:"msg_type"` 29 | Content string `json:"message"` 30 | } 31 | 32 | func NotifyFeishu(feishuMessage *FeishuMessage) { 33 | if feishuMessage == nil || feishuMessage.MsgType == "" { 34 | return 35 | } 36 | if config.Config.FeishuBotUrl == nil { 37 | return 38 | } 39 | url := *config.Config.FeishuBotUrl 40 | 41 | jsonData, err := json.Marshal(feishuMessage) 42 | if err != nil { 43 | RequestLog("Error marshaling JSON", "NotifyFeishu", 0, false) 44 | return 45 | } 46 | 47 | RequestLog(fmt.Sprintf("Request: %s", string(jsonData)), "NotifyFeishu", 0, false) 48 | 49 | resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData)) 50 | if err != nil { 51 | RequestLog("Error creating request", "NotifyFeishu", 0, false) 52 | return 53 | } 54 | 55 | defer resp.Body.Close() 56 | if resp.StatusCode != 200 { 57 | response, err := io.ReadAll(resp.Body) 58 | if err != nil { 59 | RequestLog("Error Unmarshaling response", "NotifyFeishu", 0, false) 60 | } 61 | RequestLog(fmt.Sprintf("Error sending request %s", string(response)), "NotifyFeishu", 0, false) 62 | } 63 | } 64 | 65 | func NotifyQQ(botMessage *BotMessage) { 66 | if botMessage == nil { 67 | return 68 | } 69 | if botMessage.MessageType == MessageTypeGroup && botMessage.GroupID == nil { 70 | return 71 | } 72 | if botMessage.MessageType == MessageTypePrivate && botMessage.UserID == nil { 73 | return 74 | } 75 | if config.Config.QQBotUrl == nil { 76 | return 77 | } 78 | // "[CQ:face,id=199]test[CQ:image,file=https://ts1.cn.mm.bing.net/th?id=OIP-C.K5AFHsGlWeLUzKjXGXxdQgHaFj&w=224&h=150&c=8&rs=1&qlt=90&o=6&dpr=1.5&pid=3.1&rm=2]", 79 | url := *config.Config.QQBotUrl + "/send_msg" 80 | 81 | jsonData, err := json.Marshal(botMessage) 82 | if err != nil { 83 | RequestLog("Error marshaling JSON", "NotifyQQ", 0, false) 84 | return 85 | } 86 | 87 | RequestLog(fmt.Sprintf("Request: %s", string(jsonData)), "NotifyQQ", 0, false) 88 | 89 | resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData)) 90 | if err != nil { 91 | RequestLog("Error creating request", "NotifyQQ", 0, false) 92 | return 93 | } 94 | 95 | defer resp.Body.Close() 96 | if resp.StatusCode != 200 { 97 | response, err := io.ReadAll(resp.Body) 98 | if err != nil { 99 | RequestLog("Error Unmarshaling response", "NotifyQQ", 0, false) 100 | } 101 | RequestLog(fmt.Sprintf("Error sending request %s", string(response)), "NotifyQQ", 0, false) 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /utils/cache.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/eko/gocache/lib/v4/cache" 8 | "github.com/eko/gocache/lib/v4/store" 9 | gocache_store "github.com/eko/gocache/store/go_cache/v4" 10 | redis_store "github.com/eko/gocache/store/redis/v4" 11 | "github.com/goccy/go-json" 12 | gocache "github.com/patrickmn/go-cache" 13 | "github.com/redis/go-redis/v9" 14 | 15 | "treehole_next/config" 16 | ) 17 | 18 | var Cache *cache.Cache[[]byte] 19 | 20 | func InitCache() { 21 | if config.Config.RedisURL != "" { 22 | redisStore := redis_store.NewRedis(redis.NewClient(&redis.Options{ 23 | Addr: config.Config.RedisURL, 24 | })) 25 | Cache = cache.New[[]byte](redisStore) 26 | } else { 27 | gocacheStore := gocache_store.NewGoCache(gocache.New(5*time.Minute, 10*time.Minute)) 28 | Cache = cache.New[[]byte](gocacheStore) 29 | } 30 | } 31 | 32 | const maxDuration time.Duration = 1<<63 - 1 33 | 34 | func SetCache(key string, value any, expiration time.Duration) error { 35 | data, err := json.Marshal(value) 36 | if err != nil { 37 | return err 38 | } 39 | if expiration == 0 { 40 | expiration = maxDuration 41 | } 42 | return Cache.Set(context.Background(), key, data, store.WithExpiration(expiration)) 43 | } 44 | 45 | func GetCache(key string, value any) bool { 46 | data, err := Cache.Get(context.Background(), key) 47 | if err != nil { 48 | return false 49 | } 50 | err = json.Unmarshal(data, value) 51 | return err == nil 52 | } 53 | 54 | func DeleteCache(key string) error { 55 | err := Cache.Delete(context.Background(), key) 56 | if err == nil { 57 | return nil 58 | } 59 | if err.Error() == "Entry not found" { 60 | return nil 61 | } 62 | return err 63 | } 64 | -------------------------------------------------------------------------------- /utils/errors.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | const ( 4 | ErrCodeNotAnsweredQuestions = iota + 403001 5 | ) 6 | -------------------------------------------------------------------------------- /utils/log.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "github.com/rs/zerolog/log" 5 | ) 6 | 7 | type Role string 8 | 9 | const ( 10 | RoleOwner = "owner" 11 | RoleAdmin = "admin" 12 | RoleOperator = "operator" 13 | ) 14 | 15 | func MyLog(model string, action string, objectID, userID int, role Role, msg ...string) { 16 | message := "" 17 | for _, v := range msg { 18 | message += v 19 | } 20 | log.Info(). 21 | Str("model", model). 22 | Int("user_id", userID). 23 | Int("object_id", objectID). 24 | Str("action", action). 25 | Str("role", string(role)). 26 | Msg(message) 27 | } 28 | 29 | func RequestLog(msg string, TypeName string, Id int64, ans bool) { 30 | log.Info().Str("TypeName", TypeName). 31 | Int64("Id", Id). 32 | Bool("CheckAnswer", ans). 33 | Msg(msg) 34 | } 35 | -------------------------------------------------------------------------------- /utils/model.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | type IDModel[T any] interface { 4 | *T 5 | GetID() int 6 | } 7 | 8 | func binarySearch[T any, PT IDModel[T]](models []PT, targetID int) int { 9 | left := 0 10 | right := len(models) 11 | for left < right { 12 | mid := left + (right-left)>>1 13 | if models[mid].GetID() < targetID { 14 | left = mid + 1 15 | } else if models[mid].GetID() > targetID { 16 | right = mid 17 | } else { 18 | return mid 19 | } 20 | } 21 | return -1 22 | } 23 | 24 | func OrderInGivenOrder[T any, PT IDModel[T]](models []PT, order []int) (result []PT) { 25 | for _, i := range order { 26 | index := binarySearch(models, i) 27 | if index >= 0 { 28 | result = append(result, models[index]) 29 | } 30 | } 31 | return result 32 | } 33 | 34 | func Models2IDSlice[T any, PT IDModel[T]](models []PT) (result []int) { 35 | result = make([]int, len(models)) 36 | for i := range models { 37 | result[i] = models[i].GetID() 38 | } 39 | return result 40 | } 41 | -------------------------------------------------------------------------------- /utils/name.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/binary" 6 | "math/rand" 7 | "sort" 8 | "time" 9 | 10 | "github.com/goccy/go-json" 11 | "github.com/rs/zerolog/log" 12 | 13 | "treehole_next/config" 14 | "treehole_next/data" 15 | 16 | "golang.org/x/exp/slices" 17 | ) 18 | 19 | var names []string 20 | var length int 21 | 22 | const ( 23 | charset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 24 | randomCodeLength = 6 25 | ) 26 | 27 | func init() { 28 | err := json.Unmarshal(data.NamesFile, &names) 29 | if err != nil { 30 | log.Fatal().Err(err).Send() 31 | } 32 | sort.Strings(names) 33 | length = len(names) 34 | } 35 | 36 | func inArray(target string, array []string) bool { 37 | _, in := slices.BinarySearch(array, target) 38 | return in 39 | } 40 | 41 | func timeStampBase64() string { 42 | bytes := make([]byte, 8) 43 | binary.LittleEndian.PutUint64(bytes, uint64(time.Now().Unix())) 44 | return base64.StdEncoding.EncodeToString(bytes) 45 | } 46 | 47 | func generateRandomCode() string { 48 | code := make([]byte, randomCodeLength) 49 | charsetLength := len(charset) 50 | 51 | for i := 0; i < randomCodeLength; i++ { 52 | n := rand.Intn(charsetLength) 53 | code[i] = charset[n] 54 | } 55 | 56 | return string(code) 57 | } 58 | 59 | func NewRandName() string { 60 | return names[rand.Intn(length)] 61 | } 62 | 63 | func GenerateName(compareList []string) string { 64 | if len(compareList) < length>>3 { 65 | for { 66 | name := NewRandName() 67 | if !inArray(name, compareList) { 68 | return name 69 | } 70 | } 71 | } else if len(compareList) < length { 72 | var j, k int 73 | list := make([]string, length) 74 | for i := 0; i < length; i++ { 75 | if j < len(compareList) && names[i] == compareList[j] { 76 | j++ 77 | } else { 78 | list[k] = names[i] 79 | k++ 80 | } 81 | } 82 | return list[rand.Intn(k)] 83 | } else { 84 | for { 85 | // name := names[rand.Intn(length)] + "_" + timeStampBase64() 86 | name := names[rand.Intn(length)] + "_" + generateRandomCode() 87 | if !inArray(name, compareList) { 88 | return name 89 | } 90 | } 91 | } 92 | } 93 | 94 | func GetFuzzName(name string) string { 95 | if !config.Config.OpenFuzzName { 96 | return name 97 | } 98 | if fuzzName, ok := data.NamesMapping[name]; ok { 99 | return fuzzName 100 | } else { 101 | return name 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /utils/sensitive/utils.go: -------------------------------------------------------------------------------- 1 | package sensitive 2 | 3 | import ( 4 | "errors" 5 | "golang.org/x/exp/slices" 6 | "mvdan.cc/xurls/v2" 7 | "net/url" 8 | "regexp" 9 | "strings" 10 | "treehole_next/config" 11 | ) 12 | 13 | var imageRegex = regexp.MustCompile( 14 | `!\[(.*?)]\(([^" )]*?)\s*(".*?")?\)`, 15 | ) 16 | 17 | var ( 18 | ErrUrlParsing = errors.New("error parsing url") 19 | ErrInvalidImageHost = errors.New("不允许使用外部图片链接") 20 | ErrImageLinkTextOnly = errors.New("image link only contains text") 21 | ) 22 | 23 | // findImagesInMarkdown 从Markdown文本中查找所有图片链接,检查图片链接是否合法,并且返回清除链接之后的文本 24 | func findImagesInMarkdownContent(content string) (imageUrls []string, clearContent string, err error) { 25 | err = nil 26 | clearContent = imageRegex.ReplaceAllStringFunc(content, func(s string) string { 27 | if err != nil { 28 | return "" 29 | } 30 | submatch := imageRegex.FindStringSubmatch(s) 31 | altText := submatch[1] 32 | 33 | var imageUrl string 34 | if len(submatch) > 2 && submatch[2] != "" { 35 | imageUrl = submatch[2] 36 | innerErr := checkValidUrl(imageUrl) 37 | if innerErr != nil { 38 | if errors.Is(innerErr, ErrInvalidImageHost) { 39 | err = innerErr 40 | return "" 41 | } 42 | // if the url is not valid, treat as text only 43 | } else { 44 | // append only valid image url 45 | imageUrls = append(imageUrls, imageUrl) 46 | imageUrl = "" 47 | } 48 | } 49 | 50 | var title string 51 | if len(submatch) > 3 && submatch[3] != "" { 52 | title = strings.Trim(submatch[3], "\"") 53 | } 54 | 55 | var ret strings.Builder 56 | if altText != "" { 57 | ret.WriteString(altText) 58 | } 59 | if imageUrl != "" { 60 | if ret.String() != "" { 61 | ret.WriteString(" ") 62 | } 63 | ret.WriteString(imageUrl) 64 | } 65 | if title != "" { 66 | if ret.String() != "" { 67 | ret.WriteString(" ") 68 | } 69 | ret.WriteString(title) 70 | } 71 | return ret.String() 72 | }) 73 | return 74 | } 75 | 76 | func checkType(params ParamsForCheck) bool { 77 | return slices.Contains(checkTypes, params.TypeName) 78 | } 79 | 80 | func containsUnsafeURL(content string) (bool, string) { 81 | xurlsRelaxed := xurls.Relaxed() 82 | matchedURLs := xurlsRelaxed.FindAllString(content, -1) 83 | if len(matchedURLs) == 0 { 84 | return false, "" 85 | } 86 | 87 | for _, matchedURL := range matchedURLs { 88 | if !strings.Contains(matchedURL, "://") { 89 | matchedURL = "http://" + matchedURL 90 | } 91 | parsedURL, err := url.Parse(matchedURL) 92 | if err != nil || parsedURL == nil { 93 | return true, matchedURL 94 | } 95 | checked := slices.ContainsFunc(config.Config.UrlHostnameWhitelist, func(s string) bool { 96 | return strings.HasSuffix(parsedURL.Host, s) 97 | }) 98 | if !checked { 99 | return true, parsedURL.Host 100 | } 101 | } 102 | return false, "" 103 | } 104 | 105 | func checkValidUrl(input string) error { 106 | imageUrl, err := url.Parse(input) 107 | if err != nil { 108 | return ErrUrlParsing 109 | } 110 | // if the url is text only, skip check 111 | if imageUrl.Scheme == "" && imageUrl.Host == "" { 112 | return ErrImageLinkTextOnly 113 | } 114 | if !slices.Contains(config.Config.ValidImageUrl, imageUrl.Hostname()) { 115 | return ErrInvalidImageHost 116 | } 117 | return nil 118 | } 119 | 120 | var reHole = regexp.MustCompile(`[^#]#(\d+)`) 121 | var reFloor = regexp.MustCompile(`##(\d+)`) 122 | 123 | func removeIDReprInContent(content string) string { 124 | content = " " + content 125 | content = reHole.ReplaceAllString(content, "") 126 | content = reFloor.ReplaceAllString(content, "") 127 | return strings.TrimSpace(content) 128 | } 129 | -------------------------------------------------------------------------------- /utils/sensitive/utils_test.go: -------------------------------------------------------------------------------- 1 | package sensitive 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | "treehole_next/config" 7 | ) 8 | 9 | func TestFindImagesInMarkdown(t *testing.T) { 10 | config.Config.ValidImageUrl = []string{"example.com"} 11 | 12 | type wantStruct struct { 13 | clearContent string 14 | imageUrls []string 15 | err error 16 | } 17 | tests := []struct { 18 | text string 19 | want wantStruct 20 | }{ 21 | { 22 | text: `![image1](https://example.com/image1)`, 23 | want: wantStruct{ 24 | clearContent: `image1`, 25 | imageUrls: []string{"https://example.com/image1"}, 26 | err: nil, 27 | }, 28 | }, 29 | { 30 | text: `![image1](https://example.com/image1) ![image2](https://example.com/image2)`, 31 | want: wantStruct{ 32 | clearContent: `image1 image2`, 33 | imageUrls: []string{"https://example.com/image1", "https://example.com/image2"}, 34 | }, 35 | }, 36 | { 37 | text: `![image1](https://example.com/image1 "title1") ![image2](https://example.com/image2 "title2")`, 38 | want: wantStruct{ 39 | clearContent: `image1 title1 image2 title2`, 40 | imageUrls: []string{"https://example.com/image1", "https://example.com/image2"}, 41 | }, 42 | }, 43 | { 44 | text: `![image1](123) ![image2](456)`, 45 | want: wantStruct{ 46 | clearContent: `image1 123 image2 456`, 47 | imageUrls: nil, 48 | }, 49 | }, 50 | { 51 | text: `![](123) ![](456)`, 52 | want: wantStruct{ 53 | clearContent: `123 456`, 54 | imageUrls: nil, 55 | }, 56 | }, 57 | { 58 | text: "![](https://example2.com/image1)", 59 | want: wantStruct{ 60 | clearContent: "", 61 | imageUrls: nil, 62 | err: ErrInvalidImageHost, 63 | }, 64 | }, 65 | } 66 | 67 | for _, tt := range tests { 68 | imageUrls, cleanText, err := findImagesInMarkdownContent(tt.text) 69 | assert.EqualValues(t, tt.want.clearContent, cleanText, "cleanText should be equal") 70 | assert.EqualValues(t, tt.want.imageUrls, imageUrls, "imageUrls should be equal") 71 | assert.EqualValues(t, tt.want.err, err, "err should be equal") 72 | } 73 | } 74 | 75 | func TestCheckValidUrl(t *testing.T) { 76 | config.Config.ValidImageUrl = []string{"example.com"} 77 | type wantStruct struct { 78 | err error 79 | } 80 | tests := []struct { 81 | url string 82 | want wantStruct 83 | }{ 84 | { 85 | url: "https://example.com/image1", 86 | want: wantStruct{ 87 | err: nil, 88 | }, 89 | }, 90 | { 91 | url: "https://example.com/image2", 92 | want: wantStruct{ 93 | err: nil, 94 | }, 95 | }, 96 | { 97 | url: "123456", 98 | want: wantStruct{ 99 | err: ErrImageLinkTextOnly, 100 | }, 101 | }, 102 | { 103 | url: "https://example2.com", 104 | want: wantStruct{ 105 | err: ErrInvalidImageHost, 106 | }, 107 | }, 108 | } 109 | 110 | for _, tt := range tests { 111 | err := checkValidUrl(tt.url) 112 | assert.EqualValues(t, tt.want.err, err, "err should be equal") 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /utils/utils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "golang.org/x/exp/slices" 5 | "strconv" 6 | 7 | "github.com/gofiber/fiber/v2" 8 | "github.com/opentreehole/go-common" 9 | "golang.org/x/exp/constraints" 10 | 11 | "treehole_next/config" 12 | ) 13 | 14 | type CanPreprocess interface { 15 | Preprocess(c *fiber.Ctx) error 16 | } 17 | 18 | func Serialize(c *fiber.Ctx, obj CanPreprocess) error { 19 | err := obj.Preprocess(c) 20 | if err != nil { 21 | return err 22 | } 23 | return c.JSON(obj) 24 | } 25 | 26 | func RegText2IntArray(IDs [][]string) ([]int, error) { 27 | ansIDs := make([]int, 0) 28 | for _, v := range IDs { 29 | id, err := strconv.Atoi(v[1]) 30 | if err != nil { 31 | return nil, err 32 | } 33 | ansIDs = append(ansIDs, id) 34 | } 35 | return ansIDs, nil 36 | } 37 | 38 | func Keys[T comparable, S any](m map[T]S) (s []T) { 39 | for k := range m { 40 | s = append(s, k) 41 | } 42 | return s 43 | } 44 | 45 | func Min[T constraints.Ordered](x T, y T) T { 46 | if x > y { 47 | return y 48 | } else { 49 | return x 50 | } 51 | } 52 | 53 | func Intersect[T comparable](x []T, y []T) []T { 54 | var result = make([]T, 0) 55 | for i := range x { 56 | if slices.Contains(y, x[i]) { 57 | result = append(result, x[i]) 58 | } 59 | } 60 | return result 61 | } 62 | 63 | // Difference returns the elements in a that aren't in b 64 | func Difference[T comparable](a, b []T) []T { 65 | m := make(map[T]bool) 66 | var result []T 67 | 68 | for _, item := range b { 69 | m[item] = true 70 | } 71 | 72 | for _, item := range a { 73 | if _, ok := m[item]; !ok { 74 | result = append(result, item) 75 | } 76 | } 77 | 78 | return result 79 | } 80 | 81 | func StripContent(content string, contentMaxSize int) string { 82 | return string([]rune(content)[:Min(len([]rune(content)), contentMaxSize)]) 83 | } 84 | 85 | func MiddlewareHasAnsweredQuestions(c *fiber.Ctx) error { 86 | if config.Config.Mode == "test" || config.Config.Mode == "bench" { 87 | return c.Next() 88 | } 89 | var user struct { 90 | HasAnsweredQuestions bool `json:"has_answered_questions"` 91 | } 92 | err := common.ParseJWTToken(common.GetJWTToken(c), &user) 93 | if err != nil { 94 | return err 95 | } 96 | if !user.HasAnsweredQuestions { 97 | return &common.HttpError{ 98 | Code: ErrCodeNotAnsweredQuestions, 99 | Message: "请先通过注册答题", 100 | } 101 | } 102 | return c.Next() 103 | } 104 | -------------------------------------------------------------------------------- /utils/utils_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestStripContent(t *testing.T) { 10 | var str string 11 | str = "愿中国青年都摆脱冷气,只是向上走,不必听自暴自弃者流的话。能做事的做事,能发声的发声。有一分热,发一分光。就令萤火一般,也可以在黑暗里发一点光,不必等候炬火。" 12 | println(len(str)) 13 | println(len([]rune(str))) 14 | assert.Equal(t, "愿中国青年都摆脱冷气", StripContent(str, 10)) 15 | assert.Equal(t, str, StripContent(str, 100)) 16 | } 17 | --------------------------------------------------------------------------------