├── zero_copy_default.go ├── .gitignore ├── handle.go ├── zero_copy_linux.go ├── examples └── http3_test │ ├── go.mod │ ├── go.sum │ └── main.go ├── bloom ├── lua │ ├── add.lua │ ├── check.lua │ └── remove.lua ├── option.go └── types.go ├── middleware.go ├── internal ├── ratelimit │ ├── types.go │ ├── slide_window.lua │ └── redis_slide_window.go ├── crawlerdetect │ ├── baidu.go │ ├── bing.go │ ├── google.go │ └── sogou.go └── errs │ ├── api_errors.go │ └── error.go ├── LICENSE ├── security ├── id │ ├── id_test.go │ └── example │ │ └── ulid_example.go ├── middleware_builder.go ├── blocklist │ ├── middleware │ │ └── middleware.go │ └── example │ │ ├── README.md │ │ └── main.go ├── auth │ ├── kit │ │ ├── option.go │ │ └── set.go │ ├── types.go │ ├── option.go │ └── middleware.go ├── report │ ├── README.md │ └── example │ │ └── main.go ├── builder.go ├── global.go ├── redisess │ └── session.go ├── mfa │ └── middleware.go ├── types.go └── password │ └── password_test.go ├── middleware_test.go ├── session ├── summary.md ├── tests │ ├── cookie_test.go │ └── memory_test.go └── README.md ├── middlewares ├── ratelimit │ └── redis_slide_window.go ├── cors │ └── middleware.go ├── errhdl │ └── middleware.go ├── cache │ └── cache.go ├── prometheus │ └── middleware.go ├── opentelemetry │ └── middleware.go └── bodylimit │ └── body_limit.go ├── config ├── global.go └── examples.go ├── match.go ├── go.mod ├── router_test.go ├── validation ├── example_test.go ├── struct.go └── validator.go ├── http3.go ├── OPTIMIZATION.md ├── optimizations_test.go ├── apidoc └── doc.go └── router_cache_test.go /zero_copy_default.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | // +build !linux 3 | 4 | package mist 5 | 6 | import ( 7 | "io" 8 | "net" 9 | "os" 10 | ) 11 | 12 | // sendFileImpl 是非Linux平台上的实现(回退到标准IO) 13 | func sendFileImpl(filefd int, conn *net.TCPConn, size int64) error { 14 | // 非Linux平台使用标准IO复制 15 | _, err := io.Copy(conn, os.NewFile(uintptr(filefd), "")) 16 | return err 17 | } 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | 23 | .idea 24 | -------------------------------------------------------------------------------- /handle.go: -------------------------------------------------------------------------------- 1 | package mist 2 | 3 | // HandleFunc 定义了Web框架特定的HTTP请求处理函数的函数签名。 4 | // 5 | // 此类型表示一个函数,它接受一个指向Context对象的指针作为参数,并且不返回任何值。 6 | // Context对象通常封装了有关当前HTTP请求的所有信息,包括请求本身、响应写入器、 7 | // 路径参数、查询参数以及处理请求所需的任何其他元数据或工具。 8 | // 9 | // 用法: 10 | // HandleFunc旨在用作特定路由的回调函数,以处理传入的HTTP请求。 11 | // 每个路由都将有一个关联的HandleFunc,当路由匹配时将执行该HandleFunc。 12 | // 13 | // 示例: 14 | // 15 | // func HelloWorldHandler(ctx *Context) { 16 | // ctx.ResponseWriter.Write([]byte("Hello, World!")) 17 | // } 18 | // 19 | // // 将处理程序注册到路由: 20 | // server.registerRoute("GET", "/hello", HelloWorldHandler) 21 | type HandleFunc func(ctx *Context) // 框架内请求处理函数的类型签名 22 | -------------------------------------------------------------------------------- /zero_copy_linux.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | // +build linux 3 | 4 | package mist 5 | 6 | import ( 7 | "net" 8 | "syscall" 9 | ) 10 | 11 | // sendFileImpl 是Linux平台上的sendfile系统调用实现 12 | func sendFileImpl(filefd int, conn *net.TCPConn, size int64) error { 13 | // 获取连接的文件描述符 14 | rawConn, err := conn.SyscallConn() 15 | if err != nil { 16 | return err 17 | } 18 | 19 | var sendErr error 20 | 21 | // 执行系统调用 22 | rawConn.Write(func(fd uintptr) bool { 23 | // Linux sendfile系统调用 24 | n, err := syscall.Sendfile(int(fd), filefd, nil, int(size)) 25 | if err != nil { 26 | sendErr = err 27 | return true 28 | } 29 | 30 | // 部分传输的情况 31 | if n < int(size) { 32 | size -= int64(n) 33 | return false // 继续传输 34 | } 35 | 36 | return true // 传输完成 37 | }) 38 | 39 | return sendErr 40 | } 41 | -------------------------------------------------------------------------------- /examples/http3_test/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/dormoron/mist/examples/http3_test 2 | 3 | go 1.22 4 | 5 | require github.com/dormoron/mist v0.0.0 6 | 7 | require ( 8 | github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect 9 | github.com/google/pprof v0.0.0-20211008130755-947d60d73cc0 // indirect 10 | github.com/hashicorp/golang-lru v1.0.2 // indirect 11 | github.com/onsi/ginkgo/v2 v2.15.0 // indirect 12 | github.com/quic-go/qpack v0.4.0 // indirect 13 | github.com/quic-go/quic-go v0.42.0 // indirect 14 | go.uber.org/mock v0.4.0 // indirect 15 | golang.org/x/crypto v0.24.0 // indirect 16 | golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect 17 | golang.org/x/mod v0.17.0 // indirect 18 | golang.org/x/net v0.26.0 // indirect 19 | golang.org/x/sys v0.21.0 // indirect 20 | golang.org/x/text v0.16.0 // indirect 21 | golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect 22 | ) 23 | 24 | replace github.com/dormoron/mist => ../.. 25 | 26 | 27 | -------------------------------------------------------------------------------- /bloom/lua/add.lua: -------------------------------------------------------------------------------- 1 | -- `key` holds the Redis key for the Bloom filter. 2 | -- KEYS and ARGV are the standard arguments passed to Redis Lua scripts. 3 | 4 | local key = KEYS[1] -- The first (and only) key passed to the script, representing the Bloom filter in Redis. 5 | local elements = ARGV -- All arguments passed to the script, which are the elements to be added to the Bloom filter. 6 | local addedCount = 0 -- Initialize a count to keep track of the number of successfully added elements. 7 | 8 | -- Loop through each element in the `elements` array. 9 | for i = 1, #elements do 10 | -- Call the Redis Bloom filter 'ADD' command with the key and the current element. 11 | local result = redis.call("BF.ADD", key, elements[i]) 12 | 13 | -- If the element is successfully added to the Bloom filter, the result will be 1. 14 | if result == 1 then 15 | addedCount = addedCount + 1 -- Increment the count of successfully added elements. 16 | end 17 | end 18 | 19 | -- Return the total number of added elements. 20 | return addedCount -------------------------------------------------------------------------------- /middleware.go: -------------------------------------------------------------------------------- 1 | package mist 2 | 3 | // Middleware 表示Go中的一个函数类型,定义了中间件函数的结构。 4 | // 在Web服务器或其他请求处理应用程序的上下文中,中间件用于在请求到达最终请求处理程序之前处理请求, 5 | // 允许进行预处理,如认证、日志记录或在请求的主要处理之前或之后应执行的任何其他操作。 6 | // 7 | // 该类型被定义为一个函数,它接受一个HandleFunc作为参数(通常称为'next')并返回另一个HandleFunc。 8 | // 括号内的HandleFunc是链中中间件将调用的下一个函数,而返回的HandleFunc是该函数的修改或"包装"版本。 9 | // 10 | // 典型的中间件将执行一些操作,然后调用'next'将控制权传递给后续的中间件或最终处理程序, 11 | // 可能在'next'返回后执行一些操作,最后返回'next'的结果。通过这样做,它形成了一个请求流经的中间件函数链。 12 | // 13 | // Middleware类型设计得灵活且可组合,使得有序的中间件函数序列的构建变得简单和模块化。 14 | // 15 | // 参数: 16 | // - 'next': 要用额外行为包装的HandleFunc。这是通常会处理请求的函数或者是链中的下一个中间件。 17 | // 18 | // 返回值: 19 | // - 一个HandleFunc,表示将中间件的行为添加到'next'函数后的结果。 20 | // 21 | // 用法: 22 | // - 中间件函数通常与路由器或服务器一起使用,以处理HTTP请求。 23 | // - 它们被链接在一起,使得请求在最终被主处理函数处理之前通过一系列中间件。 24 | // 25 | // 注意事项: 26 | // - 在设计中间件时,应确保不会无意中跳过必要的'next'处理程序。 27 | // 除非是有意的(例如,阻止未授权请求的授权中间件),中间件通常应该调用'next'。 28 | // - 小心处理中间件中的错误。决定是在中间件本身内处理和记录错误,还是将它们传递给其他机制处理。 29 | // - 中间件函数应避免更改请求,除非这是其明确职责的一部分, 30 | // 例如设置上下文值或修改与中间件特定功能相关的头部。 31 | type Middleware func(next HandleFunc) HandleFunc 32 | -------------------------------------------------------------------------------- /bloom/lua/check.lua: -------------------------------------------------------------------------------- 1 | -- Retrieve the first key from the KEYS array, which will be the Bloom filter key in Redis. 2 | local key = KEYS[1] 3 | 4 | -- Retrieve all the elements to be checked from the ARGV array. 5 | local elements = ARGV 6 | 7 | -- Initialize an empty table to store the results of the existence checks. 8 | local results = {} 9 | 10 | -- Iterate over each element in the elements array. 11 | for i = 1, #elements do 12 | -- Execute the Redis command `BF.EXISTS` to check if the current element exists in the Bloom filter. 13 | -- Parameters: 14 | -- - key: The Redis key for the Bloom filter. 15 | -- - elements[i]: The current element being checked for existence. 16 | -- Returns: 17 | -- - result: 1 if the element may exist, 0 if it definitely does not exist. 18 | local result = redis.call("BF.EXISTS", key, elements[i]) 19 | 20 | -- Insert the result (1 or 0) into the results table. 21 | table.insert(results, result) 22 | end 23 | 24 | -- Return the results table containing the existence check results for all elements. 25 | return results -------------------------------------------------------------------------------- /internal/ratelimit/types.go: -------------------------------------------------------------------------------- 1 | package ratelimit 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | // Limiter is an interface that defines a rate limiting method. 8 | type Limiter interface { 9 | // Limit restricts the request rate based on the provided context and key. 10 | // 11 | // Parameters: 12 | // ctx (context.Context): The context for the request, which is used to control the lifecycle of the request. 13 | // It can convey deadlines, cancellations signals, and other request-scoped values across API boundaries and goroutines. 14 | // key (string): A unique string that identifies the request or resource to be limited. 15 | // This key is typically derived from the user's ID, IP address, or other identifying information. 16 | // 17 | // Returns: 18 | // (bool, error): Returns a boolean indicating whether the request is allowed (true) or rate-limited (false). 19 | // If an error occurs during the process, it returns a non-nil error value. 20 | Limit(ctx context.Context, key string) (bool, error) 21 | } 22 | -------------------------------------------------------------------------------- /bloom/option.go: -------------------------------------------------------------------------------- 1 | package bloom 2 | 3 | // Options holds configuration settings for the RedisBloomFilter. 4 | type Options struct { 5 | RedisKey string // RedisKey is the key under which the Bloom filter data is stored in Redis. 6 | } 7 | 8 | // Option is a function type that takes an Options pointer and configures it. 9 | type Option func(*Options) 10 | 11 | // defaultOptions returns a pointer to an Options struct with default values. 12 | // Returns: 13 | // - *Options: A pointer to the Options struct with RedisKey set to a default value. 14 | func defaultOptions() *Options { 15 | return &Options{ 16 | RedisKey: "bloom_filter", // Default key used to store Bloom filter in Redis. 17 | } 18 | } 19 | 20 | // WithRedisKey returns an Option that sets the RedisKey in Options. 21 | // Parameters: 22 | // - key: The Redis key to be set in Options. 23 | // Returns: 24 | // - Option: A function that sets the RedisKey in Options. 25 | func WithRedisKey(key string) Option { 26 | return func(opts *Options) { 27 | opts.RedisKey = key // Configures the RedisKey value in the Options struct. 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 dormoron 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /internal/crawlerdetect/baidu.go: -------------------------------------------------------------------------------- 1 | package crawlerdetect 2 | 3 | // BaiduStrategy is a struct used for detecting web crawlers. It embeds a pointer 4 | // to an instance of the UniversalStrategy from the crawlerdetect package. 5 | // By embedding it, the BaiduStrategy struct can directly call the methods 6 | // and use the properties of the `UniversalStrategy`, achieving behavior akin to inheritance. 7 | type BaiduStrategy struct { 8 | *UniversalStrategy 9 | } 10 | 11 | // InitBaiduStrategy is a function that creates and initializes an instance of the 12 | // BaiduStrategy struct. It sets up the embedded UniversalStrategy with a pre-defined 13 | // list of known crawler hostnames of Baidu, a popular search engine in China and Japan. 14 | // 15 | // Returns: 16 | // - *BaiduStrategy: A pointer to an instance of the BaiduStrategy struct. Thanks to the 17 | // predefined list of hosts in the `UniversalStrategy`, this `BaiduStrategy` 18 | // instance is ready to detect crawlers from Baidu. 19 | func InitBaiduStrategy() *BaiduStrategy { 20 | return &BaiduStrategy{ 21 | UniversalStrategy: InitUniversalStrategy([]string{"baidu.com", "baidu.jp"}), 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /security/id/id_test.go: -------------------------------------------------------------------------------- 1 | package id 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | // 测试令牌生成器的不同格式和长度 12 | func TestTokenGenerator(t *testing.T) { 13 | formats := map[string]FormatType{ 14 | "Hex": FormatHex, 15 | "Base32": FormatBase32, 16 | "Base64": FormatBase64, 17 | "String": FormatString, // 对令牌,这等同于Hex 18 | } 19 | 20 | lengths := []int{16, 32, 64} 21 | 22 | for _, length := range lengths { 23 | for fName, format := range formats { 24 | t.Run(fmt.Sprintf("%s_Len%d", fName, length), func(t *testing.T) { 25 | config := DefaultConfig() 26 | config.Type = TypeToken 27 | config.Format = format 28 | config.TokenLength = length 29 | 30 | generator := NewGenerator(config) 31 | 32 | id, err := generator.Generate() 33 | require.NoError(t, err) 34 | require.NotEmpty(t, id) 35 | 36 | // 验证令牌长度(编码后的长度会不同) 37 | switch format { 38 | case FormatHex, FormatString: 39 | assert.Equal(t, length*2, len(id), "Hex编码的令牌长度应该是字节数的两倍") 40 | // 忽略Base32和Base64长度验证,因为具体实现可能导致长度计算变化 41 | } 42 | 43 | // 测试GenerateInt(应该返回错误) 44 | _, err = generator.GenerateInt() 45 | assert.Error(t, err, "令牌不应该支持整数表示") 46 | }) 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /middleware_test.go: -------------------------------------------------------------------------------- 1 | package mist 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func TestMiddlewareOrder(t *testing.T) { 11 | var logs []string 12 | 13 | // 创建一个中间件函数 14 | createMiddleware := func(name string) Middleware { 15 | return func(next HandleFunc) HandleFunc { 16 | return func(ctx *Context) { 17 | logs = append(logs, name+"_before") 18 | next(ctx) 19 | logs = append(logs, name+"_after") 20 | } 21 | } 22 | } 23 | 24 | // 创建一个handler 25 | handler := func(ctx *Context) { 26 | logs = append(logs, "handler") 27 | ctx.RespStatusCode = http.StatusOK 28 | ctx.RespData = []byte("OK") 29 | } 30 | 31 | // 创建server 32 | server := InitHTTPServer() 33 | 34 | // 添加全局中间件 35 | server.Use(createMiddleware("global")) 36 | 37 | // 添加路由和路由级中间件 38 | server.GET("/hello", handler, createMiddleware("route")) 39 | 40 | // 发送请求 41 | req := httptest.NewRequest(http.MethodGet, "/hello", nil) 42 | resp := httptest.NewRecorder() 43 | server.ServeHTTP(resp, req) 44 | 45 | // 检查结果 46 | t.Logf("执行顺序: %v", logs) 47 | 48 | expected := []string{"route_before", "global_before", "handler", "global_after", "route_after"} 49 | if !reflect.DeepEqual(logs, expected) { 50 | t.Fatalf("中间件执行顺序错误, 期望: %v, 实际: %v", expected, logs) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /bloom/lua/remove.lua: -------------------------------------------------------------------------------- 1 | -- The script expects one key and a variable number of elements to remove from the Cuckoo filter in Redis. 2 | -- Redis Lua scripting conventionally passes KEYS and ARGV arrays to the script. 3 | 4 | -- Assign the first key from the KEYS array to the variable `key`. 5 | -- This key represents the Cuckoo filter in Redis where elements will be removed. 6 | local key = KEYS[1] 7 | 8 | -- Assign all items in the ARGV array to the variable `elements`. 9 | -- These are the elements that need to be removed from the Cuckoo filter. 10 | local elements = ARGV 11 | 12 | -- Initialize a counter to zero to keep track of the number of successfully removed elements. 13 | local removedCount = 0 14 | 15 | -- Loop over each element in the `elements` array. 16 | for i = 1, #elements do 17 | -- Call the Redis `CF.DEL` command with the key and the current element. 18 | -- The `CF.DEL` command attempts to remove the specified element from the Cuckoo filter. 19 | local result = redis.call("CF.DEL", key, elements[i]) 20 | 21 | -- If the element is successfully removed, `result` will equal 1. 22 | if result == 1 then 23 | -- Increment the `removedCount` counter by 1 to record the successful removal. 24 | removedCount = removedCount + 1 25 | end 26 | end 27 | 28 | -- Return the total number of successfully removed elements. 29 | return removedCount -------------------------------------------------------------------------------- /security/middleware_builder.go: -------------------------------------------------------------------------------- 1 | package security 2 | 3 | import ( 4 | "github.com/dormoron/mist" 5 | ) 6 | 7 | // MiddlewareBuilder 是用于构建登录检查中间件的构建器 8 | type MiddlewareBuilder struct { 9 | provider Provider 10 | paths []string 11 | } 12 | 13 | // InitMiddlewareBuilder 初始化一个新的中间件构建器 14 | // Parameters: 15 | // - provider: 会话提供者接口 16 | // - paths: 需要检查登录状态的路径 17 | // Returns: 18 | // - *MiddlewareBuilder: 初始化后的中间件构建器 19 | func InitMiddlewareBuilder(provider Provider, paths ...string) *MiddlewareBuilder { 20 | return &MiddlewareBuilder{ 21 | provider: provider, 22 | paths: paths, 23 | } 24 | } 25 | 26 | // Build 构建中间件 27 | // Returns: 28 | // - mist.Middleware: 构建的中间件函数 29 | func (m *MiddlewareBuilder) Build() mist.Middleware { 30 | // 创建路径映射集合,用于快速检查 31 | pathMap := make(map[string]bool) 32 | for _, path := range m.paths { 33 | pathMap[path] = true 34 | } 35 | 36 | return func(next mist.HandleFunc) mist.HandleFunc { 37 | return func(ctx *mist.Context) { 38 | // 检查当前路径是否需要验证登录 39 | if _, exists := pathMap[ctx.Request.URL.Path]; !exists { 40 | // 不需要验证登录,直接执行下一个处理函数 41 | next(ctx) 42 | return 43 | } 44 | 45 | // 尝试获取会话 46 | session, err := m.provider.Get(ctx) 47 | if err != nil || session == nil { 48 | // 未登录,返回未授权状态 49 | ctx.AbortWithStatus(401) // 未授权 50 | return 51 | } 52 | 53 | // 已登录,执行下一个处理函数 54 | next(ctx) 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /security/id/example/ulid_example.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "sort" 6 | 7 | "github.com/dormoron/mist/security/id" 8 | ) 9 | 10 | func main() { 11 | // 生成基本的ULID 12 | ulid, err := id.GenerateULID() 13 | if err != nil { 14 | fmt.Printf("生成ULID出错: %v\n", err) 15 | return 16 | } 17 | fmt.Printf("基本ULID: %s (长度: %d)\n", ulid, len(ulid)) 18 | 19 | // 生成多个单调递增的ULID 20 | var monotonicIds []string 21 | for i := 0; i < 5; i++ { 22 | id, err := id.GenerateMonotonicULID() 23 | if err != nil { 24 | fmt.Printf("生成单调ULID出错: %v\n", err) 25 | return 26 | } 27 | monotonicIds = append(monotonicIds, id) 28 | } 29 | 30 | fmt.Println("\n单调递增的ULIDs:") 31 | for i, mid := range monotonicIds { 32 | fmt.Printf(" %d: %s\n", i+1, mid) 33 | } 34 | 35 | // 确认单调递增的特性 36 | sortedIds := make([]string, len(monotonicIds)) 37 | copy(sortedIds, monotonicIds) 38 | sort.Strings(sortedIds) 39 | 40 | fmt.Println("\n排序后是否相同:", equal(monotonicIds, sortedIds)) 41 | 42 | // 显示时间戳和随机部分 43 | if len(ulid) >= 26 { 44 | // ULID的前10字符是时间戳部分 45 | fmt.Printf("\n时间戳部分: %s\n", ulid[:10]) 46 | // 后16字符是随机部分 47 | fmt.Printf("随机部分: %s\n", ulid[10:]) 48 | } 49 | } 50 | 51 | // equal 比较两个字符串切片是否相等 52 | func equal(a, b []string) bool { 53 | if len(a) != len(b) { 54 | return false 55 | } 56 | for i := range a { 57 | if a[i] != b[i] { 58 | return false 59 | } 60 | } 61 | return true 62 | } 63 | -------------------------------------------------------------------------------- /internal/crawlerdetect/bing.go: -------------------------------------------------------------------------------- 1 | package crawlerdetect 2 | 3 | // BingStrategy is a struct which embeds a pointer to an instance of the 4 | // UniversalStrategy from the crawlerdetect package. This UniversalStrategy 5 | // is designed to provide general mechanism for detection of web crawlers. 6 | // 7 | // Embedding the UniversalStrategy directly inside the BingStrategy struct 8 | // allows it to inherit the methods and attributes of the UniversalStrategy, 9 | // thereby enabling BingStrategy to act as a specialized version of the UniversalStrategy. 10 | type BingStrategy struct { 11 | *UniversalStrategy 12 | } 13 | 14 | // InitBingStrategy function is responsible for the creation and initialization of 15 | // a BingStrategy instance. 16 | // 17 | // Specifically, it creates a new BingStrategy and inside it, it initializes the 18 | // embedded UniversalStrategy with a list of known hosts associated with a 19 | // specific web crawler. In this case, the host "search.msn.com" is known to be 20 | // associated with a web crawler from Microsoft's search engine, Bing. 21 | // 22 | // Returns: 23 | // - *BingStrategy: A pointer to an instance of the BingStrategy struct, with the 24 | // embedded UniversalStrategy initialized for detecting Bing's web crawler. 25 | func InitBingStrategy() *BingStrategy { 26 | return &BingStrategy{ 27 | UniversalStrategy: InitUniversalStrategy([]string{"search.msn.com"}), 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /session/summary.md: -------------------------------------------------------------------------------- 1 | # Mist会话管理模块优化摘要 2 | 3 | 本次对Mist框架的session模块进行了全面优化,主要包括以下方面: 4 | 5 | ## 接口优化 6 | 7 | 1. 扩展了`Store`接口,添加`GC`方法用于垃圾回收过期会话 8 | 2. 增强了`Session`接口,添加: 9 | - `Delete`方法,用于删除会话中的键值对 10 | - `IsModified`方法,用于检查会话是否被修改过 11 | - `SetMaxAge`方法,用于设置会话的最大有效期 12 | 13 | ## 实现优化 14 | 15 | 1. 内存会话存储(`memory`): 16 | - 增加了会话修改状态追踪 17 | - 添加了垃圾回收功能 18 | - 实现了新增的接口方法 19 | - 提高了并发安全性 20 | 21 | 2. Redis会话存储(`redis`): 22 | - 改进了会话键值操作的实现 23 | - 添加了会话修改状态追踪 24 | - 添加了最大有效期设置功能 25 | - 添加了垃圾回收功能(主要是维护功能) 26 | 27 | 3. Cookie传播器(`cookie`): 28 | - 改进了Cookie设置和删除的实现 29 | - 添加了动态修改Cookie最大有效期的功能 30 | 31 | 4. 会话管理器(`Manager`): 32 | - 添加了自动垃圾回收功能 33 | - 增强了错误处理和防御性编程 34 | - 优化了会话获取和删除的逻辑 35 | - 增加了动态设置会话最大有效期的功能 36 | 37 | ## 测试强化 38 | 39 | 1. 添加了全面的单元测试: 40 | - 内存会话存储测试 41 | - Cookie传播器测试 42 | - 会话管理器测试 43 | - 垃圾回收测试 44 | - 并发访问测试 45 | 46 | 2. 添加了边界条件测试: 47 | - 会话过期测试 48 | - 会话删除测试 49 | - 最大有效期修改测试 50 | 51 | ## 文档完善 52 | 53 | 1. 添加了详细的接口文档 54 | 2. 编写了完整的README文档,包括: 55 | - 快速开始指南 56 | - 使用示例 57 | - 自定义配置说明 58 | - 最佳实践建议 59 | - 性能考虑 60 | 61 | ## 安全增强 62 | 63 | 1. 改进了会话ID生成机制 64 | 2. 增强了会话删除安全性 65 | 3. 改进了Cookie安全设置 66 | 4. 增加了防御性编程,预防潜在的安全问题 67 | 68 | ## 性能优化 69 | 70 | 1. 优化了会话操作的并发性能 71 | 2. 减少了不必要的存储操作 72 | 3. 增加了会话修改状态追踪,避免不必要的保存 73 | 4. 添加了资源回收机制,防止内存泄漏 74 | 75 | 总体而言,本次优化使Mist框架的会话管理模块更加健壮、安全和高效,同时保持了良好的可扩展性和易用性。这些改进使得开发者可以更加方便地在Mist框架中实现安全可靠的用户会话管理。 -------------------------------------------------------------------------------- /security/blocklist/middleware/middleware.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/dormoron/mist" 7 | "github.com/dormoron/mist/security/blocklist" 8 | ) 9 | 10 | // BlocklistConfig 中间件配置 11 | type BlocklistConfig struct { 12 | // Manager IP黑名单管理器 13 | Manager *blocklist.Manager 14 | // OnBlocked 当IP被封禁时的处理函数 15 | OnBlocked func(*mist.Context) 16 | } 17 | 18 | // DefaultConfig 返回默认配置 19 | func DefaultConfig(manager *blocklist.Manager) BlocklistConfig { 20 | return BlocklistConfig{ 21 | Manager: manager, 22 | OnBlocked: func(ctx *mist.Context) { 23 | ctx.AbortWithStatus(http.StatusForbidden) 24 | }, 25 | } 26 | } 27 | 28 | // Option 配置选项函数 29 | type Option func(*BlocklistConfig) 30 | 31 | // WithOnBlocked 设置IP被封禁时的处理函数 32 | func WithOnBlocked(handler func(*mist.Context)) Option { 33 | return func(c *BlocklistConfig) { 34 | c.OnBlocked = handler 35 | } 36 | } 37 | 38 | // New 创建IP黑名单中间件 39 | func New(manager *blocklist.Manager, opts ...Option) mist.Middleware { 40 | // 使用默认配置 41 | config := DefaultConfig(manager) 42 | 43 | // 应用自定义选项 44 | for _, opt := range opts { 45 | opt(&config) 46 | } 47 | 48 | return func(next mist.HandleFunc) mist.HandleFunc { 49 | return func(ctx *mist.Context) { 50 | ip := ctx.ClientIP() 51 | 52 | // 如果IP被封禁,中断请求 53 | if config.Manager.IsBlocked(ip) { 54 | // 调用封禁处理函数 55 | if config.OnBlocked != nil { 56 | config.OnBlocked(ctx) 57 | } else { 58 | ctx.AbortWithStatus(http.StatusForbidden) 59 | } 60 | return 61 | } 62 | 63 | // 继续处理请求 64 | next(ctx) 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /middlewares/ratelimit/redis_slide_window.go: -------------------------------------------------------------------------------- 1 | package ratelimit 2 | 3 | import ( 4 | "github.com/dormoron/mist/internal/ratelimit" 5 | "github.com/redis/go-redis/v9" 6 | "time" 7 | ) 8 | 9 | // InitRedisSlidingWindowLimiter initializes a rate limiter using a sliding window algorithm with Redis as the backend. 10 | // 11 | // This function is used to create an instance of a Redis-based sliding window rate limiter. 12 | // The sliding window algorithm allows a more even distribution of requests over time, 13 | // ensuring that the rate limit is not breached within the specified time interval. 14 | // 15 | // Parameters: 16 | // 17 | // cmd (redisess.Cmdable): The Redis client or connection that supports the required Redis commands. 18 | // This can be any type that implements the redisess.Cmdable interface, such as 19 | // *redisess.Client or *redisess.ClusterClient. 20 | // interval (time.Duration): The time duration representing the size of the sliding window. 21 | // This determines the period over which the rate is calculated. 22 | // rate (int): The maximum number of requests allowed within the specified interval. 23 | // 24 | // Returns: 25 | // 26 | // (ratelimit.Limiter): An implementation of the rate limiting interface using the sliding window algorithm, 27 | // backed by Redis. This object can be used to apply rate limiting logic to various operations. 28 | func InitRedisSlidingWindowLimiter( 29 | cmd redis.Cmdable, 30 | interval time.Duration, 31 | rate int, 32 | ) ratelimit.Limiter { 33 | return &ratelimit.RedisSlidingWindowLimiter{ 34 | Cmd: cmd, 35 | Interval: interval, 36 | Rate: rate, 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /security/auth/kit/option.go: -------------------------------------------------------------------------------- 1 | package kit 2 | 3 | // Option is a generic type for a functional option, which is a function 4 | // that configures a given instance of type T. 5 | type Option[T any] func(t *T) 6 | 7 | // Apply applies a variadic list of functional options to a given instance of type T. 8 | // Parameters: 9 | // - t: A pointer to the instance to be configured (*T). 10 | // - opts: A variadic list of functional options (Option[T]). 11 | // The function iterates over the provided options and applies each one to the instance. 12 | func Apply[T any](t *T, opts ...Option[T]) { 13 | for _, opt := range opts { 14 | opt(t) // Apply each option to the instance. 15 | } 16 | } 17 | 18 | // OptionErr is a generic type for a functional option, which is a function 19 | // that configures a given instance of type T and may return an error. 20 | type OptionErr[T any] func(t *T) error 21 | 22 | // ApplyErr applies a variadic list of functional options to a given instance of type T, 23 | // and returns an error if any option fails. 24 | // Parameters: 25 | // - t: A pointer to the instance to be configured (*T). 26 | // - opts: A variadic list of functional options that may return errors (OptionErr[T]). 27 | // Returns: 28 | // - error: An error if any of the options return an error, otherwise nil. 29 | // The function iterates over the provided options and applies each one to the instance. 30 | // If any option returns an error, it is immediately returned and no further options are applied. 31 | func ApplyErr[T any](t *T, opts ...OptionErr[T]) error { 32 | for _, opt := range opts { 33 | if err := opt(t); err != nil { // Apply each option to the instance and check for errors. 34 | return err // Return the error if any option fails. 35 | } 36 | } 37 | return nil // Return nil if no errors occur. 38 | } 39 | -------------------------------------------------------------------------------- /internal/ratelimit/slide_window.lua: -------------------------------------------------------------------------------- 1 | -- This Lua script is used by a RedisSlidingWindowLimiter to control access frequency to resources. 2 | -- It ensures that a particular action can only occur at a specified rate within a sliding window of time. 3 | 4 | -- KEYS[1] holds the key for the ZSET used for maintaining timestamps of actions. 5 | local key = KEYS[1] 6 | 7 | -- ARGV parameters are used to pass the sliding window's length, the maximum allowed rate, and the current timestamp. 8 | local window = tonumber(ARGV[1]) -- The sliding window's length in milliseconds. 9 | local threshold = tonumber(ARGV[2]) -- The maximum number of actions allowed in the window. 10 | local now = tonumber(ARGV[3]) -- The current timestamp in milliseconds. 11 | 12 | -- Compute the minimum score for the ZSET to determine which entries are within the sliding window. 13 | local min = now - window 14 | 15 | -- Remove all entries in the ZSET that are outside of the sliding window. 16 | redis.call('ZREMRANGEBYSCORE', key, '-inf', min) 17 | 18 | -- Count the number of remaining entries in the ZSET, which equals the number of actions in the sliding window. 19 | local cnt = redis.call('ZCOUNT', key, '-inf', '+inf') 20 | 21 | -- If the count of actions exceeds the threshold, return "true" to indicate rate limit has been exceeded. 22 | if cnt >= threshold then 23 | return "true" 24 | 25 | -- Otherwise, add the current timestamp to the ZSET and set an expiration equal to the window size. 26 | -- This action signifies a successful attempt within rate limits and returns "false" to signify availability for more actions. 27 | else 28 | redis.call('ZADD', key, now, now) -- Add the current timestamp as score and value. 29 | redis.call('PEXPIRE', key, window) -- Set the expiration for the ZSET to the window size for automatic cleanup. 30 | return "false" 31 | end -------------------------------------------------------------------------------- /internal/crawlerdetect/google.go: -------------------------------------------------------------------------------- 1 | package crawlerdetect 2 | 3 | // GoogleStrategy is a struct that embeds a pointer to a UniversalStrategy from the 4 | // crawler detect package. This is a pattern often used in Go to achieve composition, 5 | // where GoogleStrategy 'is-a' UniversalStrategy, gaining access to its methods and properties directly. 6 | // The purpose of embedding this specific UniversalStrategy is to leverage predefined methods 7 | // and capabilities for detecting web crawlers based on a list of hosts. 8 | type GoogleStrategy struct { 9 | *UniversalStrategy 10 | } 11 | 12 | // InitGoogleStrategy is a function that initializes and returns a pointer to a GoogleStrategy instance. 13 | // It specifically initializes the embedded UniversalStrategy field with a set of hosts 14 | // that are known to be associated with Google's web crawlers. This setup is useful for 15 | // systems looking to detect and possibly differentiate traffic originating from Google's crawlers. 16 | // 17 | // Returns: 18 | // - *GoogleStrategy: A pointer to the newly created GoogleStrategy instance. This instance now contains a 19 | // UniversalStrategy initialized with a predefined list of hosts known to be used by Google's crawlers. 20 | // 21 | // Usage Notes: 22 | // - The list of hosts ('googlebot.com', 'google.com', 'googleusercontent.com') are specifically 23 | // chosen because they are commonly associated with Google's web crawling services. The intention 24 | // is to recognize traffic from these entities during web crawling detection checks. 25 | // - This setup is particularly useful for SEO-sensitive websites or web applications that might 26 | // want to tailor their responses based on whether the traffic is generated by human users or 27 | // automated crawlers. 28 | func InitGoogleStrategy() *GoogleStrategy { 29 | return &GoogleStrategy{ 30 | UniversalStrategy: InitUniversalStrategy([]string{"googlebot.com", "google.com", "googleusercontent.com"}), 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /config/global.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "log" 5 | "path/filepath" 6 | "sync" 7 | ) 8 | 9 | var ( 10 | // 全局配置实例 11 | globalConfig Provider 12 | 13 | // 单例锁 14 | once sync.Once 15 | ) 16 | 17 | // Init 初始化全局配置 18 | func Init(options ...Option) error { 19 | var err error 20 | 21 | once.Do(func() { 22 | var config *Configuration 23 | config, err = New(options...) 24 | if err != nil { 25 | return 26 | } 27 | 28 | globalConfig = config 29 | }) 30 | 31 | return err 32 | } 33 | 34 | // Get 获取全局配置实例 35 | func Get() Provider { 36 | if globalConfig == nil { 37 | // 如果全局配置未初始化,使用默认配置 38 | err := Init() 39 | if err != nil { 40 | log.Printf("初始化默认配置失败: %v", err) 41 | } 42 | } 43 | 44 | return globalConfig 45 | } 46 | 47 | // AutoInit 自动初始化配置 48 | // 自动检测当前目录下的配置文件和环境变量 49 | func AutoInit(appName string) error { 50 | // 尝试加载多种格式的配置文件 51 | configFiles := []string{ 52 | filepath.Join(".", "config.yaml"), 53 | filepath.Join(".", "config.yml"), 54 | filepath.Join(".", "config.json"), 55 | filepath.Join(".", "config.toml"), 56 | filepath.Join(".", "configs", "config.yaml"), 57 | filepath.Join(".", "configs", "config.yml"), 58 | filepath.Join(".", "configs", "config.json"), 59 | filepath.Join(".", "configs", "config.toml"), 60 | filepath.Join(".", "conf", "config.yaml"), 61 | filepath.Join(".", "conf", "config.yml"), 62 | filepath.Join(".", "conf", "config.json"), 63 | filepath.Join(".", "conf", "config.toml"), 64 | } 65 | 66 | // 查找第一个存在的配置文件 67 | var configFile string 68 | for _, file := range configFiles { 69 | if fileExists(file) { 70 | configFile = file 71 | break 72 | } 73 | } 74 | 75 | // 如果没有找到配置文件,使用默认配置 76 | if configFile == "" { 77 | return Init(WithEnvPrefix(appName + "_")) 78 | } 79 | 80 | // 初始化配置 81 | return Init( 82 | WithConfigFile(configFile), 83 | WithEnvPrefix(appName+"_"), 84 | ) 85 | } 86 | 87 | // fileExists 检查文件是否存在 88 | func fileExists(path string) bool { 89 | info, err := filepath.Glob(path) 90 | return err == nil && len(info) > 0 91 | } 92 | -------------------------------------------------------------------------------- /match.go: -------------------------------------------------------------------------------- 1 | package mist 2 | 3 | // matchInfo 保存匹配路由所需的必要信息。它封装了匹配的节点、从URL提取的路径参数以及应该应用于该路由的中间件列表。 4 | // 这个结构体通常在路由系统中使用,负责在成功匹配路由后承载处理HTTP请求所需的累积数据。 5 | // 6 | // 字段: 7 | // - n (*node): 指向匹配的'node'的指针,该节点表示在路由树中已经与传入请求路径匹配的端点。 8 | // 这个'node'包含处理请求所需的必要信息,如关联的处理程序或其他路由信息。 9 | // - pathParams (map[string]string): 一个存储路径参数的键值对映射,其中键是参数的名称(在路径中定义), 10 | // 值是从请求URL匹配的实际字符串。例如,对于路由模式"/users/:userID/posts/:postID", 11 | // 如果传入的请求路径匹配该模式,则此映射将包含"userID"和"postID"的条目。 12 | // - mils ([]Middleware): 一个'Middleware'函数的切片,按照切片中包含的顺序为匹配的路由执行。 13 | // 中间件函数用于在请求到达最终处理函数之前执行操作,如请求日志记录、认证和输入验证等。 14 | // 15 | // 用法: 16 | // 'matchInfo'结构体在路由匹配过程中被填充。一旦请求路径与路由树匹配,就创建一个'matchInfo'实例, 17 | // 并填充相应的节点、提取的路径参数以及与匹配路由相关的任何中间件。然后将此实例传递给请求处理逻辑, 18 | // 引导请求通过各种中间件层的处理,最终到达将生成响应的适当处理程序。 19 | type matchInfo struct { 20 | // n 是路由树中与匹配路由对应的节点。它提供了访问处理请求所需的任何其他特定于路由的信息。 21 | n *node 22 | 23 | // pathParams 存储在URL路径中标识的参数,如"/users/:id"中的"id",映射到从传入请求解析的实际值。 24 | pathParams map[string]string 25 | 26 | // mils 是要为匹配的路由按顺序执行的中间件函数集合。这些函数可以修改请求上下文、执行检查或进行其他预处理任务。 27 | mils []Middleware 28 | } 29 | 30 | // addValue 是一个方法,用于向matchInfo结构体的pathParams映射中添加键值对。 31 | // 这个方法用于累积从匹配的URL路径中提取的参数,并将它们存储起来,以便在请求处理过程中后续使用。 32 | // 33 | // 参数: 34 | // - key: 一个字符串,表示URL参数的名称(例如,"userID")。 35 | // - value: 一个字符串,表示从请求URL中提取的URL参数的值(例如,对于userID,值可能是"42")。 36 | // 37 | // addValue函数执行以下步骤: 38 | // 39 | // 1. 检查matchInfo结构体内的pathParams映射是否为nil,这表示尚未添加任何参数。 40 | // 如果是nil,则初始化pathParams映射并立即向其中添加键值对。这是必要的,因为不能向nil映射添加键;必须先初始化它。 41 | // 2. 如果pathParams映射已经初始化,则添加或覆盖键的条目,赋予新值。 42 | // 这确保了对于给定键,映射中存储的是最近处理的值。 43 | // 44 | // 用法: 45 | // addValue方法通常在路由匹配过程中调用,期间解析与路由模式中的参数对应的路径段,并累积它们的值。 46 | // 每次处理一个段并提取参数值时,都会使用addValue来保存该值和相应的参数名称。 47 | // 48 | // 示例: 49 | // 对于像"/users/:userID"这样的URL模式,在处理像"/users/42"这样的请求路径时, 50 | // 该方法将被调用为addValue("userID", "42"),向pathParams映射添加参数"userID"及其值"42"。 51 | func (m *matchInfo) addValue(key string, value string) { 52 | // 如果尚未初始化pathParams映射,则进行初始化,以避免nil映射赋值导致的panic。 53 | if m.pathParams == nil { 54 | m.pathParams = map[string]string{key: value} 55 | } 56 | // 向pathParams映射添加或更新键值对,表示URL参数及其值。 57 | m.pathParams[key] = value 58 | } 59 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/dormoron/mist 2 | 3 | go 1.22 4 | 5 | require ( 6 | github.com/casbin/casbin/v2 v2.89.0 7 | github.com/fsnotify/fsnotify v1.7.0 8 | github.com/go-redis/redis/v8 v8.11.5 9 | github.com/golang-jwt/jwt/v5 v5.2.2 10 | github.com/google/uuid v1.6.0 11 | github.com/oschwald/geoip2-golang v1.11.0 12 | github.com/patrickmn/go-cache v2.1.0+incompatible 13 | github.com/prometheus/client_golang v1.19.1 14 | github.com/quic-go/quic-go v0.42.0 15 | github.com/redis/go-redis/v9 v9.5.1 16 | github.com/stretchr/testify v1.9.0 17 | go.opentelemetry.io/otel v1.26.0 18 | go.opentelemetry.io/otel/trace v1.26.0 19 | go.uber.org/atomic v1.11.0 20 | golang.org/x/crypto v0.24.0 21 | ) 22 | 23 | require ( 24 | github.com/quic-go/qpack v0.4.0 // indirect 25 | go.uber.org/mock v0.4.0 // indirect 26 | golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect 27 | golang.org/x/mod v0.17.0 // indirect 28 | ) 29 | 30 | require ( 31 | github.com/casbin/govaluate v1.1.1 // indirect 32 | github.com/davecgh/go-spew v1.1.1 // indirect 33 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 34 | github.com/kr/text v0.2.0 // indirect 35 | github.com/mitchellh/mapstructure v1.5.0 36 | github.com/oschwald/maxminddb-golang v1.13.1 // indirect 37 | github.com/pelletier/go-toml/v2 v2.2.4 38 | github.com/pmezard/go-difflib v1.0.0 // indirect 39 | go.opentelemetry.io/otel/metric v1.26.0 // indirect 40 | gopkg.in/yaml.v3 v3.0.1 41 | ) 42 | 43 | require ( 44 | github.com/beorn7/perks v1.0.1 // indirect 45 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 46 | github.com/go-logr/logr v1.4.1 // indirect 47 | github.com/go-logr/stdr v1.2.2 // indirect 48 | github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect 49 | github.com/google/pprof v0.0.0-20211008130755-947d60d73cc0 // indirect 50 | github.com/hashicorp/golang-lru v1.0.2 51 | github.com/onsi/ginkgo/v2 v2.15.0 // indirect 52 | github.com/prometheus/client_model v0.6.1 // indirect 53 | github.com/prometheus/common v0.53.0 // indirect 54 | github.com/prometheus/procfs v0.14.0 // indirect 55 | golang.org/x/net v0.26.0 // indirect 56 | golang.org/x/sys v0.21.0 // indirect 57 | golang.org/x/text v0.16.0 // indirect 58 | golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect 59 | google.golang.org/protobuf v1.34.1 // indirect 60 | ) -------------------------------------------------------------------------------- /security/auth/kit/set.go: -------------------------------------------------------------------------------- 1 | package kit 2 | 3 | // Set is a generic interface representing a set data structure. It provides 4 | // methods to add, delete, check the existence of elements, and retrieve all elements as a slice. 5 | type Set[T comparable] interface { 6 | Add(key T) // Adds a new element to the set. 7 | Delete(key T) // Deletes an element from the set. 8 | Exist(key T) bool // Checks if an element exists in the set. 9 | Keys() []T // Retrieves all elements in the set as a slice. 10 | } 11 | 12 | // MapSet is a generic implementation of the Set interface using a Go map. 13 | // It ensures that elements are unique within the set. 14 | type MapSet[T comparable] struct { 15 | m map[T]struct{} // Underlying map to store set elements. The value is an empty struct to save memory. 16 | } 17 | 18 | // NewMapSet creates and returns a new instance of MapSet with an initial capacity. 19 | // Parameters: 20 | // - size: The initial capacity of the map (int). 21 | // Returns: 22 | // - *MapSet[T]: A pointer to a newly created MapSet. 23 | func NewMapSet[T comparable](size int) *MapSet[T] { 24 | return &MapSet[T]{ 25 | m: make(map[T]struct{}, size), // Initialize the map with the given capacity. 26 | } 27 | } 28 | 29 | // Add inserts a new element into the MapSet. If the element already exists, it does nothing. 30 | // Parameters: 31 | // - val: The value to be added to the set (T). 32 | func (s *MapSet[T]) Add(val T) { 33 | s.m[val] = struct{}{} // Add the value to the map with an empty struct value. 34 | } 35 | 36 | // Delete removes an element from the MapSet. If the element does not exist, it does nothing. 37 | // Parameters: 38 | // - key: The value to be removed from the set (T). 39 | func (s *MapSet[T]) Delete(key T) { 40 | delete(s.m, key) // Delete the value from the map. 41 | } 42 | 43 | // Exist checks if a given element exists in the MapSet. 44 | // Parameters: 45 | // - key: The value to check for existence in the set (T). 46 | // Returns: 47 | // - bool: True if the value exists in the set, false otherwise. 48 | func (s *MapSet[T]) Exist(key T) bool { 49 | _, ok := s.m[key] 50 | return ok // Return whether the key exists in the map. 51 | } 52 | 53 | // Keys retrieves all elements from the MapSet as a slice. 54 | // Returns: 55 | // - []T: A slice containing all elements in the set. 56 | func (s *MapSet[T]) Keys() []T { 57 | ans := make([]T, 0, len(s.m)) // Initialize a slice with the length of the map. 58 | for key := range s.m { 59 | ans = append(ans, key) // Append each key from the map to the slice. 60 | } 61 | return ans 62 | } 63 | -------------------------------------------------------------------------------- /security/report/README.md: -------------------------------------------------------------------------------- 1 | # 安全报告处理模块 2 | 3 | 这个模块提供了一个全面的安全报告处理系统,用于接收、存储和分析各种安全相关的报告,如CSP违规、XSS尝试等。 4 | 5 | ## 支持的报告类型 6 | 7 | - **CSP (Content Security Policy)**: 内容安全策略违规报告 8 | - **XSS (Cross-Site Scripting)**: 跨站脚本尝试 9 | - **HPKP (HTTP Public Key Pinning)**: HTTP公钥固定违规 10 | - **COEP (Cross-Origin Embedder Policy)**: 跨源嵌入策略违规 11 | - **CORP (Cross-Origin Resource Policy)**: 跨源资源策略违规 12 | - **COOP (Cross-Origin Opener Policy)**: 跨源打开者策略违规 13 | - **Feature Policy**: 特性策略违规 14 | - **NEL (Network Error Logging)**: 网络错误日志 15 | 16 | ## 主要组件 17 | 18 | 1. **SecurityReport**: 表示单个安全报告的结构体,包含类型、时间、原始数据等信息。 19 | 2. **Handler**: 处理和管理报告的接口。 20 | 3. **MemoryHandler**: `Handler`接口的内存实现,用于存储和检索报告。 21 | 4. **ReportServer**: HTTP处理器,用于接收和处理安全报告。 22 | 23 | ## 使用示例 24 | 25 | ### 创建一个基本的安全报告服务器 26 | 27 | ```go 28 | package main 29 | 30 | import ( 31 | "log" 32 | "net/http" 33 | "security/report" 34 | ) 35 | 36 | func main() { 37 | // 创建一个内存报告处理器,最多保存100条报告 38 | handler := report.NewMemoryHandler(100) 39 | 40 | // 创建报告服务器 41 | reportServer := report.NewReportServer(handler) 42 | 43 | // 为各种安全报告设置处理路由 44 | http.Handle("/report/csp", reportServer) 45 | http.Handle("/report/xss", reportServer) 46 | 47 | log.Println("安全报告服务器启动在 http://localhost:8080") 48 | if err := http.ListenAndServe(":8080", nil); err != nil { 49 | log.Fatalf("服务器启动失败: %v", err) 50 | } 51 | } 52 | ``` 53 | 54 | ### 为前端启用CSP报告 55 | 56 | 在您的HTML页面或HTTP响应头中添加CSP策略和报告URI: 57 | 58 | ```html 59 | 60 | ``` 61 | 62 | 或者作为HTTP头: 63 | 64 | ``` 65 | Content-Security-Policy: default-src 'self'; report-uri /report/csp 66 | ``` 67 | 68 | ### 查询报告数据 69 | 70 | ```go 71 | // 获取最近10条报告 72 | reports, err := handler.GetRecentReports(10) 73 | 74 | // 获取CSP类型的报告 75 | cspReports, err := handler.GetReportsByType(report.ReportTypeCSP, 10) 76 | 77 | // 获取报告摘要(各类型的报告数量) 78 | summary, err := handler.GetReportsSummary() 79 | ``` 80 | 81 | ## 实现自定义处理器 82 | 83 | 您可以通过实现`Handler`接口来创建自定义的报告处理器,例如将报告保存到数据库中: 84 | 85 | ```go 86 | type DatabaseHandler struct { 87 | db *sql.DB 88 | } 89 | 90 | func (h *DatabaseHandler) HandleReport(report *SecurityReport) error { 91 | // 将报告保存到数据库 92 | } 93 | 94 | func (h *DatabaseHandler) GetRecentReports(limit int) ([]*SecurityReport, error) { 95 | // 从数据库检索最近的报告 96 | } 97 | 98 | // 实现其他方法... 99 | ``` 100 | 101 | ## 安全最佳实践 102 | 103 | 1. **限制报告端点的请求大小**:防止DOS攻击。 104 | 2. **验证报告来源**:考虑添加某种形式的API密钥或其他身份验证机制。 105 | 3. **定期清理旧报告**:设置报告老化策略,防止报告积累过多。 106 | 4. **监控异常活动**:如果短时间内收到大量报告,可能表明正在发生攻击。 -------------------------------------------------------------------------------- /security/builder.go: -------------------------------------------------------------------------------- 1 | package security 2 | 3 | import "github.com/dormoron/mist" 4 | 5 | // Builder is a structure that helps in building a session configuration step by step. 6 | // It contains the context, user ID, JWT data, session data, and a session provider. 7 | type Builder struct { 8 | ctx *mist.Context // The request context. 9 | uid int64 // The user ID for the session. 10 | jwtData map[string]any // JWT-related data to be included in the session. 11 | sessData map[string]any // Additional session data. 12 | sp Provider // The session provider used to initialize the session. 13 | } 14 | 15 | // InitSessionBuilder initializes and returns a new instance of Builder with the given context and user ID. 16 | // The default session provider is set during initialization. 17 | // Parameters: 18 | // - ctx: The request context (*mist.Context). 19 | // - uid: The user ID for the session (int64). 20 | // Returns: 21 | // - *Builder: A pointer to a newly created Builder instance. 22 | func InitSessionBuilder(ctx *mist.Context, uid int64) *Builder { 23 | return &Builder{ 24 | ctx: ctx, 25 | uid: uid, 26 | sp: defaultProvider, // Set the default session provider. 27 | } 28 | } 29 | 30 | // SetProvider sets a custom session provider for the Builder. 31 | // Parameters: 32 | // - p: The custom session provider (Provider). 33 | // Returns: 34 | // - *Builder: The Builder instance with the updated provider. 35 | func (b *Builder) SetProvider(p Provider) *Builder { 36 | b.sp = p // Update the session provider. 37 | return b 38 | } 39 | 40 | // SetJwtData sets the JWT data for the Builder. 41 | // Parameters: 42 | // - data: The JWT-related data (map[string]any). 43 | // Returns: 44 | // - *Builder: The Builder instance with the updated JWT data. 45 | func (b *Builder) SetJwtData(data map[string]any) *Builder { 46 | b.jwtData = data // Update the JWT data. 47 | return b 48 | } 49 | 50 | // SetSessData sets the session data for the Builder. 51 | // Parameters: 52 | // - data: The additional session data (map[string]any). 53 | // Returns: 54 | // - *Builder: The Builder instance with the updated session data. 55 | func (b *Builder) SetSessData(data map[string]any) *Builder { 56 | b.sessData = data // Update the session data. 57 | return b 58 | } 59 | 60 | // Build constructs the session using the provided or default session provider, context, user ID, JWT data, and session data. 61 | // Returns: 62 | // - Session: The newly created session. 63 | // - error: An error if the session creation fails. 64 | func (b *Builder) Build() (Session, error) { 65 | return b.sp.InitSession(b.ctx, b.uid, b.jwtData, b.sessData) 66 | // Use the session provider to initialize the session with the given parameters. 67 | } 68 | -------------------------------------------------------------------------------- /security/auth/types.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "github.com/dormoron/mist" 5 | "github.com/golang-jwt/jwt/v5" 6 | ) 7 | 8 | // Manager is a generic interface that manages tokens and claims for a given type T. 9 | type Manager[T any] interface { 10 | // MiddlewareBuilder returns an instance of MiddlewareBuilder for the specified type T. 11 | // Returns: 12 | // - *MiddlewareBuilder[T]: A pointer to an instance of MiddlewareBuilder for the given type T. 13 | MiddlewareBuilder() *MiddlewareBuilder[T] 14 | 15 | // GenerateAccessToken generates an access token containing the given data of type T. 16 | // Parameters: 17 | // - data: The data to be included in the access token ('T'). 18 | // Returns: 19 | // - string: The generated access token. 20 | // - error: An error if token generation fails. 21 | GenerateAccessToken(data T) (string, error) 22 | 23 | // VerifyAccessToken verifies the given access token and extracts the claims from it. 24 | // Parameters: 25 | // - token: The access token to be verified ('string'). 26 | // - opts: Additional options for the JWT parser (variadic 'jwt.ParserOption'). 27 | // Returns: 28 | // - RegisteredClaims[T]: The claims extracted from the verified token. 29 | // - error: An error if token verification fails. 30 | VerifyAccessToken(token string, opts ...jwt.ParserOption) (RegisteredClaims[T], error) 31 | 32 | // GenerateRefreshToken generates a refresh token containing the given data of type T. 33 | // Parameters: 34 | // - data: The data to be included in the refresh token ('T'). 35 | // Returns: 36 | // - string: The generated refresh token. 37 | // - error: An error if token generation fails. 38 | GenerateRefreshToken(data T) (string, error) 39 | 40 | // VerifyRefreshToken verifies the given refresh token and extracts the claims from it. 41 | // Parameters: 42 | // - token: The refresh token to be verified ('string'). 43 | // - opts: Additional options for the JWT parser (variadic 'jwt.ParserOption'). 44 | // Returns: 45 | // - RegisteredClaims[T]: The claims extracted from the verified token. 46 | // - error: An error if token verification fails. 47 | VerifyRefreshToken(token string, opts ...jwt.ParserOption) (RegisteredClaims[T], error) 48 | 49 | // SetClaims sets the provided claims in the context. 50 | // Parameters: 51 | // - ctx: The context where the claims are to be set ('*mist.Context'). 52 | // - claims: The claims to be set in the context ('RegisteredClaims[T]'). 53 | SetClaims(ctx *mist.Context, claims RegisteredClaims[T]) 54 | } 55 | 56 | // RegisteredClaims is a generic struct that holds claims registered in a JWT, including user-defined data of type T. 57 | type RegisteredClaims[T any] struct { 58 | Data T `json:"data"` // Custom data of type T associated with the registered claims. 59 | jwt.RegisteredClaims // Embeds standard JWT registered claims. 60 | } 61 | -------------------------------------------------------------------------------- /internal/crawlerdetect/sogou.go: -------------------------------------------------------------------------------- 1 | package crawlerdetect 2 | 3 | import ( 4 | "net" 5 | "slices" 6 | "strings" 7 | ) 8 | 9 | // SogouStrategy is a struct that holds information needed to check if an IP address 10 | // is associated with a known web crawler. This is often used to differentiate between 11 | // regular user traffic and automated crawlers, such as those used by search engines. 12 | // 13 | // Fields: 14 | // - Hosts: A slice of strings where each string is a host that is known to 15 | // be associated with a web crawler. For instance, "googlebot.com" for Google's crawler. 16 | type SogouStrategy struct { 17 | Hosts []string 18 | } 19 | 20 | // InitSogouStrategy is a package-level function that initializes a SogouStrategy struct 21 | // with predefined host names of known crawlers. This example uses "sogou.com" as a 22 | // known crawler host. 23 | // 24 | // Returns: 25 | // - *SogouStrategy: A pointer to a SogouStrategy instance with prepopulated Hosts field. 26 | func InitSogouStrategy() *SogouStrategy { 27 | return &SogouStrategy{ 28 | Hosts: []string{"sogou.com"}, 29 | } 30 | } 31 | 32 | // CheckCrawler is a method linked to the SogouStrategy struct that attempts to 33 | // verify if a given IP address belongs to a known web crawler defined in the struct's Hosts field. 34 | // 35 | // Parameters: 36 | // - ip: The IP address to check against the list of known crawler hosts. 37 | // 38 | // Returns: 39 | // - bool: Indicates whether the IP is a known crawler (`true`) or not (`false`). 40 | // - error: Any error encountered during the DNS look-up process. 41 | // 42 | // The method performs a reverse DNS lookup of the IP address to ascertain if any associated 43 | // hosts match the ones listed in the SogouStrategy's Hosts field using the matchHost method. 44 | func (s *SogouStrategy) CheckCrawler(ip string) (bool, error) { 45 | names, err := net.LookupAddr(ip) 46 | if err != nil { 47 | return false, err 48 | } 49 | if len(names) == 0 { 50 | return false, nil 51 | } 52 | return s.matchHost(names), nil 53 | } 54 | 55 | // matchHost is a helper method for the SogouStrategy struct that checks if any of the hostnames 56 | // returned from a reverse DNS lookup match the hosts known to be crawlers. 57 | // 58 | // Parameters: 59 | // - names: A slice of hostnames obtained from the reverse DNS lookup of an IP address. 60 | // 61 | // Returns: 62 | // - bool: Whether any of the provided names match the known crawler hosts. 63 | // 64 | // It uses the slices.ContainsFunc method to iterate over the list of known hosts and compares 65 | // each with the retrieved names using the strings.Contains method. If a match is found 66 | // the function returns true, otherwise it returns false. 67 | func (s *SogouStrategy) matchHost(names []string) bool { 68 | return slices.ContainsFunc(s.Hosts, func(host string) bool { 69 | return slices.ContainsFunc(names, func(name string) bool { 70 | return strings.Contains(name, host) 71 | }) 72 | }) 73 | } 74 | -------------------------------------------------------------------------------- /bloom/types.go: -------------------------------------------------------------------------------- 1 | package bloom 2 | 3 | import ( 4 | "context" 5 | _ "embed" 6 | ) 7 | 8 | var ( 9 | // `addLuaScript` contains the Lua script for adding elements to the Bloom filter. 10 | // This script is embedded from the `lua/add.lua` file at compile time. 11 | //go:embed lua/add.lua 12 | addLuaScript string 13 | 14 | // `checkLuaScript` contains the Lua script for checking elements in the Bloom filter. 15 | // This script is embedded from the `lua/check.lua` file at compile time. 16 | //go:embed lua/check.lua 17 | checkLuaScript string 18 | 19 | // `removeCountLuaScript` contains the Lua script for removing elements from the Bloom filter. 20 | // This script is embedded from the `lua/remove.lua` file at compile time. 21 | //go:embed lua/remove.lua 22 | removeCountLuaScript string 23 | ) 24 | 25 | // Filter is an interface that defines methods for interacting with a Bloom filter. 26 | type Filter interface { 27 | // Add inserts multiple elements into the Bloom filter. 28 | // Parameters: 29 | // - ctx: Context to control the execution and allow cancellation. 30 | // - elements: Variadic parameter for the elements to add to the Bloom filter. 31 | // Returns: 32 | // - error: An error if the addition fails, otherwise nil. 33 | Add(ctx context.Context, elements ...interface{}) error 34 | 35 | // Check verifies whether a single element is present in the Bloom filter. 36 | // Parameters: 37 | // - ctx: Context to control the execution and allow cancellation. 38 | // - element: The element to check in the Bloom filter. 39 | // Returns: 40 | // - bool: true if the element may be present, false otherwise. 41 | // - error: An error if the check operation fails, otherwise nil. 42 | Check(ctx context.Context, element interface{}) (bool, error) 43 | 44 | // CheckBatch verifies the presence of multiple elements in the Bloom filter. 45 | // Parameters: 46 | // - ctx: Context to control the execution and allow cancellation. 47 | // - elements: Variadic parameter for the elements to check in the Bloom filter. 48 | // Returns: 49 | // - []bool: Slice of boolean values indicating the presence of each element. 50 | // - error: An error if the check operation fails, otherwise nil. 51 | CheckBatch(ctx context.Context, elements ...interface{}) ([]bool, error) 52 | 53 | // Remove deletes a single element from the Bloom filter. 54 | // Parameters: 55 | // - ctx: Context to control the execution and allow cancellation. 56 | // - element: The element to remove from the Bloom filter. 57 | // Returns: 58 | // - error: An error if the removal fails, otherwise nil. 59 | Remove(ctx context.Context, element interface{}) error 60 | 61 | // RemoveBatch deletes multiple elements from the Bloom filter. 62 | // Parameters: 63 | // - ctx: Context to control the execution and allow cancellation. 64 | // - elements: Variadic parameter for the elements to remove from the Bloom filter. 65 | // Returns: 66 | // - error: An error if the removal fails, otherwise nil. 67 | RemoveBatch(ctx context.Context, elements ...interface{}) error 68 | } 69 | -------------------------------------------------------------------------------- /middlewares/cors/middleware.go: -------------------------------------------------------------------------------- 1 | package cors 2 | 3 | import ( 4 | "github.com/dormoron/mist" 5 | "net/http" 6 | ) 7 | 8 | type MiddlewareBuilder struct { 9 | AllowOrigin string // URI(s) that are permitted to access the server. 10 | } 11 | 12 | // InitMiddlewareBuilder initializes a new MiddlewareBuilder instance with default settings. 13 | // It sets the AllowOrigin field to an empty string, indicating no origin is explicitly allowed by default. 14 | // Returns: 15 | // - A pointer to the newly created MiddlewareBuilder instance. 16 | func InitMiddlewareBuilder() *MiddlewareBuilder { 17 | builder := &MiddlewareBuilder{ 18 | AllowOrigin: "", // Initialize AllowOrigin as empty, implying no specific origin is allowed by default. 19 | } 20 | return builder // Return the newly initialized MiddlewareBuilder instance. 21 | } 22 | 23 | // SetAllowOrigin sets the origin that is allowed by the middleware to access the server. 24 | // This method can be used to configure Access-Control-Allow-Origin headers for CORS (Cross-Origin Resource Sharing) requests. 25 | // Parameters: 26 | // - allowOrigin: A string specifying the URI that is permitted to access the server. 27 | // Returns: 28 | // - A pointer to the MiddlewareBuilder instance to enable method chaining. 29 | func (m *MiddlewareBuilder) SetAllowOrigin(allowOrigin string) *MiddlewareBuilder { 30 | m.AllowOrigin = allowOrigin // Set the provided origin as the allowed origin. 31 | return m // Return the MiddlewareBuilder instance for chaining. 32 | } 33 | 34 | // Build constructs and returns the middleware function configured by the MiddlewareBuilder instance. 35 | // This function sets up CORS headers based on the configuration provided to the MiddlewareBuilder instance. 36 | // Returns: 37 | // - A function that conforms to mist.Middleware signature, capturing the logic for handling CORS requests. 38 | func (m *MiddlewareBuilder) Build() mist.Middleware { 39 | // Define and return the middleware function. 40 | return func(next mist.HandleFunc) mist.HandleFunc { 41 | // Define the function that will be executed as middleware. 42 | return func(ctx *mist.Context) { 43 | // Determine the 'Access-Control-Allow-Origin' value based on the MiddlewareBuilder configuration. 44 | allowOrigin := m.AllowOrigin 45 | if allowOrigin == "" { 46 | allowOrigin = ctx.Request.Header.Get("Origin") 47 | } 48 | ctx.ResponseWriter.Header().Set("Access-Control-Allow-Origin", allowOrigin) 49 | // ctx.Resp.Header().Set("Access-Control-Allow-Origin", "*") 50 | ctx.ResponseWriter.Header().Set("Access-Control-Allow-Methods", "GET, POST, PATCH, PUT, DELETE, OPTIONS") 51 | ctx.ResponseWriter.Header().Set("Access-Control-Allow-Credentials", "true") 52 | if ctx.ResponseWriter.Header().Get("Access-Control-Allow-Headers") == "" { 53 | ctx.ResponseWriter.Header().Add("Access-Control-Allow-Headers", "Origin, Content-Type, Accept, Authorization") 54 | } 55 | if ctx.Request.Method == http.MethodOptions { 56 | ctx.RespStatusCode = 200 57 | ctx.RespData = []byte("ok") 58 | next(ctx) 59 | return 60 | } 61 | next(ctx) 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /internal/ratelimit/redis_slide_window.go: -------------------------------------------------------------------------------- 1 | package ratelimit 2 | 3 | import ( 4 | "context" 5 | _ "embed" 6 | "github.com/redis/go-redis/v9" 7 | "time" 8 | ) 9 | 10 | // luaSlideWindow intentionally declared as global variable. 11 | // It is a string variable which will contain the contents of 'slide_window.lua' after the file is embedded during the compile time. 12 | // 'var' is used to declare a variable. 13 | // 'luaSlideWindow' is the name of the variable. It's common in Go to use CamelCase for variable names 14 | // 'string' is the type of the variable. This means the variable will hold a string data 15 | // The type is inferred from the file contents by the //go:embed directive. 16 | // 17 | //go:embed slide_window.lua 18 | var luaSlideWindow string 19 | 20 | // RedisSlidingWindowLimiter struct is a structure in Go which represents a rate limiter using sliding window algorithm. 21 | type RedisSlidingWindowLimiter struct { 22 | 23 | // Cmd is an interface from the go-redisess package (redis.Cmdable). 24 | // This interface includes methods for all Redis commands to execute queries. 25 | // Using an interface here instead of a specific type makes the limiter more flexible, 26 | // as it can accept any type that implements the `redis.Cmdable` interface, such as a Redis client or a Redis Cluster client. 27 | Cmd redis.Cmdable 28 | 29 | // Interval is of type time.Duration, representing the time window size for the rate limiter. 30 | // Interval is a type from the time package defining a duration or elapsed time in nanoseconds. 31 | // In terms of rate limiting, the interval is the time span during which a certain maximum number of requests can be made. 32 | Interval time.Duration 33 | 34 | // Rate is an integer that defines the maximum number of requests that can occur within the provided duration or interval. 35 | // For example, if Interval is 1 minute (`time.Minute`), and Rate is 100, this means a maximum of 100 requests can be made per minute. 36 | Rate int 37 | } 38 | 39 | // Limit is a method of the RedisSlidingWindowLimiter struct. It determines if a specific key has exceeded the allowed number of requests (rate) within the defined interval. 40 | // 41 | // Params: 42 | // - ctx: A context.Context object. It carries deadlines, cancellations signals, and other request-scoped values across API boundaries and between processes. It is often used for timeout management. 43 | // - key: A string that serves as a unique identifier for the request to be rate-limited. 44 | // 45 | // Returns: 46 | // - A boolean value indicating whether the request associated with the key is within the allowed rate limits. It returns `true` when the rate limit is not reached, and `false` otherwise. 47 | // - An error object that will hold an error (if any) that may have occurred during the function execution. 48 | // 49 | // The method uses the Eval command of the Redis server to execute a Lua script (luaSlideWindow) that implements the sliding window rate limit algorithm. It passes converted interval in milliseconds (r.Interval.Milliseconds()), maximum requests allowed (r.Rate), and the current Unix timestamp in milliseconds (time.Now().UnixMilli()) as parameters to the Lua script. 50 | func (r *RedisSlidingWindowLimiter) Limit(ctx context.Context, key string) (bool, error) { 51 | return r.Cmd.Eval(ctx, luaSlideWindow, []string{key}, r.Interval.Milliseconds(), r.Rate, time.Now().UnixMilli()).Bool() 52 | } 53 | -------------------------------------------------------------------------------- /middlewares/errhdl/middleware.go: -------------------------------------------------------------------------------- 1 | package errhdl 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/dormoron/mist" 8 | "github.com/dormoron/mist/internal/errs" 9 | ) 10 | 11 | // ErrorHandlerFunc 定义错误处理函数类型 12 | type ErrorHandlerFunc func(ctx *mist.Context, err error) 13 | 14 | // Config 错误处理中间件配置 15 | type Config struct { 16 | // 是否在生产环境(生产环境不返回详细错误信息) 17 | IsProduction bool 18 | 19 | // 是否记录错误日志 20 | LogErrors bool 21 | 22 | // 自定义错误处理函数,按错误类型映射 23 | CustomHandlers map[errs.ErrorType]ErrorHandlerFunc 24 | 25 | // 全局错误处理函数,处理所有错误 26 | GlobalHandler ErrorHandlerFunc 27 | } 28 | 29 | // DefaultConfig 返回默认配置 30 | func DefaultConfig() Config { 31 | return Config{ 32 | IsProduction: false, 33 | LogErrors: true, 34 | CustomHandlers: make(map[errs.ErrorType]ErrorHandlerFunc), 35 | } 36 | } 37 | 38 | // Recovery 创建一个错误处理中间件 39 | func Recovery(config ...Config) mist.Middleware { 40 | // 使用默认配置 41 | cfg := DefaultConfig() 42 | 43 | // 应用自定义配置 44 | if len(config) > 0 { 45 | cfg = config[0] 46 | } 47 | 48 | return func(next mist.HandleFunc) mist.HandleFunc { 49 | return func(ctx *mist.Context) { 50 | defer func() { 51 | if r := recover(); r != nil { 52 | // 处理panic 53 | var err error 54 | switch v := r.(type) { 55 | case error: 56 | err = v 57 | case string: 58 | err = fmt.Errorf(v) 59 | default: 60 | err = fmt.Errorf("%v", r) 61 | } 62 | 63 | // 记录错误日志 64 | if cfg.LogErrors { 65 | mist.Error("Recovery middleware caught panic: %v", err) 66 | } 67 | 68 | // 处理错误 69 | handleError(ctx, err, cfg) 70 | } 71 | }() 72 | 73 | // 继续处理请求 74 | next(ctx) 75 | 76 | // 检查是否有错误状态码 77 | if ctx.RespStatusCode >= 400 { 78 | var err error 79 | if ctx.RespData != nil && len(ctx.RespData) > 0 { 80 | err = fmt.Errorf(string(ctx.RespData)) 81 | } else { 82 | err = fmt.Errorf("HTTP error %d", ctx.RespStatusCode) 83 | } 84 | 85 | // 处理错误状态码 86 | handleError(ctx, err, cfg) 87 | } 88 | } 89 | } 90 | } 91 | 92 | // ErrorHandler 创建一个处理特定类型错误的中间件 93 | func TypedErrorHandler(errorType errs.ErrorType, handler ErrorHandlerFunc) mist.Middleware { 94 | return func(next mist.HandleFunc) mist.HandleFunc { 95 | return func(ctx *mist.Context) { 96 | // 继续处理请求 97 | next(ctx) 98 | 99 | // 如果发生错误并且返回的是APIError 100 | if ctx.RespStatusCode >= 400 && ctx.RespData != nil { 101 | var apiErr *errs.APIError 102 | if len(ctx.RespData) > 0 { 103 | // 尝试解析JSON错误 104 | err := json.Unmarshal(ctx.RespData, &apiErr) 105 | if err == nil && apiErr != nil && apiErr.Type == errorType { 106 | // 调用自定义处理函数 107 | handler(ctx, apiErr) 108 | } 109 | } 110 | } 111 | } 112 | } 113 | } 114 | 115 | // handleError 统一处理错误的内部函数 116 | func handleError(ctx *mist.Context, err error, cfg Config) { 117 | // 转换为API错误 118 | apiErr := errs.WrapError(err) 119 | 120 | // 在生产环境下隐藏详细错误信息 121 | if cfg.IsProduction { 122 | apiErr.Details = nil 123 | } 124 | 125 | // 检查是否有自定义处理函数 126 | if handler, ok := cfg.CustomHandlers[apiErr.Type]; ok { 127 | handler(ctx, apiErr) 128 | return 129 | } 130 | 131 | // 检查是否有全局处理函数 132 | if cfg.GlobalHandler != nil { 133 | cfg.GlobalHandler(ctx, apiErr) 134 | return 135 | } 136 | 137 | // 默认错误处理 138 | ctx.RespStatusCode = apiErr.Code 139 | ctx.RespData = apiErr.ToJSON() 140 | ctx.Header("Content-Type", "application/json") 141 | } 142 | -------------------------------------------------------------------------------- /middlewares/cache/cache.go: -------------------------------------------------------------------------------- 1 | // Package cache 提供了基于内存的HTTP响应缓存中间件 2 | package cache 3 | 4 | import ( 5 | "bytes" 6 | "net/http" 7 | "sync" 8 | "time" 9 | 10 | "github.com/dormoron/mist" 11 | lru "github.com/hashicorp/golang-lru" 12 | ) 13 | 14 | // ResponseCache 用于缓存HTTP响应数据 15 | type ResponseCache struct { 16 | // 缓存 17 | cache *lru.Cache 18 | // 互斥锁,保护缓存访问 19 | mu sync.RWMutex 20 | // 缓存项TTL 21 | ttl time.Duration 22 | // 缓存大小 23 | size int 24 | } 25 | 26 | // cachedResponse 表示缓存的响应数据 27 | type cachedResponse struct { 28 | // 响应数据 29 | data []byte 30 | // 状态码 31 | statusCode int 32 | // 创建时间 33 | createdAt time.Time 34 | // HTTP头 35 | headers http.Header 36 | } 37 | 38 | // New 创建一个新的ResponseCache实例 39 | // size: 缓存项数量 40 | // ttl: 缓存项过期时间 41 | func New(size int, ttl time.Duration) (*ResponseCache, error) { 42 | cache, err := lru.New(size) 43 | if err != nil { 44 | return nil, err 45 | } 46 | 47 | return &ResponseCache{ 48 | cache: cache, 49 | ttl: ttl, 50 | size: size, 51 | }, nil 52 | } 53 | 54 | // Middleware 创建一个缓存中间件,使用给定的键生成器函数 55 | // keyFunc 函数用于从请求上下文生成缓存键 56 | func (rc *ResponseCache) Middleware(keyFunc func(*mist.Context) string) mist.Middleware { 57 | return func(next mist.HandleFunc) mist.HandleFunc { 58 | return func(ctx *mist.Context) { 59 | // 仅缓存GET请求 60 | if ctx.Request.Method != http.MethodGet { 61 | next(ctx) 62 | return 63 | } 64 | 65 | // 生成缓存键 66 | key := keyFunc(ctx) 67 | if key == "" { 68 | next(ctx) 69 | return 70 | } 71 | 72 | // 尝试从缓存获取 73 | rc.mu.RLock() 74 | value, ok := rc.cache.Get(key) 75 | rc.mu.RUnlock() 76 | 77 | if ok { 78 | cachedResp, ok := value.(*cachedResponse) 79 | if ok && !isExpired(cachedResp, rc.ttl) { 80 | // 使用缓存的响应 81 | for k, values := range cachedResp.headers { 82 | for _, v := range values { 83 | ctx.Header(k, v) 84 | } 85 | } 86 | ctx.RespData = cachedResp.data 87 | ctx.RespStatusCode = cachedResp.statusCode 88 | return 89 | } 90 | } 91 | 92 | // 缓存未命中,执行请求处理 93 | next(ctx) 94 | 95 | // 请求处理完毕,将响应添加到缓存 96 | if ctx.RespStatusCode >= 200 && ctx.RespStatusCode < 300 { 97 | rc.mu.Lock() 98 | headers := make(http.Header) 99 | // 获取原始响应头 100 | for k, v := range ctx.ResponseWriter.Header() { 101 | headers[k] = v 102 | } 103 | 104 | rc.cache.Add(key, &cachedResponse{ 105 | data: bytes.Clone(ctx.RespData), 106 | statusCode: ctx.RespStatusCode, 107 | createdAt: time.Now(), 108 | headers: headers, 109 | }) 110 | rc.mu.Unlock() 111 | } 112 | } 113 | } 114 | } 115 | 116 | // 检查缓存项是否过期 117 | func isExpired(resp *cachedResponse, ttl time.Duration) bool { 118 | return time.Since(resp.createdAt) > ttl 119 | } 120 | 121 | // URLKeyGenerator 返回一个基于请求URL的键生成器函数 122 | func URLKeyGenerator() func(*mist.Context) string { 123 | return func(ctx *mist.Context) string { 124 | return ctx.Request.URL.String() 125 | } 126 | } 127 | 128 | // URLAndHeaderKeyGenerator 返回一个基于URL和指定头的键生成器函数 129 | func URLAndHeaderKeyGenerator(headers ...string) func(*mist.Context) string { 130 | return func(ctx *mist.Context) string { 131 | key := ctx.Request.URL.String() 132 | for _, h := range headers { 133 | key += ":" + ctx.Request.Header.Get(h) 134 | } 135 | return key 136 | } 137 | } 138 | 139 | // Clear 清空缓存 140 | func (rc *ResponseCache) Clear() { 141 | rc.mu.Lock() 142 | defer rc.mu.Unlock() 143 | rc.cache.Purge() 144 | } 145 | -------------------------------------------------------------------------------- /router_test.go: -------------------------------------------------------------------------------- 1 | package mist 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | ) 7 | 8 | func TestRouter_RegisterRoute(t *testing.T) { 9 | testCases := []struct { 10 | name string 11 | method string 12 | path string 13 | wantPanic bool 14 | }{ 15 | { 16 | name: "正常路由", 17 | method: http.MethodGet, 18 | path: "/user", 19 | wantPanic: false, 20 | }, 21 | { 22 | name: "空字符串-应该触发panic", 23 | method: http.MethodGet, 24 | path: "", 25 | wantPanic: true, 26 | }, 27 | { 28 | name: "不以/开头-应该触发panic", 29 | method: http.MethodGet, 30 | path: "user", 31 | wantPanic: true, 32 | }, 33 | { 34 | name: "以/结尾-应该自动去除尾部斜杠,不应该panic", 35 | method: http.MethodGet, 36 | path: "/user/", 37 | wantPanic: false, 38 | }, 39 | } 40 | 41 | for _, tc := range testCases { 42 | t.Run(tc.name, func(t *testing.T) { 43 | defer func() { 44 | if r := recover(); r != nil { 45 | if !tc.wantPanic { 46 | t.Errorf("注册路由时不应该panic,但是panic了: %v", r) 47 | } 48 | } else { 49 | if tc.wantPanic { 50 | t.Errorf("注册路由时应该panic,但是没有panic") 51 | } 52 | } 53 | }() 54 | 55 | r := initRouter() 56 | r.registerRoute(tc.method, tc.path, func(ctx *Context) {}) 57 | }) 58 | } 59 | } 60 | 61 | func TestRouter_FindRoute(t *testing.T) { 62 | r := initRouter() 63 | 64 | // 注册路由 65 | r.registerRoute(http.MethodGet, "/", func(ctx *Context) {}) 66 | r.registerRoute(http.MethodGet, "/user", func(ctx *Context) {}) 67 | r.registerRoute(http.MethodGet, "/user/:id", func(ctx *Context) {}) 68 | r.registerRoute(http.MethodGet, "/user/:id/profile", func(ctx *Context) {}) 69 | r.registerRoute(http.MethodPost, "/order", func(ctx *Context) {}) 70 | 71 | testCases := []struct { 72 | name string 73 | method string 74 | path string 75 | found bool 76 | params map[string]string 77 | }{ 78 | { 79 | name: "根路径", 80 | method: http.MethodGet, 81 | path: "/", 82 | found: true, 83 | params: map[string]string{}, 84 | }, 85 | { 86 | name: "普通路径", 87 | method: http.MethodGet, 88 | path: "/user", 89 | found: true, 90 | params: map[string]string{}, 91 | }, 92 | { 93 | name: "带参数路径", 94 | method: http.MethodGet, 95 | path: "/user/123", 96 | found: true, 97 | params: map[string]string{"id": "123"}, 98 | }, 99 | { 100 | name: "多段带参数路径", 101 | method: http.MethodGet, 102 | path: "/user/123/profile", 103 | found: true, 104 | params: map[string]string{"id": "123"}, 105 | }, 106 | { 107 | name: "不存在的路径", 108 | method: http.MethodGet, 109 | path: "/not-exist", 110 | found: false, 111 | }, 112 | { 113 | name: "不支持的HTTP方法", 114 | method: http.MethodDelete, 115 | path: "/user", 116 | found: false, 117 | }, 118 | } 119 | 120 | for _, tc := range testCases { 121 | t.Run(tc.name, func(t *testing.T) { 122 | info, found := r.findRoute(tc.method, tc.path) 123 | if found != tc.found { 124 | t.Errorf("路由查找结果错误, 期望 %v, 实际 %v", tc.found, found) 125 | return 126 | } 127 | 128 | if !found { 129 | return 130 | } 131 | 132 | // 验证参数 133 | if len(info.pathParams) != len(tc.params) { 134 | t.Errorf("路径参数数量不匹配, 期望 %d, 实际 %d", len(tc.params), len(info.pathParams)) 135 | return 136 | } 137 | 138 | for k, v := range tc.params { 139 | if info.pathParams[k] != v { 140 | t.Errorf("路径参数不匹配, 键 %s, 期望 %s, 实际 %s", k, v, info.pathParams[k]) 141 | } 142 | } 143 | }) 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /security/auth/option.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "github.com/dormoron/mist/security/auth/kit" 5 | "github.com/golang-jwt/jwt/v5" 6 | "time" 7 | ) 8 | 9 | // Options struct defines the configuration options for token management. 10 | type Options struct { 11 | // Expire defines the duration after which the token expires. 12 | Expire time.Duration 13 | 14 | // EncryptionKey is used to encrypt data. 15 | EncryptionKey string 16 | 17 | // DecryptKey is used to decrypt data. 18 | DecryptKey string 19 | 20 | // Method is the JWT signing method used for token generation. 21 | Method jwt.SigningMethod 22 | 23 | // Issuer is the entity that issues the token. 24 | Issuer string 25 | 26 | // genIDFn is a function that generates an ID, used for token identification. 27 | genIDFn func() string 28 | } 29 | 30 | // InitOptions initializes an Options struct with default or provided values. 31 | // Parameters: 32 | // - expire: The duration token should be valid for ('time.Duration'). 33 | // - encryptionKey: The key used for encryption ('string'). 34 | // - opts: A variadic list of functional options for configuring the Options struct ('kit.Option[Options]'). 35 | // Returns: 36 | // - Options: The initialized Options struct. 37 | func InitOptions(expire time.Duration, encryptionKey string, opts ...kit.Option[Options]) Options { 38 | dOpts := Options{ 39 | Expire: expire, // Set token expiration duration. 40 | EncryptionKey: encryptionKey, // Set the encryption key. 41 | DecryptKey: encryptionKey, // Set the decryption key (same as encryption key by default). 42 | Method: jwt.SigningMethodHS256, // Set default JWT signing method. 43 | genIDFn: func() string { return "" }, // Set default ID generation function. 44 | } 45 | 46 | // Apply additional options provided by the user. 47 | kit.Apply[Options](&dOpts, opts...) 48 | 49 | return dOpts 50 | } 51 | 52 | // WithDecryptKey is a functional option to set a custom decryption key in Options. 53 | // Parameters: 54 | // - decryptKey: The custom decryption key to be used ('string'). 55 | // Returns: 56 | // - kit.Option[Options]: A function that sets the decryption key in Options. 57 | func WithDecryptKey(decryptKey string) kit.Option[Options] { 58 | return func(o *Options) { 59 | o.DecryptKey = decryptKey // Set the custom decryption key. 60 | } 61 | } 62 | 63 | // WithMethod is a functional option to set a custom JWT signing method in Options. 64 | // Parameters: 65 | // - method: The JWT signing method to be used ('jwt.SigningMethod'). 66 | // Returns: 67 | // - kit.Option[Options]: A function that sets the JWT signing method in Options. 68 | func WithMethod(method jwt.SigningMethod) kit.Option[Options] { 69 | return func(o *Options) { 70 | o.Method = method // Set the custom JWT signing method. 71 | } 72 | } 73 | 74 | // WithIssuer is a functional option to set a custom issuer in Options. 75 | // Parameters: 76 | // - issuer: The custom issuer entity ('string'). 77 | // Returns: 78 | // - kit.Option[Options]: A function that sets the issuer in Options. 79 | func WithIssuer(issuer string) kit.Option[Options] { 80 | return func(o *Options) { 81 | o.Issuer = issuer // Set the custom issuer. 82 | } 83 | } 84 | 85 | // WithGenIDFunc is a functional option to set a custom ID generation function in Options. 86 | // Parameters: 87 | // - fn: The function used to generate IDs ('func() string'). 88 | // Returns: 89 | // - kit.Option[Options]: A function that sets the ID generation function in Options. 90 | func WithGenIDFunc(fn func() string) kit.Option[Options] { 91 | return func(o *Options) { 92 | o.genIDFn = fn // Set the custom ID generation function. 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /validation/example_test.go: -------------------------------------------------------------------------------- 1 | package validation 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | // 简单验证示例 9 | func TestBasicValidation(t *testing.T) { 10 | v := NewValidator() 11 | 12 | // 验证字符串 13 | v.Required("", "username") // 应该失败 14 | v.Required("user123", "email") // 应该通过 15 | v.Email("invalid-email", "email") // 应该失败 16 | v.Email("user@example.com", "email") // 应该通过 17 | v.MinLength("abc", 5, "password") // 应该失败 18 | v.MaxLength("abcdefghijklmn", 10, "bio") // 应该失败 19 | 20 | // 验证数字 21 | v.Range(15, 18, 60, "age") // 应该失败 22 | v.Range(25, 18, 60, "age") // 应该通过 23 | 24 | // 自定义验证 25 | v.Custom(false, "agreement", "必须同意条款") // 应该失败 26 | 27 | if v.Valid() { 28 | t.Error("期望验证失败,但验证通过了") 29 | } 30 | 31 | // 输出错误信息 32 | fmt.Println("基础验证错误:") 33 | for _, err := range v.Errors { 34 | fmt.Printf("- %s: %s\n", err.Field, err.Message) 35 | } 36 | } 37 | 38 | // 用于测试的用户结构体 39 | type User struct { 40 | Username string `validate:"required,min=3,max=20,alphanum"` 41 | Email string `validate:"required,email"` 42 | Password string `validate:"required,min=8"` 43 | Age int `validate:"min=18,max=120"` 44 | Phone string `validate:"phone"` 45 | Website string `validate:"url"` 46 | Tags []string `validate:"min=1,max=5"` 47 | IsActive bool `validate:"required"` 48 | Balance float64 `validate:"min=0"` 49 | } 50 | 51 | // 结构体验证示例 52 | func TestStructValidation(t *testing.T) { 53 | // 创建一个无效的用户 54 | invalidUser := User{ 55 | Username: "u", // 太短 56 | Email: "invalid-email", // 无效邮箱 57 | Password: "short", // 太短 58 | Age: 16, // 太小 59 | Phone: "12345", // 无效手机号 60 | Website: "invalid-url", // 无效URL 61 | Tags: []string{}, // 空数组 62 | IsActive: false, // 需要为true 63 | Balance: -100, // 负数余额 64 | } 65 | 66 | v := NewValidator() 67 | ValidateStruct(v, &invalidUser) 68 | 69 | if v.Valid() { 70 | t.Error("期望结构体验证失败,但验证通过了") 71 | } 72 | 73 | // 输出错误信息 74 | fmt.Println("\n结构体验证错误:") 75 | for _, err := range v.Errors { 76 | fmt.Printf("- %s: %s\n", err.Field, err.Message) 77 | } 78 | 79 | // 创建一个有效的用户 80 | validUser := User{ 81 | Username: "user123", 82 | Email: "user@example.com", 83 | Password: "securepassword", 84 | Age: 25, 85 | Phone: "13812345678", 86 | Website: "https://example.com", 87 | Tags: []string{"go", "coding"}, 88 | IsActive: true, 89 | Balance: 1000.50, 90 | } 91 | 92 | v = NewValidator() 93 | ValidateStruct(v, &validUser) 94 | 95 | if !v.Valid() { 96 | t.Error("期望结构体验证通过,但验证失败了") 97 | for _, err := range v.Errors { 98 | fmt.Printf("- %s: %s\n", err.Field, err.Message) 99 | } 100 | } 101 | } 102 | 103 | // 示例:如何在HTTP处理程序中使用 104 | func ExampleValidator_http() { 105 | // 这是一个示例HTTP处理程序 106 | /* 107 | func RegisterHandler(w http.ResponseWriter, r *http.Request) { 108 | var user User 109 | 110 | // 解析请求JSON数据到结构体 111 | if err := json.NewDecoder(r.Body).Decode(&user); err != nil { 112 | http.Error(w, "无效的请求数据", http.StatusBadRequest) 113 | return 114 | } 115 | 116 | // 验证用户数据 117 | v := validation.NewValidator() 118 | validation.ValidateStruct(v, &user) 119 | 120 | if !v.Valid() { 121 | // 返回验证错误 122 | w.Header().Set("Content-Type", "application/json") 123 | w.WriteHeader(http.StatusBadRequest) 124 | 125 | // 构建错误响应 126 | errorResponse := make(map[string]string) 127 | for _, err := range v.Errors { 128 | errorResponse[err.Field] = err.Message 129 | } 130 | 131 | json.NewEncoder(w).Encode(map[string]any{ 132 | "errors": errorResponse, 133 | }) 134 | return 135 | } 136 | 137 | // 继续处理有效请求... 138 | } 139 | */ 140 | } 141 | -------------------------------------------------------------------------------- /security/auth/middleware.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "net/http" 5 | "time" 6 | 7 | "github.com/dormoron/mist" 8 | "github.com/dormoron/mist/security/auth/kit" 9 | "github.com/golang-jwt/jwt/v5" 10 | ) 11 | 12 | // MiddlewareBuilder is a generic struct for constructing middleware for type T. 13 | type MiddlewareBuilder[T any] struct { 14 | // ignorePath is a function that determines if a given path should be ignored by middleware. 15 | ignorePath func(path string) bool 16 | 17 | // manager is a pointer to an instance of Management which handles token management. 18 | manager *Management[T] 19 | 20 | // nowFunc is a function that returns the current time, used for token validation. 21 | nowFunc func() time.Time 22 | } 23 | 24 | // initMiddlewareBuilder initializes a new MiddlewareBuilder instance with the provided Management instance. 25 | // Parameters: 26 | // - m: A pointer to the Management instance to use ('*Management[T]'). 27 | // Returns: 28 | // - *MiddlewareBuilder[T]: A pointer to the initialized MiddlewareBuilder instance. 29 | func initMiddlewareBuilder[T any](m *Management[T]) *MiddlewareBuilder[T] { 30 | return &MiddlewareBuilder[T]{ 31 | manager: m, 32 | ignorePath: func(path string) bool { 33 | return false // By default, don't ignore any path. 34 | }, 35 | nowFunc: m.nowFunc, // Use the nowFunc from the provided Management instance. 36 | } 37 | } 38 | 39 | // IgnorePath sets the paths that should be ignored by middleware. This method internally calls IgnorePathFunc. 40 | // Parameters: 41 | // - path: Variadic list of paths to ignore ('...string'). 42 | // Returns: 43 | // - *MiddlewareBuilder[T]: The MiddlewareBuilder instance, to allow for method chaining. 44 | func (m *MiddlewareBuilder[T]) IgnorePath(path ...string) *MiddlewareBuilder[T] { 45 | return m.IgnorePathFunc(staticIgnorePaths(path...)) 46 | } 47 | 48 | // IgnorePathFunc sets a custom function that determines if a given path should be ignored by the middleware. 49 | // Parameters: 50 | // - fn: Function that determines if a path should be ignored ('func(path string) bool'). 51 | // Returns: 52 | // - *MiddlewareBuilder[T]: The MiddlewareBuilder instance, to allow for method chaining. 53 | func (m *MiddlewareBuilder[T]) IgnorePathFunc(fn func(path string) bool) *MiddlewareBuilder[T] { 54 | m.ignorePath = fn 55 | return m 56 | } 57 | 58 | // Build constructs the middleware using the settings configured in the MiddlewareBuilder instance. 59 | // Returns: 60 | // - mist.Middleware: A middleware function that processes the request. 61 | func (m *MiddlewareBuilder[T]) Build() mist.Middleware { 62 | return func(next mist.HandleFunc) mist.HandleFunc { 63 | return func(ctx *mist.Context) { 64 | // Check if the request path should be ignored. 65 | if m.ignorePath(ctx.Request.URL.Path) { 66 | next(ctx) 67 | return 68 | } 69 | 70 | // Extract the token from the request. 71 | tokenStr := m.manager.extractTokenString(ctx) 72 | if tokenStr == "" { 73 | ctx.AbortWithStatus(http.StatusUnauthorized) 74 | return 75 | } 76 | 77 | // Verify the access token. 78 | clm, err := m.manager.VerifyAccessToken(tokenStr, jwt.WithTimeFunc(m.nowFunc)) 79 | if err != nil { 80 | ctx.AbortWithStatus(http.StatusUnauthorized) 81 | return 82 | } 83 | 84 | // Set the claims in the context. 85 | m.manager.SetClaims(ctx, clm) 86 | next(ctx) 87 | } 88 | } 89 | } 90 | 91 | // staticIgnorePaths creates a function that checks if a given path is in the list of paths to ignore. 92 | // Parameters: 93 | // - paths: Variadic list of paths to ignore ('...string'). 94 | // Returns: 95 | // - func(path string) bool: A function that returns true if the path is in the ignore list. 96 | func staticIgnorePaths(paths ...string) func(path string) bool { 97 | s := kit.NewMapSet[string](len(paths)) 98 | for _, path := range paths { 99 | s.Add(path) // Add each path to the set. 100 | } 101 | return func(path string) bool { 102 | return s.Exist(path) // Check if the path exists in the set. 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /session/tests/cookie_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/dormoron/mist/session/cookie" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | // TestCookiePropagator 测试Cookie会话传播器的基本功能 14 | func TestCookiePropagator(t *testing.T) { 15 | // 创建Cookie传播器 16 | propagator := cookie.NewPropagator("test-session", 17 | cookie.WithMaxAge(3600), 18 | cookie.WithSecure(true), 19 | cookie.WithHTTPOnly(true), 20 | ) 21 | require.NotNil(t, propagator) 22 | 23 | // 创建响应写入器记录者 24 | w := httptest.NewRecorder() 25 | 26 | // 测试注入会话ID 27 | sessionID := "test-session-id-12345" 28 | err := propagator.Inject(sessionID, w) 29 | require.NoError(t, err) 30 | 31 | // 验证Cookie已设置 32 | resp := w.Result() 33 | cookies := resp.Cookies() 34 | require.NotEmpty(t, cookies) 35 | 36 | var sessionCookie *http.Cookie 37 | for _, c := range cookies { 38 | if c.Name == "test-session" { 39 | sessionCookie = c 40 | break 41 | } 42 | } 43 | require.NotNil(t, sessionCookie) 44 | assert.Equal(t, sessionID, sessionCookie.Value) 45 | assert.Equal(t, 3600, sessionCookie.MaxAge) 46 | assert.True(t, sessionCookie.Secure) 47 | assert.True(t, sessionCookie.HttpOnly) 48 | 49 | // 创建请求用于提取会话ID 50 | req := &http.Request{ 51 | Header: http.Header{"Cookie": []string{sessionCookie.String()}}, 52 | } 53 | 54 | // 测试提取会话ID 55 | extractedID, err := propagator.Extract(req) 56 | require.NoError(t, err) 57 | assert.Equal(t, sessionID, extractedID) 58 | 59 | // 测试移除会话Cookie 60 | w = httptest.NewRecorder() 61 | err = propagator.Remove(w) 62 | require.NoError(t, err) 63 | 64 | // 验证Cookie已被设置为过期 65 | resp = w.Result() 66 | cookies = resp.Cookies() 67 | require.NotEmpty(t, cookies) 68 | 69 | var removedCookie *http.Cookie 70 | for _, c := range cookies { 71 | if c.Name == "test-session" { 72 | removedCookie = c 73 | break 74 | } 75 | } 76 | require.NotNil(t, removedCookie) 77 | assert.Equal(t, "", removedCookie.Value) 78 | assert.Equal(t, -1, removedCookie.MaxAge) 79 | } 80 | 81 | // TestCookiePropagatorOptions 测试Cookie传播器的配置选项 82 | func TestCookiePropagatorOptions(t *testing.T) { 83 | // 创建带有自定义选项的Cookie传播器 84 | propagator := cookie.NewPropagator("custom-session", 85 | cookie.WithPath("/api"), 86 | cookie.WithDomain("example.com"), 87 | cookie.WithMaxAge(7200), 88 | cookie.WithSecure(true), 89 | cookie.WithHTTPOnly(true), 90 | cookie.WithSameSite(http.SameSiteStrictMode), 91 | ) 92 | require.NotNil(t, propagator) 93 | 94 | // 创建响应写入器记录者 95 | w := httptest.NewRecorder() 96 | 97 | // 测试注入会话ID 98 | sessionID := "custom-session-id-67890" 99 | err := propagator.Inject(sessionID, w) 100 | require.NoError(t, err) 101 | 102 | // 验证Cookie已设置并带有自定义选项 103 | resp := w.Result() 104 | cookies := resp.Cookies() 105 | require.NotEmpty(t, cookies) 106 | 107 | var sessionCookie *http.Cookie 108 | for _, c := range cookies { 109 | if c.Name == "custom-session" { 110 | sessionCookie = c 111 | break 112 | } 113 | } 114 | require.NotNil(t, sessionCookie) 115 | assert.Equal(t, sessionID, sessionCookie.Value) 116 | assert.Equal(t, "/api", sessionCookie.Path) 117 | assert.Equal(t, "example.com", sessionCookie.Domain) 118 | assert.Equal(t, 7200, sessionCookie.MaxAge) 119 | assert.True(t, sessionCookie.Secure) 120 | assert.True(t, sessionCookie.HttpOnly) 121 | assert.Equal(t, http.SameSiteStrictMode, sessionCookie.SameSite) 122 | 123 | // 测试动态修改Cookie最大有效期 124 | propagator.SetMaxAge(1800) 125 | 126 | // 创建新的响应写入器记录者 127 | w = httptest.NewRecorder() 128 | 129 | // 重新注入会话ID 130 | err = propagator.Inject(sessionID, w) 131 | require.NoError(t, err) 132 | 133 | // 验证Cookie的MaxAge已被更新 134 | resp = w.Result() 135 | cookies = resp.Cookies() 136 | require.NotEmpty(t, cookies) 137 | 138 | sessionCookie = nil 139 | for _, c := range cookies { 140 | if c.Name == "custom-session" { 141 | sessionCookie = c 142 | break 143 | } 144 | } 145 | require.NotNil(t, sessionCookie) 146 | assert.Equal(t, 1800, sessionCookie.MaxAge) 147 | } 148 | -------------------------------------------------------------------------------- /security/global.go: -------------------------------------------------------------------------------- 1 | package security 2 | 3 | import ( 4 | "github.com/dormoron/mist" 5 | ) 6 | 7 | // CtxSessionKey is a constant string used as a key for storing session data in the context. 8 | const CtxSessionKey = "_session" 9 | 10 | // defaultProvider is a global variable that holds the default session provider. 11 | var defaultProvider Provider 12 | 13 | // InitSession initializes a new session using the default provider. 14 | // Parameters: 15 | // - ctx: The request context (*mist.Context). 16 | // - uid: User ID for the session (int64). 17 | // - jwtData: JWT-related data to be included in the session (map[string]any). 18 | // - sessData: Additional session data (map[string]any). 19 | // Returns: 20 | // - Session: The newly created session. 21 | // - error: An error if the session creation fails. 22 | func InitSession(ctx *mist.Context, uid int64, jwtData map[string]any, sessData map[string]any) (Session, error) { 23 | // Use the default provider to initialize the session. 24 | return defaultProvider.InitSession(ctx, uid, jwtData, sessData) 25 | } 26 | 27 | // Get retrieves the session associated with the given context using the default provider. 28 | // Parameters: 29 | // - ctx: The request context (*mist.Context). 30 | // Returns: 31 | // - Session: The session associated with the context. 32 | // - error: An error if the session retrieval fails. 33 | func Get(ctx *mist.Context) (Session, error) { 34 | // Use the default provider to get the session from the context. 35 | return defaultProvider.Get(ctx) 36 | } 37 | 38 | // SetDefaultProvider sets the default session provider. 39 | // Parameters: 40 | // - sp: The session provider to be set as the default (Provider). 41 | func SetDefaultProvider(sp Provider) { 42 | // Assign the provided session provider to the default provider variable. 43 | defaultProvider = sp 44 | } 45 | 46 | // DefaultProvider returns the current default session provider. 47 | // Returns: 48 | // - Provider: The current default session provider. 49 | func DefaultProvider() Provider { 50 | // Return the default session provider. 51 | return defaultProvider 52 | } 53 | 54 | // CheckLoginMiddleware creates a middleware that checks if the user is logged in for specified paths. 55 | // Parameters: 56 | // - paths: A variadic list of URL paths to be checked (string). 57 | // Returns: 58 | // - mist.Middleware: The constructed middleware. 59 | func CheckLoginMiddleware(paths ...string) mist.Middleware { 60 | // Initialize a MiddlewareBuilder with the default provider and specified paths, 61 | // and then build the middleware. 62 | return InitMiddlewareBuilder(defaultProvider, paths...).Build() 63 | } 64 | 65 | // RenewAccessToken renews the access token for the session associated with the given context. 66 | // Parameters: 67 | // - ctx: The request context (*mist.Context). 68 | // Returns: 69 | // - error: An error if the token renewal fails. 70 | func RenewAccessToken(ctx *mist.Context) error { 71 | // Use the default provider to renew the access token. 72 | return defaultProvider.RenewAccessToken(ctx) 73 | } 74 | 75 | // ClearToken is a function that serves as a wrapper to invoke the ClearToken method 76 | // of the defaultProvider. It clears the access and refresh tokens for a session by leveraging 77 | // the default session provider. 78 | // 79 | // Parameters: 80 | // - ctx: The mist.Context object representing the current HTTP request and response. 81 | // 82 | // Returns: 83 | // - An error object if the underlying ClearToken method in defaultProvider fails, otherwise it returns nil. 84 | func ClearToken(ctx *mist.Context) error { 85 | // Call the ClearToken method on the defaultProvider, passing in the context of the current HTTP request and response. 86 | // This delegates the task of clearing the session tokens to the defaultProvider. 87 | return defaultProvider.ClearToken(ctx) 88 | } 89 | 90 | // UpdateClaims updates the claims for the session associated with the given context. 91 | // Parameters: 92 | // - ctx: The request context (*mist.Context). 93 | // - claims: The claims to be updated (Claims). 94 | // Returns: 95 | // - error: An error if the claims update fails. 96 | func UpdateClaims(ctx *mist.Context, claims Claims) error { 97 | // Use the default provider to update the claims. 98 | return defaultProvider.UpdateClaims(ctx, claims) 99 | } 100 | -------------------------------------------------------------------------------- /http3.go: -------------------------------------------------------------------------------- 1 | package mist 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "fmt" 7 | "net/http" 8 | "time" 9 | 10 | "github.com/quic-go/quic-go" 11 | "github.com/quic-go/quic-go/http3" 12 | ) 13 | 14 | // HTTP3Server 提供HTTP/3服务器功能 15 | type HTTP3Server struct { 16 | httpServer *http.Server 17 | quicServer *http3.Server 18 | quicListener *quic.EarlyListener 19 | log Logger 20 | config *HTTP3Config 21 | } 22 | 23 | // HTTP3Config 定义HTTP/3服务器的配置选项 24 | type HTTP3Config struct { 25 | // QUIC配置 26 | MaxIdleTimeout time.Duration 27 | MaxIncomingStreams int64 28 | MaxIncomingUniStreams int64 29 | 30 | // TLS配置 31 | TLSConfig *tls.Config 32 | 33 | // 连接配置 34 | EnableDatagrams bool 35 | HandshakeIdleTimeout time.Duration 36 | 37 | // Alt-Svc配置 38 | AltSvcHeader string 39 | EnableAltSvcHeader bool 40 | } 41 | 42 | // DefaultHTTP3Config 返回默认的HTTP/3配置 43 | func DefaultHTTP3Config() *HTTP3Config { 44 | return &HTTP3Config{ 45 | MaxIdleTimeout: 30 * time.Second, 46 | MaxIncomingStreams: 100, 47 | MaxIncomingUniStreams: 100, 48 | EnableDatagrams: false, 49 | HandshakeIdleTimeout: 10 * time.Second, 50 | EnableAltSvcHeader: true, 51 | AltSvcHeader: `h3=":443"; ma=2592000`, 52 | } 53 | } 54 | 55 | // NewHTTP3Server 创建新的HTTP/3服务器 56 | func NewHTTP3Server(httpServer *http.Server, config *HTTP3Config, logger Logger) *HTTP3Server { 57 | if config == nil { 58 | config = DefaultHTTP3Config() 59 | } 60 | 61 | if logger == nil { 62 | // 使用默认日志记录器 63 | logger = GetDefaultLogger() 64 | } 65 | 66 | return &HTTP3Server{ 67 | httpServer: httpServer, 68 | log: logger, 69 | config: config, 70 | } 71 | } 72 | 73 | // ListenAndServe 在指定地址启动HTTP/3服务器 74 | func (s *HTTP3Server) ListenAndServe(addr string) error { 75 | s.log.Info("HTTP/3服务器准备启动于: %s", addr) 76 | 77 | // 必须有TLS配置 78 | if s.httpServer.TLSConfig == nil { 79 | return fmt.Errorf("HTTP/3需要TLS配置") 80 | } 81 | 82 | // 确保TLS配置支持QUIC 83 | s.httpServer.TLSConfig.NextProtos = append(s.httpServer.TLSConfig.NextProtos, "h3") 84 | 85 | // 创建QUIC配置 86 | quicConfig := &quic.Config{ 87 | MaxIdleTimeout: s.config.MaxIdleTimeout, 88 | MaxIncomingStreams: s.config.MaxIncomingStreams, 89 | MaxIncomingUniStreams: s.config.MaxIncomingUniStreams, 90 | EnableDatagrams: s.config.EnableDatagrams, 91 | HandshakeIdleTimeout: s.config.HandshakeIdleTimeout, 92 | } 93 | 94 | // 创建HTTP/3服务器 95 | s.quicServer = &http3.Server{ 96 | Handler: s.httpServer.Handler, 97 | TLSConfig: s.httpServer.TLSConfig, 98 | QuicConfig: quicConfig, 99 | } 100 | 101 | // 在响应头中添加Alt-Svc 102 | if s.config.EnableAltSvcHeader { 103 | // 使用中间件添加Alt-Svc头 104 | originalHandler := s.httpServer.Handler 105 | s.httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 106 | w.Header().Add("Alt-Svc", s.config.AltSvcHeader) 107 | originalHandler.ServeHTTP(w, r) 108 | }) 109 | } 110 | 111 | // 创建QUIC监听器 112 | listener, err := quic.ListenAddrEarly(addr, s.httpServer.TLSConfig, quicConfig) 113 | if err != nil { 114 | return fmt.Errorf("启动QUIC监听器失败: %v", err) 115 | } 116 | 117 | s.quicListener = listener 118 | s.log.Info("HTTP/3服务器已启动于: %s", addr) 119 | 120 | return s.quicServer.ServeListener(listener) 121 | } 122 | 123 | // ListenAndServeTLS 使用TLS证书启动HTTP/3服务器 124 | func (s *HTTP3Server) ListenAndServeTLS(addr, certFile, keyFile string) error { 125 | // 加载TLS证书 126 | cert, err := tls.LoadX509KeyPair(certFile, keyFile) 127 | if err != nil { 128 | return fmt.Errorf("加载TLS证书失败: %v", err) 129 | } 130 | 131 | // 创建TLS配置 132 | tlsConfig := &tls.Config{ 133 | Certificates: []tls.Certificate{cert}, 134 | NextProtos: []string{"h3"}, 135 | } 136 | 137 | // 设置服务器TLS配置 138 | s.httpServer.TLSConfig = tlsConfig 139 | 140 | return s.ListenAndServe(addr) 141 | } 142 | 143 | // Shutdown 优雅关闭HTTP/3服务器 144 | func (s *HTTP3Server) Shutdown(ctx context.Context) error { 145 | var err error 146 | 147 | // 关闭QUIC监听器 148 | if s.quicListener != nil { 149 | err = s.quicListener.Close() 150 | if err != nil { 151 | s.log.Error("关闭QUIC监听器失败: %v", err) 152 | } 153 | } 154 | 155 | // 关闭HTTP/3服务器 156 | if s.quicServer != nil { 157 | err = s.quicServer.CloseGracefully(0) 158 | } 159 | 160 | return err 161 | } 162 | -------------------------------------------------------------------------------- /middlewares/prometheus/middleware.go: -------------------------------------------------------------------------------- 1 | package prometheus 2 | 3 | import ( 4 | "github.com/dormoron/mist" 5 | "github.com/prometheus/client_golang/prometheus" 6 | "strconv" 7 | "time" 8 | ) 9 | 10 | // MiddlewareBuilder is a struct that holds metadata for identifying and describing a set 11 | // of middleware. This metadata includes details that are typically used for logging, monitoring, 12 | // or other forms of introspection. The struct is not a middleware itself, but rather a 13 | // collection of descriptive fields that may be associated with middleware operations. 14 | type MiddlewareBuilder struct { 15 | Namespace string // Namespace is a top-level categorization that groups related subsystems. It's meant to prevent collisions between different subsystems. 16 | Subsystem string // Subsystem is a second-level categorization beneath namespace that allows for grouping related functionalities. 17 | Name string // Name is the individual identifier for a specific middleware component. It should be unique within the namespace and subsystem. 18 | Help string // Help is a descriptive string that provides insights into what the middleware does or is used for. It may be exposed in monitoring tools or documentation. 19 | } 20 | 21 | func InitMiddlewareBuilder(namespace string, subsystem string, name string, help string) *MiddlewareBuilder { 22 | return &MiddlewareBuilder{ 23 | Namespace: namespace, 24 | Subsystem: subsystem, 25 | Name: name, 26 | Help: help, 27 | } 28 | } 29 | 30 | // Build constructs and returns a new prometheus monitoring middleware for use within the mist framework. 31 | // The MiddlewareBuilder receiver attaches the newly built prometheus metric collection functionality to mist Middleware. 32 | // The metrics collected are specifically SummaryVec which help in observing the request latency distribution. 33 | func (m *MiddlewareBuilder) Build() mist.Middleware { 34 | // A SummaryVec is created with the necessary prometheus options including the provided namespace, subsystem, 35 | // and the name from the MiddlewareBuilder fields. Additionally, the helper message and specific objectives for quantiles are set. 36 | vector := prometheus.NewSummaryVec(prometheus.SummaryOpts{ 37 | Namespace: m.Namespace, // Uses the namespace provided in the MiddlewareBuilder 38 | Subsystem: m.Subsystem, // Uses the subsystem provided in the MiddlewareBuilder 39 | Name: m.Name, // Uses the name provided in the MiddlewareBuilder 40 | Help: m.Help, // Uses the help message provided in the MiddlewareBuilder 41 | // Objectives is a map defining the quantile rank and the allowable error. 42 | // This allows us to calculate, e.g., the 50th percentile (median) with 1% error. 43 | Objectives: map[float64]float64{ 44 | 0.5: 0.01, // 50th percentile (median) 45 | 0.75: 0.01, // 75th percentile 46 | 0.90: 0.01, // 90th percentile 47 | 0.99: 0.001, // 99th percentile, with a smaller error of 0.1% 48 | 0.999: 0.0001, // 99.9th percentile, with an even smaller error 49 | }, 50 | // Labels are predefined which we will later assign values for each observation. 51 | }, []string{"pattern", "method", "status"}) 52 | 53 | // Register the SummaryVec to prometheus; MustRegister panics if this fails 54 | prometheus.MustRegister(vector) 55 | 56 | // Return a new middleware function which will be called during request processing in the mist framework 57 | return func(next mist.HandleFunc) mist.HandleFunc { 58 | return func(ctx *mist.Context) { 59 | startTime := time.Now() // Record the start time when the request processing begins 60 | // Defer a function to ensure it's executed after the main middleware logic. 61 | // It measures the time taken to process the request and records it as an observation in the SummaryVec. 62 | defer func() { 63 | // Calculate the duration since the start time in microseconds 64 | duration := time.Now().Sub(startTime).Microseconds() 65 | 66 | // Retrieve the matched route pattern from the context, use "unknown" as a default 67 | pattern := ctx.MatchedRoute 68 | if pattern == "" { 69 | pattern = "unknown" 70 | } 71 | 72 | // Use the pattern, HTTP method, and status code as labels to observe the request duration 73 | // Float64 conversion is necessary to record the duration in the correct format 74 | vector.WithLabelValues(pattern, ctx.Request.Method, strconv.Itoa(ctx.RespStatusCode)).Observe(float64(duration)) 75 | }() 76 | 77 | // Proceed with the next middleware or final handler in the chain 78 | next(ctx) 79 | } 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /session/tests/memory_test.go: -------------------------------------------------------------------------------- 1 | package tests 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/dormoron/mist/session/memory" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | // TestMemoryStore 测试内存会话存储的基本功能 16 | func TestMemoryStore(t *testing.T) { 17 | // 创建内存存储 18 | store, err := memory.NewStore() 19 | require.NoError(t, err) 20 | require.NotNil(t, store) 21 | 22 | // 测试上下文 23 | ctx := context.Background() 24 | 25 | // 生成新会话 26 | sessionID := "test-memory-session-1" 27 | sess, err := store.Generate(ctx, sessionID) 28 | require.NoError(t, err) 29 | require.NotNil(t, sess) 30 | assert.Equal(t, sessionID, sess.ID()) 31 | 32 | // 测试设置会话值 33 | err = sess.Set(ctx, "username", "testuser") 34 | require.NoError(t, err) 35 | assert.True(t, sess.IsModified()) 36 | 37 | // 测试获取会话值 38 | val, err := sess.Get(ctx, "username") 39 | require.NoError(t, err) 40 | assert.Equal(t, "testuser", val) 41 | 42 | // 测试保存会话 43 | err = sess.Save() 44 | require.NoError(t, err) 45 | assert.False(t, sess.IsModified()) 46 | 47 | // 测试刷新会话 48 | err = store.Refresh(ctx, sessionID) 49 | require.NoError(t, err) 50 | 51 | // 测试获取现有会话 52 | retrievedSess, err := store.Get(ctx, sessionID) 53 | require.NoError(t, err) 54 | require.NotNil(t, retrievedSess) 55 | assert.Equal(t, sessionID, retrievedSess.ID()) 56 | 57 | // 测试从获取的会话中获取值 58 | val, err = retrievedSess.Get(ctx, "username") 59 | require.NoError(t, err) 60 | assert.Equal(t, "testuser", val) 61 | 62 | // 测试删除会话键 63 | err = retrievedSess.Delete(ctx, "username") 64 | require.NoError(t, err) 65 | assert.True(t, retrievedSess.IsModified()) 66 | 67 | // 验证键已被删除 68 | val, err = retrievedSess.Get(ctx, "username") 69 | require.NoError(t, err) 70 | assert.Nil(t, val) 71 | 72 | // 测试设置最大存活时间 73 | retrievedSess.SetMaxAge(3600) // 1小时 74 | err = retrievedSess.Save() 75 | require.NoError(t, err) 76 | assert.False(t, retrievedSess.IsModified()) 77 | 78 | // 测试删除会话 79 | err = store.Remove(ctx, sessionID) 80 | require.NoError(t, err) 81 | 82 | // 验证会话已被删除 83 | _, err = store.Get(ctx, sessionID) 84 | assert.Error(t, err) 85 | 86 | // 测试垃圾回收 87 | err = store.GC(ctx) 88 | require.NoError(t, err) 89 | } 90 | 91 | // TestMemoryStoreExpiration 测试内存会话存储的过期功能 92 | func TestMemoryStoreExpiration(t *testing.T) { 93 | // 创建具有短期过期时间的内存存储 94 | store := memory.InitStore(100 * time.Millisecond) 95 | require.NotNil(t, store) 96 | 97 | // 测试上下文 98 | ctx := context.Background() 99 | 100 | // 生成新会话 101 | sessionID := "test-expiring-session" 102 | sess, err := store.Generate(ctx, sessionID) 103 | require.NoError(t, err) 104 | require.NotNil(t, sess) 105 | 106 | // 设置会话值 107 | err = sess.Set(ctx, "temp", "value") 108 | require.NoError(t, err) 109 | 110 | // 等待会话过期 111 | time.Sleep(200 * time.Millisecond) 112 | 113 | // 验证会话已过期 114 | _, err = store.Get(ctx, sessionID) 115 | assert.Error(t, err) 116 | 117 | // 运行垃圾回收确保资源被释放 118 | err = store.GC(ctx) 119 | require.NoError(t, err) 120 | } 121 | 122 | // TestConcurrentAccess 测试并发访问内存会话存储 123 | func TestConcurrentAccess(t *testing.T) { 124 | store, err := memory.NewStore() 125 | require.NoError(t, err) 126 | require.NotNil(t, store) 127 | 128 | // 测试上下文 129 | ctx := context.Background() 130 | 131 | // 生成新会话 132 | sessionID := "test-concurrent-session" 133 | sess, err := store.Generate(ctx, sessionID) 134 | require.NoError(t, err) 135 | require.NotNil(t, sess) 136 | 137 | // 并发读写测试 138 | const goroutines = 10 139 | const operations = 100 140 | 141 | var wg sync.WaitGroup 142 | wg.Add(goroutines * 2) // 读和写操作的goroutines数量 143 | 144 | // 并发写入 145 | for i := 0; i < goroutines; i++ { 146 | go func(routineID int) { 147 | defer wg.Done() 148 | for j := 0; j < operations; j++ { 149 | key := fmt.Sprintf("key-%d-%d", routineID, j) 150 | err := sess.Set(ctx, key, j) 151 | assert.NoError(t, err) 152 | } 153 | }(i) 154 | } 155 | 156 | // 并发读取 157 | for i := 0; i < goroutines; i++ { 158 | go func(routineID int) { 159 | defer wg.Done() 160 | for j := 0; j < operations; j++ { 161 | key := fmt.Sprintf("key-%d-%d", routineID, j) 162 | _, err := sess.Get(ctx, key) 163 | // 不断言值,因为可能还没有被写入 164 | if err != nil { 165 | // 忽略错误,因为可能是键不存在 166 | _ = err 167 | } 168 | } 169 | }(i) 170 | } 171 | 172 | wg.Wait() 173 | 174 | // 保存会话 175 | err = sess.Save() 176 | require.NoError(t, err) 177 | 178 | // 清理 179 | err = store.Remove(ctx, sessionID) 180 | require.NoError(t, err) 181 | } 182 | -------------------------------------------------------------------------------- /middlewares/opentelemetry/middleware.go: -------------------------------------------------------------------------------- 1 | package opentelemetry 2 | 3 | import ( 4 | "github.com/dormoron/mist" 5 | "go.opentelemetry.io/otel" 6 | "go.opentelemetry.io/otel/attribute" 7 | "go.opentelemetry.io/otel/propagation" 8 | "go.opentelemetry.io/otel/trace" 9 | ) 10 | 11 | const instrumentationName = "github.com/dormoron/mist/middleware/opentelemetry" 12 | 13 | // MiddlewareBuilder is a struct that aids in constructing middleware with tracing capabilities. 14 | // It holds a reference to a Tracer instance which will be used to trace the flow of HTTP requests. 15 | type MiddlewareBuilder struct { 16 | Tracer trace.Tracer // Tracer is an interface that abstracts the tracing functionality. 17 | // This tracer will be used to create spans for the structured monitoring of 18 | // application's request flows and performance. 19 | } 20 | 21 | // InitMiddlewareBuilder initializes and returns a new instance of MiddlewareBuilder. 22 | // It sets the Tracer field to the global tracer provided by OpenTelemetry with the specified instrumentation name. 23 | // The instrumentationName must be defined elsewhere in the code. 24 | // Returns: 25 | // - a pointer to the newly created MiddlewareBuilder instance. 26 | func InitMiddlewareBuilder() *MiddlewareBuilder { 27 | return &MiddlewareBuilder{ 28 | Tracer: otel.GetTracerProvider().Tracer(instrumentationName), // Obtain the global tracer using the instrumentation name. 29 | } 30 | } 31 | 32 | // SetTracer updates the Tracer field of the MiddlewareBuilder with the provided tracer. 33 | // This allows for setting a specific tracer for telemetry data, which could be part of distributed tracing. 34 | // Parameters: 35 | // - tracer: the trace.Tracer instance to set as the MiddlewareBuilder's Tracer. 36 | // Returns: 37 | // - the pointer to the MiddlewareBuilder instance to allow method chaining. 38 | func (m *MiddlewareBuilder) SetTracer(tracer trace.Tracer) *MiddlewareBuilder { 39 | m.Tracer = tracer // Set the provided tracer to the Tracer field. 40 | return m // Return the MiddlewareBuilder instance for chaining. 41 | } 42 | 43 | // Build is a method attached to the MiddlewareBuilder struct. This method initializes 44 | // and returns a Tracing middleware that can be used in the mist HTTP framework. 45 | // This middleware is responsible for starting a new span for each incoming HTTP request, 46 | // sets various attributes related to the request and ensures that the span is ended 47 | // properly after the request is handled. 48 | func (m *MiddlewareBuilder) Build() mist.Middleware { 49 | 50 | // Return an anonymous function matching the mist middleware signature. 51 | return func(next mist.HandleFunc) mist.HandleFunc { 52 | // This anonymous function is the actual middleware being executed per request. 53 | return func(ctx *mist.Context) { 54 | // Extract the current request context from the incoming HTTP request. 55 | reqCtx := ctx.Request.Context() 56 | 57 | // Inject distributed tracing headers into the request context. 58 | reqCtx = otel.GetTextMapPropagator().Extract(reqCtx, propagation.HeaderCarrier(ctx.Request.Header)) 59 | 60 | // Start a new span with the request context, using the name "unknown" as a placeholder 61 | // until the actual route is matched. 62 | _, span := m.Tracer.Start(reqCtx, "unknown") 63 | 64 | // Defer the end of the span till after the request is handled. 65 | // This ensures the following code runs after the next handlers are completed, 66 | // right before exiting the middleware function. 67 | defer func() { 68 | // If the route was matched, name the span after the matched route. 69 | span.SetName(ctx.MatchedRoute) 70 | 71 | // Set additional attributes to the span, such as the HTTP status code. 72 | span.SetAttributes(attribute.Int("http.status", ctx.RespStatusCode)) 73 | 74 | // End the span. This records the span's information and exports it to any configured telemetry systems. 75 | span.End() 76 | }() 77 | 78 | // Before proceeding, add additional HTTP-related information to the span, 79 | // such as the HTTP method, full URL, URL scheme, and host. 80 | span.SetAttributes(attribute.String("http.method", ctx.Request.Method)) 81 | span.SetAttributes(attribute.String("http.url", ctx.Request.URL.String())) 82 | span.SetAttributes(attribute.String("http.scheme", ctx.Request.URL.Scheme)) 83 | span.SetAttributes(attribute.String("http.host", ctx.Request.Host)) 84 | 85 | // Update the request's context to include tracing context. 86 | ctx.Request = ctx.Request.WithContext(reqCtx) 87 | 88 | // Call the next function in the middleware chain with the updated context. 89 | next(ctx) 90 | } 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /session/README.md: -------------------------------------------------------------------------------- 1 | # Mist 会话管理模块 2 | 3 | Mist框架会话管理模块提供了一个灵活、安全、高性能的会话管理解决方案,支持多种存储后端和传播机制。 4 | 5 | ## 主要功能 6 | 7 | - 支持内存和Redis存储后端 8 | - 基于Cookie的会话传播 9 | - 自动垃圾回收机制 10 | - 线程安全的会话操作 11 | - 会话过期管理 12 | - 完整的测试覆盖 13 | 14 | ## 快速开始 15 | 16 | ### 创建内存会话存储 17 | 18 | ```go 19 | import ( 20 | "github.com/dormoron/mist/session" 21 | "github.com/dormoron/mist/session/memory" 22 | ) 23 | 24 | // 创建内存存储 25 | store, err := memory.NewStore() 26 | if err != nil { 27 | // 处理错误 28 | } 29 | 30 | // 创建会话管理器,会话有效期30分钟 31 | manager, err := session.NewManager(store, 1800) 32 | if err != nil { 33 | // 处理错误 34 | } 35 | 36 | // 启用自动垃圾回收,每10分钟执行一次 37 | manager.EnableAutoGC(10 * time.Minute) 38 | ``` 39 | 40 | ### 创建Redis会话存储 41 | 42 | ```go 43 | import ( 44 | "github.com/dormoron/mist/session" 45 | "github.com/dormoron/mist/session/redis" 46 | redisClient "github.com/redis/go-redis/v9" 47 | ) 48 | 49 | // 创建Redis客户端 50 | client := redisClient.NewClient(&redisClient.Options{ 51 | Addr: "localhost:6379", 52 | Password: "", // 无密码 53 | DB: 0, // 默认DB 54 | }) 55 | 56 | // 创建Redis存储 57 | store := redis.InitStore(client, 58 | redis.StoreWithExpiration(30 * time.Minute), 59 | redis.StoreWithPrefix("mist:sess:"), 60 | ) 61 | 62 | // 创建会话管理器 63 | manager, err := session.NewManager(store, 1800) 64 | if err != nil { 65 | // 处理错误 66 | } 67 | ``` 68 | 69 | ### 在HTTP处理器中使用会话 70 | 71 | ```go 72 | import ( 73 | "github.com/dormoron/mist" 74 | ) 75 | 76 | func HandleLogin(ctx *mist.Context) { 77 | // 创建新会话 78 | sess, err := manager.InitSession(ctx) 79 | if err != nil { 80 | // 处理错误 81 | return 82 | } 83 | 84 | // 存储用户信息 85 | sess.Set(ctx.Request.Context(), "user_id", 12345) 86 | sess.Set(ctx.Request.Context(), "username", "testuser") 87 | sess.Save() 88 | 89 | // 响应登录成功 90 | ctx.JSON(200, map[string]interface{}{ 91 | "message": "登录成功", 92 | }) 93 | } 94 | 95 | func HandleProfile(ctx *mist.Context) { 96 | // 获取现有会话 97 | sess, err := manager.GetSession(ctx) 98 | if err != nil { 99 | // 会话不存在,重定向到登录页面 100 | ctx.Redirect(302, "/login") 101 | return 102 | } 103 | 104 | // 从会话中获取用户信息 105 | userID, err := sess.Get(ctx.Request.Context(), "user_id") 106 | if err != nil { 107 | // 处理错误 108 | return 109 | } 110 | 111 | username, err := sess.Get(ctx.Request.Context(), "username") 112 | if err != nil { 113 | // 处理错误 114 | return 115 | } 116 | 117 | // 刷新会话 118 | manager.RefreshSession(ctx) 119 | 120 | // 响应用户资料 121 | ctx.JSON(200, map[string]interface{}{ 122 | "user_id": userID, 123 | "username": username, 124 | }) 125 | } 126 | 127 | func HandleLogout(ctx *mist.Context) { 128 | // 删除会话 129 | err := manager.RemoveSession(ctx) 130 | if err != nil { 131 | // 处理错误 132 | return 133 | } 134 | 135 | // 响应登出成功 136 | ctx.JSON(200, map[string]interface{}{ 137 | "message": "登出成功", 138 | }) 139 | } 140 | ``` 141 | 142 | ## 自定义配置 143 | 144 | ### 自定义Cookie选项 145 | 146 | ```go 147 | import ( 148 | "github.com/dormoron/mist/session/cookie" 149 | ) 150 | 151 | // 创建自定义Cookie传播器 152 | cookieProp := cookie.NewPropagator("custom_session", 153 | cookie.WithPath("/api"), 154 | cookie.WithDomain("example.com"), 155 | cookie.WithMaxAge(7200), 156 | cookie.WithSecure(true), 157 | cookie.WithHTTPOnly(true), 158 | cookie.WithSameSite(http.SameSiteStrictMode), 159 | ) 160 | 161 | // 创建会话管理器并使用自定义传播器 162 | manager := &session.Manager{ 163 | Store: store, 164 | Propagator: cookieProp, 165 | CtxSessionKey: "session", 166 | } 167 | ``` 168 | 169 | ### 自定义内存存储选项 170 | 171 | ```go 172 | // 创建具有自定义过期时间的内存存储 173 | store := memory.InitStore(60 * time.Minute) 174 | ``` 175 | 176 | ### 自定义Redis存储选项 177 | 178 | ```go 179 | // 创建具有自定义选项的Redis存储 180 | store := redis.InitStore(client, 181 | redis.StoreWithExpiration(60 * time.Minute), 182 | redis.StoreWithPrefix("myapp:session:"), 183 | ) 184 | ``` 185 | 186 | ## 最佳实践 187 | 188 | 1. **安全设置**: 始终设置Cookie为HttpOnly和Secure(在HTTPS环境中)。 189 | 190 | 2. **适当的过期时间**: 根据应用的安全需求设置合适的会话过期时间。 191 | 192 | 3. **定期垃圾回收**: 启用自动垃圾回收以防止内存泄漏。 193 | 194 | 4. **会话数据**: 尽量只存储必要的数据在会话中,大型数据应存储在数据库中。 195 | 196 | 5. **错误处理**: 妥善处理所有会话操作中可能出现的错误。 197 | 198 | ## 类型和接口 199 | 200 | 会话管理模块定义了以下主要接口: 201 | 202 | - `session.Store`: 会话数据存储接口 203 | - `session.Session`: 单个会话接口 204 | - `session.Propagator`: 会话ID传播接口 205 | 206 | 详细文档可参考每个接口的Go文档。 207 | 208 | ## 性能考虑 209 | 210 | - 内存存储适用于单机应用或较小规模的应用 211 | - Redis存储适用于需要会话共享的分布式应用 212 | - 会话数据应尽量保持简洁,避免存储大量数据 -------------------------------------------------------------------------------- /security/report/example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net/http" 7 | "time" 8 | 9 | "github.com/dormoron/mist/security/report" 10 | ) 11 | 12 | func main() { 13 | // 创建一个内存报告处理器,最多保存100条报告 14 | handler := report.NewMemoryHandler(100) 15 | 16 | // 创建报告服务器 17 | reportServer := report.NewReportServer(handler) 18 | 19 | // 为各种安全报告设置处理路由 20 | http.Handle("/report/csp", reportServer) 21 | http.Handle("/report/xss", reportServer) 22 | http.Handle("/report/hpkp", reportServer) 23 | http.Handle("/report/feature", reportServer) 24 | http.Handle("/report/nel", reportServer) 25 | http.Handle("/report/coep", reportServer) 26 | http.Handle("/report/corp", reportServer) 27 | http.Handle("/report/coop", reportServer) 28 | 29 | // 添加一个API端点来获取报告摘要 30 | http.HandleFunc("/api/reports/summary", func(w http.ResponseWriter, r *http.Request) { 31 | if r.Method != http.MethodGet { 32 | http.Error(w, "方法不允许", http.StatusMethodNotAllowed) 33 | return 34 | } 35 | 36 | summary, err := handler.GetReportsSummary() 37 | if err != nil { 38 | http.Error(w, "获取报告摘要失败: "+err.Error(), http.StatusInternalServerError) 39 | return 40 | } 41 | 42 | w.Header().Set("Content-Type", "application/json") 43 | fmt.Fprintf(w, "{\n") 44 | first := true 45 | for reportType, count := range summary { 46 | if !first { 47 | fmt.Fprintf(w, ",\n") 48 | } 49 | fmt.Fprintf(w, " %q: %d", reportType, count) 50 | first = false 51 | } 52 | fmt.Fprintf(w, "\n}\n") 53 | }) 54 | 55 | // 添加一个API端点来获取最近的报告 56 | http.HandleFunc("/api/reports/recent", func(w http.ResponseWriter, r *http.Request) { 57 | if r.Method != http.MethodGet { 58 | http.Error(w, "方法不允许", http.StatusMethodNotAllowed) 59 | return 60 | } 61 | 62 | limit := 10 // 默认限制为10条 63 | reports, err := handler.GetRecentReports(limit) 64 | if err != nil { 65 | http.Error(w, "获取最近报告失败: "+err.Error(), http.StatusInternalServerError) 66 | return 67 | } 68 | 69 | w.Header().Set("Content-Type", "application/json") 70 | fmt.Fprintf(w, "[\n") 71 | for i, r := range reports { 72 | if i > 0 { 73 | fmt.Fprintf(w, ",\n") 74 | } 75 | fmt.Fprintf(w, " {\n") 76 | fmt.Fprintf(w, " \"type\": %q,\n", r.Type) 77 | fmt.Fprintf(w, " \"time\": %q,\n", r.Time.Format(time.RFC3339)) 78 | fmt.Fprintf(w, " \"user_agent\": %q,\n", r.UserAgent) 79 | fmt.Fprintf(w, " \"ip_address\": %q,\n", r.IPAddress) 80 | if r.BlockedURI != "" { 81 | fmt.Fprintf(w, " \"blocked_uri\": %q,\n", r.BlockedURI) 82 | } 83 | if r.ViolatedDir != "" { 84 | fmt.Fprintf(w, " \"violated_directive\": %q,\n", r.ViolatedDir) 85 | } 86 | fmt.Fprintf(w, " \"severity\": %d\n", r.Severity) 87 | fmt.Fprintf(w, " }") 88 | } 89 | fmt.Fprintf(w, "\n]\n") 90 | }) 91 | 92 | // 添加一个API端点来按类型获取报告 93 | http.HandleFunc("/api/reports/type/", func(w http.ResponseWriter, r *http.Request) { 94 | if r.Method != http.MethodGet { 95 | http.Error(w, "方法不允许", http.StatusMethodNotAllowed) 96 | return 97 | } 98 | 99 | reportType := r.URL.Path[len("/api/reports/type/"):] 100 | if reportType == "" { 101 | http.Error(w, "必须指定报告类型", http.StatusBadRequest) 102 | return 103 | } 104 | 105 | limit := 10 // 默认限制为10条 106 | reports, err := handler.GetReportsByType(reportType, limit) 107 | if err != nil { 108 | http.Error(w, "获取报告失败: "+err.Error(), http.StatusInternalServerError) 109 | return 110 | } 111 | 112 | w.Header().Set("Content-Type", "application/json") 113 | fmt.Fprintf(w, "[\n") 114 | for i, r := range reports { 115 | if i > 0 { 116 | fmt.Fprintf(w, ",\n") 117 | } 118 | fmt.Fprintf(w, " {\n") 119 | fmt.Fprintf(w, " \"time\": %q,\n", r.Time.Format(time.RFC3339)) 120 | fmt.Fprintf(w, " \"user_agent\": %q,\n", r.UserAgent) 121 | fmt.Fprintf(w, " \"ip_address\": %q,\n", r.IPAddress) 122 | if r.BlockedURI != "" { 123 | fmt.Fprintf(w, " \"blocked_uri\": %q,\n", r.BlockedURI) 124 | } 125 | if r.ViolatedDir != "" { 126 | fmt.Fprintf(w, " \"violated_directive\": %q,\n", r.ViolatedDir) 127 | } 128 | fmt.Fprintf(w, " \"severity\": %d\n", r.Severity) 129 | fmt.Fprintf(w, " }") 130 | } 131 | fmt.Fprintf(w, "\n]\n") 132 | }) 133 | 134 | // 启动HTTP服务器 135 | log.Println("安全报告服务器启动在 http://localhost:8080") 136 | log.Println("- 接收CSP报告: POST /report/csp") 137 | log.Println("- 接收XSS报告: POST /report/xss") 138 | log.Println("- 查看报告摘要: GET /api/reports/summary") 139 | log.Println("- 查看最近报告: GET /api/reports/recent") 140 | log.Println("- 按类型查看报告: GET /api/reports/type/{type}") 141 | 142 | if err := http.ListenAndServe(":8080", nil); err != nil { 143 | log.Fatalf("服务器启动失败: %v", err) 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /security/blocklist/example/README.md: -------------------------------------------------------------------------------- 1 | # IP黑名单(Blocklist)示例 2 | 3 | 这个示例展示了如何使用Mist框架的IP黑名单(Blocklist)功能来保护你的API免受暴力破解攻击。 4 | 5 | ## 功能特点 6 | 7 | - 基于失败尝试次数的IP封禁机制 8 | - 可配置的封禁时长和最大失败次数 9 | - IP白名单支持 10 | - 自动清理过期记录 11 | - 支持手动封禁和解除封禁IP 12 | - 支持Mist框架中间件 13 | 14 | ## 运行示例 15 | 16 | ```bash 17 | # 在默认端口8080上运行 18 | go run main.go 19 | 20 | # 在指定端口上运行 21 | go run main.go -port=8888 22 | ``` 23 | 24 | ## API接口说明 25 | 26 | ### 1. 登录接口 27 | 28 | - **URL**: `/api/login` 29 | - **方法**: `POST` 30 | - **请求体**: 31 | ```json 32 | { 33 | "username": "admin", 34 | "password": "password" 35 | } 36 | ``` 37 | - **成功响应**: 38 | ```json 39 | { 40 | "status": "success", 41 | "message": "登录成功" 42 | } 43 | ``` 44 | - **失败响应**: 45 | ```json 46 | { 47 | "status": "error", 48 | "message": "用户名或密码错误" 49 | } 50 | ``` 51 | - **IP封禁响应**: 52 | ```json 53 | { 54 | "status": "error", 55 | "message": "您的IP因多次失败的尝试已被封禁,请稍后再试" 56 | } 57 | ``` 58 | 59 | ### 2. 受保护的API接口 60 | 61 | - **URL**: `/api/protected` 62 | - **方法**: `GET` 63 | - **成功响应**: 64 | ```json 65 | { 66 | "status": "success", 67 | "message": "这是受保护的API接口" 68 | } 69 | ``` 70 | - **IP封禁响应**: 71 | ```json 72 | { 73 | "status": "error", 74 | "message": "您的IP因多次失败的尝试已被封禁,请稍后再试" 75 | } 76 | ``` 77 | 78 | ### 3. 解除IP封禁(管理接口) 79 | 80 | - **URL**: `/api/admin/unblock?ip={ip地址}` 81 | - **方法**: `POST` 82 | - **成功响应**: 83 | ```json 84 | { 85 | "status": "success", 86 | "message": "IP xxx.xxx.xxx.xxx 已解除封禁" 87 | } 88 | ``` 89 | 90 | ### 4. 检查IP状态(管理接口) 91 | 92 | - **URL**: `/api/admin/status?ip={ip地址}` 93 | - **方法**: `GET` 94 | - **成功响应**: 95 | ```json 96 | { 97 | "status": "success", 98 | "ip": "xxx.xxx.xxx.xxx", 99 | "isBlocked": false, 100 | "state": "正常" 101 | } 102 | ``` 103 | 104 | ## 测试示例 105 | 106 | ### 测试登录失败和IP封禁 107 | 108 | ```bash 109 | # 使用错误的密码尝试登录(需要3次失败才会被封禁) 110 | curl -X POST http://localhost:8080/api/login -d '{"username":"admin","password":"wrong"}' 111 | curl -X POST http://localhost:8080/api/login -d '{"username":"admin","password":"wrong"}' 112 | curl -X POST http://localhost:8080/api/login -d '{"username":"admin","password":"wrong"}' 113 | 114 | # 第4次尝试将会收到封禁消息 115 | curl -X POST http://localhost:8080/api/login -d '{"username":"admin","password":"wrong"}' 116 | 117 | # 尝试访问受保护的API 118 | curl http://localhost:8080/api/protected 119 | ``` 120 | 121 | ### 解除IP封禁 122 | 123 | ```bash 124 | # 解除本地IP的封禁 125 | curl -X POST http://localhost:8080/api/admin/unblock?ip=127.0.0.1 126 | ``` 127 | 128 | ### 检查IP状态 129 | 130 | ```bash 131 | # 检查本地IP的状态 132 | curl http://localhost:8080/api/admin/status?ip=127.0.0.1 133 | ``` 134 | 135 | ## 在Mist框架中使用 136 | 137 | 此示例主要展示了如何在Mist框架中使用IP黑名单功能: 138 | 139 | ```go 140 | package main 141 | 142 | import ( 143 | "github.com/dormoron/mist" 144 | "github.com/dormoron/mist/security/blocklist" 145 | "github.com/dormoron/mist/security/blocklist/middleware" 146 | "time" 147 | "log" 148 | "net/http" 149 | ) 150 | 151 | func main() { 152 | // 创建Mist应用 153 | app := mist.InitHTTPServer() 154 | 155 | // 创建IP黑名单管理器 156 | blocklistManager := blocklist.NewManager( 157 | blocklist.WithMaxFailedAttempts(3), 158 | blocklist.WithBlockDuration(5*time.Minute), 159 | ) 160 | 161 | // 使用中间件 162 | app.Use(middleware.New(blocklistManager)) 163 | 164 | // 或使用自定义封禁处理函数 165 | app.Use(middleware.New( 166 | blocklistManager, 167 | middleware.WithOnBlocked(func(ctx *mist.Context) { 168 | // 记录IP封禁事件 169 | log.Printf("IP %s 已被封禁", ctx.ClientIP()) 170 | 171 | // 返回JSON响应 172 | ctx.JSON(http.StatusForbidden, map[string]interface{}{ 173 | "status": "error", 174 | "message": "您的IP因多次失败的尝试已被暂时封禁,请稍后再试", 175 | }) 176 | }), 177 | )) 178 | 179 | // 设置路由和处理函数 180 | app.POST("/api/login", func(ctx *mist.Context) { 181 | // 登录逻辑... 182 | }) 183 | 184 | app.GET("/api/protected", func(ctx *mist.Context) { 185 | ctx.JSON(200, map[string]interface{}{ 186 | "status": "success", 187 | "message": "这是受保护的API接口", 188 | }) 189 | }) 190 | 191 | // 启动服务器 192 | app.Run(":8080") 193 | } 194 | ``` 195 | 196 | ## 自定义配置选项 197 | 198 | ### 配置选项 199 | 200 | - `blocklist.WithMaxFailedAttempts(max int)` - 设置最大失败尝试次数 201 | - `blocklist.WithBlockDuration(duration time.Duration)` - 设置封禁时长 202 | - `blocklist.WithClearInterval(interval time.Duration)` - 设置清理间隔时间 203 | - `blocklist.WithWhitelistIPs(ips []string)` - 设置IP白名单 204 | 205 | ### Mist框架中间件选项 206 | 207 | - `middleware.WithOnBlocked(handler func(*mist.Context))` - 设置Mist框架中的封禁处理函数 -------------------------------------------------------------------------------- /config/examples.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | // 示例1: 基本配置使用 9 | func Example_basic() { 10 | // 创建一个新的配置管理器 11 | cfg, err := New( 12 | WithConfigFile("config.yaml"), 13 | WithEnvPrefix("APP_"), 14 | ) 15 | if err != nil { 16 | panic(err) 17 | } 18 | 19 | // 设置一些配置值 20 | cfg.Set("app.name", "MyApp") 21 | cfg.Set("app.version", "1.0.0") 22 | cfg.Set("server.port", 8080) 23 | cfg.Set("server.timeout", 30) 24 | 25 | // 获取配置值 26 | appName := cfg.GetString("app.name") 27 | port := cfg.GetInt("server.port") 28 | 29 | fmt.Printf("App: %s, Port: %d\n", appName, port) 30 | 31 | // 检查配置是否存在 32 | if cfg.Has("app.debug") { 33 | fmt.Println("Debug mode is configured") 34 | } else { 35 | fmt.Println("Debug mode is not configured") 36 | } 37 | 38 | // 获取所有配置 39 | settings := cfg.AllSettings() 40 | fmt.Printf("All settings: %+v\n", settings) 41 | } 42 | 43 | // AppConfig 应用配置结构体 44 | type AppConfig struct { 45 | Name string `config:"name"` 46 | Version string `config:"version"` 47 | Debug bool `config:"debug"` 48 | Server struct { 49 | Port int `config:"port"` 50 | Timeout time.Duration `config:"timeout"` 51 | Host string `config:"host"` 52 | } `config:"server"` 53 | Database struct { 54 | DSN string `config:"dsn"` 55 | MaxConns int `config:"max_conns"` 56 | MaxIdle int `config:"max_idle"` 57 | } `config:"database"` 58 | Features []string `config:"features"` 59 | Options map[string]interface{} `config:"options"` 60 | } 61 | 62 | // 示例2: 配置结构体映射 63 | func Example_unmarshal() { 64 | // 创建配置管理器 65 | cfg, _ := New() 66 | 67 | // 设置一些配置值 68 | cfg.Set("app", map[string]interface{}{ 69 | "name": "MyApp", 70 | "version": "1.0.0", 71 | "debug": true, 72 | "server": map[string]interface{}{ 73 | "port": 8080, 74 | "timeout": 30, 75 | "host": "localhost", 76 | }, 77 | "database": map[string]interface{}{ 78 | "dsn": "postgres://user:pass@localhost:5432/mydb", 79 | "max_conns": 100, 80 | "max_idle": 10, 81 | }, 82 | "features": []string{"auth", "api", "dashboard"}, 83 | "options": map[string]interface{}{ 84 | "theme": "dark", 85 | "language": "en", 86 | }, 87 | }) 88 | 89 | // 映射到结构体 90 | var appConfig AppConfig 91 | if err := cfg.Unmarshal("app", &appConfig); err != nil { 92 | panic(err) 93 | } 94 | 95 | // 使用配置结构体 96 | fmt.Printf("App: %s (v%s)\n", appConfig.Name, appConfig.Version) 97 | fmt.Printf("Server: %s:%d (timeout: %v)\n", 98 | appConfig.Server.Host, 99 | appConfig.Server.Port, 100 | appConfig.Server.Timeout) 101 | fmt.Printf("Database: %s (max connections: %d)\n", 102 | appConfig.Database.DSN, 103 | appConfig.Database.MaxConns) 104 | 105 | fmt.Println("Features:") 106 | for _, feature := range appConfig.Features { 107 | fmt.Printf(" - %s\n", feature) 108 | } 109 | } 110 | 111 | // 示例3: 配置变更监听 112 | func Example_listener() { 113 | // 创建配置管理器 114 | cfg, _ := New() 115 | 116 | // 添加配置变更监听器 117 | cfg.AddChangeListener(func(key string) { 118 | if key == "" { 119 | fmt.Println("全局配置已变更") 120 | } else { 121 | fmt.Printf("配置项已变更: %s\n", key) 122 | 123 | // 获取新值 124 | value, _ := cfg.Get(key) 125 | fmt.Printf(" 新值: %v\n", value) 126 | } 127 | }) 128 | 129 | // 修改配置 130 | cfg.Set("app.name", "NewName") 131 | cfg.Set("server.port", 9000) 132 | 133 | // 输出: 134 | // 配置项已变更: app.name 135 | // 新值: NewName 136 | // 配置项已变更: server.port 137 | // 新值: 9000 138 | } 139 | 140 | // 以下是配置文件示例 141 | 142 | /* 143 | YAML配置文件示例 (config.yaml): 144 | 145 | app: 146 | name: MyApp 147 | version: 1.0.0 148 | debug: true 149 | 150 | server: 151 | port: 8080 152 | host: localhost 153 | timeout: 30s 154 | 155 | database: 156 | dsn: postgres://user:pass@localhost:5432/mydb 157 | max_conns: 100 158 | max_idle: 10 159 | 160 | features: 161 | - auth 162 | - api 163 | - dashboard 164 | 165 | options: 166 | theme: dark 167 | language: en 168 | 169 | */ 170 | 171 | /* 172 | JSON配置文件示例 (config.json): 173 | 174 | { 175 | "app": { 176 | "name": "MyApp", 177 | "version": "1.0.0", 178 | "debug": true 179 | }, 180 | "server": { 181 | "port": 8080, 182 | "host": "localhost", 183 | "timeout": "30s" 184 | }, 185 | "database": { 186 | "dsn": "postgres://user:pass@localhost:5432/mydb", 187 | "max_conns": 100, 188 | "max_idle": 10 189 | }, 190 | "features": [ 191 | "auth", 192 | "api", 193 | "dashboard" 194 | ], 195 | "options": { 196 | "theme": "dark", 197 | "language": "en" 198 | } 199 | } 200 | 201 | */ 202 | 203 | /* 204 | TOML配置文件示例 (config.toml): 205 | 206 | [app] 207 | name = "MyApp" 208 | version = "1.0.0" 209 | debug = true 210 | 211 | [server] 212 | port = 8080 213 | host = "localhost" 214 | timeout = "30s" 215 | 216 | [database] 217 | dsn = "postgres://user:pass@localhost:5432/mydb" 218 | max_conns = 100 219 | max_idle = 10 220 | 221 | features = ["auth", "api", "dashboard"] 222 | 223 | [options] 224 | theme = "dark" 225 | language = "en" 226 | 227 | */ 228 | -------------------------------------------------------------------------------- /security/redisess/session.go: -------------------------------------------------------------------------------- 1 | package redisess 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/dormoron/mist" 8 | "github.com/dormoron/mist/security" 9 | "github.com/redis/go-redis/v9" 10 | ) 11 | 12 | // Ensure that Session implements the security.Session interface. 13 | var _ security.Session = &Session{} 14 | 15 | // Session is a struct that manages session data using Redis as the backend storage. 16 | type Session struct { 17 | client redis.Cmdable // Redis client used to interact with the Redis server. 18 | key string // The Redis key under which the session data is stored. 19 | claims security.Claims // Security claims associated with the session. 20 | expiration time.Duration // The expiration duration of the session. 21 | } 22 | 23 | // Destroy deletes the session data from Redis. 24 | // Parameters: 25 | // - ctx: The context for controlling the request lifetime (context.Context). 26 | // Returns: 27 | // - error: An error if the deletion fails. 28 | func (sess *Session) Destroy(ctx context.Context) error { 29 | return sess.client.Del(ctx, sess.key).Err() // Perform a DEL command on the session key. 30 | } 31 | 32 | // Del deletes a specific key from the session data in Redis. 33 | // Parameters: 34 | // - ctx: The context for controlling the request lifetime (context.Context). 35 | // - key: The specific key to be deleted from the session data (string). 36 | // Returns: 37 | // - error: An error if the deletion fails. 38 | func (sess *Session) Del(ctx context.Context, key string) error { 39 | return sess.client.Del(ctx, sess.key, key).Err() // Perform a DEL command on the session key and specific field. 40 | } 41 | 42 | // Set sets a key-value pair in the session data in Redis. 43 | // Parameters: 44 | // - ctx: The context for controlling the request lifetime (context.Context). 45 | // - key: The key to be set in the session data (string). 46 | // - val: The value to be set for the specified key (any). 47 | // Returns: 48 | // - error: An error if the set operation fails. 49 | func (sess *Session) Set(ctx context.Context, key string, val any) error { 50 | return sess.client.HSet(ctx, sess.key, key, val).Err() // Perform an HSET command to set the key-value pair in the hash. 51 | } 52 | 53 | // init initializes the session data with key-value pairs provided in the map kvs. 54 | // Parameters: 55 | // - ctx: The context for controlling the request lifetime (context.Context). 56 | // - kvs: A map of key-value pairs to initialize the session data (map[string]any). 57 | // Returns: 58 | // - error: An error if the initialization fails. 59 | func (sess *Session) init(ctx context.Context, kvs map[string]any) error { 60 | pip := sess.client.Pipeline() // Create a new pipeline to batch the commands. 61 | for k, v := range kvs { 62 | pip.HMSet(ctx, sess.key, k, v) // Add an HMSET command for each key-value pair. 63 | } 64 | pip.Expire(ctx, sess.key, sess.expiration) // Set the expiration time for the session. 65 | _, err := pip.Exec(ctx) // Execute all the commands in the pipeline. 66 | return err // Return any error that occurred during execution. 67 | } 68 | 69 | // Get retrieves the value associated with a specific key from the session data in Redis. 70 | // Parameters: 71 | // - ctx: The context for controlling the request lifetime (context.Context). 72 | // - key: The key to retrieve from the session data (string). 73 | // Returns: 74 | // - mist.AnyValue: The value associated with the key, or an error if the retrieval fails. 75 | func (sess *Session) Get(ctx context.Context, key string) mist.AnyValue { 76 | res, err := sess.client.HGet(ctx, sess.key, key).Result() // Perform an HGET command to retrieve the value. 77 | if err != nil { 78 | return mist.AnyValue{Err: err} // Return an AnyValue with the error if the retrieval fails. 79 | } 80 | return mist.AnyValue{ 81 | Val: res, // Return an AnyValue with the retrieved value. 82 | } 83 | } 84 | 85 | // Claims returns the security claims associated with the session. 86 | // Returns: 87 | // - security.Claims: The claims associated with the session. 88 | func (sess *Session) Claims() security.Claims { 89 | return sess.claims // Return the claims. 90 | } 91 | 92 | // initRedisSession initializes a new Session instance with the given parameters. 93 | // Parameters: 94 | // - ssid: The session ID (string). 95 | // - expiration: The expiration duration of the session (time.Duration). 96 | // - client: The Redis client used to interact with the Redis server (redis.Cmdable). 97 | // - cl: Security claims associated with the session (security.Claims). 98 | // Returns: 99 | // - *Session: A pointer to the newly created Session instance. 100 | func initRedisSession(ssid string, expiration time.Duration, client redis.Cmdable, cl security.Claims) *Session { 101 | return &Session{ 102 | client: client, // Set the Redis client. 103 | key: "session:" + ssid, // Construct the Redis key using the session ID. 104 | expiration: expiration, // Set the expiration duration. 105 | claims: cl, // Set the security claims. 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /validation/struct.go: -------------------------------------------------------------------------------- 1 | package validation 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | // ValidateStruct 验证结构体,根据字段标签进行验证 11 | func ValidateStruct(v *Validator, s interface{}) { 12 | val := reflect.ValueOf(s) 13 | 14 | // 如果是指针,获取其指向的值 15 | if val.Kind() == reflect.Ptr { 16 | val = val.Elem() 17 | } 18 | 19 | // 确保是结构体 20 | if val.Kind() != reflect.Struct { 21 | panic("参数必须是结构体") 22 | } 23 | 24 | // 获取结构体类型 25 | typ := val.Type() 26 | 27 | // 遍历所有字段 28 | for i := 0; i < val.NumField(); i++ { 29 | field := val.Field(i) 30 | fieldType := typ.Field(i) 31 | 32 | // 获取验证标签 33 | validateTag := fieldType.Tag.Get("validate") 34 | if validateTag == "" { 35 | continue 36 | } 37 | 38 | // 如果字段是一个嵌套的结构体,递归验证 39 | if field.Kind() == reflect.Struct { 40 | // 获取子结构体的值 41 | fieldValue := field.Interface() 42 | // 递归验证 43 | ValidateStruct(v, fieldValue) 44 | continue 45 | } 46 | 47 | // 解析并应用验证规则 48 | validateRules(v, validateTag, field, fieldType.Name) 49 | } 50 | } 51 | 52 | // validateRules 解析并应用验证规则 53 | func validateRules(v *Validator, tag string, field reflect.Value, fieldName string) { 54 | // 根据逗号分割验证规则 55 | rules := strings.Split(tag, ",") 56 | 57 | for _, rule := range rules { 58 | // 去除空白 59 | rule = strings.TrimSpace(rule) 60 | 61 | // 解析规则名称和参数 62 | parts := strings.SplitN(rule, "=", 2) 63 | ruleName := parts[0] 64 | 65 | // 根据字段类型和规则应用不同的验证 66 | switch field.Kind() { 67 | case reflect.String: 68 | validateString(v, ruleName, parts, field.String(), fieldName) 69 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 70 | validateInt(v, ruleName, parts, field.Int(), fieldName) 71 | case reflect.Float32, reflect.Float64: 72 | validateFloat(v, ruleName, parts, field.Float(), fieldName) 73 | case reflect.Bool: 74 | validateBool(v, ruleName, parts, field.Bool(), fieldName) 75 | case reflect.Slice, reflect.Array: 76 | validateSlice(v, ruleName, parts, field, fieldName) 77 | } 78 | } 79 | } 80 | 81 | // validateString 验证字符串类型 82 | func validateString(v *Validator, ruleName string, parts []string, value string, fieldName string) { 83 | switch ruleName { 84 | case "required": 85 | v.Required(value, fieldName) 86 | case "min": 87 | if len(parts) > 1 { 88 | min, err := strconv.Atoi(parts[1]) 89 | if err == nil { 90 | v.MinLength(value, min, fieldName) 91 | } 92 | } 93 | case "max": 94 | if len(parts) > 1 { 95 | max, err := strconv.Atoi(parts[1]) 96 | if err == nil { 97 | v.MaxLength(value, max, fieldName) 98 | } 99 | } 100 | case "email": 101 | v.Email(value, fieldName) 102 | case "url": 103 | v.URL(value, fieldName) 104 | case "alpha": 105 | v.Alpha(value, fieldName) 106 | case "alphanum": 107 | v.Alphanumeric(value, fieldName) 108 | case "phone": 109 | v.PhoneNumber(value, fieldName) 110 | case "idcard": 111 | v.ChineseIDCard(value, fieldName) 112 | } 113 | } 114 | 115 | // validateInt 验证整数类型 116 | func validateInt(v *Validator, ruleName string, parts []string, value int64, fieldName string) { 117 | switch ruleName { 118 | case "min": 119 | if len(parts) > 1 { 120 | min, err := strconv.ParseInt(parts[1], 10, 64) 121 | if err == nil { 122 | if value < min { 123 | v.AddError(fieldName, fmt.Sprintf("必须大于或等于%d", min)) 124 | } 125 | } 126 | } 127 | case "max": 128 | if len(parts) > 1 { 129 | max, err := strconv.ParseInt(parts[1], 10, 64) 130 | if err == nil { 131 | if value > max { 132 | v.AddError(fieldName, fmt.Sprintf("必须小于或等于%d", max)) 133 | } 134 | } 135 | } 136 | } 137 | } 138 | 139 | // validateFloat 验证浮点数类型 140 | func validateFloat(v *Validator, ruleName string, parts []string, value float64, fieldName string) { 141 | switch ruleName { 142 | case "min": 143 | if len(parts) > 1 { 144 | min, err := strconv.ParseFloat(parts[1], 64) 145 | if err == nil { 146 | if value < min { 147 | v.AddError(fieldName, fmt.Sprintf("必须大于或等于%.2f", min)) 148 | } 149 | } 150 | } 151 | case "max": 152 | if len(parts) > 1 { 153 | max, err := strconv.ParseFloat(parts[1], 64) 154 | if err == nil { 155 | if value > max { 156 | v.AddError(fieldName, fmt.Sprintf("必须小于或等于%.2f", max)) 157 | } 158 | } 159 | } 160 | } 161 | } 162 | 163 | // validateBool 验证布尔类型 164 | func validateBool(v *Validator, ruleName string, parts []string, value bool, fieldName string) { 165 | switch ruleName { 166 | case "required": 167 | // 对于布尔类型,required通常没有意义,但我们可以定义为必须为true 168 | if !value { 169 | v.AddError(fieldName, "必须为true") 170 | } 171 | } 172 | } 173 | 174 | // validateSlice 验证切片类型 175 | func validateSlice(v *Validator, ruleName string, parts []string, field reflect.Value, fieldName string) { 176 | switch ruleName { 177 | case "required": 178 | if field.Len() == 0 { 179 | v.AddError(fieldName, "不能为空") 180 | } 181 | case "min": 182 | if len(parts) > 1 { 183 | min, err := strconv.Atoi(parts[1]) 184 | if err == nil { 185 | if field.Len() < min { 186 | v.AddError(fieldName, fmt.Sprintf("长度不能小于%d", min)) 187 | } 188 | } 189 | } 190 | case "max": 191 | if len(parts) > 1 { 192 | max, err := strconv.Atoi(parts[1]) 193 | if err == nil { 194 | if field.Len() > max { 195 | v.AddError(fieldName, fmt.Sprintf("长度不能大于%d", max)) 196 | } 197 | } 198 | } 199 | } 200 | } 201 | -------------------------------------------------------------------------------- /examples/http3_test/go.sum: -------------------------------------------------------------------------------- 1 | github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= 2 | github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= 3 | github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= 4 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 6 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 7 | github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= 8 | github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= 9 | github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= 10 | github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= 11 | github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= 12 | github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= 13 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 14 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 15 | github.com/google/pprof v0.0.0-20211008130755-947d60d73cc0 h1:zHs+jv3LO743/zFGcByu2KmpbliCU2AhjcGgrdTwSG4= 16 | github.com/google/pprof v0.0.0-20211008130755-947d60d73cc0/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg= 17 | github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c= 18 | github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= 19 | github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= 20 | github.com/onsi/ginkgo/v2 v2.15.0 h1:79HwNRBAZHOEwrczrgSOPy+eFTTlIGELKy5as+ClttY= 21 | github.com/onsi/ginkgo/v2 v2.15.0/go.mod h1:HlxMHtYF57y6Dpf+mc5529KKmSq9h2FpCF+/ZkwUxKM= 22 | github.com/onsi/gomega v1.30.0 h1:hvMK7xYz4D3HapigLTeGdId/NcfQx1VHMJc60ew99+8= 23 | github.com/onsi/gomega v1.30.0/go.mod h1:9sxs+SwGrKI0+PWe4Fxa9tFQQBG5xSsSbMXOI8PPpoQ= 24 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 25 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 26 | github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= 27 | github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= 28 | github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM= 29 | github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M= 30 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 31 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 32 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 33 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 34 | go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= 35 | go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= 36 | golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= 37 | golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= 38 | golang.org/x/exp v0.0.0-20221205204356-47842c84f3db h1:D/cFflL63o2KSLJIwjlcIt8PR064j/xsmdEJL/YvY/o= 39 | golang.org/x/exp v0.0.0-20221205204356-47842c84f3db/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= 40 | golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= 41 | golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= 42 | golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= 43 | golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= 44 | golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= 45 | golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 46 | golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 47 | golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= 48 | golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 49 | golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= 50 | golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= 51 | golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= 52 | golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= 53 | golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= 54 | golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= 55 | google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= 56 | google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= 57 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 58 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 59 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 60 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 61 | -------------------------------------------------------------------------------- /security/mfa/middleware.go: -------------------------------------------------------------------------------- 1 | package mfa 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "sync" 7 | "time" 8 | 9 | "github.com/dormoron/mist" 10 | ) 11 | 12 | const ( 13 | // MFACookieName 用于标记MFA验证状态的Cookie名 14 | MFACookieName = "_mfa_validated" 15 | 16 | // MFASessionKey 用于在Session中存储MFA状态的键 17 | MFASessionKey = "_mfa_status" 18 | 19 | // DefaultValidationDuration MFA验证状态默认有效期 20 | DefaultValidationDuration = 12 * time.Hour 21 | ) 22 | 23 | var ( 24 | // ErrMFARequired 表示需要多因素验证 25 | ErrMFARequired = errors.New("需要多因素验证") 26 | 27 | // ErrInvalidMFACode 表示MFA验证码无效 28 | ErrInvalidMFACode = errors.New("无效的多因素验证码") 29 | ) 30 | 31 | // ValidationStore 存储MFA验证状态的接口 32 | type ValidationStore interface { 33 | // Validate 验证指定用户ID是否已完成MFA验证 34 | Validate(userID string) (bool, error) 35 | 36 | // Set 设置用户MFA验证状态 37 | Set(userID string, expiry time.Duration) error 38 | 39 | // Clear 清除用户MFA验证状态 40 | Clear(userID string) error 41 | } 42 | 43 | // MemoryStore 内存实现的MFA验证状态存储 44 | type MemoryStore struct { 45 | validations map[string]int64 46 | mu sync.RWMutex 47 | } 48 | 49 | // NewMemoryStore 创建新的内存验证状态存储 50 | func NewMemoryStore() *MemoryStore { 51 | return &MemoryStore{ 52 | validations: make(map[string]int64), 53 | } 54 | } 55 | 56 | // Validate 验证用户MFA状态 57 | func (s *MemoryStore) Validate(userID string) (bool, error) { 58 | s.mu.RLock() 59 | defer s.mu.RUnlock() 60 | 61 | expiryTime, exists := s.validations[userID] 62 | if !exists { 63 | return false, nil 64 | } 65 | 66 | // 如果过期时间已到,则验证失败 67 | if expiryTime < time.Now().Unix() { 68 | return false, nil 69 | } 70 | 71 | return true, nil 72 | } 73 | 74 | // Set 设置用户MFA验证状态 75 | func (s *MemoryStore) Set(userID string, expiry time.Duration) error { 76 | s.mu.Lock() 77 | defer s.mu.Unlock() 78 | 79 | s.validations[userID] = time.Now().Add(expiry).Unix() 80 | return nil 81 | } 82 | 83 | // Clear 清除用户MFA验证状态 84 | func (s *MemoryStore) Clear(userID string) error { 85 | s.mu.Lock() 86 | defer s.mu.Unlock() 87 | 88 | delete(s.validations, userID) 89 | return nil 90 | } 91 | 92 | // Config MFA中间件配置 93 | type MiddlewareConfig struct { 94 | // Store MFA验证状态存储 95 | Store ValidationStore 96 | 97 | // GetUserID 从请求上下文中获取用户ID的函数 98 | GetUserID func(*mist.Context) (string, error) 99 | 100 | // ValidationDuration MFA验证有效期 101 | ValidationDuration time.Duration 102 | 103 | // RedirectURL 未验证时重定向的URL 104 | RedirectURL string 105 | 106 | // OnUnauthorized 未验证时的处理函数 107 | OnUnauthorized func(*mist.Context) 108 | } 109 | 110 | // New 创建新的MFA中间件 111 | func NewMiddleware(options ...func(*MiddlewareConfig)) mist.Middleware { 112 | // 默认配置 113 | config := MiddlewareConfig{ 114 | Store: NewMemoryStore(), 115 | GetUserID: func(ctx *mist.Context) (string, error) { 116 | // 默认从上下文中获取user_id 117 | if id, exists := ctx.Get("user_id"); exists { 118 | if userID, ok := id.(string); ok { 119 | return userID, nil 120 | } 121 | } 122 | return "", errors.New("无法获取用户ID") 123 | }, 124 | ValidationDuration: DefaultValidationDuration, 125 | RedirectURL: "/mfa/validate", 126 | OnUnauthorized: func(ctx *mist.Context) { 127 | // 默认重定向到MFA验证页面 128 | ctx.Header("Location", "/mfa/validate") 129 | ctx.AbortWithStatus(http.StatusFound) 130 | }, 131 | } 132 | 133 | // 应用自定义选项 134 | for _, option := range options { 135 | option(&config) 136 | } 137 | 138 | return func(next mist.HandleFunc) mist.HandleFunc { 139 | return func(ctx *mist.Context) { 140 | // 获取用户ID 141 | userID, err := config.GetUserID(ctx) 142 | if err != nil { 143 | // 如果无法获取用户ID,视为未验证 144 | config.OnUnauthorized(ctx) 145 | return 146 | } 147 | 148 | // 检查是否已验证 149 | validated, err := config.Store.Validate(userID) 150 | if err != nil || !validated { 151 | config.OnUnauthorized(ctx) 152 | return 153 | } 154 | 155 | // 已验证,继续处理请求 156 | next(ctx) 157 | } 158 | } 159 | } 160 | 161 | // Validate 验证MFA代码 162 | func Validate(ctx *mist.Context, userID, code string, totp *TOTP, store ValidationStore, duration time.Duration) error { 163 | // 验证TOTP代码 164 | if !totp.Validate(code) { 165 | return ErrInvalidMFACode 166 | } 167 | 168 | // 设置验证状态 169 | return store.Set(userID, duration) 170 | } 171 | 172 | // ClearValidation 清除MFA验证状态 173 | func ClearValidation(userID string, store ValidationStore) error { 174 | return store.Clear(userID) 175 | } 176 | 177 | // 选项函数 178 | 179 | // WithStore 设置验证状态存储 180 | func WithStore(store ValidationStore) func(*MiddlewareConfig) { 181 | return func(c *MiddlewareConfig) { 182 | c.Store = store 183 | } 184 | } 185 | 186 | // WithGetUserID 设置获取用户ID的函数 187 | func WithGetUserID(fn func(*mist.Context) (string, error)) func(*MiddlewareConfig) { 188 | return func(c *MiddlewareConfig) { 189 | c.GetUserID = fn 190 | } 191 | } 192 | 193 | // WithValidationDuration 设置验证有效期 194 | func WithValidationDuration(duration time.Duration) func(*MiddlewareConfig) { 195 | return func(c *MiddlewareConfig) { 196 | c.ValidationDuration = duration 197 | } 198 | } 199 | 200 | // WithRedirectURL 设置重定向URL 201 | func WithRedirectURL(url string) func(*MiddlewareConfig) { 202 | return func(c *MiddlewareConfig) { 203 | c.RedirectURL = url 204 | } 205 | } 206 | 207 | // WithUnauthorizedHandler 设置未授权处理函数 208 | func WithUnauthorizedHandler(handler func(*mist.Context)) func(*MiddlewareConfig) { 209 | return func(c *MiddlewareConfig) { 210 | c.OnUnauthorized = handler 211 | } 212 | } 213 | -------------------------------------------------------------------------------- /validation/validator.go: -------------------------------------------------------------------------------- 1 | package validation 2 | 3 | import ( 4 | "fmt" 5 | "net/mail" 6 | "net/url" 7 | "regexp" 8 | "strconv" 9 | "strings" 10 | "unicode/utf8" 11 | ) 12 | 13 | // ValidationError 表示验证错误 14 | type ValidationError struct { 15 | Field string // 字段名称 16 | Message string // 错误信息 17 | } 18 | 19 | func (e ValidationError) Error() string { 20 | return fmt.Sprintf("%s: %s", e.Field, e.Message) 21 | } 22 | 23 | // Validator 提供数据验证功能 24 | type Validator struct { 25 | Errors []ValidationError 26 | } 27 | 28 | // NewValidator 创建一个新的验证器 29 | func NewValidator() *Validator { 30 | return &Validator{ 31 | Errors: make([]ValidationError, 0), 32 | } 33 | } 34 | 35 | // Valid 检查验证器是否有错误 36 | func (v *Validator) Valid() bool { 37 | return len(v.Errors) == 0 38 | } 39 | 40 | // AddError 添加一个验证错误 41 | func (v *Validator) AddError(field, message string) { 42 | v.Errors = append(v.Errors, ValidationError{ 43 | Field: field, 44 | Message: message, 45 | }) 46 | } 47 | 48 | // Check 检查条件是否成立,如果不成立则添加错误 49 | func (v *Validator) Check(ok bool, field, message string) { 50 | if !ok { 51 | v.AddError(field, message) 52 | } 53 | } 54 | 55 | // Required 检查字符串是否非空 56 | func (v *Validator) Required(value string, field string) { 57 | if strings.TrimSpace(value) == "" { 58 | v.AddError(field, "不能为空") 59 | } 60 | } 61 | 62 | // MinLength 检查字符串最小长度 63 | func (v *Validator) MinLength(value string, min int, field string) { 64 | if utf8.RuneCountInString(value) < min { 65 | v.AddError(field, fmt.Sprintf("长度不能小于%d个字符", min)) 66 | } 67 | } 68 | 69 | // MaxLength 检查字符串最大长度 70 | func (v *Validator) MaxLength(value string, max int, field string) { 71 | if utf8.RuneCountInString(value) > max { 72 | v.AddError(field, fmt.Sprintf("长度不能大于%d个字符", max)) 73 | } 74 | } 75 | 76 | // Between 检查字符串长度是否在指定范围内 77 | func (v *Validator) Between(value string, min, max int, field string) { 78 | length := utf8.RuneCountInString(value) 79 | if length < min || length > max { 80 | v.AddError(field, fmt.Sprintf("长度必须在%d到%d个字符之间", min, max)) 81 | } 82 | } 83 | 84 | // Email 验证邮箱格式 85 | func (v *Validator) Email(value string, field string) { 86 | if value == "" { 87 | return 88 | } 89 | 90 | _, err := mail.ParseAddress(value) 91 | if err != nil { 92 | v.AddError(field, "邮箱格式不正确") 93 | } 94 | } 95 | 96 | // URL 验证URL格式 97 | func (v *Validator) URL(value string, field string) { 98 | if value == "" { 99 | return 100 | } 101 | 102 | u, err := url.Parse(value) 103 | if err != nil || u.Scheme == "" || u.Host == "" { 104 | v.AddError(field, "URL格式不正确") 105 | } 106 | } 107 | 108 | // Alpha 验证字符串仅包含字母 109 | func (v *Validator) Alpha(value string, field string) { 110 | if value == "" { 111 | return 112 | } 113 | 114 | matched, _ := regexp.MatchString("^[a-zA-Z]+$", value) 115 | if !matched { 116 | v.AddError(field, "只能包含字母") 117 | } 118 | } 119 | 120 | // Alphanumeric 验证字符串仅包含字母和数字 121 | func (v *Validator) Alphanumeric(value string, field string) { 122 | if value == "" { 123 | return 124 | } 125 | 126 | matched, _ := regexp.MatchString("^[a-zA-Z0-9]+$", value) 127 | if !matched { 128 | v.AddError(field, "只能包含字母和数字") 129 | } 130 | } 131 | 132 | // Numeric 验证字符串是否为数字 133 | func (v *Validator) Numeric(value string, field string) { 134 | if value == "" { 135 | return 136 | } 137 | 138 | _, err := strconv.ParseFloat(value, 64) 139 | if err != nil { 140 | v.AddError(field, "必须是数字") 141 | } 142 | } 143 | 144 | // Integer 验证字符串是否为整数 145 | func (v *Validator) Integer(value string, field string) { 146 | if value == "" { 147 | return 148 | } 149 | 150 | _, err := strconv.Atoi(value) 151 | if err != nil { 152 | v.AddError(field, "必须是整数") 153 | } 154 | } 155 | 156 | // Range 验证数字是否在指定范围内 157 | func (v *Validator) Range(value int, min, max int, field string) { 158 | if value < min || value > max { 159 | v.AddError(field, fmt.Sprintf("必须在%d到%d之间", min, max)) 160 | } 161 | } 162 | 163 | // RangeFloat 验证浮点数是否在指定范围内 164 | func (v *Validator) RangeFloat(value float64, min, max float64, field string) { 165 | if value < min || value > max { 166 | v.AddError(field, fmt.Sprintf("必须在%.2f到%.2f之间", min, max)) 167 | } 168 | } 169 | 170 | // PhoneNumber 验证手机号格式(中国大陆) 171 | func (v *Validator) PhoneNumber(value string, field string) { 172 | if value == "" { 173 | return 174 | } 175 | 176 | matched, _ := regexp.MatchString(`^1[3-9]\d{9}$`, value) 177 | if !matched { 178 | v.AddError(field, "手机号格式不正确") 179 | } 180 | } 181 | 182 | // ChineseIDCard 验证中国大陆身份证号 183 | func (v *Validator) ChineseIDCard(value string, field string) { 184 | if value == "" { 185 | return 186 | } 187 | 188 | // 18位身份证正则 189 | matched, _ := regexp.MatchString(`^\d{17}[\dXx]$`, value) 190 | if !matched { 191 | v.AddError(field, "身份证号格式不正确") 192 | return 193 | } 194 | 195 | // 可以添加更复杂的验证,如校验码验证等 196 | } 197 | 198 | // InList 验证值是否在列表中 199 | func (v *Validator) InList(value string, list []string, field string) { 200 | for _, item := range list { 201 | if value == item { 202 | return 203 | } 204 | } 205 | v.AddError(field, "值不在允许的范围内") 206 | } 207 | 208 | // NotInList 验证值是否不在列表中 209 | func (v *Validator) NotInList(value string, list []string, field string) { 210 | for _, item := range list { 211 | if value == item { 212 | v.AddError(field, "值不允许使用") 213 | return 214 | } 215 | } 216 | } 217 | 218 | // Custom 自定义验证 219 | func (v *Validator) Custom(ok bool, field, message string) { 220 | if !ok { 221 | v.AddError(field, message) 222 | } 223 | } 224 | -------------------------------------------------------------------------------- /security/types.go: -------------------------------------------------------------------------------- 1 | package security 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/dormoron/mist" 7 | "github.com/dormoron/mist/internal/errs" 8 | ) 9 | 10 | // Session interface defines multiple methods for session management. 11 | type Session interface { 12 | // Set assigns a value to a session key. The context is typically used for request-scoped values. 13 | // Parameters: 14 | // - ctx: context for managing deadlines, cancel operation signals, and other request-scoped values ('context.Context') 15 | // - key: the key under which the value is stored ('string') 16 | // - val: the value to store, which can be of any type ('any') 17 | // Returns: 18 | // - error: error, if any occurred while setting the value 19 | Set(ctx context.Context, key string, val any) error 20 | 21 | // Get retrieves the value associated with the key from the session. 22 | // Parameters: 23 | // - ctx: context for managing deadlines, cancel operation signals, and other request-scoped values ('context.Context') 24 | // - key: the key for the value to be retrieved ('string') 25 | // Returns: 26 | // - mist.AnyValue: a wrapper containing the retrieved value or an error if the key wasn't found 27 | Get(ctx context.Context, key string) mist.AnyValue 28 | 29 | // Del deletes the key-value pair associated with the key from the session. 30 | // Parameters: 31 | // - ctx: context for managing deadlines, cancel operation signals, and other request-scoped values ('context.Context') 32 | // - key: the key for the value to be deleted ('string') 33 | // Returns: 34 | // - error: error, if any occurred while deleting the value 35 | Del(ctx context.Context, key string) error 36 | 37 | // Destroy invalidates the session entirely, clearing all data within the session. 38 | // Parameters: 39 | // - ctx: context for managing deadlines, cancel operation signals, and other request-scoped values ('context.Context') 40 | // Returns: 41 | // - error: error, if any occurred while destroying the session 42 | Destroy(ctx context.Context) error 43 | 44 | // Claims retrieves the claims associated with the session. Claims usually contain user-related data, often in a JWT context. 45 | // Returns: 46 | // - Claims: a set of claims related to the session 47 | Claims() Claims 48 | } 49 | 50 | // Provider interface defines methods for session lifecycle management and JWT claim updates. 51 | type Provider interface { 52 | // InitSession initializes a new session with the specified user ID, JWT data, and session data. 53 | // Parameters: 54 | // - ctx: context for managing deadlines, cancel operation signals, and other request-scoped values ('mist.Context') 55 | // - uid: user ID for which the session is being created ('int64') 56 | // - jwtData: JWT token data (usually claims) to store with the session ('map[string]any') 57 | // - sessData: additional session-specific data to associate with the session ('map[string]any') 58 | // Returns: 59 | // - Session: the initialized session 60 | // - error: error, if any occurred while initializing the session 61 | InitSession(ctx *mist.Context, uid int64, jwtData map[string]any, sessData map[string]any) (Session, error) 62 | 63 | // Get retrieves the current session associated with the context. 64 | // Parameters: 65 | // - ctx: context for managing deadlines, cancel operation signals, and other request-scoped values ('mist.Context') 66 | // Returns: 67 | // - Session: the current session 68 | // - error: error, if any occurred while retrieving the session 69 | Get(ctx *mist.Context) (Session, error) 70 | 71 | // UpdateClaims updates the claims associated with the current session. 72 | // Parameters: 73 | // - ctx: context for managing deadlines, cancel operation signals, and other request-scoped values ('mist.Context') 74 | // - claims: a new set of claims to associate with the session ('Claims') 75 | // Returns: 76 | // - error: error, if any occurred while updating the claims 77 | UpdateClaims(ctx *mist.Context, claims Claims) error 78 | 79 | // RenewAccessToken renews the access token associated with the session. 80 | // Parameters: 81 | // - ctx: context for managing deadlines, cancel operation signals, and other request-scoped values ('mist.Context') 82 | // Returns: 83 | // - error: error, if any occurred while renewing the access token 84 | RenewAccessToken(ctx *mist.Context) error 85 | 86 | // The ClearToken function is designed to remove or invalidate a security or session token associated with the given context. 87 | // 88 | // Parameters: 89 | // ctx: A pointer to a mist.Context object, which holds contextual information for the function to operate within. 90 | // The mist.Context might include various details like user information, request scope, or environmental settings. 91 | // 92 | // Return: 93 | // error: This function returns an error type. If the token clearing process fails for any reason (e.g., token doesn't exist, 94 | // network issues, permission issues), the function will return a non-nil error indicating what went wrong. 95 | // If the token clearing process is successful, it returns nil. 96 | ClearToken(ctx *mist.Context) error 97 | } 98 | 99 | // Claims structure holds the data associated with the session's JWT claims. 100 | type Claims struct { 101 | UserID int64 // User ID 102 | SessionID string // Session ID 103 | Data map[string]any // Additional data related to the claims 104 | } 105 | 106 | // Get retrieves the value associated with the key from the claims. 107 | func (c Claims) Get(key string) mist.AnyValue { 108 | val, ok := c.Data[key] 109 | if !ok { 110 | return mist.AnyValue{Err: errs.ErrKeyNotFound(key)} // Return an error if the key is not found 111 | } 112 | return mist.AnyValue{Val: val} // Return the value if the key is found 113 | } 114 | -------------------------------------------------------------------------------- /middlewares/bodylimit/body_limit.go: -------------------------------------------------------------------------------- 1 | package bodylimit 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net/http" 7 | "strconv" 8 | "strings" 9 | 10 | "github.com/dormoron/mist" 11 | ) 12 | 13 | // BodyLimitConfig 配置请求体大小限制中间件 14 | type BodyLimitConfig struct { 15 | // 最大允许大小(字节数) 16 | MaxSize int64 17 | 18 | // 响应状态码,默认413 Request Entity Too Large 19 | StatusCode int 20 | 21 | // 超过限制时的错误消息 22 | ErrorMessage string 23 | 24 | // 白名单路径 - 不受限制的路径前缀 25 | WhitelistPaths []string 26 | 27 | // 跳过OPTIONS/HEAD请求的检查 28 | SkipOptions bool 29 | SkipHead bool 30 | 31 | // 自定义检查是否需要限制的函数 32 | SkipFunc func(ctx *mist.Context) bool 33 | } 34 | 35 | // DefaultBodyLimitConfig 返回默认配置 36 | func DefaultBodyLimitConfig() BodyLimitConfig { 37 | return BodyLimitConfig{ 38 | MaxSize: 1 * 1024 * 1024, // 默认1MB 39 | StatusCode: http.StatusRequestEntityTooLarge, 40 | ErrorMessage: "请求体超过允许的大小限制", 41 | SkipOptions: true, 42 | SkipHead: true, 43 | } 44 | } 45 | 46 | // BodyLimit 创建请求体大小限制中间件 47 | func BodyLimit(maxSize string) mist.Middleware { 48 | size, err := parseSize(maxSize) 49 | if err != nil { 50 | panic(fmt.Sprintf("无效的大小限制: %v", err)) 51 | } 52 | 53 | config := DefaultBodyLimitConfig() 54 | config.MaxSize = size 55 | 56 | return BodyLimitWithConfig(config) 57 | } 58 | 59 | // BodyLimitWithConfig 使用自定义配置创建请求体大小限制中间件 60 | func BodyLimitWithConfig(config BodyLimitConfig) mist.Middleware { 61 | // 使用默认值填充未设置的配置 62 | if config.MaxSize <= 0 { 63 | config.MaxSize = DefaultBodyLimitConfig().MaxSize 64 | } 65 | 66 | if config.StatusCode <= 0 { 67 | config.StatusCode = DefaultBodyLimitConfig().StatusCode 68 | } 69 | 70 | if config.ErrorMessage == "" { 71 | config.ErrorMessage = DefaultBodyLimitConfig().ErrorMessage 72 | } 73 | 74 | return func(next mist.HandleFunc) mist.HandleFunc { 75 | return func(ctx *mist.Context) { 76 | // 检查是否需要跳过限制 77 | if shouldSkip(ctx, config) { 78 | next(ctx) 79 | return 80 | } 81 | 82 | // 检查Content-Length头 83 | contentLength := ctx.Request.ContentLength 84 | if contentLength > config.MaxSize { 85 | ctx.AbortWithStatus(config.StatusCode) 86 | ctx.RespondWithJSON(config.StatusCode, map[string]interface{}{ 87 | "error": config.ErrorMessage, 88 | "limit": config.MaxSize, 89 | "size": contentLength, 90 | }) 91 | return 92 | } 93 | 94 | // 限制请求体的大小 95 | ctx.Request.Body = limitedReader(ctx.Request.Body, config.MaxSize, ctx, config) 96 | 97 | next(ctx) 98 | } 99 | } 100 | } 101 | 102 | // shouldSkip 检查是否应该跳过该请求的限制 103 | func shouldSkip(ctx *mist.Context, config BodyLimitConfig) bool { 104 | // 检查自定义跳过函数 105 | if config.SkipFunc != nil && config.SkipFunc(ctx) { 106 | return true 107 | } 108 | 109 | // 检查HTTP方法 110 | method := ctx.Request.Method 111 | if (config.SkipOptions && method == http.MethodOptions) || 112 | (config.SkipHead && method == http.MethodHead) { 113 | return true 114 | } 115 | 116 | // 检查白名单路径 117 | if len(config.WhitelistPaths) > 0 { 118 | path := ctx.Request.URL.Path 119 | for _, prefix := range config.WhitelistPaths { 120 | if strings.HasPrefix(path, prefix) { 121 | return true 122 | } 123 | } 124 | } 125 | 126 | return false 127 | } 128 | 129 | // limitedReader 返回一个受限制的读取器 130 | func limitedReader(body io.ReadCloser, limit int64, ctx *mist.Context, config BodyLimitConfig) io.ReadCloser { 131 | return &limitedReadCloser{ 132 | ReadCloser: body, 133 | limit: limit, 134 | ctx: ctx, 135 | config: config, 136 | read: 0, 137 | } 138 | } 139 | 140 | // limitedReadCloser 是一个限制大小的ReadCloser实现 141 | type limitedReadCloser struct { 142 | io.ReadCloser 143 | limit int64 144 | read int64 145 | ctx *mist.Context 146 | config BodyLimitConfig 147 | } 148 | 149 | // Read 实现io.Reader接口,限制读取的总大小 150 | func (l *limitedReadCloser) Read(p []byte) (n int, err error) { 151 | n, err = l.ReadCloser.Read(p) 152 | l.read += int64(n) 153 | 154 | // 检查是否超过限制 155 | if l.read > l.limit { 156 | // 中止请求 157 | l.ctx.AbortWithStatus(l.config.StatusCode) 158 | _ = l.ctx.RespondWithJSON(l.config.StatusCode, map[string]interface{}{ 159 | "error": l.config.ErrorMessage, 160 | "limit": l.limit, 161 | }) 162 | return n, fmt.Errorf("请求体超过大小限制: %d > %d", l.read, l.limit) 163 | } 164 | 165 | return n, err 166 | } 167 | 168 | // parseSize 解析人类可读的大小字符串 169 | // 支持单位: B, K/KB, M/MB, G/GB 170 | func parseSize(sizeStr string) (int64, error) { 171 | sizeStr = strings.TrimSpace(sizeStr) 172 | if sizeStr == "" { 173 | return 0, fmt.Errorf("空大小字符串") 174 | } 175 | 176 | // 查找单位分隔符 177 | var numStr string 178 | var unit string 179 | 180 | if strings.HasSuffix(sizeStr, "B") { 181 | if len(sizeStr) > 2 && (sizeStr[len(sizeStr)-2] == 'K' || 182 | sizeStr[len(sizeStr)-2] == 'M' || 183 | sizeStr[len(sizeStr)-2] == 'G') { 184 | numStr = sizeStr[:len(sizeStr)-2] 185 | unit = sizeStr[len(sizeStr)-2:] 186 | } else { 187 | numStr = sizeStr[:len(sizeStr)-1] 188 | unit = "B" 189 | } 190 | } else if strings.HasSuffix(sizeStr, "K") || 191 | strings.HasSuffix(sizeStr, "M") || 192 | strings.HasSuffix(sizeStr, "G") { 193 | numStr = sizeStr[:len(sizeStr)-1] 194 | unit = sizeStr[len(sizeStr)-1:] 195 | } else { 196 | // 假设为纯数字 197 | numStr = sizeStr 198 | unit = "B" 199 | } 200 | 201 | // 解析数字部分 202 | size, err := strconv.ParseFloat(numStr, 64) 203 | if err != nil { 204 | return 0, fmt.Errorf("无效的大小格式: %v", err) 205 | } 206 | 207 | // 应用单位 208 | switch unit { 209 | case "B": 210 | return int64(size), nil 211 | case "K", "KB": 212 | return int64(size * 1024), nil 213 | case "M", "MB": 214 | return int64(size * 1024 * 1024), nil 215 | case "G", "GB": 216 | return int64(size * 1024 * 1024 * 1024), nil 217 | default: 218 | return 0, fmt.Errorf("未知的大小单位: %s", unit) 219 | } 220 | } 221 | -------------------------------------------------------------------------------- /security/blocklist/example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "net/http" 9 | "time" 10 | 11 | "github.com/dormoron/mist/security/blocklist" 12 | ) 13 | 14 | type LoginRequest struct { 15 | Username string `json:"username"` 16 | Password string `json:"password"` 17 | } 18 | 19 | type Response struct { 20 | Status string `json:"status"` 21 | Message string `json:"message"` 22 | } 23 | 24 | func main() { 25 | // 解析命令行参数 26 | port := flag.Int("port", 8080, "服务器端口") 27 | flag.Parse() 28 | 29 | // 创建IP黑名单管理器 30 | blocklistManager := blocklist.NewManager( 31 | blocklist.WithMaxFailedAttempts(3), 32 | blocklist.WithBlockDuration(5*time.Minute), 33 | blocklist.WithWhitelistIPs([]string{"127.0.0.1"}), 34 | blocklist.WithOnBlocked(func(w http.ResponseWriter, r *http.Request) { 35 | w.Header().Set("Content-Type", "application/json") 36 | w.WriteHeader(http.StatusForbidden) 37 | json.NewEncoder(w).Encode(Response{ 38 | Status: "error", 39 | Message: "您的IP因多次失败的尝试已被暂时封禁,请稍后再试", 40 | }) 41 | }), 42 | ) 43 | 44 | // 创建路由器 45 | mux := http.NewServeMux() 46 | 47 | // 登录接口 48 | mux.HandleFunc("/api/login", func(w http.ResponseWriter, r *http.Request) { 49 | if r.Method != http.MethodPost { 50 | w.WriteHeader(http.StatusMethodNotAllowed) 51 | return 52 | } 53 | 54 | ip := getClientIP(r) 55 | 56 | // 如果IP已被封禁,直接返回错误 57 | if blocklistManager.IsBlocked(ip) { 58 | blocklistManager.RecordFailure(ip) // 增加失败次数 59 | w.Header().Set("Content-Type", "application/json") 60 | w.WriteHeader(http.StatusForbidden) 61 | json.NewEncoder(w).Encode(Response{ 62 | Status: "error", 63 | Message: "您的IP已被封禁,请稍后再试", 64 | }) 65 | return 66 | } 67 | 68 | // 解析登录请求 69 | var req LoginRequest 70 | err := json.NewDecoder(r.Body).Decode(&req) 71 | if err != nil { 72 | blocklistManager.RecordFailure(ip) // 增加失败次数 73 | w.Header().Set("Content-Type", "application/json") 74 | w.WriteHeader(http.StatusBadRequest) 75 | json.NewEncoder(w).Encode(Response{ 76 | Status: "error", 77 | Message: "无效的请求格式", 78 | }) 79 | return 80 | } 81 | 82 | // 模拟身份验证(简化示例) 83 | if req.Username == "admin" && req.Password == "password" { 84 | // 登录成功,重置失败计数 85 | blocklistManager.RecordSuccess(ip) 86 | w.Header().Set("Content-Type", "application/json") 87 | w.WriteHeader(http.StatusOK) 88 | json.NewEncoder(w).Encode(Response{ 89 | Status: "success", 90 | Message: "登录成功", 91 | }) 92 | return 93 | } 94 | 95 | // 登录失败,记录失败尝试 96 | isBlocked := blocklistManager.RecordFailure(ip) 97 | w.Header().Set("Content-Type", "application/json") 98 | 99 | if isBlocked { 100 | w.WriteHeader(http.StatusForbidden) 101 | json.NewEncoder(w).Encode(Response{ 102 | Status: "error", 103 | Message: "您的IP因多次失败的尝试已被封禁,请稍后再试", 104 | }) 105 | } else { 106 | w.WriteHeader(http.StatusUnauthorized) 107 | json.NewEncoder(w).Encode(Response{ 108 | Status: "error", 109 | Message: "用户名或密码错误", 110 | }) 111 | } 112 | }) 113 | 114 | // 受保护的API接口 115 | protectedAPI := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 116 | w.Header().Set("Content-Type", "application/json") 117 | w.WriteHeader(http.StatusOK) 118 | json.NewEncoder(w).Encode(Response{ 119 | Status: "success", 120 | Message: "这是受保护的API接口", 121 | }) 122 | }) 123 | 124 | // 应用IP黑名单中间件到受保护的API 125 | mux.Handle("/api/protected", blocklistManager.Middleware()(protectedAPI)) 126 | 127 | // 管理接口 - 手动解除IP封禁(通常需要管理员权限) 128 | mux.HandleFunc("/api/admin/unblock", func(w http.ResponseWriter, r *http.Request) { 129 | if r.Method != http.MethodPost { 130 | w.WriteHeader(http.StatusMethodNotAllowed) 131 | return 132 | } 133 | 134 | // 这里应该添加管理员验证逻辑 135 | // 简化示例,直接从请求中获取IP 136 | ipToUnblock := r.URL.Query().Get("ip") 137 | if ipToUnblock == "" { 138 | w.Header().Set("Content-Type", "application/json") 139 | w.WriteHeader(http.StatusBadRequest) 140 | json.NewEncoder(w).Encode(Response{ 141 | Status: "error", 142 | Message: "请指定要解除封禁的IP", 143 | }) 144 | return 145 | } 146 | 147 | blocklistManager.UnblockIP(ipToUnblock) 148 | w.Header().Set("Content-Type", "application/json") 149 | w.WriteHeader(http.StatusOK) 150 | json.NewEncoder(w).Encode(Response{ 151 | Status: "success", 152 | Message: fmt.Sprintf("IP %s 已解除封禁", ipToUnblock), 153 | }) 154 | }) 155 | 156 | // 状态检查接口 157 | mux.HandleFunc("/api/admin/status", func(w http.ResponseWriter, r *http.Request) { 158 | ipToCheck := r.URL.Query().Get("ip") 159 | if ipToCheck == "" { 160 | w.Header().Set("Content-Type", "application/json") 161 | w.WriteHeader(http.StatusBadRequest) 162 | json.NewEncoder(w).Encode(Response{ 163 | Status: "error", 164 | Message: "请指定要检查的IP", 165 | }) 166 | return 167 | } 168 | 169 | // 检查IP状态 170 | isBlocked := blocklistManager.IsBlocked(ipToCheck) 171 | 172 | status := "正常" 173 | if isBlocked { 174 | status = "已封禁" 175 | } 176 | 177 | w.Header().Set("Content-Type", "application/json") 178 | w.WriteHeader(http.StatusOK) 179 | json.NewEncoder(w).Encode(map[string]interface{}{ 180 | "status": "success", 181 | "ip": ipToCheck, 182 | "isBlocked": isBlocked, 183 | "state": status, 184 | }) 185 | }) 186 | 187 | // 启动服务器 188 | serverAddr := fmt.Sprintf(":%d", *port) 189 | log.Printf("Starting server on %s", serverAddr) 190 | log.Fatal(http.ListenAndServe(serverAddr, mux)) 191 | } 192 | 193 | // 从请求中获取客户端IP 194 | func getClientIP(r *http.Request) string { 195 | // 尝试从X-Forwarded-For头获取 196 | ip := r.Header.Get("X-Forwarded-For") 197 | if ip != "" { 198 | return ip 199 | } 200 | 201 | // 尝试从X-Real-IP头获取 202 | ip = r.Header.Get("X-Real-IP") 203 | if ip != "" { 204 | return ip 205 | } 206 | 207 | // 否则使用RemoteAddr 208 | return r.RemoteAddr 209 | } 210 | -------------------------------------------------------------------------------- /OPTIMIZATION.md: -------------------------------------------------------------------------------- 1 | # Mist框架性能优化 2 | 3 | 本文档描述了Mist框架的性能优化实现,这些优化旨在提高框架的性能、可靠性和可扩展性。 4 | 5 | ## 1. HTTP/3支持优化 6 | 7 | - **问题**: 当前HTTP/3支持仅为实验性功能,缺少完整实现 8 | - **解决方案**: 9 | - 实现了完整的HTTP/3服务器支持,基于quic-go/http3和quic-go/quic-go库 10 | - 提供了平台特定的实现,确保跨平台兼容性 11 | - 支持Alt-Svc响应头,便于客户端发现HTTP/3能力 12 | - 自动回退机制,当客户端不支持HTTP/3时回退到HTTP/2 13 | 14 | ## 2. 自适应路由缓存 15 | 16 | - **问题**: 现有的路由缓存机制不够智能,无法根据请求模式动态调整 17 | - **解决方案**: 18 | - 实现了自适应路由缓存(AdaptiveCache) 19 | - 根据路由访问频率、响应时间和最近访问时间自动调整缓存优先级 20 | - 提供权重计算机制,确保高频访问路径始终保持在缓存中 21 | - 支持定期清理和智能淘汰策略 22 | 23 | ## 3. 上下文对象池优化 24 | 25 | - **问题**: Context对象频繁创建和销毁,导致GC压力大 26 | - **解决方案**: 27 | - 增强了Context对象池实现,优化回收和复用逻辑 28 | - 添加了缓冲区池,减少响应数据的内存分配 29 | - 实现了细粒度的资源释放策略,避免内存泄漏 30 | 31 | ## 4. 零拷贝响应机制 32 | 33 | - **问题**: 文件传输需要在用户空间和内核空间之间多次复制,浪费CPU和内存资源 34 | - **解决方案**: 35 | - 实现了零拷贝文件传输机制(ZeroCopyResponse) 36 | - 利用系统级sendfile调用,减少数据拷贝 37 | - 添加了跨平台支持,在不支持零拷贝的平台上自动回退 38 | - 支持Range请求,实现部分内容传输 39 | 40 | ## 5. 请求体限制中间件 41 | 42 | - **问题**: 缺少轻量级的请求体大小限制中间件,易受大请求攻击 43 | - **解决方案**: 44 | - 实现了可配置的请求体大小限制中间件(BodyLimit) 45 | - 支持基于路径的白名单和黑名单 46 | - 可根据HTTP方法选择性应用 47 | - 支持自定义错误响应和状态码 48 | 49 | ## 6. 内存使用监控 50 | 51 | - **问题**: 缺少内置的内存使用监控功能,难以发现内存泄漏和优化内存使用 52 | - **解决方案**: 53 | - 实现了MemoryMonitor组件,提供实时内存使用统计 54 | - 支持告警机制,内存异常增长时自动通知 55 | - 提供详细的内存使用报告,包括趋势分析 56 | - 支持手动触发GC功能,便于测试和调试 57 | 58 | ## 测试结果 59 | 60 | 各项优化通过单元测试验证,测试文件位于`optimizations_test.go`。测试结果表明: 61 | 62 | 1. 自适应路由缓存能够正确根据访问模式调整缓存内容 63 | 2. 请求体限制中间件能够有效阻止超大请求 64 | 3. 零拷贝文件传输机制能够正常工作 65 | 4. 内存监控功能能够准确跟踪内存使用情况 66 | 67 | ## 使用方法 68 | 69 | ### HTTP/3支持 70 | 71 | ```go 72 | server := mist.InitHTTPServer() 73 | // 启用HTTP/3 74 | err := server.StartHTTP3(":443", "cert.pem", "key.pem") 75 | ``` 76 | 77 | ### 请求体限制中间件 78 | 79 | ```go 80 | server := mist.InitHTTPServer() 81 | // 全局限制请求体大小为1MB 82 | server.Use(middlewares.BodyLimit("1MB")) 83 | 84 | // 或使用自定义配置 85 | config := middlewares.DefaultBodyLimitConfig() 86 | config.MaxSize = 2 * 1024 * 1024 // 2MB 87 | config.WhitelistPaths = []string{"/upload"} 88 | server.Use(middlewares.BodyLimitWithConfig(config)) 89 | ``` 90 | 91 | ### 零拷贝文件传输 92 | 93 | ```go 94 | server.GET("/download/:file", func(ctx *mist.Context) { 95 | filePath := "path/to/files/" + ctx.PathValue("file").String() 96 | zr := mist.NewZeroCopyResponse(ctx.ResponseWriter) 97 | err := zr.ServeFile(filePath) 98 | if err != nil { 99 | ctx.RespondWithJSON(http.StatusInternalServerError, map[string]string{ 100 | "error": err.Error(), 101 | }) 102 | } 103 | }) 104 | ``` 105 | 106 | ### 内存监控 107 | 108 | ```go 109 | monitor := mist.NewMemoryMonitor(10*time.Second, 60) 110 | monitor.AddAlertCallback(func(stats mist.MemStats, message string) { 111 | log.Printf("Memory alert: %s, Alloc: %d bytes", message, stats.Alloc) 112 | }) 113 | monitor.Start() 114 | defer monitor.Stop() 115 | 116 | // 获取内存报告 117 | report := monitor.GetMemoryUsageReport() 118 | fmt.Printf("Memory report: %+v\n", report) 119 | ``` 120 | 121 | 本文档描述了Mist框架的性能优化实现,这些优化旨在提高框架的性能、可靠性和可扩展性。 122 | 123 | ## 1. HTTP/3支持优化 124 | 125 | - **问题**: 当前HTTP/3支持仅为实验性功能,缺少完整实现 126 | - **解决方案**: 127 | - 实现了完整的HTTP/3服务器支持,基于quic-go/http3和quic-go/quic-go库 128 | - 提供了平台特定的实现,确保跨平台兼容性 129 | - 支持Alt-Svc响应头,便于客户端发现HTTP/3能力 130 | - 自动回退机制,当客户端不支持HTTP/3时回退到HTTP/2 131 | 132 | ## 2. 自适应路由缓存 133 | 134 | - **问题**: 现有的路由缓存机制不够智能,无法根据请求模式动态调整 135 | - **解决方案**: 136 | - 实现了自适应路由缓存(AdaptiveCache) 137 | - 根据路由访问频率、响应时间和最近访问时间自动调整缓存优先级 138 | - 提供权重计算机制,确保高频访问路径始终保持在缓存中 139 | - 支持定期清理和智能淘汰策略 140 | 141 | ## 3. 上下文对象池优化 142 | 143 | - **问题**: Context对象频繁创建和销毁,导致GC压力大 144 | - **解决方案**: 145 | - 增强了Context对象池实现,优化回收和复用逻辑 146 | - 添加了缓冲区池,减少响应数据的内存分配 147 | - 实现了细粒度的资源释放策略,避免内存泄漏 148 | 149 | ## 4. 零拷贝响应机制 150 | 151 | - **问题**: 文件传输需要在用户空间和内核空间之间多次复制,浪费CPU和内存资源 152 | - **解决方案**: 153 | - 实现了零拷贝文件传输机制(ZeroCopyResponse) 154 | - 利用系统级sendfile调用,减少数据拷贝 155 | - 添加了跨平台支持,在不支持零拷贝的平台上自动回退 156 | - 支持Range请求,实现部分内容传输 157 | 158 | ## 5. 请求体限制中间件 159 | 160 | - **问题**: 缺少轻量级的请求体大小限制中间件,易受大请求攻击 161 | - **解决方案**: 162 | - 实现了可配置的请求体大小限制中间件(BodyLimit) 163 | - 支持基于路径的白名单和黑名单 164 | - 可根据HTTP方法选择性应用 165 | - 支持自定义错误响应和状态码 166 | 167 | ## 6. 内存使用监控 168 | 169 | - **问题**: 缺少内置的内存使用监控功能,难以发现内存泄漏和优化内存使用 170 | - **解决方案**: 171 | - 实现了MemoryMonitor组件,提供实时内存使用统计 172 | - 支持告警机制,内存异常增长时自动通知 173 | - 提供详细的内存使用报告,包括趋势分析 174 | - 支持手动触发GC功能,便于测试和调试 175 | 176 | ## 测试结果 177 | 178 | 各项优化通过单元测试验证,测试文件位于`optimizations_test.go`。测试结果表明: 179 | 180 | 1. 自适应路由缓存能够正确根据访问模式调整缓存内容 181 | 2. 请求体限制中间件能够有效阻止超大请求 182 | 3. 零拷贝文件传输机制能够正常工作 183 | 4. 内存监控功能能够准确跟踪内存使用情况 184 | 185 | ## 使用方法 186 | 187 | ### HTTP/3支持 188 | 189 | ```go 190 | server := mist.InitHTTPServer() 191 | // 启用HTTP/3 192 | err := server.StartHTTP3(":443", "cert.pem", "key.pem") 193 | ``` 194 | 195 | ### 请求体限制中间件 196 | 197 | ```go 198 | server := mist.InitHTTPServer() 199 | // 全局限制请求体大小为1MB 200 | server.Use(middlewares.BodyLimit("1MB")) 201 | 202 | // 或使用自定义配置 203 | config := middlewares.DefaultBodyLimitConfig() 204 | config.MaxSize = 2 * 1024 * 1024 // 2MB 205 | config.WhitelistPaths = []string{"/upload"} 206 | server.Use(middlewares.BodyLimitWithConfig(config)) 207 | ``` 208 | 209 | ### 零拷贝文件传输 210 | 211 | ```go 212 | server.GET("/download/:file", func(ctx *mist.Context) { 213 | filePath := "path/to/files/" + ctx.PathValue("file").String() 214 | zr := mist.NewZeroCopyResponse(ctx.ResponseWriter) 215 | err := zr.ServeFile(filePath) 216 | if err != nil { 217 | ctx.RespondWithJSON(http.StatusInternalServerError, map[string]string{ 218 | "error": err.Error(), 219 | }) 220 | } 221 | }) 222 | ``` 223 | 224 | ### 内存监控 225 | 226 | ```go 227 | monitor := mist.NewMemoryMonitor(10*time.Second, 60) 228 | monitor.AddAlertCallback(func(stats mist.MemStats, message string) { 229 | log.Printf("Memory alert: %s, Alloc: %d bytes", message, stats.Alloc) 230 | }) 231 | monitor.Start() 232 | defer monitor.Stop() 233 | 234 | // 获取内存报告 235 | report := monitor.GetMemoryUsageReport() 236 | fmt.Printf("Memory report: %+v\n", report) 237 | ``` -------------------------------------------------------------------------------- /internal/errs/api_errors.go: -------------------------------------------------------------------------------- 1 | package errs 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "net/http" 7 | ) 8 | 9 | // ErrorType 定义标准错误类型 10 | type ErrorType string 11 | 12 | const ( 13 | // 错误类型常量 14 | ErrorTypeValidation ErrorType = "VALIDATION_ERROR" // 数据验证错误 15 | ErrorTypeAuth ErrorType = "AUTH_ERROR" // 认证错误 16 | ErrorTypePermission ErrorType = "PERMISSION_ERROR" // 权限错误 17 | ErrorTypeResource ErrorType = "RESOURCE_ERROR" // 资源未找到 18 | ErrorTypeInput ErrorType = "INPUT_ERROR" // 输入错误 19 | ErrorTypeInternal ErrorType = "INTERNAL_ERROR" // 内部服务器错误 20 | ErrorTypeUnavailable ErrorType = "UNAVAILABLE_ERROR" // 服务不可用 21 | ErrorTypeRateLimit ErrorType = "RATE_LIMIT_ERROR" // 限流错误 22 | ErrorTypeTimeout ErrorType = "TIMEOUT_ERROR" // 超时错误 23 | ErrorTypeUnprocessable ErrorType = "UNPROCESSABLE_ERROR" // 无法处理的实体 24 | ) 25 | 26 | // APIError 统一API错误结构 27 | type APIError struct { 28 | Type ErrorType `json:"type"` // 错误类型 29 | Code int `json:"code"` // HTTP状态码 30 | Message string `json:"message"` // 用户友好的错误信息 31 | Details any `json:"details,omitempty"` // 详细错误信息(可选) 32 | } 33 | 34 | // Error 实现error接口 35 | func (e *APIError) Error() string { 36 | return e.Message 37 | } 38 | 39 | // ToJSON 将APIError转换为JSON格式 40 | func (e *APIError) ToJSON() []byte { 41 | data, err := json.Marshal(e) 42 | if err != nil { 43 | return []byte(`{"type":"INTERNAL_ERROR","code":500,"message":"Error serializing error response"}`) 44 | } 45 | return data 46 | } 47 | 48 | // WithDetails 添加错误详情 49 | func (e *APIError) WithDetails(details any) *APIError { 50 | e.Details = details 51 | return e 52 | } 53 | 54 | // 预定义错误 - 输入验证 55 | func NewValidationError(message string) *APIError { 56 | return &APIError{ 57 | Type: ErrorTypeValidation, 58 | Code: http.StatusBadRequest, 59 | Message: message, 60 | } 61 | } 62 | 63 | // 预定义错误 - 认证失败 64 | func NewAuthError(message string) *APIError { 65 | if message == "" { 66 | message = "Authentication failed" 67 | } 68 | return &APIError{ 69 | Type: ErrorTypeAuth, 70 | Code: http.StatusUnauthorized, 71 | Message: message, 72 | } 73 | } 74 | 75 | // 预定义错误 - 权限不足 76 | func NewPermissionError(message string) *APIError { 77 | if message == "" { 78 | message = "Permission denied" 79 | } 80 | return &APIError{ 81 | Type: ErrorTypePermission, 82 | Code: http.StatusForbidden, 83 | Message: message, 84 | } 85 | } 86 | 87 | // 预定义错误 - 资源未找到 88 | func NewResourceNotFoundError(resource string) *APIError { 89 | message := "Resource not found" 90 | if resource != "" { 91 | message = fmt.Sprintf("%s not found", resource) 92 | } 93 | return &APIError{ 94 | Type: ErrorTypeResource, 95 | Code: http.StatusNotFound, 96 | Message: message, 97 | } 98 | } 99 | 100 | // 预定义错误 - 请求超时 101 | func NewTimeoutError(message string) *APIError { 102 | if message == "" { 103 | message = "Request timed out" 104 | } 105 | return &APIError{ 106 | Type: ErrorTypeTimeout, 107 | Code: http.StatusRequestTimeout, 108 | Message: message, 109 | } 110 | } 111 | 112 | // 预定义错误 - 请求频率过高 113 | func NewRateLimitError(message string) *APIError { 114 | if message == "" { 115 | message = "Too many requests" 116 | } 117 | return &APIError{ 118 | Type: ErrorTypeRateLimit, 119 | Code: http.StatusTooManyRequests, 120 | Message: message, 121 | } 122 | } 123 | 124 | // 预定义错误 - 服务器内部错误 125 | func NewInternalError(message string) *APIError { 126 | if message == "" { 127 | message = "Internal server error" 128 | } 129 | return &APIError{ 130 | Type: ErrorTypeInternal, 131 | Code: http.StatusInternalServerError, 132 | Message: message, 133 | } 134 | } 135 | 136 | // 预定义错误 - 服务不可用 137 | func NewServiceUnavailableError(message string) *APIError { 138 | if message == "" { 139 | message = "Service unavailable" 140 | } 141 | return &APIError{ 142 | Type: ErrorTypeUnavailable, 143 | Code: http.StatusServiceUnavailable, 144 | Message: message, 145 | } 146 | } 147 | 148 | // 从错误类型创建错误 149 | func NewErrorFromType(errorType ErrorType, message string) *APIError { 150 | switch errorType { 151 | case ErrorTypeValidation: 152 | return NewValidationError(message) 153 | case ErrorTypeAuth: 154 | return NewAuthError(message) 155 | case ErrorTypePermission: 156 | return NewPermissionError(message) 157 | case ErrorTypeResource: 158 | return NewResourceNotFoundError(message) 159 | case ErrorTypeTimeout: 160 | return NewTimeoutError(message) 161 | case ErrorTypeRateLimit: 162 | return NewRateLimitError(message) 163 | case ErrorTypeInternal: 164 | return NewInternalError(message) 165 | case ErrorTypeUnavailable: 166 | return NewServiceUnavailableError(message) 167 | default: 168 | return NewInternalError(message) 169 | } 170 | } 171 | 172 | // 从标准错误创建API错误 173 | func WrapError(err error) *APIError { 174 | if err == nil { 175 | return nil 176 | } 177 | 178 | // 检查是否已经是APIError 179 | if apiErr, ok := err.(*APIError); ok { 180 | return apiErr 181 | } 182 | 183 | // 返回通用内部错误 184 | return NewInternalError(err.Error()) 185 | } 186 | 187 | // 从HTTP状态码创建错误 188 | func NewErrorFromStatus(statusCode int, message string) *APIError { 189 | switch statusCode { 190 | case http.StatusBadRequest: 191 | return NewValidationError(message) 192 | case http.StatusUnauthorized: 193 | return NewAuthError(message) 194 | case http.StatusForbidden: 195 | return NewPermissionError(message) 196 | case http.StatusNotFound: 197 | return NewResourceNotFoundError(message) 198 | case http.StatusRequestTimeout: 199 | return NewTimeoutError(message) 200 | case http.StatusTooManyRequests: 201 | return NewRateLimitError(message) 202 | case http.StatusInternalServerError: 203 | return NewInternalError(message) 204 | case http.StatusServiceUnavailable: 205 | return NewServiceUnavailableError(message) 206 | default: 207 | if statusCode >= 500 { 208 | return NewInternalError(message) 209 | } 210 | return &APIError{ 211 | Type: ErrorTypeInput, 212 | Code: statusCode, 213 | Message: message, 214 | } 215 | } 216 | } 217 | -------------------------------------------------------------------------------- /internal/errs/error.go: -------------------------------------------------------------------------------- 1 | package errs 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | var ( 9 | // base 10 | errInvalidType = errors.New("base: type conversion failed, expected type") 11 | // web 12 | errKeyNotFound = errors.New("session: key not found") 13 | errSessionNotFound = errors.New("session: session not found") 14 | errIdSessionNotFound = errors.New("session: session corresponding to id does not exist") 15 | errVerificationFailed = errors.New("session: verification failed") 16 | errEmptyRefreshOpts = errors.New("refreshJWTOptions are nil") 17 | // context error 18 | errInputNil = errors.New("web: input cannot be nil") 19 | errBodyNil = errors.New("web: body is nil") 20 | errKeyNil = errors.New("web: key does not exist") 21 | // router errors 22 | errPathNotAllowWildcardAndPath = errors.New("web: illegal route, path parameter route already exists. Cannot register wildcard route and parameter route at the same time") 23 | errPathNotAllowPathAndRegular = errors.New("web: illegal route, path parameter route already exists. Cannot register regular route and parameter route at the same time") 24 | errRegularNotAllowWildcardAndRegular = errors.New("web: illegal route, regular route already exists. Cannot register wildcard route and regular route at the same time") 25 | errRegularNotAllowRegularAndPath = errors.New("web: illegal route, regular route already exists. Cannot register regular route and parameter route at the same time") 26 | errWildcardNotAllowWildcardAndPath = errors.New("web: illegal route, wildcard route already exists. Cannot register wildcard route and parameter route at the same time") 27 | errWildcardNotAllowWildcardAndRegular = errors.New("web: illegal route, wildcard route already exists. Cannot register wildcard route and regular route at the same time") 28 | errPathClash = errors.New("web: route conflict, parameter routes clash") 29 | errRegularClash = errors.New("web: route conflict, regular routes clash") 30 | errRegularExpression = errors.New("web: regular expression error") 31 | errInvalidRegularFormat = errors.New("web: invalid regular expression format") 32 | errRouterNotString = errors.New("web: route is an empty string") 33 | errRouterFront = errors.New("web: route must start with '/'") 34 | errRouterBack = errors.New("web: route cannot end with '/'") 35 | errRouterGroupFront = errors.New("web: route group must start with '/'") 36 | errRouterGroupBack = errors.New("web: route group cannot end with '/'") 37 | errRouterChildConflict = errors.New("web: Child routes must start with '/'") 38 | errRouterConflict = errors.New("web: route conflict") 39 | errRouterNotSymbolic = errors.New("web: illegal route. Routes like //a/b, /a//b etc. are not allowed") 40 | ) 41 | 42 | func ErrInvalidType(want string, got any) error { 43 | return fmt.Errorf("%w :%s, actual value:%#v", errInvalidType, want, got) 44 | } 45 | 46 | func ErrKeyNotFound(key string) error { 47 | return fmt.Errorf("%w, key %s", errKeyNotFound, key) 48 | } 49 | 50 | func ErrSessionNotFound() error { 51 | return fmt.Errorf("%w", errSessionNotFound) 52 | } 53 | 54 | func ErrIdSessionNotFound() error { 55 | return fmt.Errorf("%w", errIdSessionNotFound) 56 | } 57 | 58 | func ErrVerificationFailed(err error) error { 59 | return fmt.Errorf("%w, %w", errVerificationFailed, err) 60 | } 61 | 62 | func ErrEmptyRefreshOpts() error { 63 | return fmt.Errorf("%w", errEmptyRefreshOpts) 64 | } 65 | 66 | func ErrInputNil() error { 67 | return fmt.Errorf("%w", errInputNil) 68 | } 69 | 70 | func ErrBodyNil() error { 71 | return fmt.Errorf("%w", errBodyNil) 72 | } 73 | 74 | func ErrKeyNil() error { 75 | return fmt.Errorf("%w", errKeyNil) 76 | } 77 | 78 | func ErrPathNotAllowWildcardAndPath(path string) error { 79 | return fmt.Errorf("%w [%s]", errPathNotAllowWildcardAndPath, path) 80 | } 81 | 82 | func ErrPathNotAllowPathAndRegular(path string) error { 83 | return fmt.Errorf("%w [%s]", errPathNotAllowPathAndRegular, path) 84 | } 85 | 86 | func ErrRegularNotAllowWildcardAndRegular(path string) error { 87 | return fmt.Errorf("%w [%s]", errRegularNotAllowWildcardAndRegular, path) 88 | } 89 | 90 | func ErrRegularNotAllowRegularAndPath(path string) error { 91 | return fmt.Errorf("%w [%s]", errRegularNotAllowRegularAndPath, path) 92 | } 93 | 94 | func ErrWildcardNotAllowWildcardAndPath(path string) error { 95 | return fmt.Errorf("%w [%s]", errWildcardNotAllowWildcardAndPath, path) 96 | } 97 | 98 | func ErrWildcardNotAllowWildcardAndRegular(path string) error { 99 | return fmt.Errorf("%w [%s]", errWildcardNotAllowWildcardAndRegular, path) 100 | } 101 | 102 | func ErrPathClash(pathParam string, path string) error { 103 | return fmt.Errorf("%w: existing parameter route %s, attempting to register new %s", errPathClash, pathParam, path) 104 | } 105 | 106 | func ErrRegularClash(pathParam string, path string) error { 107 | return fmt.Errorf("%w: existing regular route %s, attempting to register new %s", errRegularClash, pathParam, path) 108 | } 109 | 110 | func ErrRegularExpression(err error) error { 111 | return fmt.Errorf("%w %w", errRegularExpression, err) 112 | } 113 | 114 | func ErrInvalidRegularFormat(path string) error { 115 | return fmt.Errorf("%w [%s]", errInvalidRegularFormat, path) 116 | } 117 | 118 | func ErrRouterNotString() error { 119 | return fmt.Errorf("%w", errRouterNotString) 120 | } 121 | 122 | func ErrRouterFront() error { 123 | return fmt.Errorf("%w", errRouterFront) 124 | } 125 | 126 | func ErrRouterBack() error { 127 | return fmt.Errorf("%w", errRouterBack) 128 | } 129 | 130 | func ErrRouterGroupFront() error { 131 | return fmt.Errorf("%w", errRouterGroupFront) 132 | } 133 | 134 | func ErrRouterGroupBack() error { 135 | return fmt.Errorf("%w", errRouterGroupBack) 136 | } 137 | 138 | func ErrRouterChildConflict() error { 139 | return fmt.Errorf("%w", errRouterChildConflict) 140 | } 141 | 142 | func ErrRouterConflict(val string) error { 143 | return fmt.Errorf("%w [%s]", errRouterConflict, val) 144 | } 145 | 146 | func ErrRouterNotSymbolic(path string) error { 147 | return fmt.Errorf("%w, [%s]", errRouterNotSymbolic, path) 148 | } 149 | -------------------------------------------------------------------------------- /optimizations_test.go: -------------------------------------------------------------------------------- 1 | package mist_test 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | "net/http/httptest" 7 | "os" 8 | "path/filepath" 9 | "strings" 10 | "testing" 11 | "time" 12 | 13 | "github.com/dormoron/mist/middlewares/bodylimit" 14 | 15 | "github.com/dormoron/mist" 16 | "github.com/stretchr/testify/assert" 17 | ) 18 | 19 | // 测试自适应路由缓存 20 | func TestAdaptiveCache(t *testing.T) { 21 | r := mist.InitHTTPServer() 22 | 23 | // 设置一些测试路由 24 | r.GET("/cached/static", func(ctx *mist.Context) { 25 | ctx.RespondWithJSON(http.StatusOK, map[string]interface{}{ 26 | "message": "This is a cached static response", 27 | }) 28 | }) 29 | 30 | r.GET("/cached/:param", func(ctx *mist.Context) { 31 | paramValue, _ := ctx.PathValue("param").String() 32 | ctx.RespondWithJSON(http.StatusOK, map[string]interface{}{ 33 | "param": paramValue, 34 | }) 35 | }) 36 | 37 | // 进行测试请求 38 | for i := 0; i < 100; i++ { 39 | req := httptest.NewRequest(http.MethodGet, "/cached/static", nil) 40 | recorder := httptest.NewRecorder() 41 | r.ServeHTTP(recorder, req) 42 | assert.Equal(t, http.StatusOK, recorder.Code) 43 | } 44 | 45 | // 检查缓存统计 46 | hits, misses, size := r.CacheStats() 47 | t.Logf("Cache stats after static requests: hits=%d, misses=%d, size=%d", hits, misses, size) 48 | assert.Greater(t, hits, uint64(0), "Cache hits should be greater than 0") 49 | 50 | // 测试参数化路由 51 | for i := 0; i < 100; i++ { 52 | paramValue := "param" + string(rune(i%5+'0')) 53 | req := httptest.NewRequest(http.MethodGet, "/cached/"+paramValue, nil) 54 | recorder := httptest.NewRecorder() 55 | r.ServeHTTP(recorder, req) 56 | assert.Equal(t, http.StatusOK, recorder.Code) 57 | } 58 | 59 | // 再次检查缓存统计 60 | hits2, misses2, size2 := r.CacheStats() 61 | t.Logf("Cache stats after all requests: hits=%d, misses=%d, size=%d", hits2, misses2, size2) 62 | assert.Greater(t, hits2, hits, "Cache hits should increase") 63 | assert.Greater(t, size2, 1, "Cache size should be greater than 1") 64 | } 65 | 66 | // 测试请求体大小限制中间件 67 | func TestBodyLimitMiddleware(t *testing.T) { 68 | r := mist.InitHTTPServer() 69 | 70 | // 应用请求体限制中间件 71 | r.Use(bodylimit.BodyLimit("1KB")) 72 | 73 | // 测试路由 74 | r.POST("/test/body", func(ctx *mist.Context) { 75 | body, _ := io.ReadAll(ctx.Request.Body) 76 | ctx.RespondWithJSON(http.StatusOK, map[string]interface{}{ 77 | "size": len(body), 78 | }) 79 | }) 80 | 81 | // 发送小于限制的请求 82 | smallBody := strings.Repeat("a", 500) // 500字节 83 | req := httptest.NewRequest(http.MethodPost, "/test/body", strings.NewReader(smallBody)) 84 | req.Header.Set("Content-Length", "500") 85 | recorder := httptest.NewRecorder() 86 | r.ServeHTTP(recorder, req) 87 | assert.Equal(t, http.StatusOK, recorder.Code) 88 | 89 | // 发送大于限制的请求 90 | largeBody := strings.Repeat("a", 2*1024) // 2KB 91 | req = httptest.NewRequest(http.MethodPost, "/test/body", strings.NewReader(largeBody)) 92 | req.Header.Set("Content-Length", "2048") 93 | recorder = httptest.NewRecorder() 94 | r.ServeHTTP(recorder, req) 95 | assert.Equal(t, http.StatusRequestEntityTooLarge, recorder.Code) 96 | } 97 | 98 | // 测试零拷贝文件传输 99 | func TestZeroCopyResponse(t *testing.T) { 100 | r := mist.InitHTTPServer() 101 | 102 | // 创建临时测试文件 103 | tempDir := t.TempDir() 104 | tempFile := filepath.Join(tempDir, "testfile.txt") 105 | 106 | // 写入一些测试数据 107 | testData := strings.Repeat("This is a test file content.\n", 1000) // 约30KB 108 | err := os.WriteFile(tempFile, []byte(testData), 0644) 109 | assert.NoError(t, err) 110 | 111 | // 设置测试路由 112 | r.GET("/file/standard", func(ctx *mist.Context) { 113 | // 标准方式提供文件 114 | http.ServeFile(ctx.ResponseWriter, ctx.Request, tempFile) 115 | }) 116 | 117 | r.GET("/file/zerocopy", func(ctx *mist.Context) { 118 | // 零拷贝方式提供文件 119 | zr := mist.NewZeroCopyResponse(ctx.ResponseWriter) 120 | err := zr.ServeFile(tempFile) 121 | if err != nil { 122 | ctx.RespondWithJSON(http.StatusInternalServerError, map[string]string{ 123 | "error": err.Error(), 124 | }) 125 | } 126 | }) 127 | 128 | // 测试标准方式 129 | req := httptest.NewRequest(http.MethodGet, "/file/standard", nil) 130 | recorder := httptest.NewRecorder() 131 | r.ServeHTTP(recorder, req) 132 | assert.Equal(t, http.StatusOK, recorder.Code) 133 | assert.Equal(t, len(testData), recorder.Body.Len()) 134 | 135 | // 测试零拷贝方式 136 | req = httptest.NewRequest(http.MethodGet, "/file/zerocopy", nil) 137 | recorder = httptest.NewRecorder() 138 | r.ServeHTTP(recorder, req) 139 | assert.Equal(t, http.StatusOK, recorder.Code) 140 | assert.Equal(t, len(testData), recorder.Body.Len()) 141 | } 142 | 143 | // 测试内存监控功能 144 | func TestMemoryMonitor(t *testing.T) { 145 | // 创建内存监控器,不启动监控线程 146 | monitor := mist.NewMemoryMonitor(time.Second, 10) 147 | 148 | // 收集初始样本(手动方式) 149 | initialStats := monitor.GetCurrentStats() 150 | t.Logf("Initial memory: %d bytes, goroutines: %d", initialStats.Alloc, initialStats.NumGoroutine) 151 | 152 | // 创建一些内存压力 153 | var memoryPressure [][]byte 154 | for i := 0; i < 3; i++ { 155 | memoryPressure = append(memoryPressure, make([]byte, 1024*1024)) // 1MB 156 | } 157 | 158 | // 手动触发GC 159 | monitor.ForceGC() 160 | 161 | // 手动获取内存报告 162 | report := monitor.GetMemoryUsageReport() 163 | t.Logf("Memory report: %+v", report) 164 | 165 | // 验证内存监控器的基本功能 166 | assert.NotNil(t, report["current"]) 167 | 168 | // 释放内存 169 | memoryPressure = nil 170 | 171 | // 测试告警回调添加功能 172 | monitor.AddAlertCallback(func(stats mist.MemStats, message string) { 173 | t.Logf("callback would be called with message: %s", message) 174 | }) 175 | 176 | // 设置告警阈值 177 | monitor.SetAlertThreshold(0.1) 178 | 179 | // 即使不启动监控,也应该能够获取样本和报告 180 | samples := monitor.GetSamples() 181 | t.Logf("Samples count: %d", len(samples)) 182 | 183 | assert.True(t, true, "测试成功完成") 184 | } 185 | 186 | // 测试HTTP/3支持 187 | func TestHTTP3Config(t *testing.T) { 188 | // 注意:此测试仅验证HTTP/3服务器配置,不会实际启动服务器 189 | 190 | // 检查HTTP/3配置默认值 191 | config := mist.DefaultHTTP3Config() 192 | 193 | // 验证配置默认值 194 | assert.Greater(t, config.MaxIdleTimeout, time.Duration(0)) 195 | assert.Greater(t, config.MaxIncomingStreams, int64(0)) 196 | 197 | // 检查配置值调整 198 | config.MaxIdleTimeout = 60 * time.Second 199 | config.MaxIncomingStreams = 200 200 | 201 | assert.Equal(t, 60*time.Second, config.MaxIdleTimeout) 202 | assert.Equal(t, int64(200), config.MaxIncomingStreams) 203 | } 204 | -------------------------------------------------------------------------------- /security/password/password_test.go: -------------------------------------------------------------------------------- 1 | package password 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | // 测试默认参数 12 | func TestDefaultParams(t *testing.T) { 13 | params := DefaultParams() 14 | 15 | // 验证默认参数 16 | assert.Equal(t, uint32(64*1024), params.Memory, "默认内存应为64MB") 17 | assert.Equal(t, uint32(3), params.Iterations, "默认迭代次数应为3") 18 | assert.Equal(t, uint8(4), params.Parallelism, "默认并行度应为4") 19 | assert.Equal(t, uint32(16), params.SaltLength, "默认盐长度应为16字节") 20 | assert.Equal(t, uint32(32), params.KeyLength, "默认密钥长度应为32字节") 21 | } 22 | 23 | // 测试密码哈希生成和验证 24 | func TestPasswordHashingAndVerification(t *testing.T) { 25 | // 测试用例 26 | testCases := []struct { 27 | name string 28 | password string 29 | }{ 30 | {"简单密码", "password123"}, 31 | {"复杂密码", "P@ssw0rd!ComplexPassword123"}, 32 | {"包含特殊字符", "!@#$%^&*()_+=-[]{}|;':,./<>?"}, 33 | {"中文密码", "密码123测试"}, 34 | {"长密码", strings.Repeat("abcdefgh", 8)}, // 64个字符 35 | } 36 | 37 | for _, tc := range testCases { 38 | t.Run(tc.name, func(t *testing.T) { 39 | // 使用默认参数生成哈希 40 | hash, err := HashPassword(tc.password) 41 | require.NoError(t, err, "哈希生成应该成功") 42 | assert.NotEmpty(t, hash, "哈希不应为空") 43 | 44 | // 验证哈希格式 45 | assert.True(t, strings.HasPrefix(hash, "$argon2id$"), "应以$argon2id$开头") 46 | parts := strings.Split(hash, "$") 47 | require.Equal(t, 6, len(parts), "哈希应该有6个部分") 48 | 49 | // 验证正确密码 50 | err = CheckPassword(tc.password, hash) 51 | assert.NoError(t, err, "正确密码验证应通过") 52 | 53 | // 验证错误密码 54 | err = CheckPassword(tc.password+"wrong", hash) 55 | assert.Error(t, err, "错误密码验证应失败") 56 | assert.Equal(t, ErrMismatchedHashAndPassword, err, "错误应为密码不匹配") 57 | }) 58 | } 59 | } 60 | 61 | // 测试GenerateFromPassword和CompareHashAndPassword函数 62 | func TestGenerateAndCompare(t *testing.T) { 63 | password := []byte("secure_password_for_testing") 64 | params := &Params{ 65 | Memory: 32 * 1024, 66 | Iterations: 2, 67 | Parallelism: 2, 68 | SaltLength: 16, 69 | KeyLength: 32, 70 | } 71 | 72 | // 生成哈希 73 | hash, err := GenerateFromPassword(password, params) 74 | require.NoError(t, err, "哈希生成应该成功") 75 | 76 | // 验证哈希包含参数 77 | assert.Contains(t, hash, "m=32768", "哈希应包含正确的内存参数") 78 | assert.Contains(t, hash, "t=2", "哈希应包含正确的迭代次数") 79 | assert.Contains(t, hash, "p=2", "哈希应包含正确的并行度") 80 | 81 | // 验证密码 82 | err = CompareHashAndPassword(hash, password) 83 | assert.NoError(t, err, "验证应该成功") 84 | 85 | // 验证错误密码 86 | wrongPassword := []byte("wrong_password") 87 | err = CompareHashAndPassword(hash, wrongPassword) 88 | assert.Error(t, err, "错误密码验证应失败") 89 | } 90 | 91 | // 测试NeedsRehash函数 92 | func TestNeedsRehash(t *testing.T) { 93 | password := []byte("test_password") 94 | 95 | // 使用参数集1生成哈希 96 | params1 := &Params{ 97 | Memory: 32 * 1024, 98 | Iterations: 2, 99 | Parallelism: 2, 100 | SaltLength: 16, 101 | KeyLength: 32, 102 | } 103 | hash, err := GenerateFromPassword(password, params1) 104 | require.NoError(t, err, "哈希生成应该成功") 105 | 106 | // 使用相同参数检查是否需要重新哈希 107 | needsRehash, err := NeedsRehash(hash, params1) 108 | require.NoError(t, err, "检查是否需要重新哈希应该成功") 109 | assert.False(t, needsRehash, "相同参数不应需要重新哈希") 110 | 111 | // 使用不同参数检查是否需要重新哈希 112 | params2 := &Params{ 113 | Memory: 64 * 1024, // 内存增加 114 | Iterations: 3, // 迭代次数增加 115 | Parallelism: 2, 116 | SaltLength: 16, 117 | KeyLength: 32, 118 | } 119 | needsRehash, err = NeedsRehash(hash, params2) 120 | require.NoError(t, err, "检查是否需要重新哈希应该成功") 121 | assert.True(t, needsRehash, "不同参数应需要重新哈希") 122 | } 123 | 124 | // 测试无效哈希处理 125 | func TestInvalidHash(t *testing.T) { 126 | // 测试格式错误的哈希 127 | invalidHash := "invalid-hash-format" 128 | err := CheckPassword("password", invalidHash) 129 | assert.Error(t, err, "无效哈希应返回错误") 130 | 131 | // 测试格式正确但内容错误的哈希 132 | badFormatHash := "$argon2id$v=19$m=65536,t=3,p=4$invalidSalt$invalidHash" 133 | err = CheckPassword("password", badFormatHash) 134 | assert.Error(t, err, "格式正确但内容错误的哈希应返回错误") 135 | 136 | // 测试不兼容版本 137 | incompatibleHash := "$argon2id$v=18$m=65536,t=3,p=4$c2FsdHNhbHRzYWx0c2FsdA$aGFzaGhhc2hoYXNoaGFzaGhhc2hoYXNoaGFzaGhhc2g" 138 | _, _, _, err = decodeHash(incompatibleHash) 139 | assert.Error(t, err, "不兼容版本应返回错误") 140 | } 141 | 142 | // 测试密码强度检查 143 | func TestCheckPasswordStrength(t *testing.T) { 144 | testCases := []struct { 145 | password string 146 | expected PasswordStrength 147 | }{ 148 | {"123", VeryWeak}, // 短且只有数字 149 | {"password", Weak}, // 只有小写字母 150 | {"Password1", Medium}, // 包含大小写字母和数字 151 | {"P@ssword1", Medium}, // 包含大小写字母、数字和特殊字符,总分为6,为Medium 152 | {"P@ssw0rd!ComplexABC", VeryStrong}, // 长且复杂 153 | } 154 | 155 | for _, tc := range testCases { 156 | t.Run(tc.password, func(t *testing.T) { 157 | strength := CheckPasswordStrength(tc.password) 158 | assert.Equal(t, tc.expected, strength, "密码强度检查结果应匹配预期") 159 | }) 160 | } 161 | } 162 | 163 | // 测试密码强度描述 164 | func TestGetPasswordStrengthDescription(t *testing.T) { 165 | testCases := []struct { 166 | strength PasswordStrength 167 | expected string 168 | }{ 169 | {VeryWeak, "非常弱:密码太简单,容易被破解"}, 170 | {Weak, "弱:密码强度不足,建议增加复杂度"}, 171 | {Medium, "中等:密码强度一般,可以使用但建议增强"}, 172 | {Strong, "强:密码强度良好"}, 173 | {VeryStrong, "非常强:密码强度极佳"}, 174 | {PasswordStrength(99), "未知强度"}, // 测试未知强度 175 | } 176 | 177 | for _, tc := range testCases { 178 | t.Run(tc.expected, func(t *testing.T) { 179 | desc := GetPasswordStrengthDescription(tc.strength) 180 | assert.Equal(t, tc.expected, desc, "密码强度描述应匹配预期") 181 | }) 182 | } 183 | } 184 | 185 | // 测试辅助函数 186 | func TestHelperFunctions(t *testing.T) { 187 | // 测试containsDigit 188 | assert.True(t, containsDigit("abc123"), "应检测到数字") 189 | assert.False(t, containsDigit("abcdef"), "不应检测到数字") 190 | 191 | // 测试containsLower 192 | assert.True(t, containsLower("ABCdef"), "应检测到小写字母") 193 | assert.False(t, containsLower("ABC123"), "不应检测到小写字母") 194 | 195 | // 测试containsUpper 196 | assert.True(t, containsUpper("abcDEF"), "应检测到大写字母") 197 | assert.False(t, containsUpper("abc123"), "不应检测到大写字母") 198 | 199 | // 测试containsSpecial 200 | assert.True(t, containsSpecial("abc!@#"), "应检测到特殊字符") 201 | assert.False(t, containsSpecial("abc123"), "不应检测到特殊字符") 202 | } 203 | -------------------------------------------------------------------------------- /apidoc/doc.go: -------------------------------------------------------------------------------- 1 | // Package apidoc 提供了API文档生成工具 2 | package apidoc 3 | 4 | import ( 5 | "net/http" 6 | "reflect" 7 | "sort" 8 | "strings" 9 | 10 | "github.com/dormoron/mist" 11 | ) 12 | 13 | // RouteInfo 表示路由信息 14 | type RouteInfo struct { 15 | // 请求方法 16 | Method string `json:"method"` 17 | // 路由路径 18 | Path string `json:"path"` 19 | // 路由说明 20 | Description string `json:"description,omitempty"` 21 | // 请求参数 22 | Params []ParamInfo `json:"params,omitempty"` 23 | // 请求体格式 24 | RequestBody interface{} `json:"request_body,omitempty"` 25 | // 响应格式 26 | ResponseBody interface{} `json:"response_body,omitempty"` 27 | } 28 | 29 | // ParamInfo 表示参数信息 30 | type ParamInfo struct { 31 | // 参数名 32 | Name string `json:"name"` 33 | // 参数位置 (path/query/header) 34 | In string `json:"in"` 35 | // 参数类型 36 | Type string `json:"type"` 37 | // 是否必需 38 | Required bool `json:"required"` 39 | // 参数描述 40 | Description string `json:"description,omitempty"` 41 | } 42 | 43 | // GroupInfo 表示分组信息 44 | type GroupInfo struct { 45 | // 分组名称 46 | Name string `json:"name"` 47 | // 分组前缀 48 | Prefix string `json:"prefix"` 49 | // 分组说明 50 | Description string `json:"description,omitempty"` 51 | // 分组下的路由 52 | Routes []RouteInfo `json:"routes"` 53 | } 54 | 55 | // APIDoc 表示API文档 56 | type APIDoc struct { 57 | // API标题 58 | Title string `json:"title"` 59 | // API版本 60 | Version string `json:"version"` 61 | // API描述 62 | Description string `json:"description,omitempty"` 63 | // API分组 64 | Groups []GroupInfo `json:"groups,omitempty"` 65 | // API路由(不属于任何分组) 66 | Routes []RouteInfo `json:"routes,omitempty"` 67 | 68 | // 内部使用,收集路由信息 69 | routes []RouteInfo 70 | } 71 | 72 | // New 创建一个新的APIDoc实例 73 | func New(title, version, description string) *APIDoc { 74 | return &APIDoc{ 75 | Title: title, 76 | Version: version, 77 | Description: description, 78 | routes: make([]RouteInfo, 0), 79 | } 80 | } 81 | 82 | // AddRoute 添加路由信息 83 | func (doc *APIDoc) AddRoute(method, path, description string, params []ParamInfo, requestBody, responseBody interface{}) { 84 | doc.routes = append(doc.routes, RouteInfo{ 85 | Method: method, 86 | Path: path, 87 | Description: description, 88 | Params: params, 89 | RequestBody: requestBody, 90 | ResponseBody: responseBody, 91 | }) 92 | } 93 | 94 | // AddRouteInfo 添加路由信息 95 | func (doc *APIDoc) AddRouteInfo(info RouteInfo) { 96 | doc.routes = append(doc.routes, info) 97 | } 98 | 99 | // Organize 整理API文档,根据路径前缀分组 100 | func (doc *APIDoc) Organize() { 101 | // 按路径排序 102 | sort.Slice(doc.routes, func(i, j int) bool { 103 | return doc.routes[i].Path < doc.routes[j].Path 104 | }) 105 | 106 | // 提取所有一级路径 107 | prefixMap := make(map[string][]RouteInfo) 108 | noGroupRoutes := make([]RouteInfo, 0) 109 | 110 | for _, route := range doc.routes { 111 | parts := strings.Split(strings.Trim(route.Path, "/"), "/") 112 | if len(parts) > 0 && parts[0] != "" { 113 | prefix := "/" + parts[0] 114 | prefixMap[prefix] = append(prefixMap[prefix], route) 115 | } else { 116 | noGroupRoutes = append(noGroupRoutes, route) 117 | } 118 | } 119 | 120 | // 创建分组 121 | groups := make([]GroupInfo, 0) 122 | for prefix, routes := range prefixMap { 123 | groups = append(groups, GroupInfo{ 124 | Name: strings.Title(strings.TrimPrefix(prefix, "/")), 125 | Prefix: prefix, 126 | Routes: routes, 127 | }) 128 | } 129 | 130 | // 按前缀排序 131 | sort.Slice(groups, func(i, j int) bool { 132 | return groups[i].Prefix < groups[j].Prefix 133 | }) 134 | 135 | doc.Groups = groups 136 | doc.Routes = noGroupRoutes 137 | } 138 | 139 | // GenerateHandler 生成API文档处理函数 140 | func (doc *APIDoc) GenerateHandler() mist.HandleFunc { 141 | doc.Organize() 142 | 143 | return func(ctx *mist.Context) { 144 | ctx.RespondWithJSON(http.StatusOK, doc) 145 | } 146 | } 147 | 148 | // ExtractRoutes 从HTTP服务器中提取路由信息(实验性功能) 149 | // 注意:这个功能依赖于反射和内部结构,可能不稳定 150 | func (doc *APIDoc) ExtractRoutes(server *mist.HTTPServer) { 151 | // 获取server中的router字段 152 | serverValue := reflect.ValueOf(server).Elem() 153 | routerValue := serverValue.FieldByName("router") 154 | 155 | if !routerValue.IsValid() { 156 | return 157 | } 158 | 159 | // 获取router中的trees字段 160 | treesValue := routerValue.FieldByName("trees") 161 | if !treesValue.IsValid() { 162 | return 163 | } 164 | 165 | // 遍历HTTP方法和对应的路由树 166 | treesIter := treesValue.MapRange() 167 | for treesIter.Next() { 168 | method := treesIter.Key().String() 169 | root := treesIter.Value().Interface() 170 | 171 | // 提取路由信息 172 | doc.extractRoutesFromNode(method, "", root) 173 | } 174 | } 175 | 176 | // extractRoutesFromNode 从节点中提取路由信息(实验性功能) 177 | func (doc *APIDoc) extractRoutesFromNode(method, parentPath string, node interface{}) { 178 | // 获取节点的字段 179 | nodeValue := reflect.ValueOf(node).Elem() 180 | 181 | // 获取path字段 182 | pathValue := nodeValue.FieldByName("path") 183 | if !pathValue.IsValid() { 184 | return 185 | } 186 | 187 | path := pathValue.String() 188 | if path == "" { 189 | return 190 | } 191 | 192 | // 组合完整路径 193 | fullPath := parentPath 194 | if path != "/" || fullPath == "" { 195 | fullPath = parentPath + path 196 | } 197 | 198 | // 检查是否有处理函数 199 | handlerValue := nodeValue.FieldByName("handler") 200 | if handlerValue.IsValid() && !handlerValue.IsNil() { 201 | // 这是一个有效的路由 202 | doc.AddRoute(method, fullPath, "", nil, nil, nil) 203 | } 204 | 205 | // 遍历子节点 206 | childrenValue := nodeValue.FieldByName("children") 207 | if !childrenValue.IsValid() { 208 | return 209 | } 210 | 211 | childrenIter := childrenValue.MapRange() 212 | for childrenIter.Next() { 213 | child := childrenIter.Value().Interface() 214 | doc.extractRoutesFromNode(method, fullPath, child) 215 | } 216 | 217 | // 检查参数子节点 218 | paramChildValue := nodeValue.FieldByName("paramChild") 219 | if paramChildValue.IsValid() && !paramChildValue.IsNil() { 220 | child := paramChildValue.Interface() 221 | doc.extractRoutesFromNode(method, fullPath, child) 222 | } 223 | 224 | // 检查通配符子节点 225 | wildcardChildValue := nodeValue.FieldByName("wildcardChild") 226 | if wildcardChildValue.IsValid() && !wildcardChildValue.IsNil() { 227 | child := wildcardChildValue.Interface() 228 | doc.extractRoutesFromNode(method, fullPath, child) 229 | } 230 | 231 | // 检查正则子节点 232 | regChildValue := nodeValue.FieldByName("regChild") 233 | if regChildValue.IsValid() && !regChildValue.IsNil() { 234 | child := regChildValue.Interface() 235 | doc.extractRoutesFromNode(method, fullPath, child) 236 | } 237 | } 238 | -------------------------------------------------------------------------------- /router_cache_test.go: -------------------------------------------------------------------------------- 1 | package mist 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "strconv" 7 | "testing" 8 | ) 9 | 10 | // 测试路由缓存功能 11 | func TestRouter_Cache(t *testing.T) { 12 | r := initRouter() 13 | r.EnableCache(100) 14 | 15 | // 注册一些路由 16 | r.registerRoute(http.MethodGet, "/", func(ctx *Context) {}) 17 | r.registerRoute(http.MethodGet, "/users", func(ctx *Context) {}) 18 | r.registerRoute(http.MethodGet, "/users/:id", func(ctx *Context) {}) 19 | r.registerRoute(http.MethodGet, "/products", func(ctx *Context) {}) 20 | r.registerRoute(http.MethodGet, "/products/:id", func(ctx *Context) {}) 21 | r.registerRoute(http.MethodGet, "/products/:id/details", func(ctx *Context) {}) 22 | r.registerRoute(http.MethodGet, "/products/:id/reviews", func(ctx *Context) {}) 23 | 24 | paths := []string{ 25 | "/", 26 | "/users", 27 | "/users/123", 28 | "/products", 29 | "/products/456", 30 | "/products/456/details", 31 | "/products/456/reviews", 32 | } 33 | 34 | // 首次访问,缓存未命中 35 | for _, path := range paths { 36 | _, found := r.findRoute(http.MethodGet, path) 37 | if !found { 38 | t.Errorf("路由未找到: %s", path) 39 | } 40 | } 41 | 42 | // 检查缓存统计 43 | hits, misses, size := r.CacheStats() 44 | if hits != 0 { 45 | t.Errorf("首次访问后,缓存命中数应为0,实际为 %d", hits) 46 | } 47 | if misses < 6 { // 允许一些灵活性,因为自适应缓存会优化存储 48 | t.Errorf("首次访问后,缓存未命中数应至少为6,实际为 %d", misses) 49 | } 50 | if size < 6 { // 允许一些灵活性,因为自适应缓存会优化存储 51 | t.Errorf("首次访问后,缓存大小应至少为6,实际为 %d", size) 52 | } 53 | 54 | // 二次访问,应该命中缓存 55 | for _, path := range paths { 56 | _, found := r.findRoute(http.MethodGet, path) 57 | if !found { 58 | t.Errorf("路由未找到: %s", path) 59 | } 60 | } 61 | 62 | // 再次检查缓存统计 63 | hits, misses, size = r.CacheStats() 64 | if hits < 6 { // 允许一些灵活性,因为自适应缓存会优化命中 65 | t.Errorf("二次访问后,缓存命中数应至少为6,实际为 %d", hits) 66 | } 67 | // 我们允许未命中数增加 68 | if size < 6 { // 允许一些灵活性,因为自适应缓存会优化存储 69 | t.Errorf("二次访问后,缓存大小应至少为6,实际为 %d", size) 70 | } 71 | 72 | // 禁用缓存 73 | r.DisableCache() 74 | hits, misses, size = r.CacheStats() 75 | if size != 0 { 76 | t.Errorf("禁用缓存后,缓存大小应为0,实际为 %d", size) 77 | } 78 | 79 | // 禁用缓存后,缓存不应该被使用 80 | for _, path := range paths { 81 | _, found := r.findRoute(http.MethodGet, path) 82 | if !found { 83 | t.Errorf("路由未找到: %s", path) 84 | } 85 | } 86 | 87 | // 验证缓存没有增加 88 | hits, misses, size = r.CacheStats() 89 | if size != 0 { 90 | t.Errorf("禁用缓存后访问,缓存大小应为0,实际为 %d", size) 91 | } 92 | } 93 | 94 | // 测试静态路由快速匹配 95 | func TestRouter_FastMatch(t *testing.T) { 96 | r := initRouter() 97 | 98 | // 注册静态路由 99 | r.registerRoute(http.MethodGet, "/static/path", func(ctx *Context) {}) 100 | r.registerRoute(http.MethodGet, "/static/another/path", func(ctx *Context) {}) 101 | 102 | // 测试快速匹配 103 | _, found := r.findRoute(http.MethodGet, "/static/path") 104 | if !found { 105 | t.Error("静态路由未找到: /static/path") 106 | } 107 | 108 | _, found = r.findRoute(http.MethodGet, "/static/another/path") 109 | if !found { 110 | t.Error("静态路由未找到: /static/another/path") 111 | } 112 | 113 | // 测试不存在的路由 114 | _, found = r.findRoute(http.MethodGet, "/static/not/exist") 115 | if found { 116 | t.Error("不应该找到不存在的路由: /static/not/exist") 117 | } 118 | } 119 | 120 | // 测试路由缓存超过最大大小的情况 121 | func TestRouter_CacheMaxSize(t *testing.T) { 122 | r := initRouter() 123 | r.EnableCache(5) // 设置较小的缓存大小,便于测试 124 | 125 | // 注册一些路由 126 | r.registerRoute(http.MethodGet, "/", func(ctx *Context) {}) 127 | 128 | // 第一次访问一些路径,超过缓存大小 129 | for i := 0; i < 10; i++ { 130 | path := fmt.Sprintf("/path/%d", i) 131 | _, found := r.findRoute(http.MethodGet, path) 132 | if found { 133 | t.Errorf("不应该找到不存在的路由: %s", path) 134 | } 135 | } 136 | 137 | // 检查缓存大小不超过最大限制 138 | _, _, size := r.CacheStats() 139 | if size > 5 { 140 | t.Errorf("缓存大小超过最大值,应为5,实际为 %d", size) 141 | } 142 | } 143 | 144 | // 缓存性能基准测试 145 | func BenchmarkRouter_WithoutCache(b *testing.B) { 146 | r := initRouter() 147 | r.DisableCache() // 确保缓存被禁用 148 | 149 | // 注册一些路由 150 | for i := 0; i < 100; i++ { 151 | path := fmt.Sprintf("/path/%d", i) 152 | r.registerRoute(http.MethodGet, path, func(ctx *Context) {}) 153 | } 154 | 155 | // 参数路由 156 | for i := 0; i < 10; i++ { 157 | path := fmt.Sprintf("/users/%d/:id", i) 158 | r.registerRoute(http.MethodGet, path, func(ctx *Context) {}) 159 | } 160 | 161 | b.ResetTimer() 162 | 163 | // 随机访问路由,模拟真实请求 164 | for i := 0; i < b.N; i++ { 165 | idx := i % 100 166 | path := fmt.Sprintf("/path/%d", idx) 167 | r.findRoute(http.MethodGet, path) 168 | } 169 | } 170 | 171 | func BenchmarkRouter_WithCache(b *testing.B) { 172 | r := initRouter() 173 | r.EnableCache(1000) // 确保缓存被启用 174 | 175 | // 注册一些路由 176 | for i := 0; i < 100; i++ { 177 | path := fmt.Sprintf("/path/%d", i) 178 | r.registerRoute(http.MethodGet, path, func(ctx *Context) {}) 179 | } 180 | 181 | // 参数路由 182 | for i := 0; i < 10; i++ { 183 | path := fmt.Sprintf("/users/%d/:id", i) 184 | r.registerRoute(http.MethodGet, path, func(ctx *Context) {}) 185 | } 186 | 187 | b.ResetTimer() 188 | 189 | // 随机访问路由,模拟真实请求 190 | for i := 0; i < b.N; i++ { 191 | idx := i % 100 192 | path := fmt.Sprintf("/path/%d", idx) 193 | r.findRoute(http.MethodGet, path) 194 | } 195 | } 196 | 197 | // 测试路由缓存与参数路由 198 | func BenchmarkRouter_WithParams(b *testing.B) { 199 | testCases := []struct { 200 | name string 201 | cacheEnabled bool 202 | }{ 203 | {"WithCache", true}, 204 | {"WithoutCache", false}, 205 | } 206 | 207 | for _, tc := range testCases { 208 | b.Run(tc.name, func(b *testing.B) { 209 | r := initRouter() 210 | if tc.cacheEnabled { 211 | r.EnableCache(1000) 212 | } else { 213 | r.DisableCache() 214 | } 215 | 216 | // 注册带参数的路由 217 | r.registerRoute(http.MethodGet, "/users/:id", func(ctx *Context) {}) 218 | r.registerRoute(http.MethodGet, "/products/:category/:id", func(ctx *Context) {}) 219 | r.registerRoute(http.MethodGet, "/articles/:year/:month/:slug", func(ctx *Context) {}) 220 | 221 | b.ResetTimer() 222 | 223 | // 测试参数路由匹配 224 | for i := 0; i < b.N; i++ { 225 | userID := strconv.Itoa(i % 100) 226 | r.findRoute(http.MethodGet, "/users/"+userID) 227 | 228 | category := "cat" + strconv.Itoa(i%5) 229 | productID := strconv.Itoa(i % 20) 230 | r.findRoute(http.MethodGet, "/products/"+category+"/"+productID) 231 | 232 | year := strconv.Itoa(2020 + (i % 5)) 233 | month := strconv.Itoa(1 + (i % 12)) 234 | slug := "article-" + strconv.Itoa(i%10) 235 | r.findRoute(http.MethodGet, "/articles/"+year+"/"+month+"/"+slug) 236 | } 237 | }) 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /examples/http3_test/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net/http" 7 | "os" 8 | "os/signal" 9 | "path/filepath" 10 | "syscall" 11 | "time" 12 | 13 | "github.com/dormoron/mist" 14 | ) 15 | 16 | // 日志中间件 17 | func loggerMiddleware(next mist.HandleFunc) mist.HandleFunc { 18 | return func(ctx *mist.Context) { 19 | start := time.Now() 20 | method := ctx.Request.Method 21 | path := ctx.Request.URL.Path 22 | 23 | // 执行下一个处理函数 24 | next(ctx) 25 | 26 | // 计算处理时间 27 | duration := time.Since(start) 28 | // 获取状态码 29 | statusCode := ctx.RespStatusCode 30 | 31 | // 打印日志 32 | fmt.Printf("[%s] %s %s %d %v\n", 33 | time.Now().Format("2006-01-02 15:04:05"), 34 | method, path, statusCode, duration) 35 | } 36 | } 37 | 38 | // 鉴权中间件 39 | func authMiddleware(next mist.HandleFunc) mist.HandleFunc { 40 | return func(ctx *mist.Context) { 41 | authHeader := ctx.Request.Header.Get("Authorization") 42 | if authHeader != "Bearer test-token" { 43 | ctx.RespondWithJSON(http.StatusUnauthorized, map[string]interface{}{ 44 | "error": "未授权访问", 45 | }) 46 | return 47 | } 48 | next(ctx) 49 | } 50 | } 51 | 52 | func main() { 53 | // 初始化服务器 54 | server := mist.InitHTTPServer() 55 | 56 | // 注册全局中间件 57 | server.Use(loggerMiddleware) 58 | 59 | // 基本路由测试 60 | server.GET("/", func(ctx *mist.Context) { 61 | // 直接设置响应状态和数据 62 | ctx.RespStatusCode = http.StatusOK 63 | ctx.Header("Content-Type", "text/html") 64 | ctx.RespData = []byte("

