├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
└── workflows
│ ├── dev.yaml
│ └── main.yaml
├── .gitignore
├── CODE_OF_CONDUCT.md
├── Dockerfile
├── LICENSE
├── README.md
├── apis
├── default.go
├── division
│ ├── apis.go
│ ├── routes.go
│ ├── schemas.go
│ └── utils.go
├── favourite
│ ├── api.go
│ ├── routes.go
│ └── schemas.go
├── floor
│ ├── apis.go
│ ├── routes.go
│ ├── schemas.go
│ ├── search.go
│ └── utils.go
├── hole
│ ├── apis.go
│ ├── purge_hole.go
│ ├── routes.go
│ ├── schemas.go
│ └── update_views.go
├── message
│ ├── apis.go
│ ├── purge.go
│ ├── routes.go
│ └── schemas.go
├── penalty
│ └── api.go
├── report
│ ├── apis.go
│ ├── routes.go
│ └── schemas.go
├── routes.go
├── subscription
│ ├── api.go
│ ├── routes.go
│ └── schemas.go
├── tag
│ ├── apis.go
│ ├── routes.go
│ └── schemas.go
└── user
│ ├── apis.go
│ └── schemas.go
├── benchmarks
├── floor_test.go
├── hole_test.go
├── init.go
└── utils.go
├── bootstrap
└── init.go
├── config
└── config.go
├── data
├── data.go
├── meta.json
└── names.json
├── docs
├── docs.go
├── swagger.json
└── swagger.yaml
├── go.mod
├── go.sum
├── main.go
├── models
├── admin_log.go
├── anonyname.go
├── base.go
├── division.go
├── elastic.go
├── favorite_group.go
├── floor.go
├── floor_history.go
├── floor_like.go
├── floor_mention.go
├── hole.go
├── hole_tags.go
├── init.go
├── message.go
├── notification.go
├── punishment.go
├── report.go
├── report_punishment.go
├── tag.go
├── url_hostname_whitelist.go
├── user.go
├── user_favorite.go
├── user_subscription.go
└── user_test.go
├── tests
├── default.go
├── default_test.go
├── division_test.go
├── favorite_test.go
├── floor_test.go
├── hole_test.go
├── init.go
├── report_test.go
├── tag_test.go
└── utils.go
└── utils
├── bot.go
├── cache.go
├── errors.go
├── log.go
├── model.go
├── name.go
├── sensitive
├── api.go
├── utils.go
└── utils_test.go
├── utils.go
└── utils_test.go
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug报告
3 | about: 说明你遇到的Bug。
4 | title: '[BUG]'
5 | labels: bug
6 | assignees: ''
7 |
8 | ---
9 |
10 | **描述 Bug**
11 | 简要描述 Bug 是什么。
12 | 如果你认为标题已经说得很清楚了,可以删除这一项。
13 |
14 | **复现步骤**
15 | 复现该 Bug 的步骤:
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: 请求新功能
3 | about: 请求后端加入新的功能!
4 | title: '[Feature Request]'
5 | labels: feature request
6 | assignees: ''
7 |
8 | ---
9 |
10 | **你的功能需求和某个bug有关吗?**
11 | `[是/否]`
12 |
13 | **你想要什么样的功能?**
14 | 简要、清晰地描述你请求增加的功能。
15 |
--------------------------------------------------------------------------------
/.github/workflows/dev.yaml:
--------------------------------------------------------------------------------
1 | name: Dev Build
2 | on:
3 | push:
4 | branches: [ dev ]
5 |
6 | env:
7 | APP_NAME: treehole_next
8 |
9 | jobs:
10 | docker:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Checkout
14 | uses: actions/checkout@master
15 |
16 | # - name: Set up Go
17 | # uses: actions/setup-go@master
18 | # with:
19 | # go-version: 1.21.1
20 | #
21 | # - run: go version
22 | #
23 | # - name: Automated Testing
24 | # env:
25 | # MODE: test
26 | # run: go test -v -count=1 -json -tags release ./tests/...
27 |
28 | - name: Set up QEMU
29 | uses: docker/setup-qemu-action@master
30 |
31 | - name: Set up Docker Buildx
32 | uses: docker/setup-buildx-action@master
33 |
34 | - name: Login to DockerHub
35 | uses: docker/login-action@master
36 | with:
37 | username: ${{ secrets.DOCKERHUB_USERNAME }}
38 | password: ${{ secrets.DOCKERHUB_TOKEN }}
39 |
40 | - name: Build and push
41 | id: docker_build
42 | uses: docker/build-push-action@master
43 | with:
44 | push: true
45 | tags: |
46 | ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.APP_NAME }}:latest
47 | ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.APP_NAME }}:dev
48 |
49 |
--------------------------------------------------------------------------------
/.github/workflows/main.yaml:
--------------------------------------------------------------------------------
1 | name: Production Build
2 | on:
3 | push:
4 | branches: [ main ]
5 |
6 | env:
7 | APP_NAME: treehole_next
8 |
9 | jobs:
10 | docker:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Checkout
14 | uses: actions/checkout@master
15 |
16 | # - name: Set up Go
17 | # uses: actions/setup-go@master
18 | # with:
19 | # go-version: 1.21.1
20 | #
21 | # - run: go version
22 | # - name: Automated Testing
23 | # env:
24 | # MODE: test
25 | # run: go test -v -count=1 -json -tags release ./tests/...
26 |
27 | - name: Set up QEMU
28 | uses: docker/setup-qemu-action@master
29 |
30 | - name: Set up Docker Buildx
31 | uses: docker/setup-buildx-action@master
32 |
33 | - name: Login to DockerHub
34 | uses: docker/login-action@master
35 | with:
36 | username: ${{ secrets.DOCKERHUB_USERNAME }}
37 | password: ${{ secrets.DOCKERHUB_TOKEN }}
38 |
39 | - name: Build and push
40 | id: docker_build
41 | uses: docker/build-push-action@master
42 | with:
43 | push: true
44 | tags: |
45 | ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.APP_NAME }}:latest
46 | ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.APP_NAME }}:master
47 |
48 |
49 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | .vscode
3 | *.env
4 | *.exe
5 | *.db
6 | *.ps1
7 | *.log
8 | *.out
9 | *.http
10 | .DS_Store
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, religion, or sexual identity
10 | and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the
26 | overall community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or
31 | advances of any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email
35 | address, without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at
63 | dev@danta.tech.
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 | ## Enforcement Guidelines
70 |
71 | Community leaders will follow these Community Impact Guidelines in determining
72 | the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 | ### 1. Correction
75 |
76 | **Community Impact**: Use of inappropriate language or other behavior deemed
77 | unprofessional or unwelcome in the community.
78 |
79 | **Consequence**: A private, written warning from community leaders, providing
80 | clarity around the nature of the violation and an explanation of why the
81 | behavior was inappropriate. A public apology may be requested.
82 |
83 | ### 2. Warning
84 |
85 | **Community Impact**: A violation through a single incident or series
86 | of actions.
87 |
88 | **Consequence**: A warning with consequences for continued behavior. No
89 | interaction with the people involved, including unsolicited interaction with
90 | those enforcing the Code of Conduct, for a specified period of time. This
91 | includes avoiding interactions in community spaces as well as external channels
92 | like social media. Violating these terms may lead to a temporary or
93 | permanent ban.
94 |
95 | ### 3. Temporary Ban
96 |
97 | **Community Impact**: A serious violation of community standards, including
98 | sustained inappropriate behavior.
99 |
100 | **Consequence**: A temporary ban from any sort of interaction or public
101 | communication with the community for a specified period of time. No public or
102 | private interaction with the people involved, including unsolicited interaction
103 | with those enforcing the Code of Conduct, is allowed during this period.
104 | Violating these terms may lead to a permanent ban.
105 |
106 | ### 4. Permanent Ban
107 |
108 | **Community Impact**: Demonstrating a pattern of violation of community
109 | standards, including sustained inappropriate behavior, harassment of an
110 | individual, or aggression toward or disparagement of classes of individuals.
111 |
112 | **Consequence**: A permanent ban from any sort of public interaction within
113 | the community.
114 |
115 | ## Attribution
116 |
117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 | version 2.0, available at
119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120 |
121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct
122 | enforcement ladder](https://github.com/mozilla/diversity).
123 |
124 | [homepage]: https://www.contributor-covenant.org
125 |
126 | For answers to common questions about this code of conduct, see the FAQ at
127 | https://www.contributor-covenant.org/faq. Translations are available at
128 | https://www.contributor-covenant.org/translations.
129 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM golang:1.22-alpine as builder
2 |
3 | WORKDIR /app
4 |
5 | COPY go.mod go.sum ./
6 | RUN apk add --no-cache --virtual .build-deps \
7 | ca-certificates \
8 | tzdata \
9 | gcc \
10 | g++ && \
11 | go mod download
12 |
13 | COPY . .
14 |
15 | RUN go build -ldflags "-s -w" -o treehole
16 |
17 | FROM alpine
18 |
19 | WORKDIR /app
20 |
21 | COPY --from=builder /app/treehole /app/
22 | COPY --from=builder /usr/share/zoneinfo /usr/share/zoneinfo
23 | COPY data data
24 |
25 | ENV TZ=Asia/Shanghai
26 | ENV MODE=production
27 |
28 | EXPOSE 8000
29 |
30 | ENTRYPOINT ["./treehole"]
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Open Tree Hole Next
2 |
3 | Next Generation of OpenTreeHole Written In Golang
4 |
5 | An Anonymous BBS
6 |
7 | ## Features
8 | To be done
9 |
10 | ## Usage
11 |
12 | ### build and run
13 |
14 | Before building, you should have go (>=1.22) installed on your machine. If not, please visit [golang.org](https://golang.org) to download the latest version.
15 |
16 | ```shell
17 | git clone https://github.com/OpenTreeHole/treehole_next.git
18 | cd treehole_next
19 | # install swag and generate docs
20 | go install github.com/swaggo/swag/cmd/swag@latest
21 | swag init --parseDependency --parseDepth 1 # to generate the latest docs, this should be run before compiling
22 | # build for debug
23 | go build -o treehole.exe
24 | # build for release
25 | go build -tags "release" -ldflags "-s -w" -o treehole.exe
26 | # run
27 | ./treehole.exe
28 | ```
29 |
30 | Note: You should check [config file](config/config.go) to see if required environment variables are set.
31 |
32 | ### test
33 |
34 | ```shell
35 | export MODE=test
36 | go test -v ./tests/...
37 | ```
38 |
39 | ### benchmark
40 |
41 | ```shell
42 | export MODE=bench
43 | go test -v -benchmem -cpuprofile=cpu.out -benchtime=1s ./benchmarks/... -bench .
44 | ```
45 | For documentation, please open http://localhost:8000/docs after running app
46 | ## Badge
47 |
48 | [//]: # ([](https://github.com/OpenTreeHole/treehole_next/actions/workflows/master.yaml))
49 | [//]: # ([](https://github.com/OpenTreeHole/treehole_next/actions/workflows/dev.yaml))
50 |
51 | [](https://github.com/OpenTreeHole/treehole_next/stargazers)
52 | [](https://github.com/OpenTreeHole/treehole_next/issues)
53 | [](https://github.com/OpenTreeHole/treehole_next/pulls)
54 |
55 | [](https://github.com/RichardLitt/standard-readme)
56 |
57 | ### Powered by
58 |
59 | 
60 | 
61 |
62 | ## Contributing
63 |
64 | Feel free to dive in! [Open an issue](https://github.com/OpenTreeHole/treehole_next/issues/new) or [Submit PRs](https://github.com/OpenTreeHole/treehole_next/compare).
65 |
66 | We are now in rapid development, any contribution would be of great help.
67 | For the developing roadmap, please visit [this issue](https://github.com/OpenTreeHole/treehole_next/issues/1).
68 |
69 | ### Contributors
70 |
71 | This project exists thanks to all the people who contribute.
72 |
73 |
74 |
75 |
76 |
77 | ## Licence
78 |
79 | [](https://github.com/OpenTreeHole/treehole_next/blob/master/LICENSE)
80 | © OpenTreeHole
81 |
--------------------------------------------------------------------------------
/apis/default.go:
--------------------------------------------------------------------------------
1 | package apis
2 |
3 | import (
4 | "treehole_next/data"
5 |
6 | "github.com/gofiber/fiber/v2"
7 | )
8 |
9 | // Index
10 | //
11 | // @Produce application/json
12 | // @Success 200 {object} models.MessageModel
13 | // @Router / [get]
14 | func Index(c *fiber.Ctx) error {
15 | return c.Send(data.MetaFile)
16 | }
17 |
--------------------------------------------------------------------------------
/apis/division/apis.go:
--------------------------------------------------------------------------------
1 | package division
2 |
3 | import (
4 | "strconv"
5 |
6 | "github.com/goccy/go-json"
7 | "gorm.io/gorm"
8 | "gorm.io/gorm/clause"
9 |
10 | "github.com/opentreehole/go-common"
11 |
12 | . "treehole_next/models"
13 | . "treehole_next/utils"
14 |
15 | "github.com/gofiber/fiber/v2"
16 | )
17 |
18 | // AddDivision
19 | //
20 | // @Summary Add A Division
21 | // @Tags Division
22 | // @Accept application/json
23 | // @Produce application/json
24 | // @Router /divisions [post]
25 | // @Param json body CreateModel true "json"
26 | // @Success 201 {object} models.Division
27 | // @Success 200 {object} models.Division
28 | func AddDivision(c *fiber.Ctx) error {
29 | // validate body
30 | var body CreateModel
31 | err := common.ValidateBody(c, &body)
32 | if err != nil {
33 | return err
34 | }
35 |
36 | // get user
37 | user, err := GetCurrLoginUser(c)
38 | if err != nil {
39 | return err
40 | }
41 |
42 | // permission check
43 | if !user.IsAdmin {
44 | return common.Forbidden()
45 | }
46 |
47 | // bind division
48 | division := Division{
49 | Name: body.Name,
50 | Description: body.Description,
51 | }
52 | result := DB.FirstOrCreate(&division, Division{Name: body.Name})
53 | if result.RowsAffected == 0 {
54 | c.Status(200)
55 | } else {
56 | c.Status(201)
57 | }
58 | return Serialize(c, &division)
59 | }
60 |
61 | // ListDivisions
62 | //
63 | // @Summary List All Divisions
64 | // @Tags Division
65 | // @Produce application/json
66 | // @Router /divisions [get]
67 | // @Success 200 {array} models.Division
68 | func ListDivisions(c *fiber.Ctx) error {
69 | var divisions Divisions
70 | if GetCache("divisions", &divisions) {
71 | return c.JSON(divisions)
72 | }
73 | err := DB.Find(&divisions, "hidden = false").Error
74 | if err != nil {
75 | return err
76 | }
77 | return Serialize(c, divisions)
78 | }
79 |
80 | // GetDivision
81 | //
82 | // @Summary Get Division
83 | // @Tags Division
84 | // @Produce application/json
85 | // @Router /divisions/{id} [get]
86 | // @Param id path int true "id"
87 | // @Success 200 {object} models.Division
88 | // @Failure 404 {object} MessageModel
89 | func GetDivision(c *fiber.Ctx) error {
90 | id, err := c.ParamsInt("id")
91 | if err != nil {
92 | return err
93 | }
94 | var division Division
95 | result := DB.Where("hidden = false").First(&division, id)
96 | if result.Error != nil {
97 | return result.Error
98 | }
99 | return Serialize(c, &division)
100 | }
101 |
102 | // ModifyDivision
103 | //
104 | // @Summary Modify A Division
105 | // @Tags Division
106 | // @Produce json
107 | // @Router /divisions/{id} [put]
108 | // @Router /divisions/{id}/_webvpn [patch]
109 | // @Param id path int true "id"
110 | // @Param json body ModifyDivisionModel true "json"
111 | // @Success 200 {object} models.Division
112 | // @Failure 404 {object} MessageModel
113 | func ModifyDivision(c *fiber.Ctx) error {
114 | // validate body
115 | var body ModifyDivisionModel
116 | err := common.ValidateBody(c, &body)
117 | if err != nil {
118 | return err
119 | }
120 |
121 | id, err := c.ParamsInt("id")
122 | if err != nil {
123 | return err
124 | }
125 |
126 | // get user
127 | user, err := GetCurrLoginUser(c)
128 | if err != nil {
129 | return err
130 | }
131 |
132 | // permission check
133 | if !user.IsAdmin {
134 | return common.Forbidden()
135 | }
136 |
137 | var division Division
138 | err = DB.Transaction(func(tx *gorm.DB) error {
139 | err = tx.Clauses(clause.Locking{Strength: "UPDATE"}).First(&division, id).Error
140 | if err != nil {
141 | return err
142 | }
143 |
144 | modifyData := make(map[string]any)
145 | if body.Name != nil {
146 | modifyData["name"] = *body.Name
147 | }
148 | if body.Description != nil {
149 | modifyData["description"] = *body.Description
150 | }
151 | if body.Pinned != nil {
152 | data, _ := json.Marshal(body.Pinned)
153 | modifyData["pinned"] = string(data)
154 | }
155 |
156 | if len(modifyData) == 0 {
157 | return common.BadRequest("No data to modify.")
158 | }
159 |
160 | return tx.Model(&division).Updates(modifyData).Error
161 | })
162 | if err != nil {
163 | return err
164 | }
165 |
166 | var newDivision Division
167 | err = DB.First(&newDivision, id).Error
168 | if err != nil {
169 | return err
170 | }
171 |
172 | MyLog("Division", "Modify", division.ID, user.ID, RoleAdmin)
173 |
174 | CreateAdminLog(DB, AdminLogTypeDivision, user.ID, map[string]any{
175 | "division_id": division.ID,
176 | "before": division,
177 | "after": newDivision,
178 | })
179 |
180 | // refresh cache. here should not use `go refreshCache`
181 | err = refreshCache(c)
182 | if err != nil {
183 | return err
184 | }
185 |
186 | return Serialize(c, &newDivision)
187 | }
188 |
189 | // DeleteDivision
190 | //
191 | // @Summary Delete A Division
192 | // @Description Delete a division and move all of its holes to another given division
193 | // @Tags Division
194 | // @Produce application/json
195 | // @Router /divisions/{id} [delete]
196 | // @Param id path int true "id"
197 | // @Param json body DeleteModel true "json"
198 | // @Success 204
199 | // @Failure 404 {object} MessageModel
200 | func DeleteDivision(c *fiber.Ctx) error {
201 | // validate body
202 | var body DeleteModel
203 | err := common.ValidateBody(c, &body)
204 | if err != nil {
205 | return err
206 | }
207 | id, err := c.ParamsInt("id")
208 | if err != nil {
209 | return err
210 | }
211 |
212 | // get user
213 | user, err := GetCurrLoginUser(c)
214 | if err != nil {
215 | return err
216 | }
217 | if !user.IsAdmin {
218 | return common.Forbidden()
219 | }
220 |
221 | if id == body.To {
222 | return common.BadRequest("The deleted division can't be the same as to.")
223 | }
224 | err = DB.Exec("UPDATE hole SET division_id = ? WHERE division_id = ?", body.To, id).Error
225 | if err != nil {
226 | return err
227 | }
228 | err = DB.Delete(&Division{ID: id}).Error
229 | if err != nil {
230 | return err
231 | }
232 |
233 | // log
234 | //if err != nil {
235 | // return err
236 | //}
237 | MyLog("Division", "Delete", id, user.ID, RoleAdmin, "To: ", strconv.Itoa(body.To))
238 |
239 | err = refreshCache(c)
240 | if err != nil {
241 | return err
242 | }
243 |
244 | return c.Status(204).JSON(nil)
245 | }
246 |
--------------------------------------------------------------------------------
/apis/division/routes.go:
--------------------------------------------------------------------------------
1 | package division
2 |
3 | import "github.com/gofiber/fiber/v2"
4 |
5 | func RegisterRoutes(app fiber.Router) {
6 | app.Post("/divisions", AddDivision)
7 | app.Get("/divisions", ListDivisions)
8 | app.Get("/divisions/:id", GetDivision)
9 | app.Put("/divisions/:id", ModifyDivision)
10 | app.Patch("/divisions/:id/_webvpn", ModifyDivision)
11 | app.Delete("/divisions/:id", DeleteDivision)
12 | }
13 |
--------------------------------------------------------------------------------
/apis/division/schemas.go:
--------------------------------------------------------------------------------
1 | package division
2 |
3 | type DeleteModel struct {
4 | // Admin only
5 | // ID of the target division that all the deleted division's holes will be moved to
6 | To int `json:"to" default:"1"`
7 | }
8 |
9 | type CreateModel struct {
10 | Name string `json:"name"`
11 | Description string `json:"description"`
12 | }
13 |
14 | type ModifyDivisionModel struct {
15 | Name *string `json:"name"`
16 | Description *string `json:"description"`
17 | Pinned []int `json:"pinned"`
18 | }
19 |
--------------------------------------------------------------------------------
/apis/division/utils.go:
--------------------------------------------------------------------------------
1 | package division
2 |
3 | import (
4 | "github.com/gofiber/fiber/v2"
5 | "github.com/rs/zerolog/log"
6 |
7 | . "treehole_next/models"
8 | )
9 |
10 | func refreshCache(c *fiber.Ctx) error {
11 |
12 | var divisions Divisions
13 | err := DB.Find(&divisions).Error
14 | if err != nil {
15 | return err
16 | }
17 |
18 | err = divisions.Preprocess(c)
19 | if err != nil {
20 | log.Err(err).Msg("error refreshing cache")
21 | return err
22 | }
23 |
24 | return nil
25 | }
26 |
--------------------------------------------------------------------------------
/apis/favourite/routes.go:
--------------------------------------------------------------------------------
1 | package favourite
2 |
3 | import "github.com/gofiber/fiber/v2"
4 |
5 | func RegisterRoutes(app fiber.Router) {
6 | app.Get("/user/favorites", ListFavorites)
7 | app.Post("/user/favorites", AddFavorite)
8 | app.Put("/user/favorites", ModifyFavorite)
9 | app.Patch("/user/favorites/_webvpn", ModifyFavorite)
10 | app.Delete("/user/favorites", DeleteFavorite)
11 | app.Get("/user/favorite_groups", ListFavoriteGroups)
12 | app.Post("/user/favorite_groups", AddFavoriteGroup)
13 | app.Put("/user/favorite_groups", ModifyFavoriteGroup)
14 | app.Patch("/user/favorite_groups/_webvpn", ModifyFavoriteGroup)
15 | app.Delete("/user/favorite_groups", DeleteFavoriteGroup)
16 | app.Put("/user/favorites/move", MoveFavorite)
17 | }
18 |
--------------------------------------------------------------------------------
/apis/favourite/schemas.go:
--------------------------------------------------------------------------------
1 | package favourite
2 |
3 | type Response struct {
4 | Message string `json:"message"`
5 | Data []int `json:"data"`
6 | }
7 |
8 | type ListFavoriteModel struct {
9 | Order string `json:"order" query:"order" validate:"omitempty,oneof=id time_created hole_time_updated" default:"time_created"`
10 | Plain bool `json:"plain" default:"false" query:"plain"`
11 | FavoriteGroupID *int `json:"favorite_group_id" query:"favorite_group_id"`
12 | }
13 |
14 | type AddModel struct {
15 | HoleID int `json:"hole_id"`
16 | FavoriteGroupID int `json:"favorite_group_id" default:"0"`
17 | }
18 |
19 | type ModifyModel struct {
20 | HoleIDs []int `json:"hole_ids"`
21 | FavoriteGroupID int `json:"favorite_group_id" default:"0"`
22 | }
23 |
24 | type DeleteModel struct {
25 | HoleID int `json:"hole_id"`
26 | FavoriteGroupID int `json:"favorite_group_id" default:"0"`
27 | }
28 |
29 | type AddFavoriteGroupModel struct {
30 | Name string `json:"name" validate:"required,max=64"`
31 | }
32 |
33 | type ModifyFavoriteGroupModel struct {
34 | Name string `json:"name" validate:"required,max=64"`
35 | FavoriteGroupID *int `json:"favorite_group_id" validate:"required"`
36 | }
37 |
38 | type DeleteFavoriteGroupModel struct {
39 | FavoriteGroupID *int `json:"favorite_group_id" validate:"required"`
40 | }
41 |
42 | type MoveModel struct {
43 | HoleIDs []int `json:"hole_ids"`
44 | FromFavoriteGroupID *int `json:"from_favorite_group_id" default:"0" validate:"required"`
45 | ToFavoriteGroupID *int `json:"to_favorite_group_id" validate:"required"`
46 | }
47 |
48 | type ListFavoriteGroupModel struct {
49 | Order string `json:"order" query:"order" validate:"omitempty,oneof=id time_created time_updated" default:"time_created"`
50 | Plain bool `json:"plain" default:"false" query:"plain"`
51 | }
52 |
--------------------------------------------------------------------------------
/apis/floor/routes.go:
--------------------------------------------------------------------------------
1 | package floor
2 |
3 | import (
4 | "github.com/gofiber/fiber/v2"
5 |
6 | "treehole_next/utils"
7 | )
8 |
9 | func RegisterRoutes(app fiber.Router) {
10 | app.Post("/floors/search", SearchFloors)
11 | app.Get("/floors/search", SearchFloors)
12 |
13 | app.Get("/holes/:id/floors", ListFloorsInAHole)
14 | app.Get("/floors", ListFloorsOld)
15 | app.Get("/floors/:id", GetFloor)
16 | app.Post("/holes/:id/floors", utils.MiddlewareHasAnsweredQuestions, CreateFloor)
17 | app.Post("/floors", utils.MiddlewareHasAnsweredQuestions, CreateFloorOld)
18 | app.Put("/floors/:id", ModifyFloor)
19 | app.Patch("/floors/:id/_webvpn", ModifyFloor)
20 | app.Post("/floors/:id/like/:like", ModifyFloorLike)
21 | app.Delete("/floors/:id", DeleteFloor)
22 |
23 | app.Get("/users/me/floors", ListReplyFloors)
24 |
25 | app.Get("/floors/:id/history", GetFloorHistory)
26 | app.Post("/floors/:id/restore/:floor_history_id", RestoreFloor)
27 |
28 | app.Post("/config/search", SearchConfig)
29 | app.Get("/floors/:id/punishment", GetPunishmentHistory)
30 | app.Get("/floors/:id/user_silence", GetUserSilence)
31 |
32 | app.Get("/floors/_sensitive", ListSensitiveFloors)
33 | app.Put("/floors/:id/_sensitive", ModifyFloorSensitive)
34 | app.Patch("/floors/:id/_sensitive/_webvpn", ModifyFloorSensitive)
35 | }
36 |
--------------------------------------------------------------------------------
/apis/floor/schemas.go:
--------------------------------------------------------------------------------
1 | package floor
2 |
3 | import (
4 | "time"
5 |
6 | "github.com/opentreehole/go-common"
7 |
8 | "treehole_next/models"
9 | )
10 |
11 | type ListModel struct {
12 | Size int `json:"size" query:"size" default:"30" validate:"min=0,max=50"` // length of object array
13 | Offset int `json:"offset" query:"offset" default:"0" validate:"min=0"` // offset of object array
14 | Sort string `json:"sort" query:"sort" default:"asc" validate:"oneof=asc desc"` // Sort order
15 | OrderBy string `json:"order_by" query:"order_by" default:"id" validate:"oneof=id like"` // SQL ORDER BY field
16 | }
17 |
18 | type ListOldModel struct {
19 | HoleID int `query:"hole_id" json:"hole_id"`
20 | Size int `query:"length" json:"length" validate:"min=0,max=50" `
21 | Offset int `query:"start_floor" json:"start_floor"`
22 | Search string `query:"s" json:"s"`
23 | }
24 |
25 | type CreateModel struct {
26 | Content string `json:"content" validate:"required"`
27 | // Admin and Operator only
28 | SpecialTag string `json:"special_tag" validate:"omitempty,max=16"`
29 | // id of the floor to which replied
30 | ReplyTo int `json:"reply_to" validate:"min=0"`
31 | }
32 |
33 | type CreateOldModel struct {
34 | HoleID int `json:"hole_id" validate:"min=1"`
35 | CreateModel
36 | }
37 |
38 | type CreateOldResponse struct {
39 | Data models.Floor `json:"data"`
40 | Message string `json:"message"`
41 | }
42 |
43 | type ModifyModel struct {
44 | // Owner or admin, the original content should be moved to floor_history
45 | Content *string `json:"content" validate:"omitempty"`
46 | // Admin and Operator only
47 | SpecialTag *string `json:"special_tag" validate:"omitempty,max=16"`
48 | // All user, deprecated, "add" is like, "cancel" is reset
49 | Like *string `json:"like" validate:"omitempty,oneof=add cancel"`
50 | // 仅管理员,留空则重置,高优先级
51 | Fold *string `json:"fold_v2" validate:"omitempty,max=64"`
52 | // 仅管理员,留空则重置,低优先级
53 | FoldFrontend []string `json:"fold" validate:"omitempty"`
54 | }
55 |
56 | func (body ModifyModel) DoNothing() bool {
57 | return body.Content == nil && body.SpecialTag == nil && body.Like == nil && body.Fold == nil && body.FoldFrontend == nil
58 | }
59 |
60 | func (body ModifyModel) CheckPermission(user *models.User, floor *models.Floor, hole *models.Hole) error {
61 | if body.Content != nil {
62 | if !user.IsAdmin {
63 | if user.ID != floor.UserID {
64 | return common.Forbidden("这不是您的楼层,您没有权限修改")
65 | } else {
66 | if user.BanDivision[hole.DivisionID] != nil {
67 | return common.Forbidden(user.BanDivisionMessage(hole.DivisionID))
68 | } else if hole.Locked {
69 | return common.Forbidden("此洞已被锁定,您无法修改")
70 | } else if floor.Deleted {
71 | return common.Forbidden("此洞已被删除,您无法修改")
72 | }
73 | }
74 | } else {
75 | if user.BanDivision[hole.DivisionID] != nil {
76 | return common.Forbidden(user.BanDivisionMessage(hole.DivisionID))
77 | }
78 | }
79 | }
80 | if (body.Fold != nil || body.FoldFrontend != nil) && !user.IsAdmin {
81 | return common.Forbidden("非管理员禁止折叠")
82 | }
83 | if body.SpecialTag != nil && !user.IsAdmin {
84 | return common.Forbidden("非管理员禁止修改特殊标签")
85 | }
86 | return nil
87 | }
88 |
89 | type DeleteModel struct {
90 | Reason string `json:"delete_reason" validate:"max=32"`
91 | }
92 |
93 | type RestoreModel struct {
94 | Reason string `json:"restore_reason" validate:"required,max=32"`
95 | }
96 |
97 | type SearchConfigModel struct {
98 | Open bool `json:"open"`
99 | }
100 |
101 | type SensitiveFloorRequest struct {
102 | Size int `json:"size" query:"size" default:"10" validate:"max=10"`
103 | Offset common.CustomTime `json:"offset" query:"offset" swaggertype:"string"`
104 | OrderBy string `json:"order_by" query:"order_by" default:"time_created" validate:"oneof=time_created time_updated"`
105 | Open bool `json:"open" query:"open"`
106 | All bool `json:"all" query:"all"`
107 | }
108 |
109 | type SensitiveFloorResponse struct {
110 | ID int `json:"id"`
111 | CreatedAt time.Time `json:"time_created"`
112 | UpdatedAt time.Time `json:"time_updated"`
113 | Content string `json:"content"`
114 | Modified int `json:"modified"`
115 | IsActualSensitive *bool `json:"is_actual_sensitive"`
116 | HoleID int `json:"hole_id"`
117 | Deleted bool `json:"deleted"`
118 | SensitiveDetail string `json:"sensitive_detail,omitempty"`
119 | }
120 |
121 | func (s *SensitiveFloorResponse) FromModel(floor *models.Floor) *SensitiveFloorResponse {
122 | s.ID = floor.ID
123 | s.CreatedAt = floor.CreatedAt
124 | s.UpdatedAt = floor.UpdatedAt
125 | s.Content = floor.Content
126 | s.Modified = floor.Modified
127 | s.IsActualSensitive = floor.IsActualSensitive
128 | s.HoleID = floor.HoleID
129 | s.Deleted = floor.Deleted
130 | s.SensitiveDetail = floor.SensitiveDetail
131 | return s
132 | }
133 |
134 | type ModifySensitiveFloorRequest struct {
135 | IsActualSensitive bool `json:"is_actual_sensitive"`
136 | }
137 |
138 | type BanDivision map[int]*time.Time
139 |
--------------------------------------------------------------------------------
/apis/floor/search.go:
--------------------------------------------------------------------------------
1 | package floor
2 |
3 | import (
4 | "github.com/gofiber/fiber/v2"
5 | "github.com/opentreehole/go-common"
6 |
7 | . "treehole_next/config"
8 | . "treehole_next/models"
9 | . "treehole_next/utils"
10 | )
11 |
12 | // SearchQuery is the query struct for searching floors
13 | type SearchQuery struct {
14 | Search string `json:"search" query:"search" validate:"required"`
15 | Size int `json:"size" query:"size" validate:"min=0" default:"10"`
16 | Offset int `json:"offset" query:"offset" validate:"min=0" default:"0"`
17 |
18 | // Accurate is used to determine whether to use accurate search
19 | Accurate bool `json:"accurate" query:"accurate" default:"false"`
20 |
21 | // StartTime and EndTime are used to filter floors by time
22 | // Both are Unix timestamps, and are optional
23 | StartTime *int64 `json:"start_time" query:"start_time"`
24 | EndTime *int64 `json:"end_time" query:"end_time"`
25 | }
26 |
27 | // SearchFloors
28 | //
29 | // @Summary SearchFloors In ElasticSearch
30 | // @Tags Search
31 | // @Produce application/json
32 | // @Router /floors/search [get]
33 | // @Router /floors/search [post]
34 | // @Param object query SearchQuery true "search_query"
35 | // @Success 200 {array} models.Floor
36 | func SearchFloors(c *fiber.Ctx) error {
37 | var query SearchQuery
38 | err := common.ValidateQuery(c, &query)
39 | if err != nil {
40 | return err
41 | }
42 |
43 | floors, err := Search(c, query.Search, query.Size, query.Offset, query.Accurate, query.StartTime, query.EndTime)
44 | if err != nil {
45 | return err
46 | }
47 |
48 | return Serialize(c, floors)
49 | }
50 |
51 | // SearchConfig
52 | //
53 | // @Summary change search config
54 | // @Tags Search
55 | // @Produce application/json
56 | // @Router /config/search [post]
57 | // @Param json body SearchConfigModel true "json"
58 | // @Success 200 {object} Map
59 | func SearchConfig(c *fiber.Ctx) error {
60 | var body SearchConfigModel
61 | err := c.BodyParser(&body)
62 | if err != nil {
63 | return err
64 | }
65 | user, err := GetCurrLoginUser(c)
66 | if err != nil {
67 | return err
68 | }
69 | if !user.IsAdmin {
70 | return common.Forbidden()
71 | }
72 | if DynamicConfig.OpenSearch.Load() == body.Open {
73 | return c.Status(200).JSON(Map{"message": "已经被修改"})
74 | } else {
75 | DynamicConfig.OpenSearch.Store(body.Open)
76 | return c.Status(201).JSON(Map{"message": "修改成功"})
77 | }
78 | }
79 |
80 | func SearchFloorsOld(c *fiber.Ctx, query *ListOldModel) error {
81 | if !DynamicConfig.OpenSearch.Load() {
82 | return common.Forbidden("茶楼流量激增,搜索功能暂缓开放")
83 | }
84 |
85 | floors, err := Search(c, query.Search, query.Size, query.Offset, false, nil, nil)
86 | if err != nil {
87 | return err
88 | }
89 |
90 | return Serialize(c, &floors)
91 | }
92 |
--------------------------------------------------------------------------------
/apis/floor/utils.go:
--------------------------------------------------------------------------------
1 | package floor
2 |
3 | import "fmt"
4 |
5 | func generateDeleteReason(reason string, isOwner bool) string {
6 | if reason == "" {
7 | if isOwner {
8 | return "该内容被作者删除"
9 | }
10 | reason = "违反社区规范"
11 | }
12 | return fmt.Sprintf("该内容因%s被删除", reason)
13 | }
14 |
--------------------------------------------------------------------------------
/apis/hole/purge_hole.go:
--------------------------------------------------------------------------------
1 | package hole
2 |
3 | import (
4 | "context"
5 | "github.com/rs/zerolog/log"
6 | "gorm.io/gorm"
7 | "gorm.io/gorm/clause"
8 | "time"
9 |
10 | "treehole_next/config"
11 | . "treehole_next/models"
12 | )
13 |
14 | func purgeHole() (err error) {
15 | const REASON = "purge_hole"
16 | const DELETE_CONTENT = "该内容已被删除"
17 |
18 | return DB.Transaction(func(tx *gorm.DB) (err error) {
19 |
20 | // load holeIDs, lock for update
21 | var holeIDs []int
22 | err = tx.Model(&Hole{}).
23 | Clauses(clause.Locking{Strength: "UPDATE"}).
24 | Where("no_purge = ?", false).
25 | Where("division_id IN ?", config.Config.HolePurgeDivisions).
26 | Where("updated_at < ?",
27 | time.Now().AddDate(0, 0, -config.Config.HolePurgeDays),
28 | ).Pluck("id", &holeIDs).Error
29 | if err != nil {
30 | return err
31 | }
32 |
33 | if len(holeIDs) == 0 {
34 | return nil
35 | }
36 |
37 | /* delete all floors in hole of holeIOs */
38 |
39 | // get floors, lock for update
40 | var floors []Floor
41 | err = tx.
42 | Clauses(clause.Locking{Strength: "UPDATE"}).
43 | Where("hole_id IN ?", holeIDs).
44 | Find(&floors).Error
45 | if err != nil {
46 | return err
47 | }
48 | if len(floors) == 0 {
49 | return nil
50 | }
51 |
52 | // generate floorHistory
53 | var floorHistorySlice = make([]FloorHistory, 0, len(floors))
54 | for i := range floors {
55 | floorHistorySlice = append(floorHistorySlice, FloorHistory{
56 | Content: floors[i].Content,
57 | Reason: REASON,
58 | FloorID: floors[i].ID,
59 | UserID: 1,
60 | SensitiveDetail: floors[i].SensitiveDetail,
61 | IsActualSensitive: floors[i].IsActualSensitive,
62 | IsSensitive: floors[i].IsSensitive,
63 | })
64 | }
65 | err = tx.Create(&floorHistorySlice).Error
66 | if err != nil {
67 | return err
68 | }
69 |
70 | // delete floors
71 | var floorIDs = make([]int, 0, len(floors))
72 | for i := range floors {
73 | floorIDs = append(floorIDs, floors[i].ID)
74 | }
75 |
76 | err = tx.Model(&Floor{}).
77 | Where("id IN ?", floorIDs).
78 | Updates(map[string]any{
79 | "deleted": true,
80 | "content": DELETE_CONTENT,
81 | }).Error
82 | if err != nil {
83 | return err
84 | }
85 |
86 | /* delete all holes in holeIDs */
87 | err = tx.
88 | Where("id IN ?", holeIDs).
89 | Delete(&Hole{}).Error
90 | if err != nil {
91 | return err
92 | }
93 |
94 | // delete floor in search engine
95 | go BulkDelete(floorIDs)
96 |
97 | // log
98 | log.Info().
99 | Ints("hole_ids", holeIDs).
100 | Ints("floor_ids", floorIDs).
101 | Msg("purge hole")
102 |
103 | return nil
104 | })
105 | }
106 |
107 | func PurgeHole(ctx context.Context) {
108 | ticker := time.NewTicker(time.Minute * 10)
109 | defer ticker.Stop()
110 | for {
111 | select {
112 | case <-ticker.C:
113 | err := purgeHole()
114 | if err != nil {
115 | log.Err(err).Msg("error purge hole")
116 | }
117 | case <-ctx.Done():
118 | return
119 | }
120 | }
121 | }
122 |
--------------------------------------------------------------------------------
/apis/hole/routes.go:
--------------------------------------------------------------------------------
1 | package hole
2 |
3 | import (
4 | "github.com/gofiber/fiber/v2"
5 |
6 | "treehole_next/utils"
7 | )
8 |
9 | func RegisterRoutes(app fiber.Router) {
10 | app.Get("/divisions/:id/holes", ListHolesByDivision)
11 | app.Get("/tags/:name/holes", ListHolesByTag)
12 | app.Get("/users/me/holes", ListHolesByMe)
13 | app.Get("/holes/:id", GetHole)
14 | app.Get("/holes", ListHoles)
15 | app.Get("/holes/_good", ListGoodHoles)
16 | app.Post("/divisions/:id/holes", utils.MiddlewareHasAnsweredQuestions, CreateHole)
17 | app.Post("/holes", utils.MiddlewareHasAnsweredQuestions, CreateHoleOld)
18 | app.Patch("/holes/:id/_webvpn", ModifyHole)
19 | app.Patch("/holes/:id", PatchHole)
20 | app.Put("/holes/:id", ModifyHole)
21 | app.Delete("/holes/:id", HideHole)
22 | app.Delete("/holes/:id/_force", DeleteHole)
23 | }
24 |
--------------------------------------------------------------------------------
/apis/hole/schemas.go:
--------------------------------------------------------------------------------
1 | package hole
2 |
3 | import (
4 | "time"
5 |
6 | "github.com/opentreehole/go-common"
7 |
8 | "treehole_next/apis/tag"
9 | "treehole_next/models"
10 | )
11 |
12 | type QueryTime struct {
13 | Size int `json:"size" query:"size" default:"10" validate:"max=10"`
14 | // updated time < offset (default is now)
15 | Offset common.CustomTime `json:"offset" query:"offset" swaggertype:"string"`
16 | Order string `json:"order" query:"order"`
17 | }
18 |
19 | func (q *QueryTime) SetDefaults() {
20 | if q.Offset.IsZero() {
21 | q.Offset = common.CustomTime{Time: time.Now()}
22 | }
23 | }
24 |
25 | type ListOldModel struct {
26 | Offset0 common.CustomTime `json:"start_time" query:"start_time" swaggertype:"string"`
27 | Offset common.CustomTime `json:"offset" query:"offset" swaggertype:"string"`
28 | Size0 int `json:"length" query:"length" default:"10" validate:"max=10"`
29 | Size int `json:"size" query:"size" default:"10" validate:"max=10" `
30 | Tag string `json:"tag" query:"tag"`
31 | Tags []string `json:"tags" query:"tags"`
32 | DivisionID int `json:"division_id" query:"division_id"`
33 | Order string `json:"order" query:"order"`
34 | CreatedStart *common.CustomTime `json:"created_start" query:"created_start" swaggertype:"string"`
35 | CreatedEnd *common.CustomTime `json:"created_end" query:"created_end" swaggertype:"string"`
36 | }
37 |
38 | func (q *ListOldModel) SetDefaults() {
39 | if q.Size == 0 {
40 | q.Size = q.Size0
41 | }
42 | if q.Offset.IsZero() {
43 | if q.Offset0.IsZero() {
44 | q.Offset = common.CustomTime{Time: time.Now()}
45 | } else {
46 | q.Offset = q.Offset0
47 | }
48 | }
49 | if q.CreatedStart == nil {
50 | q.CreatedStart = &common.CustomTime{Time: time.Time{}} // 默认值为零时间
51 | }
52 | if q.CreatedEnd == nil {
53 | q.CreatedEnd = &common.CustomTime{Time: time.Now()}
54 | }
55 | }
56 |
57 | type TagCreateModelSlice struct {
58 | Tags []tag.CreateModel `json:"tags" validate:"omitempty,min=1,max=10,dive"` // All users
59 | }
60 |
61 | func (tagCreateModelSlice TagCreateModelSlice) ToName() []string {
62 | tags := make([]string, 0, len(tagCreateModelSlice.Tags))
63 | for _, tagCreateModel := range tagCreateModelSlice.Tags {
64 | tags = append(tags, tagCreateModel.Name)
65 | }
66 | return tags
67 | }
68 |
69 | type CreateModel struct {
70 | Content string `json:"content" validate:"required"`
71 | TagCreateModelSlice
72 | // Admin and Operator only
73 | SpecialTag string `json:"special_tag" validate:"max=16"`
74 | }
75 |
76 | type CreateOldModel struct {
77 | CreateModel
78 | DivisionID int `json:"division_id" validate:"omitempty,min=1" default:"1"`
79 | }
80 |
81 | type CreateOldResponse struct {
82 | Data models.Hole `json:"data"`
83 | Message string `json:"message"`
84 | }
85 |
86 | type ModifyModel struct {
87 | TagCreateModelSlice
88 | DivisionID *int `json:"division_id" validate:"omitempty,min=1"` // Admin and owner only
89 | Hidden *bool `json:"hidden"` // Admin only
90 | Unhidden *bool `json:"unhidden"` // admin only
91 | Lock *bool `json:"lock"` // admin only
92 | }
93 |
94 | func (body ModifyModel) CheckPermission(user *models.User, hole *models.Hole) error {
95 | if body.DivisionID != nil && !user.IsAdmin {
96 | return common.Forbidden("非管理员禁止修改分区")
97 | }
98 | if body.Hidden != nil && !user.IsAdmin {
99 | return common.Forbidden("非管理员禁止隐藏帖子")
100 | }
101 | if body.Unhidden != nil && !user.IsAdmin {
102 | return common.BadRequest("非管理员禁止取消隐藏")
103 | }
104 | if body.Tags != nil && !(user.IsAdmin) {
105 | return common.Forbidden()
106 | }
107 | if body.Tags != nil && len(body.Tags) == 0 {
108 | return common.BadRequest("tags 不能为空")
109 | }
110 | if body.Lock != nil && !user.IsAdmin {
111 | return common.Forbidden("非管理员禁止锁定帖子")
112 | }
113 | return nil
114 | }
115 |
116 | func (body ModifyModel) DoNothing() bool {
117 | return body.Hidden == nil && body.Unhidden == nil && body.Tags == nil && body.DivisionID == nil && body.Lock == nil
118 | }
119 |
--------------------------------------------------------------------------------
/apis/hole/update_views.go:
--------------------------------------------------------------------------------
1 | package hole
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "strconv"
7 | "strings"
8 | "time"
9 |
10 | "github.com/rs/zerolog/log"
11 |
12 | . "treehole_next/models"
13 | )
14 |
15 | var holeViewsChan = make(chan int, 1000)
16 | var holeViews = map[int]int{}
17 |
18 | func updateHoleViews() {
19 | /*
20 | UPDATE table
21 | SET field = CASE id
22 | WHEN 1 THEN 'value'
23 | WHEN 2 THEN 'value'
24 | WHEN 3 THEN 'value'
25 | END
26 | WHERE id IN (1,2,3)
27 | */
28 | length := len(holeViews)
29 | if length == 0 {
30 | return
31 | }
32 | keys := make([]string, 0, length)
33 |
34 | var builder strings.Builder
35 | builder.WriteString("UPDATE hole SET view = CASE id ")
36 |
37 | for holeID, views := range holeViews {
38 | builder.WriteString(fmt.Sprintf("WHEN %d THEN view + %d ", holeID, views))
39 | keys = append(keys, strconv.Itoa(holeID))
40 | delete(holeViews, holeID)
41 | }
42 | builder.WriteString("END WHERE id IN (")
43 | builder.WriteString(strings.Join(keys, ","))
44 | builder.WriteString(")")
45 |
46 | result := DB.Exec(builder.String())
47 | if result.Error != nil {
48 | log.Err(result.Error).Msg("update hole views failed")
49 | } else {
50 | log.Info().Strs("updated", keys).Msg("update hole views success")
51 | }
52 | }
53 |
54 | func UpdateHoleViews(ctx context.Context) {
55 |
56 | ticker := time.NewTicker(time.Second * 60)
57 | defer ticker.Stop()
58 | for {
59 | select {
60 | case <-ticker.C:
61 | updateHoleViews()
62 | case holeID := <-holeViewsChan:
63 | holeViews[holeID]++
64 | case <-ctx.Done():
65 | updateHoleViews()
66 | log.Info().Msg("task UpdateHoleViews stopped...")
67 | return
68 | }
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/apis/message/apis.go:
--------------------------------------------------------------------------------
1 | package message
2 |
3 | import (
4 | "github.com/opentreehole/go-common"
5 |
6 | . "treehole_next/models"
7 | . "treehole_next/utils"
8 |
9 | "github.com/gofiber/fiber/v2"
10 | )
11 |
12 | // ListMessages
13 | // @Summary List Messages of a User
14 | // @Tags Message
15 | // @Produce application/json
16 | // @Router /messages [get]
17 | // @Success 200 {array} Message
18 | // @Param object query ListModel false "query"
19 | func ListMessages(c *fiber.Ctx) error {
20 | var query ListModel
21 | err := common.ValidateQuery(c, &query)
22 | if err != nil {
23 | return err
24 | }
25 |
26 | userID, err := common.GetUserID(c)
27 | if err != nil {
28 | return err
29 | }
30 |
31 | messages := Messages{}
32 |
33 | if query.NotRead {
34 | DB.Raw(`
35 | SELECT message.*,message_user.has_read FROM message
36 | INNER JOIN message_user
37 | WHERE message.id = message_user.message_id and message_user.user_id = ? and message_user.has_read = false
38 | ORDER BY updated_at DESC`,
39 | userID,
40 | ).Scan(&messages)
41 | } else {
42 | DB.Raw(`
43 | SELECT message.*,message_user.has_read FROM message
44 | INNER JOIN message_user
45 | WHERE message.id = message_user.message_id and message_user.user_id = ?
46 | ORDER BY updated_at DESC`,
47 | userID,
48 | ).Scan(&messages)
49 | }
50 |
51 | return Serialize(c, &messages)
52 | }
53 |
54 | // SendMail
55 | // @Summary Send a Mail
56 | // @Description Send to multiple recipients and save to db, admin only.
57 | // @Tags Message
58 | // @Produce application/json
59 | // @Param json body CreateModel true "json"
60 | // @Router /messages [post]
61 | // @Success 201 {object} Message
62 | func SendMail(c *fiber.Ctx) error {
63 | var body CreateModel
64 | err := common.ValidateBody(c, &body)
65 | if err != nil {
66 | return err
67 | }
68 |
69 | // get user
70 | user, err := GetCurrLoginUser(c)
71 | if err != nil {
72 | return err
73 | }
74 |
75 | // permission
76 | if !user.IsAdmin {
77 | return common.Forbidden()
78 | }
79 |
80 | // construct mail
81 | mail := Notification{
82 | Description: body.Description,
83 | Recipients: body.Recipients,
84 | Data: Map{},
85 | Title: "您有一封站内信",
86 | Type: MessageTypeMail,
87 | URL: "/api/messages",
88 | }
89 |
90 | // send
91 | message, err := mail.Send()
92 | if err != nil {
93 | return err
94 | }
95 |
96 | CreateAdminLog(DB, AdminLogTypeMessage, user.ID, body)
97 |
98 | return Serialize(c.Status(201), &message)
99 | }
100 |
101 | // ClearMessages
102 | // @Summary Clear Messages of a User
103 | // @Tags Message
104 | // @Produce application/json
105 | // @Router /messages/clear [post]
106 | // @Success 204
107 | func ClearMessages(c *fiber.Ctx) error {
108 | userID, err := common.GetUserID(c)
109 | if err != nil {
110 | return err
111 | }
112 |
113 | result := DB.Exec(
114 | "UPDATE message_user SET has_read = true WHERE user_id = ?",
115 | userID,
116 | )
117 | if result.Error != nil {
118 | return result.Error
119 | }
120 | return c.Status(204).JSON(nil)
121 | }
122 |
123 | // ClearMessagesDeprecated
124 | // @Summary Clear Messages Deprecated
125 | // @Tags Message
126 | // @Produce application/json
127 | // @Router /messages [put]
128 | // @Router /messages/_webvpn [patch]
129 | // @Success 204
130 | func ClearMessagesDeprecated(c *fiber.Ctx) error {
131 | return ClearMessages(c)
132 | }
133 |
134 | // DeleteMessage
135 | // @Summary Delete a message of a user
136 | // @Tags Message
137 | // @Produce application/json
138 | // @Router /messages/{id} [delete]
139 | // @Param id path int true "message id"
140 | // @Success 204
141 | func DeleteMessage(c *fiber.Ctx) error {
142 | userID, err := common.GetUserID(c)
143 | if err != nil {
144 | return err
145 | }
146 |
147 | id, _ := c.ParamsInt("id")
148 | result := DB.Exec(
149 | "UPDATE message_user SET has_read = true WHERE user_id = ? AND message_id = ?",
150 | userID, id,
151 | )
152 | if result.Error != nil {
153 | return result.Error
154 | }
155 | return c.Status(204).JSON(nil)
156 | }
157 |
--------------------------------------------------------------------------------
/apis/message/purge.go:
--------------------------------------------------------------------------------
1 | package message
2 |
3 | import (
4 | "time"
5 |
6 | "github.com/rs/zerolog/log"
7 |
8 | "treehole_next/config"
9 | . "treehole_next/models"
10 | )
11 |
12 | func purgeMessage() error {
13 | return DB.Exec(
14 | "DELETE FROM message WHERE created_at < ?",
15 | time.Now().Add(-time.Hour*24*time.Duration(config.Config.MessagePurgeDays)),
16 | ).Error
17 | }
18 |
19 | func PurgeMessage() {
20 | ticker := time.NewTicker(time.Hour * 24)
21 | defer ticker.Stop()
22 | for range ticker.C {
23 | err := purgeMessage()
24 | if err != nil {
25 | log.Err(err).Msg("error purge message")
26 | }
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/apis/message/routes.go:
--------------------------------------------------------------------------------
1 | package message
2 |
3 | import "github.com/gofiber/fiber/v2"
4 |
5 | func RegisterRoutes(app fiber.Router) {
6 | app.Post("/messages", SendMail)
7 | app.Get("/messages", ListMessages)
8 | app.Post("/messages/clear", ClearMessages)
9 | app.Put("/messages", ClearMessagesDeprecated)
10 | app.Patch("/messages/_webvpn", ClearMessagesDeprecated)
11 | app.Delete("/messages/:id", DeleteMessage)
12 | }
13 |
--------------------------------------------------------------------------------
/apis/message/schemas.go:
--------------------------------------------------------------------------------
1 | package message
2 |
3 | type CreateModel struct {
4 | // MessageTypeMail
5 | Description string `json:"description"`
6 | Recipients []int `json:"recipients" validate:"required"`
7 | }
8 |
9 | type ListModel struct {
10 | NotRead bool `json:"not_read" default:"false" query:"not_read"`
11 | }
12 |
--------------------------------------------------------------------------------
/apis/penalty/api.go:
--------------------------------------------------------------------------------
1 | // Package penalty is deprecated! Please use APIs in auth.
2 | package penalty
3 |
4 | import (
5 | "fmt"
6 | "time"
7 |
8 | "treehole_next/config"
9 | . "treehole_next/models"
10 | "treehole_next/utils"
11 |
12 | "github.com/opentreehole/go-common"
13 | "gorm.io/gorm"
14 | "gorm.io/gorm/clause"
15 |
16 | "github.com/gofiber/fiber/v2"
17 | )
18 |
19 | type PostBody struct {
20 | PenaltyLevel *int `json:"penalty_level" validate:"omitempty"` // low priority, deprecated
21 | Days *int `json:"days" validate:"omitempty,min=1"` // high priority
22 | Divisions []int `json:"divisions" validate:"omitempty,min=1"` // high priority
23 | Reason string `json:"reason"` // optional
24 | }
25 |
26 | type ForeverPostBody struct {
27 | Reason string `json:"reason"` // optional
28 | }
29 |
30 | // BanUser
31 | //
32 | // @Summary Ban publisher of a floor
33 | // @Tags Penalty
34 | // @Produce json
35 | // @Router /penalty/{floor_id} [post]
36 | // @Param json body PostBody true "json"
37 | // @Success 201 {object} User
38 | func BanUser(c *fiber.Ctx) error {
39 | // validate body
40 | var body PostBody
41 | err := common.ValidateBody(c, &body)
42 | if err != nil {
43 | return err
44 | }
45 |
46 | floorID, err := c.ParamsInt("id")
47 | if err != nil {
48 | return err
49 | }
50 |
51 | // get user
52 | user, err := GetCurrLoginUser(c)
53 | if err != nil {
54 | return err
55 | }
56 |
57 | // permission
58 | if !user.IsAdmin {
59 | return common.Forbidden()
60 | }
61 |
62 | var floor Floor
63 | err = DB.Take(&floor, floorID).Error
64 | if err != nil {
65 | return err
66 | }
67 |
68 | var hole Hole
69 | err = DB.Take(&hole, floor.HoleID).Error
70 | if err != nil {
71 | return err
72 | }
73 |
74 | var days int
75 | if body.Days != nil {
76 | days = *body.Days
77 | if days <= 0 {
78 | days = 1
79 | }
80 | } else if body.PenaltyLevel != nil {
81 | switch *body.PenaltyLevel {
82 | case 1:
83 | days = 1
84 | case 2:
85 | days = 5
86 | case 3:
87 | days = 999
88 | default:
89 | days = 1
90 | }
91 | }
92 |
93 | duration := time.Duration(days) * 24 * time.Hour
94 |
95 | punishment := Punishment{
96 | UserID: floor.UserID,
97 | MadeBy: user.ID,
98 | FloorID: &floor.ID,
99 | DivisionID: hole.DivisionID,
100 | Duration: &duration,
101 | Day: days,
102 | Reason: body.Reason,
103 | }
104 | user, err = punishment.Create()
105 | if err != nil {
106 | return err
107 | }
108 |
109 | // construct message for user
110 | message := Notification{
111 | Data: floor,
112 | Recipients: []int{floor.UserID},
113 | Description: fmt.Sprintf(
114 | "您因为违反社区公约被禁言。时间:%d天,原因:%s\n如有异议,请联系admin@danta.tech。",
115 | days,
116 | body.Reason,
117 | ),
118 | Title: "处罚通知",
119 | Type: MessageTypePermission,
120 | URL: fmt.Sprintf("/api/floors/%d", floor.ID),
121 | }
122 |
123 | // send
124 | _, err = message.Send()
125 | if err != nil {
126 | return err
127 | }
128 |
129 | return c.JSON(user)
130 | }
131 |
132 | // BanUserForever
133 | //
134 | // @Summary Ban publisher of a floor forever
135 | // @Tags Penalty
136 | // @Produce json
137 | // @Router /penalty/{floor_id}/_forever [post]
138 | // @Param json body ForeverPostBody true "json"
139 | // @Success 201 {object} User
140 | func BanUserForever(c *fiber.Ctx) error {
141 | // validate body
142 | var body ForeverPostBody
143 | err := common.ValidateBody(c, &body)
144 | if err != nil {
145 | return err
146 | }
147 |
148 | floorID, err := c.ParamsInt("id")
149 | if err != nil {
150 | return err
151 | }
152 |
153 | // get user
154 | user, err := GetCurrLoginUser(c)
155 | if err != nil {
156 | return err
157 | }
158 |
159 | // permission
160 | if !user.IsAdmin {
161 | return common.Forbidden()
162 | }
163 |
164 | var floor Floor
165 | err = DB.Take(&floor, floorID).Error
166 | if err != nil {
167 | return err
168 | }
169 |
170 | // var hole Hole
171 | // err = DB.Take(&hole, floor.HoleID).Error
172 | // if err != nil {
173 | // return err
174 | // }
175 |
176 | days := 3650
177 | duration := time.Duration(days) * 24 * time.Hour
178 |
179 | var punishments Punishments
180 | var punishment *Punishment
181 | var divisionIDs []int
182 | madeBy := user.ID
183 | user = &User{
184 | ID: floor.UserID,
185 | }
186 | err = DB.Transaction(func(tx *gorm.DB) (err error) {
187 | err = tx.Clauses(clause.Locking{Strength: "UPDATE"}).Take(&user).Error
188 | if err != nil {
189 | return err
190 | }
191 |
192 | err = tx.Clauses(clause.Locking{Strength: "UPDATE"}).Model(&Division{}).Select("ID").Scan(&divisionIDs).Error
193 | if err != nil {
194 | return err
195 | }
196 |
197 | ExcludeBanForeverDivisionIds := config.Config.ExcludeBanForeverDivisionIds
198 |
199 | divisionIDs = utils.Difference(divisionIDs, ExcludeBanForeverDivisionIds)
200 |
201 | for _, divisionID := range divisionIDs {
202 | punishment = &Punishment{
203 | UserID: floor.UserID,
204 | MadeBy: madeBy,
205 | FloorID: nil,
206 | DivisionID: divisionID,
207 | Duration: &duration,
208 | Day: days,
209 | Reason: body.Reason,
210 | StartTime: time.Now(),
211 | EndTime: time.Now().Add(duration),
212 | }
213 |
214 | if user.BanDivision[divisionID] == nil {
215 | user.BanDivision[divisionID] = &punishment.EndTime
216 | } else {
217 | user.BanDivision[divisionID].Add(*punishment.Duration)
218 | }
219 |
220 | punishments = append(punishments, punishment)
221 | }
222 | user.OffenceCount += len(divisionIDs)
223 |
224 | err = tx.Create(&punishments).Error
225 | if err != nil {
226 | return err
227 | }
228 |
229 | err = tx.Select("BanDivision", "OffenceCount").Save(&user).Error
230 | if err != nil {
231 | return err
232 | }
233 |
234 | return nil
235 | })
236 | if err != nil {
237 | return err
238 | }
239 |
240 | // construct message for user
241 | message := Notification{
242 | Data: floor,
243 | Recipients: []int{floor.UserID},
244 | Description: fmt.Sprintf(
245 | "您因为违反社区公约被禁言。时间:%d天,原因:%s\n如有异议,请联系admin@danta.tech。",
246 | days,
247 | body.Reason,
248 | ),
249 | Title: "处罚通知",
250 | Type: MessageTypePermission,
251 | URL: fmt.Sprintf("/api/floors/%d", floor.ID),
252 | }
253 |
254 | // send
255 | _, err = message.Send()
256 | if err != nil {
257 | return err
258 | }
259 |
260 | return c.JSON(user)
261 | }
262 |
263 | // ListMyPunishments godoc
264 | // @Summary List my punishments
265 | // @Tags Penalty
266 | // @Produce json
267 | // @Router /users/me/punishments [get]
268 | // @Success 200 {array} Punishment
269 | func ListMyPunishments(c *fiber.Ctx) error {
270 | userID, err := common.GetUserID(c)
271 | if err != nil {
272 | return err
273 | }
274 |
275 | punishments, err := listPunishmentsByUserID(userID)
276 | if err != nil {
277 | return err
278 | }
279 |
280 | return c.JSON(punishments)
281 | }
282 |
283 | // ListPunishmentsByUserID godoc
284 | // @Summary List punishments by user id
285 | // @Tags Penalty
286 | // @Produce json
287 | // @Router /users/{id}/punishments [get]
288 | // @Param id path int true "User ID"
289 | // @Success 200 {array} Punishment
290 | func ListPunishmentsByUserID(c *fiber.Ctx) error {
291 | userID, err := c.ParamsInt("id")
292 | if err != nil {
293 | return err
294 | }
295 |
296 | currentUser, err := GetCurrLoginUser(c)
297 | if err != nil {
298 | return err
299 | }
300 | if !currentUser.IsAdmin && currentUser.ID != userID {
301 | return common.Forbidden()
302 | }
303 |
304 | punishments, err := listPunishmentsByUserID(userID)
305 | if err != nil {
306 | return err
307 | }
308 |
309 | return c.JSON(punishments)
310 | }
311 |
312 | func listPunishmentsByUserID(userID int) ([]Punishment, error) {
313 | var punishments []Punishment
314 | err := DB.Where("user_id = ?", userID).Preload("Floor").Find(&punishments).Error
315 | if err != nil {
316 | return nil, err
317 | }
318 |
319 | // remove made_by
320 | for i := range punishments {
321 | punishments[i].MadeBy = 0
322 | }
323 |
324 | return punishments, nil
325 | }
326 |
327 | func RegisterRoutes(app fiber.Router) {
328 | app.Post("/penalty/:id/_forever", BanUserForever)
329 | app.Post("/penalty/:id", BanUser)
330 | app.Get("/users/me/punishments", ListMyPunishments)
331 | app.Get("/users/:id/punishments", ListPunishmentsByUserID)
332 | }
333 |
--------------------------------------------------------------------------------
/apis/report/apis.go:
--------------------------------------------------------------------------------
1 | package report
2 |
3 | import (
4 | "fmt"
5 | "time"
6 | . "treehole_next/models"
7 | . "treehole_next/utils"
8 |
9 | "github.com/opentreehole/go-common"
10 | "github.com/rs/zerolog/log"
11 |
12 | "github.com/gofiber/fiber/v2"
13 | "gorm.io/gorm"
14 | )
15 |
16 | // GetReport
17 | //
18 | // @Summary Get A Report
19 | // @Tags Report
20 | // @Produce application/json
21 | // @Router /reports/{id} [get]
22 | // @Param id path int true "id"
23 | // @Success 200 {object} Report
24 | // @Failure 404 {object} MessageModel
25 | func GetReport(c *fiber.Ctx) error {
26 | // validate query
27 | reportID, err := c.ParamsInt("id")
28 | if err != nil {
29 | return err
30 | }
31 |
32 | // find report
33 | var report Report
34 | result := LoadReportFloor(DB).First(&report, reportID)
35 | if result.Error != nil {
36 | return result.Error
37 | }
38 | return Serialize(c, &report)
39 | }
40 |
41 | // ListReports
42 | //
43 | // @Summary List All Reports
44 | // @Tags Report
45 | // @Produce application/json
46 | // @Router /reports [get]
47 | // @Param object query ListModel false "query"
48 | // @Success 200 {array} Report
49 | // @Failure 404 {object} MessageModel
50 | func ListReports(c *fiber.Ctx) error {
51 | // validate query
52 | var query ListModel
53 | err := common.ValidateQuery(c, &query)
54 | if err != nil {
55 | return err
56 | }
57 |
58 | // find reports
59 | var reports Reports
60 |
61 | querySet := LoadReportFloor(query.BaseQuery())
62 |
63 | var result *gorm.DB
64 | switch query.Range {
65 | case RangeNotDealt:
66 | result = querySet.Find(&reports, "dealt = ?", false)
67 | case RangeDealt:
68 | result = querySet.Find(&reports, "dealt = ?", true)
69 | case RangeAll:
70 | result = querySet.Find(&reports)
71 | }
72 | if result.Error != nil {
73 | return result.Error
74 | }
75 | return Serialize(c, &reports)
76 | }
77 |
78 | // AddReport
79 | //
80 | // @Summary Add a report
81 | // @Description Add a report and send notification to admins
82 | // @Tags Report
83 | // @Produce application/json
84 | // @Router /reports [post]
85 | // @Param json body AddModel true "json"
86 | // @Success 204
87 | //
88 | // @Failure 400 {object} common.HttpError
89 | func AddReport(c *fiber.Ctx) error {
90 | // validate body
91 | var body AddModel
92 | err := common.ValidateBody(c, &body)
93 | if err != nil {
94 | return err
95 | }
96 |
97 | user, err := GetCurrLoginUser(c)
98 | if err != nil {
99 | return err
100 | }
101 |
102 | // permission
103 | if user.BanReport != nil {
104 | return common.Forbidden(user.BanReportMessage())
105 | }
106 |
107 | // add report
108 | report := Report{
109 | FloorID: body.FloorID,
110 | Reason: body.Reason,
111 | Dealt: false,
112 | }
113 | err = report.Create(c)
114 | if err != nil {
115 | return err
116 | }
117 |
118 | // Send Notification
119 | err = report.SendCreate(DB)
120 | if err != nil {
121 | log.Err(err).Str("model", "Notification").Msg("SendCreate failed: ")
122 | // return err // only for test
123 | }
124 |
125 | return c.Status(204).JSON(nil)
126 | }
127 |
128 | // DeleteReport
129 | //
130 | // @Summary Deal a report
131 | // @Description Mark a report as "dealt" and send notification to reporter
132 | // @Tags Report
133 | // @Produce application/json
134 | // @Router /reports/{id} [delete]
135 | // @Param id path int true "id"
136 | // @Param json body DeleteModel true "json"
137 | // @Success 200 {object} Report
138 | // @Failure 400 {object} common.HttpError
139 | func DeleteReport(c *fiber.Ctx) error {
140 | // validate query
141 | reportID, err := c.ParamsInt("id")
142 | if err != nil {
143 | return err
144 | }
145 |
146 | // validate body
147 | var body DeleteModel
148 | err = common.ValidateBody(c, &body)
149 | if err != nil {
150 | return err
151 | }
152 |
153 | // get user id
154 | userID, err := common.GetUserID(c)
155 | if err != nil {
156 | return err
157 | }
158 |
159 | // modify report
160 | var report Report
161 | result := LoadReportFloor(DB).First(&report, reportID)
162 | if result.Error != nil {
163 | return result.Error
164 | }
165 | report.Dealt = true
166 | report.DealtBy = userID
167 | report.Result = body.Result
168 | DB.Omit("Floor").Save(&report)
169 |
170 | MyLog("Report", "Delete", reportID, userID, RoleAdmin)
171 | CreateAdminLog(DB, AdminLogTypeDeleteReport, userID, report)
172 |
173 | // Send Notification
174 | err = report.SendModify(DB)
175 | if err != nil {
176 | log.Err(err).Str("model", "Notification").Msg("SendModify failed")
177 | // return err // only for test
178 | }
179 |
180 | return Serialize(c, &report)
181 | }
182 |
183 | type banBody struct {
184 | Days *int `json:"days" validate:"omitempty,min=1"`
185 | Reason string `json:"reason"` // optional
186 | }
187 |
188 | // BanReporter
189 | //
190 | // @Summary Ban reporter of a report
191 | // @Tags Report
192 | // @Produce json
193 | // @Router /reports/ban/{id} [post]
194 | // @Param json body banBody true "json"
195 | // @Success 201 {object} User
196 | func BanReporter(c *fiber.Ctx) error {
197 | // validate body
198 | var body banBody
199 | err := common.ValidateBody(c, &body)
200 | if err != nil {
201 | return err
202 | }
203 |
204 | reportID, err := c.ParamsInt("id")
205 | if err != nil {
206 | return err
207 | }
208 |
209 | // get user
210 | user, err := GetCurrLoginUser(c)
211 | if err != nil {
212 | return err
213 | }
214 |
215 | // permission
216 | if !user.IsAdmin {
217 | return common.Forbidden()
218 | }
219 |
220 | var report Report
221 | err = DB.Take(&report, reportID).Error
222 | if err != nil {
223 | return err
224 | }
225 |
226 | var days int
227 | if body.Days != nil {
228 | days = *body.Days
229 | if days <= 0 {
230 | days = 1
231 | }
232 | } else {
233 | days = 1
234 | }
235 |
236 | duration := time.Duration(days) * 24 * time.Hour
237 |
238 | reportPunishment := ReportPunishment{
239 | UserID: report.UserID,
240 | MadeBy: user.ID,
241 | ReportId: report.ID,
242 | Duration: &duration,
243 | Reason: body.Reason,
244 | }
245 | user, err = reportPunishment.Create()
246 | if err != nil {
247 | return err
248 | }
249 |
250 | // construct message for user
251 | message := Notification{
252 | Data: report,
253 | Recipients: []int{report.UserID},
254 | Description: fmt.Sprintf(
255 | "您因违反社区公约被禁止举报。时间:%d天,原因:%s\n如有异议,请联系admin@danta.tech。",
256 | days,
257 | body.Reason,
258 | ),
259 | Title: "处罚通知",
260 | Type: MessageTypePermission,
261 | URL: fmt.Sprintf("/api/reports/%d", report.ID),
262 | }
263 |
264 | // send
265 | _, err = message.Send()
266 | if err != nil {
267 | return err
268 | }
269 |
270 | return c.JSON(user)
271 | }
272 |
--------------------------------------------------------------------------------
/apis/report/routes.go:
--------------------------------------------------------------------------------
1 | package report
2 |
3 | import "github.com/gofiber/fiber/v2"
4 |
5 | func RegisterRoutes(app fiber.Router) {
6 | app.Get("/reports/:id", GetReport)
7 | app.Get("/reports", ListReports)
8 | app.Post("/reports", AddReport)
9 | app.Delete("/reports/:id", DeleteReport)
10 |
11 | app.Post("/reports/ban/:id", BanReporter)
12 | }
13 |
--------------------------------------------------------------------------------
/apis/report/schemas.go:
--------------------------------------------------------------------------------
1 | package report
2 |
3 | import (
4 | "fmt"
5 |
6 | "gorm.io/gorm"
7 |
8 | . "treehole_next/models"
9 | )
10 |
11 | type Range int
12 |
13 | const (
14 | RangeNotDealt Range = iota
15 | RangeDealt
16 | RangeAll
17 | )
18 |
19 | type ListModel struct {
20 | Size int `query:"size" default:"30" validate:"min=0,max=50"`
21 | Offset int `query:"offset" default:"0" validate:"min=0"`
22 | OrderBy string `query:"order_by" default:"id"`
23 | // Sort order, default is desc
24 | Sort string `json:"sort" query:"sort" default:"desc" validate:"oneof=asc desc"`
25 | // Range, 0: not dealt, 1: dealt, 2: all
26 | Range Range `json:"range"`
27 | }
28 |
29 | func (q *ListModel) BaseQuery() *gorm.DB {
30 | return DB.
31 | Limit(q.Size).
32 | Offset(q.Offset).
33 | Order(fmt.Sprintf("`report`.`%s` %s", q.OrderBy, q.Sort))
34 | }
35 |
36 | type AddModel struct {
37 | FloorID int `json:"floor_id" validate:"required"`
38 | Reason string `json:"reason" validate:"required,max=128"`
39 | }
40 |
41 | type DeleteModel struct {
42 | // The deal result, send it to reporter
43 | Result string `json:"result" validate:"required,max=128"`
44 | }
45 |
--------------------------------------------------------------------------------
/apis/routes.go:
--------------------------------------------------------------------------------
1 | package apis
2 |
3 | import (
4 | "github.com/opentreehole/go-common"
5 |
6 | "treehole_next/apis/division"
7 | "treehole_next/apis/favourite"
8 | "treehole_next/apis/floor"
9 | "treehole_next/apis/hole"
10 | "treehole_next/apis/message"
11 | "treehole_next/apis/penalty"
12 | "treehole_next/apis/report"
13 | "treehole_next/apis/subscription"
14 | "treehole_next/apis/tag"
15 | "treehole_next/apis/user"
16 | "treehole_next/config"
17 | _ "treehole_next/docs"
18 | "treehole_next/models"
19 |
20 | "github.com/gofiber/fiber/v2"
21 | fiberSwagger "github.com/swaggo/fiber-swagger"
22 | )
23 |
24 | func registerRoutes(app *fiber.App) {
25 | app.Get("/", func(c *fiber.Ctx) error {
26 | return c.Redirect("/api")
27 | })
28 | app.Get("/docs", func(c *fiber.Ctx) error {
29 | return c.Redirect("/docs/index.html")
30 | })
31 | app.Get("/docs/*", fiberSwagger.WrapHandler)
32 | }
33 |
34 | func RegisterRoutes(app *fiber.App) {
35 | registerRoutes(app)
36 |
37 | group := app.Group("/api")
38 | group.Get("/", Index)
39 | group.Use(MiddlewareGetUser)
40 | division.RegisterRoutes(group)
41 | tag.RegisterRoutes(group)
42 | hole.RegisterRoutes(group)
43 | floor.RegisterRoutes(group)
44 | report.RegisterRoutes(group)
45 | favourite.RegisterRoutes(group)
46 | subscription.RegisterRoutes(group)
47 | penalty.RegisterRoutes(group)
48 | user.RegisterRoutes(group)
49 | message.RegisterRoutes(group)
50 | }
51 |
52 | func MiddlewareGetUser(c *fiber.Ctx) error {
53 | userObject, err := models.GetCurrLoginUser(c)
54 | if err != nil {
55 | return err
56 | }
57 | c.Locals("user", userObject)
58 | if config.Config.AdminOnly {
59 | if !userObject.IsAdmin {
60 | return common.Forbidden()
61 | }
62 | }
63 | return c.Next()
64 | }
65 |
--------------------------------------------------------------------------------
/apis/subscription/api.go:
--------------------------------------------------------------------------------
1 | package subscription
2 |
3 | import (
4 | "github.com/gofiber/fiber/v2"
5 | "github.com/opentreehole/go-common"
6 | "gorm.io/gorm"
7 | "gorm.io/plugin/dbresolver"
8 |
9 | . "treehole_next/models"
10 | . "treehole_next/utils"
11 | )
12 |
13 | // ListSubscriptions
14 | //
15 | // @Summary List User's Subscriptions
16 | // @Tags Subscription
17 | // @Produce application/json
18 | // @Router /users/subscriptions [get]
19 | // @Param object query ListModel false "query"
20 | // @Success 200 {object} models.Map
21 | // @Success 200 {array} models.Hole
22 | func ListSubscriptions(c *fiber.Ctx) error {
23 | // get userID
24 | userID, err := common.GetUserID(c)
25 | if err != nil {
26 | return err
27 | }
28 |
29 | var query ListModel
30 | err = common.ValidateQuery(c, &query)
31 | if err != nil {
32 | return err
33 | }
34 |
35 | if query.Plain {
36 | data, err := UserGetSubscriptionData(DB, userID)
37 | if err != nil {
38 | return err
39 | }
40 | return c.JSON(Map{"data": data})
41 | } else {
42 | holes := make(Holes, 0)
43 | err := DB.
44 | Joins("JOIN user_subscription ON user_subscription.hole_id = hole.id AND user_subscription.user_id = ?", userID).
45 | Order("user_subscription.created_at desc").Find(&holes).Error
46 | if err != nil {
47 | return err
48 | }
49 | return Serialize(c, &holes)
50 | }
51 | }
52 |
53 | // AddSubscription
54 | //
55 | // @Summary Add A Subscription
56 | // @Tags Subscription
57 | // @Accept application/json
58 | // @Produce application/json
59 | // @Router /users/subscriptions [post]
60 | // @Param json body AddModel true "json"
61 | // @Success 201 {object} Response
62 | func AddSubscription(c *fiber.Ctx) error {
63 | // validate body
64 | var body AddModel
65 | err := common.ValidateBody(c, &body)
66 | if err != nil {
67 | return err
68 | }
69 |
70 | // get userID
71 | userID, err := common.GetUserID(c)
72 | if err != nil {
73 | return err
74 | }
75 |
76 | var data []int
77 |
78 | err = DB.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error {
79 | // add favorites
80 | err = AddUserSubscription(tx, userID, body.HoleID)
81 | if err != nil {
82 | return err
83 | }
84 |
85 | // create response
86 | data, err = UserGetSubscriptionData(tx, userID)
87 | return err
88 | })
89 | if err != nil {
90 | return err
91 | }
92 |
93 | return c.Status(201).JSON(&Response{
94 | Message: "关注成功",
95 | Data: data,
96 | })
97 | }
98 |
99 | // DeleteSubscription
100 | //
101 | // @Summary Delete A Subscription
102 | // @Tags Subscription
103 | // @Produce application/json
104 | // @Router /users/subscription [delete]
105 | // @Param json body DeleteModel true "json"
106 | // @Success 200 {object} Response
107 | // @Failure 404 {object} Response
108 | func DeleteSubscription(c *fiber.Ctx) error {
109 | // validate body
110 | var body DeleteModel
111 | err := common.ValidateBody(c, &body)
112 | if err != nil {
113 | return err
114 | }
115 |
116 | // get userID
117 | userID, err := common.GetUserID(c)
118 | if err != nil {
119 | return err
120 | }
121 |
122 | // delete subscriptions
123 | err = DB.Delete(UserSubscription{UserID: userID, HoleID: body.HoleID}).Error
124 | if err != nil {
125 | return err
126 | }
127 |
128 | // create response
129 | data, err := UserGetSubscriptionData(DB, userID)
130 | if err != nil {
131 | return err
132 | }
133 |
134 | return c.JSON(&Response{
135 | Message: "删除成功",
136 | Data: data,
137 | })
138 | }
139 |
--------------------------------------------------------------------------------
/apis/subscription/routes.go:
--------------------------------------------------------------------------------
1 | package subscription
2 |
3 | import "github.com/gofiber/fiber/v2"
4 |
5 | func RegisterRoutes(app fiber.Router) {
6 | app.Get("/users/subscriptions", ListSubscriptions)
7 | app.Post("/users/subscriptions", AddSubscription)
8 | app.Delete("/users/subscriptions", DeleteSubscription)
9 | app.Delete("/users/subscription", DeleteSubscription)
10 | }
11 |
--------------------------------------------------------------------------------
/apis/subscription/schemas.go:
--------------------------------------------------------------------------------
1 | package subscription
2 |
3 | type Response struct {
4 | Message string `json:"message"`
5 | Data []int `json:"data"`
6 | }
7 |
8 | type ListModel struct {
9 | Plain bool `json:"plain" default:"false" query:"plain"`
10 | }
11 |
12 | type AddModel struct {
13 | HoleID int `json:"hole_id"`
14 | }
15 |
16 | type DeleteModel struct {
17 | HoleID int `json:"hole_id"`
18 | }
19 |
--------------------------------------------------------------------------------
/apis/tag/apis.go:
--------------------------------------------------------------------------------
1 | package tag
2 |
3 | import (
4 | "strings"
5 | "time"
6 | "treehole_next/utils/sensitive"
7 |
8 | "github.com/opentreehole/go-common"
9 | "gorm.io/plugin/dbresolver"
10 |
11 | . "treehole_next/models"
12 | . "treehole_next/utils"
13 |
14 | "github.com/gofiber/fiber/v2"
15 | "gorm.io/gorm"
16 | )
17 |
18 | // ListTags
19 | //
20 | // @Summary List All Tags
21 | // @Tags Tag
22 | // @Produce application/json
23 | // @Param object query SearchModel false "query"
24 | // @Router /tags [get]
25 | // @Success 200 {array} Tag
26 | func ListTags(c *fiber.Ctx) error {
27 | var query SearchModel
28 | err := common.ValidateQuery(c, &query)
29 | if err != nil {
30 | return err
31 | }
32 |
33 | tags := make(Tags, 0, 10)
34 | if query.Search == "" {
35 | if GetCache("tags", &tags) {
36 | return c.JSON(&tags)
37 | } else {
38 | err = DB.Order("temperature DESC").Find(&tags).Error
39 | if err != nil {
40 | return err
41 | }
42 | go UpdateTagCache(tags)
43 | return Serialize(c, &tags)
44 | }
45 | }
46 | err = DB.Where("name LIKE ?", "%"+query.Search+"%").
47 | Order("temperature DESC").Find(&tags).Error
48 | if err != nil {
49 | return err
50 | }
51 | return Serialize(c, &tags)
52 | }
53 |
54 | // GetTag
55 | //
56 | // @Summary Get A Tag
57 | // @Tags Tag
58 | // @Produce application/json
59 | // @Router /tags/{id} [get]
60 | // @Param id path int true "id"
61 | // @Success 200 {object} Tag
62 | // @Failure 404 {object} MessageModel
63 | func GetTag(c *fiber.Ctx) error {
64 | id, _ := c.ParamsInt("id")
65 | var tag Tag
66 | tag.ID = id
67 | result := DB.First(&tag)
68 | if result.Error != nil {
69 | return result.Error
70 | }
71 | return Serialize(c, &tag)
72 | }
73 |
74 | // CreateTag
75 | //
76 | // @Summary Create A Tag
77 | // @Tags Tag
78 | // @Produce application/json
79 | // @Router /tags [post]
80 | // @Param json body CreateModel true "json"
81 | // @Success 200 {object} Tag
82 | // @Success 201 {object} Tag
83 | func CreateTag(c *fiber.Ctx) error {
84 | // validate body
85 | var tag Tag
86 | var body CreateModel
87 | err := common.ValidateBody(c, &body)
88 | if err != nil {
89 | return err
90 | }
91 |
92 | // check tag prefix
93 | user, err := GetCurrLoginUser(c)
94 | if err != nil {
95 | return err
96 | }
97 | if !user.IsAdmin {
98 | if len(tag.Name) > 15 && len([]rune(tag.Name)) > 10 {
99 | return common.BadRequest("标签长度不能超过 10 个字符")
100 | }
101 | if strings.HasPrefix(body.Name, "#") {
102 | return common.BadRequest("只有管理员才能创建 # 开头的 tag")
103 | }
104 | if strings.HasPrefix(body.Name, "@") {
105 | return common.BadRequest("只有管理员才能创建 @ 开头的 tag")
106 | }
107 | if strings.HasPrefix(tag.Name, "*") {
108 | return common.BadRequest("只有管理员才能创建 * 开头的 tag")
109 | }
110 | }
111 |
112 | sensitiveResp, err := sensitive.CheckSensitive(sensitive.ParamsForCheck{
113 | Content: body.Name,
114 | Id: time.Now().UnixNano(),
115 | TypeName: sensitive.TypeTag,
116 | })
117 | if err != nil {
118 | return err
119 | }
120 | tag.IsSensitive = !sensitiveResp.Pass
121 |
122 | // bind and create tag
123 | body.Name = strings.TrimSpace(body.Name)
124 | tag.Name = body.Name
125 | result := DB.Where("name = ?", body.Name).FirstOrCreate(&tag)
126 |
127 | if result.RowsAffected == 0 {
128 | c.Status(200)
129 | } else {
130 | c.Status(201)
131 | }
132 | return Serialize(c, &tag)
133 | }
134 |
135 | // ModifyTag
136 | //
137 | // @Summary Modify A Tag, admin only
138 | // @Tags Tag
139 | // @Produce application/json
140 | // @Router /tags/{id} [put]
141 | // @Router /tags/{id}/_webvpn [patch]
142 | // @Param id path int true "id"
143 | // @Param json body ModifyModel true "json"
144 | // @Success 200 {object} Tag
145 | // @Failure 404 {object} MessageModel
146 | func ModifyTag(c *fiber.Ctx) error {
147 | // admin
148 | user, err := GetCurrLoginUser(c)
149 | if err != nil {
150 | return err
151 | }
152 | if !user.IsAdmin {
153 | return common.Forbidden()
154 | }
155 |
156 | // validate body
157 | var body ModifyModel
158 | err = common.ValidateBody(c, &body)
159 | if err != nil {
160 | return err
161 | }
162 | id, err := c.ParamsInt("id")
163 | if err != nil {
164 | return err
165 | }
166 |
167 | // modify tag
168 | var tag Tag
169 | DB.Find(&tag, id)
170 | tag.Name = strings.TrimSpace(body.Name)
171 | tag.Temperature = body.Temperature
172 |
173 | sensitiveResp, err := sensitive.CheckSensitive(sensitive.ParamsForCheck{
174 | Content: body.Name,
175 | Id: time.Now().UnixNano(),
176 | TypeName: sensitive.TypeTag,
177 | })
178 | if err != nil {
179 | return err
180 | }
181 | tag.IsSensitive = !sensitiveResp.Pass
182 |
183 | DB.Save(&tag)
184 |
185 | // log
186 | userID, err := common.GetUserID(c)
187 | if err != nil {
188 | return err
189 | }
190 | MyLog("Tag", "Modify", tag.ID, userID, RoleAdmin)
191 | CreateAdminLog(DB, AdminLogTypeTag, userID, struct {
192 | TagID int `json:"tag_id"`
193 | Body ModifyModel `json:"body"`
194 | }{
195 | TagID: tag.ID,
196 | Body: body,
197 | })
198 |
199 | return Serialize(c, &tag)
200 | }
201 |
202 | // DeleteTag
203 | //
204 | // @Summary Delete A Tag
205 | // @Description Delete a tag and link all of its holes to another given tag
206 | // @Tags Tag
207 | // @Produce application/json
208 | // @Router /tags/{id} [delete]
209 | // @Param id path int true "id"
210 | // @Param json body DeleteModel true "json"
211 | // @Success 200 {object} Tag
212 | // @Failure 404 {object} MessageModel
213 | func DeleteTag(c *fiber.Ctx) error {
214 | // admin
215 | user, err := GetCurrLoginUser(c)
216 | if err != nil {
217 | return err
218 | }
219 | if !user.IsAdmin {
220 | return common.Forbidden()
221 | }
222 |
223 | // validate body
224 | var body DeleteModel
225 | err = common.ValidateBody(c, &body)
226 | if err != nil {
227 | return err
228 | }
229 |
230 | id, err := c.ParamsInt("id")
231 | if err != nil {
232 | return err
233 | }
234 |
235 | var tag Tag
236 | result := DB.First(&tag, id)
237 | if result.Error != nil {
238 | return result.Error
239 | }
240 |
241 | var newTag Tag
242 | result = DB.Where("name = ?", body.To).First(&newTag)
243 | if result.Error != nil {
244 | return result.Error
245 | }
246 |
247 | newTag.Temperature += tag.Temperature
248 |
249 | err = DB.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error {
250 | result = tx.Exec(`
251 | DELETE FROM hole_tags WHERE tag_id = ? AND hole_id IN
252 | (SELECT a.hole_id FROM
253 | (SELECT hole_id FROM hole_tags WHERE tag_id = ?)a
254 | )`, id, newTag.ID)
255 | if result.Error != nil {
256 | return result.Error
257 | }
258 |
259 | result = tx.Exec(`UPDATE hole_tags SET tag_id = ? WHERE tag_id = ?`, newTag.ID, id)
260 | if result.Error != nil {
261 | return result.Error
262 | }
263 |
264 | result = tx.Updates(&newTag)
265 | if result.Error != nil {
266 | return result.Error
267 | }
268 |
269 | result = tx.Delete(&tag)
270 | if result.Error != nil {
271 | return result.Error
272 | }
273 |
274 | return nil
275 | })
276 | if err != nil {
277 | return err
278 | }
279 |
280 | // log
281 | userID, err := common.GetUserID(c)
282 | if err != nil {
283 | return err
284 | }
285 | MyLog("Tag", "Delete", id, userID, RoleAdmin)
286 | return Serialize(c, &newTag)
287 | }
288 |
--------------------------------------------------------------------------------
/apis/tag/routes.go:
--------------------------------------------------------------------------------
1 | package tag
2 |
3 | import "github.com/gofiber/fiber/v2"
4 |
5 | func RegisterRoutes(app fiber.Router) {
6 | app.Get("/tags", ListTags)
7 | app.Get("/tags/:id", GetTag)
8 | app.Post("/tags", CreateTag)
9 | app.Put("/tags/:id", ModifyTag)
10 | app.Patch("/tags/:id/_webvpn", ModifyTag)
11 | app.Delete("/tags/:id", DeleteTag)
12 | }
13 |
--------------------------------------------------------------------------------
/apis/tag/schemas.go:
--------------------------------------------------------------------------------
1 | package tag
2 |
3 | type CreateModel struct {
4 | Name string `json:"name,omitempty" validate:"max=20"` // Admin only
5 | }
6 |
7 | type ModifyModel struct {
8 | CreateModel
9 | Temperature int `json:"temperature,omitempty"` // Admin only
10 | }
11 |
12 | type DeleteModel struct {
13 | // Admin only
14 | // Name of the target tag that all the deleted tag's holes will be connected to
15 | To string `json:"to,omitempty"`
16 | }
17 |
18 | type SearchModel struct {
19 | Search string `json:"s" query:"s" validate:"max=32"` // search tag by name
20 | }
21 |
--------------------------------------------------------------------------------
/apis/user/apis.go:
--------------------------------------------------------------------------------
1 | package user
2 |
3 | import (
4 | "github.com/gofiber/fiber/v2"
5 | "github.com/opentreehole/go-common"
6 | "gorm.io/gorm/clause"
7 |
8 | . "treehole_next/models"
9 | )
10 |
11 | func RegisterRoutes(app fiber.Router) {
12 | app.Get("/users/me", GetCurrentUser)
13 | app.Get("/users/:id", GetUserByID)
14 | app.Put("/users/:id", ModifyUser)
15 | app.Patch("/users/:id/_webvpn", ModifyUser)
16 | app.Put("/users/me", ModifyCurrentUser)
17 | app.Patch("/users/me/_webvpn", ModifyCurrentUser)
18 | }
19 |
20 | // GetCurrentUser
21 | //
22 | // @Summary get current user
23 | // @Tags user
24 | // @Deprecated
25 | // @Produce json
26 | // @Router /users/me [get]
27 | // @Success 200 {object} User
28 | func GetCurrentUser(c *fiber.Ctx) error {
29 | user, err := GetCurrLoginUser(c)
30 | if err != nil {
31 | return err
32 | }
33 | return c.JSON(&user)
34 | }
35 |
36 | // GetUserByID
37 | //
38 | // @Summary get user by id, owner or admin
39 | // @Tags user
40 | // @Produce json
41 | // @Router /users/{user_id} [get]
42 | // @Success 200 {object} User
43 | func GetUserByID(c *fiber.Ctx) error {
44 | userID, err := c.ParamsInt("id")
45 | if err != nil {
46 | return err
47 | }
48 |
49 | user, err := GetCurrLoginUser(c)
50 | if err != nil {
51 | return err
52 | }
53 |
54 | if !user.IsAdmin || user.ID == userID {
55 | return common.Forbidden()
56 | }
57 |
58 | var getUser User
59 | err = getUser.LoadUserByID(userID)
60 | if err != nil {
61 | return err
62 | }
63 |
64 | return c.JSON(&getUser)
65 | }
66 |
67 | // ModifyUser
68 | //
69 | // @Summary modify user profiles
70 | // @Tags User
71 | // @Produce json
72 | // @Router /users/{user_id} [put]
73 | // @Router /users/{user_id}/_webvpn [patch]
74 | // @Param user_id path int true "user id"
75 | // @Param json body ModifyModel true "modify user"
76 | // @Success 200 {object} User
77 | func ModifyUser(c *fiber.Ctx) error {
78 | userID, err := c.ParamsInt("id")
79 | if err != nil {
80 | return err
81 | }
82 |
83 | user, err := GetCurrLoginUser(c)
84 | if err != nil {
85 | return err
86 | }
87 |
88 | if !user.IsAdmin && user.ID != userID {
89 | return common.Forbidden()
90 | }
91 |
92 | var body ModifyModel
93 | err = common.ValidateBody(c, &body)
94 | if err != nil {
95 | return err
96 | }
97 |
98 | // cannot get field "has_answered_questions" when admin changes other user's config
99 | if user.ID != userID {
100 | user = &User{
101 | ID: userID,
102 | }
103 | err = DB.Take(user).Error
104 | if err != nil {
105 | return err
106 | }
107 | }
108 |
109 | err = modifyUser(c, user, body)
110 | if err != nil {
111 | return err
112 | }
113 |
114 | return c.JSON(user)
115 | }
116 |
117 | // ModifyCurrentUser
118 | //
119 | // @Summary modify current user profiles
120 | // @Tags User
121 | // @Produce json
122 | // @Router /users/me [put]
123 | // @Router /users/me/_webvpn [patch]
124 | // @Param user_id path int true "user id"
125 | // @Param json body ModifyModel true "modify user"
126 | // @Success 200 {object} User
127 | func ModifyCurrentUser(c *fiber.Ctx) error {
128 | user, err := GetCurrLoginUser(c)
129 | if err != nil {
130 | return err
131 | }
132 |
133 | var body ModifyModel
134 | err = common.ValidateBody(c, &body)
135 | if err != nil {
136 | return err
137 | }
138 |
139 | err = modifyUser(c, user, body)
140 | if err != nil {
141 | return err
142 | }
143 |
144 | return c.JSON(&user)
145 | }
146 |
147 | func modifyUser(_ *fiber.Ctx, user *User, body ModifyModel) error {
148 | var newUser User
149 | err := DB.Select("config").First(&newUser, user.ID).Error
150 | if err != nil {
151 | return err
152 | }
153 |
154 | if body.Config != nil {
155 | if body.Config.Notify != nil {
156 | newUser.Config.Notify = body.Config.Notify
157 | }
158 | if body.Config.ShowFolded != nil {
159 | newUser.Config.ShowFolded = *body.Config.ShowFolded
160 | }
161 | }
162 |
163 | err = DB.Model(&user).Omit(clause.Associations).Select("Config").UpdateColumns(&newUser).Error
164 | if err != nil {
165 | return err
166 | }
167 |
168 | user.Config = newUser.Config
169 | return nil
170 | }
171 |
--------------------------------------------------------------------------------
/apis/user/schemas.go:
--------------------------------------------------------------------------------
1 | package user
2 |
3 | type ModifyModel struct {
4 | Nickname *string `json:"nickname" validate:"omitempty,min=1"`
5 | Config *UserConfigModel `json:"config"`
6 | }
7 |
8 | type UserConfigModel struct {
9 | Notify []string `json:"notify"`
10 | ShowFolded *string `json:"show_folded"`
11 | }
12 |
--------------------------------------------------------------------------------
/benchmarks/floor_test.go:
--------------------------------------------------------------------------------
1 | package benchmarks
2 |
3 | import (
4 | "math/rand"
5 | "strconv"
6 | "testing"
7 |
8 | . "treehole_next/models"
9 | )
10 |
11 | func BenchmarkListFloorsInAHole(b *testing.B) {
12 | for i := 0; i < b.N; i++ {
13 | b.StopTimer()
14 | route := "/api/holes/" + strconv.Itoa(rand.Intn(HOLE_MAX)+1) + "/floors/"
15 | b.StartTimer()
16 |
17 | benchmarkCommon(b, "get", route, REQUEST_BODY)
18 | }
19 | }
20 |
21 | func BenchmarkGetFloor(b *testing.B) {
22 | for i := 0; i < b.N; i++ {
23 | b.StopTimer()
24 | route := "/api/floors/" + strconv.Itoa(rand.Intn(FLOOR_MAX)+1) + "/"
25 | b.StartTimer()
26 |
27 | benchmarkCommon(b, "get", route, REQUEST_BODY)
28 | }
29 | }
30 |
31 | func BenchmarkCreateFloor(b *testing.B) {
32 | for i := 0; i < b.N; i++ {
33 | b.StopTimer()
34 | route := "/api/holes/" + strconv.Itoa(rand.Intn(HOLE_MAX)+1) + "/floors/"
35 | data := Map{
36 | "content": strconv.Itoa(rand.Int()),
37 | "reply_to": 0,
38 | }
39 | b.StartTimer()
40 |
41 | benchmarkCommon(b, "post", route, REQUEST_BODY, data)
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/benchmarks/hole_test.go:
--------------------------------------------------------------------------------
1 | package benchmarks
2 |
3 | import (
4 | "fmt"
5 | "math/rand"
6 | "strconv"
7 | "testing"
8 |
9 | . "treehole_next/models"
10 | _ "treehole_next/tests"
11 | )
12 |
13 | func BenchmarkListHoles(b *testing.B) {
14 | for i := 0; i < b.N; i++ {
15 | // prepare
16 | b.StopTimer()
17 | route := "/api/divisions/" + strconv.Itoa(rand.Intn(DIVISION_MAX)+1) + "/holes/"
18 | b.StartTimer()
19 |
20 | benchmarkCommon(b, "get", route, REQUEST_BODY)
21 | }
22 | }
23 |
24 | func BenchmarkCreateHoles(b *testing.B) {
25 | for i := 0; i < b.N; i++ {
26 | // prepare
27 | b.StopTimer()
28 | route := "/api/divisions/" + strconv.Itoa(rand.Intn(DIVISION_MAX)+1) + "/holes/"
29 | data := Map{
30 | "content": fmt.Sprintf("%v", rand.Uint64()),
31 | "tag": []Map{
32 | {"name": "123"},
33 | {"name": "456"},
34 | },
35 | }
36 | b.StartTimer()
37 |
38 | benchmarkCommon(b, "post", route, REQUEST_BODY, data)
39 | }
40 | }
41 |
42 | func BenchmarkGetHole(b *testing.B) {
43 | for i := 0; i < b.N; i++ {
44 | // prepare
45 | b.StopTimer()
46 | holeID := rand.Intn(HOLE_MAX) + 1
47 | url := "/api/holes/" + strconv.Itoa(holeID) + "/"
48 | b.StartTimer()
49 |
50 | benchmarkCommon(b, "get", url, REQUEST_BODY)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/benchmarks/init.go:
--------------------------------------------------------------------------------
1 | package benchmarks
2 |
3 | import (
4 | "fmt"
5 | "math/rand"
6 | "strings"
7 |
8 | "github.com/rs/zerolog/log"
9 | "gorm.io/gorm/logger"
10 |
11 | . "treehole_next/models"
12 | "treehole_next/utils"
13 | )
14 |
15 | const (
16 | DIVISION_MAX = 10
17 | TAG_MAX = 100
18 | HOLE_MAX = 100
19 | FLOOR_MAX = 1000
20 | )
21 |
22 | func init() {
23 | DB.Logger = logger.Default.LogMode(logger.Silent)
24 |
25 | divisions := make(Divisions, 0, DIVISION_MAX)
26 | tags := make(Tags, 0, TAG_MAX)
27 | holes := make(Holes, 0, HOLE_MAX)
28 | floors := make(Floors, 0, FLOOR_MAX)
29 |
30 | for i := 0; i < DIVISION_MAX; i++ {
31 | divisions = append(divisions, &Division{
32 | ID: i + 1,
33 | Name: strings.Repeat("d", i+1),
34 | Description: strings.Repeat("dd", i+1),
35 | })
36 | }
37 |
38 | for i := 0; i < TAG_MAX; i++ {
39 | content := fmt.Sprintf("%v", rand.Uint64())
40 | tags = append(tags, &Tag{
41 | ID: i + 1,
42 | Name: content,
43 | })
44 | }
45 |
46 | for i := 0; i < HOLE_MAX; i++ {
47 | generateTag := func() Tags {
48 | nowTags := make(Tags, rand.Intn(10))
49 | for i := range nowTags {
50 | nowTags[i] = tags[rand.Intn(TAG_MAX)]
51 | }
52 | return nowTags
53 | }
54 | holes = append(holes, &Hole{
55 | ID: i + 1,
56 | UserID: 1,
57 | DivisionID: rand.Intn(DIVISION_MAX) + 1,
58 | Tags: generateTag(),
59 | })
60 | }
61 |
62 | for i := 0; i < FLOOR_MAX; i++ {
63 | content := fmt.Sprintf("%v", rand.Uint64())
64 | generateMention := func() Floors {
65 | floorMentions := make(Floors, 0, rand.Intn(10))
66 | for j := range floorMentions {
67 | floorMentions[j] = &Floor{ID: rand.Intn(FLOOR_MAX) + 1}
68 | }
69 | return floorMentions
70 | }
71 | floors = append(floors, &Floor{
72 | ID: i + 1,
73 | Content: strings.Repeat(content, rand.Intn(2)),
74 | Anonyname: utils.GenerateName([]string{}),
75 | HoleID: rand.Intn(HOLE_MAX) + 1,
76 | Mention: generateMention(),
77 | })
78 | holes[floors[i].HoleID-1].Reply += 1
79 | }
80 |
81 | var err error
82 | err = DB.Create(divisions).Error
83 | if err != nil {
84 | log.Fatal().Err(err).Send()
85 | }
86 | err = DB.Create(tags).Error
87 | if err != nil {
88 | log.Fatal().Err(err).Send()
89 | }
90 | err = DB.Create(holes).Error
91 | if err != nil {
92 | log.Fatal().Err(err).Send()
93 | }
94 | err = DB.Create(floors).Error
95 | if err != nil {
96 | log.Fatal().Err(err).Send()
97 | }
98 | }
99 |
--------------------------------------------------------------------------------
/benchmarks/utils.go:
--------------------------------------------------------------------------------
1 | package benchmarks
2 |
3 | import (
4 | "bytes"
5 | "io"
6 | "net/http"
7 | "strings"
8 | "testing"
9 |
10 | "github.com/goccy/go-json"
11 | "github.com/hetiansu5/urlquery"
12 | "github.com/stretchr/testify/assert"
13 |
14 | "treehole_next/bootstrap"
15 | . "treehole_next/models"
16 | )
17 |
18 | var App, _ = bootstrap.Init()
19 |
20 | var _ Map
21 |
22 | const (
23 | REQUEST_BODY = iota
24 | REQUEST_QUERY
25 | )
26 |
27 | func benchmarkCommon(b *testing.B, method string, route string, requestType int, data ...Map) []byte {
28 | var requestData []byte
29 | var err error
30 | var req *http.Request
31 |
32 | b.StopTimer()
33 | switch requestType {
34 | case REQUEST_BODY:
35 | if len(data) > 0 && data[0] != nil { // data[0] is request data
36 | requestData, err = json.Marshal(data[0])
37 | assert.Nilf(b, err, "encode request body")
38 | }
39 | req, err = http.NewRequest(
40 | strings.ToUpper(method),
41 | route,
42 | bytes.NewBuffer(requestData),
43 | )
44 | case REQUEST_QUERY:
45 | req, err = http.NewRequest(
46 | strings.ToUpper(method),
47 | route,
48 | nil,
49 | )
50 | if len(data) > 0 && data[0] != nil { // data[0] is query data
51 | queryData, err := urlquery.Marshal(data[0])
52 | req.URL.RawQuery = string(queryData)
53 | assert.Nilf(b, err, "encode request body")
54 | }
55 | }
56 |
57 | req.Header.Add("Content-Type", "application/json")
58 | assert.Nilf(b, err, "constructs http request")
59 |
60 | b.StartTimer()
61 | res, err := App.Test(req, -1)
62 | b.StopTimer()
63 | assert.Nilf(b, err, "perform request")
64 |
65 | responseBody, err := io.ReadAll(res.Body)
66 | assert.Nilf(b, err, "decode response")
67 |
68 | if res.StatusCode != 200 && res.StatusCode != 201 {
69 | assert.Fail(b, string(responseBody))
70 | }
71 | return responseBody
72 | }
73 |
--------------------------------------------------------------------------------
/bootstrap/init.go:
--------------------------------------------------------------------------------
1 | package bootstrap
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/gofiber/fiber/v2"
7 | "github.com/opentreehole/go-common"
8 |
9 | "treehole_next/apis"
10 | "treehole_next/apis/hole"
11 | "treehole_next/apis/message"
12 | "treehole_next/config"
13 | "treehole_next/models"
14 | "treehole_next/utils"
15 | "treehole_next/utils/sensitive"
16 |
17 | "github.com/goccy/go-json"
18 | "github.com/gofiber/fiber/v2/middleware/pprof"
19 | "github.com/gofiber/fiber/v2/middleware/recover"
20 | )
21 |
22 | func Init() (*fiber.App, context.CancelFunc) {
23 | config.InitConfig()
24 | utils.InitCache()
25 | sensitive.InitSensitiveLabelMap()
26 | models.Init()
27 | models.InitDB()
28 | models.InitAdminList()
29 |
30 | app := fiber.New(fiber.Config{
31 | ErrorHandler: common.ErrorHandler,
32 | JSONEncoder: json.Marshal,
33 | JSONDecoder: json.Unmarshal,
34 | DisableStartupMessage: true,
35 | })
36 | registerMiddlewares(app)
37 | apis.RegisterRoutes(app)
38 |
39 | return app, startTasks()
40 | }
41 |
42 | func registerMiddlewares(app *fiber.App) {
43 | app.Use(recover.New(recover.Config{EnableStackTrace: true}))
44 | app.Use(common.MiddlewareGetUserID)
45 | if config.Config.Mode != "bench" {
46 | app.Use(common.MiddlewareCustomLogger)
47 | }
48 | app.Use(pprof.New())
49 | }
50 |
51 | func startTasks() context.CancelFunc {
52 | ctx, cancel := context.WithCancel(context.Background())
53 | go hole.UpdateHoleViews(ctx)
54 | go hole.PurgeHole(ctx)
55 | go message.PurgeMessage()
56 | // go models.UpdateAdminList(ctx)
57 | go sensitive.UpdateSensitiveLabelMap(ctx)
58 | return cancel
59 | }
60 |
--------------------------------------------------------------------------------
/config/config.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "github.com/caarlos0/env/v9"
5 | "net/url"
6 | "sync/atomic"
7 |
8 | "github.com/rs/zerolog/log"
9 | )
10 |
11 | var Config struct {
12 | Mode string `env:"MODE" envDefault:"dev"`
13 | TZ string `env:"TZ" envDefault:"Asia/Shanghai"`
14 | Size int `env:"SIZE" envDefault:"30"`
15 | MaxSize int `env:"MAX_SIZE" envDefault:"50"`
16 | TagSize int `env:"TAG_SIZE" envDefault:"5"`
17 | HoleFloorSize int `env:"HOLE_FLOOR_SIZE" envDefault:"10"`
18 | Debug bool `env:"DEBUG" envDefault:"false"`
19 | // example: user:pass@tcp(127.0.0.1:3306)/dbname?parseTime=true&loc=Asia%2fShanghai
20 | // set time_zone in url, otherwise UTC
21 | // for more detail, see https://github.com/go-sql-driver/mysql#dsn-data-source-name
22 | DbURL string `env:"DB_URL"`
23 | // example: MYSQL_REPLICA_URL="db1_dsn,db2_dsn", use ',' as separator
24 | // should also set time_zone in url
25 | MysqlReplicaURLs []string `env:"MYSQL_REPLICA_URL"`
26 | RedisURL string `env:"REDIS_URL"` // redis:6379
27 | NotificationUrl string `env:"NOTIFICATION_URL"`
28 | MessagePurgeDays int `envDefault:"7" env:"MESSAGE_PURGE_DAYS"`
29 | AuthUrl string `env:"AUTH_URL"`
30 | ElasticsearchUrl string `env:"ELASTICSEARCH_URL"`
31 | OpenSearch bool `env:"OPEN_SEARCH" envDefault:"true"`
32 | OpenFuzzName bool `env:"OPEN_FUZZ_NAME" envDefault:"false"`
33 | UserAllShowHidden bool `env:"USER_ALL_HIDDEN" envDefault:"false"`
34 | AdminOnly bool `env:"ADMIN_ONLY" envDefault:"false"`
35 | HolePurgeDivisions []int `env:"HOLE_PURGE_DIVISIONS" envDefault:"2"`
36 | HolePurgeDays int `env:"HOLE_PURGE_DAYS" envDefault:"30"`
37 | OpenSensitiveCheck bool `env:"OPEN_SENSITIVE_CHECK" envDefault:"true"`
38 |
39 | YiDunBusinessIdText string `env:"YI_DUN_BUSINESS_ID_TEXT" envDefault:""`
40 | YiDunBusinessIdImage string `env:"YI_DUN_BUSINESS_ID_IMAGE" envDefault:""`
41 | YiDunSecretId string `env:"YI_DUN_SECRET_ID" envDefault:""`
42 | YiDunSecretKey string `env:"YI_DUN_SECRET_KEY" envDefault:""`
43 | YiDunAccessKeyId string `env:"YI_DUN_ACCESS_KEY_ID" envDefault:""`
44 | YiDunAccessKeySecret string `env:"YI_DUN_ACCESS_KEY_SECRET" envDefault:""`
45 | ValidImageUrl []string `env:"VALID_IMAGE_URL"`
46 | UrlHostnameWhitelist []string `env:"URL_HOSTNAME_WHITELIST"`
47 | ExternalImageHost string `env:"EXTERNAL_IMAGE_HOSTNAME" envDefault:""`
48 | NotifiableAdminIds []int `env:"NOTIFIABLE_ADMIN_IDS"`
49 | ExcludeBanForeverDivisionIds []int `env:"EXCLUDE_BAN_FOREVER_DIVISION_IDS"`
50 | ProxyUrl *url.URL `env:"PROXY_URL"`
51 | QQBotPhysicsGroupID *int64 `env:"PHYSICS_GROUP_ID"`
52 | QQBotCodingGroupID *int64 `env:"CODING_GROUP_ID"`
53 | QQBotUserID *int64 `env:"USER_ID"`
54 | QQBotUrl *string `env:"QQ_BOT_URL"`
55 | FeishuBotUrl *string `env:"FEISHU_BOT_URL"`
56 | AdminOnlyTagIds []int `env:"ADMIN_ONLY_TAG_IDS"`
57 | }
58 |
59 | var DynamicConfig struct {
60 | OpenSearch atomic.Bool
61 | }
62 |
63 | func InitConfig() { // load config from environment variables
64 | if err := env.Parse(&Config); err != nil {
65 | log.Fatal().Err(err).Send()
66 | }
67 | log.Info().Any("config", Config).Msg("init config")
68 | DynamicConfig.OpenSearch.Store(Config.OpenSearch)
69 | }
70 |
--------------------------------------------------------------------------------
/data/data.go:
--------------------------------------------------------------------------------
1 | package data
2 |
3 | import (
4 | _ "embed"
5 | "os"
6 |
7 | "github.com/goccy/go-json"
8 | "github.com/rs/zerolog/log"
9 | )
10 |
11 | //go:embed names.json
12 | var NamesFile []byte
13 |
14 | //go:embed meta.json
15 | var MetaFile []byte
16 |
17 | var NamesMapping map[string]string
18 |
19 | func init() {
20 | err := initNamesMapping()
21 | if err != nil {
22 | log.Err(err).Msg("could not init names mapping")
23 | }
24 | }
25 |
26 | func initNamesMapping() error {
27 | NamesMappingData, err := os.ReadFile(`data/names_mapping.json`)
28 | if err != nil {
29 | return err
30 | }
31 |
32 | return json.Unmarshal(NamesMappingData, &NamesMapping)
33 | }
34 |
--------------------------------------------------------------------------------
/data/meta.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Open Tree Hole",
3 | "description": "Next Generation of OpenTreeHole Implemented In Go ---- An Anonymous BBS",
4 | "version": "2.0.0",
5 | "homepage": "https://github.com/opentreehole",
6 | "repository": "https://github.com/OpenTreeHole/treehole_next",
7 | "author": "hasbai",
8 | "maintainer": "jingyijun",
9 | "email": "dev@danta.tech",
10 | "license": "Apache-2.0"
11 | }
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module treehole_next
2 |
3 | go 1.22
4 |
5 | require (
6 | github.com/caarlos0/env/v9 v9.0.0
7 | github.com/eko/gocache/lib/v4 v4.1.6
8 | github.com/eko/gocache/store/go_cache/v4 v4.2.2
9 | github.com/eko/gocache/store/redis/v4 v4.2.2
10 | github.com/elastic/go-elasticsearch/v8 v8.14.0
11 | github.com/goccy/go-json v0.10.3
12 | github.com/gofiber/fiber/v2 v2.52.5
13 | github.com/hetiansu5/urlquery v1.2.7
14 | github.com/opentreehole/go-common v0.1.7
15 | github.com/patrickmn/go-cache v2.1.0+incompatible
16 | github.com/redis/go-redis/v9 v9.6.1
17 | github.com/rs/zerolog v1.33.0
18 | github.com/stretchr/testify v1.9.0
19 | github.com/swaggo/fiber-swagger v1.3.0
20 | github.com/swaggo/swag v1.16.3
21 | github.com/yidun/yidun-golang-sdk v1.0.14
22 | golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56
23 | gorm.io/driver/mysql v1.5.7
24 | gorm.io/driver/sqlite v1.5.6
25 | gorm.io/gorm v1.25.11
26 | gorm.io/plugin/dbresolver v1.5.2
27 | mvdan.cc/xurls/v2 v2.5.0
28 | )
29 |
30 | require (
31 | filippo.io/edwards25519 v1.1.0 // indirect
32 | github.com/KyleBanks/depth v1.2.1 // indirect
33 | github.com/andybalholm/brotli v1.1.0 // indirect
34 | github.com/beorn7/perks v1.0.1 // indirect
35 | github.com/cespare/xxhash/v2 v2.3.0 // indirect
36 | github.com/creasty/defaults v1.7.0 // indirect
37 | github.com/davecgh/go-spew v1.1.1 // indirect
38 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
39 | github.com/elastic/elastic-transport-go/v8 v8.6.0 // indirect
40 | github.com/gabriel-vasile/mimetype v1.4.5 // indirect
41 | github.com/go-logr/logr v1.4.2 // indirect
42 | github.com/go-logr/stdr v1.2.2 // indirect
43 | github.com/go-openapi/jsonpointer v0.21.0 // indirect
44 | github.com/go-openapi/jsonreference v0.21.0 // indirect
45 | github.com/go-openapi/spec v0.21.0 // indirect
46 | github.com/go-openapi/swag v0.23.0 // indirect
47 | github.com/go-playground/locales v0.14.1 // indirect
48 | github.com/go-playground/universal-translator v0.18.1 // indirect
49 | github.com/go-playground/validator/v10 v10.22.0 // indirect
50 | github.com/go-sql-driver/mysql v1.8.1 // indirect
51 | github.com/golang/mock v1.6.0 // indirect
52 | github.com/google/uuid v1.6.0 // indirect
53 | github.com/jinzhu/inflection v1.0.0 // indirect
54 | github.com/jinzhu/now v1.1.5 // indirect
55 | github.com/josharian/intern v1.0.0 // indirect
56 | github.com/klauspost/compress v1.17.9 // indirect
57 | github.com/leodido/go-urn v1.4.0 // indirect
58 | github.com/mailru/easyjson v0.7.7 // indirect
59 | github.com/mattn/go-colorable v0.1.13 // indirect
60 | github.com/mattn/go-isatty v0.0.20 // indirect
61 | github.com/mattn/go-runewidth v0.0.16 // indirect
62 | github.com/mattn/go-sqlite3 v1.14.22 // indirect
63 | github.com/pmezard/go-difflib v1.0.0 // indirect
64 | github.com/prometheus/client_golang v1.19.0 // indirect
65 | github.com/prometheus/client_model v0.6.0 // indirect
66 | github.com/prometheus/common v0.51.1 // indirect
67 | github.com/prometheus/procfs v0.13.0 // indirect
68 | github.com/rivo/uniseg v0.4.7 // indirect
69 | github.com/swaggo/files v1.0.1 // indirect
70 | github.com/tjfoc/gmsm v1.4.1 // indirect
71 | github.com/valyala/bytebufferpool v1.0.0 // indirect
72 | github.com/valyala/fasthttp v1.55.0 // indirect
73 | github.com/valyala/tcplisten v1.0.0 // indirect
74 | go.opentelemetry.io/otel v1.28.0 // indirect
75 | go.opentelemetry.io/otel/metric v1.28.0 // indirect
76 | go.opentelemetry.io/otel/trace v1.28.0 // indirect
77 | golang.org/x/crypto v0.26.0 // indirect
78 | golang.org/x/net v0.28.0 // indirect
79 | golang.org/x/sync v0.8.0 // indirect
80 | golang.org/x/sys v0.23.0 // indirect
81 | golang.org/x/text v0.17.0 // indirect
82 | golang.org/x/tools v0.24.0 // indirect
83 | google.golang.org/protobuf v1.33.0 // indirect
84 | gopkg.in/yaml.v3 v3.0.1 // indirect
85 | )
86 |
87 | replace github.com/yidun/yidun-golang-sdk => github.com/jingyijun/yidun-golang-sdk v1.0.13-0.20240709102803-aaae270a5671
88 |
--------------------------------------------------------------------------------
/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "os"
5 | "os/signal"
6 | "syscall"
7 |
8 | "github.com/rs/zerolog/log"
9 |
10 | "treehole_next/bootstrap"
11 | )
12 |
13 | // @title Open Tree Hole
14 | // @version 2.1.0
15 | // @description An Anonymous BBS \n Note: PUT methods are used to PARTLY update, and we don't use PATCH method.
16 |
17 | // @contact.name Maintainer Ke Chen
18 | // @contact.email dev@danta.tech
19 |
20 | // @license.name Apache 2.0
21 | // @license.url https://www.apache.org/licenses/LICENSE-2.0.html
22 |
23 | // @host
24 | // @BasePath /api
25 |
26 | func main() {
27 | app, cancel := bootstrap.Init()
28 | go func() {
29 | err := app.Listen("0.0.0.0:8000")
30 | if err != nil {
31 | log.Fatal().Err(err).Msg("app listen failed")
32 | }
33 | }()
34 |
35 | interrupt := make(chan os.Signal, 1)
36 |
37 | // wait for CTRL-C interrupt
38 | signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
39 | <-interrupt
40 |
41 | // close app
42 | err := app.Shutdown()
43 | if err != nil {
44 | log.Err(err).Msg("error shutdown app")
45 | }
46 | // stop tasks
47 | cancel()
48 | }
49 |
--------------------------------------------------------------------------------
/models/admin_log.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import (
4 | "github.com/rs/zerolog/log"
5 | "gorm.io/gorm"
6 | "time"
7 | )
8 |
9 | type AdminLog struct {
10 | ID int `gorm:"primaryKey"`
11 | CreatedAt time.Time
12 | Type AdminLogType `gorm:"size:16;not null"`
13 | UserID int `gorm:"not null"`
14 | Data any `gorm:"serializer:json"`
15 | }
16 |
17 | type AdminLogType string
18 |
19 | const (
20 | AdminLogTypeHole AdminLogType = "edit_hole"
21 | AdminLogTypeHideHole AdminLogType = "hide_hole"
22 | AdminLogTypeTag AdminLogType = "edit_tag"
23 | AdminLogTypeDivision AdminLogType = "edit_division"
24 | AdminLogTypeMessage AdminLogType = "send_message"
25 | AdminLogTypeDeleteReport AdminLogType = "delete_report"
26 | AdminLogTypeChangeSensitive AdminLogType = "change_sensitive"
27 | )
28 |
29 | // CreateAdminLog
30 | // save admin edit log for audit purpose
31 | func CreateAdminLog(tx *gorm.DB, logType AdminLogType, userID int, data any) {
32 | adminLog := AdminLog{
33 | Type: logType,
34 | UserID: userID,
35 | Data: data,
36 | }
37 | err := tx.Create(&adminLog).Error // omit error
38 | if err != nil {
39 | log.Error().Err(err).Msg("failed to create admin log")
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/models/anonyname.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import (
4 | "errors"
5 |
6 | "gorm.io/gorm"
7 | "gorm.io/gorm/clause"
8 |
9 | "treehole_next/utils"
10 | )
11 |
12 | type AnonynameMapping struct {
13 | HoleID int `json:"hole_id" gorm:"primaryKey"`
14 | UserID int `json:"user_id" gorm:"primaryKey"`
15 | Anonyname string `json:"anonyname" gorm:"size:32"`
16 | }
17 |
18 | func NewAnonyname(tx *gorm.DB, holeID, userID int) (string, error) {
19 | name := utils.NewRandName()
20 | return name, tx.Create(&AnonynameMapping{
21 | HoleID: holeID,
22 | UserID: userID,
23 | Anonyname: name,
24 | }).Error
25 | }
26 |
27 | func FindOrGenerateAnonyname(tx *gorm.DB, holeID, userID int) (string, error) {
28 | var anonyname string
29 | err := tx.
30 | Model(&AnonynameMapping{}).
31 | Select("anonyname").
32 | Where("hole_id = ?", holeID).
33 | Where("user_id = ?", userID).
34 | Take(&anonyname).Error
35 |
36 | if err != nil {
37 | if errors.Is(err, gorm.ErrRecordNotFound) {
38 | var names []string
39 | err = tx.
40 | Clauses(clause.Locking{Strength: "UPDATE"}).
41 | Model(&AnonynameMapping{}).
42 | Select("anonyname").
43 | Where("hole_id = ?", holeID).
44 | Order("anonyname").
45 | Scan(&names).Error
46 | if err != nil {
47 | return "", err
48 | }
49 |
50 | anonyname = utils.GenerateName(names)
51 | err = tx.Create(&AnonynameMapping{
52 | HoleID: holeID,
53 | UserID: userID,
54 | Anonyname: anonyname,
55 | }).Error
56 | if err != nil {
57 | return anonyname, err
58 | }
59 | } else {
60 | return "", err
61 | }
62 | }
63 | return anonyname, nil
64 | }
65 |
--------------------------------------------------------------------------------
/models/base.go:
--------------------------------------------------------------------------------
1 | // Package models contains database models
2 | package models
3 |
4 | type Map = map[string]interface{}
5 |
6 | type Models interface {
7 | Division | Hole | Floor | Tag | User | Report | Message |
8 | Divisions | Holes | Floors | Tags | Users | Reports | Messages
9 | }
10 |
11 | type MessageModel struct {
12 | Message string `json:"message"`
13 | }
14 |
--------------------------------------------------------------------------------
/models/division.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import (
4 | "time"
5 |
6 | "treehole_next/utils"
7 |
8 | "github.com/gofiber/fiber/v2"
9 | "gorm.io/gorm"
10 | )
11 |
12 | type Division struct {
13 | /// saved fields
14 | ID int `json:"id" gorm:"primaryKey"`
15 | CreatedAt time.Time `json:"time_created" gorm:"not null"`
16 | UpdatedAt time.Time `json:"time_updated" gorm:"not null"`
17 |
18 | /// base info
19 | Name string `json:"name" gorm:"unique;size:10"`
20 | Description string `json:"description" gorm:"size:64"`
21 | Hidden bool `json:"hidden" gorm:"not null;default:false"`
22 |
23 | // pinned holes in given order
24 | Pinned []int `json:"-" gorm:"serializer:json;size:100;not null;default:\"[]\""`
25 |
26 | /// association fields, should add foreign key
27 |
28 | // return pinned hole to frontend
29 | Holes Holes `json:"pinned"`
30 |
31 | /// generated field
32 | DivisionID int `json:"division_id" gorm:"-:all"`
33 | }
34 |
35 | func (division *Division) GetID() int {
36 | return division.ID
37 | }
38 |
39 | type Divisions []*Division
40 |
41 | func (divisions Divisions) Preprocess(c *fiber.Ctx) error {
42 | for _, division := range divisions {
43 | err := division.Preprocess(c)
44 | if err != nil {
45 | return err
46 | }
47 | }
48 | return utils.SetCache("divisions", divisions, 0)
49 | }
50 |
51 | func (division *Division) Preprocess(c *fiber.Ctx) error {
52 | var pinned = division.Pinned
53 | division.Holes = make(Holes, 0, 10)
54 | if len(pinned) == 0 {
55 | return nil
56 | }
57 | DB.Find(&division.Holes, pinned)
58 | if len(division.Holes) == 0 {
59 | return nil
60 | }
61 | division.Holes = utils.OrderInGivenOrder(division.Holes, pinned)
62 | // division.Holes = division.Holes.RemoveIf(func(hole *Hole) bool {
63 | // return hole.Hidden
64 | // })
65 | return division.Holes.Preprocess(c)
66 | }
67 |
68 | func (division *Division) AfterFind(_ *gorm.DB) (err error) {
69 | division.DivisionID = division.ID
70 | return nil
71 | }
72 |
73 | func (division *Division) AfterCreate(_ *gorm.DB) (err error) {
74 | division.DivisionID = division.ID
75 | return nil
76 | }
77 |
--------------------------------------------------------------------------------
/models/favorite_group.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import (
4 | "errors"
5 | "github.com/opentreehole/go-common"
6 | "gorm.io/gorm"
7 | "gorm.io/plugin/dbresolver"
8 | "time"
9 | )
10 |
11 | type FavoriteGroup struct {
12 | FavoriteGroupID int `json:"favorite_group_id" gorm:"primaryKey"`
13 | UserID int `json:"user_id" gorm:"primaryKey"`
14 | Name string `json:"name" gorm:"not null;size:64" default:"默认"`
15 | CreatedAt time.Time `json:"time_created"`
16 | UpdatedAt time.Time `json:"time_updated"`
17 | Deleted bool `json:"deleted" gorm:"default:false"`
18 | Count int `json:"count" gorm:"default:0"`
19 | }
20 |
21 | const MaxGroupPerUser = 10
22 |
23 | type FavoriteGroups []FavoriteGroup
24 |
25 | func (FavoriteGroup) TableName() string {
26 | return "favorite_groups"
27 | }
28 |
29 | // make sure use this function in a transaction
30 | func UserGetFavoriteGroups(tx *gorm.DB, userID int, order *string) (favoriteGroups FavoriteGroups, err error) {
31 | err = CheckDefaultFavoriteGroup(tx, userID)
32 | if err != nil {
33 | return
34 | }
35 |
36 | if order == nil {
37 | err = tx.Where("user_id = ? and deleted = false", userID).Find(&favoriteGroups).Error
38 | } else {
39 | err = tx.Where("user_id = ? and deleted = false", userID).Order(*order).Find(&favoriteGroups).Error
40 | }
41 | return
42 | }
43 |
44 | func DeleteUserFavoriteGroup(tx *gorm.DB, userID int, groupID int) (err error) {
45 | if groupID == 0 {
46 | return common.Forbidden("默认收藏夹不可删除")
47 | }
48 | err = tx.Model(&UserFavorite{}).Where("user_id = ? AND favorite_group_id = ?", userID, groupID).Take(&UserFavorite{}).Error
49 | if err != nil {
50 | if !errors.Is(err, gorm.ErrRecordNotFound) {
51 | return err
52 | }
53 | } else {
54 | return common.Forbidden("收藏夹中存在收藏内容,请先移除")
55 | }
56 |
57 | result := tx.Clauses(dbresolver.Write).Where("user_id = ? AND favorite_group_id = ?", userID, groupID).Updates(FavoriteGroup{Deleted: true})
58 | if result.Error != nil {
59 | return err
60 | }
61 | if result.RowsAffected == 0 {
62 | return common.NotFound("收藏夹不存在")
63 | }
64 | err = tx.Model(&UserFavorite{}).Where("user_id = ? AND favorite_group_id = ?", userID, groupID).Delete(&UserFavorite{}).Error
65 | if err != nil {
66 | return err
67 | }
68 | return tx.Model(&User{}).Where("id = ?", userID).Update("favorite_group_count", gorm.Expr("favorite_group_count - 1")).Error
69 | }
70 |
71 | func CheckDefaultFavoriteGroup(tx *gorm.DB, userID int) (err error) {
72 | return tx.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error {
73 | err = tx.Model(&FavoriteGroup{}).Where("user_id = ? AND favorite_group_id = 0", userID).Take(&FavoriteGroup{}).Error
74 | if err != nil {
75 | if !errors.Is(err, gorm.ErrRecordNotFound) {
76 | return err
77 | }
78 |
79 | // insert default favorite group if not exists
80 | err = tx.Create(&FavoriteGroup{
81 | UserID: userID,
82 | Name: "默认收藏夹",
83 | FavoriteGroupID: 0,
84 | CreatedAt: time.Now(),
85 | }).Error
86 | if err != nil {
87 | return err
88 | }
89 | return tx.Model(&User{}).Where("id = ?", userID).Update("favorite_group_count", gorm.Expr("favorite_group_count + 1")).Error
90 | }
91 |
92 | // default favorite group exists
93 | return nil
94 | })
95 |
96 | }
97 |
98 | func AddUserFavoriteGroup(tx *gorm.DB, userID int, name string) (err error) {
99 | return tx.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error {
100 | var groupID int
101 | err = tx.Model(&FavoriteGroup{}).Select("IFNULL(MAX(favorite_group_id), 0) AS max_id").Where("user_id = ? and deleted = false", userID).
102 | Take(&groupID).Error
103 | groupID++
104 | if err != nil {
105 | return err
106 | }
107 | if groupID >= MaxGroupPerUser {
108 | err = tx.Model(&FavoriteGroup{}).Where("user_id = ? and deleted = true", userID).Order("favorite_group_id").Limit(1).Take(&groupID).Error
109 | }
110 | if errors.Is(err, gorm.ErrRecordNotFound) {
111 | return common.Forbidden("收藏夹数量已达上限")
112 | }
113 | if err != nil {
114 | return err
115 | }
116 |
117 | err = tx.Create(&FavoriteGroup{
118 | UserID: userID,
119 | Name: name,
120 | FavoriteGroupID: groupID,
121 | CreatedAt: time.Now(),
122 | }).Error
123 | if err != nil {
124 | return err
125 | }
126 | return tx.Model(&User{}).Where("id = ?", userID).Update("favorite_group_count", gorm.Expr("favorite_group_count + 1")).Error
127 | })
128 | }
129 |
130 | func ModifyUserFavoriteGroup(tx *gorm.DB, userID int, groupID int, name string) (err error) {
131 | return tx.Clauses(dbresolver.Write).Where("user_id = ? AND favorite_group_id = ?", userID, groupID).
132 | Updates(FavoriteGroup{Name: name, UpdatedAt: time.Now()}).Error
133 | }
134 |
--------------------------------------------------------------------------------
/models/floor_history.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import "time"
4 |
5 | type FloorHistory struct {
6 | /// base info
7 | ID int `json:"id" gorm:"primaryKey"`
8 | CreatedAt time.Time `json:"time_created"`
9 | UpdatedAt time.Time `json:"time_updated"`
10 | Content string `json:"content" gorm:"size:15000"`
11 | Reason string `json:"reason"`
12 | FloorID int `json:"floor_id"`
13 | // auto sensitive check
14 | IsSensitive bool `json:"is_sensitive"`
15 |
16 | // manual sensitive check
17 | IsActualSensitive *bool `json:"is_actual_sensitive"`
18 |
19 | SensitiveDetail string `json:"sensitive_detail,omitempty"`
20 | // The one who modified the floor
21 | UserID int `json:"user_id"`
22 | }
23 |
24 | type FloorHistorySlice []*FloorHistory
25 |
--------------------------------------------------------------------------------
/models/floor_like.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | type FloorLike struct {
4 | FloorID int `json:"floor_id" gorm:"primaryKey"`
5 | UserID int `json:"user_id" gorm:"primaryKey"`
6 | LikeData int8 `json:"like_data"`
7 | }
8 |
--------------------------------------------------------------------------------
/models/floor_mention.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import (
4 | "regexp"
5 |
6 | "gorm.io/gorm"
7 |
8 | "treehole_next/utils"
9 | )
10 |
11 | type FloorMention struct {
12 | FloorID int `json:"floor_id" gorm:"primaryKey"`
13 | MentionID int `json:"mention_id" gorm:"primaryKey"`
14 | }
15 |
16 | func (FloorMention) TableName() string {
17 | return "floor_mention"
18 | }
19 |
20 | var reHole = regexp.MustCompile(`[^#]#(\d+)`)
21 | var reFloor = regexp.MustCompile(`##(\d+)`)
22 |
23 | func parseMentionIDs(content string) (holeIDs []int, floorIDs []int, err error) {
24 | // todo: parse replyTo
25 |
26 | // find mentioned holeIDs
27 | holeIDsText := reHole.FindAllStringSubmatch(" "+content, -1)
28 | holeIDs, err = utils.RegText2IntArray(holeIDsText)
29 | if err != nil {
30 | return nil, nil, err
31 | }
32 |
33 | // find mentioned floorIDs
34 | floorIDsText := reFloor.FindAllStringSubmatch(" "+content, -1)
35 | floorIDs, err = utils.RegText2IntArray(floorIDsText)
36 | return holeIDs, floorIDs, err
37 | }
38 |
39 | func LoadFloorMentions(tx *gorm.DB, content string) (Floors, error) {
40 | holeIDs, floorIDs, err := parseMentionIDs(content)
41 | if err != nil {
42 | return nil, err
43 | }
44 |
45 | queryGetHoleFloors := tx.Model(&Floor{}).Where("hole_id in ? and ranking = 0", holeIDs)
46 | queryGetFloors := tx.Model(&Floor{}).Where("id in ?", floorIDs)
47 | mentionFloors := Floors{}
48 | if len(holeIDs) > 0 && len(floorIDs) > 0 {
49 | err = tx.Raw(`? UNION ?`, queryGetHoleFloors, queryGetFloors).Scan(&mentionFloors).Error
50 | } else if len(holeIDs) > 0 {
51 | err = queryGetHoleFloors.Scan(&mentionFloors).Error
52 | } else if len(floorIDs) > 0 {
53 | err = queryGetFloors.Scan(&mentionFloors).Error
54 | }
55 | return mentionFloors, err
56 | }
57 |
--------------------------------------------------------------------------------
/models/hole_tags.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | type HoleTag struct {
4 | HoleID int `json:"hole_id" gorm:"index"`
5 | TagID int `json:"tag_id" gorm:"index"`
6 | }
7 |
8 | func (HoleTag) TableName() string {
9 | return "hole_tags"
10 | }
11 |
12 | type HoleTags []*HoleTag
13 |
--------------------------------------------------------------------------------
/models/init.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import (
4 | "os"
5 | "time"
6 |
7 | "github.com/rs/zerolog/log"
8 |
9 | "treehole_next/config"
10 |
11 | "gorm.io/gorm/logger"
12 | "gorm.io/plugin/dbresolver"
13 |
14 | "gorm.io/driver/mysql"
15 | "gorm.io/driver/sqlite"
16 | "gorm.io/gorm"
17 | "gorm.io/gorm/schema"
18 | )
19 |
20 | var DB *gorm.DB
21 |
22 | var gormConfig = &gorm.Config{
23 | NamingStrategy: schema.NamingStrategy{
24 | SingularTable: true, // use singular table name, table for `User` would be `user` with this option enabled
25 | },
26 | Logger: logger.New(
27 | &log.Logger,
28 | logger.Config{
29 | SlowThreshold: time.Second, // 慢 SQL 阈值
30 | LogLevel: logger.Error, // 日志级别
31 | IgnoreRecordNotFoundError: true, // 忽略ErrRecordNotFound(记录未找到)错误
32 | Colorful: false, // 禁用彩色打印
33 | },
34 | ),
35 | }
36 |
37 | // Read/Write Splitting
38 | func mysqlDB() *gorm.DB {
39 | // set source databases
40 | source := mysql.Open(config.Config.DbURL)
41 | db, err := gorm.Open(source, gormConfig)
42 | if err != nil {
43 | log.Fatal().Err(err).Send()
44 | }
45 |
46 | // set replica databases
47 | var replicas []gorm.Dialector
48 | for _, url := range config.Config.MysqlReplicaURLs {
49 | replicas = append(replicas, mysql.Open(url))
50 | }
51 | err = db.Use(dbresolver.Register(dbresolver.Config{
52 | Sources: []gorm.Dialector{source},
53 | Replicas: replicas,
54 | Policy: dbresolver.RandomPolicy{},
55 | }))
56 | if err != nil {
57 | log.Fatal().Err(err).Send()
58 | }
59 | return db
60 | }
61 |
62 | func sqliteDB() *gorm.DB {
63 | err := os.MkdirAll("data", 0750)
64 | if err != nil {
65 | log.Fatal().Err(err).Send()
66 | }
67 | db, err := gorm.Open(sqlite.Open("data/sqlite.db"), gormConfig)
68 | if err != nil {
69 | log.Fatal().Err(err).Send()
70 | }
71 | // https://github.com/go-gorm/gorm/issues/3709
72 | phyDB, err := db.DB()
73 | if err != nil {
74 | log.Fatal().Err(err).Send()
75 | }
76 | phyDB.SetMaxOpenConns(1)
77 | return db
78 | }
79 |
80 | func memoryDB() *gorm.DB {
81 | db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), gormConfig)
82 | if err != nil {
83 | log.Fatal().Err(err).Send()
84 | }
85 | // https://github.com/go-gorm/gorm/issues/3709
86 | phyDB, err := db.DB()
87 | if err != nil {
88 | log.Fatal().Err(err).Send()
89 | }
90 | phyDB.SetMaxOpenConns(1)
91 | return db
92 | }
93 |
94 | func InitDB() {
95 | var err error
96 | switch config.Config.Mode {
97 | case "production":
98 | DB = mysqlDB()
99 | case "test":
100 | fallthrough
101 | case "bench":
102 | DB = memoryDB()
103 | case "dev":
104 | if config.Config.DbURL == "" {
105 | DB = sqliteDB()
106 | } else {
107 | DB = mysqlDB()
108 | }
109 | default:
110 | log.Fatal().Msg("unknown mode")
111 | }
112 |
113 | switch config.Config.Mode {
114 | case "test":
115 | fallthrough
116 | case "dev":
117 | DB = DB.Debug()
118 | }
119 |
120 | err = DB.SetupJoinTable(&User{}, "UserLikedFloors", &FloorLike{})
121 | if err != nil {
122 | log.Fatal().Err(err).Send()
123 | }
124 |
125 | err = DB.SetupJoinTable(&Hole{}, "Mapping", &AnonynameMapping{})
126 | if err != nil {
127 | log.Fatal().Err(err).Send()
128 | }
129 |
130 | err = DB.SetupJoinTable(&Message{}, "Users", &MessageUser{})
131 | if err != nil {
132 | log.Fatal().Err(err).Send()
133 | }
134 |
135 | err = DB.SetupJoinTable(&User{}, "UserSubscription", &UserSubscription{})
136 | if err != nil {
137 | log.Fatal().Err(err).Send()
138 | }
139 |
140 | // models must be registered here to migrate into the database
141 | err = DB.AutoMigrate(
142 | &Division{},
143 | &Tag{},
144 | &User{},
145 | &Floor{},
146 | &Hole{},
147 | &Report{},
148 | &Punishment{},
149 | &ReportPunishment{},
150 | &Message{},
151 | &FloorHistory{},
152 | &AdminLog{},
153 | &UserFavorite{},
154 | &FavoriteGroup{},
155 | &UrlHostnameWhitelist{},
156 | )
157 | if err != nil {
158 | log.Fatal().Err(err).Send()
159 | }
160 |
161 | err = DB.Model(&UrlHostnameWhitelist{}).Pluck("hostname", &config.Config.UrlHostnameWhitelist).Error
162 | if err != nil {
163 | log.Fatal().Err(err).Send()
164 | }
165 | }
166 |
--------------------------------------------------------------------------------
/models/message.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | // Should be same as message in notification project
4 |
5 | import (
6 | "time"
7 |
8 | "github.com/gofiber/fiber/v2"
9 | "gorm.io/gorm"
10 | )
11 |
12 | type Messages []Message
13 |
14 | type Message struct {
15 | ID int `gorm:"primaryKey" json:"id"`
16 | CreatedAt time.Time `json:"time_created"`
17 | UpdatedAt time.Time `json:"time_updated"`
18 | Title string `json:"message" gorm:"size:1024;not null"`
19 | Description string `json:"description" gorm:"size:65536;not null"`
20 | Data any `json:"data" gorm:"serializer:json" `
21 | Type MessageType `json:"code" gorm:"size:16;not null"`
22 | URL string `json:"url" gorm:"size:64;default:'';not null"`
23 | Recipients []int `json:"-" gorm:"-:all" `
24 | MessageID int `json:"message_id" gorm:"-:all"` // 兼容旧版 id
25 | HasRead bool `json:"has_read" gorm:"default:false"` // 兼容旧版, 永远为false,以MessageUser的HasRead为准
26 | Users Users `json:"-" gorm:"many2many:message_user;constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"`
27 | }
28 |
29 | type MessageUser struct {
30 | MessageID int `json:"message_id" gorm:"primaryKey"`
31 | UserID int `json:"user_id" gorm:"primaryKey"`
32 | HasRead bool `json:"has_read" gorm:"default:false"` // 兼容旧版
33 | }
34 |
35 | type MessageType string
36 |
37 | const (
38 | MessageTypeFavorite MessageType = "favorite"
39 | MessageTypeReply MessageType = "reply"
40 | MessageTypeMention MessageType = "mention"
41 | MessageTypeModify MessageType = "modify" // including fold and delete
42 | MessageTypePermission MessageType = "permission"
43 | MessageTypeReport MessageType = "report"
44 | MessageTypeReportDealt MessageType = "report_dealt"
45 | MessageTypeMail MessageType = "mail"
46 | MessageTypeSensitive MessageType = "sensitive"
47 | )
48 |
49 | func (messages Messages) Preprocess(c *fiber.Ctx) error {
50 | for i := 0; i < len(messages); i++ {
51 | err := messages[i].Preprocess(c)
52 | if err != nil {
53 | return err
54 | }
55 | }
56 | return nil
57 | }
58 |
59 | func (message *Message) Preprocess(_ *fiber.Ctx) error {
60 | message.MessageID = message.ID
61 | return nil
62 | }
63 |
64 | func (message *Message) AfterCreate(tx *gorm.DB) (err error) {
65 | mapping := make([]MessageUser, len(message.Recipients))
66 | for i, userID := range message.Recipients {
67 | mapping[i] = MessageUser{
68 | MessageID: message.ID,
69 | UserID: userID,
70 | }
71 | }
72 | return tx.Create(&mapping).Error
73 | }
74 |
--------------------------------------------------------------------------------
/models/notification.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "errors"
7 | "fmt"
8 | "io"
9 | "math/rand"
10 | "net/http"
11 | "regexp"
12 | "strings"
13 | "sync"
14 | "time"
15 |
16 | "github.com/rs/zerolog/log"
17 |
18 | "treehole_next/config"
19 | "treehole_next/utils"
20 |
21 | "golang.org/x/exp/slices"
22 | "gorm.io/gorm/clause"
23 |
24 | "github.com/goccy/go-json"
25 | )
26 |
27 | const (
28 | timeout = time.Second * 10
29 | )
30 |
31 | var client = http.Client{Timeout: timeout}
32 |
33 | type Notifications []Notification
34 |
35 | type Notification struct {
36 | // Should be same as CrateModel in notification project
37 | Title string `json:"message"`
38 | Description string `json:"description"`
39 | Data any `json:"data"`
40 | Type MessageType `json:"code"`
41 | URL string `json:"url"`
42 | Recipients []int `json:"recipients"`
43 | }
44 |
45 | func readRespNotification(body io.ReadCloser) Notification {
46 | defer func(body io.ReadCloser) {
47 | err := body.Close()
48 | if err != nil {
49 | log.Err(err).Str("model", "Notification").Msg("error close body")
50 | }
51 | }(body)
52 |
53 | data, err := io.ReadAll(body)
54 | if err != nil {
55 | log.Err(err).Str("model", "Notification").Msg("error read body")
56 | return Notification{}
57 | }
58 | var response Notification
59 | err = json.Unmarshal(data, &response)
60 | if err != nil {
61 | log.Err(err).Str("model", "Notification").Msg("error unmarshal body")
62 | return Notification{}
63 | }
64 | return response
65 | }
66 |
67 | func (messages Notifications) Merge(newNotification Notification) Notifications {
68 | if len(newNotification.Recipients) == 0 {
69 | return messages
70 | }
71 |
72 | newMerge := newNotification.Recipients
73 | for _, message := range messages {
74 | old := message.Recipients
75 | for _, r1 := range old {
76 | for id, r2 := range newMerge {
77 | if r1 == r2 {
78 | newMerge = append(newMerge[:id], newMerge[id+1:]...)
79 | break
80 | }
81 | }
82 | }
83 | if len(newMerge) == 0 {
84 | return messages
85 | }
86 | }
87 |
88 | newNotification.Recipients = newMerge
89 | return append(messages, newNotification)
90 | }
91 |
92 | func (messages Notifications) Send() error {
93 | if messages == nil {
94 | return nil
95 | }
96 |
97 | for _, message := range messages {
98 | _, err := message.Send()
99 | if err != nil {
100 | return err
101 | }
102 | }
103 | return nil
104 | }
105 |
106 | // check user.config.Notify contain message.Type
107 | func (message *Notification) checkConfig() {
108 | // generate new recipients
109 | var newRecipient []int
110 |
111 | // find users
112 | var users []User
113 | result := DB.Find(&users, "id in ?", message.Recipients)
114 | if result.Error != nil {
115 | message.Recipients = newRecipient
116 | return
117 | }
118 |
119 | // filter recipients
120 | for _, user := range users {
121 | if slices.Contains(defaultUserConfig.Notify, string(message.Type)) && !slices.Contains(user.Config.Notify, string(message.Type)) {
122 | continue
123 | }
124 | newRecipient = append(newRecipient, user.ID)
125 | }
126 | message.Recipients = newRecipient
127 | }
128 |
129 | func (message Notification) Send() (Message, error) {
130 | // only for test
131 | // message["recipients"] = []int{1}
132 |
133 | var err error
134 |
135 | message.checkConfig()
136 | // return if no recipient
137 | if len(message.Recipients) == 0 {
138 | return Message{}, nil
139 | }
140 |
141 | // save to database first
142 | body := Message{
143 | Type: message.Type,
144 | Title: message.Title,
145 | Description: message.Description,
146 | Data: message.Data,
147 | URL: message.URL,
148 | Recipients: message.Recipients,
149 | }
150 | err = DB.Omit(clause.Associations).Create(&body).Error
151 | if err != nil {
152 | log.Err(err).Str("model", "Notification").Msg("message save failed: " + err.Error())
153 | return Message{}, err
154 | }
155 | if config.Config.NotificationUrl == "" {
156 | return Message{}, nil
157 | }
158 | message.Title = utils.StripContent(message.Title, 32) //varchar(32)
159 | message.Description = utils.StripContent(cleanNotificationDescription(message.Description), 64) //varchar(64)
160 | body.Title = message.Title
161 | body.Description = message.Description
162 |
163 | // construct form
164 | form, err := json.Marshal(message)
165 | if err != nil {
166 | log.Err(err).Str("model", "Notification").Msg("error encoding notification")
167 | return Message{}, err
168 | }
169 |
170 | // construct http request
171 | req, err := http.NewRequest(
172 | "POST",
173 | fmt.Sprintf("%s/messages", config.Config.NotificationUrl),
174 | bytes.NewBuffer(form),
175 | )
176 | if err != nil {
177 | log.Err(err).Str("model", "Notification").Msg("error making request")
178 | return Message{}, err
179 | }
180 | req.Header.Add("Content-Type", "application/json")
181 |
182 | // bench and simulation
183 | if config.Config.Mode == "bench" {
184 | time.Sleep(time.Millisecond)
185 | return Message{}, nil
186 | }
187 |
188 | // get response
189 | resp, err := client.Do(req)
190 | if err != nil {
191 | log.Err(err).Str("model", "Notification").Msg("error sending notification")
192 | return Message{}, err
193 | }
194 |
195 | response := readRespNotification(resp.Body)
196 | if resp.StatusCode != 201 {
197 | log.Error().Str("model", "Notification").Any("response", response).Msg("notification response failed")
198 | return Message{}, errors.New(fmt.Sprint(response))
199 | }
200 |
201 | return body, nil
202 | }
203 |
204 | var adminList struct {
205 | sync.RWMutex
206 | data []int
207 | }
208 |
209 | func InitAdminList() {
210 | // skip when bench
211 | if config.Config.Mode == "bench" || config.Config.AuthUrl == "" {
212 | return
213 | }
214 |
215 | // // http request
216 | // res, err := http.Get(config.Config.AuthUrl + "/users/admin")
217 |
218 | // // handle err
219 | // if err != nil {
220 | // log.Err(err).Str("model", "get admin").Msg("error sending auth server")
221 | // return
222 | // }
223 |
224 | // defer func() {
225 | // _ = res.Body.Close()
226 | // }()
227 |
228 | // if res.StatusCode != 200 {
229 | // log.Error().Str("model", "get admin").Msg("auth server response failed" + res.Status)
230 | // return
231 | // }
232 |
233 | // data, err := io.ReadAll(res.Body)
234 | // if err != nil {
235 | // log.Err(err).Str("model", "get admin").Msg("error reading auth server response")
236 | // return
237 | // }
238 |
239 | adminList.Lock()
240 | defer adminList.Unlock()
241 |
242 | // err = json.Unmarshal(data, &adminList.data)
243 | // if err != nil {
244 | // log.Err(err).Str("model", "get admin").Msg("error unmarshal auth server response")
245 | // return
246 | // }
247 | adminList.data = config.Config.NotifiableAdminIds
248 |
249 | // shuffle ids
250 | for i := range adminList.data {
251 | j := rand.Intn(i + 1)
252 | adminList.data[i], adminList.data[j] = adminList.data[j], adminList.data[i]
253 | }
254 | }
255 |
256 | func UpdateAdminList(ctx context.Context) {
257 | ticker := time.NewTicker(time.Minute)
258 | for {
259 | select {
260 | case <-ctx.Done():
261 | return
262 | case <-ticker.C:
263 | InitAdminList()
264 | }
265 | }
266 | }
267 |
268 | var (
269 | reMention = regexp.MustCompile(`#{1,2}\d+`)
270 | reFormula = regexp.MustCompile(`(?s)\${1,2}.*?\${1,2}`)
271 | reSticker = regexp.MustCompile(`!\[\]\(dx_\S+?\)`)
272 | reImage = regexp.MustCompile(`!\[.*?\]\(.*?\)`)
273 | )
274 |
275 | func cleanNotificationDescription(content string) string {
276 | newContent := reMention.ReplaceAllString(content, "")
277 | newContent = reFormula.ReplaceAllString(newContent, "[公式]")
278 | newContent = reSticker.ReplaceAllString(newContent, "[表情]")
279 | newContent = reImage.ReplaceAllString(newContent, "[图片]")
280 | newContent = strings.ReplaceAll(newContent, "\n", "")
281 | if newContent == "" {
282 | return content
283 | }
284 | return newContent
285 | }
286 |
--------------------------------------------------------------------------------
/models/punishment.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import (
4 | "errors"
5 | "time"
6 |
7 | "gorm.io/gorm"
8 | "gorm.io/gorm/clause"
9 | "gorm.io/plugin/dbresolver"
10 | )
11 |
12 | // Punishment
13 | // a record of user punishment
14 | // when a record created, it can't be modified if other admins punish this user on the same floor
15 | // whether a user is banned to post on one division based on the latest / max(id) record
16 | // if admin want to modify punishment duration, manually modify the latest record of this user in database
17 | // admin can be granted update privilege on SQL view of this table
18 | type Punishment struct {
19 | ID int `json:"id" gorm:"primaryKey"`
20 |
21 | // time when this punishment creates
22 | CreatedAt time.Time `json:"created_at"`
23 |
24 | UpdatedAt time.Time `json:"updated_at"`
25 |
26 | // time when this punishment revoked
27 | DeletedAt gorm.DeletedAt `json:"-"`
28 |
29 | // start from end_time of previous punishment (punishment accumulation of different floors)
30 | // if no previous punishment or previous punishment end time less than time.Now() (synced), set start time time.Now()
31 | StartTime time.Time `json:"start_time" gorm:"not null"`
32 |
33 | // end_time of this punishment
34 | EndTime time.Time `json:"end_time" gorm:"not null"`
35 |
36 | Duration *time.Duration `json:"duration" swaggertype:"integer"`
37 |
38 | Day int `json:"day"`
39 |
40 | // user punished
41 | UserID int `json:"user_id" gorm:"not null;index"`
42 |
43 | // admin user_id who made this punish
44 | MadeBy int `json:"made_by,omitempty"`
45 |
46 | // punished because of this floor
47 | FloorID *int `json:"floor_id" gorm:"uniqueIndex"`
48 |
49 | Floor *Floor `json:"floor,omitempty" gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"` // foreign key
50 |
51 | DivisionID int `json:"division_id" gorm:"not null"`
52 |
53 | Division *Division `json:"division,omitempty"` // foreign key
54 |
55 | // reason
56 | Reason string `json:"reason" gorm:"size:128"`
57 | }
58 |
59 | type Punishments []*Punishment
60 |
61 | func (punishment *Punishment) Create() (*User, error) {
62 | var user User
63 |
64 | err := DB.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error {
65 | err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Take(&user, punishment.UserID).Error
66 | if err != nil {
67 | return err
68 | }
69 |
70 | var previousPunishment Punishment
71 | err = tx.Where("user_id = ? and floor_id = ?", user.ID, punishment.FloorID).Take(&previousPunishment).Error
72 | if err == nil {
73 | // return common.Forbidden("该用户已被禁言")
74 |
75 | // same as before, do nothing
76 | if previousPunishment.Duration == punishment.Duration && previousPunishment.Day == punishment.Day {
77 | return nil
78 | }
79 |
80 | // different duration, revoke previous punishment
81 | diffDuration := time.Duration(punishment.Day-previousPunishment.Day) * 24 * time.Hour
82 |
83 | previousPunishment.Duration = punishment.Duration
84 | previousPunishment.Day = punishment.Day
85 | previousPunishment.EndTime = previousPunishment.StartTime.Add(*punishment.Duration)
86 | previousPunishment.Reason = punishment.Reason
87 | previousPunishment.MadeBy = punishment.MadeBy
88 | // conflict with previous punishment if not equal
89 | // ignore it as it's rare
90 | previousPunishment.DivisionID = punishment.DivisionID
91 |
92 | if user.BanDivision[punishment.DivisionID] == nil {
93 | user.BanDivision[punishment.DivisionID] = &previousPunishment.EndTime
94 | } else {
95 | *user.BanDivision[punishment.DivisionID] = user.BanDivision[punishment.DivisionID].Add(diffDuration)
96 | }
97 |
98 | err = tx.Updates(&previousPunishment).Error
99 | if err != nil {
100 | return err
101 | }
102 | } else if !errors.Is(err, gorm.ErrRecordNotFound) {
103 | return err
104 | } else {
105 | user.OffenceCount += 1
106 | punishment.StartTime = time.Now()
107 | punishment.EndTime = punishment.StartTime.Add(*punishment.Duration)
108 |
109 | if user.BanDivision[punishment.DivisionID] == nil {
110 | user.BanDivision[punishment.DivisionID] = &punishment.EndTime
111 | } else {
112 | *user.BanDivision[punishment.DivisionID] = user.BanDivision[punishment.DivisionID].Add(*punishment.Duration)
113 | }
114 |
115 | err = tx.Create(&punishment).Error
116 | if err != nil {
117 | return err
118 | }
119 | }
120 |
121 | err = tx.Select("BanDivision", "OffenceCount").Save(&user).Error
122 | if err != nil {
123 | return err
124 | }
125 |
126 | return nil
127 | })
128 | return &user, err
129 | }
130 |
--------------------------------------------------------------------------------
/models/report.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "sync/atomic"
7 | "time"
8 |
9 | "github.com/rs/zerolog/log"
10 |
11 | "github.com/gofiber/fiber/v2"
12 | "github.com/opentreehole/go-common"
13 | "gorm.io/gorm"
14 | )
15 |
16 | type Report struct {
17 | ID int `json:"id" gorm:"primaryKey"`
18 | CreatedAt time.Time `json:"time_created"`
19 | UpdatedAt time.Time `json:"time_updated"`
20 | ReportID int `json:"report_id" gorm:"-:all"`
21 | FloorID int `json:"floor_id"`
22 | HoleID int `json:"hole_id" gorm:"-:all"`
23 | Floor *Floor `json:"floor" gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"`
24 | UserID int `json:"-"` // the reporter's id, should keep a secret
25 | Reason string `json:"reason" gorm:"size:128"`
26 | Dealt bool `json:"dealt"` // the report has been dealt
27 | // who dealt the report
28 | DealtBy int `json:"dealt_by" gorm:"index"`
29 | Result string `json:"result" gorm:"size:128"` // deal result
30 | }
31 |
32 | func (report *Report) GetID() int {
33 | return report.ID
34 | }
35 |
36 | type Reports []*Report
37 |
38 | func (report *Report) Preprocess(c *fiber.Ctx) (err error) {
39 | err = report.Floor.SetDefaults(c)
40 | if err != nil {
41 | return err
42 | }
43 | for i := range report.Floor.Mention {
44 | err = report.Floor.Mention[i].SetDefaults(c)
45 | if err != nil {
46 | return err
47 | }
48 | }
49 | report.HoleID = report.Floor.HoleID
50 | return nil
51 | }
52 |
53 | func (reports Reports) Preprocess(c *fiber.Ctx) error {
54 | for i := 0; i < len(reports); i++ {
55 | _ = reports[i].Preprocess(c)
56 | }
57 | return nil
58 | }
59 |
60 | func LoadReportFloor(tx *gorm.DB) *gorm.DB {
61 | return tx.Preload("Floor.Mention").Preload("Floor")
62 | }
63 |
64 | func (report *Report) Create(c *fiber.Ctx, db ...*gorm.DB) error {
65 | var tx *gorm.DB
66 | if len(db) > 0 {
67 | tx = db[0]
68 | } else {
69 | tx = DB
70 | }
71 | userID, err := common.GetUserID(c)
72 | if err != nil {
73 | return err
74 | }
75 |
76 | existingReport := Report{}
77 | err = tx.Where("user_id = ? AND floor_id = ?", userID, report.FloorID).First(&existingReport).Error
78 | if err != nil {
79 | if !errors.Is(err, gorm.ErrRecordNotFound) {
80 | return err
81 | }
82 | }
83 |
84 | if errors.Is(err, gorm.ErrRecordNotFound) {
85 | report.UserID = userID
86 | err = tx.Create(&report).Error
87 | if err != nil {
88 | return err
89 | }
90 |
91 | report.ReportID = report.ID
92 |
93 | err = tx.Model(report).Association("Floor").Find(&report.Floor)
94 | if err != nil {
95 | return err
96 | }
97 |
98 | err = report.Preprocess(c)
99 | if err != nil {
100 | return err
101 | }
102 | } else {
103 | existingReport.Reason = existingReport.Reason + "\n" + report.Reason
104 | err = tx.Model(&existingReport).Updates(map[string]any{
105 | "reason": existingReport.Reason,
106 | "dealt": false,
107 | }).Error // update reason and load floor in AfterUpdate hook
108 | if err != nil {
109 | return err
110 | }
111 | report.Floor = existingReport.Floor
112 | }
113 |
114 | return nil
115 | }
116 |
117 | func (report *Report) AfterFind(_ *gorm.DB) (err error) {
118 | report.ReportID = report.ID
119 |
120 | return nil
121 | }
122 |
123 | func (report *Report) AfterUpdate(tx *gorm.DB) (err error) {
124 | err = tx.Model(report).Association("Floor").Find(&report.Floor)
125 | if err != nil {
126 | return err
127 | }
128 |
129 | //err = report.Preprocess(nil)
130 | //if err != nil {
131 | // return err
132 | //}
133 |
134 | return nil
135 | }
136 |
137 | var adminCounter = new(int32)
138 |
139 | func (report *Report) SendCreate(_ *gorm.DB) error {
140 | // adminList.RLock()
141 | // defer adminList.RUnlock()
142 | if len(adminList.data) == 0 {
143 | return nil
144 | }
145 |
146 | // get counter
147 | currentCounter := atomic.AddInt32(adminCounter, 1)
148 | result := atomic.CompareAndSwapInt32(adminCounter, int32(len(adminList.data)), 0)
149 | if result {
150 | log.Info().Str("model", "get admin").Msg("adminCounter Reset")
151 | }
152 | userIDs := []int{adminList.data[currentCounter-1]}
153 |
154 | // construct message
155 | message := Notification{
156 | Data: report,
157 | Recipients: userIDs,
158 | Description: fmt.Sprintf(
159 | "理由:%s,内容:%s",
160 | report.Reason,
161 | report.Floor.Content,
162 | ),
163 | Title: "您有举报需要处理",
164 | Type: MessageTypeReport,
165 | URL: fmt.Sprintf("/api/reports/%d", report.ID),
166 | }
167 |
168 | // send
169 | _, err := message.Send()
170 | return err
171 | }
172 |
173 | func (report *Report) SendModify(_ *gorm.DB) error {
174 | // get recipients
175 | userIDs := []int{report.UserID}
176 |
177 | // construct message
178 | message := Notification{
179 | Data: report,
180 | Recipients: userIDs,
181 | Description: fmt.Sprintf(
182 | "处理结果:%s\n感谢您为维护社区秩序所做的贡献。",
183 | report.Result,
184 | ),
185 | Title: "您的举报已得到处理",
186 | Type: MessageTypeReportDealt,
187 | URL: fmt.Sprintf("/api/reports/%d", report.ID),
188 | }
189 |
190 | // send
191 | _, err := message.Send()
192 | return err
193 | }
194 |
--------------------------------------------------------------------------------
/models/report_punishment.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import (
4 | "errors"
5 | "github.com/opentreehole/go-common"
6 | "gorm.io/gorm"
7 | "gorm.io/gorm/clause"
8 | "gorm.io/plugin/dbresolver"
9 | "time"
10 | )
11 |
12 | type ReportPunishment struct {
13 | ID int `json:"id" gorm:"primaryKey"`
14 |
15 | // time when this report punishment creates
16 | CreatedAt time.Time `json:"created_at"`
17 |
18 | // time when this report punishment revoked
19 | DeletedAt gorm.DeletedAt `json:"deleted_at"`
20 |
21 | // start from end_time of previous punishment (punishment accumulation of different floors)
22 | // if no previous punishment or previous punishment end time less than time.Now() (synced), set start time time.Now()
23 | StartTime time.Time `json:"start_time" gorm:"not null"`
24 |
25 | // end_time of this report punishment
26 | EndTime time.Time `json:"end_time" gorm:"not null"`
27 |
28 | Duration *time.Duration `json:"duration"`
29 |
30 | // user punished
31 | UserID int `json:"user_id" gorm:"not null;index"`
32 |
33 | // admin user_id who made this punish
34 | MadeBy int `json:"made_by,omitempty"`
35 |
36 | // punished because of this report
37 | ReportId int `json:"report_id" gorm:"uniqueIndex"`
38 |
39 | Report *Report `json:"report,omitempty"` // foreign key
40 |
41 | // reason
42 | Reason string `json:"reason" gorm:"size:128"`
43 | }
44 |
45 | type ReportPunishments []*ReportPunishment
46 |
47 | func (reportPunishment *ReportPunishment) Create() (*User, error) {
48 | var user User
49 |
50 | err := DB.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error {
51 | err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Take(&user, reportPunishment.UserID).Error
52 | if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
53 | return err
54 | }
55 |
56 | // if the user has been banned from this report
57 | var punishmentRecord ReportPunishment
58 | err = tx.Where("user_id = ? and report_id = ?", user.ID, reportPunishment.ReportId).Take(&punishmentRecord).Error
59 | if err == nil {
60 | return common.Forbidden("该用户已被限制使用举报功能")
61 | } else if !errors.Is(err, gorm.ErrRecordNotFound) {
62 | return err
63 | }
64 |
65 | var lastPunishment ReportPunishment
66 | err = tx.Where("user_id = ?", user.ID).Last(&lastPunishment).Error
67 | if err == nil {
68 | if lastPunishment.EndTime.Before(time.Now()) {
69 | reportPunishment.StartTime = time.Now()
70 | } else {
71 | reportPunishment.StartTime = lastPunishment.EndTime
72 | }
73 | } else if errors.Is(err, gorm.ErrRecordNotFound) {
74 | reportPunishment.StartTime = time.Now()
75 | } else {
76 | return err
77 | }
78 |
79 | reportPunishment.EndTime = reportPunishment.StartTime.Add(*reportPunishment.Duration)
80 |
81 | user.BanReport = &reportPunishment.EndTime
82 | user.BanReportCount += 1
83 |
84 | err = tx.Create(&reportPunishment).Error
85 | if err != nil {
86 | return err
87 | }
88 |
89 | err = tx.Select("BanReport", "BanReportCount").Save(&user).Error
90 | if err != nil {
91 | return err
92 | }
93 |
94 | return nil
95 | })
96 | return &user, err
97 | }
98 |
--------------------------------------------------------------------------------
/models/tag.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import (
4 | "fmt"
5 | "strings"
6 | "sync"
7 | "time"
8 | "treehole_next/config"
9 | "treehole_next/utils/sensitive"
10 |
11 | "github.com/gofiber/fiber/v2"
12 |
13 | "github.com/opentreehole/go-common"
14 | "github.com/rs/zerolog/log"
15 | "golang.org/x/exp/slices"
16 | "gorm.io/gorm"
17 | "gorm.io/gorm/clause"
18 |
19 | "treehole_next/utils"
20 | )
21 |
22 | type Tag struct {
23 | /// saved fields
24 | ID int `json:"id" gorm:"primaryKey"`
25 | CreatedAt time.Time `json:"-" gorm:"not null"`
26 | UpdatedAt time.Time `json:"-" gorm:"not null"`
27 |
28 | /// base info
29 | Name string `json:"name" gorm:"not null;unique;size:32"`
30 | Temperature int `json:"temperature" gorm:"not null;default:0"`
31 |
32 | IsZZMG bool `json:"-" gorm:"not null;default:false"`
33 |
34 | /// association info, should add foreign key
35 | Holes Holes `json:"-" gorm:"many2many:hole_tags;constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"`
36 | // auto sensitive check
37 | IsSensitive bool `json:"-" gorm:"index:idx_tag_actual_sensitive,priority:1"`
38 |
39 | // manual sensitive check
40 | IsActualSensitive *bool `json:"-" gorm:"index:idx_tag_actual_sensitive,priority:2"`
41 | /// generated field
42 | TagID int `json:"tag_id" gorm:"-:all"`
43 |
44 | Nsfw bool `json:"nsfw" gorm:"not null;default:false;index"`
45 | }
46 |
47 | type Tags []*Tag
48 |
49 | func (tag *Tag) GetID() int {
50 | return tag.ID
51 | }
52 |
53 | func (tag *Tag) AfterFind(_ *gorm.DB) (err error) {
54 | tag.TagID = tag.ID
55 | return nil
56 | }
57 |
58 | func (tag *Tag) BeforeCreate(_ *gorm.DB) (err error) {
59 | if len(tag.Name) > 0 && tag.Name[0] == '*' {
60 | tag.Nsfw = true
61 | }
62 | return nil
63 | }
64 |
65 | func (tag *Tag) AfterCreate(_ *gorm.DB) (err error) {
66 | tag.TagID = tag.ID
67 | return nil
68 | }
69 |
70 | func FindOrCreateTags(tx *gorm.DB, user *User, names []string) (Tags, error) {
71 | tags := make(Tags, 0)
72 | for i, name := range names {
73 | names[i] = strings.TrimSpace(name)
74 | }
75 | err := tx.Where("name in ?", names).Find(&tags).Error
76 | if err != nil {
77 | return nil, err
78 | }
79 |
80 | existTagNames := make([]string, 0)
81 | for _, tag := range tags {
82 | existTagNames = append(existTagNames, tag.Name)
83 | if !user.IsAdmin {
84 | if slices.ContainsFunc(config.Config.AdminOnlyTagIds, func(i int) bool {
85 | return i == tag.ID
86 | }) {
87 | return nil, common.Forbidden(fmt.Sprintf("标签 %s 为管理员专用标签", tag.Name))
88 | }
89 | }
90 | }
91 |
92 | newTags := make(Tags, 0)
93 | for _, name := range names {
94 | name = strings.TrimSpace(name)
95 | if !slices.ContainsFunc(existTagNames, func(s string) bool {
96 | return strings.EqualFold(s, name)
97 | }) {
98 | newTags = append(newTags, &Tag{Name: name})
99 | }
100 | }
101 |
102 | if len(newTags) == 0 {
103 | return tags, nil
104 | }
105 | for _, tag := range newTags {
106 | if !user.IsAdmin {
107 | if len(tag.Name) > 15 && len([]rune(tag.Name)) > 10 {
108 | return nil, common.BadRequest("标签长度不能超过 10 个字符")
109 | }
110 | if strings.HasPrefix(tag.Name, "#") {
111 | return nil, common.BadRequest("只有管理员才能创建 # 开头的 tag")
112 | }
113 | if strings.HasPrefix(tag.Name, "@") {
114 | return nil, common.BadRequest("只有管理员才能创建 @ 开头的 tag")
115 | }
116 | if strings.HasPrefix(tag.Name, "*") {
117 | return nil, common.BadRequest("只有管理员才能创建 * 开头的 tag")
118 | }
119 | }
120 | }
121 |
122 | var wg sync.WaitGroup
123 | for _, tag := range newTags {
124 | wg.Add(1)
125 | go func(tag *Tag) {
126 | sensitiveResp, err := sensitive.CheckSensitive(sensitive.ParamsForCheck{
127 | Content: tag.Name,
128 | Id: time.Now().UnixNano(),
129 | TypeName: sensitive.TypeTag,
130 | })
131 | if err != nil {
132 | return
133 | }
134 | tag.IsSensitive = !sensitiveResp.Pass
135 | wg.Done()
136 | }(tag)
137 | }
138 | wg.Wait()
139 |
140 | err = tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&newTags).Error
141 |
142 | go UpdateTagCache(nil)
143 |
144 | return append(tags, newTags...), err
145 | }
146 |
147 | func UpdateTagCache(tags Tags) {
148 | var err error
149 | if len(tags) == 0 {
150 | err := DB.Order("temperature desc").Find(&tags).Error
151 | if err != nil {
152 | log.Printf("update tag cache error: %s", err)
153 | }
154 | }
155 | err = utils.SetCache("tags", tags, 10*time.Minute)
156 | if err != nil {
157 | log.Printf("update tag cache error: %s", err)
158 | }
159 | }
160 |
161 | func (tag *Tag) Preprocess(c *fiber.Ctx) error {
162 | return Tags{tag}.Preprocess(c)
163 | }
164 |
165 | func (tags Tags) Preprocess(_ *fiber.Ctx) error {
166 | tagIDs := make([]int, len(tags))
167 | IdTagMapping := make(map[int]*Tag)
168 | for i, tag := range tags {
169 | if tags[i].Sensitive() {
170 | tags[i].Name = ""
171 | }
172 | tagIDs[i] = tag.ID
173 | IdTagMapping[tag.ID] = tags[i]
174 | }
175 | return nil
176 | }
177 |
178 | func (tag *Tag) Sensitive() bool {
179 | if tag == nil {
180 | return false
181 | }
182 | if tag.IsActualSensitive != nil {
183 | return *tag.IsActualSensitive
184 | }
185 | return tag.IsSensitive
186 | }
187 |
--------------------------------------------------------------------------------
/models/url_hostname_whitelist.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | type UrlHostnameWhitelist struct {
4 | ID int `json:"id" gorm:"primaryKey"`
5 | Hostname string `json:"hostname" gorm:"size:255;not null"`
6 | }
7 |
--------------------------------------------------------------------------------
/models/user.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "time"
7 |
8 | "golang.org/x/exp/slices"
9 |
10 | "treehole_next/config"
11 |
12 | "github.com/gofiber/fiber/v2"
13 | "github.com/opentreehole/go-common"
14 | "github.com/rs/zerolog/log"
15 | "gorm.io/gorm"
16 | )
17 |
18 | type User struct {
19 | /// base info
20 | ID int `json:"id" gorm:"primaryKey"`
21 |
22 | Config UserConfig `json:"config" gorm:"serializer:json;not null;default:\"{}\""`
23 |
24 | BanDivision map[int]*time.Time `json:"-" gorm:"serializer:json;not null;default:\"{}\""`
25 |
26 | OffenceCount int `json:"-" gorm:"not null;default:0"`
27 |
28 | BanReport *time.Time `json:"-" gorm:"serializer:json"`
29 |
30 | BanReportCount int `json:"-" gorm:"not null;default:0"`
31 |
32 | DefaultSpecialTag string `json:"default_special_tag" gorm:"size:32"`
33 |
34 | SpecialTags []string `json:"special_tags" gorm:"serializer:json;not null;default:\"[]\""`
35 |
36 | FavoriteGroupCount int `json:"favorite_group_count" gorm:"not null;default:0"`
37 |
38 | /// association fields, should add foreign key
39 |
40 | // holes owned by the user
41 | UserHoles Holes `json:"-" gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"`
42 |
43 | // floors owned by the user
44 | UserFloors Floors `json:"-" gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"`
45 |
46 | // reports made by the user; a user has many report
47 | UserReports Reports `json:"-"`
48 |
49 | // floors liked by the user
50 | UserLikedFloors Floors `json:"-" gorm:"many2many:floor_like;constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"`
51 |
52 | // floor history made by the user
53 | UserFloorHistory FloorHistorySlice `json:"-" gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"`
54 |
55 | // user punishments on division
56 | UserPunishments Punishments `json:"-"`
57 |
58 | // punishments made by this user
59 | UserMakePunishments Punishments `json:"-" gorm:"foreignKey:MadeBy"`
60 |
61 | // user punishments on report
62 | UserReportPunishments ReportPunishments `json:"-"`
63 |
64 | // report punishments made by this user
65 | UserMakeReportPunishments ReportPunishments `json:"-" gorm:"foreignKey:MadeBy"`
66 |
67 | UserSubscription Holes `json:"-" gorm:"many2many:user_subscription;constraint:OnUpdate:CASCADE,OnDelete:CASCADE;"`
68 |
69 | /// dynamically generated field
70 |
71 | UserID int `json:"user_id" gorm:"-:all"`
72 |
73 | Permission struct {
74 | // 管理员权限到期时间
75 | Admin time.Time `json:"admin"`
76 | // key: division_id value: 对应分区禁言解除时间
77 | Silent map[int]*time.Time `json:"silent"`
78 | OffenseCount int `json:"offense_count"`
79 | } `json:"permission" gorm:"-:all"`
80 |
81 | // get from jwt
82 | IsAdmin bool `json:"is_admin" gorm:"-:all"`
83 | JoinedTime time.Time `json:"joined_time" gorm:"-:all"`
84 | Nickname string `json:"nickname" gorm:"-:all"`
85 | HasAnsweredQuestions bool `json:"has_answered_questions" gorm:"-:all"`
86 | }
87 |
88 | type Users []*User
89 |
90 | type UserConfig struct {
91 | // used when notify
92 | Notify []string `json:"notify"`
93 |
94 | // 对折叠内容的处理
95 | // fold 折叠, hide 隐藏, show 展示
96 | ShowFolded string `json:"show_folded"`
97 | }
98 |
99 | var defaultUserConfig = UserConfig{
100 | Notify: []string{"mention", "favorite", "report"},
101 | ShowFolded: "hide",
102 | }
103 |
104 | var showFoldedOptions = []string{"hide", "fold", "show"}
105 |
106 | func (user *User) GetID() int {
107 | return user.ID
108 | }
109 |
110 | func (user *User) AfterCreate(_ *gorm.DB) error {
111 | user.UserID = user.ID
112 | return nil
113 | }
114 |
115 | func (user *User) AfterFind(_ *gorm.DB) error {
116 | user.UserID = user.ID
117 | return nil
118 | }
119 |
120 | var (
121 | maxTime time.Time
122 | minTime time.Time
123 | )
124 |
125 | func init() {
126 | var err error
127 | maxTime, err = time.Parse(time.RFC3339, "9999-01-01T00:00:00+00:00")
128 | if err != nil {
129 | log.Fatal().Err(err).Send()
130 | }
131 | minTime = time.Unix(0, 0)
132 | }
133 |
134 | // GetCurrLoginUser get current login user
135 | // In dev or test mode, return a default admin user
136 | func GetCurrLoginUser(c *fiber.Ctx) (*User, error) {
137 | user := &User{
138 | BanDivision: make(map[int]*time.Time),
139 | }
140 | if config.Config.Mode == "dev" || config.Config.Mode == "test" {
141 | user.ID = 1
142 | user.IsAdmin = true
143 | user.HasAnsweredQuestions = true
144 | return user, nil
145 | }
146 |
147 | if c.Locals("user") != nil {
148 | return c.Locals("user").(*User), nil
149 | }
150 |
151 | // get id
152 | userID, err := common.GetUserID(c)
153 | if err != nil {
154 | return nil, err
155 | }
156 |
157 | // parse JWT
158 | err = common.ParseJWTToken(common.GetJWTToken(c), user)
159 | if err != nil {
160 | return nil, err
161 | }
162 |
163 | // load user from database in transaction
164 | err = user.LoadUserByID(userID)
165 |
166 | if user.IsAdmin {
167 | user.Permission.Admin = maxTime
168 | } else {
169 | user.Permission.Admin = minTime
170 | }
171 | user.Permission.Silent = user.BanDivision
172 | user.Permission.OffenseCount = user.OffenceCount
173 |
174 | if config.Config.UserAllShowHidden {
175 | user.Config.ShowFolded = "hide"
176 | }
177 |
178 | // save user in c.Locals
179 | c.Locals("user", user)
180 |
181 | return user, err
182 | }
183 |
184 | func (user *User) LoadUserByID(userID int) error {
185 | return DB.Transaction(func(tx *gorm.DB) error {
186 | err := tx.Take(&user, userID).Error
187 | if err != nil {
188 | if errors.Is(err, gorm.ErrRecordNotFound) {
189 | // insert user if not found
190 | user.ID = userID
191 | user.Config = defaultUserConfig
192 | err = tx.Create(&user).Error
193 | if err != nil {
194 | return err
195 | }
196 | } else {
197 | return err
198 | }
199 | }
200 |
201 | err = CheckDefaultFavoriteGroup(tx, userID)
202 | if err != nil {
203 | return err
204 | }
205 |
206 | // check latest permission
207 | modified := false
208 | for divisionID := range user.BanDivision {
209 | endTime := user.BanDivision[divisionID]
210 | if endTime != nil && endTime.Before(time.Now()) {
211 | delete(user.BanDivision, divisionID)
212 | modified = true
213 | }
214 | }
215 |
216 | // check config
217 | if !slices.Contains(showFoldedOptions, user.Config.ShowFolded) {
218 | user.Config.ShowFolded = defaultUserConfig.ShowFolded
219 | modified = true
220 | }
221 |
222 | if user.Config.Notify == nil {
223 | user.Config.Notify = defaultUserConfig.Notify
224 | modified = true
225 | }
226 |
227 | if modified {
228 | err = tx.Select("BanDivision", "Config").Save(&user).Error
229 | if err != nil {
230 | return err
231 | }
232 | }
233 |
234 | return nil
235 | })
236 | }
237 |
238 | func (user *User) BanDivisionMessage(divisionID int) string {
239 | if user.BanDivision[divisionID] == nil {
240 | return fmt.Sprintf("您在此板块已被禁言")
241 | } else {
242 | return fmt.Sprintf(
243 | "您在此板块已被禁言,解封时间:%s",
244 | user.BanDivision[divisionID].Format("2006-01-02 15:04:05"))
245 | }
246 | }
247 |
248 | func (user *User) BanReportMessage() string {
249 | if user.BanReport == nil {
250 | return fmt.Sprintf("您已被限制使用举报功能")
251 | } else {
252 | return fmt.Sprintf(
253 | "您已被限制使用举报功能,解封时间:%s",
254 | user.BanReport.Format("2006-01-02 15:04:05"))
255 | }
256 | }
257 |
--------------------------------------------------------------------------------
/models/user_favorite.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import (
4 | "github.com/opentreehole/go-common"
5 | "time"
6 |
7 | "gorm.io/gorm"
8 | "gorm.io/gorm/clause"
9 | "gorm.io/plugin/dbresolver"
10 |
11 | "treehole_next/utils"
12 | )
13 |
14 | type UserFavorite struct {
15 | UserID int `json:"user_id" gorm:"primaryKey"`
16 | FavoriteGroupID int `json:"favorite_group_id" gorm:"primaryKey"`
17 | HoleID int `json:"hole_id" gorm:"primaryKey"`
18 | CreatedAt time.Time `json:"time_created"`
19 | }
20 |
21 | type UserFavorites []UserFavorite
22 |
23 | func (UserFavorite) TableName() string {
24 | return "user_favorites"
25 | }
26 |
27 | func IsFavoriteGroupExist(tx *gorm.DB, userID int, favoriteGroupID int) bool {
28 | var num int64
29 | tx.Model(&FavoriteGroup{}).Where("user_id = ? AND favorite_group_id = ? AND deleted = false", userID, favoriteGroupID).Count(&num)
30 | return num > 0
31 | }
32 |
33 | // ModifyUserFavorite only take effect in the same favorite_group
34 | func ModifyUserFavorite(tx *gorm.DB, userID int, holeIDs []int, favoriteGroupID int) error {
35 | if len(holeIDs) == 0 {
36 | return nil
37 | }
38 | if !IsFavoriteGroupExist(tx, userID, favoriteGroupID) {
39 | return common.NotFound("收藏夹不存在")
40 | }
41 | if !IsHolesExist(tx, holeIDs) {
42 | return common.Forbidden("帖子不存在")
43 | }
44 | return tx.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error {
45 | var oldHoleIDs []int
46 | err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
47 | Model(&UserFavorite{}).Where("user_id = ? AND favorite_group_id = ?", userID, favoriteGroupID).
48 | Pluck("hole_id", &oldHoleIDs).Error
49 | if err != nil {
50 | return err
51 | }
52 |
53 | // remove user_favorite that not in holeIDs
54 | var removingHoleIDMapping = make(map[int]bool)
55 | for _, holeID := range oldHoleIDs {
56 | removingHoleIDMapping[holeID] = true
57 | }
58 | for _, holeID := range holeIDs {
59 | if removingHoleIDMapping[holeID] {
60 | delete(removingHoleIDMapping, holeID)
61 | }
62 | }
63 | removingHoleIDs := utils.Keys(removingHoleIDMapping)
64 | if len(removingHoleIDs) > 0 {
65 | deleteUserFavorite := make(UserFavorites, 0)
66 | for _, holeID := range removingHoleIDs {
67 | deleteUserFavorite = append(deleteUserFavorite, UserFavorite{UserID: userID, HoleID: holeID, FavoriteGroupID: favoriteGroupID})
68 | }
69 | err = tx.Delete(&deleteUserFavorite).Error
70 | if err != nil {
71 | return err
72 | }
73 | }
74 |
75 | // insert user_favorite that not in oldHoleIDs
76 | var newHoleIDMapping = make(map[int]bool)
77 | for _, holeID := range holeIDs {
78 | newHoleIDMapping[holeID] = true
79 | }
80 | for _, holeID := range oldHoleIDs {
81 | if newHoleIDMapping[holeID] {
82 | delete(newHoleIDMapping, holeID)
83 | }
84 | }
85 | newHoleIDs := utils.Keys(newHoleIDMapping)
86 | if len(newHoleIDs) > 0 {
87 | insertUserFavorite := make(UserFavorites, 0)
88 | for _, holeID := range newHoleIDs {
89 | insertUserFavorite = append(insertUserFavorite, UserFavorite{UserID: userID, HoleID: holeID, FavoriteGroupID: favoriteGroupID})
90 | }
91 | err = tx.Create(&insertUserFavorite).Error
92 | if err != nil {
93 | return err
94 | }
95 | }
96 | return tx.Model(&FavoriteGroup{}).Where("user_id = ? AND favorite_group_id = ?", userID, favoriteGroupID).Update("count", len(holeIDs)).Error
97 | })
98 | }
99 |
100 | func AddUserFavorite(tx *gorm.DB, userID int, holeID int, favoriteGroupID int) error {
101 | if !IsFavoriteGroupExist(tx, userID, favoriteGroupID) {
102 | return common.NotFound("收藏夹不存在")
103 | }
104 | if !IsHolesExist(tx, []int{holeID}) {
105 | return common.NotFound("帖子不存在")
106 | }
107 | var err = tx.Clauses(clause.OnConflict{
108 | DoUpdates: clause.Assignments(Map{"created_at": time.Now()}),
109 | }).Create(&UserFavorite{
110 | UserID: userID,
111 | HoleID: holeID,
112 | FavoriteGroupID: favoriteGroupID,
113 | }).Error
114 | if err != nil {
115 | return err
116 | }
117 | return tx.Clauses(dbresolver.Write).Model(&FavoriteGroup{}).
118 | Where("user_id = ? AND favorite_group_id = ?", userID, favoriteGroupID).Update("count", gorm.Expr("count + 1")).Error
119 | }
120 |
121 | // UserGetFavoriteData get all favorite data of a user
122 | func UserGetFavoriteData(tx *gorm.DB, userID int) ([]int, error) {
123 | data := make([]int, 0, 10)
124 | err := tx.Clauses(dbresolver.Write).Model(&UserFavorite{}).Where("user_id = ?", userID).Distinct().
125 | Pluck("hole_id", &data).Error
126 | return data, err
127 | }
128 |
129 | // UserGetFavoriteDataByFavoriteGroup get favorite data in specific favorite group
130 | func UserGetFavoriteDataByFavoriteGroup(tx *gorm.DB, userID int, favoriteGroupID int) ([]int, error) {
131 | if !IsFavoriteGroupExist(tx, userID, favoriteGroupID) {
132 | return nil, common.NotFound("收藏夹不存在")
133 | }
134 | data := make([]int, 0, 10)
135 | err := tx.Clauses(dbresolver.Write).Model(&UserFavorite{}).
136 | Where("user_id = ? AND favorite_group_id = ?", userID, favoriteGroupID).Pluck("hole_id", &data).Error
137 | return data, err
138 | }
139 |
140 | // DeleteUserFavorite delete user favorite
141 | // if user favorite hole only once, delete the hole
142 | // otherwise, delete the favorite in the specific favorite group
143 | func DeleteUserFavorite(tx *gorm.DB, userID int, holeID int, favoriteGroupID int) error {
144 | if !IsFavoriteGroupExist(tx, userID, favoriteGroupID) {
145 | return common.NotFound("收藏夹不存在")
146 | }
147 | if !IsHolesExist(tx, []int{holeID}) {
148 | return common.NotFound("帖子不存在")
149 | }
150 | return tx.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error {
151 | err := tx.Delete(&UserFavorite{UserID: userID, HoleID: holeID, FavoriteGroupID: favoriteGroupID}).Error
152 | if err != nil {
153 | return err
154 | }
155 | return tx.Clauses(dbresolver.Write).Model(&FavoriteGroup{}).Where("user_id = ? AND favorite_group_id = ?", userID, favoriteGroupID).Update("count", gorm.Expr("count - 1")).Error
156 | })
157 | }
158 |
159 | // MoveUserFavorite move holes that are really in the fromFavoriteGroup
160 | func MoveUserFavorite(tx *gorm.DB, userID int, holeIDs []int, fromFavoriteGroupID int, toFavoriteGroupID int) error {
161 | if fromFavoriteGroupID == toFavoriteGroupID {
162 | return nil
163 | }
164 | if len(holeIDs) == 0 {
165 | return nil
166 | }
167 | if !IsFavoriteGroupExist(tx, userID, fromFavoriteGroupID) || !IsFavoriteGroupExist(tx, userID, toFavoriteGroupID) {
168 | return common.NotFound("收藏夹不存在")
169 | }
170 | if !IsHolesExist(tx, holeIDs) {
171 | return common.Forbidden("帖子不存在")
172 | }
173 | return tx.Clauses(dbresolver.Write).Transaction(func(tx *gorm.DB) error {
174 | var oldHoleIDs []int
175 | err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
176 | Model(&UserFavorite{}).Where("user_id = ? AND favorite_group_id = ?", userID, fromFavoriteGroupID).
177 | Pluck("hole_id", &oldHoleIDs).Error
178 | if err != nil {
179 | return err
180 | }
181 |
182 | // move user_favorite that in holeIDs
183 | var removingHoleIDMapping = make(map[int]bool)
184 | var removingHoleIDs []int
185 | for _, holeID := range oldHoleIDs {
186 | removingHoleIDMapping[holeID] = true
187 | }
188 | for _, holeID := range holeIDs {
189 | if removingHoleIDMapping[holeID] {
190 | removingHoleIDs = append(removingHoleIDs, holeID)
191 | }
192 | }
193 | if len(removingHoleIDs) > 0 {
194 | err = tx.Table("user_favorites").
195 | Where("user_id = ? AND favorite_group_id = ? AND hole_id IN ?", userID, fromFavoriteGroupID, removingHoleIDs).
196 | Updates(map[string]interface{}{"favorite_group_id": toFavoriteGroupID}).Error
197 | if err != nil {
198 | return err
199 | }
200 | }
201 | err = tx.Model(&FavoriteGroup{}).Where("user_id = ? AND favorite_group_id = ?", userID, fromFavoriteGroupID).Update("count", gorm.Expr("count - ?", len(removingHoleIDs))).Error
202 | if err != nil {
203 | return err
204 | }
205 | return tx.Model(&FavoriteGroup{}).Where("user_id = ? AND favorite_group_id = ?", userID, toFavoriteGroupID).Update("count", gorm.Expr("count + ?", len(removingHoleIDs))).Error
206 | })
207 | }
208 |
--------------------------------------------------------------------------------
/models/user_subscription.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import (
4 | "time"
5 |
6 | "gorm.io/gorm"
7 | "gorm.io/gorm/clause"
8 | "gorm.io/plugin/dbresolver"
9 | )
10 |
11 | type UserSubscription struct {
12 | UserID int `json:"user_id" gorm:"primaryKey"`
13 | HoleID int `json:"hole_id" gorm:"primaryKey"`
14 | CreatedAt time.Time `json:"time_created"`
15 | }
16 |
17 | type UserSubscriptions []UserSubscription
18 |
19 | func (UserSubscription) TableName() string {
20 | return "user_subscription"
21 | }
22 |
23 | func UserGetSubscriptionData(tx *gorm.DB, userID int) ([]int, error) {
24 | data := make([]int, 0, 10)
25 | err := tx.Clauses(dbresolver.Write).Raw("SELECT hole_id FROM user_subscription WHERE user_id = ? ORDER BY created_at", userID).Scan(&data).Error
26 | return data, err
27 | }
28 |
29 | func AddUserSubscription(tx *gorm.DB, userID int, holeID int) error {
30 | return tx.Clauses(clause.OnConflict{
31 | DoUpdates: clause.Assignments(Map{"created_at": time.Now()}),
32 | }).Create(&UserSubscription{
33 | UserID: userID,
34 | HoleID: holeID}).Error
35 | }
36 |
--------------------------------------------------------------------------------
/models/user_test.go:
--------------------------------------------------------------------------------
1 | package models
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/opentreehole/go-common"
7 | "github.com/stretchr/testify/assert"
8 | )
9 |
10 | func TestParseJWT(t *testing.T) {
11 | var user User
12 | jwt := "Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ1aWQiOjE2LCJpc3MiOiJEU2lSa2NvWDJZV3dta3VqM3FFdFVxSE1uUnNvMjZQYiIsImlhdCI6MTY2MjUyNzg5OSwiaWQiOjE2LCJpc19hZG1pbiI6ZmFsc2UsIm5pY2tuYW1lIjoidXNlciIsIm9mZmVuc2VfY291bnQiOjAsInJvbGVzIjpbXSwidHlwZSI6ImFjY2VzcyIsImV4cCI6MTY2MjUyOTY5OX0.Ov_8cJay-Ta0jsPYUx1D-XDc_D1WK1iTdjnuEKAelaM"
13 | err := common.ParseJWTToken(jwt, &user)
14 | assert.Nilf(t, err, "ParseJWTToken failed: %v", err)
15 | }
16 |
--------------------------------------------------------------------------------
/tests/default.go:
--------------------------------------------------------------------------------
1 | package tests
2 |
3 | var largeInt = 1145141919810
4 |
--------------------------------------------------------------------------------
/tests/default_test.go:
--------------------------------------------------------------------------------
1 | package tests
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/assert"
7 | )
8 |
9 | func TestIndex(t *testing.T) {
10 | testAPI(t, "get", "/", 302, nil)
11 | testAPI(t, "get", "/api", 200, nil)
12 | web404 := testAPI(t, "get", "/404", 404, nil)
13 | assert.EqualValues(t, "Cannot GET /404", web404["message"])
14 | }
15 |
16 | func TestDocs(t *testing.T) {
17 | testCommon(t, "get", "/docs", 302)
18 | testCommon(t, "get", "/docs/index.html", 200)
19 | }
20 |
--------------------------------------------------------------------------------
/tests/division_test.go:
--------------------------------------------------------------------------------
1 | package tests
2 |
3 | import (
4 | "strconv"
5 | "testing"
6 |
7 | . "treehole_next/models"
8 |
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | func TestGetDivision(t *testing.T) {
13 | var divisionPinned = []int{0, 2, 3, 1, largeInt}
14 |
15 | var d Division
16 | DB.First(&d, 1)
17 | d.Pinned = divisionPinned
18 | DB.Save(&d)
19 |
20 | var division Division
21 | testAPIModel(t, "get", "/api/divisions/1", 200, &division)
22 | // test pinned order
23 | respPinned := make([]int, 3)
24 | for i, p := range division.Holes {
25 | respPinned[i] = p.ID
26 | }
27 | assert.Equal(t, []int{2, 3, 1}, respPinned)
28 | }
29 |
30 | func TestListDivision(t *testing.T) {
31 | // return all divisions
32 | var length int64
33 | DB.Table("division").Count(&length)
34 | resp := testAPIArray(t, "get", "/api/divisions", 200)
35 | assert.Equal(t, length, int64(len(resp)))
36 | }
37 |
38 | func TestAddDivision(t *testing.T) {
39 | data := Map{"name": "TestAddDivision", "description": "TestAddDivisionDescription"}
40 | testAPI(t, "post", "/api/divisions", 201, data)
41 |
42 | // duplicate post, return 200 and change nothing
43 | data["description"] = "another"
44 | resp := testAPI(t, "post", "/api/divisions", 200, data)
45 | assert.Equal(t, "TestAddDivisionDescription", resp["description"])
46 | }
47 |
48 | func TestModifyDivision(t *testing.T) {
49 | pinned := []int{3, 2, 5, 1, 4}
50 | data := Map{"name": "modify", "description": "modify", "pinned": pinned}
51 |
52 | var division Division
53 | testAPIModel(t, "put", "/api/divisions/1", 200, &division, data)
54 |
55 | // test modify
56 | assert.Equal(t, "modify", division.Name)
57 | assert.Equal(t, "modify", division.Description)
58 |
59 | // test pinned order
60 | respPinned := make([]int, 5)
61 | for i, d := range division.Holes {
62 | respPinned[i] = d.ID
63 | }
64 | assert.Equal(t, pinned, respPinned)
65 | }
66 |
67 | func TestDeleteDivision(t *testing.T) {
68 | id := 3
69 | toID := 2
70 |
71 | hole := Hole{DivisionID: id}
72 | DB.Create(&hole)
73 | testAPI(t, "delete", "/api/divisions/"+strconv.Itoa(id), 204, Map{"to": toID})
74 | testAPI(t, "delete", "/api/divisions/"+strconv.Itoa(id), 204, Map{}) // repeat delete
75 |
76 | // deleted
77 | var d Division
78 | result := DB.First(&d, id)
79 | assert.True(t, result.Error != nil)
80 |
81 | // hole moved
82 | DB.First(&hole, hole.ID)
83 | assert.Equal(t, toID, hole.DivisionID)
84 |
85 | }
86 |
87 | func TestDeleteDivisionDefaultValue(t *testing.T) {
88 | id := 4
89 | toID := 1
90 |
91 | // if create hole here, say database lock, pending enquiry
92 | var hole, getHole Hole
93 | DB.Where("division_id = ?", id).First(&hole)
94 | testAPI(t, "delete", "/api/divisions/"+strconv.Itoa(id), 204, Map{})
95 |
96 | // hole moved
97 | DB.Take(&getHole, hole.ID)
98 | assert.Equal(t, toID, getHole.DivisionID)
99 |
100 | }
101 |
--------------------------------------------------------------------------------
/tests/favorite_test.go:
--------------------------------------------------------------------------------
1 | package tests
2 |
3 | import (
4 | "testing"
5 |
6 | . "treehole_next/models"
7 |
8 | "github.com/stretchr/testify/assert"
9 | "golang.org/x/exp/slices"
10 | )
11 |
12 | func TestListFavorites(t *testing.T) {
13 | var holes Holes
14 | testAPIModel(t, "get", "/api/user/favorites", 200, &holes)
15 | assert.EqualValues(t, 10, len(holes))
16 | }
17 |
18 | func TestAddFavorite(t *testing.T) {
19 | data := Map{"hole_id": 11}
20 | testAPI(t, "post", "/api/user/favorites", 201, data)
21 | testAPI(t, "post", "/api/user/favorites", 201, data) // duplicated, refresh updated_at
22 | }
23 |
24 | func TestModifyFavorites(t *testing.T) {
25 | data := Map{"hole_ids": []int{1, 2, 5, 6, 7}}
26 | testAPI(t, "put", "/api/user/favorites", 201, data)
27 | testAPI(t, "put", "/api/user/favorites", 201, data) // duplicated
28 | var userFavorites []UserFavorite
29 | DB.Where("user_id = ?", 1).Find(&userFavorites)
30 | assert.EqualValues(t, 5, len(userFavorites))
31 | }
32 |
33 | func TestDeleteFavorite(t *testing.T) {
34 | data := Map{"hole_id": 1}
35 | testAPI(t, "delete", "/api/user/favorites", 200, data)
36 | var userFavorites []UserFavorite
37 | DB.Where("user_id = ?", 1).Find(&userFavorites)
38 | assert.EqualValues(t, false, slices.Contains(userFavorites, UserFavorite{UserID: 1, HoleID: 1}))
39 | favouriteLen := len(userFavorites)
40 |
41 | testAPI(t, "delete", "/api/user/favorites", 200, data) // duplicated
42 | DB.Where("user_id = ?", 1).Find(&userFavorites)
43 | assert.EqualValues(t, favouriteLen, len(userFavorites))
44 | }
45 |
--------------------------------------------------------------------------------
/tests/floor_test.go:
--------------------------------------------------------------------------------
1 | package tests
2 |
3 | import (
4 | "strconv"
5 | "strings"
6 | "testing"
7 |
8 | "github.com/goccy/go-json"
9 |
10 | . "treehole_next/config"
11 | . "treehole_next/models"
12 |
13 | "github.com/stretchr/testify/assert"
14 | )
15 |
16 | func TestListFloorsInAHole(t *testing.T) {
17 | var hole Hole
18 | DB.Where("division_id = ?", 7).First(&hole)
19 | var floors Floors
20 | testAPIModel(t, "get", "/api/holes/"+strconv.Itoa(hole.ID)+"/floors", 200, &floors)
21 | assert.EqualValues(t, Config.Size, len(floors))
22 | if len(floors) != 0 {
23 | assert.EqualValues(t, "1", floors[0].Content)
24 | }
25 |
26 | // size
27 | size := 38
28 | data := Map{"size": size}
29 | testAPIModelWithQuery(t, "get", "/api/holes/"+strconv.Itoa(hole.ID)+"/floors", 200, &floors, data)
30 | assert.EqualValues(t, size, len(floors))
31 | if len(floors) != 0 {
32 | assert.EqualValues(t, "1", floors[0].Content)
33 | }
34 |
35 | // offset
36 | offset := 7
37 | data = Map{"offset": offset}
38 | testAPIModelWithQuery(t, "get", "/api/holes/"+strconv.Itoa(hole.ID)+"/floors", 200, &floors, data)
39 | assert.EqualValues(t, Config.Size, len(floors))
40 | if len(floors) != 0 {
41 | assert.EqualValues(t, strings.Repeat("1", offset+1), floors[0].Content)
42 | }
43 | }
44 |
45 | func TestListFloorsOld(t *testing.T) {
46 | var hole Hole
47 | DB.Where("division_id = ?", 7).First(&hole)
48 | data := Map{"hole_id": hole.ID}
49 | var floors Floors
50 | testAPIModelWithQuery(t, "get", "/api/floors", 200, &floors, data)
51 | assert.EqualValues(t, Config.MaxSize, len(floors))
52 | if len(floors) != 0 {
53 | assert.EqualValues(t, "1", floors[0].Content)
54 | }
55 | }
56 |
57 | func TestGetFloor(t *testing.T) {
58 | var hole Hole
59 | DB.Where("division_id = ?", 7).First(&hole)
60 | var floor Floor
61 | DB.Where("hole_id = ?", hole.ID).First(&floor)
62 | var getFloor Floor
63 | testAPIModel(t, "get", "/api/floors/"+strconv.Itoa(floor.ID), 200, &getFloor)
64 | assert.EqualValues(t, floor.Content, getFloor.Content)
65 |
66 | testAPIModel(t, "get", "/api/floors/"+strconv.Itoa(largeInt), 404, &getFloor)
67 | }
68 |
69 | func TestCreateFloor(t *testing.T) {
70 | var hole Hole
71 | DB.Where("division_id = ?", 7).Offset(1).First(&hole)
72 | content := "123"
73 | data := Map{"content": content}
74 | var getFloor Floor
75 | testAPIModel(t, "post", "/api/holes/"+strconv.Itoa(hole.ID)+"/floors", 201, &getFloor, data)
76 | assert.EqualValues(t, content, getFloor.Content)
77 |
78 | var floors Floors
79 | DB.Where("hole_id = ?", hole.ID).Find(&floors)
80 | assert.EqualValues(t, 2, len(floors))
81 |
82 | testAPIModel(t, "post", "/api/holes/"+strconv.Itoa(largeInt)+"/floors", 404, &getFloor, data)
83 | }
84 |
85 | func TestCreateFloorOld(t *testing.T) {
86 | var hole Hole
87 | DB.Where("division_id = ?", 7).Offset(2).First(&hole)
88 | content := "1234"
89 | data := Map{"hole_id": hole.ID, "content": content}
90 | type CreateOLdResponse struct {
91 | Data Floor
92 | Message string
93 | }
94 | var getFloor CreateOLdResponse
95 | rsp := testCommon(t, "post", "/api/floors", 201, data)
96 | err := json.Unmarshal(rsp, &getFloor)
97 | assert.Nilf(t, err, "Unmarshal Failed")
98 | assert.EqualValues(t, content, getFloor.Data.Content)
99 |
100 | var floors Floors
101 | DB.Where("hole_id = ?", hole.ID).Find(&floors)
102 | assert.EqualValues(t, 2, len(floors))
103 | if len(floors) != 0 {
104 | assert.EqualValues(t, content, floors[1].Content)
105 | }
106 |
107 | testCommon(t, "post", "/api/holes/"+strconv.Itoa(123456)+"/floors", 404, data)
108 | }
109 |
110 | func TestModifyFloor(t *testing.T) {
111 | var hole Hole
112 | DB.Where("division_id = ?", 7).Offset(3).First(&hole)
113 | var floor Floor
114 | DB.Where("hole_id = ?", hole.ID).First(&floor)
115 | content := "12341234"
116 | data := Map{"content": content}
117 | var getFloor Floor
118 |
119 | // modify content
120 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 200, data)
121 |
122 | DB.Find(&getFloor, floor.ID)
123 | assert.EqualValues(t, content, getFloor.Content)
124 |
125 | // modify fold
126 | // test 1: fold == ["test"], fold_v2 == ""
127 | data = Map{"fold": []string{"test"}}
128 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 200, data)
129 | DB.Find(&getFloor, floor.ID)
130 | assert.EqualValues(t, data["fold"].([]string)[0], getFloor.Fold)
131 |
132 | // test2: fold == [], fold_v2 == "": expect reset fold
133 | data = Map{"fold": []string{}}
134 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 200, data)
135 | DB.Find(&getFloor, floor.ID)
136 | assert.EqualValues(t, "", getFloor.Fold)
137 |
138 | // test3: fold == [], fold_v2 == "test_test"
139 | data = Map{"fold_v2": "test_test"}
140 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 200, data)
141 | DB.Find(&getFloor, floor.ID)
142 | assert.EqualValues(t, data["fold_v2"], getFloor.Fold)
143 |
144 | // test4: fold == [], fold_v2 == "": expect reset fold
145 | data = Map{"fold": []string{}}
146 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 200, data)
147 | DB.Find(&getFloor, floor.ID)
148 | assert.EqualValues(t, "", getFloor.Fold)
149 |
150 | // test5: fold == ["test"], fold_v2 == "test_test": expect "test_test", fold_v2 has the priority
151 | data = Map{"fold": []string{"test"}, "fold_v2": "test_test"}
152 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 200, data)
153 | DB.Find(&getFloor, floor.ID)
154 | assert.EqualValues(t, "test_test", getFloor.Fold)
155 |
156 | // test6: fold == nil, fold_v2 == "": do nothing; 无效请求
157 | data = Map{}
158 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 400, data)
159 | DB.Find(&getFloor, floor.ID)
160 | assert.EqualValues(t, "test_test", getFloor.Fold)
161 |
162 | // modify like add old
163 | data = Map{"like": "add"}
164 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 200, data)
165 | DB.Find(&getFloor, floor.ID)
166 | assert.EqualValues(t, 1, getFloor.Like)
167 |
168 | // modify like reset old
169 | data = Map{"like": "cancel"}
170 | testAPI(t, "put", "/api/floors/"+strconv.Itoa(floor.ID), 200, data)
171 | DB.Find(&getFloor, floor.ID)
172 | assert.EqualValues(t, 0, getFloor.Like)
173 | }
174 |
175 | func TestModifyFloorLike(t *testing.T) {
176 | var hole Hole
177 | DB.Where("division_id = ?", 7).Offset(4).First(&hole)
178 | var floor Floor
179 | DB.Where("hole_id = ?", hole.ID).First(&floor)
180 |
181 | // like
182 | for i := 0; i < 10; i++ {
183 | testAPI(t, "post", "/api/floors/"+strconv.Itoa(floor.ID)+"/like/1", 200)
184 | }
185 | DB.First(&floor, floor.ID)
186 | assert.EqualValues(t, 1, floor.Like)
187 | assert.EqualValues(t, 0, floor.Dislike)
188 |
189 | // dislike
190 | for i := 0; i < 15; i++ {
191 | testAPI(t, "post", "/api/floors/"+strconv.Itoa(floor.ID)+"/like/-1", 200)
192 | }
193 | DB.First(&floor, floor.ID)
194 | assert.EqualValues(t, 0, floor.Like)
195 | assert.EqualValues(t, 1, floor.Dislike)
196 |
197 | // reset
198 | testAPI(t, "post", "/api/floors/"+strconv.Itoa(floor.ID)+"/like/0", 200)
199 | DB.First(&floor, floor.ID)
200 | assert.EqualValues(t, 0, floor.Like)
201 | }
202 |
203 | func TestDeleteFloor(t *testing.T) {
204 | var hole Hole
205 | DB.Where("division_id = ?", 7).Offset(5).First(&hole)
206 | var floor Floor
207 | DB.Where("hole_id = ?", hole.ID).First(&floor)
208 | content := "1234567"
209 | data := Map{"delete_reason": content}
210 |
211 | testAPI(t, "delete", "/api/floors/"+strconv.Itoa(floor.ID), 200, data)
212 |
213 | DB.First(&floor, floor.ID)
214 | assert.EqualValues(t, true, floor.Deleted)
215 | var floorHistory FloorHistory
216 | DB.Where("floor_id = ?", floor.ID).First(&floorHistory)
217 | assert.EqualValues(t, content, floorHistory.Reason)
218 |
219 | // permission
220 | floor = Floor{}
221 | DB.Where("hole_id = ?", hole.ID).Offset(1).First(&floor)
222 | testAPI(t, "delete", "/api/floors/"+strconv.Itoa(floor.ID), 200, data)
223 | }
224 |
--------------------------------------------------------------------------------
/tests/hole_test.go:
--------------------------------------------------------------------------------
1 | package tests
2 |
3 | import (
4 | "strconv"
5 | "strings"
6 | "testing"
7 |
8 | . "treehole_next/config"
9 | . "treehole_next/models"
10 | "treehole_next/utils"
11 |
12 | "github.com/stretchr/testify/assert"
13 | )
14 |
15 | func TestListHoleInADivision(t *testing.T) {
16 | var holes Holes
17 | var ids []int
18 |
19 | DB.Raw("SELECT id FROM hole WHERE division_id = 6 AND hidden = 0 ORDER BY updated_at DESC").Scan(&ids)
20 |
21 | testAPIModel(t, "get", "/api/divisions/6/holes", 200, &holes)
22 | assert.Equal(t, ids[:Config.HoleFloorSize], utils.Models2IDSlice(holes))
23 |
24 | testAPIModel(t, "get", "/api/divisions/"+strconv.Itoa(largeInt)+"/holes", 200, &holes) // return empty holes
25 | testAPI(t, "get", "/api/divisions/"+strings.Repeat(strconv.Itoa(largeInt), 15)+"/holes", 500) // huge divisionID
26 | }
27 |
28 | func TestListHolesByTag(t *testing.T) {
29 | var tag Tag
30 | DB.Where("name = ?", "114").First(&tag)
31 | var holes Holes
32 | err := DB.Model(&tag).Association("Holes").Find(&holes)
33 | if err != nil {
34 | t.Fatal(err)
35 | }
36 |
37 | var getHoles Holes
38 | testAPIModel(t, "get", "/api/tags/114/holes", 200, &getHoles)
39 | assert.EqualValues(t, len(holes), len(getHoles))
40 |
41 | // empty holes
42 | testAPIModel(t, "get", "/api/tags/115/holes", 200, &getHoles)
43 | assert.EqualValues(t, Holes{}, getHoles)
44 | }
45 |
46 | func TestCreateHole(t *testing.T) {
47 | content := "abcdef"
48 | data := Map{"content": content, "tags": []Map{{"name": "a"}, {"name": "ab"}, {"name": "abc"}}}
49 | testAPI(t, "post", "/api/divisions/1/holes", 201, data)
50 | data["tags"] = []Map{{"name": "abcd"}, {"name": "ab"}, {"name": "abc"}} // update temperature or create tag
51 | testAPI(t, "post", "/api/divisions/1/holes", 201, data)
52 |
53 | tag := Tag{}
54 | DB.Where("name = ?", "a").First(&tag)
55 | assert.EqualValues(t, 1, tag.Temperature)
56 | tag = Tag{}
57 | DB.Where("name = ?", "abc").First(&tag)
58 | assert.EqualValues(t, 2, tag.Temperature)
59 | assert.EqualValues(t, 2, DB.Model(&tag).Association("Holes").Count())
60 |
61 | data = Map{"content": content, "tags": []Map{}}
62 | testAPI(t, "post", "/api/divisions/1/holes", 400, data) // at least one tag
63 |
64 | content = strings.Repeat("~", 15001)
65 | data = Map{"content": content, "tags": []Map{{"name": "a"}, {"name": "ab"}, {"name": "abc"}}}
66 | testAPI(t, "post", "/api/divisions/1/holes", 400, data) // data no more than 10000
67 |
68 | tags := make([]Map, 11)
69 | for i := range tags {
70 | tags[i] = Map{"name": strconv.Itoa(i)}
71 | }
72 | data = Map{"content": "123456789", "tags": tags} // at most 10 tags
73 | testAPI(t, "post", "/api/divisions/1/holes", 400, data)
74 | }
75 |
76 | func TestCreateHoleOld(t *testing.T) {
77 | content := "abcdef"
78 | tagName := []Map{{"name": "d"}, {"name": "de"}, {"name": "def"}}
79 | division_id := 1
80 | data := Map{"content": content, "tags": tagName, "division_id": division_id}
81 | testAPI(t, "post", "/api/holes", 201, data)
82 | tagName = []Map{{"name": "abc"}, {"name": "defg"}, {"name": "de"}}
83 | data = Map{"content": content, "tags": tagName, "division_id": division_id}
84 | testAPI(t, "post", "/api/holes", 201, data)
85 |
86 | var holes Holes
87 | var tag Tag
88 | DB.Where("name = ?", "def").First(&tag)
89 | err := DB.Model(&tag).Association("Holes").Find(&holes)
90 | if err != nil {
91 | t.Fatal(err)
92 | }
93 | }
94 |
95 | func TestModifyHole(t *testing.T) {
96 | var tag Tag
97 | DB.Where("name = ?", "111").First(&tag)
98 | var holes Holes
99 | err := DB.Model(&tag).Association("Holes").Find(&holes)
100 | if err != nil {
101 | t.Fatal(err)
102 | }
103 |
104 | tagName := []Map{{"name": "111"}, {"name": "d"}, {"name": "de"}, {"name": "def"}}
105 | division_id := 5
106 | data := Map{"tags": tagName, "division_id": division_id}
107 | testAPI(t, "put", "/api/holes/"+strconv.Itoa(holes[0].ID), 200, data)
108 |
109 | DB.Preload("Tags").Where("id = ?", holes[0].ID).Find(&holes[0])
110 |
111 | var getTagName []Map
112 | for _, v := range holes[0].Tags {
113 | getTagName = append(getTagName, Map{"name": v.Name})
114 | }
115 | assert.EqualValues(t, tagName, getTagName)
116 | assert.EqualValues(t, division_id, holes[0].DivisionID)
117 |
118 | // default schemas
119 | testAPI(t, "put", "/api/holes/"+strconv.Itoa(holes[0].ID), 400, Map{}) // bad request if modify nothing
120 | DB.Where("id = ?", holes[0].ID).Find(&holes[0])
121 | assert.Equal(t, division_id, holes[0].DivisionID)
122 | }
123 |
124 | func TestDeleteHole(t *testing.T) {
125 | var hole Hole
126 | holeID := 10
127 | testAPI(t, "delete", "/api/holes/"+strconv.Itoa(holeID), 204)
128 | testAPI(t, "delete", "/api/holes/"+strconv.Itoa(largeInt), 404)
129 | DB.Where("id = ?", 10).Find(&hole)
130 | assert.Equal(t, true, hole.Hidden)
131 | }
132 |
--------------------------------------------------------------------------------
/tests/init.go:
--------------------------------------------------------------------------------
1 | package tests
2 |
3 | import (
4 | "strconv"
5 | "strings"
6 |
7 | "github.com/rs/zerolog/log"
8 |
9 | "treehole_next/config"
10 | . "treehole_next/models"
11 | )
12 |
13 | func init() {
14 | initTestDivision()
15 | initTestHoles()
16 | initTestFloors()
17 | initTestTags()
18 | initTestFavorites()
19 | initTestReports()
20 |
21 | config.Config.OpenSensitiveCheck = false
22 | }
23 |
24 | func initTestDivision() {
25 | divisions := make(Divisions, 10)
26 | for i := range divisions {
27 | divisions[i] = &Division{
28 | ID: i + 1,
29 | Name: strconv.Itoa(i),
30 | Description: strconv.Itoa(i),
31 | }
32 | }
33 | holes := make(Holes, 10)
34 | for i := range holes {
35 | holes[i] = &Hole{
36 | DivisionID: 1,
37 | }
38 | }
39 | holes[9].DivisionID = 4 // for TestDeleteDivisionDefaultValue
40 | err := DB.Create(&divisions).Error
41 | if err != nil {
42 | log.Fatal().Err(err).Send()
43 | }
44 | err = DB.Create(&holes).Error
45 | if err != nil {
46 | log.Fatal().Err(err).Send()
47 | }
48 | }
49 |
50 | func initTestHoles() {
51 | holes := make(Holes, 10)
52 | for i := range holes {
53 | holes[i] = &Hole{
54 | DivisionID: 6,
55 | }
56 | }
57 | tag := Tag{Name: "114", Temperature: 15}
58 | holes[1].Tags = Tags{&tag}
59 | holes[2].Tags = Tags{&tag} // here it will insert twice in latest version gorm
60 | holes[3].Tags = Tags{{Name: "111", Temperature: 23}, {Name: "222", Temperature: 45}}
61 | err := DB.Create(&holes).Error
62 | if err != nil {
63 | log.Fatal().Err(err).Send()
64 | }
65 | tag = Tag{Name: "115"}
66 | err = DB.Create(&tag).Error
67 | if err != nil {
68 | log.Fatal().Err(err).Send()
69 | }
70 | }
71 |
72 | func initTestFloors() {
73 | holes := make(Holes, 10)
74 | for i := range holes {
75 | holes[i] = &Hole{
76 | DivisionID: 7,
77 | }
78 | }
79 | for i := 1; i <= 50; i++ {
80 | holes[0].Floors = append(holes[0].Floors, &Floor{Content: strings.Repeat("1", i), Ranking: i - 1})
81 | }
82 | holes[0].Floors[10].Mention = Floors{
83 | {HoleID: 102},
84 | {HoleID: 304},
85 | }
86 | holes[0].Floors[11].Mention = Floors{
87 | {HoleID: 506},
88 | {HoleID: 708},
89 | }
90 | holes[1].Floors = Floors{{Content: "123456789"}} // for TestCreate
91 | holes[2].Floors = Floors{{Content: "123456789"}} // for TestCreate
92 | holes[3].Floors = Floors{{Content: "123456789"}} // for TestModify
93 | holes[4].Floors = Floors{{Content: "123456789"}} // for TestModify like
94 | holes[5].Floors = Floors{{Content: "123456789", UserID: 1}, {Content: "23333", UserID: 5, Ranking: 1}} // for TestDelete
95 | err := DB.Create(&holes).Error
96 | if err != nil {
97 | log.Fatal().Err(err).Send()
98 | }
99 | }
100 |
101 | func initTestTags() {
102 | holes := make(Holes, 5)
103 | tags := make(Tags, 6)
104 | hole_tags := [][]int{
105 | {0, 1, 2},
106 | {3},
107 | {0, 4},
108 | {1, 0, 2},
109 | {2, 3, 4},
110 | {0, 4},
111 | } // int[tag_id][hole_id]
112 |
113 | for i := range holes {
114 | holes[i] = &Hole{DivisionID: 8}
115 | }
116 |
117 | for i := range tags {
118 | tags[i] = &Tag{Name: strconv.Itoa(i + 1)}
119 | for _, v := range hole_tags[i] {
120 | tags[i].Holes = append(tags[i].Holes, holes[v])
121 | }
122 | }
123 |
124 | tags[0].Temperature = 5
125 | tags[2].Temperature = 25
126 | tags[5].Temperature = 34
127 | err := DB.Create(&tags).Error
128 | if err != nil {
129 | log.Fatal().Err(err).Send()
130 | }
131 | }
132 |
133 | func initTestFavorites() {
134 | favoriteGroup := FavoriteGroup{Name: "test", UserID: 1, FavoriteGroupID: 0}
135 | err := DB.Create(&favoriteGroup).Error
136 | if err != nil {
137 | log.Fatal().Err(err).Send()
138 | }
139 | userFavorites := make([]UserFavorite, 10)
140 | for i := range userFavorites {
141 | userFavorites[i].HoleID = i + 1
142 | userFavorites[i].UserID = 1
143 | }
144 | err = DB.Create(&userFavorites).Error
145 | if err != nil {
146 | log.Fatal().Err(err).Send()
147 | }
148 | }
149 |
150 | const (
151 | REPORT_BASE_ID = 1
152 | REPORT_FLOOR_BASE_ID = 1001
153 | )
154 |
155 | func initTestReports() {
156 | hole := Hole{ID: 1000}
157 | floors := make(Floors, 20)
158 | for i := range floors {
159 | floors[i] = &Floor{
160 | ID: REPORT_FLOOR_BASE_ID + i,
161 | HoleID: 1000,
162 | Ranking: i,
163 | UserID: 1,
164 | }
165 | }
166 | reports := make([]Report, 10)
167 | for i := range reports {
168 | reports[i].ID = REPORT_BASE_ID + i
169 | reports[i].FloorID = REPORT_FLOOR_BASE_ID + i
170 | reports[i].UserID = 1
171 | if i < 5 {
172 | reports[i].Dealt = true
173 | }
174 | }
175 |
176 | err := DB.Create(&hole).Error
177 | if err != nil {
178 | log.Fatal().Err(err).Send()
179 | }
180 | err = DB.Create(&floors).Error
181 | if err != nil {
182 | log.Fatal().Err(err).Send()
183 | }
184 | err = DB.Create(&reports).Error
185 | if err != nil {
186 | log.Fatal().Err(err).Send()
187 | }
188 | }
189 |
--------------------------------------------------------------------------------
/tests/report_test.go:
--------------------------------------------------------------------------------
1 | package tests
2 |
3 | import (
4 | "strconv"
5 | "testing"
6 |
7 | "github.com/rs/zerolog/log"
8 |
9 | . "treehole_next/models"
10 |
11 | "github.com/stretchr/testify/assert"
12 | )
13 |
14 | func TestGetReport(t *testing.T) {
15 | reportID := REPORT_BASE_ID
16 | var report Report
17 | DB.First(&report, reportID)
18 | log.Info().Any("report", report).Send()
19 |
20 | var getReport Report
21 | testAPIModel(t, "get", "/api/reports/"+strconv.Itoa(reportID), 200, &getReport)
22 | assert.EqualValues(t, report.FloorID, getReport.FloorID)
23 | assert.EqualValues(t, report.FloorID, getReport.Floor.FloorID)
24 | }
25 |
26 | func TestListReport(t *testing.T) {
27 | data := Map{}
28 |
29 | var getReports Reports
30 | testAPIModelWithQuery(t, "get", "/api/reports", 200, &getReports, data)
31 | log.Printf("getReports: %+v\n", getReports)
32 |
33 | data = Map{"range": 1}
34 | testAPIModelWithQuery(t, "get", "/api/reports", 200, &getReports, data)
35 | log.Printf("getReports: %+v\n", getReports)
36 |
37 | data = Map{"range": 2}
38 | testAPIModelWithQuery(t, "get", "/api/reports", 200, &getReports, data)
39 | log.Printf("getReports: %+v\n", getReports)
40 | }
41 |
42 | func TestAddReport(t *testing.T) {
43 | data := Map{"floor_id": REPORT_FLOOR_BASE_ID + 14, "reason": "123456789"}
44 |
45 | testAPI(t, "post", "/api/reports", 204, data)
46 | }
47 |
48 | func TestDeleteReport(t *testing.T) {
49 | reportID := REPORT_BASE_ID + 7
50 | var getReport Report
51 | data := Map{"result": "123456789"}
52 | testAPI(t, "delete", "/api/reports/"+strconv.Itoa(reportID), 200, data)
53 |
54 | DB.First(&getReport, reportID)
55 | assert.EqualValues(t, true, getReport.Dealt)
56 | }
57 |
--------------------------------------------------------------------------------
/tests/tag_test.go:
--------------------------------------------------------------------------------
1 | package tests
2 |
3 | import (
4 | "strconv"
5 | "testing"
6 |
7 | . "treehole_next/models"
8 |
9 | "github.com/stretchr/testify/assert"
10 | )
11 |
12 | func init() {
13 |
14 | }
15 |
16 | func TestListTag(t *testing.T) {
17 | var length int64
18 | DB.Table("tag").Count(&length)
19 | resp := testAPIArray(t, "get", "/api/tags", 200)
20 | assert.Equal(t, length, int64(len(resp)))
21 | }
22 |
23 | func TestGetTag(t *testing.T) {
24 | id := 3
25 |
26 | var tag Tag
27 | DB.First(&tag, id)
28 |
29 | var newTag Tag
30 | testAPIModel(t, "get", "/api/tags/"+strconv.Itoa(id), 200, &newTag)
31 | assert.Equalf(t, tag.Name, newTag.Name, "get tag")
32 | }
33 |
34 | func TestCreateTag(t *testing.T) {
35 | data := Map{"name": "name"}
36 | testAPI(t, "post", "/api/tags", 201, data)
37 |
38 | // duplicate post, return 200 and change nothing
39 | testAPI(t, "post", "/api/tags", 200, data)
40 | }
41 |
42 | func TestModifyTag(t *testing.T) {
43 | id := 3
44 | data := Map{"name": "another", "temperature": 34}
45 |
46 | testAPI(t, "put", "/api/tags/"+strconv.Itoa(id), 200, data)
47 |
48 | var tag Tag
49 | DB.Model(&Tag{}).First(&tag, 3)
50 | assert.Equalf(t, "another", tag.Name, "modify tag name")
51 | assert.Equalf(t, 34, tag.Temperature, "modify tag tempeture")
52 | }
53 |
54 | func TestDeleteTag(t *testing.T) {
55 |
56 | // Move holes to existed tag
57 | fromName := "1"
58 | toName := "6"
59 | var id int
60 | DB.Model(Tag{}).Where("name = ?", fromName).Pluck("id", &id)
61 | data := Map{"to": toName}
62 | testAPI(t, "delete", "/api/tags/"+strconv.Itoa(id), 200, data)
63 | var tag Tag
64 | DB.Where("name = ?", toName).First(&tag)
65 | associationHolesLen := DB.Model(&tag).Association("Holes").Count()
66 | assert.EqualValuesf(t, 4, associationHolesLen, "move holes")
67 | assert.EqualValuesf(t, 39, tag.Temperature, "tag Temperature add")
68 | tag = Tag{}
69 |
70 | if result := DB.First(&tag, id); result.Error == nil {
71 | assert.Error(t, result.Error, "delete tags")
72 | }
73 |
74 | // Duplicated delete holes
75 | testAPI(t, "delete", "/api/tags/"+strconv.Itoa(id), 404, data)
76 |
77 | // Move holes to new tag
78 | id = 8
79 | data["to"] = "iii555"
80 | testAPI(t, "delete", "/api/tags/"+strconv.Itoa(id), 404, data)
81 | }
82 |
--------------------------------------------------------------------------------
/tests/utils.go:
--------------------------------------------------------------------------------
1 | package tests
2 |
3 | import (
4 | "bytes"
5 | "io"
6 | "net/http"
7 | "strings"
8 | "testing"
9 |
10 | "github.com/goccy/go-json"
11 |
12 | "treehole_next/bootstrap"
13 | . "treehole_next/models"
14 |
15 | "github.com/hetiansu5/urlquery"
16 | "github.com/stretchr/testify/assert"
17 | )
18 |
19 | type JsonData interface {
20 | Map | []Map
21 | }
22 |
23 | var App, _ = bootstrap.Init()
24 |
25 | // testCommon tests status code and returns response body in bytes
26 | func testCommon(t *testing.T, method string, route string, statusCode int, data ...Map) []byte {
27 | var requestData []byte
28 | var err error
29 |
30 | if len(data) > 0 && data[0] != nil { // data[0] is request data
31 | requestData, err = json.Marshal(data[0])
32 | assert.Nilf(t, err, "encode request body")
33 | }
34 | req, err := http.NewRequest(
35 | strings.ToUpper(method),
36 | route,
37 | bytes.NewBuffer(requestData),
38 | )
39 | req.Header.Add("Content-Type", "application/json")
40 | req.Header.Add("X-Consumer-Username", "1") // for common.GetUserID
41 | assert.Nilf(t, err, "constructs http request")
42 |
43 | res, err := App.Test(req, -1)
44 | assert.Nilf(t, err, "perform request")
45 | assert.Equalf(t, statusCode, res.StatusCode, "status code")
46 |
47 | responseBody, err := io.ReadAll(res.Body)
48 | assert.Nilf(t, err, "decode response")
49 |
50 | return responseBody
51 | }
52 |
53 | // testCommonQuery tests status code and returns response body in bytes
54 | func testCommonQuery(t *testing.T, method string, route string, statusCode int, data ...Map) []byte {
55 | var err error
56 | req, err := http.NewRequest(
57 | strings.ToUpper(method),
58 | route,
59 | nil,
60 | )
61 | if len(data) > 0 && data[0] != nil { // data[0] is query data
62 | queryData, err := urlquery.Marshal(data[0])
63 | req.URL.RawQuery = string(queryData)
64 | assert.Nilf(t, err, "encode request body")
65 | }
66 |
67 | req.Header.Add("Content-Type", "application/json")
68 | req.Header.Add("X-Consumer-Username", "1") // for common.GetUserID
69 | assert.Nilf(t, err, "constructs http request")
70 |
71 | res, err := App.Test(req, -1)
72 | assert.Nilf(t, err, "perform request")
73 | assert.Equalf(t, statusCode, res.StatusCode, "status code")
74 |
75 | responseBody, err := io.ReadAll(res.Body)
76 | assert.Nilf(t, err, "decode response")
77 |
78 | return responseBody
79 | }
80 |
81 | // testAPIGeneric inherits testCommon, decodes response body to json, tests whether it's expected
82 | func testAPIGeneric[T JsonData](t *testing.T, method string, route string, statusCode int, data ...Map) T {
83 | responseBody := testCommon(t, method, route, statusCode, data...)
84 |
85 | if statusCode == 204 || statusCode == 302 { // no content and redirect
86 | return nil
87 | }
88 | var responseData T
89 | err := json.Unmarshal(responseBody, &responseData)
90 | assert.Nilf(t, err, "decode response")
91 |
92 | if len(data) > 1 { // data[1] is response data
93 | assert.EqualValuesf(t, data[1], responseData, "response data")
94 | }
95 |
96 | return responseData
97 | }
98 |
99 | // testAPI returns a Map
100 | func testAPI(t *testing.T, method string, route string, statusCode int, data ...Map) Map {
101 | return testAPIGeneric[Map](t, method, route, statusCode, data...)
102 | }
103 |
104 | // testAPIArray returns []Map
105 | func testAPIArray(t *testing.T, method string, route string, statusCode int, data ...Map) []Map {
106 | return testAPIGeneric[[]Map](t, method, route, statusCode, data...)
107 | }
108 |
109 | func testAPIModel[T Models](t *testing.T, method string, route string, statusCode int, obj *T, data ...Map) {
110 | responseBytes := testCommon(t, method, route, statusCode, data...)
111 | err := json.Unmarshal(responseBytes, obj)
112 | assert.Nilf(t, err, "unmarshal response")
113 | }
114 |
115 | func testAPIModelWithQuery[T Models](t *testing.T, method string, route string, statusCode int, obj *T, data ...Map) {
116 | responseBytes := testCommonQuery(t, method, route, statusCode, data...)
117 | err := json.Unmarshal(responseBytes, obj)
118 | assert.Nilf(t, err, "unmarshal response")
119 | }
120 |
--------------------------------------------------------------------------------
/utils/bot.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "bytes"
5 | "encoding/json"
6 | "fmt"
7 | "io"
8 | "net/http"
9 | "treehole_next/config"
10 | )
11 |
12 | type BotMessageType string
13 |
14 | const (
15 | MessageTypeGroup BotMessageType = "group"
16 | MessageTypePrivate BotMessageType = "private"
17 | )
18 |
19 | type BotMessage struct {
20 | MessageType BotMessageType `json:"message_type"`
21 | GroupID *int64 `json:"group_id"`
22 | UserID *int64 `json:"user_id"`
23 | Message string `json:"message"`
24 | AutoEscape bool `json:"auto_escape default:false"`
25 | }
26 |
27 | type FeishuMessage struct {
28 | MsgType string `json:"msg_type"`
29 | Content string `json:"message"`
30 | }
31 |
32 | func NotifyFeishu(feishuMessage *FeishuMessage) {
33 | if feishuMessage == nil || feishuMessage.MsgType == "" {
34 | return
35 | }
36 | if config.Config.FeishuBotUrl == nil {
37 | return
38 | }
39 | url := *config.Config.FeishuBotUrl
40 |
41 | jsonData, err := json.Marshal(feishuMessage)
42 | if err != nil {
43 | RequestLog("Error marshaling JSON", "NotifyFeishu", 0, false)
44 | return
45 | }
46 |
47 | RequestLog(fmt.Sprintf("Request: %s", string(jsonData)), "NotifyFeishu", 0, false)
48 |
49 | resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData))
50 | if err != nil {
51 | RequestLog("Error creating request", "NotifyFeishu", 0, false)
52 | return
53 | }
54 |
55 | defer resp.Body.Close()
56 | if resp.StatusCode != 200 {
57 | response, err := io.ReadAll(resp.Body)
58 | if err != nil {
59 | RequestLog("Error Unmarshaling response", "NotifyFeishu", 0, false)
60 | }
61 | RequestLog(fmt.Sprintf("Error sending request %s", string(response)), "NotifyFeishu", 0, false)
62 | }
63 | }
64 |
65 | func NotifyQQ(botMessage *BotMessage) {
66 | if botMessage == nil {
67 | return
68 | }
69 | if botMessage.MessageType == MessageTypeGroup && botMessage.GroupID == nil {
70 | return
71 | }
72 | if botMessage.MessageType == MessageTypePrivate && botMessage.UserID == nil {
73 | return
74 | }
75 | if config.Config.QQBotUrl == nil {
76 | return
77 | }
78 | // "[CQ:face,id=199]test[CQ:image,file=https://ts1.cn.mm.bing.net/th?id=OIP-C.K5AFHsGlWeLUzKjXGXxdQgHaFj&w=224&h=150&c=8&rs=1&qlt=90&o=6&dpr=1.5&pid=3.1&rm=2]",
79 | url := *config.Config.QQBotUrl + "/send_msg"
80 |
81 | jsonData, err := json.Marshal(botMessage)
82 | if err != nil {
83 | RequestLog("Error marshaling JSON", "NotifyQQ", 0, false)
84 | return
85 | }
86 |
87 | RequestLog(fmt.Sprintf("Request: %s", string(jsonData)), "NotifyQQ", 0, false)
88 |
89 | resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData))
90 | if err != nil {
91 | RequestLog("Error creating request", "NotifyQQ", 0, false)
92 | return
93 | }
94 |
95 | defer resp.Body.Close()
96 | if resp.StatusCode != 200 {
97 | response, err := io.ReadAll(resp.Body)
98 | if err != nil {
99 | RequestLog("Error Unmarshaling response", "NotifyQQ", 0, false)
100 | }
101 | RequestLog(fmt.Sprintf("Error sending request %s", string(response)), "NotifyQQ", 0, false)
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/utils/cache.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "context"
5 | "time"
6 |
7 | "github.com/eko/gocache/lib/v4/cache"
8 | "github.com/eko/gocache/lib/v4/store"
9 | gocache_store "github.com/eko/gocache/store/go_cache/v4"
10 | redis_store "github.com/eko/gocache/store/redis/v4"
11 | "github.com/goccy/go-json"
12 | gocache "github.com/patrickmn/go-cache"
13 | "github.com/redis/go-redis/v9"
14 |
15 | "treehole_next/config"
16 | )
17 |
18 | var Cache *cache.Cache[[]byte]
19 |
20 | func InitCache() {
21 | if config.Config.RedisURL != "" {
22 | redisStore := redis_store.NewRedis(redis.NewClient(&redis.Options{
23 | Addr: config.Config.RedisURL,
24 | }))
25 | Cache = cache.New[[]byte](redisStore)
26 | } else {
27 | gocacheStore := gocache_store.NewGoCache(gocache.New(5*time.Minute, 10*time.Minute))
28 | Cache = cache.New[[]byte](gocacheStore)
29 | }
30 | }
31 |
32 | const maxDuration time.Duration = 1<<63 - 1
33 |
34 | func SetCache(key string, value any, expiration time.Duration) error {
35 | data, err := json.Marshal(value)
36 | if err != nil {
37 | return err
38 | }
39 | if expiration == 0 {
40 | expiration = maxDuration
41 | }
42 | return Cache.Set(context.Background(), key, data, store.WithExpiration(expiration))
43 | }
44 |
45 | func GetCache(key string, value any) bool {
46 | data, err := Cache.Get(context.Background(), key)
47 | if err != nil {
48 | return false
49 | }
50 | err = json.Unmarshal(data, value)
51 | return err == nil
52 | }
53 |
54 | func DeleteCache(key string) error {
55 | err := Cache.Delete(context.Background(), key)
56 | if err == nil {
57 | return nil
58 | }
59 | if err.Error() == "Entry not found" {
60 | return nil
61 | }
62 | return err
63 | }
64 |
--------------------------------------------------------------------------------
/utils/errors.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | const (
4 | ErrCodeNotAnsweredQuestions = iota + 403001
5 | )
6 |
--------------------------------------------------------------------------------
/utils/log.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "github.com/rs/zerolog/log"
5 | )
6 |
7 | type Role string
8 |
9 | const (
10 | RoleOwner = "owner"
11 | RoleAdmin = "admin"
12 | RoleOperator = "operator"
13 | )
14 |
15 | func MyLog(model string, action string, objectID, userID int, role Role, msg ...string) {
16 | message := ""
17 | for _, v := range msg {
18 | message += v
19 | }
20 | log.Info().
21 | Str("model", model).
22 | Int("user_id", userID).
23 | Int("object_id", objectID).
24 | Str("action", action).
25 | Str("role", string(role)).
26 | Msg(message)
27 | }
28 |
29 | func RequestLog(msg string, TypeName string, Id int64, ans bool) {
30 | log.Info().Str("TypeName", TypeName).
31 | Int64("Id", Id).
32 | Bool("CheckAnswer", ans).
33 | Msg(msg)
34 | }
35 |
--------------------------------------------------------------------------------
/utils/model.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | type IDModel[T any] interface {
4 | *T
5 | GetID() int
6 | }
7 |
8 | func binarySearch[T any, PT IDModel[T]](models []PT, targetID int) int {
9 | left := 0
10 | right := len(models)
11 | for left < right {
12 | mid := left + (right-left)>>1
13 | if models[mid].GetID() < targetID {
14 | left = mid + 1
15 | } else if models[mid].GetID() > targetID {
16 | right = mid
17 | } else {
18 | return mid
19 | }
20 | }
21 | return -1
22 | }
23 |
24 | func OrderInGivenOrder[T any, PT IDModel[T]](models []PT, order []int) (result []PT) {
25 | for _, i := range order {
26 | index := binarySearch(models, i)
27 | if index >= 0 {
28 | result = append(result, models[index])
29 | }
30 | }
31 | return result
32 | }
33 |
34 | func Models2IDSlice[T any, PT IDModel[T]](models []PT) (result []int) {
35 | result = make([]int, len(models))
36 | for i := range models {
37 | result[i] = models[i].GetID()
38 | }
39 | return result
40 | }
41 |
--------------------------------------------------------------------------------
/utils/name.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "encoding/base64"
5 | "encoding/binary"
6 | "math/rand"
7 | "sort"
8 | "time"
9 |
10 | "github.com/goccy/go-json"
11 | "github.com/rs/zerolog/log"
12 |
13 | "treehole_next/config"
14 | "treehole_next/data"
15 |
16 | "golang.org/x/exp/slices"
17 | )
18 |
19 | var names []string
20 | var length int
21 |
22 | const (
23 | charset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
24 | randomCodeLength = 6
25 | )
26 |
27 | func init() {
28 | err := json.Unmarshal(data.NamesFile, &names)
29 | if err != nil {
30 | log.Fatal().Err(err).Send()
31 | }
32 | sort.Strings(names)
33 | length = len(names)
34 | }
35 |
36 | func inArray(target string, array []string) bool {
37 | _, in := slices.BinarySearch(array, target)
38 | return in
39 | }
40 |
41 | func timeStampBase64() string {
42 | bytes := make([]byte, 8)
43 | binary.LittleEndian.PutUint64(bytes, uint64(time.Now().Unix()))
44 | return base64.StdEncoding.EncodeToString(bytes)
45 | }
46 |
47 | func generateRandomCode() string {
48 | code := make([]byte, randomCodeLength)
49 | charsetLength := len(charset)
50 |
51 | for i := 0; i < randomCodeLength; i++ {
52 | n := rand.Intn(charsetLength)
53 | code[i] = charset[n]
54 | }
55 |
56 | return string(code)
57 | }
58 |
59 | func NewRandName() string {
60 | return names[rand.Intn(length)]
61 | }
62 |
63 | func GenerateName(compareList []string) string {
64 | if len(compareList) < length>>3 {
65 | for {
66 | name := NewRandName()
67 | if !inArray(name, compareList) {
68 | return name
69 | }
70 | }
71 | } else if len(compareList) < length {
72 | var j, k int
73 | list := make([]string, length)
74 | for i := 0; i < length; i++ {
75 | if j < len(compareList) && names[i] == compareList[j] {
76 | j++
77 | } else {
78 | list[k] = names[i]
79 | k++
80 | }
81 | }
82 | return list[rand.Intn(k)]
83 | } else {
84 | for {
85 | // name := names[rand.Intn(length)] + "_" + timeStampBase64()
86 | name := names[rand.Intn(length)] + "_" + generateRandomCode()
87 | if !inArray(name, compareList) {
88 | return name
89 | }
90 | }
91 | }
92 | }
93 |
94 | func GetFuzzName(name string) string {
95 | if !config.Config.OpenFuzzName {
96 | return name
97 | }
98 | if fuzzName, ok := data.NamesMapping[name]; ok {
99 | return fuzzName
100 | } else {
101 | return name
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/utils/sensitive/utils.go:
--------------------------------------------------------------------------------
1 | package sensitive
2 |
3 | import (
4 | "errors"
5 | "golang.org/x/exp/slices"
6 | "mvdan.cc/xurls/v2"
7 | "net/url"
8 | "regexp"
9 | "strings"
10 | "treehole_next/config"
11 | )
12 |
13 | var imageRegex = regexp.MustCompile(
14 | `!\[(.*?)]\(([^" )]*?)\s*(".*?")?\)`,
15 | )
16 |
17 | var (
18 | ErrUrlParsing = errors.New("error parsing url")
19 | ErrInvalidImageHost = errors.New("不允许使用外部图片链接")
20 | ErrImageLinkTextOnly = errors.New("image link only contains text")
21 | )
22 |
23 | // findImagesInMarkdown 从Markdown文本中查找所有图片链接,检查图片链接是否合法,并且返回清除链接之后的文本
24 | func findImagesInMarkdownContent(content string) (imageUrls []string, clearContent string, err error) {
25 | err = nil
26 | clearContent = imageRegex.ReplaceAllStringFunc(content, func(s string) string {
27 | if err != nil {
28 | return ""
29 | }
30 | submatch := imageRegex.FindStringSubmatch(s)
31 | altText := submatch[1]
32 |
33 | var imageUrl string
34 | if len(submatch) > 2 && submatch[2] != "" {
35 | imageUrl = submatch[2]
36 | innerErr := checkValidUrl(imageUrl)
37 | if innerErr != nil {
38 | if errors.Is(innerErr, ErrInvalidImageHost) {
39 | err = innerErr
40 | return ""
41 | }
42 | // if the url is not valid, treat as text only
43 | } else {
44 | // append only valid image url
45 | imageUrls = append(imageUrls, imageUrl)
46 | imageUrl = ""
47 | }
48 | }
49 |
50 | var title string
51 | if len(submatch) > 3 && submatch[3] != "" {
52 | title = strings.Trim(submatch[3], "\"")
53 | }
54 |
55 | var ret strings.Builder
56 | if altText != "" {
57 | ret.WriteString(altText)
58 | }
59 | if imageUrl != "" {
60 | if ret.String() != "" {
61 | ret.WriteString(" ")
62 | }
63 | ret.WriteString(imageUrl)
64 | }
65 | if title != "" {
66 | if ret.String() != "" {
67 | ret.WriteString(" ")
68 | }
69 | ret.WriteString(title)
70 | }
71 | return ret.String()
72 | })
73 | return
74 | }
75 |
76 | func checkType(params ParamsForCheck) bool {
77 | return slices.Contains(checkTypes, params.TypeName)
78 | }
79 |
80 | func containsUnsafeURL(content string) (bool, string) {
81 | xurlsRelaxed := xurls.Relaxed()
82 | matchedURLs := xurlsRelaxed.FindAllString(content, -1)
83 | if len(matchedURLs) == 0 {
84 | return false, ""
85 | }
86 |
87 | for _, matchedURL := range matchedURLs {
88 | if !strings.Contains(matchedURL, "://") {
89 | matchedURL = "http://" + matchedURL
90 | }
91 | parsedURL, err := url.Parse(matchedURL)
92 | if err != nil || parsedURL == nil {
93 | return true, matchedURL
94 | }
95 | checked := slices.ContainsFunc(config.Config.UrlHostnameWhitelist, func(s string) bool {
96 | return strings.HasSuffix(parsedURL.Host, s)
97 | })
98 | if !checked {
99 | return true, parsedURL.Host
100 | }
101 | }
102 | return false, ""
103 | }
104 |
105 | func checkValidUrl(input string) error {
106 | imageUrl, err := url.Parse(input)
107 | if err != nil {
108 | return ErrUrlParsing
109 | }
110 | // if the url is text only, skip check
111 | if imageUrl.Scheme == "" && imageUrl.Host == "" {
112 | return ErrImageLinkTextOnly
113 | }
114 | if !slices.Contains(config.Config.ValidImageUrl, imageUrl.Hostname()) {
115 | return ErrInvalidImageHost
116 | }
117 | return nil
118 | }
119 |
120 | var reHole = regexp.MustCompile(`[^#]#(\d+)`)
121 | var reFloor = regexp.MustCompile(`##(\d+)`)
122 |
123 | func removeIDReprInContent(content string) string {
124 | content = " " + content
125 | content = reHole.ReplaceAllString(content, "")
126 | content = reFloor.ReplaceAllString(content, "")
127 | return strings.TrimSpace(content)
128 | }
129 |
--------------------------------------------------------------------------------
/utils/sensitive/utils_test.go:
--------------------------------------------------------------------------------
1 | package sensitive
2 |
3 | import (
4 | "github.com/stretchr/testify/assert"
5 | "testing"
6 | "treehole_next/config"
7 | )
8 |
9 | func TestFindImagesInMarkdown(t *testing.T) {
10 | config.Config.ValidImageUrl = []string{"example.com"}
11 |
12 | type wantStruct struct {
13 | clearContent string
14 | imageUrls []string
15 | err error
16 | }
17 | tests := []struct {
18 | text string
19 | want wantStruct
20 | }{
21 | {
22 | text: ``,
23 | want: wantStruct{
24 | clearContent: `image1`,
25 | imageUrls: []string{"https://example.com/image1"},
26 | err: nil,
27 | },
28 | },
29 | {
30 | text: ` `,
31 | want: wantStruct{
32 | clearContent: `image1 image2`,
33 | imageUrls: []string{"https://example.com/image1", "https://example.com/image2"},
34 | },
35 | },
36 | {
37 | text: ` `,
38 | want: wantStruct{
39 | clearContent: `image1 title1 image2 title2`,
40 | imageUrls: []string{"https://example.com/image1", "https://example.com/image2"},
41 | },
42 | },
43 | {
44 | text: ` `,
45 | want: wantStruct{
46 | clearContent: `image1 123 image2 456`,
47 | imageUrls: nil,
48 | },
49 | },
50 | {
51 | text: ` `,
52 | want: wantStruct{
53 | clearContent: `123 456`,
54 | imageUrls: nil,
55 | },
56 | },
57 | {
58 | text: "",
59 | want: wantStruct{
60 | clearContent: "",
61 | imageUrls: nil,
62 | err: ErrInvalidImageHost,
63 | },
64 | },
65 | }
66 |
67 | for _, tt := range tests {
68 | imageUrls, cleanText, err := findImagesInMarkdownContent(tt.text)
69 | assert.EqualValues(t, tt.want.clearContent, cleanText, "cleanText should be equal")
70 | assert.EqualValues(t, tt.want.imageUrls, imageUrls, "imageUrls should be equal")
71 | assert.EqualValues(t, tt.want.err, err, "err should be equal")
72 | }
73 | }
74 |
75 | func TestCheckValidUrl(t *testing.T) {
76 | config.Config.ValidImageUrl = []string{"example.com"}
77 | type wantStruct struct {
78 | err error
79 | }
80 | tests := []struct {
81 | url string
82 | want wantStruct
83 | }{
84 | {
85 | url: "https://example.com/image1",
86 | want: wantStruct{
87 | err: nil,
88 | },
89 | },
90 | {
91 | url: "https://example.com/image2",
92 | want: wantStruct{
93 | err: nil,
94 | },
95 | },
96 | {
97 | url: "123456",
98 | want: wantStruct{
99 | err: ErrImageLinkTextOnly,
100 | },
101 | },
102 | {
103 | url: "https://example2.com",
104 | want: wantStruct{
105 | err: ErrInvalidImageHost,
106 | },
107 | },
108 | }
109 |
110 | for _, tt := range tests {
111 | err := checkValidUrl(tt.url)
112 | assert.EqualValues(t, tt.want.err, err, "err should be equal")
113 | }
114 | }
115 |
--------------------------------------------------------------------------------
/utils/utils.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "golang.org/x/exp/slices"
5 | "strconv"
6 |
7 | "github.com/gofiber/fiber/v2"
8 | "github.com/opentreehole/go-common"
9 | "golang.org/x/exp/constraints"
10 |
11 | "treehole_next/config"
12 | )
13 |
14 | type CanPreprocess interface {
15 | Preprocess(c *fiber.Ctx) error
16 | }
17 |
18 | func Serialize(c *fiber.Ctx, obj CanPreprocess) error {
19 | err := obj.Preprocess(c)
20 | if err != nil {
21 | return err
22 | }
23 | return c.JSON(obj)
24 | }
25 |
26 | func RegText2IntArray(IDs [][]string) ([]int, error) {
27 | ansIDs := make([]int, 0)
28 | for _, v := range IDs {
29 | id, err := strconv.Atoi(v[1])
30 | if err != nil {
31 | return nil, err
32 | }
33 | ansIDs = append(ansIDs, id)
34 | }
35 | return ansIDs, nil
36 | }
37 |
38 | func Keys[T comparable, S any](m map[T]S) (s []T) {
39 | for k := range m {
40 | s = append(s, k)
41 | }
42 | return s
43 | }
44 |
45 | func Min[T constraints.Ordered](x T, y T) T {
46 | if x > y {
47 | return y
48 | } else {
49 | return x
50 | }
51 | }
52 |
53 | func Intersect[T comparable](x []T, y []T) []T {
54 | var result = make([]T, 0)
55 | for i := range x {
56 | if slices.Contains(y, x[i]) {
57 | result = append(result, x[i])
58 | }
59 | }
60 | return result
61 | }
62 |
63 | // Difference returns the elements in a that aren't in b
64 | func Difference[T comparable](a, b []T) []T {
65 | m := make(map[T]bool)
66 | var result []T
67 |
68 | for _, item := range b {
69 | m[item] = true
70 | }
71 |
72 | for _, item := range a {
73 | if _, ok := m[item]; !ok {
74 | result = append(result, item)
75 | }
76 | }
77 |
78 | return result
79 | }
80 |
81 | func StripContent(content string, contentMaxSize int) string {
82 | return string([]rune(content)[:Min(len([]rune(content)), contentMaxSize)])
83 | }
84 |
85 | func MiddlewareHasAnsweredQuestions(c *fiber.Ctx) error {
86 | if config.Config.Mode == "test" || config.Config.Mode == "bench" {
87 | return c.Next()
88 | }
89 | var user struct {
90 | HasAnsweredQuestions bool `json:"has_answered_questions"`
91 | }
92 | err := common.ParseJWTToken(common.GetJWTToken(c), &user)
93 | if err != nil {
94 | return err
95 | }
96 | if !user.HasAnsweredQuestions {
97 | return &common.HttpError{
98 | Code: ErrCodeNotAnsweredQuestions,
99 | Message: "请先通过注册答题",
100 | }
101 | }
102 | return c.Next()
103 | }
104 |
--------------------------------------------------------------------------------
/utils/utils_test.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/stretchr/testify/assert"
7 | )
8 |
9 | func TestStripContent(t *testing.T) {
10 | var str string
11 | str = "愿中国青年都摆脱冷气,只是向上走,不必听自暴自弃者流的话。能做事的做事,能发声的发声。有一分热,发一分光。就令萤火一般,也可以在黑暗里发一点光,不必等候炬火。"
12 | println(len(str))
13 | println(len([]rune(str)))
14 | assert.Equal(t, "愿中国青年都摆脱冷气", StripContent(str, 10))
15 | assert.Equal(t, str, StripContent(str, 100))
16 | }
17 |
--------------------------------------------------------------------------------