├── pkg
├── logger
│ ├── logger.go
│ ├── error.go
│ └── ginLogger
│ │ └── ginLogger.go
├── s3
│ ├── s3_test.go
│ └── s3.go
├── utils
│ ├── utils_test.go
│ ├── geoDb.go
│ ├── aes_test.go
│ ├── aes.go
│ └── utils.go
├── mail
│ ├── mail_test.go
│ └── mail.go
├── push
│ ├── once.go
│ ├── route.go
│ ├── client.go
│ ├── notification.go
│ └── stream.go
├── base
│ ├── http_return.go
│ ├── cache_test.go
│ ├── redis.go
│ ├── send_push.go
│ ├── hooks.go
│ ├── permissions.go
│ ├── cache.go
│ ├── structs.go
│ └── sql.go
├── model
│ └── model.go
├── route
│ ├── security
│ │ ├── updateIOSDeviceToken.go
│ │ ├── route.go
│ │ ├── middlewares.go
│ │ ├── devices.go
│ │ ├── deleteAccount.go
│ │ ├── login.go
│ │ ├── checkEmail.go
│ │ └── createAccount.go
│ ├── contents
│ │ ├── hotPosts.go
│ │ ├── push.go
│ │ ├── vote.go
│ │ ├── route.go
│ │ ├── middlewares.go
│ │ ├── adminCommands.go
│ │ └── routeApiGET.go
│ └── auth
│ │ └── auth.go
├── config
│ └── config.go
├── bot
│ └── telegram.go
└── consts
│ └── consts.go
├── README.md
├── .gitignore
├── cmd
├── treehollow-v3-push-api
│ └── main.go
├── treehollow-v3-services-api
│ └── main.go
├── treehollow-v3-security-api
│ └── main.go
├── test-pressure-gen-tokens
│ └── main.go
├── test-pressure-connector
│ └── main.go
├── treehollow-v3-fallback
│ └── main.go
└── treehollow-migrate-v2-to-v3
│ └── main.go
├── go.mod
└── example.config.yml
/pkg/logger/logger.go:
--------------------------------------------------------------------------------
1 | package logger
2 |
3 | import (
4 | "io"
5 | "log"
6 | "os"
7 | )
8 |
9 | func InitLog(logFileName string) {
10 | logFile, err := os.OpenFile(logFileName, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0666)
11 | if err != nil {
12 | panic(err)
13 | }
14 | mw := io.MultiWriter(os.Stdout, logFile)
15 | log.SetOutput(mw)
16 | }
17 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # treehollow-v3-backend
2 | [](https://www.gnu.org/licenses/agpl-3.0)
3 | [](https://codebeat.co/projects/github-com-thuhole-thuhole-go-backend-master)
4 |
5 | 第三代树洞后端
6 |
7 | 安装文档:https://github.com/treehollow/install-doc
8 |
9 | ## License
10 | [AGPL v3](./LICENSE)
11 |
--------------------------------------------------------------------------------
/pkg/s3/s3_test.go:
--------------------------------------------------------------------------------
1 | package s3
2 |
3 | import (
4 | "os"
5 | "strings"
6 | "testing"
7 | "treehollow-v3-backend/pkg/config"
8 | )
9 |
10 | func TestS3(t *testing.T) {
11 | if os.Getenv("TRAVIS") != "true" {
12 | _ = os.Chdir("..")
13 | _ = os.Chdir("..")
14 | config.InitConfigFile()
15 | err := Upload("test/test.txt", strings.NewReader("hello1"))
16 | if err != nil {
17 | t.Errorf("err=%s", err)
18 | }
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Binaries for programs and plugins
2 | *.exe
3 | *.exe~
4 | *.dll
5 | *.so
6 | *.dylib
7 |
8 | # Test binary, built with `go test -c`
9 | *.test
10 |
11 | # Output of the go coverage tool, specifically when used with LiteIDE
12 | *.out
13 |
14 | # Dependency directories (remove the comment below to include it)
15 | # vendor/
16 |
17 | #######################################################################
18 | .idea/
19 | *.key
20 | config.json*
21 | config.yml
22 | log.txt
23 | TODO.md
24 | *.mmdb
25 | *.log
26 |
--------------------------------------------------------------------------------
/pkg/utils/utils_test.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | var (
8 | cases = []struct {
9 | mail string
10 | valid bool
11 | }{
12 | {mail: "admin@mails.tsinghua.edu.cn", valid: true},
13 | {mail: "thu-hole@mails.tsinghua.edu.cn", valid: true},
14 | {mail: "thu_hole@mails.tsinghua.edu.cn", valid: true},
15 | {mail: "yezhisheng@pku.edu.cn,admin@mails.tsinghua.edu.cn", valid: false},
16 | }
17 | )
18 |
19 | func TestCheckMail(t *testing.T) {
20 | for _, c := range cases {
21 | if CheckEmail(c.mail) != c.valid {
22 | t.Errorf("%s is expected to be %v", c.mail, c.valid)
23 | }
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/pkg/mail/mail_test.go:
--------------------------------------------------------------------------------
1 | package mail
2 |
3 | import (
4 | "os"
5 | "testing"
6 | "treehollow-v3-backend/pkg/config"
7 | )
8 |
9 | func TestSendCode(t *testing.T) {
10 | _ = os.Chdir("..")
11 | _ = os.Chdir("..")
12 | config.InitConfigFile()
13 | err := SendValidationEmail("123456", "test-treehollow3@srv1.mail-tester.com")
14 | if err != nil {
15 | t.Errorf("error: %s", err)
16 | }
17 | }
18 |
19 | func TestSendNonce(t *testing.T) {
20 | _ = os.Chdir("..")
21 | _ = os.Chdir("..")
22 | config.InitConfigFile()
23 | err := SendPasswordNonceEmail("nonce-198247832648712631", "test-treehollow3@srv1.mail-tester.com")
24 | if err != nil {
25 | t.Errorf("error: %s", err)
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/cmd/treehollow-v3-push-api/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "github.com/gin-gonic/gin"
5 | "github.com/spf13/viper"
6 | "log"
7 | "time"
8 | "treehollow-v3-backend/pkg/base"
9 | "treehollow-v3-backend/pkg/config"
10 | "treehollow-v3-backend/pkg/consts"
11 | "treehollow-v3-backend/pkg/logger"
12 | "treehollow-v3-backend/pkg/push"
13 | )
14 |
15 | func main() {
16 | logger.InitLog(consts.PushApiLogFile)
17 | config.InitConfigFile()
18 |
19 | base.InitDb()
20 | base.AutoMigrateDb()
21 |
22 | log.Println("start time: ", time.Now().Format("01-02 15:04:05"))
23 | if false == viper.GetBool("is_debug") {
24 | gin.SetMode(gin.ReleaseMode)
25 | }
26 |
27 | push.ApiListenHttp()
28 | }
29 |
--------------------------------------------------------------------------------
/pkg/push/once.go:
--------------------------------------------------------------------------------
1 | // from https://github.com/gotify/server/blob/3454dcd60226acf121009975d947f05d41267283/api/stream/once.go
2 | package push
3 |
4 | import (
5 | "sync"
6 | "sync/atomic"
7 | )
8 |
9 | // Modified version of sync.Once (https://github.com/golang/go/blob/master/src/sync/once.go)
10 | // This version unlocks the mutex early and therefore doesn't hold the lock while executing func f().
11 | type once struct {
12 | m sync.Mutex
13 | done uint32
14 | }
15 |
16 | func (o *once) Do(f func()) {
17 | if atomic.LoadUint32(&o.done) == 1 {
18 | return
19 | }
20 | if o.mayExecute() {
21 | f()
22 | }
23 | }
24 |
25 | func (o *once) mayExecute() bool {
26 | o.m.Lock()
27 | defer o.m.Unlock()
28 | if o.done == 0 {
29 | atomic.StoreUint32(&o.done, 1)
30 | return true
31 | }
32 | return false
33 | }
34 |
--------------------------------------------------------------------------------
/cmd/treehollow-v3-services-api/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "github.com/gin-gonic/gin"
5 | "github.com/spf13/viper"
6 | "log"
7 | "time"
8 | "treehollow-v3-backend/pkg/base"
9 | "treehollow-v3-backend/pkg/config"
10 | "treehollow-v3-backend/pkg/consts"
11 | "treehollow-v3-backend/pkg/logger"
12 | "treehollow-v3-backend/pkg/route/contents"
13 | )
14 |
15 | func main() {
16 | logger.InitLog(consts.ServicesApiLogFile)
17 | config.InitConfigFile()
18 |
19 | base.InitDb()
20 | base.AutoMigrateDb()
21 |
22 | log.Println("start time: ", time.Now().Format("01-02 15:04:05"))
23 | if false == viper.GetBool("is_debug") {
24 | gin.SetMode(gin.ReleaseMode)
25 | }
26 |
27 | //utils.InitGeoDbRefreshCron()
28 | contents.RefreshHotPosts()
29 | contents.InitHotPostsRefreshCron()
30 |
31 | contents.ServicesApiListenHttp()
32 | }
33 |
--------------------------------------------------------------------------------
/pkg/base/http_return.go:
--------------------------------------------------------------------------------
1 | package base
2 |
3 | import (
4 | "github.com/gin-gonic/gin"
5 | "net/http"
6 | "strconv"
7 | "treehollow-v3-backend/pkg/logger"
8 | )
9 |
10 | func HttpReturnWithCodeMinusOne(c *gin.Context, e *logger.InternalError) {
11 | HttpReturnWithErr(c, -1, e)
12 | }
13 |
14 | func HttpReturnWithErr(c *gin.Context, code int, e *logger.InternalError) {
15 | user, exists := c.Get("user")
16 | if exists {
17 | e.InternalMsg = "(UserID=" + strconv.Itoa(int(user.(User).ID)) + ")" + e.InternalMsg
18 | }
19 | e.Log()
20 | c.JSON(http.StatusOK, gin.H{
21 | "code": code,
22 | "msg": e.DisplayMsg,
23 | })
24 | }
25 |
26 | func HttpReturnWithErrAndAbort(c *gin.Context, code int, e *logger.InternalError) {
27 | HttpReturnWithErr(c, code, e)
28 | c.Abort()
29 | }
30 |
31 | func HttpReturnWithCodeMinusOneAndAbort(c *gin.Context, e *logger.InternalError) {
32 | HttpReturnWithErrAndAbort(c, -1, e)
33 | c.Abort()
34 | }
35 |
--------------------------------------------------------------------------------
/pkg/base/cache_test.go:
--------------------------------------------------------------------------------
1 | package base
2 |
3 | import (
4 | "fmt"
5 | "github.com/vmihailenco/msgpack/v5"
6 | "testing"
7 | )
8 |
9 | func TestMsgPack(t *testing.T) {
10 | type Item struct {
11 | Foo string
12 | Slice []string
13 | Test []Item
14 | }
15 |
16 | b, err := msgpack.Marshal(&Item{Foo: "bar", Slice: []string{"123", "456"}, Test: []Item{{Foo: "test"}}})
17 | if err != nil {
18 | panic(err)
19 | }
20 | fmt.Println(string(b))
21 |
22 | var item Item
23 | err = msgpack.Unmarshal(b, &item)
24 | if err != nil {
25 | panic(err)
26 | }
27 | fmt.Println(item)
28 | }
29 |
30 | func TestMsgPack2(t *testing.T) {
31 | type Item struct {
32 | Foo string
33 | }
34 |
35 | b, err := msgpack.Marshal(&[]Item{{Foo: "Bar"}, {Foo: "123"}})
36 | if err != nil {
37 | panic(err)
38 | }
39 | fmt.Println(string(b))
40 |
41 | var item []Item
42 | err = msgpack.Unmarshal(b, &item)
43 | if err != nil {
44 | panic(err)
45 | }
46 | fmt.Println(item)
47 | }
48 |
--------------------------------------------------------------------------------
/pkg/base/redis.go:
--------------------------------------------------------------------------------
1 | package base
2 |
3 | import (
4 | libredis "github.com/go-redis/redis/v8"
5 | "github.com/spf13/viper"
6 | "github.com/ulule/limiter/v3"
7 | sredis "github.com/ulule/limiter/v3/drivers/store/redis"
8 | "treehollow-v3-backend/pkg/utils"
9 | )
10 |
11 | var redisClient *libredis.Client
12 |
13 | func initRedis() error {
14 | option, err := libredis.ParseURL(viper.GetString("redis_source"))
15 | if err != nil {
16 | utils.FatalErrorHandle(&err, "failed init redis url")
17 | return err
18 | }
19 | redisClient = libredis.NewClient(option)
20 | return nil
21 | }
22 |
23 | func InitLimiter(rate limiter.Rate, prefix string) *limiter.Limiter {
24 | client := GetRedisClient()
25 | store, err2 := sredis.NewStoreWithOptions(client, limiter.StoreOptions{
26 | Prefix: prefix,
27 | })
28 | if err2 != nil {
29 | utils.FatalErrorHandle(&err2, "failed init redis store")
30 | return nil
31 | }
32 | return limiter.New(store, rate)
33 | }
34 |
35 | func GetRedisClient() *libredis.Client {
36 | return redisClient
37 | }
38 |
--------------------------------------------------------------------------------
/pkg/utils/geoDb.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "github.com/oschwald/geoip2-golang"
5 | "github.com/robfig/cron/v3"
6 | "github.com/spf13/viper"
7 | "log"
8 | "sync"
9 | )
10 |
11 | type GeoDbRW struct {
12 | mu sync.RWMutex
13 | geoDb *geoip2.Reader
14 | }
15 |
16 | var GeoDb GeoDbRW
17 |
18 | func (GeoDbRW *GeoDbRW) Get() *geoip2.Reader {
19 | GeoDbRW.mu.RLock()
20 | rtn := GeoDbRW.geoDb
21 | GeoDbRW.mu.RUnlock()
22 | return rtn
23 | }
24 |
25 | func (GeoDbRW *GeoDbRW) Set(item *geoip2.Reader) {
26 | GeoDbRW.mu.Lock()
27 | GeoDbRW.geoDb = item
28 | GeoDbRW.mu.Unlock()
29 | }
30 |
31 | func RefreshGeoDb() {
32 | geoDb, err := geoip2.Open(viper.GetString("mmdb_path"))
33 | if err != nil {
34 | log.Println("geoip2 db load failed. No IP location restrictions would be available.")
35 | } else {
36 | GeoDb.Set(geoDb)
37 | log.Println("geoip2 db loaded.")
38 | }
39 | }
40 |
41 | func InitGeoDbRefreshCron() {
42 | c := cron.New()
43 | _, _ = c.AddFunc("00 05 * * *", func() {
44 | RefreshGeoDb()
45 | })
46 | c.Start()
47 | }
48 |
--------------------------------------------------------------------------------
/cmd/treehollow-v3-security-api/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "github.com/gin-gonic/gin"
5 | "github.com/spf13/viper"
6 | "log"
7 | "time"
8 | "treehollow-v3-backend/pkg/base"
9 | "treehollow-v3-backend/pkg/config"
10 | "treehollow-v3-backend/pkg/consts"
11 | "treehollow-v3-backend/pkg/logger"
12 | "treehollow-v3-backend/pkg/route/security"
13 | "treehollow-v3-backend/pkg/utils"
14 | )
15 |
16 | func main() {
17 | logger.InitLog(consts.SecurityApiLogFile)
18 | config.InitConfigFile()
19 |
20 | //if false == viper.GetBool("is_debug") {
21 | // fmt.Print("Read salt from stdin: ")
22 | // _, _ = fmt.Scanln(&utils.Salt)
23 | // if utils.SHA256(utils.Salt) != viper.GetString("salt_hashed") {
24 | // panic("salt verification failed!")
25 | // }
26 | //}
27 | utils.Salt = viper.GetString("salt")
28 |
29 | base.InitDb()
30 | base.AutoMigrateDb()
31 |
32 | utils.InitGeoDbRefreshCron()
33 |
34 | log.Println("start time: ", time.Now().Format("01-02 15:04:05"))
35 | if false == viper.GetBool("is_debug") {
36 | gin.SetMode(gin.ReleaseMode)
37 | }
38 |
39 | security.ApiListenHttp()
40 | }
41 |
--------------------------------------------------------------------------------
/pkg/model/model.go:
--------------------------------------------------------------------------------
1 | package model
2 |
3 | import "time"
4 |
5 | type Message struct {
6 | Message string
7 | Title string
8 | Extras map[string]interface{}
9 | Time time.Time
10 | }
11 |
12 | type PushType int8
13 |
14 | const (
15 | SystemMessage PushType = 0x01
16 | ReplyMeComment PushType = 0x02
17 | CommentInFavorited PushType = 0x04
18 | )
19 |
20 | type SearchOrder int8
21 |
22 | const (
23 | SearchOrderByID SearchOrder = 0
24 | SearchOrderByLikeNum SearchOrder = 1
25 | SearchOrderByReplyNum SearchOrder = 2
26 | )
27 |
28 | func SearchOrderFromString(s string) (searchOrder SearchOrder) {
29 | switch s {
30 | case "id":
31 | searchOrder = SearchOrderByID
32 | case "like_num":
33 | searchOrder = SearchOrderByLikeNum
34 | case "reply_num":
35 | searchOrder = SearchOrderByReplyNum
36 | default:
37 | searchOrder = SearchOrderByID
38 | }
39 | return
40 | }
41 |
42 | func (searchOrder *SearchOrder) ToString() string {
43 | switch *searchOrder {
44 | case SearchOrderByLikeNum:
45 | return "like_num desc"
46 | case SearchOrderByReplyNum:
47 | return "reply_num desc"
48 | case SearchOrderByID:
49 | return "id desc"
50 | default:
51 | return "id desc"
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/pkg/route/security/updateIOSDeviceToken.go:
--------------------------------------------------------------------------------
1 | package security
2 |
3 | import (
4 | "github.com/gin-gonic/gin"
5 | "net/http"
6 | "treehollow-v3-backend/pkg/base"
7 | "treehollow-v3-backend/pkg/consts"
8 | "treehollow-v3-backend/pkg/logger"
9 | "treehollow-v3-backend/pkg/utils"
10 | )
11 |
12 | func updateIOSToken(c *gin.Context) {
13 | token := c.GetHeader("TOKEN")
14 | iosDeviceToken := c.PostForm("ios_device_token")
15 | if len(iosDeviceToken) < 1 || len(iosDeviceToken) > 100 {
16 | base.HttpReturnWithErrAndAbort(c, -11, logger.NewSimpleError("NoIOSToken", "获取iOS推送口令失败", logger.WARN))
17 | return
18 | }
19 | result := base.GetDb(false).Model(&base.Device{}).
20 | Where("token = ? and created_at > ?", token, utils.GetEarliestAuthenticationTime()).
21 | Update("ios_device_token", iosDeviceToken)
22 | if result.Error != nil {
23 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(result.Error, "UpdateIOSTokenFailed", consts.DatabaseWriteFailedString))
24 | return
25 | }
26 | if result.RowsAffected != 1 {
27 | base.HttpReturnWithErrAndAbort(c, -100, logger.NewSimpleError("NoUpdateDeviceToken", "更新Device Token失败", logger.INFO))
28 | return
29 | }
30 | c.JSON(http.StatusOK, gin.H{
31 | "code": 0,
32 | })
33 | }
34 |
--------------------------------------------------------------------------------
/pkg/route/contents/hotPosts.go:
--------------------------------------------------------------------------------
1 | package contents
2 |
3 | import (
4 | "github.com/robfig/cron/v3"
5 | "github.com/shirou/gopsutil/v3/load"
6 | "github.com/spf13/viper"
7 | "log"
8 | "sync"
9 | "treehollow-v3-backend/pkg/base"
10 | )
11 |
12 | type HotPostsRW struct {
13 | mu sync.RWMutex
14 | hotPosts []base.Post
15 | }
16 |
17 | var HotPosts HotPostsRW
18 |
19 | func (hotPostRW *HotPostsRW) Get() []base.Post {
20 | hotPostRW.mu.RLock()
21 | rtn := hotPostRW.hotPosts
22 | hotPostRW.mu.RUnlock()
23 | return rtn
24 | }
25 |
26 | func (hotPostRW *HotPostsRW) Set(item []base.Post) {
27 | hotPostRW.mu.Lock()
28 | hotPostRW.hotPosts = item
29 | hotPostRW.mu.Unlock()
30 | }
31 |
32 | func RefreshHotPosts() {
33 | avg, err := load.Avg()
34 | if err == nil {
35 | if avg.Load1 <= viper.GetFloat64("sys_load_threshold") {
36 | hotPosts, err2 := base.GetHotPosts()
37 | if err2 == nil {
38 | HotPosts.Set(hotPosts)
39 | } else {
40 | log.Printf("db.GetHotPosts() failed: err=%s\n", err2)
41 | }
42 | }
43 | } else {
44 | log.Printf("load.Avg() failed: err=%s\n", err)
45 | }
46 | }
47 |
48 | func InitHotPostsRefreshCron() {
49 | c := cron.New()
50 | _, _ = c.AddFunc("*/1 * * * *", func() {
51 | RefreshHotPosts()
52 | })
53 | c.Start()
54 | }
55 |
--------------------------------------------------------------------------------
/pkg/utils/aes_test.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "fmt"
5 | "testing"
6 | )
7 |
8 | var (
9 | aesTestCases = []struct {
10 | plainText string
11 | key string
12 | }{
13 | {plainText: "1", key: "3"},
14 | {plainText: "2", key: "4"},
15 | {plainText: "3", key: ""},
16 | {plainText: "1111111111111111111111111111111111111111111111111111111111111111111111111123",
17 | key: "1111111111111111111111111111111111111111111111111111111111111111111111111123"},
18 | }
19 | )
20 |
21 | func TestAes(t *testing.T) {
22 | for _, c := range aesTestCases {
23 | fmt.Println("plaintext length:", len(c.plainText))
24 | fmt.Println("key length:", len(c.key))
25 | cipherText, err := AESEncrypt(c.plainText, c.key)
26 | if err != nil {
27 | t.Errorf(err.Error())
28 | }
29 | cipherText2, err := AESEncrypt(c.plainText, c.key)
30 | if err != nil {
31 | t.Errorf(err.Error())
32 | }
33 | if cipherText2 != cipherText {
34 | t.Errorf("Encryption is random!")
35 | }
36 | fmt.Println("ciphertext length:", len(cipherText))
37 | fmt.Println()
38 |
39 | newPlainText, err2 := AESDecrypt(cipherText, c.key)
40 | if err2 != nil {
41 | t.Errorf(err2.Error())
42 | }
43 | if newPlainText != c.plainText {
44 | t.Errorf("Decrypted text does not match!")
45 | }
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/cmd/test-pressure-gen-tokens/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "github.com/google/uuid"
6 | "os"
7 | "treehollow-v3-backend/pkg/base"
8 | "treehollow-v3-backend/pkg/config"
9 | "treehollow-v3-backend/pkg/utils"
10 | )
11 |
12 | const N = 10000
13 |
14 | func main() {
15 | config.InitConfigFile()
16 | base.InitDb()
17 |
18 | logFile, err := os.OpenFile("pressure_test_tokens.txt", os.O_CREATE|os.O_APPEND|os.O_RDWR, 0666)
19 | if err != nil {
20 | panic(err)
21 | }
22 | defer logFile.Close()
23 |
24 | user := base.User{
25 | EmailEncrypted: "PressureTestUser",
26 | ForgetPwNonce: utils.GenNonce(),
27 | Role: base.NormalUserRole,
28 | }
29 | if err = base.GetDb(false).Create(&user).Error; err != nil {
30 | panic(err)
31 | }
32 |
33 | var devices = make([]base.Device, 0, N)
34 | for i := 0; i < N; i++ {
35 | token := utils.GenToken()
36 | _, _ = fmt.Fprintln(logFile, token)
37 | devices = append(devices, base.Device{
38 | ID: uuid.New().String(),
39 | UserID: user.ID,
40 | Token: token,
41 | DeviceInfo: "PressureTestToken",
42 | Type: base.AndroidDevice,
43 | LoginIP: "127.0.0.1",
44 | LoginCity: "Unknown",
45 | IOSDeviceToken: "",
46 | })
47 | }
48 |
49 | if err = base.GetDb(false).CreateInBatches(&devices, 1000).Error; err != nil {
50 | panic(err)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/pkg/config/config.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "github.com/fsnotify/fsnotify"
5 | "github.com/gin-gonic/gin"
6 | "github.com/spf13/viper"
7 | "log"
8 | "net"
9 | "treehollow-v3-backend/pkg/consts"
10 | "treehollow-v3-backend/pkg/utils"
11 | )
12 |
13 | func refreshAllowedSubnets() {
14 | utils.AllowedSubnets = make([]*net.IPNet, 0)
15 | subnets := viper.GetStringSlice("subnets_whitelist")
16 | for _, subnet := range subnets {
17 | _, tmp, _ := net.ParseCIDR(subnet)
18 | utils.AllowedSubnets = append(utils.AllowedSubnets, tmp)
19 | }
20 | log.Println("subnets: ", subnets)
21 | }
22 |
23 | func refreshConfig() {
24 | refreshAllowedSubnets()
25 | utils.RefreshGeoDb()
26 | viper.SetDefault("sys_load_threshold", consts.SystemLoadThreshold)
27 | viper.SetDefault("ws_ping_period_sec", 90)
28 | viper.SetDefault("ws_pong_timeout_sec", 10)
29 | viper.SetDefault("push_internal_api_listen_address", "127.0.0.1:3009")
30 | }
31 |
32 | func InitConfigFile() {
33 | viper.SetConfigType("yaml")
34 | viper.AddConfigPath(".")
35 | viper.SetConfigFile("config.yml")
36 | err := viper.ReadInConfig() // Find and read the config file
37 | utils.FatalErrorHandle(&err, "error while reading config file")
38 |
39 | viper.WatchConfig()
40 | viper.OnConfigChange(func(e fsnotify.Event) {
41 | log.Println("Config file changed:", e.Name)
42 | refreshConfig()
43 | })
44 | refreshConfig()
45 | }
46 |
47 | func GetFrontendConfigInfo() gin.H {
48 | return gin.H{
49 | "web_frontend_version": viper.GetString("web_frontend_version"),
50 | "announcement": viper.GetString("announcement"),
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/pkg/logger/error.go:
--------------------------------------------------------------------------------
1 | package logger
2 |
3 | import (
4 | "github.com/google/uuid"
5 | "log"
6 | )
7 |
8 | type LogLevel int8
9 |
10 | const (
11 | DEBUG LogLevel = 0
12 | INFO LogLevel = 1
13 | WARN LogLevel = 2
14 | ERROR LogLevel = 3
15 | FATAL LogLevel = 4
16 | )
17 |
18 | func (l *LogLevel) ToString() string {
19 | switch *l {
20 | case DEBUG:
21 | return "DEBUG"
22 | case INFO:
23 | return "INFO "
24 | case WARN:
25 | return "WARN "
26 | case ERROR:
27 | return "ERROR"
28 | case FATAL:
29 | return "FATAL"
30 | }
31 | return "UNKNOWN"
32 | }
33 |
34 | type InternalError struct {
35 | UUID string
36 | Err error
37 | InternalMsg string
38 | DisplayMsg string
39 | level LogLevel
40 | }
41 |
42 | func NewSimpleError(internalMsg string, displayMsg string, level LogLevel) *InternalError {
43 | return newInternalError(nil, internalMsg, displayMsg, level)
44 | }
45 |
46 | func NewError(err error, internalMsg string, displayMsg string) *InternalError {
47 | return newInternalError(err, internalMsg, displayMsg, ERROR)
48 | }
49 |
50 | func newInternalError(err error, internalMsg string, displayMsg string, level LogLevel) *InternalError {
51 | id := ""
52 |
53 | if level > INFO {
54 | id = uuid.New().String()
55 | displayMsg = displayMsg + "(ErrorID: " + id + ")"
56 | }
57 |
58 | return &InternalError{
59 | UUID: id,
60 | Err: err,
61 | InternalMsg: internalMsg,
62 | DisplayMsg: displayMsg,
63 | level: level,
64 | }
65 | }
66 |
67 | func (e *InternalError) Log() {
68 | if e.level >= WARN {
69 | log.Printf("[%s](%s): %s; Err=%s\n", e.level.ToString(), e.UUID, e.InternalMsg, e.Err)
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/pkg/bot/telegram.go:
--------------------------------------------------------------------------------
1 | package bot
2 |
3 | import (
4 | "github.com/spf13/viper"
5 | tb "gopkg.in/tucnak/telebot.v2"
6 | "time"
7 | "treehollow-v3-backend/pkg/utils"
8 | )
9 |
10 | type TgMessage struct {
11 | Text string
12 | ImagePath string
13 | }
14 |
15 | var TgMessageChannel = make(chan TgMessage)
16 |
17 | func InitBot() {
18 | if viper.GetBool("enable_telegram") {
19 | poller := &tb.LongPoller{Timeout: 10 * time.Second}
20 | filteredPoller := tb.NewMiddlewarePoller(poller, func(upd *tb.Update) bool {
21 | if upd.Message == nil {
22 | return false
23 | }
24 |
25 | if upd.Message.Chat.ID == viper.GetInt64("tg_chat_id") {
26 | return true
27 | }
28 |
29 | return false
30 | })
31 |
32 | b, err := tb.NewBot(tb.Settings{
33 | // You can also set custom API URL.
34 | // If field is empty it equals to "https://api.telegram.org".
35 | //URL: "http://195.129.111.17:8012",
36 |
37 | Token: viper.GetString("tg_token"),
38 | Poller: filteredPoller,
39 | })
40 |
41 | if err != nil {
42 | utils.FatalErrorHandle(&err, "Telegram bot init failed")
43 | return
44 | }
45 |
46 | b.Handle("/ping", func(m *tb.Message) {
47 | _, _ = b.Send(m.Sender, "pong!")
48 | })
49 |
50 | go func() {
51 | _, _ = b.Send(tb.ChatID(viper.GetInt64("tg_chat_id")), "Backend bot started!")
52 | b.Start()
53 | }()
54 |
55 | go func() {
56 | for m := range TgMessageChannel {
57 | if len(m.ImagePath) == 0 {
58 | _, _ = b.Send(tb.ChatID(viper.GetInt64("tg_chat_id")), utils.TrimText(m.Text, 4096))
59 | } else {
60 | _, _ = b.Send(tb.ChatID(viper.GetInt64("tg_chat_id")),
61 | &tb.Photo{File: tb.FromDisk(m.ImagePath), Caption: utils.TrimText(m.Text, 1024)})
62 | }
63 | }
64 | }()
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/pkg/route/auth/auth.go:
--------------------------------------------------------------------------------
1 | package auth
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "github.com/gin-gonic/gin"
7 | "github.com/spf13/viper"
8 | "gorm.io/gorm"
9 | "treehollow-v3-backend/pkg/base"
10 | "treehollow-v3-backend/pkg/consts"
11 | "treehollow-v3-backend/pkg/logger"
12 | "treehollow-v3-backend/pkg/utils"
13 | )
14 |
15 | func AuthMiddleware() gin.HandlerFunc {
16 | return func(c *gin.Context) {
17 | token := c.GetHeader("TOKEN")
18 | user, err := base.GetUserWithCache(token)
19 | if err != nil {
20 | fmt.Println(err.Error())
21 | if !errors.Is(err, gorm.ErrRecordNotFound) {
22 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "AuthDbFailed", consts.DatabaseReadFailedString))
23 | return
24 | }
25 | if !viper.GetBool("allow_unregistered_access") && !utils.IsInAllowedSubnet(c.ClientIP()) {
26 | base.HttpReturnWithErrAndAbort(c, -100, logger.NewSimpleError("TokenExpired",
27 | "登录凭据过期,请使用邮箱重新登录。", logger.INFO))
28 | return
29 | } else {
30 | c.Set("user", base.User{ID: -1, Role: base.UnregisteredRole, EmailEncrypted: ""})
31 | c.Next()
32 | }
33 | } else {
34 | if user.Role == base.BannedUserRole {
35 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("AccountFrozen",
36 | "您的账户已被冻结。如果需要解冻,请联系"+
37 | viper.GetString("contact_email")+"。", logger.ERROR))
38 |
39 | return
40 | }
41 | c.Set("user", user)
42 | c.Next()
43 | }
44 | }
45 | }
46 |
47 | func DisallowUnregisteredUsers() gin.HandlerFunc {
48 | return func(c *gin.Context) {
49 | user := c.MustGet("user").(base.User)
50 | if user.Role == base.UnregisteredRole {
51 | base.HttpReturnWithErrAndAbort(c, -100, logger.NewSimpleError("TokenExpired",
52 | "登录凭据过期,请使用邮箱重新登录。", logger.INFO))
53 | return
54 | }
55 | c.Next()
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/pkg/consts/consts.go:
--------------------------------------------------------------------------------
1 | package consts
2 |
3 | import "time"
4 |
5 | //TODO: (low priority)move some of these consts to config file
6 | const PageSize = 30
7 | const MsgPageSize = 10
8 | const SystemLoadThreshold float64 = 5.0
9 | const TokenExpireDays = 31
10 | const WanderPageSize = 15
11 | const MaxPage = 150
12 | const SearchPageSize = 30
13 | const SearchMaxPage = 100
14 | const SearchMaxLength = 30
15 | const PostMaxLength = 10000
16 | const ImageMaxWidth = 10000
17 | const ImageMaxHeight = 10000
18 | const VoteOptionMaxCharacters = 15
19 | const VoteMaxOptions = 4
20 | const MaxDevicesPerUser = 6
21 | const ReportMaxLength = 1000
22 | const ImgMaxLength = 2000000
23 | const Base64Rate = 1.33333333
24 | const AesIv = "12345678901234567890123456789012"
25 |
26 | const PushApiLogFile = "push.log"
27 | const ServicesApiLogFile = "services-api.log"
28 | const DetailLogFile = "detail-log.log"
29 | const SecurityApiLogFile = "security-api.log"
30 |
31 | const DatabaseReadFailedString = "数据库读取失败,请联系管理员"
32 | const DatabaseWriteFailedString = "数据库写入失败,请联系管理员"
33 | const DatabaseDamagedString = "数据库损坏,请联系管理员"
34 | const DatabaseEncryptFailedString = "数据库加密失败,请联系管理员"
35 |
36 | const DzName = "洞主"
37 | const ExtraNamePrefix = "You Win "
38 |
39 | var TimeLoc, _ = time.LoadLocation("Asia/Shanghai")
40 |
41 | var Names0 = []string{
42 | "Angry",
43 | "Baby",
44 | "Crazy",
45 | "Diligent",
46 | "Excited",
47 | "Fat",
48 | "Greedy",
49 | "Hungry",
50 | "Interesting",
51 | "Jolly",
52 | "Kind",
53 | "Little",
54 | "Magic",
55 | "Naïve",
56 | "Old",
57 | "PKU",
58 | "Quiet",
59 | "Rich",
60 | "Superman",
61 | "Tough",
62 | "Undefined",
63 | "Valuable",
64 | "Wifeless",
65 | "Xiangbuchulai",
66 | "Young",
67 | "Zombie",
68 | }
69 |
70 | var Names1 = []string{
71 | "Alice",
72 | "Bob",
73 | "Carol",
74 | "Dave",
75 | "Eve",
76 | "Francis",
77 | "Grace",
78 | "Hans",
79 | "Isabella",
80 | "Jason",
81 | "Kate",
82 | "Louis",
83 | "Margaret",
84 | "Nathan",
85 | "Olivia",
86 | "Paul",
87 | "Queen",
88 | "Richard",
89 | "Susan",
90 | "Thomas",
91 | "Uma",
92 | "Vivian",
93 | "Winnie",
94 | "Xander",
95 | "Yasmine",
96 | "Zach",
97 | }
98 |
--------------------------------------------------------------------------------
/pkg/utils/aes.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "bytes"
5 | "crypto/aes"
6 | "crypto/cipher"
7 | "crypto/sha256"
8 | "encoding/hex"
9 | "errors"
10 | "io"
11 | "treehollow-v3-backend/pkg/consts"
12 | )
13 |
14 | func Pad(src []byte, blockSize int) []byte {
15 | padding := blockSize - len(src)%blockSize
16 | padText := bytes.Repeat([]byte{byte(padding)}, padding)
17 | return append(src, padText...)
18 | }
19 |
20 | func Unpad(src []byte) ([]byte, error) {
21 | length := len(src)
22 | unpadding := int(src[length-1])
23 |
24 | if unpadding > length {
25 | return nil, errors.New("unpad error. This could happen when incorrect encryption key is used")
26 | }
27 |
28 | return src[:(length - unpadding)], nil
29 | }
30 |
31 | //See https://gist.github.com/thuhole/ba41ade1ca97be838ddfcb030306d997
32 | func AESEncrypt(plaintext string, keyStr string) (string, error) {
33 | h := sha256.New()
34 | h.Write([]byte(keyStr))
35 | key := h.Sum(nil)
36 | block, err := aes.NewCipher(key)
37 | if err != nil {
38 | return "", err
39 | }
40 | blockSize := block.BlockSize()
41 |
42 | msg := Pad([]byte(plaintext), blockSize)
43 | ciphertext := make([]byte, blockSize+len(msg))
44 | iv := ciphertext[:blockSize]
45 | if _, err = io.ReadFull(bytes.NewReader([]byte(consts.AesIv)), iv); err != nil {
46 | return "", err
47 | }
48 |
49 | cfb := cipher.NewCFBEncrypter(block, iv)
50 | cfb.XORKeyStream(ciphertext[blockSize:], msg)
51 | finalMsg := hex.EncodeToString(ciphertext)
52 | return finalMsg, nil
53 | }
54 |
55 | func AESDecrypt(ciphertext string, keyStr string) (string, error) {
56 | h := sha256.New()
57 | h.Write([]byte(keyStr))
58 | key := h.Sum(nil)
59 | block, err := aes.NewCipher(key)
60 | if err != nil {
61 | return "", err
62 | }
63 | blockSize := block.BlockSize()
64 |
65 | decodedMsg, err := hex.DecodeString(ciphertext)
66 | if err != nil {
67 | return "", err
68 | }
69 |
70 | if (len(decodedMsg) % blockSize) != 0 {
71 | return "", errors.New("block_size must be multiple of decoded message length")
72 | }
73 |
74 | iv := decodedMsg[:blockSize]
75 | msg := decodedMsg[blockSize:]
76 |
77 | cfb := cipher.NewCFBDecrypter(block, iv)
78 | cfb.XORKeyStream(msg, msg)
79 |
80 | unpadMsg, err := Unpad(msg)
81 | if err != nil {
82 | return "", err
83 | }
84 |
85 | return string(unpadMsg), nil
86 | }
87 |
--------------------------------------------------------------------------------
/pkg/base/send_push.go:
--------------------------------------------------------------------------------
1 | package base
2 |
3 | import (
4 | "bytes"
5 | "encoding/json"
6 | "github.com/spf13/viper"
7 | "gorm.io/gorm"
8 | "log"
9 | "net/http"
10 | "treehollow-v3-backend/pkg/model"
11 | )
12 |
13 | func PreProcessPushMessages(tx *gorm.DB, msgs []PushMessage) error {
14 | var userIDs []int32
15 | for _, msg := range msgs {
16 | userIDs = append(userIDs, msg.UserID)
17 | }
18 |
19 | var pushSettings []PushSettings
20 | err := tx.Model(&PushSettings{}).Where("user_id in (?)", userIDs).
21 | Find(&pushSettings).Error
22 | if err != nil {
23 | log.Printf("read push settings failed: %s", err)
24 | return err
25 | }
26 |
27 | pushSettingsMap := make(map[int32]PushSettings)
28 | for _, s := range pushSettings {
29 | pushSettingsMap[s.UserID] = s
30 | }
31 |
32 | for i, msg := range msgs {
33 | s, ok := pushSettingsMap[msg.UserID]
34 | if ok {
35 | if (s.Settings & msg.Type) > 0 {
36 | msgs[i].DoPush = true
37 | } else {
38 | msgs[i].DoPush = false
39 | }
40 | } else if (msg.Type & (model.SystemMessage | model.ReplyMeComment)) > 0 {
41 | msgs[i].DoPush = true
42 | } else {
43 | msgs[i].DoPush = false
44 | }
45 | }
46 | return nil
47 | }
48 |
49 | func SendToPushService(msgs []PushMessage) {
50 | postBody, _ := json.Marshal(msgs)
51 | bytesBody := bytes.NewBuffer(postBody)
52 | req, err2 := http.NewRequest("POST",
53 | "http://"+viper.GetString("push_internal_api_listen_address")+"/send_messages", bytesBody)
54 | if err2 != nil {
55 | log.Printf("push request build failed: %s\n", err2)
56 | return
57 | }
58 | clientHttp := &http.Client{}
59 | resp, err3 := clientHttp.Do(req)
60 | if err3 != nil {
61 | log.Printf("push failed: %s\n", err3)
62 | return
63 | }
64 | _ = resp.Body.Close()
65 | }
66 |
67 | func SendDeletionToPushService(commentID int32) {
68 | postBody, _ := json.Marshal(commentID)
69 | bytesBody := bytes.NewBuffer(postBody)
70 | req, err2 := http.NewRequest("POST",
71 | "http://"+viper.GetString("push_internal_api_listen_address")+"/delete_messages", bytesBody)
72 | if err2 != nil {
73 | log.Printf("push request build failed: %s\n", err2)
74 | return
75 | }
76 | clientHttp := &http.Client{}
77 | resp, err3 := clientHttp.Do(req)
78 | if err3 != nil {
79 | log.Printf("push failed: %s\n", err3)
80 | return
81 | }
82 | _ = resp.Body.Close()
83 | }
84 |
--------------------------------------------------------------------------------
/pkg/route/contents/push.go:
--------------------------------------------------------------------------------
1 | package contents
2 |
3 | import (
4 | "errors"
5 | "github.com/gin-gonic/gin"
6 | "gorm.io/gorm"
7 | "gorm.io/gorm/clause"
8 | "net/http"
9 | "treehollow-v3-backend/pkg/base"
10 | "treehollow-v3-backend/pkg/consts"
11 | "treehollow-v3-backend/pkg/logger"
12 | "treehollow-v3-backend/pkg/model"
13 | )
14 |
15 | func boolToInt(b bool) int {
16 | if b {
17 | return 1
18 | }
19 | return 0
20 | }
21 |
22 | func getPush(c *gin.Context) {
23 | user := c.MustGet("user").(base.User)
24 | var pushSettings base.PushSettings
25 | err := base.GetDb(false).First(&pushSettings, user.ID).Error
26 | if err != nil {
27 | if errors.Is(err, gorm.ErrRecordNotFound) {
28 | c.JSON(http.StatusOK, gin.H{
29 | "code": 0,
30 | "data": gin.H{
31 | "push_system_msg": 1,
32 | "push_reply_me": 1,
33 | "push_favorited": 0,
34 | },
35 | })
36 | } else {
37 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err, "GetPushSettingsFailed", consts.DatabaseReadFailedString))
38 | }
39 | return
40 | }
41 | c.JSON(http.StatusOK, gin.H{
42 | "code": 0,
43 | "data": gin.H{
44 | "push_system_msg": boolToInt((pushSettings.Settings & model.SystemMessage) > 0),
45 | "push_reply_me": boolToInt((pushSettings.Settings & model.ReplyMeComment) > 0),
46 | "push_favorited": boolToInt((pushSettings.Settings & model.CommentInFavorited) > 0),
47 | },
48 | })
49 | }
50 |
51 | func setPush(c *gin.Context) {
52 | pushSystemMsg := c.PostForm("push_system_msg")
53 | pushReplyMe := c.PostForm("push_reply_me")
54 | pushFavorited := c.PostForm("push_favorited")
55 | user := c.MustGet("user").(base.User)
56 |
57 | var pushSettings model.PushType
58 | if pushSystemMsg == "1" {
59 | pushSettings += model.SystemMessage
60 | }
61 | if pushReplyMe == "1" {
62 | pushSettings += model.ReplyMeComment
63 | }
64 | if pushFavorited == "1" {
65 | pushSettings += model.CommentInFavorited
66 | }
67 |
68 | err := base.GetDb(false).Clauses(clause.OnConflict{
69 | UpdateAll: true,
70 | }).Create(&base.PushSettings{
71 | UserID: user.ID,
72 | Settings: pushSettings,
73 | }).Error
74 | if err != nil {
75 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err, "SavePushSettingsFailed", consts.DatabaseWriteFailedString))
76 | return
77 | }
78 | c.JSON(http.StatusOK, gin.H{
79 | "code": 0,
80 | })
81 | }
82 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module treehollow-v3-backend
2 |
3 | go 1.15
4 |
5 | require (
6 | github.com/ProtonMail/go-crypto v0.0.0-20201208181130-20fe99622a86 // indirect
7 | github.com/ProtonMail/gopenpgp/v2 v2.1.4
8 | github.com/SSSaaS/sssa-golang v0.0.0-20170502204618-d37d7782d752
9 | github.com/aws/aws-sdk-go v1.37.9
10 | github.com/fsnotify/fsnotify v1.4.9
11 | github.com/gin-contrib/cors v1.3.1
12 | github.com/gin-gonic/gin v1.6.3
13 | github.com/go-playground/validator/v10 v10.4.1 // indirect
14 | github.com/go-redis/cache/v8 v8.3.1
15 | github.com/go-redis/redis/v8 v8.5.0
16 | github.com/google/uuid v1.2.0
17 | github.com/gorilla/websocket v1.4.2
18 | github.com/iancoleman/orderedmap v0.2.0
19 | github.com/json-iterator/go v1.1.10 // indirect
20 | github.com/leodido/go-urn v1.2.1 // indirect
21 | github.com/magiconair/properties v1.8.4 // indirect
22 | github.com/mattn/go-isatty v0.0.12 // indirect
23 | github.com/mitchellh/mapstructure v1.4.1 // indirect
24 | github.com/oschwald/geoip2-golang v1.4.0
25 | github.com/oschwald/maxminddb-golang v1.8.0 // indirect
26 | github.com/pelletier/go-toml v1.8.1 // indirect
27 | github.com/pkg/errors v0.9.1
28 | github.com/robfig/cron/v3 v3.0.1
29 | github.com/shirou/gopsutil/v3 v3.21.1
30 | github.com/sideshow/apns2 v0.20.0
31 | github.com/sigurn/crc8 v0.0.0-20160107002456-e55481d6f45c
32 | github.com/sigurn/utils v0.0.0-20190728110027-e1fefb11a144 // indirect
33 | github.com/sirupsen/logrus v1.7.0 // indirect
34 | github.com/spf13/afero v1.5.1 // indirect
35 | github.com/spf13/cast v1.3.1 // indirect
36 | github.com/spf13/jwalterweatherman v1.1.0 // indirect
37 | github.com/spf13/pflag v1.0.5 // indirect
38 | github.com/spf13/viper v1.7.1
39 | github.com/ugorji/go v1.2.4 // indirect
40 | github.com/ulule/limiter/v3 v3.8.0
41 | github.com/vmihailenco/msgpack/v5 v5.1.0
42 | golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad // indirect
43 | golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c // indirect
44 | golang.org/x/text v0.3.5 // indirect
45 | gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
46 | gopkg.in/ezzarghili/recaptcha-go.v4 v4.3.0
47 | gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df
48 | gopkg.in/ini.v1 v1.62.0 // indirect
49 | gopkg.in/tucnak/telebot.v2 v2.3.5 // indirect
50 | gopkg.in/yaml.v2 v2.4.0 // indirect
51 | gorm.io/driver/mysql v1.0.4
52 | gorm.io/gorm v1.20.12
53 | )
54 |
--------------------------------------------------------------------------------
/pkg/route/security/route.go:
--------------------------------------------------------------------------------
1 | package security
2 |
3 | import (
4 | "github.com/gin-contrib/cors"
5 | "github.com/gin-gonic/gin"
6 | "github.com/spf13/viper"
7 | "github.com/ulule/limiter/v3"
8 | "log"
9 | "net"
10 | "net/http"
11 | "os"
12 | "path/filepath"
13 | "strings"
14 | "time"
15 | "treehollow-v3-backend/pkg/base"
16 | "treehollow-v3-backend/pkg/route/contents"
17 | "treehollow-v3-backend/pkg/utils"
18 | )
19 |
20 | func ApiListenHttp() {
21 | r := gin.Default()
22 | corsConfig := cors.DefaultConfig()
23 | corsConfig.AllowAllOrigins = true
24 | corsConfig.AllowHeaders = append(corsConfig.AllowHeaders, "TOKEN")
25 | r.Use(cors.New(corsConfig))
26 |
27 | contents.EmailLimiter = base.InitLimiter(limiter.Rate{
28 | Period: 24 * time.Hour,
29 | Limit: viper.GetInt64("max_email_per_ip_per_day"),
30 | }, "emailLimiter")
31 |
32 | r.POST("/v3/security/login/check_email",
33 | checkEmailParamsCheckMiddleware,
34 | checkEmailRegexMiddleware,
35 | checkEmailIsRegisteredUserMiddleware,
36 | checkEmailIsOldTreeholeUserMiddleware,
37 | checkEmailRateLimitVerificationCode,
38 | checkEmailReCaptchaValidationMiddleware,
39 | checkEmail)
40 | r.POST("/v3/security/login/check_email_unregister",
41 | checkEmailParamsCheckMiddleware,
42 | checkAccountIsRegistered,
43 | checkEmailRateLimitVerificationCode,
44 | checkEmailReCaptchaValidationMiddleware,
45 | unregisterEmail)
46 | r.POST("/v3/security/login/create_account",
47 | loginParamsCheckMiddleware,
48 | checkAccountNotRegistered,
49 | loginCheckIOSToken,
50 | createAccount)
51 | r.POST("/v3/security/login/login",
52 | loginParamsCheckMiddleware,
53 | checkAccountIsRegistered,
54 | loginGetUserMiddleware,
55 | loginCheckMaxDevices,
56 | loginCheckIOSToken,
57 | login)
58 | r.POST("/v3/security/login/change_password",
59 | checkAccountIsRegistered,
60 | changePassword)
61 | r.POST("/v3/security/login/unregister",
62 | checkAccountIsRegistered,
63 | deleteAccount)
64 | r.GET("/v3/security/devices/list", listDevices)
65 | r.POST("/v3/security/devices/terminate", terminateDevice)
66 | r.POST("/v3/security/logout", logout)
67 | r.POST("/v3/security/update_ios_token", updateIOSToken)
68 |
69 | listenAddr := viper.GetString("security_api_listen_address")
70 | if strings.Contains(listenAddr, ":") {
71 | _ = r.Run(listenAddr)
72 | } else {
73 | _ = os.MkdirAll(filepath.Dir(listenAddr), os.ModePerm)
74 | _ = os.Remove(listenAddr)
75 |
76 | listener, err := net.Listen("unix", listenAddr)
77 | utils.FatalErrorHandle(&err, "bind failed")
78 | log.Printf("Listening and serving HTTP on unix: %s.\n"+
79 | "Note: 0777 is not a safe permission for the unix socket file. "+
80 | "It would be better if the user manually set the permission after startup\n",
81 | listenAddr)
82 | _ = os.Chmod(listenAddr, 0777)
83 | err = http.Serve(listener, r)
84 | }
85 | }
86 |
--------------------------------------------------------------------------------
/cmd/test-pressure-connector/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "bufio"
5 | "fmt"
6 | "github.com/gorilla/websocket"
7 | "log"
8 | "net/http"
9 | "net/url"
10 | "os"
11 | "os/signal"
12 | "time"
13 | )
14 |
15 | func newClient(id int, token string, addr string, scheme string) {
16 |
17 | interrupt := make(chan os.Signal, 1)
18 | signal.Notify(interrupt, os.Interrupt)
19 |
20 | u := url.URL{Scheme: scheme, Host: addr, Path: "/v3/stream"}
21 | log.Printf("%d connecting to %s", id, u.String())
22 |
23 | header := http.Header{}
24 | header.Add("TOKEN", token)
25 | c, _, err := websocket.DefaultDialer.Dial(u.String(), header)
26 | if err != nil {
27 | log.Printf("%d dial failed:%s\n", id, err)
28 | return
29 | } else {
30 | log.Printf("%d connected to %s", id, u.String())
31 | }
32 | defer c.Close()
33 |
34 | done := make(chan struct{})
35 |
36 | go func() {
37 | defer close(done)
38 | for {
39 | _, message, err := c.ReadMessage()
40 | if err != nil {
41 | log.Println("read:", err)
42 | return
43 | }
44 | log.Printf("%d recv: %s", id, message)
45 | }
46 | }()
47 |
48 | //ticker := time.NewTicker(time.Second)
49 | //defer ticker.Stop()
50 |
51 | for {
52 | select {
53 | case <-done:
54 | return
55 | //case t := <-ticker.C:
56 | // err := c.WriteMessage(websocket.TextMessage, []byte(t.String()))
57 | // if err != nil {
58 | // log.Println("write:", err)
59 | // return
60 | // }
61 | case <-interrupt:
62 | log.Println("interrupt")
63 |
64 | // Cleanly close the connection by sending a close message and then
65 | // waiting (with timeout) for the server to close the connection.
66 | err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
67 | if err != nil {
68 | log.Printf("%d write close: %s\n", id, err)
69 | return
70 | }
71 | select {
72 | case <-done:
73 | case <-time.After(time.Second):
74 | }
75 | return
76 | }
77 | }
78 | }
79 |
80 | func check(e error) {
81 | if e != nil {
82 | panic(e)
83 | }
84 | }
85 |
86 | func main() {
87 | if len(os.Args) != 4 {
88 | fmt.Println("Please input scheme, host & token file")
89 | return
90 | }
91 | scheme := os.Args[1]
92 | host := os.Args[2]
93 | tokenFile := os.Args[3]
94 |
95 | file, err := os.Open(tokenFile)
96 | check(err)
97 | defer file.Close()
98 |
99 | var tokens []string
100 | scanner := bufio.NewScanner(file)
101 | for scanner.Scan() {
102 | tokens = append(tokens, scanner.Text())
103 | }
104 |
105 | if err = scanner.Err(); err != nil {
106 | panic(err)
107 | }
108 |
109 | for i, t := range tokens {
110 | token := t
111 | i2 := i
112 | go func() {
113 | newClient(i2, token, host, scheme)
114 | }()
115 | time.Sleep(10 * time.Millisecond)
116 | }
117 | select {}
118 | }
119 |
--------------------------------------------------------------------------------
/pkg/route/security/middlewares.go:
--------------------------------------------------------------------------------
1 | package security
2 |
3 | import (
4 | "github.com/gin-gonic/gin"
5 | "strconv"
6 | "strings"
7 | "treehollow-v3-backend/pkg/base"
8 | "treehollow-v3-backend/pkg/consts"
9 | "treehollow-v3-backend/pkg/logger"
10 | "treehollow-v3-backend/pkg/utils"
11 | )
12 |
13 | func loginCheckIOSToken(c *gin.Context) {
14 | deviceType := c.MustGet("device_type").(base.DeviceType)
15 | if deviceType == base.IOSDevice {
16 | iosDeviceToken := c.PostForm("ios_device_token")
17 | //if len(iosDeviceToken) < 1 || len(iosDeviceToken) > 100 {
18 | if len(iosDeviceToken) > 100 {
19 | base.HttpReturnWithErrAndAbort(c, -11, logger.NewSimpleError("NoIOSDeviceToken", "获取iOS推送口令失败", logger.WARN))
20 | return
21 | }
22 | }
23 | c.Next()
24 | }
25 |
26 | func loginParamsCheckMiddleware(c *gin.Context) {
27 | pwHashed := c.PostForm("password_hashed")
28 | email := strings.ToLower(c.PostForm("email"))
29 | deviceTypeStr := c.PostForm("device_type")
30 | deviceInfo := c.PostForm("device_info")
31 | iosDeviceToken := c.PostForm("ios_device_token")
32 |
33 | if len(email) > 100 || len(pwHashed) > 64 || len(deviceInfo) > 100 || len(iosDeviceToken) > 100 {
34 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("LoginParamsOutOfBound", "参数错误", logger.WARN))
35 | return
36 | }
37 | deviceTypeInt, err := strconv.Atoi(deviceTypeStr)
38 | deviceType := base.DeviceType(deviceTypeInt)
39 | if err != nil || (deviceType != base.AndroidDevice &&
40 | deviceType != base.IOSDevice &&
41 | deviceType != base.WebDevice) {
42 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("DeviceTypeError", "参数device_type错误", logger.WARN))
43 | return
44 | }
45 |
46 | c.Set("device_type", deviceType)
47 | c.Next()
48 | }
49 |
50 | func checkAccountNotRegistered(c *gin.Context) {
51 | email := strings.ToLower(c.PostForm("email"))
52 | emailHash := utils.HashEmail(email)
53 |
54 | var count int64
55 | err := base.GetDb(false).Where("email_hash = ?", emailHash).
56 | Model(&base.Email{}).Count(&count).Error
57 | if err != nil {
58 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "CheckAccountRegisteredFailed", consts.DatabaseReadFailedString))
59 | return
60 | }
61 | if count == 1 {
62 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("AlreadyRegisteredError", "你已经注册过了!", logger.WARN))
63 | return
64 | }
65 |
66 | c.Set("email_hash", emailHash)
67 | c.Next()
68 | }
69 |
70 | func checkAccountIsRegistered(c *gin.Context) {
71 | email := strings.ToLower(c.PostForm("email"))
72 | emailHash := utils.HashEmail(email)
73 |
74 | var count int64
75 | err := base.GetDb(false).Where("email_hash = ?", emailHash).
76 | Model(&base.Email{}).Count(&count).Error
77 | if err != nil {
78 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "CheckAccountUnRegisteredFailed", consts.DatabaseReadFailedString))
79 | return
80 | }
81 | if count != 1 {
82 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("NotRegisteredError", "你还没有注册过!", logger.WARN))
83 | return
84 | }
85 |
86 | c.Set("email_hash", emailHash)
87 | c.Next()
88 | }
89 |
--------------------------------------------------------------------------------
/pkg/s3/s3.go:
--------------------------------------------------------------------------------
1 | package s3
2 |
3 | import (
4 | "crypto/hmac"
5 | "crypto/sha1"
6 | "encoding/hex"
7 | "encoding/json"
8 | "errors"
9 | "fmt"
10 | "github.com/aws/aws-sdk-go/aws"
11 | "github.com/aws/aws-sdk-go/aws/credentials"
12 | "github.com/aws/aws-sdk-go/aws/session"
13 | "github.com/aws/aws-sdk-go/service/s3"
14 | "github.com/spf13/viper"
15 | "io"
16 | "io/ioutil"
17 | "log"
18 | "net/http"
19 | "net/url"
20 | "strings"
21 | "time"
22 | )
23 |
24 | func DogeCloudAPI(apiPath string, data map[string]interface{}, jsonMode bool) (ret map[string]interface{}, err error) {
25 | AccessKey := viper.GetString("DCAccessKey")
26 | SecretKey := viper.GetString("DCSecretKey")
27 |
28 | body := ""
29 | mime := ""
30 | if jsonMode {
31 | _body, err := json.Marshal(data)
32 | if err != nil {
33 | log.Fatalln(err)
34 | }
35 | body = string(_body)
36 | mime = "application/json"
37 | } else {
38 | values := url.Values{}
39 | for k, v := range data {
40 | values.Set(k, v.(string))
41 | }
42 | body = values.Encode()
43 | mime = "application/x-www-form-urlencoded"
44 | }
45 |
46 | signStr := apiPath + "\n" + body
47 | hmacObj := hmac.New(sha1.New, []byte(SecretKey))
48 | hmacObj.Write([]byte(signStr))
49 | sign := hex.EncodeToString(hmacObj.Sum(nil))
50 | Authorization := "TOKEN " + AccessKey + ":" + sign
51 |
52 | req, err := http.NewRequest("POST", "https://api.dogecloud.com"+apiPath, strings.NewReader(body))
53 | if err != nil {
54 | return
55 | }
56 | req.Header.Add("Content-Type", mime)
57 | req.Header.Add("Authorization", Authorization)
58 | client := http.Client{Timeout: 15 * time.Second}
59 | resp, err := client.Do(req)
60 | if err != nil {
61 | return
62 | } // 网络错误
63 | defer resp.Body.Close()
64 | r, _ := ioutil.ReadAll(resp.Body)
65 |
66 | _ = json.Unmarshal(r, &ret)
67 |
68 | fmt.Printf("[DogeCloudAPI] code: %d, msg: %s, data: %s\n", int(ret["code"].(float64)), ret["msg"], ret["data"])
69 | return
70 | }
71 |
72 | func Upload(filePath string, fileReader io.ReadSeeker) error {
73 | prof := make(map[string]interface{})
74 | prof["channel"] = "OSS_FULL"
75 | prof["scopes"] = "*"
76 | r, err := DogeCloudAPI("/auth/tmp_token.json", prof, true)
77 | if err != nil {
78 | return err
79 | }
80 | if r["data"] == nil {
81 | return errors.New("invalid DogeCloud response")
82 | }
83 | data := r["data"].(map[string]interface{})
84 | creds := data["Credentials"].(map[string]interface{})
85 |
86 | s3Config := &aws.Config{
87 | Credentials: credentials.NewStaticCredentials(creds["accessKeyId"].(string), creds["secretAccessKey"].(string), creds["sessionToken"].(string)),
88 | Region: aws.String("automatic"),
89 | Endpoint: aws.String(viper.GetString("DCS3Endpoint")),
90 | }
91 |
92 | newSession, err := session.NewSession(s3Config)
93 | if err != nil {
94 | return err
95 | }
96 |
97 | s3Client := s3.New(newSession)
98 |
99 | _, err = s3Client.PutObject(&s3.PutObjectInput{
100 | Bucket: aws.String(viper.GetString("DCS3Bucket")),
101 | Key: aws.String(filePath),
102 | Body: fileReader,
103 | })
104 | return err
105 | }
106 |
--------------------------------------------------------------------------------
/pkg/route/security/devices.go:
--------------------------------------------------------------------------------
1 | package security
2 |
3 | import (
4 | "errors"
5 | "github.com/gin-gonic/gin"
6 | "gorm.io/gorm"
7 | "net/http"
8 | "treehollow-v3-backend/pkg/base"
9 | "treehollow-v3-backend/pkg/consts"
10 | "treehollow-v3-backend/pkg/logger"
11 | "treehollow-v3-backend/pkg/utils"
12 | )
13 |
14 | func devicesToJson(devices []base.Device) []gin.H {
15 | var data []gin.H
16 | for _, device := range devices {
17 | data = append(data, gin.H{
18 | "device_uuid": device.ID,
19 | "login_date": device.CreatedAt.Format("2006-01-02"),
20 | "device_info": device.DeviceInfo,
21 | "device_type": int32(device.Type),
22 | })
23 | }
24 | return data
25 | }
26 |
27 | func listDevices(c *gin.Context) {
28 | token := c.GetHeader("TOKEN")
29 | var device base.Device
30 | err := base.GetDb(false).Model(&base.Device{}).
31 | Where("token = ? and created_at > ?", token, utils.GetEarliestAuthenticationTime()).
32 | First(&device).
33 | Error
34 | if err != nil {
35 | if errors.Is(err, gorm.ErrRecordNotFound) {
36 | base.HttpReturnWithErrAndAbort(c, -100, logger.NewSimpleError("TokenExpired",
37 | "登录凭据过期,请使用邮箱重新登录。", logger.INFO))
38 | } else {
39 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "GetDeviceByTokenFailed", consts.DatabaseReadFailedString))
40 | }
41 | return
42 | }
43 |
44 | var devices []base.Device
45 | err = base.GetDb(false).Model(&base.Device{}).
46 | Where("user_id = ? and created_at > ?", device.UserID, utils.GetEarliestAuthenticationTime()).
47 | Find(&devices).
48 | Error
49 | if err != nil {
50 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "GetDevicesByUserIDFailed", consts.DatabaseReadFailedString))
51 | }
52 | data := devicesToJson(devices)
53 | c.JSON(http.StatusOK, gin.H{
54 | "code": 0,
55 | "data": data,
56 | "this_device": device.ID,
57 | })
58 | }
59 |
60 | func terminateDevice(c *gin.Context) {
61 | token := c.GetHeader("TOKEN")
62 | var device base.Device
63 | err := base.GetDb(false).Model(&base.Device{}).
64 | Where("token = ? and created_at > ?", token, utils.GetEarliestAuthenticationTime()).
65 | First(&device).
66 | Error
67 | if err != nil {
68 | if errors.Is(err, gorm.ErrRecordNotFound) {
69 | base.HttpReturnWithErrAndAbort(c, -100, logger.NewSimpleError("TokenExpired",
70 | "登录凭据过期,请使用邮箱重新登录。", logger.INFO))
71 | } else {
72 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "GetDeviceByTokenFailed", consts.DatabaseReadFailedString))
73 | }
74 | return
75 | }
76 |
77 | deviceUUID := c.PostForm("device_uuid")
78 | result := base.GetDb(false).
79 | Where("user_id = ? and id = ? and created_at > ?", device.UserID, deviceUUID, utils.GetEarliestAuthenticationTime()).
80 | Delete(&base.Device{})
81 | if result.Error != nil {
82 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(result.Error, "DeleteDeviceByUUIDFailed", consts.DatabaseWriteFailedString))
83 | return
84 | }
85 | if result.RowsAffected != 1 {
86 | base.HttpReturnWithErrAndAbort(c, -100, logger.NewSimpleError("NoDeviceFound", "找不到这个设备。", logger.WARN))
87 | return
88 | }
89 | c.JSON(http.StatusOK, gin.H{
90 | "code": 0,
91 | })
92 | _ = base.DelUserCache(device.Token)
93 | }
94 |
--------------------------------------------------------------------------------
/pkg/base/hooks.go:
--------------------------------------------------------------------------------
1 | package base
2 |
3 | import (
4 | "gorm.io/gorm"
5 | "math"
6 | "strconv"
7 | "treehollow-v3-backend/pkg/model"
8 | "treehollow-v3-backend/pkg/utils"
9 | )
10 |
11 | //TODO: (high priority)delete push messages for ios and android
12 | //Set the first registered user to be superuser
13 | func (u *User) AfterCreate(tx *gorm.DB) (err error) {
14 | if u.ID == 1 {
15 | err = tx.Model(u).Update("role", SuperUserRole).Error
16 | }
17 | return
18 | }
19 |
20 | func (post *Post) AfterCreate(tx *gorm.DB) (err error) {
21 | err = tx.Create(&Attention{UserID: post.UserID, PostID: post.ID}).Error
22 | return
23 | }
24 |
25 | func (comment *Comment) AfterCreate(tx *gorm.DB) (err error) {
26 | var attention int64
27 | err = tx.Model(&Attention{}).Where(&Attention{UserID: comment.UserID, PostID: comment.PostID}).Count(&attention).Error
28 | if err == nil && attention == 0 {
29 | err = tx.Create(&Attention{UserID: comment.UserID, PostID: comment.PostID}).Error
30 | }
31 | if err == nil {
32 | err = tx.Model(&Post{}).Where("id = ?", comment.PostID).
33 | Update("reply_num", gorm.Expr("reply_num + 1")).Error
34 | }
35 | return
36 | }
37 |
38 | func (attention *Attention) AfterCreate(tx *gorm.DB) (err error) {
39 | err = tx.Table("posts").Where("id = ?", attention.PostID).
40 | UpdateColumn("like_num", gorm.Expr("like_num + 1")).Error
41 | return
42 | }
43 |
44 | func (attention *Attention) AfterDelete(tx *gorm.DB) (err error) {
45 | err = tx.Table("posts").Where("id = ?", attention.PostID).
46 | UpdateColumn("like_num", gorm.Expr("like_num - 1")).Error
47 | return
48 | }
49 |
50 | func (report *Report) AfterCreate(tx *gorm.DB) (err error) {
51 | if report.Type == UserReport && !report.IsComment {
52 | err = tx.Table("posts").Where("id = ?", report.PostID).
53 | UpdateColumn("report_num", gorm.Expr("report_num + 1")).Error
54 | }
55 | return
56 | }
57 |
58 | func calcReportedTimes(ban *Ban) string {
59 | return strconv.Itoa(int(math.Round(float64(ban.ExpireAt-ban.CreatedAt.Unix()) / 86400.0)))
60 | }
61 |
62 | func (ban *Ban) AfterCreate(tx *gorm.DB) (err error) {
63 | err = tx.Create(&SystemMessage{
64 | UserID: ban.UserID,
65 | BanID: ban.ID,
66 | Title: "封禁提示",
67 | Text: ban.Reason + "\n\n这是您第" + calcReportedTimes(ban) + "次被举报,在" + utils.TimestampToString(ban.ExpireAt) + "之前您将无法发布树洞。",
68 | }).Error
69 | return
70 | }
71 |
72 | //TODO: (low priority)maybe, show reason here?
73 | func (ban *Ban) AfterDelete(tx *gorm.DB) (err error) {
74 | err = tx.Create(&SystemMessage{
75 | UserID: ban.UserID,
76 | BanID: ban.ID,
77 | Title: "解除封禁提示",
78 | Text: "您的以下封禁已被管理员手动解除:\n\n\"" + ban.Reason + "\"",
79 | }).Error
80 | return
81 | }
82 |
83 | func (msg *SystemMessage) AfterCreate(tx *gorm.DB) error {
84 | msgs := []PushMessage{{
85 | UpdatedAt: msg.CreatedAt,
86 | Title: msg.Title,
87 | UserID: msg.UserID,
88 | Message: msg.Text,
89 | BanID: msg.BanID,
90 | Type: model.SystemMessage,
91 | }}
92 | err := PreProcessPushMessages(tx, msgs)
93 | if err != nil {
94 | return err
95 | }
96 |
97 | go func() {
98 | SendToPushService(msgs)
99 | }()
100 | return nil
101 | }
102 |
--------------------------------------------------------------------------------
/pkg/push/route.go:
--------------------------------------------------------------------------------
1 | package push
2 |
3 | import (
4 | "github.com/gin-contrib/cors"
5 | "github.com/gin-gonic/gin"
6 | "github.com/spf13/viper"
7 | "log"
8 | "net"
9 | "net/http"
10 | "os"
11 | "path/filepath"
12 | "strings"
13 | "time"
14 | "treehollow-v3-backend/pkg/base"
15 | "treehollow-v3-backend/pkg/route/auth"
16 | "treehollow-v3-backend/pkg/utils"
17 | )
18 |
19 | func ApiListenHttp() {
20 | r := gin.Default()
21 | corsConfig := cors.DefaultConfig()
22 | corsConfig.AllowAllOrigins = true
23 | corsConfig.AllowHeaders = append(corsConfig.AllowHeaders, "TOKEN")
24 | r.Use(cors.New(corsConfig))
25 |
26 | Api = New(time.Duration(viper.GetInt64("ws_ping_period_sec"))*time.Second,
27 | time.Duration(viper.GetInt64("ws_pong_timeout_sec"))*time.Second)
28 |
29 | go func() {
30 | r2 := gin.Default()
31 | r2.POST("/send_messages", func(c *gin.Context) {
32 | var messages []base.PushMessage
33 | //data, err := ioutil.ReadAll(c.Request.Body)
34 | c.String(http.StatusOK, "")
35 | err := c.BindJSON(&messages)
36 | if err != nil {
37 | //base.HttpReturnWithCodeMinusOne(c, logger.NewError(err, "error reading request body", "error reading request body"))
38 | log.Printf("push service read request body error: %s\n", err)
39 | return
40 | }
41 | SendMessages(messages, Api, false)
42 | })
43 | r2.POST("/delete_messages", func(c *gin.Context) {
44 | var commendID int32
45 | //data, err := ioutil.ReadAll(c.Request.Body)
46 | c.String(http.StatusOK, "")
47 | err := c.BindJSON(&commendID)
48 | if err != nil {
49 | //base.HttpReturnWithCodeMinusOne(c, logger.NewError(err, "error reading request body", "error reading request body"))
50 | log.Printf("push deletion service read request body error: %s\n", err)
51 | return
52 | }
53 |
54 | var msgs []base.PushMessage
55 | err = base.GetDb(false).Model(&base.PushMessage{}).
56 | Where("comment_id = ? and do_push = 1", commendID).
57 | Find(&msgs).Error
58 | if err != nil {
59 | log.Printf("push deletion service read push messages error: %s\n", err)
60 | return
61 | }
62 |
63 | err = base.GetDb(false).Where("comment_id = ?", commendID).Delete(&base.PushMessage{}).Error
64 | if err != nil {
65 | log.Printf("push deletion service delete push messages error: %s\n", err)
66 | return
67 | }
68 |
69 | SendMessages(msgs, Api, true)
70 | })
71 | _ = r2.Run(viper.GetString("push_internal_api_listen_address"))
72 | }()
73 |
74 | r.Use(auth.AuthMiddleware())
75 | r.GET("/v3/stream",
76 | auth.DisallowUnregisteredUsers(),
77 | Api.Handle)
78 |
79 | listenAddr := viper.GetString("push_api_listen_address")
80 | if strings.Contains(listenAddr, ":") {
81 | _ = r.Run(listenAddr)
82 | } else {
83 | _ = os.MkdirAll(filepath.Dir(listenAddr), os.ModePerm)
84 | _ = os.Remove(listenAddr)
85 |
86 | listener, err := net.Listen("unix", listenAddr)
87 | utils.FatalErrorHandle(&err, "bind failed")
88 | log.Printf("Listening and serving HTTP on unix: %s.\n"+
89 | "Note: 0777 is not a safe permission for the unix socket file. "+
90 | "It would be better if the user manually set the permission after startup\n",
91 | listenAddr)
92 | _ = os.Chmod(listenAddr, 0777)
93 | err = http.Serve(listener, r)
94 | }
95 |
96 | }
97 |
--------------------------------------------------------------------------------
/pkg/push/client.go:
--------------------------------------------------------------------------------
1 | // from https://github.com/gotify/server/blob/3454dcd60226acf121009975d947f05d41267283/api/stream/client.go
2 | package push
3 |
4 | import (
5 | "log"
6 | "time"
7 |
8 | "github.com/gorilla/websocket"
9 | )
10 |
11 | const (
12 | writeWait = 2 * time.Second
13 | )
14 |
15 | var ping = func(conn *websocket.Conn) error {
16 | return conn.WriteMessage(websocket.PingMessage, nil)
17 | }
18 |
19 | var writeBytes = func(conn *websocket.Conn, data []byte) error {
20 | return conn.WriteMessage(websocket.TextMessage, data)
21 | }
22 |
23 | type client struct {
24 | conn *websocket.Conn
25 | onClose func(*client)
26 | write chan *[]byte
27 | token string
28 | once once
29 | }
30 |
31 | func newClient(conn *websocket.Conn, token string, onClose func(*client)) *client {
32 | return &client{
33 | conn: conn,
34 | write: make(chan *[]byte, 1),
35 | token: token,
36 | onClose: onClose,
37 | }
38 | }
39 |
40 | // Close closes the connection.
41 | func (c *client) Close() {
42 | c.once.Do(func() {
43 | _ = c.conn.Close()
44 | close(c.write)
45 | })
46 | }
47 |
48 | // NotifyClose closes the connection and notifies that the connection was closed.
49 | func (c *client) NotifyClose() {
50 | c.once.Do(func() {
51 | _ = c.conn.Close()
52 | close(c.write)
53 | c.onClose(c)
54 | })
55 | }
56 |
57 | // startWriteHandler starts listening on the client connection. As we do not need anything from the client,
58 | // we ignore incoming messages. Leaves the loop on errors.
59 | func (c *client) startReading(pongWait time.Duration) {
60 | defer c.NotifyClose()
61 | c.conn.SetReadLimit(64)
62 | _ = c.conn.SetReadDeadline(time.Now().Add(pongWait))
63 | c.conn.SetPongHandler(func(appData string) error {
64 | _ = c.conn.SetReadDeadline(time.Now().Add(pongWait))
65 | return nil
66 | })
67 | for {
68 | if _, _, err := c.conn.NextReader(); err != nil {
69 | printWebSocketError("ReadError", err)
70 | return
71 | }
72 | }
73 | }
74 |
75 | // startWriteHandler starts the write loop. The method has the following tasks:
76 | // * ping the client in the interval provided as parameter
77 | // * write messages send by the channel to the client
78 | // * on errors exit the loop.
79 | func (c *client) startWriteHandler(pingPeriod time.Duration) {
80 | pingTicker := time.NewTicker(pingPeriod)
81 | defer func() {
82 | c.NotifyClose()
83 | pingTicker.Stop()
84 | }()
85 |
86 | for {
87 | select {
88 | case message, ok := <-c.write:
89 | if !ok {
90 | return
91 | }
92 |
93 | _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait))
94 | if err := writeBytes(c.conn, *message); err != nil {
95 | printWebSocketError("WriteError", err)
96 | return
97 | }
98 | case <-pingTicker.C:
99 | _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait))
100 | if err := ping(c.conn); err != nil {
101 | printWebSocketError("PingError", err)
102 | return
103 | }
104 | }
105 | }
106 | }
107 |
108 | func printWebSocketError(prefix string, err error) {
109 | closeError, ok := err.(*websocket.CloseError)
110 |
111 | if ok && closeError != nil && (closeError.Code == 1000 || closeError.Code == 1001) {
112 | // normal closure
113 | return
114 | }
115 |
116 | log.Printf("error WebSocket: %s %s", prefix, err)
117 | }
118 |
--------------------------------------------------------------------------------
/pkg/route/contents/vote.go:
--------------------------------------------------------------------------------
1 | package contents
2 |
3 | import (
4 | "encoding/json"
5 | "errors"
6 | "github.com/gin-gonic/gin"
7 | "github.com/iancoleman/orderedmap"
8 | "gorm.io/gorm"
9 | "gorm.io/gorm/clause"
10 | "net/http"
11 | "strconv"
12 | "treehollow-v3-backend/pkg/base"
13 | "treehollow-v3-backend/pkg/consts"
14 | "treehollow-v3-backend/pkg/logger"
15 | "treehollow-v3-backend/pkg/utils"
16 | )
17 |
18 | func sendVote(c *gin.Context) {
19 | user := c.MustGet("user").(base.User)
20 | canViewDelete := base.CanViewDeletedPost(&user)
21 | option := c.PostForm("option")
22 |
23 | pid, err := strconv.Atoi(c.PostForm("pid"))
24 | if err != nil {
25 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err, "SendVoteInvalidPid", "投票操作失败,pid不合法"))
26 | return
27 | }
28 |
29 | _ = base.GetDb(false).Transaction(func(tx *gorm.DB) error {
30 | var post base.Post
31 |
32 | err3 := utils.UnscopedTx(tx, canViewDelete).Clauses(clause.Locking{Strength: "UPDATE"}).
33 | First(&post, int32(pid)).Error
34 | if err3 != nil {
35 | if errors.Is(err3, gorm.ErrRecordNotFound) {
36 | base.HttpReturnWithCodeMinusOne(c, logger.NewSimpleError("SendVoteNoPid", "投票失败,pid不存在", logger.WARN))
37 | } else {
38 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err3, "SendVoteFailedGetPost", consts.DatabaseReadFailedString))
39 | }
40 | return err3
41 | }
42 |
43 | voteData := orderedmap.New()
44 | err = json.Unmarshal([]byte(post.VoteData), &voteData)
45 | if err != nil {
46 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err, "BadVoteData", consts.DatabaseDamagedString))
47 | return nil
48 | }
49 | voteOptionCount, optionExist := voteData.Get(option)
50 | if !optionExist {
51 | base.HttpReturnWithCodeMinusOne(c, logger.NewSimpleError("VoteNoOption", "投票失败,选项不存在", logger.ERROR))
52 | return nil
53 | }
54 |
55 | var count int64
56 | err3 = tx.Model(&base.Vote{}).Where("user_id = ? and post_id = ?", user.ID, post.ID).
57 | Count(&count).Error
58 | if err3 != nil {
59 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err3, "SendVoteFailedGetCount", consts.DatabaseReadFailedString))
60 | return err3
61 | }
62 | if count > 0 {
63 | base.HttpReturnWithCodeMinusOne(c, logger.NewSimpleError("AlreadyVoted", "投票失败,已经投过票了", logger.WARN))
64 | return errors.New("投票失败,已经投过票了")
65 | }
66 |
67 | err3 = tx.Create(&base.Vote{
68 | PostID: post.ID,
69 | UserID: user.ID,
70 | Option: option,
71 | }).Error
72 | if err3 != nil {
73 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err3, "SaveVoteFailed", consts.DatabaseWriteFailedString))
74 | return err3
75 | }
76 |
77 | voteData.Set(option, voteOptionCount.(float64)+1)
78 | _newVoteData, _ := json.Marshal(voteData)
79 | newVoteData := string(_newVoteData)
80 |
81 | err3 = tx.Table("posts").Where("id = ?", post.ID).
82 | UpdateColumn("vote_data", newVoteData).Error
83 | if err3 != nil {
84 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err3, "SaveVoteFPostFailed", consts.DatabaseWriteFailedString))
85 | return err3
86 | }
87 |
88 | c.JSON(http.StatusOK, gin.H{
89 | "code": 0,
90 | "vote": gin.H{
91 | "voted": option,
92 | "vote_options": voteData.Keys(),
93 | "vote_data": voteData,
94 | },
95 | })
96 |
97 | return nil
98 | })
99 | }
100 |
--------------------------------------------------------------------------------
/pkg/base/permissions.go:
--------------------------------------------------------------------------------
1 | package base
2 |
3 | import (
4 | "treehollow-v3-backend/pkg/utils"
5 | )
6 |
7 | func GetPermissionsByPost(user *User, post *Post) []string {
8 | return getPermissions(user, post, false)
9 | }
10 |
11 | func isDeleter(role UserRole) bool {
12 | return role == DeleterRole || role == Deleter2Role || role == Deleter3Role
13 | }
14 |
15 | func getPermissions(user *User, post *Post, isComment bool) []string {
16 | rtn := []string{"report"}
17 | if !isComment {
18 | rtn = append(rtn, "fold")
19 | }
20 | timestamp := utils.GetTimeStamp()
21 | if (user.Role == AdminRole || user.Role == SuperUserRole ||
22 | ((timestamp-post.CreatedAt.Unix() <= 120) && (user.ID == post.UserID))) && (!post.DeletedAt.Valid) {
23 | rtn = append(rtn, "delete")
24 | }
25 |
26 | if user.Role == AdminRole || user.Role == SuperUserRole {
27 | rtn = append(rtn, "set_tag")
28 | if post.DeletedAt.Valid {
29 | rtn = append(rtn, "unban")
30 | rtn = append(rtn, "undelete_unban")
31 | } else {
32 | rtn = append(rtn, "delete_ban")
33 | }
34 | } else if (timestamp-post.CreatedAt.Unix() <= 172800) && isDeleter(user.Role) && !post.DeletedAt.Valid {
35 | rtn = append(rtn, "delete_ban")
36 | } else if (timestamp-post.CreatedAt.Unix() <= 172800) && user.Role == UnDeleterRole && post.DeletedAt.Valid {
37 | rtn = append(rtn, "undelete_unban")
38 | }
39 |
40 | return rtn
41 | }
42 |
43 | func GetPermissionsByComment(user *User, comment *Comment) []string {
44 | return getPermissions(user, &Post{
45 | DeletedAt: comment.DeletedAt,
46 | CreatedAt: comment.CreatedAt,
47 | UserID: comment.UserID,
48 | }, true)
49 | }
50 |
51 | func GetReportWeight(user *User) int32 {
52 | return 10
53 | }
54 |
55 | func NeedLimiter(user *User) bool {
56 | return user.Role == NormalUserRole || isDeleter(user.Role) || user.Role == UnDeleterRole
57 | }
58 |
59 | func CanViewDeletedPost(user *User) bool {
60 | return user.Role == AdminRole || user.Role == UnDeleterRole ||
61 | user.Role == SuperUserRole
62 | }
63 |
64 | func GetDeletePostRateLimitIn24h(userRole UserRole) int64 {
65 | switch userRole {
66 | case SuperUserRole:
67 | return 10000
68 | case AdminRole:
69 | return 20
70 | case DeleterRole:
71 | return 20
72 | case Deleter2Role:
73 | return 5
74 | case Deleter3Role:
75 | return 0
76 | default:
77 | return 0
78 | }
79 | }
80 |
81 | func CanOverrideBan(user *User) bool {
82 | return user.Role == AdminRole || isDeleter(user.Role) || user.Role == UnDeleterRole ||
83 | user.Role == SuperUserRole
84 | }
85 |
86 | func CanViewStatistics(user *User) bool {
87 | return user.Role == SuperUserRole || user.Role == AdminRole
88 | }
89 |
90 | func CanViewAllSystemMessages(user *User) bool {
91 | return user.Role == SuperUserRole || user.Role == AdminRole
92 | }
93 |
94 | func CanViewReports(user *User) bool {
95 | return user.Role == AdminRole || isDeleter(user.Role) || user.Role == UnDeleterRole ||
96 | user.Role == SuperUserRole
97 | }
98 |
99 | func CanViewLogs(user *User) bool {
100 | return user.Role == SuperUserRole
101 | }
102 |
103 | func CanShowHelp(user *User) bool {
104 | return user.Role == AdminRole || isDeleter(user.Role) || user.Role == UnDeleterRole ||
105 | user.Role == SuperUserRole
106 | }
107 |
108 | func CanShutdown(user *User) bool {
109 | return user.Role == SuperUserRole
110 | }
111 |
112 | func CanViewDecryptionMessages(user *User) bool {
113 | return user.Role == SuperUserRole
114 | }
115 |
--------------------------------------------------------------------------------
/pkg/mail/mail.go:
--------------------------------------------------------------------------------
1 | package mail
2 |
3 | import (
4 | "github.com/spf13/viper"
5 | "gopkg.in/gomail.v2"
6 | "strconv"
7 | )
8 |
9 | func SendValidationEmail(code string, recipient string) error {
10 | websiteName := viper.GetString("name")
11 | m := gomail.NewMessage()
12 | m.SetHeader("From", viper.GetString("smtp_username"))
13 | m.SetHeader("To", recipient)
14 | title := "【" + websiteName + "】验证码"
15 | m.SetHeader("Subject", title)
16 |
17 | msg := `
18 |
19 |
20 |
21 |
22 | ` + title + `
23 |
24 |
25 | 欢迎您注册` + websiteName + `!
26 | 这是您的验证码,有效时间12小时。
27 | ` + code + `
28 |
29 | `
30 |
31 | port, err := strconv.Atoi(viper.GetString("smtp_port"))
32 | if err != nil {
33 | return err
34 | }
35 | m.SetBody("text/html", msg)
36 | m.AddAlternative("text/plain", "您好:\n\n欢迎您注册"+websiteName+"!\n\n"+code+"\n这是您注册"+websiteName+"的验证码,有效时间12小时。\n")
37 | d := gomail.NewDialer(viper.GetString("smtp_host"), port, viper.GetString("smtp_username"), viper.GetString("smtp_password"))
38 |
39 | if err = d.DialAndSend(m); err != nil {
40 | return err
41 | }
42 | return nil
43 | }
44 |
45 | func SendUnregisterValidationEmail(code string, recipient string) error {
46 | websiteName := viper.GetString("name")
47 | m := gomail.NewMessage()
48 | m.SetHeader("From", viper.GetString("smtp_username"))
49 | m.SetHeader("To", recipient)
50 | title := "【" + websiteName + "】验证码"
51 | m.SetHeader("Subject", title)
52 |
53 | msg := `
54 |
55 |
56 |
57 |
58 | ` + title + `
59 |
60 |
61 | 您好,您正在注销` + websiteName + `。
62 | 这是您的验证码,有效时间12小时。
63 | ` + code + `
64 |
65 | `
66 |
67 | port, err := strconv.Atoi(viper.GetString("smtp_port"))
68 | if err != nil {
69 | return err
70 | }
71 | m.SetBody("text/html", msg)
72 | m.AddAlternative("text/plain", "您好:\n\n您好,您正在注销"+websiteName+"。\n\n"+code+"\n这是您的验证码,有效时间12小时。\n")
73 | d := gomail.NewDialer(viper.GetString("smtp_host"), port, viper.GetString("smtp_username"), viper.GetString("smtp_password"))
74 |
75 | if err = d.DialAndSend(m); err != nil {
76 | return err
77 | }
78 | return nil
79 | }
80 |
81 | func SendPasswordNonceEmail(nonce string, recipient string) error {
82 | websiteName := viper.GetString("name")
83 | m := gomail.NewMessage()
84 | m.SetHeader("From", viper.GetString("smtp_username"))
85 | m.SetHeader("To", recipient)
86 | title := "欢迎您注册" + websiteName
87 | m.SetHeader("Subject", title)
88 |
89 | msg := `
90 |
91 |
92 |
93 |
94 | ` + title + `
95 |
96 |
97 | 欢迎您注册` + websiteName + `!
98 | 下方的字符串是当您忘记密码时可以帮助您找回密码的口令,请您妥善保管。
99 | ` + nonce + `
100 |
101 | `
102 |
103 | port, err := strconv.Atoi(viper.GetString("smtp_port"))
104 | if err != nil {
105 | return err
106 | }
107 | m.SetBody("text/html", msg)
108 | m.AddAlternative("text/plain", "您好:\n\n欢迎您注册"+websiteName+"!\n下方的字符串是当您忘记密码时可以帮助您找回密码的口令,请您妥善保管。\n"+nonce+"\n")
109 | d := gomail.NewDialer(viper.GetString("smtp_host"), port, viper.GetString("smtp_username"), viper.GetString("smtp_password"))
110 |
111 | if err = d.DialAndSend(m); err != nil {
112 | return err
113 | }
114 | return nil
115 | }
116 |
--------------------------------------------------------------------------------
/pkg/base/cache.go:
--------------------------------------------------------------------------------
1 | package base
2 |
3 | import (
4 | "context"
5 | "github.com/go-redis/cache/v8"
6 | "gorm.io/gorm"
7 | "log"
8 | "strconv"
9 | "time"
10 | "treehollow-v3-backend/pkg/consts"
11 | "treehollow-v3-backend/pkg/logger"
12 | "treehollow-v3-backend/pkg/utils"
13 | )
14 |
15 | var tokenCache *cache.Cache
16 | var commentCache *cache.Cache
17 |
18 | const CommentCacheExpireTime = 5 * time.Hour
19 | const TOKENCacheExpireTime = 1 * time.Minute
20 |
21 | func initCache() {
22 | tokenCache = cache.New(&cache.Options{Redis: redisClient})
23 | commentCache = cache.New(&cache.Options{Redis: redisClient})
24 | }
25 |
26 | func GetUserWithCache(token string) (User, error) {
27 | ctx := context.TODO()
28 | var user User
29 | err := tokenCache.Get(ctx, "token"+token, &user)
30 | if err == nil {
31 | return user, nil
32 | } else {
33 | subQuery := db.Model(&Device{}).Distinct().
34 | Where("token = ? and created_at > ?", token, utils.GetEarliestAuthenticationTime()).
35 | Select("user_id")
36 | err = db.Where("id = (?)", subQuery).First(&user).Error
37 | if err == nil {
38 | err = tokenCache.Set(&cache.Item{
39 | Ctx: ctx,
40 | Key: "token" + token,
41 | Value: &user,
42 | TTL: TOKENCacheExpireTime,
43 | })
44 | }
45 | return user, err
46 | }
47 | }
48 |
49 | func DelUserCache(token string) error {
50 | ctx := context.TODO()
51 | err := tokenCache.Delete(ctx, "token"+token)
52 | if err != nil {
53 | log.Printf("DelUserCache error: %s\n", err)
54 | }
55 | return err
56 | }
57 |
58 | func GetCommentsWithCache(post *Post, now time.Time) ([]Comment, error) {
59 | pid := post.ID
60 | if !NeedCacheComment(post, now) {
61 | return GetComments(pid)
62 | }
63 |
64 | ctx := context.TODO()
65 | pidStr := strconv.Itoa(int(pid))
66 | var comments []Comment
67 | err := commentCache.Get(ctx, "pid"+pidStr, &comments)
68 | if err == nil {
69 | return comments, err
70 | } else {
71 | comments, err = GetComments(pid)
72 | if err == nil {
73 | err = commentCache.Set(&cache.Item{
74 | Ctx: ctx,
75 | Key: "pid" + pidStr,
76 | Value: &comments,
77 | TTL: CommentCacheExpireTime,
78 | })
79 | }
80 | return comments, err
81 | }
82 | }
83 |
84 | func GetMultipleCommentsWithCache(tx *gorm.DB, posts []Post, now time.Time) (map[int32][]Comment, *logger.InternalError) {
85 | ctx := context.TODO()
86 | rtn := make(map[int32][]Comment)
87 | noCachePids := make(map[int32]bool)
88 | var noCachePidsArray []int32
89 | for _, post := range posts {
90 | pid := post.ID
91 | if !NeedCacheComment(&post, now) {
92 | noCachePids[pid] = true
93 | noCachePidsArray = append(noCachePidsArray, pid)
94 | continue
95 | }
96 |
97 | pidStr := strconv.Itoa(int(pid))
98 | var comments []Comment
99 | err := commentCache.Get(ctx, "pid"+pidStr, &comments)
100 | if err == nil {
101 | rtn[pid] = comments
102 | } else {
103 | noCachePids[pid] = true
104 | noCachePidsArray = append(noCachePidsArray, pid)
105 | continue
106 | }
107 | }
108 |
109 | if len(noCachePidsArray) > 0 {
110 | comments, err := GetMultipleComments(tx, noCachePidsArray)
111 | if err != nil {
112 | return nil, logger.NewError(err, "SQLGetMultipleCommentsFailed", consts.DatabaseReadFailedString)
113 | }
114 | for _, comment := range comments {
115 | rtn[comment.PostID] = append(rtn[comment.PostID], comment)
116 | }
117 | }
118 | for _, post := range posts {
119 | pid := post.ID
120 | if _, noCache := noCachePids[pid]; noCache {
121 | comments2, commentsExist := rtn[pid]
122 | if !commentsExist {
123 | comments2 = []Comment{}
124 | }
125 | err := commentCache.Set(&cache.Item{
126 | Ctx: ctx,
127 | Key: "pid" + strconv.Itoa(int(pid)),
128 | Value: &comments2,
129 | TTL: CommentCacheExpireTime,
130 | })
131 | if err != nil {
132 | return nil, logger.NewError(err, "CommentCacheSetFailed", consts.DatabaseReadFailedString)
133 | }
134 | }
135 | }
136 | return rtn, nil
137 | }
138 |
139 | func DelCommentCache(pid int) error {
140 | ctx := context.TODO()
141 | err := commentCache.Delete(ctx, "pid"+strconv.Itoa(pid))
142 | if err != nil {
143 | log.Printf("DelCommentCache error: %s\n", err)
144 | }
145 | return err
146 | }
147 |
148 | func NeedCacheComment(post *Post, now time.Time) bool {
149 | return now.Before(post.CreatedAt.AddDate(0, 0, 365))
150 | //return now.Before(post.CreatedAt.AddDate(0, 0, 2))
151 | }
152 |
--------------------------------------------------------------------------------
/pkg/route/security/deleteAccount.go:
--------------------------------------------------------------------------------
1 | package security
2 |
3 | import (
4 | "errors"
5 | "github.com/gin-gonic/gin"
6 | "github.com/spf13/viper"
7 | "gorm.io/gorm"
8 | "net/http"
9 | "strconv"
10 | "strings"
11 | "treehollow-v3-backend/pkg/base"
12 | "treehollow-v3-backend/pkg/consts"
13 | "treehollow-v3-backend/pkg/logger"
14 | "treehollow-v3-backend/pkg/utils"
15 | )
16 |
17 | func deleteAccount(c *gin.Context) {
18 | email := strings.ToLower(c.PostForm("email"))
19 | emailHash := utils.HashEmail(email)
20 | nonce := c.PostForm("nonce")
21 | code := c.PostForm("valid_code")
22 | now := utils.GetTimeStamp()
23 | if len(nonce) < 10 {
24 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("NonceNotEnoughLong", "Nonce错误", logger.INFO))
25 | return
26 | }
27 |
28 | correctCode, timeStamp, failedTimes, err2 := base.GetVerificationCode(emailHash)
29 | if err2 != nil && !errors.Is(err2, gorm.ErrRecordNotFound) {
30 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err2, "QueryValidCodeFailed", consts.DatabaseReadFailedString))
31 | return
32 | }
33 | if failedTimes >= 10 && now-timeStamp <= 43200 {
34 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("ValidCodeTooMuchFailed", "验证码错误尝试次数过多,请重新发送验证码", logger.INFO))
35 | return
36 | }
37 | if correctCode != code || now-timeStamp > 43200 {
38 | base.HttpReturnWithErrAndAbort(c, -10, logger.NewSimpleError("ValidCodeInvalid", "验证码无效或过期", logger.WARN))
39 | _ = base.GetDb(false).Model(&base.VerificationCode{}).Where("email_hash = ?", emailHash).
40 | Update("failed_times", gorm.Expr("failed_times + 1")).Error
41 | return
42 | }
43 |
44 | _ = base.GetDb(false).Transaction(func(tx *gorm.DB) error {
45 | var user base.User
46 | err := tx.Model(&base.User{}).Where("forget_pw_nonce = ?", nonce).First(&user).Error
47 |
48 | if err != nil {
49 | if errors.Is(err, gorm.ErrRecordNotFound) {
50 | base.HttpReturnWithCodeMinusOne(c, logger.NewSimpleError("NonceNotFound",
51 | "没有找到nonce对应的账户。请你重新查看刚刚注册树洞后收到的欢迎邮件中的“找回密码口令”(nonce)。"+
52 | "如果仍然无法解决问题,请联系"+viper.GetString("contact_email")+"。", logger.WARN))
53 | return errors.New("NonceNotFound")
54 | }
55 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "DeleteNonceFailed", consts.DatabaseReadFailedString))
56 | return err
57 | }
58 |
59 | if user.Role == base.BannedUserRole {
60 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("DeleteAccountFrozen",
61 | "您的账户已被冻结,无法注销。如果需要解冻,请联系"+
62 | viper.GetString("contact_email")+"。", logger.ERROR))
63 |
64 | return errors.New("DeleteBannedAccount")
65 | }
66 |
67 | if user.CreatedAt.After(utils.GetEarliestAuthenticationTime()) {
68 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("DeleteJustRegisteredAccount",
69 | "注销失败,账户需要注册"+strconv.Itoa(consts.TokenExpireDays)+"天以上才可以注销。", logger.ERROR))
70 |
71 | return errors.New("DeleteJustRegisteredAccount")
72 | }
73 |
74 | timestamp := utils.GetTimeStamp()
75 | var count int64
76 | err3 := tx.Model(&base.Ban{}).Where("user_id = ? and expire_at > ?", user.ID, timestamp).Count(&count).Error
77 | if err3 != nil {
78 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err3, "GetBanFailed", consts.DatabaseReadFailedString))
79 | return err3
80 | }
81 |
82 | if count > 0 {
83 | base.HttpReturnWithCodeMinusOneAndAbort(c,
84 | logger.NewSimpleError("DisallowDeleteWhileBan", "很抱歉,您当前处于禁言状态,无法注销。", logger.ERROR))
85 | return errors.New("DisallowDeleteWhileBan")
86 | }
87 |
88 | result := tx.Where("forget_pw_nonce = ?", nonce).
89 | Delete(&base.User{})
90 |
91 | if result.Error != nil {
92 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(result.Error, "DeleteNonceFailed", consts.DatabaseWriteFailedString))
93 | return result.Error
94 | }
95 |
96 | result = tx.Where("email_hash = ?", emailHash).
97 | Delete(&base.Email{})
98 |
99 | if result.Error != nil {
100 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(result.Error, "DeleteEmailHashFailed", consts.DatabaseWriteFailedString))
101 | return result.Error
102 | }
103 |
104 | if result.RowsAffected == 0 {
105 | base.HttpReturnWithCodeMinusOne(c, logger.NewSimpleError("EmailNotFound",
106 | "没有找到此邮箱对应的账户", logger.WARN))
107 | return errors.New("EmailNotFound")
108 | }
109 |
110 | c.JSON(http.StatusOK, gin.H{
111 | "code": 0,
112 | })
113 | return nil
114 | })
115 | }
116 |
--------------------------------------------------------------------------------
/example.config.yml:
--------------------------------------------------------------------------------
1 | ######################## 必填配置
2 | ######################## 以下配置必须更改后才可运行程序
3 |
4 | ### 网站名称
5 | name: T大树洞
6 |
7 | ### 联系邮箱
8 | contact_email: contact@thuhole.com
9 |
10 | ### 检查邮箱是否合法的正则表达式。可参见https://html.spec.whatwg.org/multipage/input.html#valid-e-mail-address
11 | email_check_regex: ^[a-zA-Z]+[-]*[a-zA-Z]*[0-9]*@(mails\.tsinghua\.edu\.cn)$
12 |
13 | ### 图床域名。注意以"/"结尾
14 | img_base_url: https://img.thuhole.com/
15 |
16 | ### 图片文件存储文件夹。
17 | images_path: /path/to/images/folder
18 |
19 | ### 是否是debug模式
20 | is_debug: true
21 |
22 | ### 用户邮箱在数据库中会通过一个盐被单向加密。
23 | salt: CHANGE_ME!
24 |
25 | ### 树洞内容服务的监听端口和登录服务的监听端口(或unix socket)。这两个端口和程序分开的。
26 | security_api_listen_address: /tmp/treehollow/treehollow-security-api.sock
27 | services_api_listen_address: /tmp/treehollow/treehollow-services-api.sock
28 | push_api_listen_address: /tmp/treehollow/treehollow-push-api.sock
29 |
30 | ### Google reCAPTCHA v3密钥。需要前往reCAPTCHA官网获取。https://developers.google.com/recaptcha/docs/v3
31 | recaptcha_v3_private_key: YOUR_v3_KEY
32 | ### Google reCAPTCHA v2密钥。reCAPTCHA v3不通过时,
33 | ### 会使用v2的基于图片的人机识别验证。需要前往reCAPTCHA官网获取。https://developers.google.com/recaptcha/docs/display
34 | recaptcha_v2_private_key: YOUR_v2_KEY
35 |
36 | ### redis服务配置
37 | # Tcp connection:
38 | # redis://:@:/
39 | # Unix connection:
40 | # unix://:@?db=
41 | redis_source: unix:///var/run/redis/redis.sock?db=0
42 |
43 | ### MySQL服务配置
44 | # see https://pear.php.net/manual/en/package.database.db.intro-dsn.php for detail
45 | # e.g. user@unix(/path/to/socket)/pear
46 | # e.g. user:pass@tcp(localhost:5555)/pear
47 | sql_source: USER:PASSWORD@unix(/var/lib/mysql/mysql.sock)/DB_NAME
48 |
49 | ### SMTP配置
50 | smtp_host: smtp.thuhole.com
51 | smtp_password: YOUR_PASSWORD
52 | smtp_username: noreply@thuhole.com
53 | smtp_port: 465
54 |
55 | ### 最少解密所需人数
56 | min_decryption_key_count: 3
57 | ### 密钥保管员的PGP密钥列表
58 | # 注意:必须使用\n转义
59 | key_keepers_pgp_public_keys: [
60 | "-----BEGIN PGP PUBLIC KEY BLOCK-----\n\n...\n-----END PGP PUBLIC KEY BLOCK-----",
61 | "-----BEGIN PGP PUBLIC KEY BLOCK-----\n\n...\n-----END PGP PUBLIC KEY BLOCK-----",
62 | "-----BEGIN PGP PUBLIC KEY BLOCK-----\n\n...\n-----END PGP PUBLIC KEY BLOCK-----",
63 | "-----BEGIN PGP PUBLIC KEY BLOCK-----\n\n...\n-----END PGP PUBLIC KEY BLOCK-----",
64 | "-----BEGIN PGP PUBLIC KEY BLOCK-----\n\n...\n-----END PGP PUBLIC KEY BLOCK-----",
65 | "-----BEGIN PGP PUBLIC KEY BLOCK-----\n\n...\n-----END PGP PUBLIC KEY BLOCK-----"
66 | ]
67 |
68 |
69 |
70 | ######################## 可选配置
71 | ######################## 以下配置不必要更改
72 |
73 | ### 是否允许未注册用户浏览树洞
74 | allow_unregistered_access: true
75 |
76 | ### IP地址库文件,见 https://dev.maxmind.com/geoip/geoip2/geolite2/
77 | mmdb_path: /usr/local/share/GeoIP/GeoLite2-City.mmdb
78 | ### 允许注册的IP国家列表,需要保证mmdb_path的IP地址库可用
79 | allowed_register_countries:
80 | - 美国
81 | - 中国
82 |
83 | ### 公告
84 | ### TODO: (low priority)this announcement is not synchronous with config.txt of app
85 | announcement: This is dev server.
86 |
87 | ### 图床CDN配置。这里使用了Dogecloud CDN(https://www.dogecloud.com/)。
88 | dcaccesskey: ""
89 | dcs3bucket: ""
90 | dcs3endpoint: ""
91 | dcsecretkey: ""
92 |
93 | ### 不允许用户举报的树洞号列表
94 | disallow_report_pids:
95 | - 118
96 | - 1
97 | - 4
98 |
99 | ### 折叠检测的正则表达式
100 | fold_regex: ([A-Za-z0-9+/]{100})|(BEGIN PUBLIC KEY)
101 | ### "性相关"折叠检测的正则表达式
102 | sex_related_regex: (porn(.*)hub|杜蕾斯)
103 | ### 折叠tag列表
104 | reportable_tags:
105 | - 性相关
106 | - 政治相关
107 | - NSFW
108 | - 刷屏
109 | - 重复内容
110 | - 引战
111 | - 未经证实的传闻
112 | - 令人不适
113 | sendable_tags:
114 | - 性相关
115 | - 政治相关
116 | - NSFW
117 | - 刷屏
118 | - 引战
119 | - 未经证实的传闻
120 | - 令人不适
121 |
122 | ### 备用图床域名
123 | img_base_url_bak: https://img2.thuhole.com/
124 |
125 | ### 每个IP每天最多发送多少注册邮件
126 | max_email_per_ip_per_day: 10
127 |
128 | ### 置顶的树洞号列表
129 | pin_pids: [ ]
130 |
131 | ### reCAPTCHA v3人机认证的通过分数阈值
132 | recaptcha_threshold: 0.5
133 |
134 | ### 当allow_unregistered_access=false时,允许未登录游客访问的IP白名单
135 | subnets_whitelist: [ ]
136 |
137 | ### 最新树洞Web版的版本号。树洞Web版会通过检查此版本号自动更新。
138 | web_frontend_version: v0.0.0
139 |
140 | ### 当服务宕机时的提示信息
141 | fallback_announcement: 系统正在维护升级,请稍后重试...
142 |
143 | ### 服务宕机时的错误提示程序监听端口
144 | fallback_listen_address: 127.0.0.1:8082
145 |
146 | ### iOS apn2 push settings
147 | ios_push_auth_file: ./file.p12
148 | ios_push_auth_password: password
149 |
150 | ### OR
151 | #ios_push_auth_file: ./file.p8
152 | #ios_push_key_id: 123
153 | #ios_push_team_id: 123
154 |
155 | ### Only need to edit this config when [push service] and [other services] are at different servers
156 | push_internal_api_listen_address: 127.0.0.1:3009
157 |
158 | ### 除了邮箱正则表达式之外,额外允许注册的邮箱白名单
159 | email_whitelist: []
160 |
161 | ### 允许/help等管理员命令
162 | allow_admin_commands: true
163 |
--------------------------------------------------------------------------------
/pkg/route/security/login.go:
--------------------------------------------------------------------------------
1 | package security
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "github.com/gin-gonic/gin"
7 | "github.com/google/uuid"
8 | "github.com/spf13/viper"
9 | "gorm.io/gorm"
10 | "log"
11 | "net"
12 | "net/http"
13 | "strings"
14 | "time"
15 | "treehollow-v3-backend/pkg/base"
16 | "treehollow-v3-backend/pkg/consts"
17 | "treehollow-v3-backend/pkg/logger"
18 | "treehollow-v3-backend/pkg/utils"
19 | )
20 |
21 | func loginGetUserMiddleware(c *gin.Context) {
22 | pwHashed := c.PostForm("password_hashed")
23 | email := strings.ToLower(c.PostForm("email"))
24 |
25 | emailEncrypted, err := utils.AESEncrypt(email, pwHashed)
26 | if err != nil {
27 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "AESEncryptFailedInMiddleware", consts.DatabaseEncryptFailedString))
28 | return
29 | }
30 |
31 | var user base.User
32 |
33 | err = base.GetDb(false).Where("email_encrypted = ?", emailEncrypted).
34 | Model(&base.User{}).First(&user).Error
35 | if err != nil {
36 | if errors.Is(err, gorm.ErrRecordNotFound) {
37 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("MiddlewareNoAuth", "用户名或密码错误", logger.WARN))
38 | } else {
39 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "GetUserByEmailEncryptedFailed", consts.DatabaseReadFailedString))
40 | }
41 | return
42 | }
43 | if user.Role == base.BannedUserRole {
44 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("AccountFrozen",
45 | "您的账户已被冻结。如果需要解冻,请联系"+
46 | viper.GetString("contact_email")+"。", logger.ERROR))
47 |
48 | return
49 | }
50 |
51 | c.Set("user", user)
52 | c.Next()
53 | }
54 |
55 | func loginCheckMaxDevices(c *gin.Context) {
56 | user := c.MustGet("user").(base.User)
57 |
58 | var count int64
59 | err := base.GetDb(false).
60 | Where("user_id = ? and created_at > ?", user.ID, utils.GetEarliestAuthenticationTime()).
61 | Model(&base.Device{}).Count(&count).Error
62 | if err != nil {
63 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "GetEarliestDeviceFailed", consts.DatabaseReadFailedString))
64 | return
65 | }
66 | if count >= consts.MaxDevicesPerUser {
67 | log.Printf("user login more than max allowed: %d\n", user.ID)
68 | _ = base.GetDb(false).
69 | Where("user_id = ? and created_at > ?", user.ID, utils.GetEarliestAuthenticationTime()).
70 | Order("created_at asc").Limit(1).
71 | Delete(&base.Device{}).Error
72 | return
73 | }
74 | c.Next()
75 | }
76 |
77 | func login(c *gin.Context) {
78 | user := c.MustGet("user").(base.User)
79 | token := utils.GenToken()
80 | deviceUUID := uuid.New().String()
81 | deviceType := c.MustGet("device_type").(base.DeviceType)
82 | deviceInfo := c.PostForm("device_info")
83 | city := "Unknown"
84 |
85 | if geoDb := utils.GeoDb.Get(); geoDb != nil {
86 | ip := net.ParseIP(c.ClientIP())
87 | record, err5 := geoDb.City(ip)
88 | if err5 == nil {
89 | country := record.Country.Names["zh-CN"]
90 | if len(country) == 0 {
91 | country = record.Country.Names["en"]
92 | }
93 | if len(country) > 0 {
94 | cityName := record.City.Names["zh-CN"]
95 | if len(cityName) == 0 {
96 | cityName = record.City.Names["en"]
97 | }
98 | if len(cityName) > 0 {
99 | city = cityName + ", " + country
100 | } else {
101 | city = country
102 | }
103 | }
104 | }
105 | }
106 |
107 | err := base.GetDb(false).Create(&base.Device{
108 | ID: deviceUUID,
109 | UserID: user.ID,
110 | Token: token,
111 | DeviceInfo: deviceInfo,
112 | Type: deviceType,
113 | LoginIP: c.ClientIP(),
114 | LoginCity: city,
115 | IOSDeviceToken: c.PostForm("ios_device_token"),
116 | }).Error
117 | if err != nil {
118 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err, "SaveDeviceWhileLoginFailed", consts.DatabaseWriteFailedString))
119 | return
120 | }
121 |
122 | c.JSON(http.StatusOK, gin.H{
123 | "code": 0,
124 | "token": token,
125 | "uuid": deviceUUID,
126 | })
127 | _ = base.GetDb(false).Create(&base.SystemMessage{
128 | UserID: user.ID,
129 | Title: "新的登录",
130 | Text: fmt.Sprintf("您好,您的账户在%s于%s使用设备\"%s\"登录。\n\n如果这不是您本人所为,请您立刻修改密码。",
131 | time.Now().Format("2006-01-02 15:04"), city, deviceInfo),
132 | BanID: -1,
133 | }).Error
134 | //TODO: (middle priority) send email
135 | return
136 | }
137 |
138 | func logout(c *gin.Context) {
139 | token := c.GetHeader("TOKEN")
140 | result := base.GetDb(false).
141 | Where("token = ? and created_at > ?", token, utils.GetEarliestAuthenticationTime()).
142 | Delete(&base.Device{})
143 | if result.Error != nil {
144 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(result.Error, "DeleteDeviceFailed", consts.DatabaseWriteFailedString))
145 | return
146 | }
147 | if result.RowsAffected != 1 {
148 | base.HttpReturnWithErrAndAbort(c, -100, logger.NewSimpleError("TokenExpired",
149 | "登录凭据过期,请使用邮箱重新登录。", logger.INFO))
150 | return
151 | }
152 | c.JSON(http.StatusOK, gin.H{
153 | "code": 0,
154 | })
155 | }
156 |
--------------------------------------------------------------------------------
/pkg/push/notification.go:
--------------------------------------------------------------------------------
1 | package push
2 |
3 | import (
4 | "encoding/json"
5 | "github.com/sideshow/apns2"
6 | "github.com/sideshow/apns2/certificate"
7 | "github.com/sideshow/apns2/payload"
8 | "github.com/sideshow/apns2/token"
9 | "github.com/spf13/viper"
10 | "log"
11 | "strconv"
12 | "strings"
13 | "treehollow-v3-backend/pkg/base"
14 | "treehollow-v3-backend/pkg/model"
15 | "treehollow-v3-backend/pkg/utils"
16 | )
17 |
18 | func getIOSPushClient() (*apns2.Client, error) {
19 | authFile := viper.GetString("ios_push_auth_file")
20 | if strings.HasSuffix(authFile, "p12") {
21 | certPass := viper.GetString("ios_push_auth_password")
22 | cert, err := certificate.FromP12File(authFile, certPass)
23 | if err != nil {
24 | log.Printf("ios push cert error: %s\n", err)
25 | return nil, err
26 | }
27 | return apns2.NewClient(cert), nil
28 | }
29 | authKey, err := token.AuthKeyFromFile(authFile)
30 | if err != nil {
31 | log.Printf("ios push token error: %s\n", err)
32 | return nil, err
33 | }
34 | t := &token.Token{
35 | AuthKey: authKey,
36 | // KeyID from developer account (Certificates, Identifiers & Profiles -> Keys)
37 | KeyID: viper.GetString("ios_push_key_id"),
38 | // TeamID from developer account (View Account -> Membership)
39 | TeamID: viper.GetString("ios_push_team_id"),
40 | }
41 | pushClient := apns2.NewTokenClient(t)
42 | return pushClient, nil
43 | }
44 |
45 | func SendMessages(msgs []base.PushMessage, Api *API, isDoRecall bool) {
46 | var err error
47 |
48 | if !isDoRecall {
49 | err = base.GetDb(false).Create(&msgs).Error
50 | if err != nil {
51 | log.Printf("create push messages failed: %s", err)
52 | return
53 | }
54 | }
55 |
56 | pushUserIDs := make([]int32, 0, len(msgs))
57 | pushMap := make(map[int32]*base.PushMessage)
58 | for _, msg := range msgs {
59 | if msg.DoPush {
60 | pushUserIDs = append(pushUserIDs, msg.UserID)
61 | pushMap[msg.UserID] = &msg
62 | }
63 | }
64 |
65 | var devices []base.Device
66 | err = base.GetDb(false).Model(&base.Device{}).Where("user_id in (?)", pushUserIDs).
67 | Find(&devices).Error
68 | if err != nil {
69 | log.Printf("read push devices failed: %s", err)
70 | return
71 | }
72 |
73 | pushClient, err := getIOSPushClient()
74 | if err != nil {
75 | log.Printf("getIOSPushClient error: %s\n", err)
76 | return
77 | }
78 |
79 | //TODO: (middle priority)fix "recall before push" bug
80 | for _, device := range devices {
81 | msg := pushMap[device.UserID]
82 | switch device.Type {
83 | case base.IOSDevice:
84 | if len(device.IOSDeviceToken) > 0 {
85 | var p *payload.Payload
86 | if isDoRecall {
87 | p = payload.NewPayload().AlertTitle("消息已被删除").Custom("delete", 1).
88 | Custom("pid", msg.PostID)
89 | } else {
90 | p = payload.NewPayload().AlertTitle(utils.TrimText(msg.Title, 50)).
91 | AlertBody(utils.TrimText(msg.Message, 100)).Sound("default")
92 | if (msg.Type & (model.ReplyMeComment | model.CommentInFavorited)) > 0 {
93 | p = p.Custom("pid", msg.PostID).Custom("cid", msg.CommentID)
94 | }
95 | p.Custom("type", msg.Type)
96 | if viper.GetBool("is_debug") {
97 | log.Printf("ios push notification: %v\n", msg)
98 | }
99 | }
100 | res, err2 := pushClient.Production().Push(&apns2.Notification{
101 | DeviceToken: device.IOSDeviceToken,
102 | Topic: "treehollow.Hollow",
103 | Payload: p,
104 | CollapseID: strconv.Itoa(int(msg.ID)),
105 | })
106 |
107 | if err2 != nil {
108 | log.Printf("production push ios notifation failed: %s", err2)
109 | }
110 |
111 | if viper.GetBool("is_debug") {
112 | if res != nil {
113 | log.Printf("production push notification response: %s %d, %s\n", p, res.StatusCode, res.Reason)
114 | }
115 |
116 | res2, err2 := pushClient.Development().Push(&apns2.Notification{
117 | DeviceToken: device.IOSDeviceToken,
118 | Topic: "treehollow.Hollow",
119 | Payload: p,
120 | CollapseID: strconv.Itoa(int(msg.ID)),
121 | })
122 |
123 | if err2 != nil {
124 | log.Printf("dev push ios notifation failed: %s", err2)
125 | }
126 |
127 | if res2 != nil {
128 | log.Printf("dev push notification response: %s %d, %s\n", p, res2.StatusCode, res2.Reason)
129 | }
130 | }
131 | }
132 | case base.AndroidDevice:
133 | var p map[string]interface{}
134 | if isDoRecall {
135 | p = map[string]interface{}{
136 | "id": msg.ID,
137 | "delete": 1,
138 | }
139 | } else {
140 | p = map[string]interface{}{
141 | "id": msg.ID,
142 | "title": msg.Title,
143 | "body": utils.TrimText(msg.Message, 100),
144 | "type": msg.Type,
145 | "timestamp": msg.UpdatedAt.Unix(),
146 | }
147 | if (msg.Type & (model.ReplyMeComment | model.CommentInFavorited)) > 0 {
148 | p["pid"] = msg.PostID
149 | p["cid"] = msg.CommentID
150 | }
151 | }
152 | postBody, _ := json.Marshal(p)
153 | Api.Notify(device.Token, &postBody)
154 | }
155 | }
156 | }
157 |
--------------------------------------------------------------------------------
/pkg/route/contents/route.go:
--------------------------------------------------------------------------------
1 | package contents
2 |
3 | import (
4 | "github.com/gin-contrib/cors"
5 | "github.com/gin-gonic/gin"
6 | "github.com/robfig/cron/v3"
7 | "github.com/spf13/viper"
8 | "log"
9 | "net"
10 | "net/http"
11 | "os"
12 | "path/filepath"
13 | "strings"
14 | "treehollow-v3-backend/pkg/bot"
15 | "treehollow-v3-backend/pkg/consts"
16 | "treehollow-v3-backend/pkg/logger"
17 | "treehollow-v3-backend/pkg/logger/ginLogger"
18 | "treehollow-v3-backend/pkg/route/auth"
19 | "treehollow-v3-backend/pkg/utils"
20 | )
21 |
22 | func ServicesApiListenHttp() {
23 | r := gin.New()
24 |
25 | bot.InitBot()
26 | initLimiters()
27 | shutdownCountDown = 2
28 | c := cron.New()
29 | _, _ = c.AddFunc("0 0 * * *", func() {
30 | shutdownCountDown = 2
31 | })
32 |
33 | corsConfig := cors.DefaultConfig()
34 | corsConfig.AllowAllOrigins = true
35 | corsConfig.AllowHeaders = append(corsConfig.AllowHeaders, "TOKEN")
36 | if viper.GetBool("debug_log") {
37 | logFile, err := os.OpenFile(consts.DetailLogFile, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0666)
38 | if err != nil {
39 | panic(err)
40 | }
41 |
42 | r.Use(cors.New(corsConfig), auth.AuthMiddleware(), ginLogger.LoggerWithConfig(ginLogger.LoggerConfig{
43 | Output: logFile,
44 | }), gin.Recovery())
45 | } else {
46 | r.Use(gin.Logger(), gin.Recovery(), cors.New(corsConfig), auth.AuthMiddleware())
47 | }
48 | r.POST("/v3/config/set_push",
49 | auth.DisallowUnregisteredUsers(),
50 | setPush)
51 | r.GET("/v3/config/get_push",
52 | auth.DisallowUnregisteredUsers(),
53 | getPush)
54 | r.GET("/v3/contents/system_msg",
55 | auth.DisallowUnregisteredUsers(),
56 | systemMsg)
57 | r.GET("/v3/contents/post/list",
58 | checkParameterPage(consts.MaxPage),
59 | listPost)
60 | r.GET("/v3/contents/post/randomlist",
61 | limiterMiddleware(randomListLimiter, "你今天刷了太多树洞了,明天再来吧", logger.WARN),
62 | wanderListPost)
63 | r.GET("/v3/contents/post/detail",
64 | limiterMiddleware(detailPostLimiter, "你今天刷了太多树洞了,明天再来吧", logger.WARN),
65 | detailPost)
66 | r.GET("/v3/contents/search",
67 | checkParameterPage(consts.SearchMaxPage),
68 | limiterMiddleware(searchShortTimeLimiter, "请不要短时间内连续搜索树洞", logger.INFO),
69 | limiterMiddleware(searchLimiter, "你今天搜索太多树洞了,明天再来吧", logger.WARN),
70 | searchHotPosts(),
71 | adminHelpCommand(),
72 | adminDecryptionCommand(),
73 | adminLogsCommand(),
74 | adminReportsCommand(),
75 | adminStatisticsCommand(),
76 | adminSysMsgsCommand(),
77 | adminShutdownCommand(),
78 | sysLoadWarningMiddleware(viper.GetFloat64("sys_load_threshold"), "目前树洞服务器负载较高,搜索功能已被暂时停用"),
79 | searchPost)
80 | r.GET("/v3/contents/post/attentions",
81 | auth.DisallowUnregisteredUsers(),
82 | checkParameterPage(consts.MaxPage),
83 | attentionPosts)
84 | r.GET("/v3/contents/my_msgs",
85 | auth.DisallowUnregisteredUsers(),
86 | checkParameterPage(consts.MaxPage),
87 | myMsgs)
88 | r.GET("/v3/contents/search/attentions",
89 | auth.DisallowUnregisteredUsers(),
90 | checkParameterPage(consts.SearchMaxPage),
91 | limiterMiddleware(searchShortTimeLimiter, "请不要短时间内连续搜索树洞", logger.INFO),
92 | limiterMiddleware(searchLimiter, "你今天搜索太多树洞了,明天再来吧", logger.WARN),
93 | searchAttentionPost)
94 | r.POST("/v3/send/post",
95 | auth.DisallowUnregisteredUsers(),
96 | limiterMiddleware(postLimiter, "请不要短时间内连续发送树洞", logger.INFO),
97 | limiterMiddleware(postLimiter2, "你24小时内已经发送太多树洞了", logger.WARN),
98 | disallowBannedPostUsers(),
99 | checkParameterTextAndImage(),
100 | checkParameterVoteOptions,
101 | sendPost)
102 | r.POST("/v3/send/vote",
103 | auth.DisallowUnregisteredUsers(),
104 | disallowBannedPostUsers(),
105 | sendVote)
106 | r.POST("/v3/send/comment",
107 | auth.DisallowUnregisteredUsers(),
108 | limiterMiddleware(commentLimiter, "请不要短时间内连续发送树洞回复", logger.INFO),
109 | limiterMiddleware(commentLimiter2, "你24小时内已经发送太多树洞回复了", logger.WARN),
110 | disallowBannedPostUsers(),
111 | checkParameterTextAndImage(),
112 | sendComment)
113 | r.POST("/v3/edit/attention",
114 | auth.DisallowUnregisteredUsers(),
115 | limiterMiddleware(doAttentionLimiter, "你今天关注太多树洞了,明天再来吧", logger.WARN),
116 | editAttention)
117 | r.POST("/v3/edit/report/post",
118 | auth.DisallowUnregisteredUsers(),
119 | checkReportParams(true),
120 | handleReport(false))
121 | r.POST("/v3/edit/report/comment",
122 | auth.DisallowUnregisteredUsers(),
123 | checkReportParams(false),
124 | handleReport(true))
125 |
126 | listenAddr := viper.GetString("services_api_listen_address")
127 | if strings.Contains(listenAddr, ":") {
128 | _ = r.Run(listenAddr)
129 | } else {
130 | _ = os.MkdirAll(filepath.Dir(listenAddr), os.ModePerm)
131 | _ = os.Remove(listenAddr)
132 |
133 | listener, err := net.Listen("unix", listenAddr)
134 | utils.FatalErrorHandle(&err, "bind failed")
135 | log.Printf("Listening and serving HTTP on unix: %s.\n"+
136 | "Note: 0777 is not a safe permission for the unix socket file. "+
137 | "It would be better if the user manually set the permission after startup\n",
138 | listenAddr)
139 | _ = os.Chmod(listenAddr, 0777)
140 | err = http.Serve(listener, r)
141 | return
142 | }
143 | }
144 |
--------------------------------------------------------------------------------
/pkg/push/stream.go:
--------------------------------------------------------------------------------
1 | // from https://github.com/gotify/server/blob/3454dcd60226acf121009975d947f05d41267283/api/stream/stream.go
2 | package push
3 |
4 | import (
5 | "fmt"
6 | "github.com/gin-gonic/gin"
7 | "github.com/gorilla/websocket"
8 | "github.com/spf13/viper"
9 | "net/http"
10 | "sync"
11 | "time"
12 | )
13 |
14 | // The API provides a handler for a WebSocket stream API.
15 | type API struct {
16 | clients map[string]*client
17 | lock sync.RWMutex
18 | pingPeriod time.Duration
19 | pongTimeout time.Duration
20 | upgrader *websocket.Upgrader
21 | }
22 |
23 | var Api *API
24 |
25 | // New creates a new instance of API.
26 | // pingPeriod: is the interval, in which is server sends the a ping to the client.
27 | // pongTimeout: is the duration after the connection will be terminated, when the client does not respond with the
28 | // pong command.
29 | func New(pingPeriod, pongTimeout time.Duration) *API {
30 | return &API{
31 | clients: make(map[string]*client),
32 | pingPeriod: pingPeriod,
33 | pongTimeout: pingPeriod + pongTimeout,
34 | upgrader: newUpgrader(),
35 | }
36 | }
37 |
38 | // NotifyDeletedUser closes existing connections for the given user.
39 | func (a *API) NotifyDeletedUser(token string) error {
40 | a.lock.Lock()
41 | defer a.lock.Unlock()
42 | if c, ok := a.clients[token]; ok {
43 | c.Close()
44 | delete(a.clients, token)
45 | }
46 | return nil
47 | }
48 |
49 | // Notify notifies the clients with the given userID that a new messages was created.
50 | func (a *API) Notify(token string, msg *[]byte) {
51 | a.lock.RLock()
52 | defer a.lock.RUnlock()
53 | if c, ok := a.clients[token]; ok {
54 | if viper.GetBool("is_debug") {
55 | fmt.Printf("WebSocket Notify: %s", *msg)
56 | }
57 | c.write <- msg
58 | } else {
59 | if viper.GetBool("is_debug") {
60 | fmt.Printf("WebSocket Error: token not connected: %s", *msg)
61 | }
62 | }
63 | }
64 |
65 | func (a *API) remove(remove *client) {
66 | a.lock.Lock()
67 | defer a.lock.Unlock()
68 | if c, ok := a.clients[remove.token]; ok {
69 | c.Close()
70 | delete(a.clients, remove.token)
71 | }
72 | }
73 |
74 | func (a *API) register(client *client) {
75 | a.lock.Lock()
76 | defer a.lock.Unlock()
77 | a.clients[client.token] = client
78 | }
79 |
80 | // Handle handles incoming requests. First it upgrades the protocol to the WebSocket protocol and then starts listening
81 | // for read and writes.
82 | // swagger:operation GET /stream message streamMessages
83 | //
84 | // Websocket, return newly created messages.
85 | //
86 | // ---
87 | // schema: ws, wss
88 | // produces: [application/json]
89 | // security: [clientTokenHeader: [], clientTokenQuery: [], basicAuth: []]
90 | // responses:
91 | // 200:
92 | // description: Ok
93 | // schema:
94 | // $ref: "#/definitions/Message"
95 | // 400:
96 | // description: Bad Request
97 | // schema:
98 | // $ref: "#/definitions/Error"
99 | // 401:
100 | // description: Unauthorized
101 | // schema:
102 | // $ref: "#/definitions/Error"
103 | // 403:
104 | // description: Forbidden
105 | // schema:
106 | // $ref: "#/definitions/Error"
107 | // 500:
108 | // description: Server Error
109 | // schema:
110 | // $ref: "#/definitions/Error"
111 | func (a *API) Handle(ctx *gin.Context) {
112 | token := ctx.GetHeader("TOKEN")
113 | conn, err := a.upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
114 | if err != nil {
115 | _ = ctx.Error(err)
116 | return
117 | }
118 |
119 | client := newClient(conn, token, a.remove)
120 | a.remove(client)
121 | a.register(client)
122 | go client.startReading(a.pongTimeout)
123 | go client.startWriteHandler(a.pingPeriod)
124 | }
125 |
126 | // Close closes all client connections and stops answering new connections.
127 | func (a *API) Close() {
128 | a.lock.Lock()
129 | defer a.lock.Unlock()
130 |
131 | for _, c := range a.clients {
132 | c.Close()
133 | }
134 | for k := range a.clients {
135 | delete(a.clients, k)
136 | }
137 | }
138 |
139 | //func isAllowedOrigin(r *http.Request, allowedOrigins []*regexp.Regexp) bool {
140 | // origin := r.Header.Get("origin")
141 | // if origin == "" {
142 | // return true
143 | // }
144 | //
145 | // u, err := url.Parse(origin)
146 | // if err != nil {
147 | // return false
148 | // }
149 | //
150 | // if strings.EqualFold(u.Host, r.Host) {
151 | // return true
152 | // }
153 | //
154 | // for _, allowedOrigin := range allowedOrigins {
155 | // if allowedOrigin.Match([]byte(strings.ToLower(u.Hostname()))) {
156 | // return true
157 | // }
158 | // }
159 | //
160 | // return false
161 | //}
162 |
163 | func newUpgrader() *websocket.Upgrader {
164 | return &websocket.Upgrader{
165 | ReadBufferSize: 1024,
166 | WriteBufferSize: 1024,
167 | CheckOrigin: func(r *http.Request) bool {
168 | //if mode.IsDev() {
169 | return true
170 | //}
171 | //return isAllowedOrigin(r, compiledAllowedOrigins)
172 | },
173 | }
174 | }
175 |
176 | //func compileAllowedWebSocketOrigins(allowedOrigins []string) []*regexp.Regexp {
177 | // var compiledAllowedOrigins []*regexp.Regexp
178 | // for _, origin := range allowedOrigins {
179 | // compiledAllowedOrigins = append(compiledAllowedOrigins, regexp.MustCompile(origin))
180 | // }
181 | //
182 | // return compiledAllowedOrigins
183 | //}
184 |
--------------------------------------------------------------------------------
/cmd/treehollow-v3-fallback/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "github.com/gin-gonic/gin"
5 | "github.com/spf13/viper"
6 | "log"
7 | "net"
8 | "net/http"
9 | "os"
10 | "path/filepath"
11 | "strings"
12 | "treehollow-v3-backend/pkg/config"
13 | "treehollow-v3-backend/pkg/utils"
14 | )
15 |
16 | //func apiFallBack(c *gin.Context) {
17 | // c.JSON(http.StatusOK, gin.H{
18 | // "msg": viper.GetString("fallback_announcement"),
19 | // })
20 | // return
21 | //}
22 |
23 | func upgradePrompt(c *gin.Context) {
24 | action := c.Query("action")
25 | if action == "getlist" {
26 | c.JSON(http.StatusOK, gin.H{
27 | "code": 0,
28 | "config": gin.H{
29 | "img_base_url": viper.GetString("img_base_url"),
30 | "img_base_url_bak": viper.GetString("img_base_url_bak"),
31 | "fold_tags": viper.GetStringSlice("fold_tags"),
32 | "web_frontend_version": "v2.0.0",
33 | "announcement": "发现树洞新版本,正在更新...",
34 | },
35 | "data": []gin.H{{
36 | "pid": 0,
37 | "text": "请更新到最新版本树洞。(点击界面右上角“账户”( i ),点击“强制检查更新”)",
38 | "type": "text",
39 | "timestamp": 2147483647,
40 | "reply": 0,
41 | "likenum": 0,
42 | "url": "",
43 | "tag": nil,
44 | }},
45 | "count": 1,
46 | })
47 | } else {
48 | c.JSON(http.StatusOK, gin.H{
49 | "msg": "请更新到最新版本树洞。(点击界面右上角“账户”( i ),点击“强制检查更新”)",
50 | })
51 | }
52 | }
53 |
54 | func upgradePromptV2(c *gin.Context) {
55 | c.JSON(http.StatusOK, gin.H{
56 | "code": 0,
57 | "config": gin.H{
58 | "img_base_url": viper.GetString("img_base_url"),
59 | "img_base_url_bak": viper.GetString("img_base_url_bak"),
60 | "fold_tags": viper.GetStringSlice("fold_tags"),
61 | "web_frontend_version": "v3.0.0",
62 | "announcement": "发现树洞新版本,正在更新...",
63 | },
64 | "data": []gin.H{{
65 | "pid": 0,
66 | "text": "请更新到最新版本树洞。(点击界面右上角“账户”( i ),点击“强制检查更新”)",
67 | "type": "text",
68 | "timestamp": 2147483647,
69 | "reply": 0,
70 | "likenum": 0,
71 | "url": "",
72 | "tag": nil,
73 | "updated_at": 2147483647,
74 | "attention": false,
75 | "permissions": []string{},
76 | "deleted": false,
77 | }},
78 | "count": 1,
79 | })
80 | }
81 |
82 | func serviceFallBack(c *gin.Context) {
83 | c.JSON(http.StatusOK, gin.H{
84 | "code": -1,
85 | "msg": viper.GetString("fallback_announcement"),
86 | })
87 | return
88 | }
89 |
90 | func main() {
91 | config.InitConfigFile()
92 | viper.SetDefault("fallback_announcement", "系统正在维护升级,请稍后重试...")
93 | viper.SetDefault("fallback_listen_address", "127.0.0.1:3002")
94 |
95 | r := gin.Default()
96 | //r.Use(cors.Default())
97 |
98 | //Old v1, compatibility fallback
99 | r.POST("/api_xmcp/login/send_code", upgradePrompt)
100 | r.POST("/api_xmcp/login/login", upgradePrompt)
101 | r.GET("/api_xmcp/hole/system_msg", upgradePrompt)
102 | r.GET("/services/thuhole/api.php", upgradePrompt)
103 | r.POST("/services/thuhole/api.php", upgradePrompt)
104 |
105 | //Old v2, compatibility fallback
106 | r.GET("/contents/post/list", upgradePromptV2)
107 | r.GET("/contents/search", upgradePromptV2)
108 | r.GET("/contents/system_msg", upgradePrompt)
109 | r.GET("/contents/post/detail", upgradePrompt)
110 | r.GET("/contents/post/attentions", upgradePrompt)
111 | r.GET("/contents/search/attentions", upgradePrompt)
112 | r.POST("/send/post", upgradePrompt)
113 | r.POST("/send/comment", upgradePrompt)
114 | r.POST("/edit/attention", upgradePrompt)
115 | r.POST("/edit/report/post", upgradePrompt)
116 | r.POST("/edit/report/comment", upgradePrompt)
117 | r.POST("/security/login/send_code", upgradePrompt)
118 | r.POST("/security/login/login", upgradePrompt)
119 |
120 | //v3 fallback
121 | r.POST("/v3/config/set_push", serviceFallBack)
122 | r.GET("/v3/config/get_push", serviceFallBack)
123 | r.GET("/v3/contents/system_msg", serviceFallBack)
124 | r.GET("/v3/contents/post/list", serviceFallBack)
125 | r.GET("/v3/contents/post/randomlist", serviceFallBack)
126 | r.GET("/v3/contents/post/detail", serviceFallBack)
127 | r.GET("/v3/contents/search", serviceFallBack)
128 | r.GET("/v3/contents/post/attentions", serviceFallBack)
129 | r.GET("/v3/contents/my_msgs", serviceFallBack)
130 | r.GET("/v3/contents/search/attentions", serviceFallBack)
131 | r.POST("/v3/send/post", serviceFallBack)
132 | r.POST("/v3/send/vote", serviceFallBack)
133 | r.POST("/v3/send/comment", serviceFallBack)
134 | r.POST("/v3/edit/attention", serviceFallBack)
135 | r.POST("/v3/edit/report/post", serviceFallBack)
136 | r.POST("/v3/edit/report/comment", serviceFallBack)
137 | r.POST("/v3/security/login/check_email", serviceFallBack)
138 | r.POST("/v3/security/login/create_account", serviceFallBack)
139 | r.POST("/v3/security/login/login", serviceFallBack)
140 | r.POST("/v3/security/login/change_password", serviceFallBack)
141 | r.GET("/v3/security/devices/list", serviceFallBack)
142 | r.POST("/v3/security/devices/terminate", serviceFallBack)
143 | r.POST("/v3/security/logout", serviceFallBack)
144 |
145 | listenAddr := viper.GetString("fallback_listen_address")
146 | if strings.Contains(listenAddr, ":") {
147 | _ = r.Run(listenAddr)
148 | } else {
149 | _ = os.MkdirAll(filepath.Dir(listenAddr), os.ModePerm)
150 | _ = os.Remove(listenAddr)
151 |
152 | listener, err := net.Listen("unix", listenAddr)
153 | utils.FatalErrorHandle(&err, "bind failed")
154 | log.Printf("Listening and serving HTTP on unix: %s.\n"+
155 | "Note: 0777 is not a safe permission for the unix socket file. "+
156 | "It would be better if the user manually set the permission after startup\n",
157 | listenAddr)
158 | _ = os.Chmod(listenAddr, 0777)
159 | err = http.Serve(listener, r)
160 | }
161 | }
162 |
--------------------------------------------------------------------------------
/pkg/utils/utils.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "bytes"
5 | "crypto/rand"
6 | "crypto/sha256"
7 | "encoding/base32"
8 | "encoding/base64"
9 | "encoding/hex"
10 | "encoding/json"
11 | "fmt"
12 | "github.com/ProtonMail/gopenpgp/v2/crypto"
13 | "github.com/google/uuid"
14 | errors2 "github.com/pkg/errors"
15 | "github.com/sigurn/crc8"
16 | "github.com/spf13/viper"
17 | "gorm.io/gorm"
18 | "image"
19 | _ "image/gif"
20 | _ "image/jpeg"
21 | _ "image/png"
22 | "io/ioutil"
23 | "math/big"
24 | "net"
25 | "net/http"
26 | "os"
27 | "path/filepath"
28 | "regexp"
29 | "strconv"
30 | "strings"
31 | "time"
32 | "treehollow-v3-backend/pkg/consts"
33 | "treehollow-v3-backend/pkg/logger"
34 | )
35 |
36 | var AllowedSubnets []*net.IPNet
37 | var Salt string
38 |
39 | func GenCode() string {
40 | nBig, err := rand.Int(rand.Reader, big.NewInt(1000000))
41 | if err != nil {
42 | panic(err)
43 | }
44 | n := nBig.Int64()
45 | return fmt.Sprintf("%06d", n)
46 | }
47 |
48 | func GenToken() string {
49 | randomBytes := make([]byte, 20)
50 | _, err := rand.Read(randomBytes)
51 | if err != nil {
52 | panic(err)
53 | }
54 | return strings.ToLower(base32.StdEncoding.EncodeToString(randomBytes))
55 | }
56 |
57 | func GenNonce() string {
58 | return uuid.New().String()
59 | }
60 |
61 | func SHA256(user string) string {
62 | h := sha256.New()
63 | h.Write([]byte(user))
64 | return hex.EncodeToString(h.Sum(nil))
65 | }
66 |
67 | func HashEmail(user string) string {
68 | return SHA256(Salt + SHA256(strings.ToLower(user)))
69 | }
70 |
71 | func GetTimeStamp() int64 {
72 | return time.Now().Unix()
73 | }
74 |
75 | func FatalErrorHandle(err *error, msg string) {
76 | if *err != nil {
77 | panic(fmt.Errorf("Fatal error: %s \n %s \n", msg, *err))
78 | }
79 | }
80 |
81 | func ContainsString(s []string, e string) (int, bool) {
82 | i := -1
83 | for i, a := range s {
84 | if a == e {
85 | return i, true
86 | }
87 | }
88 | return i, false
89 | }
90 |
91 | func ContainsInt(s []int, e int) (int, bool) {
92 | i := -1
93 | for i, a := range s {
94 | if a == e {
95 | return i, true
96 | }
97 | }
98 | return i, false
99 | }
100 |
101 | func GetCommenterName(id int, names0 []string, names1 []string) string {
102 | switch {
103 | case id == 0:
104 | return consts.DzName
105 | case id <= 26:
106 | return names1[id-1]
107 | case id <= 26*27:
108 | return names0[(id-1)/26-1] + " " + names1[(id-1)%26]
109 | default:
110 | return consts.ExtraNamePrefix + strconv.Itoa(id-26*27)
111 | }
112 | }
113 |
114 | //func remove(s []int, i int) []int {
115 | // s[len(s)-1], s[i] = s[i], s[len(s)-1]
116 | // return s[:len(s)-1]
117 | //}
118 |
119 | func IfThenElse(condition bool, a interface{}, b interface{}) interface{} {
120 | if condition {
121 | return a
122 | }
123 | return b
124 | }
125 |
126 | func CheckEmail(email string) bool {
127 | // REF: https://html.spec.whatwg.org/multipage/input.html#valid-e-mail-address
128 | var emailRegexp = regexp.MustCompile("^[a-zA-Z0-9.!#$%&'*+/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$")
129 | return emailRegexp.MatchString(email)
130 | }
131 |
132 | func CreatePublicKeyRing(publicKey string) (*crypto.KeyRing, error) {
133 | publicKeyObj, err := crypto.NewKeyFromArmored(publicKey)
134 | if err != nil {
135 | return nil, errors2.Wrap(err, "gopenpgp: unable to parse public key")
136 | }
137 |
138 | if publicKeyObj.IsPrivate() {
139 | publicKeyObj, err = publicKeyObj.ToPublic()
140 | if err != nil {
141 | return nil, errors2.Wrap(err, "gopenpgp: unable to extract public key from private key")
142 | }
143 | }
144 |
145 | publicKeyRing, err := crypto.NewKeyRing(publicKeyObj)
146 | if err != nil {
147 | return nil, errors2.Wrap(err, "gopenpgp: unable to create new keyring")
148 | }
149 |
150 | return publicKeyRing, nil
151 | }
152 |
153 | func IsInAllowedSubnet(ip string) bool {
154 | for _, subnet := range AllowedSubnets {
155 | if subnet.Contains(net.ParseIP(ip)) {
156 | return true
157 | }
158 | }
159 | return false
160 | }
161 |
162 | func GetHashedFilePath(filePath string) string {
163 | if len(filePath) > 2 {
164 | return filePath[:2] + "/" + filePath
165 | }
166 | return filePath
167 | }
168 |
169 | func SaveImage(base64img string, imgPath string) ([]byte, string, string, *logger.InternalError) {
170 | var suffix string
171 | sDec, err2 := base64.StdEncoding.DecodeString(base64img)
172 | if err2 != nil {
173 | return nil, "", "{}", logger.NewSimpleError("InvalidImgBase64", "图片数据不合法", logger.WARN)
174 | }
175 | fileType := http.DetectContentType(sDec)
176 | if fileType != "image/jpeg" && fileType != "image/jpg" && fileType != "image/png" && fileType != "image/gif" {
177 | return nil, "", "{}", logger.NewSimpleError("InvalidImgType", "图片数据不合法", logger.WARN)
178 | }
179 |
180 | im, _, err := image.DecodeConfig(bytes.NewReader(sDec))
181 | if err != nil {
182 | return nil, "", "{}", logger.NewError(err, "ImageDecodeFailed", "图片解析失败")
183 | }
184 | if im.Width > consts.ImageMaxWidth || im.Height > consts.ImageMaxHeight {
185 | return nil, "", "{}", logger.NewSimpleError("TooLargeImg", "图片过大", logger.WARN)
186 | }
187 | metadataBytes, err := json.Marshal(map[string]int{"w": im.Width, "h": im.Height})
188 | if err != nil {
189 | return nil, "", "{}", logger.NewError(err, "ImageSizeDecodeFailed", "图片大小解析失败")
190 | }
191 |
192 | if fileType == "image/png" {
193 | suffix = ".png"
194 | } else if fileType == "image/gif" {
195 | suffix = ".gif"
196 | } else {
197 | suffix = ".jpeg"
198 | }
199 |
200 | hashedPath := filepath.Join(viper.GetString("images_path"), imgPath[:2])
201 | _ = os.MkdirAll(hashedPath, os.ModePerm)
202 | err3 := ioutil.WriteFile(filepath.Join(hashedPath, imgPath+suffix), sDec, 0644)
203 | if err3 != nil {
204 | return nil, suffix, string(metadataBytes), logger.NewError(err3, "ErrorSavingImage", "图片存储失败")
205 | }
206 | return sDec, suffix, string(metadataBytes), nil
207 | }
208 |
209 | func CalcExtra(str1 string, str2 string) int64 {
210 | table := crc8.MakeTable(crc8.CRC8)
211 | rtn := int64(crc8.Checksum([]byte(str2+str1), table) % 4)
212 |
213 | return rtn
214 | }
215 |
216 | type void struct{}
217 |
218 | var member void
219 |
220 | func Int32SliceToSet(ids []int32) map[int32]void {
221 | set := make(map[int32]void)
222 | for _, id := range ids {
223 | set[id] = member
224 | }
225 | return set
226 | }
227 |
228 | func Int32IsInSet(id int32, ids map[int32]void) (rtn bool) {
229 | _, rtn = ids[id]
230 | return
231 | }
232 |
233 | func TimestampToString(timestamp int64) string {
234 | return time.Unix(timestamp, 0).
235 | In(consts.TimeLoc).Format("01-02 15:04")
236 | }
237 |
238 | func TrimText(text string, maxLength int) string {
239 | runeStr := []rune(text)
240 | if len(runeStr) > maxLength {
241 | return string(runeStr[:maxLength]) + "..."
242 | }
243 | return text
244 | }
245 |
246 | func GetEarliestAuthenticationTime() time.Time {
247 | return time.Now().AddDate(0, 0, -consts.TokenExpireDays)
248 | }
249 |
250 | func UnscopedTx(tx *gorm.DB, b bool) *gorm.DB {
251 | if b {
252 | return tx.Unscoped()
253 | }
254 | return tx
255 | }
256 |
--------------------------------------------------------------------------------
/pkg/route/security/checkEmail.go:
--------------------------------------------------------------------------------
1 | package security
2 |
3 | import (
4 | "github.com/gin-gonic/gin"
5 | "github.com/spf13/viper"
6 | "gopkg.in/ezzarghili/recaptcha-go.v4"
7 | "gorm.io/gorm/clause"
8 | "log"
9 | "net"
10 | "net/http"
11 | "regexp"
12 | "strings"
13 | "time"
14 | "treehollow-v3-backend/pkg/base"
15 | "treehollow-v3-backend/pkg/consts"
16 | "treehollow-v3-backend/pkg/logger"
17 | "treehollow-v3-backend/pkg/mail"
18 | "treehollow-v3-backend/pkg/route/contents"
19 | "treehollow-v3-backend/pkg/utils"
20 | )
21 |
22 | func checkEmailParamsCheckMiddleware(c *gin.Context) {
23 | recaptchaVersion := c.PostForm("recaptcha_version")
24 | recaptchaToken := c.PostForm("recaptcha_token")
25 | oldToken := c.PostForm("old_token")
26 | email := strings.ToLower(c.PostForm("email"))
27 |
28 | if len(email) > 100 || len(oldToken) > 32 || len(recaptchaToken) > 2000 || len(recaptchaVersion) > 2 {
29 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("CheckEmailParamsOutOfBound", "参数错误", logger.WARN))
30 | return
31 | }
32 | emailHash := utils.HashEmail(email)
33 | c.Set("email_hash", emailHash)
34 | c.Next()
35 | }
36 |
37 | func checkEmailRegexMiddleware(c *gin.Context) {
38 | email := strings.ToLower(c.PostForm("email"))
39 | emailCheck, err := regexp.Compile(viper.GetString("email_check_regex"))
40 | if err != nil {
41 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "RegexError", "服务器配置错误,请联系管理员。"))
42 | return
43 | }
44 | if !emailCheck.MatchString(email) {
45 | emailWhitelist := viper.GetStringSlice("email_whitelist")
46 | if _, ok := utils.ContainsString(emailWhitelist, email); !ok {
47 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("EmailRegexCheckNotPass", "很抱歉,您的邮箱无法注册"+viper.GetString("name"), logger.INFO))
48 | return
49 | }
50 | }
51 | }
52 |
53 | func checkEmailIsRegisteredUserMiddleware(c *gin.Context) {
54 | emailHash := c.MustGet("email_hash").(string)
55 | var count int64
56 | //check if user is registered
57 | err := base.GetDb(false).Where("email_hash = ?", emailHash).Model(&base.Email{}).Count(&count).Error
58 | if err != nil {
59 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "SearchEmailHashFailed", consts.DatabaseReadFailedString))
60 | return
61 | }
62 | if count == 1 {
63 | c.JSON(http.StatusOK, gin.H{
64 | "code": 0,
65 | })
66 | c.Abort()
67 | return
68 | }
69 | c.Next()
70 | }
71 |
72 | //compatibility settings
73 | func checkEmailIsOldTreeholeUserMiddleware(c *gin.Context) {
74 | oldToken := c.PostForm("old_token")
75 | emailHash := c.MustGet("email_hash").(string)
76 | var count int64
77 |
78 | //check if user is old v2 version user
79 | err := base.GetDb(false).Where("old_email_hash = ? and old_token = ?", emailHash, oldToken).
80 | Model(&base.User{}).Count(&count).Error
81 | if err != nil {
82 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "SearchOldEmailHashFailed", consts.DatabaseReadFailedString))
83 | return
84 | }
85 | if count == 1 {
86 | c.JSON(http.StatusOK, gin.H{
87 | "code": 2,
88 | })
89 | c.Abort()
90 | return
91 | }
92 | c.Next()
93 | }
94 |
95 | func checkEmailRateLimitVerificationCode(c *gin.Context) {
96 | emailHash := c.MustGet("email_hash").(string)
97 |
98 | now := utils.GetTimeStamp()
99 | _, timeStamp, _, _ := base.GetVerificationCode(emailHash)
100 | if now-timeStamp < 60 {
101 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("TooMuchEmailInOneMinute", "请不要短时间内重复发送邮件。", logger.INFO))
102 | return
103 | }
104 | c.Next()
105 | }
106 |
107 | func checkEmailReCaptchaValidationMiddleware(c *gin.Context) {
108 | recaptchaVersion := c.PostForm("recaptcha_version")
109 | recaptchaToken := c.PostForm("recaptcha_token")
110 | email := strings.ToLower(c.PostForm("email"))
111 |
112 | if len(c.PostForm("recaptcha_token")) < 1 {
113 | c.JSON(http.StatusOK, gin.H{
114 | "code": 3,
115 | })
116 | c.Abort()
117 | return
118 | }
119 |
120 | context, err2 := contents.EmailLimiter.Get(c, c.ClientIP())
121 | if err2 != nil {
122 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err2, "EmailLimiterFailed", consts.DatabaseReadFailedString))
123 | return
124 | }
125 | if context.Reached {
126 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("EmailLimiterReached"+c.ClientIP(), "您今天已经发送了过多验证码,请24小时之后重试。", logger.WARN))
127 | return
128 | }
129 |
130 | geoDb := utils.GeoDb.Get()
131 | if geoDb != nil && len(viper.GetStringSlice("allowed_register_countries")) != 0 {
132 | ip := net.ParseIP(c.ClientIP())
133 | record, err5 := geoDb.Country(ip)
134 | if err5 == nil {
135 | country := record.Country.Names["zh-CN"]
136 | if _, ok := utils.ContainsString(viper.GetStringSlice("allowed_register_countries"), country); !ok {
137 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("RegisterNotAllowed"+c.ClientIP()+country+email, "您所在的国家暂未开放注册。", logger.WARN))
138 | return
139 | }
140 | }
141 | }
142 |
143 | var captcha recaptcha.ReCAPTCHA
144 | if recaptchaVersion == "v2" {
145 | captcha, _ = recaptcha.NewReCAPTCHA(viper.GetString("recaptcha_v2_private_key"), recaptcha.V2, 10*time.Second)
146 | } else {
147 | captcha, _ = recaptcha.NewReCAPTCHA(viper.GetString("recaptcha_v3_private_key"), recaptcha.V3, 10*time.Second)
148 | }
149 | captcha.ReCAPTCHALink = "https://www.recaptcha.net/recaptcha/api/siteverify"
150 | err := captcha.VerifyWithOptions(recaptchaToken, recaptcha.VerifyOption{
151 | RemoteIP: c.ClientIP(),
152 | Threshold: float32(viper.GetFloat64("recaptcha_threshold")),
153 | })
154 | if err != nil {
155 | log.Println("recaptcha server error", err, c.ClientIP(), email)
156 | c.JSON(http.StatusOK, gin.H{
157 | "code": 3,
158 | })
159 | c.Abort()
160 | return
161 | }
162 | c.Next()
163 | }
164 |
165 | func checkEmail(c *gin.Context) {
166 | email := strings.ToLower(c.PostForm("email"))
167 |
168 | emailHash := c.MustGet("email_hash").(string)
169 |
170 | code := utils.GenCode()
171 |
172 | err := mail.SendValidationEmail(code, email)
173 | if err != nil {
174 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err, "SendEmailFailed"+email, "验证码邮件发送失败。"))
175 | return
176 | }
177 |
178 | err = base.GetDb(false).Clauses(clause.OnConflict{
179 | UpdateAll: true,
180 | }).Create(&base.VerificationCode{Code: code, EmailHash: emailHash, FailedTimes: 0, UpdatedAt: time.Now()}).Error
181 | if err != nil {
182 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err, "SaveVerificationCodeFailed", consts.DatabaseWriteFailedString))
183 | return
184 | }
185 |
186 | c.JSON(http.StatusOK, gin.H{
187 | "code": 1,
188 | "msg": "验证码发送成功,5分钟内无法重复发送验证码。请记得查看垃圾邮件。",
189 | })
190 | }
191 |
192 | func unregisterEmail(c *gin.Context) {
193 | email := strings.ToLower(c.PostForm("email"))
194 |
195 | emailHash := c.MustGet("email_hash").(string)
196 |
197 | code := utils.GenCode()
198 |
199 | err := mail.SendUnregisterValidationEmail(code, email)
200 | if err != nil {
201 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err, "SendEmailFailed"+email, "验证码邮件发送失败。"))
202 | return
203 | }
204 |
205 | err = base.GetDb(false).Clauses(clause.OnConflict{
206 | UpdateAll: true,
207 | }).Create(&base.VerificationCode{Code: code, EmailHash: emailHash, FailedTimes: 0, UpdatedAt: time.Now()}).Error
208 | if err != nil {
209 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err, "SaveVerificationCodeFailed", consts.DatabaseWriteFailedString))
210 | return
211 | }
212 |
213 | c.JSON(http.StatusOK, gin.H{
214 | "code": 1,
215 | "msg": "验证码发送成功,5分钟内无法重复发送验证码。请记得查看垃圾邮件。",
216 | })
217 | }
218 |
--------------------------------------------------------------------------------
/pkg/logger/ginLogger/ginLogger.go:
--------------------------------------------------------------------------------
1 | package ginLogger
2 |
3 | import (
4 | "fmt"
5 | "github.com/gin-gonic/gin"
6 | "github.com/mattn/go-isatty"
7 | "io"
8 | "net/http"
9 | "os"
10 | "time"
11 | "treehollow-v3-backend/pkg/base"
12 | )
13 |
14 | type consoleColorModeValue int
15 |
16 | const (
17 | autoColor consoleColorModeValue = iota
18 | //disableColor
19 | forceColor
20 | )
21 |
22 | const (
23 | green = "\033[97;42m"
24 | white = "\033[90;47m"
25 | yellow = "\033[90;43m"
26 | red = "\033[97;41m"
27 | blue = "\033[97;44m"
28 | magenta = "\033[97;45m"
29 | cyan = "\033[97;46m"
30 | reset = "\033[0m"
31 | )
32 |
33 | var consoleColorMode = autoColor
34 |
35 | // LoggerConfig defines the config for Logger middleware.
36 | type LoggerConfig struct {
37 | // Optional. Default value is gin.defaultLogFormatter
38 | Formatter LogFormatter
39 |
40 | // Output is a writer where logs are written.
41 | // Optional. Default value is gin.DefaultWriter.
42 | Output io.Writer
43 |
44 | // SkipPaths is a url path array which logs are not written.
45 | // Optional.
46 | SkipPaths []string
47 | }
48 |
49 | // LogFormatter gives the signature of the formatter function passed to LoggerWithFormatter
50 | type LogFormatter func(params LogFormatterParams) string
51 |
52 | // LogFormatterParams is the structure any formatter will be handed when time to log comes
53 | type LogFormatterParams struct {
54 | Request *http.Request
55 |
56 | // TimeStamp shows the time after the server returns a response.
57 | TimeStamp time.Time
58 | // StatusCode is HTTP response code.
59 | StatusCode int
60 | // Latency is how much time the server cost to process a certain request.
61 | Latency time.Duration
62 | // ClientIP equals Context's ClientIP method.
63 | ClientIP string
64 | // Method is the HTTP method given to the request.
65 | Method string
66 | // Path is a path the client requests.
67 | Path string
68 | // ErrorMessage is set if error has occurred in processing the request.
69 | ErrorMessage string
70 | // isTerm shows whether does gin's output descriptor refers to a terminal.
71 | isTerm bool
72 | // BodySize is the size of the Response Body
73 | BodySize int
74 | // Keys are the keys set on the request's context.
75 | Keys map[string]interface{}
76 |
77 | UserID int32
78 | }
79 |
80 | // StatusCodeColor is the ANSI color for appropriately logging http status code to a terminal.
81 | func (p *LogFormatterParams) StatusCodeColor() string {
82 | code := p.StatusCode
83 |
84 | switch {
85 | case code >= http.StatusOK && code < http.StatusMultipleChoices:
86 | return green
87 | case code >= http.StatusMultipleChoices && code < http.StatusBadRequest:
88 | return white
89 | case code >= http.StatusBadRequest && code < http.StatusInternalServerError:
90 | return yellow
91 | default:
92 | return red
93 | }
94 | }
95 |
96 | // MethodColor is the ANSI color for appropriately logging http method to a terminal.
97 | func (p *LogFormatterParams) MethodColor() string {
98 | method := p.Method
99 |
100 | switch method {
101 | case http.MethodGet:
102 | return blue
103 | case http.MethodPost:
104 | return cyan
105 | case http.MethodPut:
106 | return yellow
107 | case http.MethodDelete:
108 | return red
109 | case http.MethodPatch:
110 | return green
111 | case http.MethodHead:
112 | return magenta
113 | case http.MethodOptions:
114 | return white
115 | default:
116 | return reset
117 | }
118 | }
119 |
120 | // ResetColor resets all escape attributes.
121 | func (p *LogFormatterParams) ResetColor() string {
122 | return reset
123 | }
124 |
125 | // IsOutputColor indicates whether can colors be outputted to the log.
126 | func (p *LogFormatterParams) IsOutputColor() bool {
127 | return consoleColorMode == forceColor || (consoleColorMode == autoColor && p.isTerm)
128 | }
129 |
130 | // defaultLogFormatter is the default log format function Logger middleware uses.
131 | var defaultLogFormatter = func(param LogFormatterParams) string {
132 | var statusColor, methodColor, resetColor string
133 | if param.IsOutputColor() {
134 | statusColor = param.StatusCodeColor()
135 | methodColor = param.MethodColor()
136 | resetColor = param.ResetColor()
137 | }
138 |
139 | if param.Latency > time.Minute {
140 | // Truncate in a golang < 1.8 safe way
141 | param.Latency = param.Latency - param.Latency%time.Second
142 | }
143 | return fmt.Sprintf("[GIN] %v |%s %3d %s| %13v | %5d | %15s |%s %-7s %s %#v\n%s",
144 | param.TimeStamp.Format("2006/01/02 - 15:04:05"),
145 | statusColor, param.StatusCode, resetColor,
146 | param.Latency,
147 | param.UserID,
148 | param.ClientIP,
149 | methodColor, param.Method, resetColor,
150 | param.Path,
151 | param.ErrorMessage,
152 | )
153 | }
154 |
155 | //// DisableConsoleColor disables color output in the console.
156 | //func DisableConsoleColor() {
157 | // consoleColorMode = disableColor
158 | //}
159 | //
160 | //// ForceConsoleColor force color output in the console.
161 | //func ForceConsoleColor() {
162 | // consoleColorMode = forceColor
163 | //}
164 | //
165 | //// ErrorLogger returns a gin.HandlerFunc for any error type.
166 | //func ErrorLogger() gin.HandlerFunc {
167 | // return ErrorLoggerT(gin.ErrorTypeAny)
168 | //}
169 |
170 | //// ErrorLoggerT returns a gin.HandlerFunc for a given error type.
171 | //func ErrorLoggerT(typ gin.ErrorType) gin.HandlerFunc {
172 | // return func(c *gin.Context) {
173 | // c.Next()
174 | // errors := c.Errors.ByType(typ)
175 | // if len(errors) > 0 {
176 | // c.JSON(-1, errors)
177 | // }
178 | // }
179 | //}
180 |
181 | //// Logger instances a Logger middleware that will write the logs to gin.DefaultWriter.
182 | //// By default gin.DefaultWriter = os.Stdout.
183 | //func Logger() gin.HandlerFunc {
184 | // return LoggerWithConfig(LoggerConfig{})
185 | //}
186 |
187 | // LoggerWithConfig instance a Logger middleware with config.
188 | func LoggerWithConfig(conf LoggerConfig) gin.HandlerFunc {
189 | formatter := conf.Formatter
190 | if formatter == nil {
191 | formatter = defaultLogFormatter
192 | }
193 |
194 | out := conf.Output
195 | if out == nil {
196 | out = gin.DefaultWriter
197 | }
198 |
199 | notlogged := conf.SkipPaths
200 |
201 | isTerm := true
202 |
203 | if w, ok := out.(*os.File); !ok || os.Getenv("TERM") == "dumb" ||
204 | (!isatty.IsTerminal(w.Fd()) && !isatty.IsCygwinTerminal(w.Fd())) {
205 | isTerm = false
206 | }
207 |
208 | var skip map[string]struct{}
209 |
210 | if length := len(notlogged); length > 0 {
211 | skip = make(map[string]struct{}, length)
212 |
213 | for _, path := range notlogged {
214 | skip[path] = struct{}{}
215 | }
216 | }
217 |
218 | return func(c *gin.Context) {
219 | // Start timer
220 | start := time.Now()
221 | path := c.Request.URL.Path
222 | raw := c.Request.URL.RawQuery
223 |
224 | // Process request
225 | c.Next()
226 |
227 | // Log only when path is not being skipped
228 | if _, ok := skip[path]; !ok {
229 | param := LogFormatterParams{
230 | Request: c.Request,
231 | isTerm: isTerm,
232 | Keys: c.Keys,
233 | }
234 |
235 | // Stop timer
236 | param.TimeStamp = time.Now()
237 | param.Latency = param.TimeStamp.Sub(start)
238 |
239 | param.ClientIP = c.ClientIP()
240 | param.Method = c.Request.Method
241 | param.StatusCode = c.Writer.Status()
242 | param.ErrorMessage = c.Errors.ByType(gin.ErrorTypePrivate).String()
243 |
244 | param.UserID = c.MustGet("user").(base.User).ID
245 |
246 | param.BodySize = c.Writer.Size()
247 |
248 | if raw != "" {
249 | path = path + "?" + raw
250 | }
251 |
252 | param.Path = path
253 |
254 | _, _ = fmt.Fprint(out, formatter(param))
255 | }
256 | }
257 | }
258 |
--------------------------------------------------------------------------------
/pkg/base/structs.go:
--------------------------------------------------------------------------------
1 | package base
2 |
3 | import (
4 | "fmt"
5 | "gorm.io/gorm"
6 | "time"
7 | "treehollow-v3-backend/pkg/model"
8 | "treehollow-v3-backend/pkg/utils"
9 | )
10 |
11 | type UserRole int32
12 |
13 | const (
14 | BannedUserRole UserRole = -100
15 | SuperUserRole = 0
16 | AdminRole = 1
17 | DeleterRole = 2
18 | UnDeleterRole = 3
19 | Deleter2Role = 20
20 | Deleter3Role = 21
21 | NormalUserRole = 50
22 | UnregisteredRole = 100
23 | )
24 |
25 | type ReportType string
26 |
27 | const (
28 | UserReport ReportType = "UserReport"
29 | UserReportFold ReportType = "UserReportFold"
30 | UserDelete ReportType = "UserDelete" // delete, no ban
31 | AdminTag ReportType = "AdminTag"
32 | AdminDeleteAndBan ReportType = "AdminDeleteBan" // delete, ban
33 | AdminUndelete ReportType = "Undelete" // undelete + unban
34 | AdminUnban ReportType = "AdminUnban" // delete + unban
35 | // For now, there's no "undelete + no unban" option
36 | )
37 |
38 | // codebeat:disable[TOO_MANY_IVARS]
39 | type User struct {
40 | ID int32 `gorm:"primaryKey;autoIncrement;not null"`
41 | OldEmailHash string `gorm:"index;type:varchar(64) NOT NULL"`
42 | OldToken string `gorm:"index;type:varchar(32) NOT NULL"`
43 | EmailEncrypted string `gorm:"index;type:varchar(200) NOT NULL"`
44 | //KeyEncrypted string `gorm:"type:varchar(200) NOT NULL"`
45 | ForgetPwNonce string `gorm:"type:varchar(36) NOT NULL"`
46 | Role UserRole
47 | //SystemMessages []SystemMessage
48 | //Bans []Ban
49 | //Posts []Post
50 | //Comments []Comment
51 | //Devices []Device
52 | CreatedAt time.Time
53 | UpdatedAt time.Time
54 | DeletedAt gorm.DeletedAt `gorm:"index"`
55 | }
56 |
57 | type DecryptionKeyShares struct {
58 | EmailEncrypted string `gorm:"index;type:varchar(200) NOT NULL"`
59 | PGPMessage string `gorm:"type:varchar(5000) NOT NULL"`
60 | PGPEmail string `gorm:"index;type:varchar(100) NOT NULL"`
61 | CreatedAt time.Time
62 | DeletedAt gorm.DeletedAt `gorm:"index"`
63 | }
64 |
65 | type Email struct {
66 | EmailHash string `gorm:"primaryKey;type:char(64) NOT NULL"`
67 | }
68 |
69 | type DeviceType int32
70 |
71 | const (
72 | WebDevice DeviceType = 0
73 | AndroidDevice = 1
74 | IOSDevice = 2
75 | )
76 |
77 | type Device struct {
78 | ID string `gorm:"type:char(36);primary_key"`
79 | UserID int32 `gorm:"index;not null"`
80 | DeviceInfo string `gorm:"type:varchar(100) NOT NULL"`
81 | Type DeviceType
82 | IOSDeviceToken string `gorm:"type:varchar(100)"`
83 | Token string `gorm:"index;type:char(32) NOT NULL"`
84 | LoginIP string `gorm:"type:varchar(50) NOT NULL"`
85 | LoginCity string `gorm:"type:varchar(50) NOT NULL"`
86 | CreatedAt time.Time `gorm:"index"`
87 | DeletedAt gorm.DeletedAt `gorm:"index"`
88 | }
89 |
90 | type PushSettings struct {
91 | UserID int32 `gorm:"primaryKey;not null"`
92 | Settings model.PushType
93 | }
94 |
95 | type VerificationCode struct {
96 | EmailHash string `gorm:"primaryKey;type:char(64) NOT NULL"`
97 | Code string `gorm:"type:varchar(20) NOT NULL"`
98 | FailedTimes int
99 | CreatedAt time.Time
100 | UpdatedAt time.Time
101 | }
102 |
103 | type Post struct {
104 | ID int32 `gorm:"primaryKey;autoIncrement;not null"`
105 | //User User
106 | UserID int32
107 | Text string `gorm:"index:,class:FULLTEXT,option:WITH PARSER ngram;type: varchar(10000) NOT NULL"`
108 | Tag string `gorm:"index;type:varchar(60) NOT NULL"`
109 | Type string `gorm:"type:varchar(20) NOT NULL"`
110 | FilePath string `gorm:"type:varchar(60) NOT NULL"`
111 | FileMetadata string `gorm:"type:varchar(40) NOT NULL"`
112 | VoteData string `gorm:"type:varchar(200) NOT NULL"`
113 | LikeNum int32 `gorm:"index"`
114 | ReplyNum int32 `gorm:"index"`
115 | ReportNum int32
116 | //Comments []Comment
117 | CreatedAt time.Time `gorm:"index"`
118 | UpdatedAt time.Time `gorm:"index"`
119 | DeletedAt gorm.DeletedAt `gorm:"index"`
120 | }
121 |
122 | type Comment struct {
123 | ID int32 `gorm:"primaryKey;autoIncrement;not null"`
124 | ReplyTo int32 `gorm:"index"`
125 | //Post Post
126 | PostID int32 `gorm:"index"`
127 | //User User
128 | UserID int32
129 | Text string `gorm:"index:,class:FULLTEXT,option:WITH PARSER ngram;type: varchar(10000) NOT NULL"`
130 | Tag string `gorm:"index;type:varchar(60) NOT NULL"`
131 | Type string `gorm:"type:varchar(20) NOT NULL"`
132 | FilePath string `gorm:"type:varchar(60) NOT NULL"`
133 | FileMetadata string `gorm:"type:varchar(40) NOT NULL"`
134 | Name string `gorm:"type:varchar(60) NOT NULL"`
135 | CreatedAt time.Time
136 | UpdatedAt time.Time
137 | DeletedAt gorm.DeletedAt `gorm:"index"`
138 | }
139 |
140 | type Report struct {
141 | ID int32 `gorm:"primaryKey;autoIncrement;not null"`
142 | //User User
143 | UserID int32
144 | //ReportedUser User
145 | ReportedUserID int32
146 | //Post Post
147 | PostID int32
148 | //Comment Comment
149 | CommentID int32
150 | Reason string `gorm:"type: varchar(1000) NOT NULL"`
151 | Type ReportType `gorm:"type:varchar(20) NOT NULL"`
152 | IsComment bool
153 | Weight int32
154 | CreatedAt time.Time `gorm:"index"`
155 | DeletedAt gorm.DeletedAt `gorm:"index"`
156 | }
157 |
158 | //TODO: (low priority)undelete = remove user reports
159 |
160 | type Attention struct {
161 | User User
162 | UserID int32 `gorm:"primaryKey;index"`
163 | Post Post
164 | PostID int32 `gorm:"primaryKey;index"`
165 | }
166 |
167 | type Vote struct {
168 | User User
169 | UserID int32 `gorm:"primaryKey;index"`
170 | Post Post
171 | PostID int32 `gorm:"primaryKey;index"`
172 | Option string `gorm:"type:varchar(100) NOT NULL"`
173 | }
174 |
175 | type SystemMessage struct {
176 | ID int32 `gorm:"primaryKey;autoIncrement;not null"`
177 | //User User
178 | UserID int32
179 | Text string `gorm:"type: varchar(11000) NOT NULL"`
180 | Title string `gorm:"type: varchar(100) NOT NULL"`
181 | //Ban Ban
182 | BanID int32 `gorm:"index"`
183 | CreatedAt time.Time `gorm:"index"`
184 | DeletedAt gorm.DeletedAt `gorm:"index"`
185 | }
186 |
187 | type Ban struct {
188 | ID int32 `gorm:"primaryKey;autoIncrement;not null"`
189 | //User User
190 | UserID int32
191 | //Report Report
192 | ReportID int32
193 | Reason string `gorm:"type: varchar(11000) NOT NULL"`
194 | ExpireAt int64
195 | CreatedAt time.Time `gorm:"index"`
196 | DeletedAt gorm.DeletedAt `gorm:"index"`
197 | }
198 |
199 | type PushMessage struct {
200 | ID int32 `gorm:"primaryKey;autoIncrement;not null"`
201 | Message string `gorm:"type: varchar(10000) NOT NULL"`
202 | Title string `gorm:"type: varchar(200) NOT NULL"`
203 | UserID int32 `gorm:"index"`
204 | PostID int32
205 | CommentID int32 `gorm:"index"`
206 | BanID int32 `gorm:"index"`
207 | DoPush bool `gorm:"index"`
208 | Type model.PushType
209 | UpdatedAt time.Time `gorm:"index"`
210 | DeletedAt gorm.DeletedAt `gorm:"index"`
211 | }
212 |
213 | //type Messages struct {
214 | // ID int32 `gorm:"primaryKey;autoIncrement;not null"`
215 | // UserID int32 `gorm:"index"`
216 | // CommentID int32
217 | //}
218 |
219 | func (report *Report) ToString() string {
220 | rtn := ""
221 | var name string
222 | if report.IsComment {
223 | name = fmt.Sprintf("To:树洞回复#%d-%d", report.PostID, report.CommentID)
224 | } else {
225 | name = fmt.Sprintf("To:树洞#%d", report.PostID)
226 | }
227 | rtn = fmt.Sprintf("%s\n***\nReason: %s", name, report.Reason)
228 | return rtn
229 | }
230 |
231 | func (typ *ReportType) ToString() string {
232 | switch *typ {
233 | case UserReport:
234 | return "用户举报"
235 | case UserReportFold:
236 | return "用户举报折叠"
237 | case AdminTag:
238 | return "管理员打Tag"
239 | case UserDelete:
240 | return "撤回或管理员删除"
241 | case AdminUndelete:
242 | return "撤销删除并解禁"
243 | case AdminDeleteAndBan:
244 | return "删帖禁言"
245 | case AdminUnban:
246 | return "解禁"
247 | default:
248 | return "unknown"
249 | }
250 | }
251 |
252 | func (report *Report) ToDetailedString() string {
253 | typeStr := report.Type.ToString()
254 | if report.Type == UserDelete {
255 | typeStr = utils.IfThenElse(report.UserID == report.ReportedUserID, "撤回", "管理员删除").(string)
256 | }
257 | rtn := fmt.Sprintf("From User ID:%d\nTo User ID:%d\nType:%s\n%s", report.UserID, report.ReportedUserID,
258 | typeStr, report.ToString())
259 | return rtn
260 | }
261 |
262 | func (msg *SystemMessage) ToString() string {
263 | return fmt.Sprintf("User ID:%d\nTitle:%s\n***\n%s", msg.UserID, msg.Title, msg.Text)
264 | }
265 |
--------------------------------------------------------------------------------
/pkg/route/contents/middlewares.go:
--------------------------------------------------------------------------------
1 | package contents
2 |
3 | import (
4 | "encoding/json"
5 | "github.com/gin-gonic/gin"
6 | "github.com/iancoleman/orderedmap"
7 | "github.com/shirou/gopsutil/v3/load"
8 | "github.com/spf13/viper"
9 | "github.com/ulule/limiter/v3"
10 | "log"
11 | "net/http"
12 | "strconv"
13 | "strings"
14 | "time"
15 | "treehollow-v3-backend/pkg/base"
16 | "treehollow-v3-backend/pkg/consts"
17 | "treehollow-v3-backend/pkg/logger"
18 | "treehollow-v3-backend/pkg/utils"
19 | "unicode/utf8"
20 | )
21 |
22 | var EmailLimiter *limiter.Limiter
23 | var postLimiter *limiter.Limiter
24 | var postLimiter2 *limiter.Limiter
25 | var commentLimiter *limiter.Limiter
26 | var commentLimiter2 *limiter.Limiter
27 | var detailPostLimiter *limiter.Limiter
28 | var randomListLimiter *limiter.Limiter
29 | var doAttentionLimiter *limiter.Limiter
30 | var searchLimiter *limiter.Limiter
31 | var searchShortTimeLimiter *limiter.Limiter
32 | var deleteBanLimiter *limiter.Limiter
33 |
34 | func initLimiters() {
35 | randomListLimiter = base.InitLimiter(limiter.Rate{
36 | Period: 24 * time.Hour,
37 | Limit: 200,
38 | }, "randomListLimiter")
39 | postLimiter = base.InitLimiter(limiter.Rate{
40 | Period: 6 * time.Second,
41 | Limit: 1,
42 | }, "postLimiter")
43 | postLimiter2 = base.InitLimiter(limiter.Rate{
44 | Period: 24 * time.Hour,
45 | Limit: 100,
46 | }, "postLimiter2")
47 | commentLimiter = base.InitLimiter(limiter.Rate{
48 | Period: 3 * time.Second,
49 | Limit: 1,
50 | }, "commentLimiter")
51 | commentLimiter2 = base.InitLimiter(limiter.Rate{
52 | Period: 24 * time.Hour,
53 | Limit: 500,
54 | }, "commentLimiter2")
55 | detailPostLimiter = base.InitLimiter(limiter.Rate{
56 | Period: 24 * time.Hour,
57 | Limit: 8000,
58 | }, "detailPostLimiter")
59 | searchShortTimeLimiter = base.InitLimiter(limiter.Rate{
60 | Period: 2 * time.Second,
61 | Limit: 1,
62 | }, "searchShortTimeLimiter")
63 | searchLimiter = base.InitLimiter(limiter.Rate{
64 | Period: 24 * time.Hour,
65 | Limit: 1000,
66 | }, "searchLimiter")
67 | doAttentionLimiter = base.InitLimiter(limiter.Rate{
68 | Period: 24 * time.Hour,
69 | Limit: 2000,
70 | }, "doAttentionLimiter")
71 | deleteBanLimiter = base.InitLimiter(limiter.Rate{
72 | Period: 24 * time.Hour,
73 | Limit: base.GetDeletePostRateLimitIn24h(base.SuperUserRole),
74 | }, "deleteBanLimiter")
75 | }
76 |
77 | func limiterMiddleware(limiter *limiter.Limiter, msg string, level logger.LogLevel) gin.HandlerFunc {
78 | return func(c *gin.Context) {
79 | user := c.MustGet("user").(base.User)
80 | uidStr := strconv.Itoa(int(user.ID))
81 |
82 | if base.NeedLimiter(&user) {
83 | context, err6 := limiter.Get(c, uidStr)
84 | if err6 != nil {
85 | c.AbortWithStatus(500)
86 | return
87 | }
88 | if context.Reached {
89 | logMsg := "limiter reached: " + msg
90 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError(logMsg, msg, level))
91 | return
92 | }
93 | }
94 | c.Next()
95 | }
96 | }
97 |
98 | func sysLoadWarningMiddleware(threshold float64, msg string) gin.HandlerFunc {
99 | return func(c *gin.Context) {
100 | avg, err := load.Avg()
101 | if err == nil {
102 | if avg.Load1 > threshold {
103 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError(msg, msg, logger.WARN))
104 | return
105 | }
106 | }
107 | c.Next()
108 | }
109 | }
110 |
111 | func textToFrontendJson(id int32, timestamp int64, text string) gin.H {
112 | return gin.H{
113 | "pid": id,
114 | "text": text,
115 | "type": "text",
116 | "timestamp": timestamp,
117 | "updated_at": timestamp,
118 | "reply": 0,
119 | "likenum": 0,
120 | "attention": false,
121 | "permissions": []string{},
122 | "url": "",
123 | "tag": nil,
124 | "deleted": false,
125 | "image_metadata": gin.H{},
126 | "vote": gin.H{},
127 | }
128 | }
129 |
130 | func httpReturnInfo(c *gin.Context, text string) {
131 | c.JSON(http.StatusOK, gin.H{
132 | "code": 0,
133 | "data": []map[string]interface{}{textToFrontendJson(0, 2147483647, text)},
134 | "comments": map[string]string{},
135 | //"timestamp": utils.GetTimeStamp(),
136 | "count": 1,
137 | })
138 | c.Abort()
139 | }
140 |
141 | func disallowBannedPostUsers() gin.HandlerFunc {
142 | return func(c *gin.Context) {
143 | user := c.MustGet("user").(base.User)
144 | if !base.CanOverrideBan(&user) {
145 | timestamp := utils.GetTimeStamp()
146 | bannedTimes, err := base.GetBannedTime(base.GetDb(false), user.ID, timestamp)
147 | if bannedTimes > 0 && err == nil {
148 | var ban base.Ban
149 | err2 := base.GetDb(false).Model(&base.Ban{}).Where("user_id = ? and expire_at > ?", user.ID, timestamp).
150 | Order("expire_at desc").First(&ban).Error
151 | if err2 == nil {
152 | base.HttpReturnWithCodeMinusOneAndAbort(c,
153 | logger.NewSimpleError("DisallowBan", "很抱歉,您当前处于禁言状态,在"+
154 | utils.TimestampToString(ban.ExpireAt)+"之前您将无法发布树洞。", logger.WARN))
155 | return
156 | }
157 | }
158 | }
159 | c.Next()
160 | }
161 | }
162 |
163 | func checkReportParams(isPost bool) gin.HandlerFunc {
164 | return func(c *gin.Context) {
165 | reason := c.PostForm("reason")
166 | if len(reason) > consts.ReportMaxLength {
167 | base.HttpReturnWithCodeMinusOneAndAbort(c,
168 | logger.NewSimpleError("TooLongReport",
169 | "字数过长!字数限制为"+strconv.Itoa(consts.ReportMaxLength)+"字节。", logger.INFO))
170 | return
171 | }
172 | id, err := strconv.Atoi(c.PostForm("id"))
173 | if err != nil {
174 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "InvalidIdReport", "操作失败,id不合法"))
175 | return
176 | }
177 | if isPost {
178 | typ := c.PostForm("type")
179 | if _, ok := utils.ContainsInt(viper.GetIntSlice("disallow_report_pids"), id); ok && typ == "report" {
180 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("DisallowReport", "这个树洞无法举报哦", logger.WARN))
181 | return
182 | }
183 | }
184 | c.Set("id", id)
185 | c.Next()
186 | }
187 | }
188 |
189 | func checkParameterTextAndImage() gin.HandlerFunc {
190 | return func(c *gin.Context) {
191 | text := c.PostForm("text")
192 | typ := c.PostForm("type")
193 | img := c.PostForm("data")
194 | if utf8.RuneCountInString(text) > consts.PostMaxLength {
195 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("TooLongText", "字数过长!字数限制为"+strconv.Itoa(consts.PostMaxLength)+"字。", logger.INFO))
196 | return
197 | } else if len(text) == 0 && typ == "text" {
198 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("NoContent", "请输入内容", logger.INFO))
199 | return
200 | } else if typ != "text" && typ != "image" {
201 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("UnknownType", "未知类型的树洞", logger.WARN))
202 | return
203 | } else if int(float64(len(img))/consts.Base64Rate) > consts.ImgMaxLength {
204 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("TooLargeImage", "图片大小超出限制!", logger.WARN))
205 | return
206 | }
207 | c.Next()
208 | }
209 | }
210 |
211 | func safeSubSlice(slice []base.Post, low int, high int) []base.Post {
212 | if high > len(slice) {
213 | high = len(slice)
214 | }
215 | if 0 <= low && low <= high {
216 | return slice[low:high]
217 | }
218 | return nil
219 | }
220 |
221 | func searchHotPosts() gin.HandlerFunc {
222 | return func(c *gin.Context) {
223 | page := c.MustGet("page").(int)
224 | pageSize := consts.SearchPageSize
225 | keywords := c.Query("keywords")
226 |
227 | if keywords == "热榜" {
228 | user := c.MustGet("user").(base.User)
229 | posts := safeSubSlice(HotPosts.Get(), (page-1)*pageSize, page*pageSize)
230 | rtn, err := appendPostDetail(base.GetDb(false), posts, &user)
231 | if err != nil {
232 | base.HttpReturnWithCodeMinusOneAndAbort(c, err)
233 | return
234 | }
235 |
236 | comments, err4 := getCommentsByPosts(posts, &user)
237 | if err4 != nil {
238 | base.HttpReturnWithCodeMinusOne(c, err4)
239 | return
240 | }
241 |
242 | c.JSON(http.StatusOK, gin.H{
243 | "code": 0,
244 | "data": utils.IfThenElse(rtn != nil, rtn, []string{}),
245 | //"timestamp": utils.GetTimeStamp(),
246 | "count": utils.IfThenElse(rtn != nil, len(rtn), 0),
247 | "comments": comments,
248 | })
249 | c.Abort()
250 | return
251 | }
252 | c.Next()
253 | }
254 | }
255 |
256 | func checkParameterPage(maxPage int) gin.HandlerFunc {
257 | return func(c *gin.Context) {
258 | page, err := strconv.Atoi(c.Query("page"))
259 | if err != nil {
260 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "PageConversionFailed", "获取失败,参数page不合法"))
261 | return
262 | }
263 |
264 | if page > maxPage || page <= 0 {
265 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("PageOutOfBounds", "获取失败,参数page超出范围", logger.WARN))
266 | return
267 | }
268 | c.Set("page", page)
269 | c.Next()
270 | }
271 | }
272 |
273 | func checkParameterVoteOptions(c *gin.Context) {
274 | //voteOptions := c.PostForm("vote_options")
275 | //var optionsList []string
276 | //err := json.Unmarshal([]byte(voteOptions), &optionsList)
277 | //if err != nil {
278 | // c.Set("vote_data", "{}")
279 | // c.Next()
280 | // return
281 | //}
282 | optionsList := c.PostFormArray("vote_options[]")
283 | for _, option := range optionsList {
284 | if utf8.RuneCountInString(option) > consts.VoteOptionMaxCharacters {
285 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("TooLongVoteOption", "发送失败,选项长度最大是"+
286 | strconv.Itoa(consts.VoteOptionMaxCharacters)+"个字符", logger.INFO))
287 | return
288 | }
289 | }
290 | voteData := orderedmap.New()
291 | for _, option := range optionsList {
292 | striped := strings.TrimSpace(option)
293 | if len(striped) > 0 {
294 | voteData.Set(striped, 0)
295 | }
296 | }
297 | if len(voteData.Keys()) > consts.VoteMaxOptions {
298 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("TooManyVoteOptions", "发送失败,最多4个投票选项", logger.WARN))
299 | return
300 | }
301 | if len(voteData.Keys()) == 1 {
302 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("TooFewVoteOptions", "发送失败,至少两个选项", logger.WARN))
303 | return
304 | }
305 | _voteData, err := json.Marshal(voteData)
306 | if err != nil {
307 | log.Printf("error json marshal voteData! %s\n", err)
308 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "VoteDataMarshalFailed", "投票参数解析失败,请联系管理员"))
309 | return
310 | }
311 | strVoteData := string(_voteData)
312 |
313 | c.Set("vote_data", strVoteData)
314 | c.Next()
315 | }
316 |
--------------------------------------------------------------------------------
/pkg/route/security/createAccount.go:
--------------------------------------------------------------------------------
1 | package security
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "github.com/ProtonMail/gopenpgp/v2/helper"
7 | "github.com/SSSaaS/sssa-golang"
8 | "github.com/gin-gonic/gin"
9 | "github.com/google/uuid"
10 | "github.com/spf13/viper"
11 | "gorm.io/gorm"
12 | "gorm.io/gorm/clause"
13 | "net"
14 | "net/http"
15 | "strings"
16 | "time"
17 | "treehollow-v3-backend/pkg/base"
18 | "treehollow-v3-backend/pkg/consts"
19 | "treehollow-v3-backend/pkg/logger"
20 | "treehollow-v3-backend/pkg/mail"
21 | "treehollow-v3-backend/pkg/utils"
22 | )
23 |
24 | func saveKeyShares(user *base.User, pwHashed string, tx *gorm.DB) *logger.InternalError {
25 | pgpPublicKeys := viper.GetStringSlice("key_keepers_pgp_public_keys")
26 | minDecryptShares := viper.GetInt("min_decryption_key_count")
27 | if len(pgpPublicKeys) < 1 || minDecryptShares < 1 {
28 | return logger.NewSimpleError("ServerConfigError", "服务器加密配置错误,请联系管理员", logger.FATAL)
29 | }
30 | shares, err2 := sssa.Create(minDecryptShares, len(pgpPublicKeys), pwHashed)
31 | if err2 != nil {
32 | return logger.NewError(err2, "SSSACreateFailed", "加密失败,请联系管理员")
33 | }
34 | for i, share := range shares {
35 | keyRing, _ := utils.CreatePublicKeyRing(pgpPublicKeys[i])
36 | PGPEmail := keyRing.GetIdentities()[0].Email
37 | msg := fmt.Sprintf(`Hello keykeeper %s,
38 |
39 | If you can see this message, you've successfully obtained your key slice.
40 |
41 | The following string is the key slice that can be used to decrypt the user whose id=%d. There are %d such key slices in total and the user's personal information can be decrypted when the number of available key slices is greater than or equal to %d.
42 |
43 | If you agree to decrypt this user's personal information, please submit the following key slice to technician for decryption. If you do not agree to the decryption, please do not disclose this key slice to anyone.
44 |
45 | ======================
46 | %s
47 | ======================`, PGPEmail, user.ID, len(pgpPublicKeys), minDecryptShares, share)
48 | armor, err3 := helper.EncryptMessageArmored(pgpPublicKeys[i], msg)
49 | if err3 != nil {
50 | return logger.NewError(err3, "EncryptMessageArmoredFailed", "加密失败,请联系管理员")
51 | }
52 | err4 := tx.Create(&base.DecryptionKeyShares{
53 | EmailEncrypted: user.EmailEncrypted,
54 | PGPMessage: armor,
55 | PGPEmail: PGPEmail,
56 | }).Error
57 | if err4 != nil {
58 | return logger.NewError(err4, "SaveDecryptionKeySharesFailed", consts.DatabaseWriteFailedString)
59 | }
60 | }
61 | return nil
62 | }
63 |
64 | func createDevice(c *gin.Context, user *base.User, pwHashed string, tx *gorm.DB) error {
65 | email := strings.ToLower(c.PostForm("email"))
66 | token := utils.GenToken()
67 | deviceUUID := uuid.New().String()
68 | deviceType := c.MustGet("device_type").(base.DeviceType)
69 | deviceInfo := c.PostForm("device_info")
70 | city := "Unknown"
71 |
72 | if geoDb := utils.GeoDb.Get(); geoDb != nil {
73 | ip := net.ParseIP(c.ClientIP())
74 | record, err5 := geoDb.City(ip)
75 | if err5 == nil {
76 | country := record.Country.Names["zh-CN"]
77 | if len(country) == 0 {
78 | country = record.Country.Names["en"]
79 | }
80 | if len(country) > 0 {
81 | cityName := record.City.Names["zh-CN"]
82 | if len(cityName) == 0 {
83 | cityName = record.City.Names["en"]
84 | }
85 | if len(cityName) > 0 {
86 | city = cityName + ", " + country
87 | } else {
88 | city = country
89 | }
90 | }
91 | }
92 | }
93 |
94 | err := tx.Create(&base.Device{
95 | ID: deviceUUID,
96 | UserID: user.ID,
97 | Token: token,
98 | DeviceInfo: deviceInfo,
99 | Type: deviceType,
100 | LoginIP: c.ClientIP(),
101 | LoginCity: city,
102 | IOSDeviceToken: c.PostForm("ios_device_token"),
103 | }).Error
104 | if err != nil {
105 | rtn := logger.NewError(err, "CreateSaveDeviceFailed", consts.DatabaseWriteFailedString)
106 | base.HttpReturnWithCodeMinusOne(c, rtn)
107 | return rtn.Err
108 | }
109 |
110 | err4 := saveKeyShares(user, pwHashed, tx)
111 | if err4 != nil {
112 | base.HttpReturnWithCodeMinusOne(c, err4)
113 | return err4.Err
114 | }
115 |
116 | c.JSON(http.StatusOK, gin.H{
117 | "code": 0,
118 | "token": token,
119 | "uuid": deviceUUID,
120 | })
121 | go func() {
122 | _ = mail.SendPasswordNonceEmail(user.ForgetPwNonce, email)
123 | }()
124 | return nil
125 | }
126 |
127 | func createAccount(c *gin.Context) {
128 | oldToken := c.PostForm("old_token")
129 | emailHash := c.MustGet("email_hash").(string)
130 | email := strings.ToLower(c.PostForm("email"))
131 | pwHashed := c.PostForm("password_hashed")
132 | emailEncrypted, err := utils.AESEncrypt(email, pwHashed)
133 |
134 | if err != nil {
135 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "AESEncryptFailedInCreateAccount", consts.DatabaseEncryptFailedString))
136 | return
137 | }
138 |
139 | var user base.User
140 | err5 := base.GetDb(false).Where("old_email_hash = ?", emailHash).
141 | Model(&base.User{}).First(&user).Error
142 | if err5 == nil && user.OldToken == oldToken {
143 | // Don't need valid code
144 | } else {
145 | if err5 != nil && !errors.Is(err5, gorm.ErrRecordNotFound) {
146 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err5, "QueryOldEmailHashFailed", consts.DatabaseReadFailedString))
147 | return
148 | }
149 | code := c.PostForm("valid_code")
150 | now := utils.GetTimeStamp()
151 | correctCode, timeStamp, failedTimes, err2 := base.GetVerificationCode(emailHash)
152 | if err2 != nil && !errors.Is(err2, gorm.ErrRecordNotFound) {
153 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err2, "QueryValidCodeFailed", consts.DatabaseReadFailedString))
154 | return
155 | }
156 | if failedTimes >= 10 && now-timeStamp <= 43200 {
157 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("ValidCodeTooMuchFailed", "验证码错误尝试次数过多,请重新发送验证码", logger.INFO))
158 | return
159 | }
160 | if correctCode != code || now-timeStamp > 43200 {
161 | base.HttpReturnWithErrAndAbort(c, -10, logger.NewSimpleError("ValidCodeInvalid", "验证码无效或过期", logger.WARN))
162 | _ = base.GetDb(false).Model(&base.VerificationCode{}).Where("email_hash = ?", emailHash).
163 | Update("failed_times", gorm.Expr("failed_times + 1")).Error
164 | return
165 | }
166 | }
167 |
168 | _ = base.GetDb(false).Transaction(func(tx *gorm.DB) error {
169 | if err = tx.Create(&base.Email{EmailHash: emailHash}).Error; err != nil {
170 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err, "CreateEmailHashFailed", consts.DatabaseWriteFailedString))
171 | return err
172 | }
173 |
174 | if err5 != nil {
175 | user = base.User{
176 | EmailEncrypted: emailEncrypted,
177 | ForgetPwNonce: utils.GenNonce(),
178 | Role: base.NormalUserRole,
179 | }
180 | if err = tx.Create(&user).Error; err != nil {
181 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err, "CreateUserFailed", consts.DatabaseWriteFailedString))
182 | return err
183 | }
184 | } else {
185 | user.OldEmailHash = ""
186 | user.OldToken = ""
187 | user.EmailEncrypted = emailEncrypted
188 | user.UpdatedAt = time.Now()
189 | user.ForgetPwNonce = utils.GenNonce()
190 | if err = tx.Model(&base.User{}).Where("id = ?", user.ID).Updates(user).Error; err != nil {
191 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err, "UpdateOldUserFailed", consts.DatabaseWriteFailedString))
192 | return err
193 | }
194 | }
195 |
196 | return createDevice(c, &user, pwHashed, tx)
197 | })
198 | }
199 |
200 | func changePassword(c *gin.Context) {
201 | oldPwHashed := c.PostForm("old_password_hashed")
202 | newPwHashed := c.PostForm("new_password_hashed")
203 | email := strings.ToLower(c.PostForm("email"))
204 |
205 | if len(email) > 100 || len(oldPwHashed) > 64 || len(newPwHashed) > 64 {
206 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("ChangePasswordInvalidParam", "参数错误", logger.WARN))
207 | return
208 | }
209 |
210 | oldEmailEncrypted, err := utils.AESEncrypt(email, oldPwHashed)
211 | if err != nil {
212 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "AESEncryptFailed", consts.DatabaseEncryptFailedString))
213 | return
214 | }
215 | newEmailEncrypted, err2 := utils.AESEncrypt(email, newPwHashed)
216 | if err2 != nil {
217 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err2, "AESEncryptFailed2", consts.DatabaseEncryptFailedString))
218 | return
219 | }
220 |
221 | _ = base.GetDb(false).Transaction(func(tx *gorm.DB) error {
222 | var user base.User
223 | result := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("email_encrypted = ?", oldEmailEncrypted).
224 | Model(&base.User{}).First(&user)
225 | if result.Error != nil {
226 | if errors.Is(result.Error, gorm.ErrRecordNotFound) {
227 | base.HttpReturnWithCodeMinusOne(c, logger.NewSimpleError("ChangePasswordNoAuth", "用户名或密码错误", logger.WARN))
228 | return nil
229 | }
230 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(result.Error, "GetUserByEmailEncryptedFailed", consts.DatabaseReadFailedString))
231 | return result.Error
232 | }
233 | //if result.RowsAffected != 1 {
234 | // base.HttpReturnWithCodeMinusOne(c, logger.NewSimpleError("ChangePasswordNoAuth", "用户名或密码错误", logger.WARN))
235 | // return nil
236 | //}
237 |
238 | result = tx.Model(&base.User{}).Where("email_encrypted = ?", oldEmailEncrypted).
239 | Update("email_encrypted", newEmailEncrypted)
240 |
241 | if result.Error != nil {
242 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(result.Error, "UpdateEmailEncryptedFailed", consts.DatabaseWriteFailedString))
243 | return result.Error
244 | }
245 |
246 | err3 := tx.Where("email_encrypted = ?", oldEmailEncrypted).
247 | Delete(&base.DecryptionKeyShares{}).Error
248 | if err3 != nil {
249 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err3, "DeleteDecryptionSharesFailed", consts.DatabaseWriteFailedString))
250 | return err3
251 | }
252 |
253 | err4 := saveKeyShares(&base.User{
254 | EmailEncrypted: newEmailEncrypted,
255 | }, newPwHashed, tx)
256 | if err4 != nil {
257 | base.HttpReturnWithCodeMinusOne(c, err4)
258 | return err4.Err
259 | }
260 |
261 | err5 := tx.Where("user_id = ?", user.ID).
262 | Delete(&base.Device{}).Error
263 | if err5 != nil {
264 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err, "DeleteUserAllDevicesFailed", consts.DatabaseWriteFailedString))
265 | return err5
266 | }
267 |
268 | //TODO: (middle priority) send email
269 | c.JSON(http.StatusOK, gin.H{
270 | "code": 0,
271 | })
272 |
273 | return nil
274 | })
275 | }
276 |
--------------------------------------------------------------------------------
/pkg/base/sql.go:
--------------------------------------------------------------------------------
1 | package base
2 |
3 | import (
4 | "github.com/spf13/viper"
5 | "gorm.io/driver/mysql"
6 | "gorm.io/gorm"
7 | "gorm.io/gorm/logger"
8 | "io"
9 | "log"
10 | "os"
11 | "strconv"
12 | "strings"
13 | "time"
14 | "treehollow-v3-backend/pkg/consts"
15 | "treehollow-v3-backend/pkg/model"
16 | "treehollow-v3-backend/pkg/utils"
17 | )
18 |
19 | var db *gorm.DB
20 |
21 | func AutoMigrateDb() {
22 | err := db.AutoMigrate(&User{}, &DecryptionKeyShares{}, &Email{},
23 | &Device{}, &PushSettings{}, &Vote{},
24 | &VerificationCode{}, &Post{}, &PushMessage{},
25 | &Comment{}, &Attention{}, &Report{}, &SystemMessage{}, Ban{})
26 | utils.FatalErrorHandle(&err, "error migrating database!")
27 | }
28 |
29 | func InitDb() {
30 | err2 := initRedis()
31 | utils.FatalErrorHandle(&err2, "error init redis")
32 | initCache()
33 |
34 | logFile, err := os.OpenFile("sql.log", os.O_CREATE|os.O_APPEND|os.O_RDWR, 0666)
35 | utils.FatalErrorHandle(&err, "error init sql log file")
36 | mw := io.MultiWriter(os.Stdout, logFile)
37 | logLevel := logger.Warn
38 | if viper.GetBool("is_debug") {
39 | logLevel = logger.Info
40 | }
41 | newLogger := logger.New(
42 | log.New(mw, "\r\n", log.LstdFlags), // io writer
43 | logger.Config{
44 | SlowThreshold: time.Millisecond * 500, // Slow SQL threshold
45 | LogLevel: logLevel, // Log level
46 | Colorful: false,
47 | },
48 | )
49 |
50 | db, err = gorm.Open(mysql.Open(
51 | viper.GetString("sql_source")+"?charset=utf8mb4&parseTime=True&loc=Asia%2FShanghai"), &gorm.Config{
52 | DisableForeignKeyConstraintWhenMigrating: true,
53 | Logger: newLogger,
54 | })
55 | utils.FatalErrorHandle(&err, "error opening sql db")
56 | }
57 |
58 | func GetDb(unscoped bool) *gorm.DB {
59 | if unscoped {
60 | return db.Unscoped()
61 | }
62 | return db
63 | }
64 |
65 | func ListPosts(tx *gorm.DB, p int, user *User) (posts []Post, err error) {
66 | offset := (p - 1) * consts.PageSize
67 | limit := consts.PageSize
68 | pinnedPids := viper.GetIntSlice("pin_pids")
69 | if CanViewDeletedPost(user) {
70 | tx = tx.Unscoped()
71 | }
72 | if len(pinnedPids) == 0 {
73 | err = tx.Order("id desc").Limit(limit).Offset(offset).Find(&posts).Error
74 | } else {
75 | err = tx.Where("id not in ?", pinnedPids).Order("id desc").Limit(limit).Offset(offset).
76 | Find(&posts).Error
77 | }
78 | return
79 | }
80 |
81 | func ListMsgs(p int, minId int32, userId int32, pushOnly bool) (msgs []PushMessage, err error) {
82 | offset := (p - 1) * consts.MsgPageSize
83 | limit := consts.MsgPageSize
84 | tx := db
85 | if pushOnly {
86 | tx = tx.Where("do_push = ?", true)
87 | }
88 | err = tx.Where("user_id = ? and id > ?", userId, minId).Order("id desc").Limit(limit).Offset(offset).
89 | Find(&msgs).Error
90 | return
91 | }
92 |
93 | func GetComments(pid int32) ([]Comment, error) {
94 | var comments []Comment
95 | err := db.Unscoped().Where("post_id = ?", pid).Order("id asc").Find(&comments).Error
96 | return comments, err
97 | }
98 |
99 | func GetMultipleComments(tx *gorm.DB, pids []int32) ([]Comment, error) {
100 | var comments []Comment
101 | err := tx.Unscoped().Where("post_id in (?)", pids).Order("id asc").Find(&comments).Error
102 | return comments, err
103 | }
104 |
105 | func SearchPosts(page int, keywords string, limitPids []int32, user User, order model.SearchOrder,
106 | includeComment bool, beforeTimestamp int64, afterTimestamp int64) (posts []Post, err error) {
107 | canViewDelete := CanViewDeletedPost(&user)
108 | var thePost Post
109 | var err2 error
110 | pid := -1
111 | if page == 1 {
112 | if strings.HasPrefix(keywords, "#") {
113 | pid, err2 = strconv.Atoi(keywords[1:])
114 | } else {
115 | pid, err2 = strconv.Atoi(keywords)
116 | }
117 | if err2 == nil {
118 | err2 = GetDb(canViewDelete).First(&thePost, int32(pid)).Error
119 | }
120 | }
121 | offset := (page - 1) * consts.SearchPageSize
122 | limit := consts.SearchPageSize
123 |
124 | tx := GetDb(canViewDelete)
125 | if limitPids != nil {
126 | tx = tx.Where("id in ?", limitPids)
127 | }
128 |
129 | subSearch := func(tx0 *gorm.DB, isTag bool) *gorm.DB {
130 | if isTag {
131 | return tx0.Where("tag = ?", keywords[1:])
132 | }
133 | replacedKeywords := "+" + strings.ReplaceAll(keywords, " ", " +")
134 | return tx0.Where("match(text) against(? IN BOOLEAN MODE)", replacedKeywords)
135 | }
136 |
137 | if canViewDelete && keywords == "dels" {
138 | subQuery1 := db.Unscoped().Model(&Report{}).Distinct().
139 | Where("type in (?) and user_id != reported_user_id and post_id = posts.id",
140 | []ReportType{UserDelete, AdminDeleteAndBan}).Select("post_id")
141 | err = db.Unscoped().Where("id in (?)", subQuery1).
142 | Order(order.ToString()).Limit(limit).Offset(offset).Find(&posts).Error
143 | } else {
144 | var subQuery2 *gorm.DB
145 | if includeComment {
146 | subQuery := subSearch(GetDb(canViewDelete).Model(&Comment{}).Distinct(),
147 | strings.HasPrefix(keywords, "#")).
148 | Select("post_id")
149 | subQuery2 = subSearch(GetDb(canViewDelete), strings.HasPrefix(keywords, "#")).
150 | Or("id in (?)", subQuery)
151 | } else {
152 | subQuery2 = subSearch(GetDb(canViewDelete), strings.HasPrefix(keywords, "#"))
153 | }
154 |
155 | if beforeTimestamp > 0 {
156 | tx = tx.Where("created_at < ?", time.Unix(beforeTimestamp, 0).In(consts.TimeLoc))
157 | }
158 | if afterTimestamp > 0 {
159 | tx = tx.Where("created_at >= ?", time.Unix(afterTimestamp, 0).In(consts.TimeLoc))
160 | }
161 | if pid > 0 {
162 | tx = tx.Where("id != ?", pid)
163 | }
164 |
165 | err = tx.Where(subQuery2).Order(order.ToString()).Limit(limit).Offset(offset).Find(&posts).Error
166 | }
167 |
168 | if err2 == nil && page == 1 {
169 | posts = append([]Post{thePost}, posts...)
170 | }
171 | return
172 | }
173 |
174 | func GetVerificationCode(emailHash string) (string, int64, int, error) {
175 | var vc VerificationCode
176 | err := db.Where("email_hash = ?", emailHash).First(&vc).Error
177 | return vc.Code, vc.UpdatedAt.Unix(), vc.FailedTimes, err
178 | }
179 |
180 | func SavePost(uid int32, text string, tag string, typ string, filePath string, metaStr string, voteData string) (id int32, err error) {
181 | post := Post{Tag: tag, UserID: uid, Text: text, Type: typ, FilePath: filePath, LikeNum: 0, ReplyNum: 0,
182 | ReportNum: 0, FileMetadata: metaStr, VoteData: voteData}
183 | err = db.Save(&post).Error
184 | id = post.ID
185 | return
186 | }
187 |
188 | func GetHotPosts() (posts []Post, err error) {
189 | err = db.Where("id>(SELECT MAX(id)-2000 FROM posts)").
190 | Order("like_num*3+reply_num+UNIX_TIMESTAMP(created_at)/1800-report_num*10 DESC").
191 | Limit(200).Find(&posts).Error
192 | return
193 | }
194 |
195 | func SaveComment(tx *gorm.DB, uid int32, text string, tag string, typ string, filePath string, pid int32, replyTo int32, name string,
196 | metaStr string) (id int32, err error) {
197 | comment := Comment{Tag: tag, UserID: uid, PostID: pid, ReplyTo: replyTo, Text: text, Type: typ, FilePath: filePath,
198 | Name: name, FileMetadata: metaStr}
199 | err = tx.Save(&comment).Error
200 | id = comment.ID
201 | if err == nil {
202 | err = DelCommentCache(int(pid))
203 | }
204 | return
205 | }
206 |
207 | func GenCommenterName(tx *gorm.DB, dzUserID int32, czUserID int32, postID int32, names0 []string, names1 []string) (string, error) {
208 | var name string
209 | var err error
210 | if dzUserID == czUserID {
211 | name = consts.DzName
212 | } else {
213 | var comment Comment
214 | err = tx.Unscoped().Where("user_id = ? AND post_id=?", czUserID, postID).First(&comment).Error
215 | if err != nil { // token is not in comments
216 | var count int64
217 | err = tx.Unscoped().Model(&Comment{}).Where("user_id != ? AND post_id=?", dzUserID, postID).
218 | Distinct("user_id").Count(&count).Error
219 | if err != nil {
220 | return "", err
221 | }
222 | name = utils.GetCommenterName(int(count)+1, names0, names1)
223 | } else {
224 | name = comment.Name
225 | }
226 | }
227 | return name, nil
228 | }
229 |
230 | func GetBannedTime(tx *gorm.DB, uid int32, startTime int64) (times int64, err error) {
231 | err = tx.Model(&Ban{}).Where("user_id = ? and expire_at > ?", uid, startTime).Count(×).Error
232 | return
233 | }
234 |
235 | func calcBanExpireTime(times int64) int64 {
236 | return utils.GetTimeStamp() + (times+1)*86400
237 | }
238 |
239 | func generateBanReason(report Report, originalText string) (rtn string) {
240 | var pre string
241 | if report.IsComment {
242 | pre = "您的树洞评论#" + strconv.Itoa(int(report.PostID)) + "-" + strconv.Itoa(int(report.CommentID))
243 | } else {
244 | pre = "您的树洞#" + strconv.Itoa(int(report.PostID))
245 | }
246 | switch report.Type {
247 | case UserReport:
248 | rtn = pre + "\n\"" + originalText + "\"\n因为用户举报过多被删除。"
249 | case AdminDeleteAndBan:
250 | rtn = pre + "\n\"" + originalText + "\"\n被管理员删除。管理员的删除理由是:【" + report.Reason + "】。"
251 | }
252 | return
253 | }
254 |
255 | func DeleteByReport(tx *gorm.DB, report Report) (err error) {
256 | if report.IsComment {
257 | err = tx.Where("id = ?", report.CommentID).Delete(&Comment{}).Error
258 | if err == nil {
259 | err = tx.Model(&Post{}).Where("id = ?", report.PostID).Update("reply_num",
260 | gorm.Expr("reply_num - 1")).Error
261 | if err == nil {
262 | err = DelCommentCache(int(report.PostID))
263 | go func() {
264 | SendDeletionToPushService(report.CommentID)
265 | }()
266 | }
267 | }
268 | } else {
269 | err = tx.Where("id = ?", report.PostID).Delete(&Post{}).Error
270 | }
271 | return
272 | }
273 |
274 | func DeleteAndBan(tx *gorm.DB, report Report, text string) (err error) {
275 | err = DeleteByReport(tx, report)
276 | if err == nil {
277 | times, err2 := GetBannedTime(tx, report.ReportedUserID, 0)
278 | if err2 == nil {
279 | tx.Create(&Ban{
280 | UserID: report.ReportedUserID,
281 | ReportID: report.ID,
282 | Reason: generateBanReason(report, text),
283 | ExpireAt: calcBanExpireTime(times),
284 | })
285 | }
286 | }
287 | return
288 | }
289 |
290 | func SetTagByReport(tx *gorm.DB, report Report) (err error) {
291 | if report.IsComment {
292 | err = tx.Model(&Comment{}).Where("id = ?", report.CommentID).
293 | Update("tag", report.Reason).Error
294 | if err == nil {
295 | err = tx.Model(&Post{}).Where("id = ?", report.PostID).
296 | Update("updated_at", time.Now()).Error
297 | if err == nil {
298 | err = DelCommentCache(int(report.PostID))
299 | }
300 | }
301 | } else {
302 | err = tx.Model(&Post{}).Where("id = ?", report.PostID).
303 | Update("tag", report.Reason).Error
304 | }
305 | return
306 | }
307 |
308 | func UnbanByReport(tx *gorm.DB, report Report) (err error) {
309 | var ban Ban
310 | subQuery := tx.Model(&Report{}).Distinct().
311 | Where("post_id = ? and comment_id = ? and is_comment = ? and type in (?)",
312 | report.PostID, report.CommentID, report.IsComment,
313 | []ReportType{UserReport, AdminDeleteAndBan}).
314 | Select("id")
315 | err = tx.Model(&Ban{}).Where("report_id in (?)", subQuery).First(&ban).Error
316 | if err == nil {
317 | err = tx.Delete(&ban).Error
318 | }
319 | return
320 | }
321 |
--------------------------------------------------------------------------------
/cmd/treehollow-migrate-v2-to-v3/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "bufio"
5 | "encoding/json"
6 | "errors"
7 | "github.com/spf13/viper"
8 | "gorm.io/gorm"
9 | "image"
10 | "log"
11 | "os"
12 | "path/filepath"
13 | "strings"
14 | "time"
15 | "treehollow-v3-backend/pkg/base"
16 | "treehollow-v3-backend/pkg/config"
17 | "treehollow-v3-backend/pkg/logger"
18 | "treehollow-v3-backend/pkg/model"
19 | "treehollow-v3-backend/pkg/utils"
20 | )
21 |
22 | type UserRole int32
23 |
24 | type ReportType string
25 |
26 | // codebeat:disable[TOO_MANY_IVARS]
27 | type User struct {
28 | ID int32 `gorm:"primaryKey;autoIncrement;not null"`
29 | OldEmailHash string `gorm:"index;type:varchar(64) NOT NULL"`
30 | OldToken string `gorm:"index;type:varchar(32) NOT NULL"`
31 | EmailEncrypted string `gorm:"index;type:varchar(200) NOT NULL"`
32 | //KeyEncrypted string `gorm:"type:varchar(200) NOT NULL"`
33 | ForgetPwNonce string `gorm:"type:varchar(36) NOT NULL"`
34 | Role UserRole
35 | //SystemMessages []SystemMessage
36 | //Bans []Ban
37 | //Posts []Post
38 | //Comments []Comment
39 | //Devices []Device
40 | CreatedAt time.Time
41 | UpdatedAt time.Time
42 | }
43 |
44 | type DecryptionKeyShares struct {
45 | EmailEncrypted string `gorm:"index;type:varchar(200) NOT NULL"`
46 | PGPMessage string `gorm:"type:varchar(5000) NOT NULL"`
47 | PGPEmail string `gorm:"index;type:varchar(100) NOT NULL"`
48 | CreatedAt time.Time
49 | DeletedAt gorm.DeletedAt `gorm:"index"`
50 | }
51 |
52 | type Email struct {
53 | EmailHash string `gorm:"primaryKey;type:char(64) NOT NULL"`
54 | }
55 |
56 | type DeviceType int32
57 |
58 | type Device struct {
59 | ID string `gorm:"type:char(36);primary_key"`
60 | UserID int32 `gorm:"index;not null"`
61 | DeviceInfo string `gorm:"type:varchar(100) NOT NULL"`
62 | Type DeviceType
63 | IOSDeviceToken string `gorm:"type:varchar(100)"`
64 | Token string `gorm:"index;type:char(32) NOT NULL"`
65 | LoginIP string `gorm:"type:varchar(50) NOT NULL"`
66 | LoginCity string `gorm:"type:varchar(50) NOT NULL"`
67 | CreatedAt time.Time `gorm:"index"`
68 | DeletedAt gorm.DeletedAt `gorm:"index"`
69 | }
70 |
71 | type PushSettings struct {
72 | UserID int32 `gorm:"primaryKey;not null"`
73 | Settings model.PushType
74 | }
75 |
76 | type VerificationCode struct {
77 | EmailHash string `gorm:"primaryKey;type:char(64) NOT NULL"`
78 | Code string `gorm:"type:varchar(20) NOT NULL"`
79 | FailedTimes int
80 | CreatedAt time.Time
81 | UpdatedAt time.Time
82 | }
83 |
84 | type Post struct {
85 | ID int32 `gorm:"primaryKey;autoIncrement;not null"`
86 | //User User
87 | UserID int32
88 | Text string `gorm:"index:,class:FULLTEXT,option:WITH PARSER ngram;type: varchar(10000) NOT NULL"`
89 | Tag string `gorm:"index;type:varchar(60) NOT NULL"`
90 | Type string `gorm:"type:varchar(20) NOT NULL"`
91 | FilePath string `gorm:"type:varchar(60) NOT NULL"`
92 | FileMetadata string `gorm:"type:varchar(40) NOT NULL"`
93 | VoteData string `gorm:"type:varchar(200) NOT NULL"`
94 | LikeNum int32 `gorm:"index"`
95 | ReplyNum int32 `gorm:"index"`
96 | ReportNum int32
97 | //Comments []Comment
98 | CreatedAt time.Time `gorm:"index"`
99 | UpdatedAt time.Time `gorm:"index"`
100 | DeletedAt gorm.DeletedAt `gorm:"index"`
101 | }
102 |
103 | type Comment struct {
104 | ID int32 `gorm:"primaryKey;autoIncrement;not null"`
105 | ReplyTo int32 `gorm:"index"`
106 | //Post Post
107 | PostID int32 `gorm:"index"`
108 | //User User
109 | UserID int32
110 | Text string `gorm:"index:,class:FULLTEXT,option:WITH PARSER ngram;type: varchar(10000) NOT NULL"`
111 | Tag string `gorm:"index;type:varchar(60) NOT NULL"`
112 | Type string `gorm:"type:varchar(20) NOT NULL"`
113 | FilePath string `gorm:"type:varchar(60) NOT NULL"`
114 | FileMetadata string `gorm:"type:varchar(40) NOT NULL"`
115 | Name string `gorm:"type:varchar(60) NOT NULL"`
116 | CreatedAt time.Time
117 | UpdatedAt time.Time
118 | DeletedAt gorm.DeletedAt `gorm:"index"`
119 | }
120 |
121 | type Report struct {
122 | ID int32 `gorm:"primaryKey;autoIncrement;not null"`
123 | //User User
124 | UserID int32
125 | //ReportedUser User
126 | ReportedUserID int32
127 | //Post Post
128 | PostID int32
129 | //Comment Comment
130 | CommentID int32
131 | Reason string `gorm:"type: varchar(1000) NOT NULL"`
132 | Type ReportType `gorm:"type:varchar(20) NOT NULL"`
133 | IsComment bool
134 | Weight int32
135 | CreatedAt time.Time `gorm:"index"`
136 | DeletedAt gorm.DeletedAt `gorm:"index"`
137 | }
138 |
139 | //TODO: (low priority)undelete = remove user reports
140 |
141 | type Attention struct {
142 | User User
143 | UserID int32 `gorm:"primaryKey;index"`
144 | Post Post
145 | PostID int32 `gorm:"primaryKey;index"`
146 | }
147 |
148 | type Vote struct {
149 | User User
150 | UserID int32 `gorm:"primaryKey;index"`
151 | Post Post
152 | PostID int32 `gorm:"primaryKey;index"`
153 | Option string `gorm:"type:varchar(100) NOT NULL"`
154 | }
155 |
156 | type SystemMessage struct {
157 | ID int32 `gorm:"primaryKey;autoIncrement;not null"`
158 | //User User
159 | UserID int32
160 | Text string `gorm:"type: varchar(11000) NOT NULL"`
161 | Title string `gorm:"type: varchar(100) NOT NULL"`
162 | //Ban Ban
163 | BanID int32 `gorm:"index"`
164 | CreatedAt time.Time `gorm:"index"`
165 | DeletedAt gorm.DeletedAt `gorm:"index"`
166 | }
167 |
168 | type Ban struct {
169 | ID int32 `gorm:"primaryKey;autoIncrement;not null"`
170 | //User User
171 | UserID int32
172 | //Report Report
173 | ReportID int32
174 | Reason string `gorm:"type: varchar(11000) NOT NULL"`
175 | ExpireAt int64
176 | CreatedAt time.Time `gorm:"index"`
177 | DeletedAt gorm.DeletedAt `gorm:"index"`
178 | }
179 |
180 | //type Messages struct {
181 | // ID int32 `gorm:"primaryKey;autoIncrement;not null"`
182 | // UserID int32 `gorm:"index"`
183 | // CommentID int32
184 | //}
185 |
186 | func migrateUser(page int) (count int) {
187 | var results []map[string]interface{}
188 | var users []User
189 | err := base.GetDb(false).Table("v2_users").Order("id asc").Offset(batchSize * page).
190 | Limit(batchSize).Find(&results).Error
191 | utils.FatalErrorHandle(&err, "error reading v2_users!")
192 | for _, result := range results {
193 | var updateAt time.Time
194 | if result["updated_at"] != nil {
195 | updateAt = result["updated_at"].(time.Time)
196 | } else {
197 | updateAt = result["created_at"].(time.Time)
198 | }
199 | user := User{
200 | ID: result["id"].(int32),
201 | OldToken: result["token"].(string),
202 | OldEmailHash: result["email_hash"].(string),
203 | EmailEncrypted: "",
204 | ForgetPwNonce: "",
205 | Role: UserRole(result["role"].(int64)),
206 | CreatedAt: result["created_at"].(time.Time),
207 | UpdatedAt: updateAt,
208 | }
209 | users = append(users, user)
210 | }
211 | count = len(results)
212 | if count > 0 {
213 | err = base.GetDb(false).Create(&users).Error
214 | utils.FatalErrorHandle(&err, "error writing v2_users!")
215 | }
216 | return
217 | }
218 |
219 | func migrateComment(page int) (count int) {
220 | var results []map[string]interface{}
221 | var comments []Comment
222 | err := base.GetDb(false).Table("v2_comments").Order("id asc").Offset(batchSize * page).
223 | Limit(batchSize).Find(&results).Error
224 | utils.FatalErrorHandle(&err, "error reading v2_comments!")
225 | for _, result := range results {
226 | var deletedAt gorm.DeletedAt
227 | _ = deletedAt.Scan(result["deleted_at"])
228 | comment := Comment{
229 | ID: result["id"].(int32),
230 | ReplyTo: -1,
231 | PostID: int32(result["post_id"].(int64)),
232 | UserID: int32(result["user_id"].(int64)),
233 | Text: result["text"].(string),
234 | Tag: result["tag"].(string),
235 | Type: result["type"].(string),
236 | FilePath: result["file_path"].(string),
237 | FileMetadata: getImgMetadata(result["file_path"].(string)),
238 | Name: result["name"].(string),
239 | CreatedAt: result["created_at"].(time.Time),
240 | UpdatedAt: result["updated_at"].(time.Time),
241 | DeletedAt: deletedAt,
242 | }
243 | comments = append(comments, comment)
244 | }
245 | count = len(results)
246 | if count > 0 {
247 | err = base.GetDb(false).Create(&comments).Error
248 | utils.FatalErrorHandle(&err, "error writing v2_comments!")
249 | }
250 | return
251 | }
252 |
253 | func migratePost(page int) (count int) {
254 | var results []map[string]interface{}
255 | var posts []Post
256 | err := base.GetDb(false).Table("v2_posts").Order("id asc").Offset(batchSize * page).
257 | Limit(batchSize).Find(&results).Error
258 | utils.FatalErrorHandle(&err, "error reading v2_posts!")
259 | for _, result := range results {
260 | var deletedAt gorm.DeletedAt
261 | _ = deletedAt.Scan(result["deleted_at"])
262 |
263 | tag := result["tag"].(string)
264 | if tag == "折叠" {
265 | tag = "令人不适"
266 | }
267 | post := Post{
268 | ID: result["id"].(int32),
269 | UserID: int32(result["user_id"].(int64)),
270 | Text: result["text"].(string),
271 | Tag: tag,
272 | Type: result["type"].(string),
273 | FilePath: result["file_path"].(string),
274 | FileMetadata: getImgMetadata(result["file_path"].(string)),
275 | LikeNum: int32(result["like_num"].(int64)),
276 | ReplyNum: int32(result["reply_num"].(int64)),
277 | ReportNum: int32(result["report_num"].(int64)),
278 | VoteData: "{}",
279 | CreatedAt: result["created_at"].(time.Time),
280 | UpdatedAt: result["updated_at"].(time.Time),
281 | DeletedAt: deletedAt,
282 | }
283 | posts = append(posts, post)
284 | }
285 | count = len(results)
286 | if count > 0 {
287 | err = base.GetDb(false).Create(&posts).Error
288 | utils.FatalErrorHandle(&err, "error writing v2_posts!")
289 | }
290 | return
291 | }
292 |
293 | const batchSize = 3000
294 |
295 | var metaData map[string]string
296 |
297 | func getImgMetadata(imgName string) (rtn string) {
298 | if imgName == "" {
299 | return "{}"
300 | }
301 | var found bool
302 | rtn, found = metaData[imgName]
303 | if !found {
304 | log.Printf("img %s not found\n", imgName)
305 | }
306 | return
307 | }
308 |
309 | func migrate(foo func(int) int) {
310 | count := -1
311 | page := 0
312 | for count != 0 {
313 | count = foo(page)
314 | page += 1
315 | }
316 |
317 | }
318 |
319 | func main() {
320 | logger.InitLog("migration.log")
321 | config.InitConfigFile()
322 | log.Println("starting migration...")
323 |
324 | var err error
325 | base.InitDb()
326 |
327 | metaData = make(map[string]string)
328 | err = filepath.Walk(viper.GetString("images_path"), func(path string, info os.FileInfo, err error) error {
329 | if err != nil {
330 | return err
331 | }
332 |
333 | if strings.HasSuffix(path, "jpeg") && !info.IsDir() {
334 | f, err2 := os.Open(path)
335 | if err2 != nil {
336 | log.Printf("error opening file %s %s\n", path, err2)
337 | return err2
338 | }
339 | defer f.Close()
340 |
341 | im, _, err3 := image.DecodeConfig(bufio.NewReader(f))
342 | if err3 != nil {
343 | log.Printf("error decoding image %s %s\n", path, err3)
344 | } else {
345 | metadataBytes, err4 := json.Marshal(map[string]int{"w": im.Width, "h": im.Height})
346 | if err4 != nil {
347 | log.Printf("error json.Marshal while decoding image %s , err=%s\n", path, err4.Error())
348 | return errors.New("图片大小解析失败")
349 | }
350 | metaData[info.Name()] = string(metadataBytes)
351 | }
352 | }
353 |
354 | return nil
355 | })
356 | utils.FatalErrorHandle(&err, "error walking images folder")
357 |
358 | err = base.GetDb(false).Migrator().RenameTable("users", "v2_users")
359 | utils.FatalErrorHandle(&err, "error rename table")
360 | err = base.GetDb(false).Migrator().RenameTable("posts", "v2_posts")
361 | utils.FatalErrorHandle(&err, "error rename table")
362 | err = base.GetDb(false).Migrator().RenameTable("comments", "v2_comments")
363 | utils.FatalErrorHandle(&err, "error rename table")
364 |
365 | err = base.GetDb(false).
366 | AutoMigrate(&User{}, &DecryptionKeyShares{}, &Email{},
367 | &Device{}, &PushSettings{}, &Vote{},
368 | &VerificationCode{}, &Post{}, //&Messages{},
369 | &Comment{}, &Attention{}, &Report{}, &SystemMessage{}, Ban{})
370 | utils.FatalErrorHandle(&err, "error migrating database!")
371 |
372 | migrate(migrateUser)
373 | log.Println("done migrating users")
374 |
375 | migrate(migratePost)
376 | log.Println("done migrating posts")
377 |
378 | migrate(migrateComment)
379 | log.Println("done migrating comments")
380 | log.Println("done all migration")
381 | }
382 |
--------------------------------------------------------------------------------
/pkg/route/contents/adminCommands.go:
--------------------------------------------------------------------------------
1 | package contents
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "github.com/gin-gonic/gin"
7 | "github.com/spf13/viper"
8 | "gorm.io/gorm"
9 | "log"
10 | "net/http"
11 | "os"
12 | "regexp"
13 | "strconv"
14 | "time"
15 | "treehollow-v3-backend/pkg/base"
16 | "treehollow-v3-backend/pkg/consts"
17 | "treehollow-v3-backend/pkg/logger"
18 | "treehollow-v3-backend/pkg/utils"
19 | )
20 |
21 | //TODO: (middle priority) better result for `reports`
22 | func adminDecryptionCommand() gin.HandlerFunc {
23 | return func(c *gin.Context) {
24 | if !viper.GetBool("allow_admin_commands") {
25 | c.Next()
26 | return
27 | }
28 | user := c.MustGet("user").(base.User)
29 | keywords := c.Query("keywords")
30 | if base.CanViewDecryptionMessages(&user) {
31 | info := ""
32 | var uid int32 = -1
33 | reg := regexp.MustCompile("decrypt pid=([0-9]+)")
34 | if reg.MatchString(keywords) {
35 | pidStr := reg.FindStringSubmatch(keywords)[1]
36 | pid, _ := strconv.Atoi(pidStr)
37 | var post base.Post
38 | err3 := base.GetDb(true).First(&post, int32(pid)).Error
39 | if err3 != nil {
40 | if errors.Is(err3, gorm.ErrRecordNotFound) {
41 | base.HttpReturnWithErrAndAbort(c, -101, logger.NewSimpleError("DecryptPostNoPid", "找不到这条树洞", logger.WARN))
42 | } else {
43 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err3, "GetSavedPostFailed", consts.DatabaseReadFailedString))
44 |
45 | }
46 | return
47 | }
48 | uid = post.UserID
49 | info += fmt.Sprintf("Decryption information for post #%d:", post.ID)
50 | }
51 |
52 | reg = regexp.MustCompile("decrypt cid=([0-9]+)")
53 | if reg.MatchString(keywords) {
54 | cidStr := reg.FindStringSubmatch(keywords)[1]
55 | cid, _ := strconv.Atoi(cidStr)
56 | var comment base.Comment
57 | err3 := base.GetDb(true).First(&comment, int32(cid)).Error
58 | if err3 != nil {
59 | if errors.Is(err3, gorm.ErrRecordNotFound) {
60 | base.HttpReturnWithErrAndAbort(c, -101, logger.NewSimpleError("DecryptCommentNoPid", "找不到这条树洞评论", logger.WARN))
61 | } else {
62 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err3, "GetSavedCommentFailed", consts.DatabaseReadFailedString))
63 |
64 | }
65 | return
66 | }
67 | uid = comment.UserID
68 | info += fmt.Sprintf("Decryption information for comment #%d-%d:", comment.PostID, comment.ID)
69 | }
70 |
71 | if uid > 0 {
72 | var toBeDecryptedUser base.User
73 | err3 := base.GetDb(true).First(&toBeDecryptedUser, uid).Error
74 | if err3 != nil {
75 | if errors.Is(err3, gorm.ErrRecordNotFound) {
76 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewSimpleError("DecryptNoUser", "找不到发帖用户", logger.WARN))
77 | } else {
78 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err3, "GetDecryptionUserFailed", consts.DatabaseReadFailedString))
79 |
80 | }
81 | return
82 | }
83 |
84 | var decryptionMsgs []base.DecryptionKeyShares
85 | err := base.GetDb(false).Where("email_encrypted = ?", toBeDecryptedUser.EmailEncrypted).
86 | Find(&decryptionMsgs).Error
87 | if err != nil {
88 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "GetDecryptionMsgsFailed", consts.DatabaseReadFailedString))
89 | return
90 | }
91 | info += "\nEncrypted email = " + toBeDecryptedUser.EmailEncrypted + "\n"
92 | for _, msg := range decryptionMsgs {
93 | info += "\n***\nKeykeeper email:" + msg.PGPEmail + "\nPGP encrypted message:\n```\n" +
94 | msg.PGPMessage + "\n```"
95 | }
96 |
97 | httpReturnInfo(c, info)
98 | return
99 | }
100 | }
101 | c.Next()
102 | }
103 | }
104 |
105 | func adminHelpCommand() gin.HandlerFunc {
106 | return func(c *gin.Context) {
107 | if !viper.GetBool("allow_admin_commands") {
108 | c.Next()
109 | return
110 | }
111 | user := c.MustGet("user").(base.User)
112 | keywords := c.Query("keywords")
113 | if base.CanShowHelp(&user) && keywords == "help" {
114 | info := ""
115 | if base.CanViewStatistics(&user) {
116 | info += "`stats`: 查看树洞统计信息\n"
117 | }
118 | if base.CanViewDecryptionMessages(&user) {
119 | info += "`decrypt pid=123`, `decrypt cid=1234`: 查看树洞发帖人个人信息的待解密消息\n"
120 | }
121 | if base.CanViewDeletedPost(&user) {
122 | info += "`dels`: 搜索所有被管理员删除的树洞和回复(包括删除后恢复的)\n"
123 | info += "`//setflag NOT_SHOW_DELETED=on`(注意大小写): 在除了`deleted`搜索界面外的其他界面隐藏被删除的树洞\n"
124 | }
125 | if base.CanViewAllSystemMessages(&user) {
126 | info += "`msgs`: 查看所有用户收到的系统消息\n"
127 | }
128 | if base.CanViewReports(&user) {
129 | info += "`rep_dels`: 查看所有用户的【删除举报】(树洞or回复)\n"
130 | }
131 | if base.CanViewLogs(&user) {
132 | info += "`rep_recalls`: 查看所有用户的【撤回】(树洞or回复)\n"
133 | info += "`rep_folds`: 查看所有用户的【折叠举报】(树洞or回复)\n"
134 | info += "`log_tags`: 查看所有【管理员打Tag】的操作日志\n"
135 | info += "`log_dels`: 查看所有的【管理员删除】\n"
136 | info += "`log_unbans`: 查看所有【撤销删除】、【解禁】的操作日志\n"
137 | info += "`logs`: 查看所有举报、删帖、打tag的操作日志\n"
138 | }
139 | if base.CanShutdown(&user) {
140 | info += "`shutdown`: 关闭树洞, 请谨慎使用此命令\n"
141 | }
142 |
143 | if base.GetDeletePostRateLimitIn24h(user.Role) > 0 {
144 | uidStr := strconv.Itoa(int(user.ID))
145 | ctx, err := deleteBanLimiter.Peek(c, uidStr)
146 | if err != nil {
147 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "DeleteBanLimiterFailed", consts.DatabaseReadFailedString))
148 | return
149 | }
150 | limit := base.GetDeletePostRateLimitIn24h(user.Role)
151 |
152 | info += "\n---\n"
153 | info += fmt.Sprintf("您的【删帖禁言】操作次数额度剩余(24h内):%d/%d\n",
154 | limit+ctx.Remaining-ctx.Limit,
155 | limit)
156 | }
157 |
158 | httpReturnInfo(c, info)
159 | return
160 | }
161 | c.Next()
162 | }
163 | }
164 |
165 | func adminStatisticsCommand() gin.HandlerFunc {
166 | return func(c *gin.Context) {
167 | if !viper.GetBool("allow_admin_commands") {
168 | c.Next()
169 | return
170 | }
171 | user := c.MustGet("user").(base.User)
172 | keywords := c.Query("keywords")
173 | if base.CanViewStatistics(&user) && keywords == "stats" {
174 | var count int64
175 | var count2 int64
176 |
177 | info := ""
178 | err := base.GetDb(true).Model(&base.User{}).Where("email_encrypted != \"\"").Count(&count).Error
179 | if err != nil {
180 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "GetTotalUserFailed", consts.DatabaseReadFailedString))
181 | return
182 | }
183 | info += "总注册人数(包含已注销账户):" + strconv.Itoa(int(count)) + "\n"
184 |
185 | err = base.GetDb(true).Model(&base.Email{}).Count(&count).Error
186 | if err != nil {
187 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "GetTotalRegisteredUserFailed", consts.DatabaseReadFailedString))
188 | return
189 | }
190 | info += "总注册人数(不包含已注销账户):" + strconv.Itoa(int(count)) + "\n"
191 |
192 | err = base.GetDb(false).Model(&base.User{}).Where("email_encrypted != \"\"").Count(&count2).Error
193 | if err != nil {
194 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "GetTotalRegisteredUser2Failed", consts.DatabaseReadFailedString))
195 | return
196 | }
197 | if count != count2 {
198 | info += "警告:数据库邮箱验证系统自洽性检验失败,请修复!(" + strconv.Itoa(int(count2)) + ")\n"
199 | }
200 |
201 | err = base.GetDb(true).Model(&base.User{}).Count(&count).Error
202 | if err != nil {
203 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "GetTotalNewOldUserFailed", consts.DatabaseReadFailedString))
204 | return
205 | }
206 | info += "总注册人数(包含老版本树洞账户):" + strconv.Itoa(int(count)) + "\n"
207 |
208 | err = base.GetDb(true).Model(&base.Post{}).
209 | Where("created_at > ?", time.Now().AddDate(0, 0, -1)).
210 | Count(&count).Error
211 | if err != nil {
212 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "GetPostPerDayFailed", consts.DatabaseReadFailedString))
213 | return
214 | }
215 | info += "24h内发帖数:" + strconv.Itoa(int(count)) + "\n"
216 |
217 | err = base.GetDb(true).Model(&base.Post{}).
218 | Where("deleted_at is not null and created_at > ?", time.Now().AddDate(0, 0, -1)).
219 | Count(&count).Error
220 | if err != nil {
221 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "GetDeletedPostStatsFailed", consts.DatabaseReadFailedString))
222 | return
223 | }
224 | info += "24h内树洞删帖数:" + strconv.Itoa(int(count)) + "\n"
225 |
226 | err = base.GetDb(true).Model(&base.Comment{}).
227 | Where("deleted_at is not null and created_at > ?", time.Now().AddDate(0, 0, -1)).
228 | Count(&count).Error
229 | if err != nil {
230 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "GetDeletedCommentStatsFailed", consts.DatabaseReadFailedString))
231 | return
232 | }
233 | info += "24h内评论删帖数:" + strconv.Itoa(int(count)) + "\n"
234 |
235 | httpReturnInfo(c, info)
236 | return
237 | }
238 | c.Next()
239 | }
240 | }
241 |
242 | func adminReportsCommand() gin.HandlerFunc {
243 | return func(c *gin.Context) {
244 | if !viper.GetBool("allow_admin_commands") {
245 | c.Next()
246 | return
247 | }
248 | user := c.MustGet("user").(base.User)
249 | keywords := c.Query("keywords")
250 | if base.CanViewReports(&user) && keywords == "rep_dels" {
251 | page := c.MustGet("page").(int)
252 | offset := (page - 1) * consts.SearchPageSize
253 | limit := consts.SearchPageSize
254 | var reports []base.Report
255 |
256 | err := base.GetDb(false).Order("id desc").Where("type = ?", base.UserReport).
257 | Where("created_at > ?", time.Now().AddDate(0, 0, -1)).
258 | Limit(limit).Offset(offset).Find(&reports).Error
259 | if err != nil {
260 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "GetReportedPostsFailed", consts.DatabaseReadFailedString))
261 | return
262 | }
263 | var data []gin.H
264 | for _, report := range reports {
265 | data = append(data, textToFrontendJson(report.ID, report.CreatedAt.Unix(), report.ToString()))
266 | }
267 |
268 | c.JSON(http.StatusOK, gin.H{
269 | "code": 0,
270 | "data": utils.IfThenElse(data != nil, data, []string{}),
271 | //"timestamp": utils.GetTimeStamp(),
272 | "count": utils.IfThenElse(data != nil, len(data), 0),
273 | "comments": map[string]string{},
274 | })
275 | c.Abort()
276 | return
277 | }
278 | c.Next()
279 | }
280 | }
281 |
282 | func adminLogsCommand() gin.HandlerFunc {
283 | return func(c *gin.Context) {
284 | if !viper.GetBool("allow_admin_commands") {
285 | c.Next()
286 | return
287 | }
288 | user := c.MustGet("user").(base.User)
289 | keywords := c.Query("keywords")
290 | if base.CanViewLogs(&user) {
291 | if _, ok := utils.ContainsString([]string{"logs", "rep_dels", "rep_folds", "log_tags", "log_dels",
292 | "rep_recalls", "log_unbans"}, keywords); ok {
293 |
294 | page := c.MustGet("page").(int)
295 | offset := (page - 1) * consts.SearchPageSize
296 | limit := consts.SearchPageSize
297 | var reports []base.Report
298 |
299 | var err error
300 | if keywords == "logs" {
301 | err = base.GetDb(false).Order("id desc").
302 | Limit(limit).Offset(offset).Find(&reports).Error
303 | } else if keywords == "log_dels" {
304 | err = base.GetDb(false).Order("id desc").Where(base.GetDb(false).
305 | Where("type = ?", base.UserDelete).
306 | Where("user_id != reported_user_id")).
307 | Or("type = ?", base.AdminDeleteAndBan).Limit(limit).Offset(offset).Find(&reports).Error
308 | } else if keywords == "rep_recalls" {
309 | err = base.GetDb(false).Order("id desc").Where("type = ?", base.UserDelete).
310 | Where("user_id = reported_user_id").Limit(limit).Offset(offset).Find(&reports).Error
311 | } else if keywords == "rep_dels" {
312 | err = base.GetDb(false).Order("id desc").Where("type = ?", base.UserReport).
313 | Limit(limit).Offset(offset).Find(&reports).Error
314 | } else if keywords == "rep_folds" {
315 | err = base.GetDb(false).Order("id desc").Where("type = ?", base.UserReportFold).
316 | Limit(limit).Offset(offset).Find(&reports).Error
317 | } else if keywords == "log_tags" {
318 | err = base.GetDb(false).Order("id desc").Where("type = ?", base.AdminTag).
319 | Limit(limit).Offset(offset).Find(&reports).Error
320 | } else if keywords == "log_unbans" {
321 | err = base.GetDb(false).Order("id desc").Where("type in (?)",
322 | []base.ReportType{base.AdminUnban, base.AdminUndelete}).
323 | Limit(limit).Offset(offset).Find(&reports).Error
324 | }
325 | if err != nil {
326 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "LogsCommandFailed", consts.DatabaseReadFailedString))
327 | return
328 | }
329 | var data []gin.H
330 | for _, report := range reports {
331 | data = append(data, textToFrontendJson(report.ID, report.CreatedAt.Unix(), report.ToDetailedString()))
332 | }
333 |
334 | c.JSON(http.StatusOK, gin.H{
335 | "code": 0,
336 | "data": utils.IfThenElse(data != nil, data, []string{}),
337 | //"timestamp": utils.GetTimeStamp(),
338 | "count": utils.IfThenElse(data != nil, len(data), 0),
339 | "comments": map[string]string{},
340 | })
341 | c.Abort()
342 | return
343 | }
344 | }
345 | c.Next()
346 | }
347 | }
348 |
349 | func adminSysMsgsCommand() gin.HandlerFunc {
350 | return func(c *gin.Context) {
351 | if !viper.GetBool("allow_admin_commands") {
352 | c.Next()
353 | return
354 | }
355 | user := c.MustGet("user").(base.User)
356 | keywords := c.Query("keywords")
357 | if base.CanViewAllSystemMessages(&user) && keywords == "msgs" {
358 | page := c.MustGet("page").(int)
359 | offset := (page - 1) * consts.SearchPageSize
360 | limit := consts.SearchPageSize
361 | var msgs []base.SystemMessage
362 |
363 | err := base.GetDb(false).Where("title != ?", "新的登录").Order("id desc").Limit(limit).Offset(offset).Find(&msgs).Error
364 | if err != nil {
365 | base.HttpReturnWithCodeMinusOneAndAbort(c, logger.NewError(err, "SysMsgCommandFailed", consts.DatabaseReadFailedString))
366 | return
367 | }
368 | var data []gin.H
369 | for _, msg := range msgs {
370 | data = append(data, textToFrontendJson(msg.ID, msg.CreatedAt.Unix(), msg.ToString()))
371 | }
372 |
373 | c.JSON(http.StatusOK, gin.H{
374 | "code": 0,
375 | "data": utils.IfThenElse(data != nil, data, []string{}),
376 | //"timestamp": utils.GetTimeStamp(),
377 | "count": utils.IfThenElse(data != nil, len(data), 0),
378 | "comments": map[string]string{},
379 | })
380 | c.Abort()
381 | return
382 | }
383 | c.Next()
384 | }
385 | }
386 |
387 | var shutdownCountDown int
388 |
389 | func adminShutdownCommand() gin.HandlerFunc {
390 | return func(c *gin.Context) {
391 | if !viper.GetBool("allow_admin_commands") {
392 | c.Next()
393 | return
394 | }
395 | user := c.MustGet("user").(base.User)
396 | keywords := c.Query("keywords")
397 | if base.CanShutdown(&user) && keywords == "shutdown" {
398 | uidStr := strconv.Itoa(int(user.ID))
399 | log.Printf("Super user " + uidStr + " shutdown. shutdownCountDown=" + strconv.Itoa(shutdownCountDown))
400 | if shutdownCountDown > 0 {
401 | httpReturnInfo(c, strconv.Itoa(shutdownCountDown)+" more times to fully shutdown.")
402 | shutdownCountDown -= 1
403 | c.Abort()
404 | } else {
405 | os.Exit(0)
406 | }
407 | return
408 | }
409 | c.Next()
410 | }
411 | }
412 |
--------------------------------------------------------------------------------
/pkg/route/contents/routeApiGET.go:
--------------------------------------------------------------------------------
1 | package contents
2 |
3 | import (
4 | "encoding/json"
5 | "errors"
6 | "github.com/gin-gonic/gin"
7 | "github.com/iancoleman/orderedmap"
8 | "github.com/spf13/viper"
9 | "gorm.io/gorm"
10 | "log"
11 | "math/rand"
12 | "net/http"
13 | "strconv"
14 | "strings"
15 | "time"
16 | "treehollow-v3-backend/pkg/base"
17 | "treehollow-v3-backend/pkg/config"
18 | "treehollow-v3-backend/pkg/consts"
19 | "treehollow-v3-backend/pkg/logger"
20 | "treehollow-v3-backend/pkg/model"
21 | "treehollow-v3-backend/pkg/utils"
22 | "unicode/utf8"
23 | )
24 |
25 | func commentToJson(comment *base.Comment, user *base.User) gin.H {
26 | offset := utils.CalcExtra(user.ForgetPwNonce, strconv.Itoa(int(comment.ID)))
27 | imageMetadata := map[string]int{}
28 | err2 := json.Unmarshal([]byte(comment.FileMetadata), &imageMetadata)
29 | if err2 != nil {
30 | log.Printf("bad image metadata in cid=%d: err=%s\n", comment.ID, err2)
31 | }
32 | return gin.H{
33 | "cid": comment.ID,
34 | "pid": comment.PostID,
35 | "text": comment.Text,
36 | "type": comment.Type,
37 | "timestamp": comment.CreatedAt.Unix() - offset,
38 | "reply_to": comment.ReplyTo,
39 | "url": utils.GetHashedFilePath(comment.FilePath),
40 | "tag": utils.IfThenElse(len(comment.Tag) != 0, comment.Tag, nil),
41 | "permissions": base.GetPermissionsByComment(user, comment),
42 | "deleted": comment.DeletedAt.Valid,
43 | "name": comment.Name,
44 | "is_dz": comment.Name == consts.DzName,
45 | "image_metadata": imageMetadata,
46 | }
47 | }
48 |
49 | func commentsToJson(comments []base.Comment, user *base.User) []gin.H {
50 | data := make([]gin.H, 0, len(comments))
51 | for _, comment := range comments {
52 | if !comment.DeletedAt.Valid || base.CanViewDeletedPost(user) {
53 | data = append(data, commentToJson(&comment, user))
54 | }
55 | }
56 | return data
57 | }
58 |
59 | func detailPost(c *gin.Context) {
60 | pid, err := strconv.Atoi(c.Query("pid"))
61 | if err != nil {
62 | base.HttpReturnWithCodeMinusOne(c, logger.NewSimpleError("DetailPostPidNotInt", "获取失败,pid不合法", logger.WARN))
63 | return
64 | }
65 |
66 | user := c.MustGet("user").(base.User)
67 | canViewDelete := base.CanViewDeletedPost(&user)
68 |
69 | var post base.Post
70 | err3 := base.GetDb(canViewDelete).First(&post, int32(pid)).Error
71 | if err3 != nil {
72 | if errors.Is(err3, gorm.ErrRecordNotFound) {
73 | base.HttpReturnWithErr(c, -101, logger.NewSimpleError("DetailPostPidNotFound", "找不到这条树洞", logger.WARN))
74 | } else {
75 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err3, "DetailPostError", consts.DatabaseReadFailedString))
76 | }
77 | return
78 | }
79 |
80 | offset := utils.CalcExtra(user.ForgetPwNonce, strconv.Itoa(int(post.ID)))
81 | var attention int64
82 | _ = base.GetDb(false).Model(&base.Attention{}).Where(&base.Attention{PostID: post.ID, UserID: user.ID}).Count(&attention).Error
83 |
84 | votes, err4 := getVotesInPosts(base.GetDb(false), &user, []base.Post{post})
85 | if err4 != nil {
86 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err4, "GetVotesInPostsFailed", consts.DatabaseReadFailedString))
87 | return
88 | }
89 |
90 | if (c.Query("include_comment") == "0") ||
91 | (c.Query("old_updated_at") == strconv.Itoa(int(post.UpdatedAt.Unix()-offset))) {
92 | c.JSON(http.StatusOK, gin.H{
93 | "code": 1,
94 | "data": nil,
95 | "post": postToJson(&post, &user, attention == 1, votes[post.ID]),
96 | })
97 | return
98 | }
99 | comments, err2 := base.GetCommentsWithCache(&post, time.Now())
100 | if err2 != nil {
101 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err2, "GetCommentsWithCacheFailed", consts.DatabaseReadFailedString))
102 | return
103 | }
104 |
105 | data := commentsToJson(comments, &user)
106 | post.ReplyNum = int32(len(data))
107 | c.JSON(http.StatusOK, gin.H{
108 | "code": 0,
109 | "data": utils.IfThenElse(data != nil, data, []string{}),
110 | "post": postToJson(&post, &user, attention == 1, votes[post.ID]),
111 | })
112 | return
113 | }
114 |
115 | func postToJson(post *base.Post, user *base.User, attention bool, voted string) gin.H {
116 | offset := utils.CalcExtra(user.ForgetPwNonce, strconv.Itoa(int(post.ID)))
117 | imageMetadata := map[string]int{}
118 | err2 := json.Unmarshal([]byte(post.FileMetadata), &imageMetadata)
119 | if err2 != nil {
120 | log.Printf("bad image metadata in pid=%d: err=%s\n", post.ID, err2)
121 | }
122 | tag := post.Tag
123 | if post.ReportNum >= 3 && !post.DeletedAt.Valid && tag == "" {
124 | tag = "举报较多"
125 | }
126 | vote := gin.H{}
127 | if len(post.VoteData) > 2 {
128 | voteData := orderedmap.New()
129 | err := json.Unmarshal([]byte(post.VoteData), &voteData)
130 | if err == nil {
131 | if len(voted) == 0 {
132 | for _, k := range voteData.Keys() {
133 | voteData.Set(k, -1)
134 | }
135 | }
136 | vote = gin.H{
137 | "voted": voted,
138 | "vote_options": voteData.Keys(),
139 | "vote_data": voteData,
140 | }
141 | } else {
142 | log.Printf("bad vote_data in pid=%d: err=%s\n", post.ID, err)
143 | }
144 | }
145 | return gin.H{
146 | "pid": post.ID,
147 | "text": post.Text,
148 | "type": post.Type,
149 | "timestamp": post.CreatedAt.Unix() - offset,
150 | "updated_at": post.UpdatedAt.Unix() - offset,
151 | "reply": post.ReplyNum,
152 | "likenum": post.LikeNum,
153 | "attention": attention,
154 | "permissions": base.GetPermissionsByPost(user, post),
155 | "deleted": post.DeletedAt.Valid,
156 | "url": utils.GetHashedFilePath(post.FilePath),
157 | "tag": utils.IfThenElse(len(tag) == 0, nil, tag),
158 | "image_metadata": imageMetadata,
159 | "vote": vote,
160 | }
161 | }
162 |
163 | func postsToJson(posts []base.Post, user *base.User, attentionPids []int32, voted map[int32]string) []gin.H {
164 | data := make([]gin.H, 0, len(posts))
165 | attentionPidsSet := utils.Int32SliceToSet(attentionPids)
166 | for _, post := range posts {
167 | data = append(data, postToJson(&post, user, utils.Int32IsInSet(post.ID, attentionPidsSet), voted[post.ID]))
168 | }
169 | return data
170 | }
171 |
172 | func getAttentionPidsInPosts(tx *gorm.DB, user *base.User, posts []base.Post) (attentionPids []int32, err error) {
173 | pids := make([]int32, 0, len(posts))
174 | for _, post := range posts {
175 | pids = append(pids, post.ID)
176 | }
177 | err = tx.Model(&base.Attention{}).Where("user_id = ? and post_id in ?", user.ID, pids).
178 | Pluck("post_id", &attentionPids).Error
179 | return
180 | }
181 |
182 | func getVotesInPosts(tx *gorm.DB, user *base.User, posts []base.Post) (map[int32]string, error) {
183 | pids := make([]int32, 0, len(posts))
184 | for _, post := range posts {
185 | if len(post.VoteData) > 2 {
186 | pids = append(pids, post.ID)
187 | }
188 | }
189 | if len(pids) == 0 {
190 | return make(map[int32]string), nil
191 | }
192 |
193 | var votes []base.Vote
194 | err := tx.Model(&base.Vote{}).
195 | Where("user_id = ? and post_id in ?", user.ID, pids).
196 | Find(&votes).Error
197 | if err != nil {
198 | return nil, err
199 | }
200 | rtn := make(map[int32]string)
201 | for _, vote := range votes {
202 | rtn[vote.PostID] = vote.Option
203 | }
204 | return rtn, nil
205 | }
206 |
207 | func appendPostDetail(tx *gorm.DB, posts []base.Post, user *base.User) ([]gin.H, *logger.InternalError) {
208 | attentionPids, err3 := getAttentionPidsInPosts(tx, user, posts)
209 | if err3 != nil {
210 | return nil, logger.NewError(err3, "getAttentionPidsInPosts failed", consts.DatabaseReadFailedString)
211 | }
212 | votes, err4 := getVotesInPosts(tx, user, posts)
213 | if err4 != nil {
214 | return nil, logger.NewError(err4, "getVotesInPosts failed", consts.DatabaseReadFailedString)
215 | }
216 | jsPosts := postsToJson(posts, user, attentionPids, votes)
217 | return jsPosts, nil
218 | }
219 |
220 | func getCommentsByPosts(posts []base.Post, user *base.User) (map[int32][]gin.H, *logger.InternalError) {
221 | comments := make(map[int32][]gin.H)
222 | commentsMap, err4 := base.GetMultipleCommentsWithCache(base.GetDb(false), posts, time.Now())
223 | if err4 != nil {
224 | return nil, err4
225 | }
226 | //TODO: (low priority) update reply_num
227 | for pid, tmp := range commentsMap {
228 | if len(tmp) > 3 {
229 | tmp = tmp[:3]
230 | }
231 | if len(tmp) > 0 {
232 | comments[pid] = commentsToJson(tmp, user)
233 | }
234 | }
235 | return comments, nil
236 | }
237 |
238 | func listPost(c *gin.Context) {
239 | user := c.MustGet("user").(base.User)
240 | canViewDelete := base.CanViewDeletedPost(&user)
241 | page := c.MustGet("page").(int)
242 | posts, err2 := base.ListPosts(base.GetDb(false), page, &user)
243 | if err2 != nil {
244 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err2, "ListPostsFailed", consts.DatabaseReadFailedString))
245 | return
246 | }
247 |
248 | pinnedPids := viper.GetIntSlice("pin_pids")
249 |
250 | var configInfo gin.H
251 | if page == 1 {
252 | configInfo = config.GetFrontendConfigInfo()
253 | if len(pinnedPids) > 0 {
254 | var pinnedPosts []base.Post
255 | err3 := base.GetDb(canViewDelete).Where(pinnedPids).Order("id desc").Find(&pinnedPosts).Error
256 | if err3 != nil {
257 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err3, "GetPinnedPostsFailed", consts.DatabaseReadFailedString))
258 | return
259 | } else {
260 | posts = append(pinnedPosts, posts...)
261 | }
262 | }
263 | }
264 |
265 | jsPosts, err := appendPostDetail(base.GetDb(false), posts, &user)
266 | if err != nil {
267 | base.HttpReturnWithCodeMinusOne(c, err)
268 | return
269 | }
270 |
271 | comments, err4 := getCommentsByPosts(posts, &user)
272 | if err4 != nil {
273 | base.HttpReturnWithCodeMinusOne(c, err4)
274 | return
275 | }
276 |
277 | c.JSON(http.StatusOK, gin.H{
278 | "code": 0,
279 | "data": utils.IfThenElse(jsPosts != nil, jsPosts, []string{}),
280 | "config": configInfo,
281 | //"timestamp": utils.GetTimeStamp(),
282 | "count": utils.IfThenElse(jsPosts != nil, len(jsPosts), 0),
283 | "comments": comments,
284 | })
285 | return
286 | }
287 |
288 | func wanderListPost(c *gin.Context) {
289 | user := c.MustGet("user").(base.User)
290 | canViewDelete := base.CanViewDeletedPost(&user)
291 | var posts []base.Post
292 |
293 | var maxId int32
294 | err2 := base.GetDb(canViewDelete).Model(&base.Post{}).Select("max(id)").First(&maxId).Error
295 | if err2 != nil {
296 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err2, "GetMaxPidFailed", consts.DatabaseReadFailedString))
297 | return
298 | }
299 | pids := make([]int32, 0, consts.WanderPageSize)
300 | for i := 0; i < consts.WanderPageSize; i++ {
301 | pids = append(pids, 1+int32(rand.Intn(int(maxId))))
302 | }
303 | err2 = base.GetDb(canViewDelete).Where("id in (?)", pids).Order("RAND()").Find(&posts).Error
304 | if err2 != nil {
305 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err2, "GetWanderPosts", consts.DatabaseReadFailedString))
306 | return
307 | }
308 |
309 | var posts2 []base.Post
310 |
311 | reportableTags := viper.GetStringSlice("reportable_tags")
312 | inactiveRangeStart := viper.GetIntSlice("inactive_pid_range_start")
313 | inactiveRangeEnd := viper.GetIntSlice("inactive_pid_range_end")
314 | checkID := func(id int) bool {
315 | if len(inactiveRangeStart) != len(inactiveRangeEnd) {
316 | return false
317 | }
318 | for i := range inactiveRangeStart {
319 | if id >= inactiveRangeStart[i] && id < inactiveRangeEnd[i] {
320 | return true
321 | }
322 | }
323 | return false
324 | }
325 | for i := 0; i < 10; i++ {
326 | posts2 = make([]base.Post, 0, len(posts))
327 | for _, post := range posts {
328 | if _, b := utils.ContainsString(reportableTags, post.Tag); b || checkID(int(post.ID)) {
329 | if rand.Float32() > 0.1 {
330 | continue
331 | }
332 | }
333 | posts2 = append(posts2, post)
334 | }
335 | if len(posts2) > 0 {
336 | break
337 | }
338 | }
339 | jsPosts, err := appendPostDetail(base.GetDb(false), posts2, &user)
340 | if err != nil {
341 | base.HttpReturnWithCodeMinusOne(c, err)
342 | return
343 | }
344 |
345 | c.JSON(http.StatusOK, gin.H{
346 | "code": 0,
347 | "data": utils.IfThenElse(jsPosts != nil, jsPosts, []string{}),
348 | //"timestamp": utils.GetTimeStamp(),
349 | "count": utils.IfThenElse(jsPosts != nil, len(jsPosts), 0),
350 | })
351 | return
352 | }
353 |
354 | func searchPost(c *gin.Context) {
355 | page := c.MustGet("page").(int)
356 | user := c.MustGet("user").(base.User)
357 | keywords := c.Query("keywords")
358 | includeComment := c.Query("include_comment") != "false"
359 | beforeDate := c.Query("before")
360 | beforeTimestamp, err := strconv.ParseInt(beforeDate, 10, 64)
361 | if err != nil {
362 | beforeTimestamp = -1
363 | }
364 | afterDate := c.Query("after")
365 | afterTimestamp, err := strconv.ParseInt(afterDate, 10, 64)
366 | if err != nil {
367 | afterTimestamp = -1
368 | }
369 |
370 | if utf8.RuneCountInString(keywords) > consts.SearchMaxLength {
371 | base.HttpReturnWithCodeMinusOne(c, logger.NewSimpleError("TooLongKeywords", "搜索内容过长", logger.WARN))
372 | return
373 | }
374 |
375 | posts, err2 := base.SearchPosts(page, keywords, nil, user,
376 | model.SearchOrderFromString(c.Query("order")), includeComment, beforeTimestamp, afterTimestamp)
377 | if err2 != nil {
378 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err2, "SearchPostsFailed", consts.DatabaseReadFailedString))
379 | return
380 | }
381 |
382 | jsPosts, err3 := appendPostDetail(base.GetDb(false), posts, &user)
383 | if err3 != nil {
384 | base.HttpReturnWithCodeMinusOne(c, err3)
385 | return
386 | }
387 |
388 | keywordsSlice := strings.Split(keywords, " ")
389 | comments := make(map[int32][]gin.H)
390 | commentsMap, err5 := base.GetMultipleCommentsWithCache(base.GetDb(false), posts, time.Now())
391 | if err5 != nil {
392 | base.HttpReturnWithCodeMinusOne(c, err5)
393 | return
394 | }
395 | //TODO: (low priority) update reply_num
396 | for pid, tmp := range commentsMap {
397 | var commentsContainsKeywords []base.Comment
398 | for _, comment := range tmp {
399 | //TODO: (low priority) check if keyword is #tag
400 | for _, keyword := range keywordsSlice {
401 | if strings.Contains(comment.Text, keyword) {
402 | commentsContainsKeywords = append(commentsContainsKeywords, comment)
403 | break
404 | }
405 | }
406 | //if len(commentsContainsKeywords) >= 3 {
407 | // break
408 | //}
409 | }
410 | if len(commentsContainsKeywords) > 0 {
411 | comments[pid] = commentsToJson(commentsContainsKeywords, &user)
412 | }
413 | }
414 |
415 | c.JSON(http.StatusOK, gin.H{
416 | "code": 0,
417 | "data": utils.IfThenElse(jsPosts != nil, jsPosts, []string{}),
418 | //"timestamp": utils.GetTimeStamp(),
419 | "count": utils.IfThenElse(jsPosts != nil, len(jsPosts), 0),
420 | "comments": comments,
421 | })
422 | return
423 | }
424 |
425 | func searchAttentionPost(c *gin.Context) {
426 | page := c.MustGet("page").(int)
427 | user := c.MustGet("user").(base.User)
428 | canViewDelete := base.CanViewDeletedPost(&user)
429 | keywords := c.Query("keywords")
430 |
431 | if utf8.RuneCountInString(keywords) > consts.SearchMaxLength {
432 | base.HttpReturnWithCodeMinusOne(c, logger.NewSimpleError("TooLongKeywords", "搜索内容过长", logger.WARN))
433 | return
434 | }
435 |
436 | var attentionPids []int32
437 | err3 := base.GetDb(canViewDelete).Model(&base.Attention{}).
438 | Where("user_id = ?", user.ID).
439 | Pluck("post_id", &attentionPids).Error
440 | if err3 != nil {
441 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err3, "GetAttentionPidsFailed", consts.DatabaseReadFailedString))
442 | return
443 | }
444 |
445 | posts, err2 := base.SearchPosts(page, keywords, attentionPids, user,
446 | model.SearchOrderFromString(c.Query("order")), true, -1, -1)
447 | if err2 != nil {
448 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err2, "SearchPostsFailed", consts.DatabaseReadFailedString))
449 | return
450 | }
451 | votes, err4 := getVotesInPosts(base.GetDb(false), &user, posts)
452 | if err4 != nil {
453 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err4, "GetVotesInPostsFailed", consts.DatabaseReadFailedString))
454 | return
455 | }
456 | jsPosts := postsToJson(posts, &user, attentionPids, votes)
457 |
458 | c.JSON(http.StatusOK, gin.H{
459 | "code": 0,
460 | "data": utils.IfThenElse(jsPosts != nil, jsPosts, []string{}),
461 | //"timestamp": utils.GetTimeStamp(),
462 | "count": utils.IfThenElse(jsPosts != nil, len(jsPosts), 0),
463 | })
464 | return
465 | }
466 |
467 | func attentionPosts(c *gin.Context) {
468 | page := c.MustGet("page").(int)
469 |
470 | user := c.MustGet("user").(base.User)
471 | canViewDelete := base.CanViewDeletedPost(&user)
472 | offset := (page - 1) * consts.PageSize
473 | limit := consts.PageSize
474 |
475 | var attentionPids []int32
476 | err3 := base.GetDb(canViewDelete).Model(&base.Attention{}).
477 | Where("user_id = ?", user.ID).Order("post_id desc").Limit(limit).Offset(offset).
478 | Pluck("post_id", &attentionPids).Error
479 | if err3 != nil {
480 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err3, "GetAttentionPidsFailed", consts.DatabaseReadFailedString))
481 | return
482 | }
483 |
484 | var posts []base.Post
485 | err2 := base.GetDb(canViewDelete).Where("id in ?", attentionPids).Order("id desc").Find(&posts).Error
486 | if err2 != nil {
487 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err2, "GetAttentionPostsFailed", consts.DatabaseReadFailedString))
488 | return
489 | }
490 | votes, err4 := getVotesInPosts(base.GetDb(false), &user, posts)
491 | if err4 != nil {
492 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err4, "GetAttentionPostsVoteFailed", consts.DatabaseReadFailedString))
493 | return
494 | }
495 |
496 | comments, err5 := getCommentsByPosts(posts, &user)
497 | if err5 != nil {
498 | base.HttpReturnWithCodeMinusOne(c, err5)
499 | return
500 | }
501 |
502 | data := postsToJson(posts, &user, attentionPids, votes)
503 | c.JSON(http.StatusOK, gin.H{
504 | "code": 0,
505 | "data": utils.IfThenElse(data != nil, data, []string{}),
506 | //"timestamp": utils.GetTimeStamp(),
507 | "count": utils.IfThenElse(data != nil, len(data), 0),
508 | "comments": comments,
509 | })
510 | return
511 |
512 | }
513 |
514 | func systemMsg(c *gin.Context) {
515 | var msgs []base.SystemMessage
516 | user := c.MustGet("user").(base.User)
517 | err2 := base.GetDb(false).Where("user_id = ?", user.ID).Order("created_at desc").Find(&msgs).Error
518 | data := make([]gin.H, 0, len(msgs))
519 | for _, msg := range msgs {
520 | data = append(data, gin.H{
521 | "content": msg.Text,
522 | "timestamp": msg.CreatedAt.Unix(),
523 | "title": msg.Title,
524 | })
525 | }
526 |
527 | if err2 != nil {
528 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err2, "GetSysMsgFailed", consts.DatabaseReadFailedString))
529 | return
530 | } else {
531 | c.JSON(http.StatusOK, gin.H{
532 | "code": 0,
533 | "data": utils.IfThenElse(data != nil, data, []gin.H{{
534 | "content": "目前尚无系统消息",
535 | "timestamp": 0,
536 | "title": "提示",
537 | }}),
538 | })
539 | }
540 | }
541 |
542 | func myMsgs(c *gin.Context) {
543 | user := c.MustGet("user").(base.User)
544 | page := c.MustGet("page").(int)
545 | pushOnly := c.Query("push_only") == "1"
546 |
547 | sinceId, err := strconv.Atoi(c.Query("since_id"))
548 | if err != nil {
549 | sinceId = -1
550 | }
551 |
552 | msgs, err2 := base.ListMsgs(page, int32(sinceId), user.ID, pushOnly)
553 | if err2 != nil {
554 | base.HttpReturnWithCodeMinusOne(c, logger.NewError(err2, "ListMsgsFailed", consts.DatabaseReadFailedString))
555 | return
556 | }
557 | var data []gin.H
558 | for _, msg := range msgs {
559 | p := gin.H{
560 | "id": msg.ID,
561 | "title": msg.Title,
562 | "body": utils.TrimText(msg.Message, 100),
563 | "type": msg.Type,
564 | "timestamp": msg.UpdatedAt.Unix(),
565 | }
566 | if (msg.Type & (model.ReplyMeComment | model.CommentInFavorited)) > 0 {
567 | p["pid"] = msg.PostID
568 | p["cid"] = msg.CommentID
569 | }
570 | data = append(data, p)
571 | }
572 |
573 | c.JSON(http.StatusOK, gin.H{
574 | "code": 0,
575 | "data": utils.IfThenElse(data != nil, data, []string{}),
576 | })
577 | return
578 | }
579 |
--------------------------------------------------------------------------------