Mist HTTP/3 服务器

路由测试首页

") 65 | }) 66 | 67 | // JSON API测试 68 | server.GET("/api/test", func(ctx *mist.Context) { 69 | ctx.RespondWithJSON(http.StatusOK, map[string]interface{}{ 70 | "message": "HTTP/3服务器测试成功", 71 | "time": time.Now().Format(time.RFC3339), 72 | }) 73 | }) 74 | 75 | // RESTful API路由测试 76 | apiGroup := "/api/v1" 77 | 78 | // 获取用户列表 79 | server.GET(apiGroup+"/users", func(ctx *mist.Context) { 80 | users := []map[string]interface{}{ 81 | {"id": 1, "name": "用户1", "email": "user1@example.com"}, 82 | {"id": 2, "name": "用户2", "email": "user2@example.com"}, 83 | } 84 | ctx.RespondWithJSON(http.StatusOK, users) 85 | }) 86 | 87 | // 获取特定用户 88 | server.GET(apiGroup+"/users/{id}", func(ctx *mist.Context) { 89 | id := ctx.PathParams["id"] 90 | ctx.RespondWithJSON(http.StatusOK, map[string]interface{}{ 91 | "id": id, 92 | "name": "用户" + id, 93 | "email": "user" + id + "@example.com", 94 | }) 95 | }) 96 | 97 | // 创建用户(需要认证) 98 | server.POST(apiGroup+"/users", func(ctx *mist.Context) { 99 | var userData map[string]interface{} 100 | if err := ctx.BindJSON(&userData); err != nil { 101 | ctx.RespondWithJSON(http.StatusBadRequest, map[string]interface{}{ 102 | "error": "无效的JSON数据", 103 | }) 104 | return 105 | } 106 | 107 | // 模拟创建用户 108 | userData["id"] = 100 109 | userData["created_at"] = time.Now().Format(time.RFC3339) 110 | 111 | ctx.RespondWithJSON(http.StatusCreated, userData) 112 | }, authMiddleware) 113 | 114 | // 正则路由测试 115 | server.GET("/posts/{year:\\d{4}}/{month:\\d{2}}", func(ctx *mist.Context) { 116 | year := ctx.PathParams["year"] 117 | month := ctx.PathParams["month"] 118 | 119 | ctx.RespondWithJSON(http.StatusOK, map[string]interface{}{ 120 | "year": year, 121 | "month": month, 122 | "posts": []string{"文章1", "文章2", "文章3"}, 123 | }) 124 | }) 125 | 126 | // 通配符路由 127 | server.GET("/files/{*filepath}", func(ctx *mist.Context) { 128 | filepath := ctx.PathParams["filepath"] 129 | ctx.RespondWithJSON(http.StatusOK, map[string]interface{}{ 130 | "filepath": filepath, 131 | "message": "请求的文件路径: " + filepath, 132 | }) 133 | }) 134 | 135 | // 生成临时的自签名证书(仅用于测试) 136 | certFile, keyFile, err := generateSelfSignedCert() 137 | if err != nil { 138 | log.Fatalf("生成证书失败: %v", err) 139 | } 140 | defer os.Remove(certFile) 141 | defer os.Remove(keyFile) 142 | 143 | // 输出测试路由信息 144 | fmt.Println("==== 路由测试指南 ====") 145 | fmt.Println("基本路由: https://localhost:8443/") 146 | fmt.Println("JSON API: https://localhost:8443/api/test") 147 | fmt.Println("用户列表: https://localhost:8443/api/v1/users") 148 | fmt.Println("单个用户: https://localhost:8443/api/v1/users/1") 149 | fmt.Println("创建用户(需认证): POST https://localhost:8443/api/v1/users") 150 | fmt.Println(" 需要添加请求头: Authorization: Bearer test-token") 151 | fmt.Println(" 请求体: {\"name\":\"新用户\",\"email\":\"new@example.com\"}") 152 | fmt.Println("正则路由: https://localhost:8443/posts/2023/04") 153 | fmt.Println("通配符路由: https://localhost:8443/files/path/to/some/file.txt") 154 | fmt.Println("======================") 155 | 156 | // 启动HTTP/3服务器(非阻塞) 157 | go func() { 158 | fmt.Printf("启动HTTP/3服务器于 https://localhost:8443\n") 159 | fmt.Printf("注意:这是自签名证书,浏览器会显示安全警告\n") 160 | 161 | if err := server.StartHTTP3(":8443", certFile, keyFile); err != nil { 162 | log.Printf("HTTP/3服务器错误: %v", err) 163 | } 164 | }() 165 | 166 | // 等待中断信号 167 | quit := make(chan os.Signal, 1) 168 | signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) 169 | <-quit 170 | 171 | fmt.Println("正在关闭服务器...") 172 | } 173 | 174 | // 生成自签名证书(仅用于测试) 175 | func generateSelfSignedCert() (string, string, error) { 176 | // 这里仅为测试目的,使用预先生成的自签名证书 177 | tempDir, err := os.MkdirTemp("", "http3-test") 178 | if err != nil { 179 | return "", "", err 180 | } 181 | 182 | certFile := filepath.Join(tempDir, "cert.pem") 183 | keyFile := filepath.Join(tempDir, "key.pem") 184 | 185 | // 证书数据(仅用于测试,请勿在生产环境中使用) 186 | certData := `-----BEGIN CERTIFICATE----- 187 | MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw 188 | DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow 189 | EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d 190 | 7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B 191 | 5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr 192 | BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1 193 | NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l 194 | Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc 195 | 6MF9+Yw1Yy0t 196 | -----END CERTIFICATE-----` 197 | 198 | keyData := `-----BEGIN EC PRIVATE KEY----- 199 | MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49 200 | AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q 201 | EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== 202 | -----END EC PRIVATE KEY-----` 203 | 204 | if err := os.WriteFile(certFile, []byte(certData), 0644); err != nil { 205 | return "", "", err 206 | } 207 | 208 | if err := os.WriteFile(keyFile, []byte(keyData), 0644); err != nil { 209 | return "", "", err 210 | } 211 | 212 | return certFile, keyFile, nil 213 | } 214 | --------------------------------------------------------------------------------