├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── ac └── ac.go ├── bitmap ├── bitmap.go ├── bitmap1.go └── bitmap_test.go ├── cache ├── cache.go └── cache_test.go ├── closes └── closes.go ├── code_err ├── code_err.go └── context.go ├── csvx └── csvx.go ├── db ├── emysql │ ├── client.go │ ├── default_client.go │ ├── emysql.go │ ├── emysql_config │ │ └── emysql_config.go │ ├── emysql_test.go │ ├── gorm.go │ ├── interceptor │ │ ├── explain.go │ │ ├── interceptor.go │ │ ├── log.go │ │ ├── logger.go │ │ ├── metrics.go │ │ ├── start_time.go │ │ └── trace.go │ ├── internal │ │ └── dsn │ │ │ ├── mysql.go │ │ │ └── mysql_test.go │ ├── manager │ │ ├── dsn.go │ │ └── manager.go │ ├── map.go │ ├── mysql_map.go │ ├── sqlx.go │ └── table.go └── eredis │ ├── eredis.go │ ├── eredis_config │ └── eredis_config.go │ ├── eredis_test.go │ ├── interceptor │ ├── interceptor.go │ ├── log.go │ ├── metric.go │ ├── start_time.go │ └── trace.go │ └── redis_map.go ├── ecodes ├── convert.go └── convert_test.go ├── econfig ├── const.go ├── econfig.go ├── eviper │ ├── viper.go │ └── viper_test.go └── nacos │ ├── local.go │ ├── nacos.go │ └── viper.go ├── ectx └── ectx.go ├── eerror ├── atomicerror.go ├── atomicerror_test.go ├── batcherror.go ├── batcherror_test.go └── bizerror.go ├── ekafka ├── config.go ├── consumergroup.go ├── container.go ├── message.go ├── metrics.go └── producer.go ├── elog ├── elog.go ├── ezap │ ├── config.go │ ├── console.go │ ├── ezap.go │ ├── ezap_test.go │ └── file.go ├── fields.go ├── level.go ├── log.go ├── log_test.go ├── logx │ └── logx.go └── sls │ └── sls.go ├── emath └── emath.go ├── env └── env.go ├── eprometheus ├── config.go └── eprometheus.go ├── etrace ├── baggage.go ├── baggage_test.go ├── ejaeger │ └── ejaeger.go ├── fskywalking │ └── skywalking.go ├── grpc.go ├── http.go └── trace.go ├── filex ├── filex.go ├── filex_test.go └── test.csv ├── fmetric ├── README.md ├── counter.go ├── gauge.go ├── grpc.go ├── grpc_test.go ├── histogram.go ├── metric.go └── summary.go ├── go.mod ├── go.sum ├── gpool └── gpool.go ├── grpc ├── grpc_client │ ├── README.md │ ├── grpc_client.go │ ├── grpc_client_config │ │ └── grpc_client_config.go │ ├── grpc_client_test.go │ └── interceptor │ │ ├── grpc_header_carrier.go │ │ ├── log.go │ │ ├── metric.go │ │ └── timeout.go ├── grpc_server │ ├── grpc_server.go │ ├── grpc_server_config │ │ └── grpc_server_config.go │ ├── grpc_server_test.go │ ├── interceptor.go │ └── interceptor │ │ ├── grpc_header_carrier.go │ │ ├── log.go │ │ ├── metric.go │ │ ├── metrics.go │ │ ├── recovery.go │ │ ├── timeout.go │ │ └── timeout_test.go └── proto │ ├── response │ ├── response.pb.go │ └── response.proto │ └── user │ ├── user.pb.go │ ├── user.proto │ └── user_grpc.pb.go ├── http ├── http_client │ ├── http_client.go │ ├── http_client_config │ │ └── config.go │ ├── http_client_test.go │ ├── interceptor │ │ ├── header_carrier.go │ │ ├── log.go │ │ ├── metric.go │ │ ├── start_time.go │ │ └── trace.go │ ├── ip.go │ ├── log.go │ ├── query.go │ └── trace.go └── http_server │ ├── http_server.go │ ├── http_server_config │ └── http_server_config.go │ ├── http_server_test.go │ ├── interceptor │ ├── auth.go │ ├── header_carrier.go │ ├── log.go │ ├── metric.go │ ├── start_time.go │ ├── timeout.go │ ├── token.go │ └── trace.go │ └── service │ ├── response.go │ └── service_context.go ├── invite_code ├── invite_code.go └── invite_code_test.go ├── lang ├── lang.go └── lang_test.go ├── list └── list.go ├── logx └── logx.go ├── mapreduce ├── mapreduce.go └── mapreduce_test.go ├── mapx ├── mapx.go └── order_map.go ├── monitor ├── dingtalk │ └── dingtalk.go ├── larkbot │ └── larkbot.go └── monitor.go ├── print └── print.go ├── retry ├── retry.go └── retry_test.go ├── run └── daemon.go ├── set ├── set.go ├── stringset.go └── stringset_test.go ├── sortedset ├── border.go ├── skiplist.go ├── skiplist_test.go ├── sortedset.go └── sortedset_test.go ├── sortx └── sort.go ├── stringx └── stringx.go ├── syncx ├── atomicbool.go ├── atomicbool_test.go ├── atomicduration.go ├── atomicduration_test.go ├── atomicfloat64.go ├── atomicfloat64_test.go ├── concurrentdoublemap.go ├── concurrentmap.go ├── cond.go ├── cond_test.go ├── donechan.go ├── donechan_test.go ├── limit.go ├── limit_test.go ├── lockedcalls.go ├── lockedcalls_test.go ├── managedresource.go ├── managedresource_test.go ├── once.go ├── once_test.go ├── onceguard.go ├── onceguard_test.go ├── pool.go ├── pool_test.go ├── refresource.go ├── refresource_test.go ├── resourcemanager.go ├── resourcemanager_test.go ├── sharedcalls.go ├── sharedcalls_test.go ├── spinlock.go ├── spinlock_test.go ├── timeoutlimit.go └── timeoutlimit_test.go ├── system ├── shutdown+polyfill.go ├── shutdown.go └── signals.go ├── threading ├── rescue.go ├── rescue_test.go ├── routinegroup.go ├── routinegroup_test.go ├── routines.go ├── routines_test.go ├── taskrunner.go ├── taskrunner_test.go ├── timeout.go └── timeout_test.go ├── timex ├── relativetime.go ├── relativetime_test.go ├── ticker.go ├── ticker_test.go ├── timex.go └── utils.go ├── timingwheel └── timingwheel.go ├── transport ├── grpc_transport.go ├── transport.go └── transport_test.go ├── utils └── run │ └── daemon.go └── webrtc_sfu └── webrtc_sfu.go /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weblazy/easy/3e8d8967888f582a7ebc394b4495ed84889ed7d9/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | dist 3 | .idea 4 | vendor 5 | *.log* 6 | tmp 7 | *.yaml 8 | .DS_Store 9 | main -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 功能列表 2 | 3 | - grpc_server: 4 | - 日志插件 5 | - metric插件 6 | - recovery插件 7 | - timeout插件 8 | - trace插件 9 | - grpc_client: 10 | - 日志插件 11 | - metric插件 12 | - timeout插件 13 | - trace插件 14 | - http_server: github.com/gin-gonic/gin 15 | - 日志插件 16 | - metric插件 17 | - recovery插件 18 | - timeout插件 19 | - trace插件 20 | - token验签插件 21 | - 解密插件 22 | - header头透传插件 23 | - http_client: github.com/go-resty/resty/v2 24 | - 日志插件 25 | - metric插件 26 | - timeout插件 27 | - trace插件 28 | - db: gorm.io/gorm 29 | - 日志插件 30 | - metric插件 31 | - timeout插件 32 | - trace插件 33 | - 脚手架: orm 34 | - redis: github.com/go-redis/redis/v8 35 | - 日志插件 36 | - metric插件 37 | - timeout插件 38 | - trace插件 39 | - log: go.uber.org/zap 40 | - config: github.com/spf13/viper 41 | - 监控面板: prometheus+grafana 42 | - 告警: lark+钉钉 43 | - 脚手架: github.com/weblazy/easy-cli 44 | - cli: github.com/urfave/cli/v2 45 | - cron: github.com/robfig/cron 46 | - trace: go.opentelemetry.io/otel/trace 47 | 48 | # easy 49 | 50 | 如果大家觉得好用,右上角帮忙点个star吧。(^_^) 51 | > 欢迎感兴趣的小伙伴一同开发,收集日常好用的golang工具包。 52 | # 联系我们 53 | - 技术支持/合作/咨询请联系作者QQ: 2276282419 54 | - 作者邮箱: 2276282419@qq.com 55 | - 即时通讯技术交流QQ群: 33280853 56 | ##### 单元测试 57 | ``` 58 | go test -coverpkg=./... -coverprofile=coverage.data -timeout=5s ./... 59 | go tool cover -html=coverage.data -o coverage.html 60 | ```` 61 | [![Go Report Card](https://goreportcard.com/badge/github.com/sunmi-OS/gocore)](https://goreportcard.com/report/github.com/sunmi-OS/gocore/v2.0.9) 62 | 63 | -------------------------------------------------------------------------------- /ac/ac.go: -------------------------------------------------------------------------------- 1 | package ac 2 | 3 | type AcNode struct { 4 | fail *AcNode 5 | next map[byte]*AcNode 6 | length []int 7 | } 8 | 9 | func newAcNode() *AcNode { 10 | return &AcNode{ 11 | fail: nil, 12 | next: map[byte]*AcNode{}, 13 | } 14 | } 15 | 16 | type AcAutoMachine struct { 17 | root *AcNode 18 | size int64 19 | } 20 | 21 | func NewAcAutoMachine() *AcAutoMachine { 22 | return &AcAutoMachine{ 23 | root: newAcNode(), 24 | } 25 | } 26 | 27 | // 构造前缀树 28 | func (ac *AcAutoMachine) AddPattern(pattern string) { 29 | chars := []byte(pattern) 30 | iter := ac.root 31 | var length int 32 | for _, c := range chars { 33 | if _, ok := iter.next[c]; !ok { 34 | iter.next[c] = newAcNode() 35 | } 36 | iter = iter.next[c] 37 | length++ 38 | } 39 | iter.length = append(iter.length, length) 40 | ac.size++ 41 | } 42 | 43 | // 构建fail指针 44 | func (ac *AcAutoMachine) Build() { 45 | queue := []*AcNode{} 46 | queue = append(queue, ac.root) 47 | // 规则1:对节点层序遍历 48 | for len(queue) != 0 { 49 | parent := queue[0] 50 | queue = queue[1:] 51 | // 遍历第一个元素的子节点 52 | for char, child := range parent.next { 53 | if parent == ac.root { 54 | // fail指针规则2:第二层节点(根节点的孩子)fail指针指向根节点 55 | child.fail = ac.root 56 | } else { 57 | // 规则3:查找父节点的fail指针是否有与自己相同的子节点 58 | failAcNode := parent.fail 59 | for failAcNode != nil { 60 | if _, ok := failAcNode.next[char]; ok { 61 | child.fail = failAcNode.next[char] 62 | child.length = append(child.length, failAcNode.next[char].length...) 63 | break 64 | } 65 | failAcNode = failAcNode.fail 66 | } 67 | // failAcNode == ac.root找不到匹配的fail指针,指向根节点 68 | if failAcNode == nil { 69 | child.fail = ac.root 70 | } 71 | } 72 | queue = append(queue, child) 73 | } 74 | } 75 | } 76 | 77 | // 匹配敏感词 78 | func (ac *AcAutoMachine) Query(content string) (results []string) { 79 | chars := []byte(content) 80 | iter := ac.root 81 | 82 | respMap := make(map[string]bool) 83 | data := []string{} 84 | for i, c := range chars { 85 | _, ok := iter.next[c] 86 | for !ok && iter != ac.root { 87 | // 匹配失败从fail指针开始尝试子串 88 | iter = iter.fail 89 | _, ok = iter.next[c] 90 | } 91 | 92 | iter = iter.next[c] 93 | if iter == nil { 94 | iter = ac.root 95 | } 96 | parent := iter 97 | if parent != ac.root && len(parent.length) > 0 { 98 | // 匹配成功 99 | for _, length := range parent.length { 100 | respMap[string([]byte(content)[i+1-length:i+1])] = true 101 | data = append(data, string([]byte(content)[i+1-length:i+1])) 102 | } 103 | } 104 | } 105 | for word := range respMap { 106 | results = append(results, word) 107 | } 108 | return 109 | } 110 | -------------------------------------------------------------------------------- /bitmap/bitmap1.go: -------------------------------------------------------------------------------- 1 | package bitmap 2 | 3 | type BitMap []byte 4 | 5 | func New() *BitMap { 6 | b := BitMap(make([]byte, 0)) 7 | return &b 8 | } 9 | 10 | func toByteSize(bitSize int64) int64 { 11 | if bitSize%8 == 0 { 12 | return bitSize / 8 13 | } 14 | return bitSize/8 + 1 15 | } 16 | 17 | func (b *BitMap) grow(bitSize int64) { 18 | byteSize := toByteSize(bitSize) 19 | gap := byteSize - int64(len(*b)) 20 | if gap <= 0 { 21 | return 22 | } 23 | *b = append(*b, make([]byte, gap)...) 24 | } 25 | 26 | func (b *BitMap) BitSize() int { 27 | return len(*b) * 8 28 | } 29 | 30 | func FromBytes(bytes []byte) *BitMap { 31 | bm := BitMap(bytes) 32 | return &bm 33 | } 34 | 35 | func (b *BitMap) ToBytes() []byte { 36 | return *b 37 | } 38 | 39 | func (b *BitMap) SetBit(offset int64, val byte) { 40 | byteIndex := offset / 8 41 | bitOffset := offset % 8 42 | mask := byte(1 << bitOffset) 43 | b.grow(offset + 1) 44 | if val > 0 { 45 | // set bit 46 | (*b)[byteIndex] |= mask 47 | } else { 48 | // clear bit 49 | (*b)[byteIndex] &^= mask 50 | } 51 | } 52 | 53 | func (b *BitMap) GetBit(offset int64) byte { 54 | byteIndex := offset / 8 55 | bitOffset := offset % 8 56 | if byteIndex >= int64(len(*b)) { 57 | return 0 58 | } 59 | return ((*b)[byteIndex] >> bitOffset) & 0x01 60 | } 61 | 62 | type Callback func(offset int64, val byte) bool 63 | 64 | func (b *BitMap) ForEachBit(begin int64, end int64, cb Callback) { 65 | offset := begin 66 | byteIndex := offset / 8 67 | bitOffset := offset % 8 68 | for byteIndex < int64(len(*b)) { 69 | b := (*b)[byteIndex] 70 | for bitOffset < 8 { 71 | bit := byte(b >> bitOffset & 0x01) 72 | if !cb(offset, bit) { 73 | return 74 | } 75 | bitOffset++ 76 | offset++ 77 | if offset >= end && end != 0 { 78 | break 79 | } 80 | } 81 | byteIndex++ 82 | bitOffset = 0 83 | if end > 0 && offset >= end { 84 | break 85 | } 86 | } 87 | } 88 | 89 | func (b *BitMap) ForEachByte(begin int, end int, cb Callback) { 90 | if end == 0 { 91 | end = len(*b) 92 | } else if end > len(*b) { 93 | end = len(*b) 94 | } 95 | for i := begin; i < end; i++ { 96 | if !cb(int64(i), (*b)[i]) { 97 | return 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /closes/closes.go: -------------------------------------------------------------------------------- 1 | package closes 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "os/signal" 7 | "sort" 8 | "syscall" 9 | ) 10 | 11 | type ( 12 | ModuleClose struct { 13 | Name string 14 | Priority int 15 | Func func() 16 | } 17 | closes []ModuleClose 18 | ) 19 | 20 | var closeHandler closes 21 | 22 | const ( 23 | MQPriority = 100 24 | GormPriority = 500 25 | RedisPriority = 500 26 | AliLogPriority = 2000 27 | ) 28 | 29 | func (c closes) Len() int { return len(c) } 30 | func (c closes) Less(i, j int) bool { return c[i].Priority < c[j].Priority } 31 | func (c closes) Swap(i, j int) { c[i], c[j] = c[j], c[i] } 32 | 33 | // AddShutdown 增加程序结束时需要关闭的服务 34 | func AddShutdown(c ...ModuleClose) { 35 | closeHandler = append(closeHandler, c...) 36 | } 37 | 38 | // SignalClose 监听信号阻塞关闭 39 | func SignalClose() { 40 | c := make(chan os.Signal, 1) 41 | signal.Notify(c, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGTSTP) 42 | sig := <-c 43 | fmt.Printf("Got %s signal. Aborting...\n", sig) 44 | Close() 45 | } 46 | 47 | // Close 按照优先级调用关闭方法 48 | func Close() { 49 | sort.Sort(closeHandler) 50 | if len(closeHandler) > 0 { 51 | for _, f := range closeHandler { 52 | fmt.Printf("Close %s ...\n", f.Name) 53 | f.Func() 54 | } 55 | } 56 | os.Exit(0) 57 | } 58 | -------------------------------------------------------------------------------- /code_err/code_err.go: -------------------------------------------------------------------------------- 1 | package code_err 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/weblazy/easy/elog" 7 | "go.uber.org/zap" 8 | ) 9 | 10 | var ( 11 | SystemErr = NewCodeErr(-1, "系统错误") 12 | ParamsErr = NewCodeErr(100001, "参数错误") 13 | TokenErr = NewCodeErr(100002, "无效Token") 14 | EncryptErr = NewCodeErr(100003, "加密失败") 15 | DecryptErr = NewCodeErr(100004, "解密失败") 16 | SignErr = NewCodeErr(100005, "签名失败") 17 | ) 18 | 19 | type CodeErr struct { 20 | Code int64 21 | Msg string 22 | } 23 | 24 | func (err *CodeErr) Error() string { 25 | return err.Msg 26 | } 27 | 28 | func NewCodeErr(code int64, msg string) *CodeErr { 29 | return &CodeErr{ 30 | Code: code, 31 | Msg: msg, 32 | } 33 | } 34 | 35 | func GetCodeErr(err error) *CodeErr { 36 | if err == nil { 37 | return nil 38 | } 39 | if v, ok := err.(*CodeErr); ok { 40 | return v 41 | } 42 | return SystemErr 43 | } 44 | 45 | // 打印msg和err 46 | func LogErr(ctx context.Context, codeErr *CodeErr, msg string, err error) *CodeErr { 47 | if v, ok := err.(*CodeErr); ok { 48 | return v 49 | } 50 | elog.ErrorCtx(elog.AddCtxSkip(ctx, 2), msg, elog.FieldError(err)) 51 | return codeErr 52 | } 53 | 54 | // 打印field 55 | func LogField(ctx context.Context, codeErr *CodeErr, msg string, fields ...zap.Field) *CodeErr { 56 | elog.ErrorCtx(elog.AddCtxSkip(ctx, 2), msg, fields...) 57 | return codeErr 58 | } 59 | -------------------------------------------------------------------------------- /code_err/context.go: -------------------------------------------------------------------------------- 1 | package code_err 2 | 3 | import ( 4 | "context" 5 | 6 | "go.uber.org/zap" 7 | ) 8 | 9 | type Log struct { 10 | Ctx context.Context 11 | } 12 | 13 | func NewLog(ctx context.Context) *Log { 14 | return &Log{Ctx: ctx} 15 | } 16 | 17 | // 打印log 18 | func (c *Log) LogErr(codeErr *CodeErr, msg string, err error) *CodeErr { 19 | return LogErr(c.Ctx, codeErr, msg, err) 20 | } 21 | 22 | // 打印log 23 | func (c *Log) LogField(codeErr *CodeErr, msg string, fields ...zap.Field) *CodeErr { 24 | return LogField(c.Ctx, codeErr, msg, fields...) 25 | } 26 | -------------------------------------------------------------------------------- /csvx/csvx.go: -------------------------------------------------------------------------------- 1 | package csvx 2 | 3 | import ( 4 | "bufio" 5 | "encoding/csv" 6 | "os" 7 | "strings" 8 | ) 9 | 10 | type CSV struct { 11 | path string 12 | wfile *os.File 13 | rfile *os.File 14 | w *csv.Writer 15 | r *bufio.Reader 16 | rowSeparator string 17 | } 18 | 19 | // NewCSV return a CSV 20 | func NewCSV(path string, rowSeparator rune, lineSeparator string) (*CSV, error) { 21 | wfile, err := os.OpenFile(path, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666) 22 | if err != nil { 23 | return nil, err 24 | } 25 | rfile, err := os.OpenFile(path, os.O_RDONLY, 0666) 26 | if err != nil { 27 | return nil, err 28 | } 29 | w := csv.NewWriter(wfile) 30 | w.Comma = rowSeparator 31 | if lineSeparator == `\r\n` { 32 | w.UseCRLF = true 33 | } 34 | r := bufio.NewReader(rfile) 35 | return &CSV{ 36 | path: path, 37 | wfile: wfile, 38 | rfile: rfile, 39 | w: w, 40 | r: r, 41 | rowSeparator: string(rowSeparator), 42 | }, nil 43 | } 44 | 45 | // Write truncate and write one line 46 | func (this *CSV) Write(str []string) error { 47 | err := this.wfile.Truncate(0) 48 | if err != nil { 49 | return err 50 | } 51 | err = this.w.Write(str) 52 | if err != nil { 53 | return err 54 | } 55 | this.w.Flush() 56 | return nil 57 | } 58 | 59 | // Append append one line 60 | func (this *CSV) Append(str []string) error { 61 | err := this.w.Write(str) 62 | if err != nil { 63 | return err 64 | } 65 | this.w.Flush() 66 | return nil 67 | } 68 | 69 | // Truncate 70 | func (this *CSV) Truncate() error { 71 | return this.wfile.Truncate(0) 72 | } 73 | 74 | // Reset 75 | func (this *CSV) Reset() (int64, error) { 76 | return this.rfile.Seek(0, 0) 77 | } 78 | 79 | // ReadLine read one line 80 | func (this *CSV) ReadLine() ([]string, error) { 81 | line, _, err := this.r.ReadLine() //以'\n'为结束符读入一行 82 | return strings.Split(string(line), this.rowSeparator), err 83 | } 84 | 85 | // Close close file 86 | func (this *CSV) Close() error { 87 | err := this.wfile.Close() 88 | if err != nil { 89 | return err 90 | } 91 | return this.rfile.Close() 92 | } 93 | -------------------------------------------------------------------------------- /db/emysql/client.go: -------------------------------------------------------------------------------- 1 | package emysql 2 | 3 | import "gorm.io/gorm" 4 | 5 | // var MysqlClient Client 6 | 7 | type GetMysqlDB func(key string) *gorm.DB 8 | 9 | var GetDB GetMysqlDB 10 | 11 | type Client interface { 12 | GetDB(key string) *gorm.DB 13 | } 14 | -------------------------------------------------------------------------------- /db/emysql/default_client.go: -------------------------------------------------------------------------------- 1 | package emysql 2 | 3 | func init() { 4 | // MysqlClient = &DefaultClient{} 5 | // GetDB = func(key string) *gorm.DB { 6 | // db := GetORM(key) 7 | // return db 8 | // } 9 | } 10 | 11 | type DefaultClient struct { 12 | Client 13 | } 14 | 15 | // func (*DefaultClient) GetDB(key string) *gorm.DB { 16 | // db := GetORM(key) 17 | // return db 18 | // } 19 | -------------------------------------------------------------------------------- /db/emysql/emysql_config/emysql_config.go: -------------------------------------------------------------------------------- 1 | package emysql_config 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/weblazy/easy/db/emysql/manager" 7 | "gorm.io/gorm" 8 | ) 9 | 10 | const ( 11 | PkgName = "emysql" 12 | ) 13 | 14 | // Config options 15 | type Config struct { 16 | Name string 17 | Dialect string // 选择数据库种类,默认mysql,postgres 18 | DSN string // DSN地址: username:password@tcp(127.0.0.1:3306)/mysql?charset=utf8mb4&collation=utf8mb4_general_ci&parseTime=True&loc=Local&timeout=1s&readTimeout=3s&writeTimeout=3s 19 | Debug bool // 是否开启调试,默认不开启,开启后,可以看到每次请求,配置名、地址、耗时、请求数据、响应数据, 标准输出, 生产请使用 access log 20 | RawDebug bool // 是否开启原生调试开关,默认不开启 21 | MaxIdleConns int // 最大空闲连接数,默认10 22 | MaxOpenConns int // 最大活动连接数,默认100 23 | ConnMaxLifetime time.Duration // 连接的最大存活时间,默认300s 24 | OnFail string // 创建连接的错误级别,=panic时,如果创建失败,立即panic,默认连接不上panic 25 | SlowLogThreshold time.Duration // 慢日志阈值,默认500ms 26 | EnableMetricInterceptor bool // 是否开启监控,默认开启 27 | EnableTraceInterceptor bool // 是否开启链路追踪,默认开启 28 | EnableDetailSQL bool // 是否打印包含参数的完整sql语句,select * from aid = ?; 29 | EnableAccessInterceptor bool // 是否开启,记录请求数据 30 | EnableAccessInterceptorReq bool // 是否开启记录请求参数 31 | EnableAccessInterceptorRes bool // 是否开启记录响应参数 32 | EnableRecordNotFoundLog bool // ErrRecordNotFound 错误时是否打印 warn 日志, 默认开启 33 | // Deprecated: not affect anything 34 | EnableSkyWalking bool // 是否额外开启 skywalking, 默认关闭 35 | 36 | Interceptors []Interceptor 37 | DsnCfg *manager.DSN 38 | } 39 | 40 | // Interceptor ... 41 | type Interceptor func(string, *manager.DSN, string, *Config) func(next Handler) Handler 42 | 43 | // Handler ... 44 | type Handler func(*gorm.DB) 45 | 46 | // DefaultConfig 返回默认配置 47 | func DefaultConfig() *Config { 48 | return &Config{ 49 | DSN: "", 50 | Dialect: "mysql", 51 | Debug: false, 52 | MaxIdleConns: 10, 53 | MaxOpenConns: 100, 54 | ConnMaxLifetime: time.Second * 300, 55 | OnFail: "panic", 56 | SlowLogThreshold: time.Millisecond * 500, 57 | EnableMetricInterceptor: false, 58 | EnableTraceInterceptor: true, 59 | EnableRecordNotFoundLog: true, 60 | // EnableAccessInterceptor: true, 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /db/emysql/emysql_test.go: -------------------------------------------------------------------------------- 1 | package emysql 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/smartystreets/goconvey/convey" 7 | ) 8 | 9 | type User struct { 10 | Id int64 11 | Name string 12 | } 13 | 14 | func (*User) TableName() string { 15 | return "user" 16 | } 17 | 18 | func TestNewMysqlClient(t *testing.T) { 19 | convey.Convey("TestNewMysqlClient", t, func() { 20 | // cfg := emysql_config.DefaultConfig() 21 | // cfg.DSN = "root:123456@tcp(localhost:13306)/test?charset=utf8mb4&collation=utf8mb4_general_ci&parseTime=True&loc=Local&timeout=1s&readTimeout=3s&writeTimeout=3s" 22 | // client, err := NewMysqlClient(cfg) 23 | // convey.So(err, convey.ShouldBeNil) 24 | // resp := User{} 25 | // err = client.WithContext(context.Background()).Where("id != ?", 1).Find(&resp).Error 26 | // convey.So(err, convey.ShouldBeNil) 27 | // fmt.Printf("resp%#v\n", resp) 28 | // convey.So(resp, convey.ShouldNotBeNil) 29 | // convey.So(err, convey.ShouldBeNil) 30 | }) 31 | } 32 | -------------------------------------------------------------------------------- /db/emysql/interceptor/interceptor.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "gorm.io/gorm" 5 | ) 6 | 7 | const ( 8 | TypeGorm = "gorm" 9 | ) 10 | 11 | var ( 12 | // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct 13 | ErrRecordNotFound = gorm.ErrRecordNotFound 14 | // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` 15 | ErrInvalidTransaction = gorm.ErrInvalidTransaction 16 | ) 17 | 18 | // // 确保在生产不要开 debug 19 | // func DebugInterceptor(compName string, dsn *manager.DSN, op string, options *emysql_config.Config) func(emysql_config.Handler) emysql_config.Handler { 20 | // return func(next emysql_config.Handler) emysql_config.Handler { 21 | // return func(db *gorm.DB) { 22 | // beg := time.Now() 23 | // next(db) 24 | // duration := time.Since(beg) 25 | // if db.Error != nil { 26 | // elog.ErrorCtx(db.Statement.Context, "fgorm.response", elog.MakeReqResError(1, compName, dsn.Addr+"/"+dsn.DBName, duration, logSQL(db.Statement.SQL.String(), db.Statement.Vars, true), db.Error.Error())) 27 | // } else { 28 | // elog.InfoCtx(db.Statement.Context, "fgorm.response", elog.MakeReqResInfo(1, compName, dsn.Addr+"/"+dsn.DBName, duration, logSQL(db.Statement.SQL.String(), db.Statement.Vars, true), fmt.Sprintf("%v", db.Statement.Dest))) 29 | // } 30 | // } 31 | // } 32 | // } 33 | -------------------------------------------------------------------------------- /db/emysql/interceptor/logger.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "reflect" 7 | "regexp" 8 | "time" 9 | "unicode" 10 | ) 11 | 12 | var ( 13 | sqlRegexp = regexp.MustCompile(`\?`) 14 | numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`) 15 | ) 16 | 17 | func logSQL(sql string, args []interface{}, containArgs bool) string { 18 | if containArgs { 19 | return bindSQL(sql, args) 20 | } 21 | return sql 22 | } 23 | 24 | // from gorm.LogFormatter 25 | func bindSQL(oriSql string, args []interface{}) (sql string) { 26 | formattedValues := make([]string, 0) 27 | for _, value := range args { 28 | indirectValue := reflect.Indirect(reflect.ValueOf(value)) 29 | if indirectValue.IsValid() { 30 | value = indirectValue.Interface() 31 | if t, ok := value.(time.Time); ok { //nolint 32 | formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05"))) 33 | } else if b, ok := value.([]byte); ok { 34 | if str := string(b); isPrintable(str) { 35 | formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str)) 36 | } else { 37 | formattedValues = append(formattedValues, "''") 38 | } 39 | } else if r, ok := value.(driver.Valuer); ok { 40 | if value, err := r.Value(); err == nil && value != nil { 41 | formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) 42 | } else { 43 | formattedValues = append(formattedValues, "NULL") 44 | } 45 | } else { 46 | switch value.(type) { 47 | case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool: 48 | formattedValues = append(formattedValues, fmt.Sprintf("%v", value)) 49 | default: 50 | formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) 51 | } 52 | } 53 | } else { 54 | formattedValues = append(formattedValues, "NULL") 55 | } 56 | } 57 | 58 | // differentiate between $n placeholders or else treat like ? 59 | if numericPlaceHolderRegexp.MatchString(oriSql) { 60 | for index, value := range formattedValues { 61 | placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1) 62 | sql = regexp.MustCompile(placeholder).ReplaceAllString(oriSql, value+"$1") 63 | } 64 | } else { 65 | formattedValuesLength := len(formattedValues) 66 | for index, value := range sqlRegexp.Split(oriSql, -1) { 67 | sql += value 68 | if index < formattedValuesLength { 69 | sql += formattedValues[index] 70 | } 71 | } 72 | } 73 | return 74 | } 75 | 76 | func isPrintable(s string) bool { 77 | for _, r := range s { 78 | if !unicode.IsPrint(r) { 79 | return false 80 | } 81 | } 82 | return true 83 | } 84 | -------------------------------------------------------------------------------- /db/emysql/interceptor/start_time.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/weblazy/easy/elog" 8 | "go.uber.org/zap" 9 | "gorm.io/gorm" 10 | ) 11 | 12 | type StartTimePlugin struct{} 13 | 14 | type ctxStartTimeKey struct{} 15 | 16 | func NewStartTimePlugin() *StartTimePlugin { 17 | return &StartTimePlugin{} 18 | } 19 | 20 | func (e *StartTimePlugin) Name() string { 21 | return "start_time" 22 | } 23 | 24 | func (e *StartTimePlugin) Initialize(db *gorm.DB) error { 25 | var lastErr error 26 | beforeErrMsg := "SetStartTimeErr" 27 | beforeName := "SetStartTime" 28 | beforeFn := SetStartTime 29 | err := db.Callback().Query().Before("gorm:query").Register(beforeName, beforeFn) 30 | if err != nil { 31 | lastErr = err 32 | elog.ErrorCtx(db.Statement.Context, beforeErrMsg, zap.Error(err)) 33 | } 34 | err = db.Callback().Create().Before("gorm:create").Register(beforeName, beforeFn) 35 | if err != nil { 36 | lastErr = err 37 | elog.ErrorCtx(db.Statement.Context, beforeErrMsg, zap.Error(err)) 38 | } 39 | err = db.Callback().Update().Before("gorm:update").Register(beforeName, beforeFn) 40 | if err != nil { 41 | lastErr = err 42 | elog.ErrorCtx(db.Statement.Context, beforeErrMsg, zap.Error(err)) 43 | } 44 | err = db.Callback().Delete().Before("gorm:delete").Register(beforeName, beforeFn) 45 | if err != nil { 46 | lastErr = err 47 | elog.ErrorCtx(db.Statement.Context, beforeErrMsg, zap.Error(err)) 48 | } 49 | err = db.Callback().Row().Before("gorm:row").Register(beforeName, beforeFn) 50 | if err != nil { 51 | lastErr = err 52 | elog.ErrorCtx(db.Statement.Context, beforeErrMsg, zap.Error(err)) 53 | } 54 | err = db.Callback().Raw().Before("gorm:raw").Register(beforeName, beforeFn) 55 | if err != nil { 56 | lastErr = err 57 | elog.ErrorCtx(db.Statement.Context, beforeErrMsg, zap.Error(err)) 58 | } 59 | return lastErr 60 | } 61 | 62 | func SetStartTime(db *gorm.DB) { 63 | startTime := time.Now() 64 | db.Statement.Context = context.WithValue(db.Statement.Context, ctxStartTimeKey{}, startTime) 65 | } 66 | 67 | func GetStartTime(db *gorm.DB) time.Time { 68 | return db.Statement.Context.Value(ctxStartTimeKey{}).(time.Time) 69 | } 70 | 71 | func GetDuration(ctx context.Context) time.Duration { 72 | startTime, _ := ctx.Value(ctxStartTimeKey{}).(time.Time) 73 | return time.Since(startTime) 74 | } 75 | 76 | func GetDurationMilliseconds(ctx context.Context) float64 { 77 | startTime, _ := ctx.Value(ctxStartTimeKey{}).(time.Time) 78 | return float64(time.Since(startTime).Microseconds()) / 1000 79 | 80 | } 81 | -------------------------------------------------------------------------------- /db/emysql/internal/dsn/mysql_test.go: -------------------------------------------------------------------------------- 1 | package dsn 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestMysqlDSNParser_ParseDSN(t *testing.T) { 11 | dsn := "user:password@tcp(localhost:9910)/dbname?charset=utf8&parseTime=True" 12 | parser := &MysqlDSNParser{} 13 | cfg, err := parser.ParseDSN(dsn) 14 | assert.NoError(t, err) 15 | assert.Equal(t, "user", cfg.User) 16 | assert.Equal(t, "password", cfg.Password) 17 | assert.Equal(t, "dbname", cfg.DBName) 18 | assert.Equal(t, "localhost:9910", cfg.Addr) 19 | assert.Equal(t, "tcp", cfg.Net) 20 | assert.Equal(t, "utf8", cfg.Params["charset"]) 21 | assert.Equal(t, "True", cfg.Params["parseTime"]) 22 | fmt.Println(cfg) 23 | } 24 | -------------------------------------------------------------------------------- /db/emysql/manager/dsn.go: -------------------------------------------------------------------------------- 1 | package manager 2 | 3 | import ( 4 | "gorm.io/gorm" 5 | ) 6 | 7 | type DSN struct { 8 | User string // Username 9 | Password string // Password (requires User) 10 | Net string // Network type 11 | Addr string // Network address (requires Net) 12 | DBName string // Database name 13 | Params map[string]string // Connection parameters 14 | } 15 | 16 | type DSNParser interface { 17 | GetDialector(dsn string) gorm.Dialector 18 | ParseDSN(dsn string) (cfg *DSN, err error) 19 | Scheme() string 20 | } 21 | -------------------------------------------------------------------------------- /db/emysql/manager/manager.go: -------------------------------------------------------------------------------- 1 | package manager 2 | 3 | var ( 4 | // m is a map from scheme to dsn builder. 5 | m = make(map[string]DSNParser) 6 | ) 7 | 8 | func Register(b DSNParser) { 9 | m[b.Scheme()] = b 10 | } 11 | 12 | // Get returns the dsn builder registered with the given scheme. 13 | // 14 | // If no builder is register with the scheme, nil will be returned. 15 | func Get(scheme string) DSNParser { 16 | if b, ok := m[scheme]; ok { 17 | return b 18 | } 19 | return nil 20 | } 21 | -------------------------------------------------------------------------------- /db/emysql/map.go: -------------------------------------------------------------------------------- 1 | package emysql 2 | 3 | import ( 4 | "database/sql/driver" 5 | "encoding/json" 6 | ) 7 | 8 | type Map map[string]interface{} 9 | 10 | // Value 实现方法 11 | func (m Map) Value() (driver.Value, error) { 12 | return json.Marshal(m) 13 | } 14 | 15 | // Scan 实现方法 16 | func (m *Map) Scan(input interface{}) error { 17 | return json.Unmarshal(input.([]byte), m) 18 | } 19 | -------------------------------------------------------------------------------- /db/emysql/mysql_map.go: -------------------------------------------------------------------------------- 1 | package emysql 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/weblazy/easy/db/emysql/emysql_config" 7 | "github.com/weblazy/easy/econfig" 8 | ) 9 | 10 | var MysqlMap sync.Map 11 | 12 | // GetMysql return a MysqlClient 13 | func GetMysql(dbName string) *MysqlClient { 14 | if v, ok := MysqlMap.Load(dbName); ok { 15 | return v.(*MysqlClient) 16 | } 17 | conf := emysql_config.DefaultConfig() 18 | econfig.GlobalViper.UnmarshalKey(dbName, conf) 19 | mysqlClient, err := NewMysqlClient(conf) 20 | if err != nil { 21 | return nil 22 | } 23 | MysqlMap.Store(dbName, mysqlClient) 24 | return mysqlClient 25 | } 26 | -------------------------------------------------------------------------------- /db/emysql/sqlx.go: -------------------------------------------------------------------------------- 1 | package emysql 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strings" 7 | 8 | "gorm.io/gorm" 9 | ) 10 | 11 | var ( 12 | FieldsError = fmt.Errorf("fileds length is 0") 13 | ) 14 | 15 | // @desc 16 | // @auth liuguoqiang 2020-11-27 17 | // @param 18 | // @return 19 | func BulkInsert(db *gorm.DB, table string, fields []string, params []map[string]interface{}) error { 20 | if len(params) == 0 { 21 | return nil 22 | } 23 | if len(fields) == 0 { 24 | return FieldsError 25 | } 26 | sql := "INSERT INTO `" + table + "` (`" + strings.Join(fields, "`,`") + "`) VALUES " 27 | args := make([]interface{}, 0) 28 | valueArr := make([]string, 0) 29 | varArr := make([]string, 0) 30 | for _, obj := range params { 31 | varArr = varArr[:0] 32 | varStr := "(" 33 | for _, value := range fields { 34 | if _, ok := obj[value]; !ok { 35 | return fmt.Errorf("%s:not found in fields", value) 36 | } 37 | varArr = append(varArr, "?") 38 | args = append(args, obj[value]) 39 | } 40 | varStr += strings.Join(varArr, ",") + ")" 41 | valueArr = append(valueArr, varStr) 42 | } 43 | sql += strings.Join(valueArr, ",") 44 | err := db.Exec(sql, args...).Error 45 | return err 46 | } 47 | 48 | // @desc 批量插入 49 | // @auth liuguoqiang 2020-11-27 50 | // @param 51 | // @return 52 | func BulkSave(db *gorm.DB, table string, fields []string, params []map[string]interface{}) error { 53 | if len(params) == 0 { 54 | return nil 55 | } 56 | if len(fields) == 0 { 57 | return FieldsError 58 | } 59 | sql := "INSERT INTO `" + table + "` (`" + strings.Join(fields, "`,`") + "`) VALUES " 60 | updateArr := make([]string, 0) 61 | args := make([]interface{}, 0) 62 | valueArr := make([]string, 0) 63 | varArr := make([]string, 0) 64 | for _, value := range fields { 65 | updateArr = append(updateArr, "`"+value+"`=VALUES(`"+value+"`)") 66 | } 67 | for _, obj := range params { 68 | varArr = varArr[:0] 69 | varStr := "(" 70 | for _, value := range fields { 71 | if _, ok := obj[value]; !ok { 72 | return fmt.Errorf("%s字段在map中不存在", value) 73 | } 74 | varArr = append(varArr, "?") 75 | args = append(args, obj[value]) 76 | } 77 | varStr += strings.Join(varArr, ",") + ")" 78 | valueArr = append(valueArr, varStr) 79 | } 80 | sql += strings.Join(valueArr, ",") 81 | sql += " ON DUPLICATE KEY UPDATE " + strings.Join(updateArr, ",") 82 | err := db.Exec(sql, args...).Error 83 | return err 84 | } 85 | 86 | // @desc 87 | // @auth liuguoqiang 2020-04-08 88 | // @param 89 | // @return 90 | func Validate(data, model interface{}) bool { 91 | if _, ok := data.(map[string]interface{}); ok { 92 | return true 93 | } 94 | if reflect.TypeOf(data).Kind() == reflect.TypeOf(model).Kind() { 95 | return true 96 | } 97 | return false 98 | } 99 | -------------------------------------------------------------------------------- /db/emysql/table.go: -------------------------------------------------------------------------------- 1 | package emysql 2 | -------------------------------------------------------------------------------- /db/eredis/eredis_config/eredis_config.go: -------------------------------------------------------------------------------- 1 | package eredis_config 2 | 3 | import ( 4 | "strings" 5 | "time" 6 | 7 | "github.com/go-redis/redis/v8" 8 | ) 9 | 10 | const ( 11 | // ClusterMode using clusterClient 12 | ClusterMode string = "cluster" 13 | // SimpleMode using Client 14 | SimpleMode string = "simple" 15 | // FailoverMode using Failover sentinel client 16 | FailoverMode string = "failover" 17 | 18 | PkgName = "eredis" 19 | ) 20 | 21 | // Config for redis, contains RedisStubConfig, RedisClusterConfig and RedisSentinelConfig 22 | type Config struct { 23 | Name string // Name redis名称 24 | Addrs []string // Addrs Cluster,Failover实例配置地址 25 | Addr string // Addr Simple 实例配置地址 26 | Mode string // Mode Redis模式 cluster|simple|failover 27 | MasterName string // MasterName 哨兵主节点名称,sentinel模式下需要配置此项 28 | Password string // Password 密码 29 | DB int // DB,默认为0, 一般应用不推荐使用DB分片 30 | PoolSize int // PoolSize 集群内每个节点的最大连接池限制 默认每个CPU10个连接 31 | 32 | MaxRetries int // MaxRetries 网络相关的错误最大重试次数 默认8次 33 | MinIdleConns int // MinIdleConns 最小空闲连接数 34 | DialTimeout time.Duration // DialTimeout 拨超时时间 35 | ReadTimeout time.Duration // ReadTimeout 读超时 默认3s 36 | WriteTimeout time.Duration // WriteTimeout 读超时 默认3s 37 | IdleTimeout time.Duration // IdleTimeout 连接最大空闲时间,默认60s, 超过该时间,连接会被主动关闭 38 | ReadOnly bool // ReadOnly 集群模式 在从属节点上启用读模式 39 | 40 | EnableMetricInterceptor bool // 是否开启监控,默认开启 41 | EnableTraceInterceptor bool // 是否开启链路,默认开启 42 | 43 | SlowLogThreshold time.Duration // 慢日志门限值,超过该门限值的请求,将被记录到慢日志中 44 | EnableLogAccess bool // 是否开启,成功时也记录请求日志 45 | EnableLogReq bool // 是否开启记录请求参数 46 | EnableLogRes bool // 是否开启记录响应参数 47 | Hooks []redis.Hook 48 | } 49 | 50 | // DefaultConfig default config ... 51 | func DefaultConfig() *Config { 52 | return &Config{ 53 | Mode: SimpleMode, 54 | DB: 0, 55 | PoolSize: 0, // will be handled by redis v8 56 | MaxRetries: 0, 57 | MinIdleConns: 20, 58 | DialTimeout: time.Second, 59 | ReadTimeout: time.Second, 60 | WriteTimeout: time.Second, 61 | IdleTimeout: time.Second * 60, 62 | ReadOnly: false, 63 | SlowLogThreshold: time.Millisecond * 250, 64 | EnableMetricInterceptor: true, 65 | EnableTraceInterceptor: true, 66 | EnableLogAccess: false, 67 | EnableLogReq: true, 68 | EnableLogRes: true, 69 | } 70 | } 71 | 72 | // AddrString 获取地址, 用于监控 73 | // 多个地址会用 , 连接 74 | func (c Config) AddrString() string { 75 | addr := c.Addr 76 | if len(c.Addrs) > 0 { 77 | addr = strings.Join(c.Addrs, ",") 78 | } 79 | return addr 80 | } 81 | -------------------------------------------------------------------------------- /db/eredis/eredis_test.go: -------------------------------------------------------------------------------- 1 | package eredis 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/smartystreets/goconvey/convey" 7 | ) 8 | 9 | func TestNewRedisClient(t *testing.T) { 10 | convey.Convey("TestNewRedisClient", t, func() { 11 | // cfg := eredis_config.DefaultConfig() 12 | // cfg.Addr = "127.0.0.1:16379" 13 | // cfg.Name = "user_redis" 14 | // client := NewRedisClient(cfg) 15 | // cmd := client.Get(context.Background(), "test") 16 | // resp, err := cmd.Result() 17 | // fmt.Printf("%#v\n", resp) 18 | // convey.So(resp, convey.ShouldNotBeNil) 19 | // convey.So(err, convey.ShouldBeNil) 20 | }) 21 | } 22 | -------------------------------------------------------------------------------- /db/eredis/interceptor/interceptor.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/go-redis/redis/v8" 7 | ) 8 | 9 | // https://stackoverflow.com/questions/40891345/fix-should-not-use-basic-type-string-as-key-in-context-withvalue-golint 10 | // https://blog.golang.org/context#TOC_3.2. 11 | // https://golang.org/pkg/context/#WithValue ,这边文章说明了用struct,可以避免分配 12 | 13 | type RedisHook struct { 14 | redis.Hook 15 | beforeProcess func(ctx context.Context, cmd redis.Cmder) (context.Context, error) 16 | afterProcess func(ctx context.Context, cmd redis.Cmder) error 17 | beforeProcessPipeline func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) 18 | afterProcessPipeline func(ctx context.Context, cmds []redis.Cmder) error 19 | } 20 | 21 | func (i *RedisHook) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) { 22 | return i.beforeProcess(ctx, cmd) 23 | } 24 | 25 | func (i *RedisHook) AfterProcess(ctx context.Context, cmd redis.Cmder) error { 26 | return i.afterProcess(ctx, cmd) 27 | } 28 | 29 | func (i *RedisHook) BeforeProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { 30 | return i.beforeProcessPipeline(ctx, cmds) 31 | } 32 | 33 | func (i *RedisHook) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) error { 34 | return i.afterProcessPipeline(ctx, cmds) 35 | } 36 | 37 | func NewRedisHook() *RedisHook { 38 | return &RedisHook{ 39 | beforeProcess: func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { 40 | return ctx, nil 41 | }, 42 | afterProcess: func(ctx context.Context, cmd redis.Cmder) error { 43 | return nil 44 | }, 45 | beforeProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) (context.Context, error) { 46 | return ctx, nil 47 | }, 48 | afterProcessPipeline: func(ctx context.Context, cmds []redis.Cmder) error { 49 | return nil 50 | }, 51 | } 52 | } 53 | 54 | func (i *RedisHook) SetBeforeProcess(p func(ctx context.Context, cmd redis.Cmder) (context.Context, error)) *RedisHook { 55 | i.beforeProcess = p 56 | return i 57 | } 58 | 59 | func (i *RedisHook) SetAfterProcess(p func(ctx context.Context, cmd redis.Cmder) error) *RedisHook { 60 | i.afterProcess = p 61 | return i 62 | } 63 | 64 | func (i *RedisHook) SetBeforeProcessPipeline(p func(ctx context.Context, cmds []redis.Cmder) (context.Context, error)) *RedisHook { //nolint 65 | i.beforeProcessPipeline = p 66 | return i 67 | } 68 | 69 | func (i *RedisHook) SetAfterProcessPipeline(p func(ctx context.Context, cmds []redis.Cmder) error) *RedisHook { //nolint 70 | i.afterProcessPipeline = p 71 | return i 72 | } 73 | -------------------------------------------------------------------------------- /db/eredis/interceptor/log.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "time" 8 | 9 | "github.com/go-redis/redis/v8" 10 | "github.com/weblazy/easy/db/eredis/eredis_config" 11 | "github.com/weblazy/easy/elog" 12 | "go.uber.org/zap" 13 | ) 14 | 15 | func LogHook(config *eredis_config.Config) redis.Hook { 16 | return NewRedisHook().SetAfterProcess( 17 | func(ctx context.Context, cmd redis.Cmder) error { 18 | var fields = make([]zap.Field, 0, 15) 19 | var err = cmd.Err() 20 | duration := GetDuration(ctx) 21 | fields = append(fields, elog.FieldName(config.Name), 22 | elog.FieldMethod(cmd.Name()), 23 | elog.FieldDuration(duration)) 24 | 25 | if config.EnableLogReq { 26 | fields = append(fields, elog.FieldReq(cmd.Args())) 27 | } 28 | if config.EnableLogRes && err == nil { 29 | fields = append(fields, elog.FieldResp(response(cmd))) 30 | } 31 | 32 | // 开启了链路,那么就记录链路id 33 | // if config.EnableTraceInterceptor && etrace.IsGlobalTracerRegistered() { 34 | // fields = append(fields, elog.FieldTrace(etrace.ExtractTraceID(ctx))) 35 | // } 36 | var isSlow bool 37 | if config.SlowLogThreshold > time.Duration(0) && duration > config.SlowLogThreshold { 38 | isSlow = true 39 | } 40 | fields = append(fields, elog.FieldSlow(isSlow)) 41 | if err != nil { 42 | fields = append(fields, elog.FieldError(err)) 43 | if errors.Is(err, redis.Nil) { 44 | elog.WarnCtx(ctx, eredis_config.PkgName, fields...) 45 | return err 46 | } 47 | elog.ErrorCtx(ctx, eredis_config.PkgName, fields...) 48 | return err 49 | } 50 | if isSlow { 51 | elog.WarnCtx(ctx, eredis_config.PkgName, fields...) 52 | return nil 53 | } 54 | if config.EnableLogAccess { 55 | elog.InfoCtx(ctx, eredis_config.PkgName, fields...) 56 | } 57 | return nil 58 | }, 59 | ) 60 | } 61 | 62 | func response(cmd redis.Cmder) string { 63 | switch t := cmd.(type) { 64 | case *redis.Cmd: 65 | return fmt.Sprintf("%v", t.Val()) 66 | case *redis.StringCmd: 67 | return t.Val() 68 | case *redis.StatusCmd: 69 | return t.Val() 70 | case *redis.IntCmd: 71 | return fmt.Sprintf("%v", t.Val()) 72 | case *redis.DurationCmd: 73 | return t.Val().String() 74 | case *redis.BoolCmd: 75 | return fmt.Sprintf("%v", t.Val()) 76 | case *redis.CommandsInfoCmd: 77 | return fmt.Sprintf("%v", t.Val()) 78 | case *redis.StringSliceCmd: 79 | return fmt.Sprintf("%v", t.Val()) 80 | default: 81 | return "" 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /db/eredis/interceptor/metric.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | 7 | "github.com/go-redis/redis/v8" 8 | "github.com/prometheus/client_golang/prometheus" 9 | "github.com/weblazy/easy/db/eredis/eredis_config" 10 | ) 11 | 12 | var ( 13 | 14 | // ClientHandleCounter ... 15 | RedisHandleCounter = prometheus.NewCounterVec( 16 | prometheus.CounterOpts{ 17 | Namespace: "", 18 | Name: "redis_handle_total", 19 | }, []string{"name", "method", "addr", "result"}) 20 | 21 | // ClientHandleHistogram ... 22 | RedisHandleHistogram = prometheus.NewHistogramVec(prometheus.HistogramOpts{ 23 | Namespace: "", 24 | Name: "redis_handle_seconds", 25 | }, []string{"name", "method", "addr"}) 26 | ) 27 | 28 | func init() { 29 | prometheus.MustRegister(RedisHandleCounter) 30 | prometheus.MustRegister(RedisHandleHistogram) 31 | } 32 | 33 | func MetricHook(config *eredis_config.Config) redis.Hook { 34 | return NewRedisHook().SetAfterProcess( 35 | func(ctx context.Context, cmd redis.Cmder) error { 36 | duration := GetDuration(ctx) 37 | err := cmd.Err() 38 | RedisHandleHistogram.WithLabelValues(config.Name, cmd.Name(), config.AddrString()).Observe(duration.Seconds()) 39 | if err != nil { 40 | if errors.Is(err, redis.Nil) { 41 | RedisHandleCounter.WithLabelValues(config.Name, cmd.Name(), config.AddrString(), "NotFound").Inc() 42 | return err 43 | } 44 | RedisHandleCounter.WithLabelValues(config.Name, cmd.Name(), config.AddrString(), "Error").Inc() 45 | return err 46 | } 47 | 48 | RedisHandleCounter.WithLabelValues(config.Name, cmd.Name(), config.AddrString(), "OK").Inc() 49 | return nil 50 | }, 51 | ) 52 | } 53 | -------------------------------------------------------------------------------- /db/eredis/interceptor/start_time.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/go-redis/redis/v8" 8 | ) 9 | 10 | type ctxStartTimeKey struct{} 11 | 12 | // StartTimeHook 13 | func StartTimeHook() redis.Hook { 14 | return NewRedisHook(). 15 | SetBeforeProcess(func(ctx context.Context, cmd redis.Cmder) (context.Context, error) { 16 | return context.WithValue(ctx, ctxStartTimeKey{}, time.Now()), nil 17 | }) 18 | } 19 | 20 | func GetStartTime(ctx context.Context) time.Time { 21 | return ctx.Value(ctxStartTimeKey{}).(time.Time) 22 | } 23 | 24 | func GetDuration(ctx context.Context) time.Duration { 25 | startTime, _ := ctx.Value(ctxStartTimeKey{}).(time.Time) 26 | return time.Since(startTime) 27 | } 28 | 29 | func GetDurationMilliseconds(ctx context.Context) float64 { 30 | startTime, _ := ctx.Value(ctxStartTimeKey{}).(time.Time) 31 | return float64(time.Since(startTime).Microseconds()) / 1000 32 | 33 | } 34 | -------------------------------------------------------------------------------- /db/eredis/redis_map.go: -------------------------------------------------------------------------------- 1 | package eredis 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/weblazy/easy/db/eredis/eredis_config" 7 | "github.com/weblazy/easy/econfig" 8 | ) 9 | 10 | var RedisMap sync.Map 11 | 12 | // GetRedis return a RedisClient 13 | func GetRedis(dbName string) *RedisClient { 14 | if v, ok := RedisMap.Load(dbName); ok { 15 | return v.(*RedisClient) 16 | } 17 | conf := eredis_config.DefaultConfig() 18 | econfig.GlobalViper.UnmarshalKey(dbName, conf) 19 | redisClient := NewRedisClient(conf) 20 | RedisMap.Store(dbName, redisClient) 21 | return redisClient 22 | } 23 | -------------------------------------------------------------------------------- /ecodes/convert.go: -------------------------------------------------------------------------------- 1 | package ecodes 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | 7 | "google.golang.org/grpc/status" 8 | 9 | "google.golang.org/grpc/codes" 10 | ) 11 | 12 | // GrpcToHTTPStatusCode gRPC转HTTP Code 13 | // example: 14 | // spbStatus := status.FromContextError(err) 15 | // httpStatusCode := ecode.GrpcToHTTPStatusCode(spbStatus.Code()) 16 | func GrpcToHTTPStatusCode(statusCode codes.Code) int { 17 | switch statusCode { 18 | case codes.OK: 19 | return http.StatusOK 20 | case codes.Canceled: 21 | return http.StatusRequestTimeout 22 | case codes.Unknown: 23 | return http.StatusInternalServerError 24 | case codes.InvalidArgument: 25 | return http.StatusBadRequest 26 | case codes.DeadlineExceeded: 27 | return http.StatusRequestTimeout 28 | case codes.NotFound: 29 | return http.StatusNotFound 30 | case codes.AlreadyExists: 31 | return http.StatusConflict 32 | case codes.PermissionDenied: 33 | return http.StatusForbidden 34 | case codes.Unauthenticated: 35 | return http.StatusUnauthorized 36 | case codes.ResourceExhausted: 37 | return http.StatusServiceUnavailable 38 | case codes.FailedPrecondition: 39 | return http.StatusPreconditionFailed 40 | case codes.Aborted: 41 | return http.StatusConflict 42 | case codes.OutOfRange: 43 | return http.StatusBadRequest 44 | case codes.Unimplemented: 45 | return http.StatusNotImplemented 46 | case codes.Internal: 47 | return http.StatusInternalServerError 48 | case codes.Unavailable: 49 | return http.StatusServiceUnavailable 50 | case codes.DataLoss: 51 | return http.StatusInternalServerError 52 | default: 53 | return http.StatusInternalServerError 54 | } 55 | } 56 | 57 | // Convert 内部转换,为了让err=nil的时候,监控数据里有OK信息 58 | func Convert(err error) *status.Status { 59 | if err == nil { 60 | return status.New(codes.OK, "OK") 61 | } 62 | 63 | if se, ok := err.(interface { 64 | GRPCStatus() *status.Status 65 | }); ok { 66 | return se.GRPCStatus() 67 | } 68 | 69 | switch err { 70 | case context.DeadlineExceeded: 71 | return status.New(codes.DeadlineExceeded, err.Error()) 72 | case context.Canceled: 73 | return status.New(codes.Canceled, err.Error()) 74 | } 75 | 76 | return status.New(codes.Unknown, err.Error()) 77 | } 78 | -------------------------------------------------------------------------------- /econfig/const.go: -------------------------------------------------------------------------------- 1 | package econfig 2 | 3 | const ( 4 | EasyConfigType = "EASY_CONFIG_TYPE" 5 | EasyConfigFile = "EASY_CONFIG_FILE" 6 | LocalType = "local" 7 | FielType = "file" 8 | NacosType = "nacos" 9 | ) 10 | -------------------------------------------------------------------------------- /econfig/econfig.go: -------------------------------------------------------------------------------- 1 | package econfig 2 | 3 | import ( 4 | "os" 5 | "strings" 6 | 7 | "github.com/weblazy/easy/econfig/eviper" 8 | "github.com/weblazy/easy/econfig/nacos" 9 | ) 10 | 11 | var GlobalViper *eviper.Viper 12 | 13 | func InitGlobalViper(config interface{}, localConfig ...string) { 14 | switch os.Getenv(EasyConfigType) { 15 | case LocalType: 16 | GlobalViper = eviper.NewViperFromString(localConfig[0]) 17 | case FielType: 18 | GlobalViper = eviper.NewViperFromFile("", os.Getenv(EasyConfigFile)) 19 | case NacosType: 20 | nacos.NewNacosEnv() 21 | vt := nacos.GetViper() 22 | vt.SetDataIds(os.Getenv("ServiceName"), os.Getenv("DataId")) 23 | // 注册配置更新回调 24 | vt.NacosToViper() 25 | GlobalViper = vt.Viper 26 | default: 27 | GlobalViper = eviper.NewViperFromString(localConfig[0]) 28 | } 29 | GlobalViper.Unmarshal(&config) 30 | } 31 | 32 | func GetEnvConfig(key string) string { 33 | env := os.Getenv(strings.Replace(strings.ToUpper(key), ".", "_", -1)) 34 | if env != "" { 35 | return env 36 | } 37 | return GlobalViper.GetString(key) 38 | } 39 | -------------------------------------------------------------------------------- /econfig/eviper/viper.go: -------------------------------------------------------------------------------- 1 | package eviper 2 | 3 | import ( 4 | "bytes" 5 | "log" 6 | "os" 7 | "path" 8 | "path/filepath" 9 | "strings" 10 | "sync" 11 | 12 | "github.com/BurntSushi/toml" 13 | "github.com/spf13/viper" 14 | ) 15 | 16 | type Viper struct { 17 | *viper.Viper 18 | } 19 | 20 | var multipleViper sync.Map 21 | 22 | func NewViperFromString(configs string) *Viper { 23 | v := viper.New() 24 | CheckToml(configs) 25 | v.SetConfigType("toml") 26 | err := v.ReadConfig(bytes.NewBuffer([]byte(configs))) 27 | if err != nil { 28 | print(err) 29 | } 30 | return &Viper{v} 31 | } 32 | 33 | func (v *Viper) MergeViperFromString(configs string) { 34 | CheckToml(configs) 35 | v.SetConfigType("toml") 36 | err := v.MergeConfig(bytes.NewBuffer([]byte(configs))) 37 | if err != nil { 38 | print(err) 39 | } 40 | } 41 | 42 | func CheckToml(configs string) { 43 | var tmp interface{} 44 | if _, err := toml.Decode(configs, &tmp); err != nil { 45 | log.Fatalf("Error decoding TOML: %s", err) 46 | return 47 | } 48 | } 49 | 50 | func NewViperFromFile(filePath string, fileName string) *Viper { 51 | return newConfig(filePath, fileName) 52 | } 53 | 54 | func newConfig(filePath string, fileName string) *Viper { 55 | v := viper.New() 56 | v.SetConfigName(fileName) 57 | //filePath支持相对路径和绝对路径 etc:"/a/b" "b" "./b" 58 | if filePath == "" || filePath[:1] != "/" { 59 | v.AddConfigPath(path.Join(GetPath(), filePath)) 60 | } else { 61 | v.AddConfigPath(filePath) 62 | } 63 | v.WatchConfig() 64 | // 找到并读取配置文件并且 处理错误读取配置文件 65 | if err := v.ReadInConfig(); err != nil { 66 | panic(err) 67 | } 68 | return &Viper{v} 69 | } 70 | 71 | // GetPath 获取项目路径 72 | func GetPath() string { 73 | dir, err := filepath.Abs(filepath.Dir(os.Args[0])) 74 | if err != nil { 75 | print(err.Error()) 76 | } 77 | path := strings.Replace(dir, "\\", "/", -1) 78 | return path 79 | } 80 | 81 | func BuildVipers(filePath string, fileName ...string) { 82 | for _, v := range fileName { 83 | _, found := multipleViper.Load(v) 84 | if !found { //can not remap 85 | A := newConfig(filePath, v) 86 | multipleViper.Store(v, A) 87 | } 88 | } 89 | } 90 | 91 | func LoadViperByFilename(filename string) *Viper { 92 | value, _ := multipleViper.Load(filename) 93 | if value == nil { 94 | return nil 95 | } else { 96 | return value.(*Viper) 97 | } 98 | } 99 | 100 | func (v *Viper) GetEnvConfig(key string) string { 101 | env := os.Getenv(strings.Replace(strings.ToUpper(key), ".", "_", -1)) 102 | if env != "" { 103 | return env 104 | } 105 | return v.GetString(key) 106 | } 107 | -------------------------------------------------------------------------------- /econfig/eviper/viper_test.go: -------------------------------------------------------------------------------- 1 | package eviper 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | ) 7 | 8 | func TestGetEnvConfig(t *testing.T) { 9 | v := NewViperFromString("") 10 | err := os.Setenv("TEST_DEMO", "6666") 11 | if err != nil { 12 | t.Error(err) 13 | return 14 | } 15 | s := v.GetEnvConfig("test.demo") 16 | if s != "6666" { 17 | t.Failed() 18 | } 19 | 20 | } 21 | -------------------------------------------------------------------------------- /econfig/nacos/local.go: -------------------------------------------------------------------------------- 1 | package nacos 2 | 3 | import ( 4 | "io/ioutil" 5 | 6 | "github.com/nacos-group/nacos-sdk-go/clients/config_client" 7 | "github.com/nacos-group/nacos-sdk-go/vo" 8 | ) 9 | 10 | type LocalNacos struct { 11 | configs string 12 | config_client.IConfigClient 13 | } 14 | 15 | // SetLocalConfigFile 注入本地配置 指定目录 16 | func SetLocalConfigFile(filePath string) { 17 | bytes, err := ioutil.ReadFile(filePath) 18 | if err != nil { 19 | panic(err) 20 | } 21 | SetLocalConfig(string(bytes)) 22 | } 23 | 24 | // SetLocalConfig 注入本地配置 25 | func SetLocalConfig(configs string) { 26 | localNacos := NewLocalNacos(configs) 27 | nacosHarder.icc = localNacos 28 | nacosHarder.local = true 29 | } 30 | 31 | func NewLocalNacos(configs string) config_client.IConfigClient { 32 | return &LocalNacos{configs: configs} 33 | } 34 | 35 | func (l *LocalNacos) GetConfig(param vo.ConfigParam) (string, error) { 36 | str := l.configs 37 | return str, nil 38 | } 39 | 40 | func (l *LocalNacos) PublishConfig(param vo.ConfigParam) (bool, error) { 41 | return true, nil 42 | } 43 | 44 | func (l *LocalNacos) DeleteConfig(param vo.ConfigParam) (bool, error) { 45 | return true, nil 46 | } 47 | 48 | func (l *LocalNacos) ListenConfig(params vo.ConfigParam) (err error) { 49 | return nil 50 | } 51 | -------------------------------------------------------------------------------- /ectx/ectx.go: -------------------------------------------------------------------------------- 1 | package ectx 2 | 3 | import ( 4 | "context" 5 | "time" 6 | ) 7 | 8 | // NoCancelContext remove context deadline 9 | type NoCancelContext struct { 10 | ctx context.Context 11 | } 12 | 13 | func (c NoCancelContext) Deadline() (time.Time, bool) { return time.Time{}, false } 14 | func (c NoCancelContext) Done() <-chan struct{} { return nil } 15 | func (c NoCancelContext) Err() error { return nil } 16 | func (c NoCancelContext) Value(key interface{}) interface{} { return c.ctx.Value(key) } 17 | 18 | // NoCancel remove ctx deadline then return a new context 19 | func NewNoCancelContext(ctx context.Context) context.Context { 20 | return NoCancelContext{ctx} 21 | } 22 | -------------------------------------------------------------------------------- /eerror/atomicerror.go: -------------------------------------------------------------------------------- 1 | package eerror 2 | 3 | import "sync" 4 | 5 | type AtomicError struct { 6 | err error 7 | lock sync.Mutex 8 | } 9 | 10 | func (ae *AtomicError) Set(err error) { 11 | ae.lock.Lock() 12 | ae.err = err 13 | ae.lock.Unlock() 14 | } 15 | 16 | func (ae *AtomicError) Load() error { 17 | ae.lock.Lock() 18 | err := ae.err 19 | ae.lock.Unlock() 20 | return err 21 | } 22 | -------------------------------------------------------------------------------- /eerror/atomicerror_test.go: -------------------------------------------------------------------------------- 1 | package eerror 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | var errDummy = errors.New("hello") 11 | 12 | func TestAtomicError(t *testing.T) { 13 | var err AtomicError 14 | err.Set(errDummy) 15 | assert.Equal(t, errDummy, err.Load()) 16 | } 17 | 18 | func TestAtomicErrorNil(t *testing.T) { 19 | var err AtomicError 20 | assert.Nil(t, err.Load()) 21 | } 22 | -------------------------------------------------------------------------------- /eerror/batcherror.go: -------------------------------------------------------------------------------- 1 | package eerror 2 | 3 | import "bytes" 4 | 5 | type BatchError []error 6 | 7 | func (be BatchError) Error() string { 8 | var buf bytes.Buffer 9 | 10 | for i := range be { 11 | if i > 0 { 12 | buf.WriteByte('\n') 13 | } 14 | buf.WriteString(be[i].Error()) 15 | } 16 | 17 | return buf.String() 18 | } 19 | -------------------------------------------------------------------------------- /eerror/batcherror_test.go: -------------------------------------------------------------------------------- 1 | package eerror 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | const ( 12 | err1 = "first error" 13 | err2 = "second error" 14 | ) 15 | 16 | func TestBatchErrorNil(t *testing.T) { 17 | var batch BatchError 18 | assert.Nil(t, batch) 19 | } 20 | 21 | func TestBatchErrorOneError(t *testing.T) { 22 | var batch BatchError 23 | batch = append(batch, errors.New(err1)) 24 | assert.NotNil(t, batch) 25 | assert.Equal(t, err1, batch.Error()) 26 | } 27 | 28 | func TestBatchErrorWithErrors(t *testing.T) { 29 | var batch BatchError 30 | batch = append(batch, errors.New(err1)) 31 | batch = append(batch, errors.New(err2)) 32 | assert.NotNil(t, batch) 33 | assert.Equal(t, fmt.Sprintf("%s\n%s", err1, err2), batch.Error()) 34 | } 35 | -------------------------------------------------------------------------------- /eerror/bizerror.go: -------------------------------------------------------------------------------- 1 | package eerror 2 | 3 | import ( 4 | "strconv" 5 | 6 | "github.com/weblazy/easy/set" 7 | ) 8 | 9 | const ( 10 | UnknownErrCode = "9999" 11 | SuccessCode = "20000" 12 | ) 13 | 14 | var ( 15 | DefaultSuccessCodes = []string{SuccessCode} 16 | ) 17 | 18 | type commonErrResp interface { 19 | GetError() error 20 | } 21 | 22 | type codeMsgResp interface { 23 | GetCode() int64 24 | GetMsg() string 25 | } 26 | 27 | type codeMsgStringResp interface { 28 | GetCode() string 29 | GetMessage() string 30 | } 31 | 32 | type retCodeMsgResp interface { 33 | GetRetCode() int32 34 | GetRetMsg() string 35 | } 36 | 37 | func ExtractBizCode(successCodes []string) func(resp interface{}, err error) (string, bool) { 38 | sc := successCodes 39 | if len(sc) == 0 { 40 | sc = DefaultSuccessCodes 41 | } 42 | scs := set.NewStringSet() 43 | scs.BatchAdd(sc...) 44 | 45 | replacer := func(s string) string { 46 | if scs.Has(s) { 47 | return SuccessCode 48 | } 49 | return s 50 | } 51 | 52 | return func(resp interface{}, err error) (string, bool) { 53 | // if err != nil { 54 | // commonErr, ok := FromErrorIsDetail(err) 55 | // // 1. rich error 56 | // if ok { 57 | // return commonErr.BizCode, true 58 | // } 59 | // non biz error 60 | // return "", false 61 | // } 62 | 63 | // // 2. 内嵌 commonError 64 | // if cer, ok := resp.(commonErrResp); ok { 65 | // // 内嵌 error nil 当做成功处理 66 | // if cer.GetError() == nil { 67 | // return SuccessCode, true 68 | // } 69 | // return replacer(cer.GetError().GetBizCode()), true 70 | // } 71 | 72 | // 3. 内嵌 code msg 73 | if cmr, ok := resp.(codeMsgResp); ok { 74 | return replacer(strconv.Itoa(int(cmr.GetCode()))), true 75 | } 76 | 77 | // 4. 内嵌 code msg string 78 | if cmr, ok := resp.(codeMsgStringResp); ok { 79 | return replacer(cmr.GetCode()), true 80 | } 81 | 82 | // 5. 内嵌 ret_code ret_msg 83 | if cmr, ok := resp.(retCodeMsgResp); ok { 84 | return replacer(strconv.Itoa(int(cmr.GetRetCode()))), true 85 | } 86 | 87 | // response 不符合上面任何标准, 且 err 为 nil 88 | // 当做正常请求计算 89 | if err == nil { 90 | return SuccessCode, true 91 | } 92 | 93 | // 不属于以上任意一种 94 | return "", false 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /ekafka/config.go: -------------------------------------------------------------------------------- 1 | package ekafka 2 | 3 | import ( 4 | "strings" 5 | 6 | "errors" 7 | 8 | "github.com/IBM/sarama" 9 | ) 10 | 11 | const ( 12 | passwordAuthType = "password" 13 | noAuthType = "none" 14 | ) 15 | 16 | const ( 17 | offsetOldest = "oldest" 18 | offsetNewest = "newest" 19 | ) 20 | 21 | type Config struct { 22 | Debug bool 23 | EnableAccessInterceptorReq bool // 是否开启记录 publish 消息,默认开启 24 | EnableAccessInterceptorRes bool // 是否开启记录 consumer 消费消息, 默认开启 25 | 26 | ClientConfig ClientConfig 27 | ProducerConfig ProducerConfig 28 | ConsumerGroupConfigs map[string]ConsumerGroupConfig 29 | } 30 | 31 | type ClientConfig struct { 32 | Brokers []string // Brokers brokers地址 33 | AuthType string // 鉴权方式, password / none. 默认为: none 34 | SaslUsername string // 鉴权方式为 password 时, 必填 35 | SaslPassword string // 鉴权方式为 password 时, 必填 36 | Version string // kafka version, 默认为 2.0.0.0 37 | } 38 | 39 | type ProducerConfig struct { 40 | MaxMessageBytes int 41 | } 42 | 43 | type RetryConfig struct { 44 | MaxRetries int64 // consumer 消费重试次数, 默认 0 不重试 45 | } 46 | 47 | type ConsumerGroupConfig struct { 48 | Topics []string 49 | GroupID string 50 | InitialOffset string // 初始化 offset, oldest / newest, 默认 oldest 51 | RetryConfig RetryConfig 52 | } 53 | 54 | func DefaultConfig() *Config { 55 | return &Config{ 56 | Debug: false, 57 | EnableAccessInterceptorReq: true, 58 | EnableAccessInterceptorRes: true, 59 | ClientConfig: ClientConfig{AuthType: noAuthType}, 60 | } 61 | } 62 | 63 | func (c *Config) toSaramaConfig() (*sarama.Config, error) { //nolint 64 | sc := sarama.NewConfig() 65 | clientConfig := c.ClientConfig 66 | 67 | if len(c.ClientConfig.Brokers) == 0 { 68 | return nil, errors.New("empty brokers") 69 | } 70 | 71 | if clientConfig.AuthType == passwordAuthType { 72 | if clientConfig.SaslUsername == "" || clientConfig.SaslPassword == "" { 73 | return nil, errors.New("username and password are required when using password auth type") 74 | } 75 | sc.Net.SASL.Enable = true 76 | sc.Net.SASL.User = clientConfig.SaslUsername 77 | sc.Net.SASL.Password = clientConfig.SaslPassword 78 | sc.Net.SASL.Mechanism = sarama.SASLTypePlaintext 79 | } 80 | 81 | if clientConfig.Version != "" { 82 | v, err := sarama.ParseKafkaVersion(clientConfig.Version) 83 | if err != nil { 84 | return nil, errors.New("invalid kafka version") 85 | } 86 | sc.Version = v 87 | } else { 88 | sc.Version = sarama.V2_0_0_0 89 | } 90 | 91 | return sc, nil 92 | } 93 | 94 | func (c *Config) brokers() string { //nolint 95 | return strings.Join(c.ClientConfig.Brokers, ",") 96 | } 97 | -------------------------------------------------------------------------------- /ekafka/container.go: -------------------------------------------------------------------------------- 1 | package ekafka 2 | 3 | const ( 4 | PackageNameProducer = "ekafka.producer" 5 | PackageNameConsumerGroup = "ekafka.consumerGroup" 6 | ) 7 | 8 | func NewProducer(name string, config *Config) (*Producer, error) { 9 | sc, err := config.toSaramaConfig() 10 | 11 | if err != nil { 12 | return nil, err 13 | } 14 | 15 | return newProducer(config, sc) 16 | } 17 | 18 | func NewConsumerGroup(name string, config *Config, groupConfig *ConsumerGroupConfig) (*ConsumerGroup, error) { 19 | sc, err := config.toSaramaConfig() 20 | 21 | if err != nil { 22 | return nil, err 23 | } 24 | 25 | return newConsumerGroup(config, groupConfig, sc) 26 | } 27 | -------------------------------------------------------------------------------- /ekafka/message.go: -------------------------------------------------------------------------------- 1 | package ekafka 2 | 3 | import "github.com/IBM/sarama" 4 | 5 | // Message sarama.ProducerMessage for kafka publish 6 | type Message struct { 7 | Topic string 8 | Key []byte 9 | Value []byte 10 | // The headers are key-value pairs that are transparently passed 11 | // by Kafka between producers and consumers. 12 | Headers []sarama.RecordHeader 13 | 14 | // This field is used to hold arbitrary data you wish to include, so it 15 | // will be available when receiving on the Successes and Errors channels. 16 | // Sarama completely ignores this field and is only to be used for 17 | // pass-through data. 18 | Metadata interface{} 19 | } 20 | 21 | func (m *Message) ToMap() map[string]interface{} { 22 | mp := make(map[string]interface{}, 5) 23 | mp["topic"] = m.Topic 24 | mp["value"] = string(m.Value) 25 | 26 | if m.Key != nil { 27 | mp["key"] = m.Key 28 | } 29 | 30 | if len(m.Headers) > 0 { 31 | headers := make(map[string]string, len(m.Headers)) 32 | for _, h := range m.Headers { 33 | headers[string(h.Key)] = string(h.Value) 34 | } 35 | mp["headers"] = headers 36 | } 37 | 38 | if m.Metadata != nil { 39 | mp["metadata"] = m.Metadata 40 | } 41 | 42 | return mp 43 | } 44 | 45 | func ConsumerMessageToMap(m *sarama.ConsumerMessage) map[string]interface{} { 46 | mp := make(map[string]interface{}, 7) 47 | mp["topic"] = m.Topic 48 | mp["value"] = string(m.Value) 49 | mp["partition"] = m.Partition 50 | mp["offset"] = m.Offset 51 | mp["timestamp"] = m.Timestamp 52 | 53 | if m.Key != nil { 54 | mp["key"] = m.Key 55 | } 56 | 57 | if len(m.Headers) > 0 { 58 | headers := make(map[string]string, len(m.Headers)) 59 | for _, h := range m.Headers { 60 | headers[string(h.Key)] = string(h.Value) 61 | } 62 | mp["headers"] = headers 63 | } 64 | 65 | return mp 66 | } 67 | -------------------------------------------------------------------------------- /ekafka/metrics.go: -------------------------------------------------------------------------------- 1 | package ekafka 2 | 3 | import ( 4 | "github.com/prometheus/client_golang/prometheus" 5 | ) 6 | 7 | var ( 8 | kafkaPublishCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ 9 | Name: "kafka_publish_total", 10 | }, []string{"brokers", "topic", "code"}) 11 | 12 | kafkaConsumerGroupCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ 13 | Name: "kafka_consumer_group_handle_total", 14 | }, []string{"brokers", "group_id", "topic", "code"}) 15 | 16 | kafkaConsumerGroupHistogram = prometheus.NewHistogramVec(prometheus.HistogramOpts{ 17 | Name: "kafka_consumer_group_handle_seconds", 18 | Buckets: []float64{.025, .05, .1, .25, .5, 1, 2.5, 5, 10, 30}, 19 | }, []string{"brokers", "group_id", "topic"}) 20 | ) 21 | 22 | func init() { 23 | prometheus.MustRegister(kafkaPublishCounter) 24 | prometheus.MustRegister(kafkaConsumerGroupCounter) 25 | prometheus.MustRegister(kafkaConsumerGroupHistogram) 26 | } 27 | -------------------------------------------------------------------------------- /ekafka/producer.go: -------------------------------------------------------------------------------- 1 | package ekafka 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/IBM/sarama" 7 | "go.uber.org/zap" 8 | 9 | "github.com/weblazy/easy/elog" 10 | ) 11 | 12 | const ( 13 | CodeOK = "OK" 14 | CodeError = "Error" 15 | ) 16 | 17 | type Producer struct { 18 | config *Config 19 | producer sarama.SyncProducer 20 | } 21 | 22 | func newProducer(config *Config, sc *sarama.Config) (*Producer, error) { //nolint 23 | c := &Producer{ 24 | config: config, 25 | } 26 | 27 | producer, err := getSyncProducer(config, *sc) 28 | if err != nil { 29 | return nil, err 30 | } 31 | 32 | c.producer = producer 33 | 34 | return c, nil 35 | } 36 | 37 | func (c *Producer) SendMessage(ctx context.Context, msg *Message) error { 38 | smsg := &sarama.ProducerMessage{ 39 | Topic: msg.Topic, 40 | Key: sarama.ByteEncoder(msg.Key), 41 | Value: sarama.ByteEncoder(msg.Value), 42 | Headers: msg.Headers, 43 | Metadata: msg.Metadata, 44 | } 45 | 46 | partition, offset, err := c.producer.SendMessage(smsg) 47 | 48 | labels := make([]zap.Field, 0) 49 | labels = append(labels, zap.String("topic", smsg.Topic)) 50 | 51 | if c.config.EnableAccessInterceptorReq { 52 | labels = append(labels, zap.Any("req", msg.ToMap())) 53 | } 54 | 55 | // if tid := etrace.ExtractTraceID(ctx); tid != "" { 56 | // labels = append(labels, elog.FieldTrace(tid)) 57 | // } 58 | 59 | if err != nil { 60 | labels = append(labels, elog.FieldError(err)) 61 | elog.ErrorCtx(ctx, "kafka publish failed", labels...) 62 | kafkaPublishCounter.WithLabelValues(c.config.brokers(), msg.Topic, CodeError).Inc() 63 | return err 64 | } 65 | 66 | labels = append(labels, zap.Int64("partition", int64(partition)), zap.Int64("offset", offset)) 67 | 68 | elog.InfoCtx(ctx, "kafka publish success", labels...) 69 | kafkaPublishCounter.WithLabelValues(c.config.brokers(), msg.Topic, CodeOK).Inc() 70 | 71 | return nil 72 | } 73 | 74 | func (c *Producer) Close() error { 75 | if c.producer != nil { 76 | elog.InfoCtx(context.Background(), "producer exit") 77 | return c.producer.Close() 78 | } 79 | 80 | return nil 81 | } 82 | 83 | func getSyncProducer(config *Config, sc sarama.Config) (sarama.SyncProducer, error) { //nolint 84 | // Add SyncProducer specific properties to copy of base config 85 | sc.Producer.RequiredAcks = sarama.WaitForAll 86 | sc.Producer.Retry.Max = 5 87 | sc.Producer.Return.Successes = true 88 | 89 | maxMessageBytes := config.ProducerConfig.MaxMessageBytes 90 | 91 | if maxMessageBytes > 0 { 92 | sc.Producer.MaxMessageBytes = maxMessageBytes 93 | } 94 | 95 | producer, err := sarama.NewSyncProducer(config.ClientConfig.Brokers, &sc) 96 | if err != nil { 97 | return nil, err 98 | } 99 | 100 | // wrap tracing 101 | return producer, nil 102 | } 103 | -------------------------------------------------------------------------------- /elog/elog.go: -------------------------------------------------------------------------------- 1 | package elog 2 | 3 | import ( 4 | "go.uber.org/zap" 5 | "golang.org/x/net/context" 6 | ) 7 | 8 | type CtxFieldKey struct{} 9 | type CtxSkipKey struct{} 10 | 11 | const DefaultSkip = 2 12 | 13 | func AddCtxSkip(ctx context.Context, skip int) context.Context { 14 | v, _ := ctx.Value(CtxSkipKey{}).(int) 15 | return context.WithValue(ctx, CtxSkipKey{}, v+skip) 16 | } 17 | 18 | func GetCtxSkip(ctx context.Context) int { 19 | v, _ := ctx.Value(CtxSkipKey{}).(int) 20 | return v 21 | } 22 | 23 | func DebugCtx(ctx context.Context, msg string, fields ...zap.Field) { 24 | logLevel, ok := ctx.Value(CtxSkipKey{}).(LogLevel) 25 | if ok && logLevel < Debug { 26 | return 27 | } 28 | fields = MergeCtxFields(ctx, fields...) 29 | logger := GetLoggerFromCtx(ctx) 30 | if logger != nil { 31 | logger.DebugCtx(ctx, msg, fields...) 32 | } 33 | } 34 | 35 | func InfoCtx(ctx context.Context, msg string, fields ...zap.Field) { 36 | logLevel, ok := ctx.Value(CtxSkipKey{}).(LogLevel) 37 | if ok && logLevel < Info { 38 | return 39 | } 40 | fields = MergeCtxFields(ctx, fields...) 41 | 42 | logger := GetLoggerFromCtx(ctx) 43 | if logger != nil { 44 | logger.InfoCtx(ctx, msg, fields...) 45 | } 46 | } 47 | 48 | func WarnCtx(ctx context.Context, msg string, fields ...zap.Field) { 49 | logLevel, ok := ctx.Value(CtxSkipKey{}).(LogLevel) 50 | if ok && logLevel < Warn { 51 | return 52 | } 53 | fields = MergeCtxFields(ctx, fields...) 54 | logger := GetLoggerFromCtx(ctx) 55 | if logger != nil { 56 | logger.WarnCtx(ctx, msg, fields...) 57 | } 58 | } 59 | 60 | func ErrorCtx(ctx context.Context, msg string, fields ...zap.Field) { 61 | logLevel, ok := ctx.Value(CtxSkipKey{}).(LogLevel) 62 | if ok && logLevel < Error { 63 | return 64 | } 65 | fields = MergeCtxFields(ctx, fields...) 66 | logger := GetLoggerFromCtx(ctx) 67 | if logger != nil { 68 | logger.ErrorCtx(ctx, msg, fields...) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /elog/ezap/config.go: -------------------------------------------------------------------------------- 1 | package ezap 2 | 3 | import ( 4 | "os" 5 | 6 | "go.uber.org/zap" 7 | "go.uber.org/zap/zapcore" 8 | ) 9 | 10 | type Config struct { 11 | Logfile *os.File 12 | ZapConfig zap.Config 13 | MaxAge int64 //定期清理日志文件,日志保留天数 14 | } 15 | 16 | // DefaultConfig default config ... 17 | func DefaultConfig() *Config { 18 | zapConfig := zap.NewProductionConfig() 19 | // zapConfig.EncoderConfig.TimeKey = zapcore.OmitKey 20 | // zapConfig.EncoderConfig.LevelKey = zapcore.OmitKey 21 | // zapConfig.EncoderConfig.NameKey = zapcore.OmitKey 22 | zapConfig.EncoderConfig.CallerKey = zapcore.OmitKey 23 | // zapConfig.EncoderConfig.FunctionKey = zapcore.OmitKey 24 | // zapConfig.EncoderConfig.MessageKey = "msg" 25 | zapConfig.EncoderConfig.StacktraceKey = zapcore.OmitKey 26 | // zapConfig.EncoderConfig.LineEnding = zapcore.DefaultLineEnding 27 | // zapConfig.EncoderConfig.EncodeLevel = zapcore.LowercaseLevelEncoder 28 | zapConfig.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder 29 | // zapConfig.EncoderConfig.EncodeDuration = zapcore.SecondsDurationEncoder 30 | // zapConfig.EncoderConfig.EncodeCaller = zapcore.ShortCallerEncoder 31 | zapConfig.Level = zap.NewAtomicLevelAt(zap.DebugLevel) 32 | return &Config{ 33 | MaxAge: 7, 34 | ZapConfig: zapConfig, 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /elog/ezap/console.go: -------------------------------------------------------------------------------- 1 | package ezap 2 | 3 | import ( 4 | "log" 5 | ) 6 | 7 | func NewConsoleEzap() *Ezap { 8 | 9 | var err error 10 | cfg := DefaultConfig() 11 | l, err := cfg.ZapConfig.Build() 12 | if err != nil { 13 | log.Printf("l.initZap(),err:%+v", err) 14 | return nil 15 | } 16 | return &Ezap{ 17 | Logger: l, 18 | Config: cfg, 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /elog/ezap/ezap.go: -------------------------------------------------------------------------------- 1 | package ezap 2 | 3 | import ( 4 | "context" 5 | "os" 6 | 7 | "github.com/weblazy/easy/elog/logx" 8 | 9 | "go.uber.org/zap" 10 | ) 11 | 12 | // Ezap 将文件输出到终端或者文件 13 | type Ezap struct { 14 | logx.GLog 15 | Logger *zap.Logger 16 | Logfile *os.File 17 | Config *Config 18 | } 19 | 20 | func (e *Ezap) DebugCtx(ctx context.Context, msg string, fields ...zap.Field) { 21 | e.Logger.Debug(msg, fields...) 22 | } 23 | 24 | func (e *Ezap) InfoCtx(ctx context.Context, msg string, fields ...zap.Field) { 25 | e.Logger.Info(msg, fields...) 26 | } 27 | 28 | func (e *Ezap) WarnCtx(ctx context.Context, msg string, fields ...zap.Field) { 29 | e.Logger.Warn(msg, fields...) 30 | } 31 | 32 | func (e *Ezap) ErrorCtx(ctx context.Context, msg string, fields ...zap.Field) { 33 | e.Logger.Error(msg, fields...) 34 | } 35 | -------------------------------------------------------------------------------- /elog/ezap/ezap_test.go: -------------------------------------------------------------------------------- 1 | package ezap 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/weblazy/easy/filex" 7 | ) 8 | 9 | func TestName(t *testing.T) { 10 | deleteLog(filex.GetPath()+"/RunTime", 7) 11 | } 12 | -------------------------------------------------------------------------------- /elog/level.go: -------------------------------------------------------------------------------- 1 | package elog 2 | 3 | import "fmt" 4 | 5 | type CtxLogLevelKey struct{} 6 | 7 | // LogLevel log level 8 | type LogLevel int 9 | 10 | const ( 11 | // Silent silent log level 12 | Silent LogLevel = iota + 1 13 | // Error error log level 14 | Error 15 | // Warn warn log level 16 | Warn 17 | // Info info log level 18 | Info 19 | // Debug debug log level 20 | Debug 21 | ) 22 | 23 | // String returns a lower-case ASCII representation of the log level. 24 | func (l LogLevel) String() string { 25 | switch l { 26 | case Silent: 27 | return "silent" 28 | case Error: 29 | return "error" 30 | case Warn: 31 | return "warn" 32 | case Info: 33 | return "info" 34 | case Debug: 35 | return "debug" 36 | default: 37 | return fmt.Sprintf("Level(%d)", l) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /elog/log.go: -------------------------------------------------------------------------------- 1 | package elog 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | 7 | "github.com/weblazy/easy/elog/ezap" 8 | "github.com/weblazy/easy/elog/logx" 9 | ) 10 | 11 | var ( 12 | Logger sync.Map 13 | DefaultLogger = ezap.NewConsoleEzap() 14 | ) 15 | 16 | const ( 17 | Ezap = "ezap" 18 | ) 19 | 20 | type CtxLoggerNameKey struct{} 21 | 22 | // 默认加入zap组件 23 | func init() { 24 | Logger.Store(Ezap, DefaultLogger) 25 | } 26 | 27 | // SetLogger 设置日志打印实例,选择输出到文件,终端,阿里云日志等 28 | func SetLogger(name string, logger logx.GLog) { 29 | Logger.Store(name, logger) 30 | } 31 | 32 | // DelLogger 删除日志插件 33 | func DelLogger(name string) { 34 | Logger.Delete(name) 35 | } 36 | 37 | // GetLoggerFromCtx 38 | func GetLoggerFromCtx(ctx context.Context) logx.GLog { 39 | loggerName, ok := ctx.Value(CtxLoggerNameKey{}).(string) 40 | if ok { 41 | logger, ok := Logger.Load(loggerName) 42 | if !ok { 43 | // 指定了logger,但是没有找到 44 | return nil 45 | } 46 | return logger.(logx.GLog) 47 | } 48 | // 没有指定logger使用默认全局logger 49 | return DefaultLogger 50 | } 51 | 52 | func SetLogerName(ctx context.Context, name string) context.Context { 53 | return context.WithValue(ctx, CtxLoggerNameKey{}, name) 54 | } 55 | -------------------------------------------------------------------------------- /elog/log_test.go: -------------------------------------------------------------------------------- 1 | package elog 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "go.uber.org/zap" 8 | ) 9 | 10 | func TestLog(t *testing.T) { 11 | s := struct { 12 | Name string 13 | Age int 14 | }{ 15 | Name: "Jerry", 16 | Age: 18, 17 | } 18 | // zap log 19 | ctx := context.Background() 20 | DebugCtx(ctx, "zap debug") 21 | InfoCtx(ctx, "", zap.Any("obj", s)) 22 | WarnCtx(ctx, "zap warn") 23 | ErrorCtx(ctx, "zap error") 24 | 25 | ctx = context.WithValue(ctx, CtxSkipKey{}, Info) 26 | // ctx = AddCtxSkip(ctx, 2) 27 | ctx = AppendCtxFields(ctx, zap.String("name", "lazy")) 28 | DebugCtx(ctx, "zap debug") 29 | InfoCtx(ctx, "", zap.Any("obj", s)) 30 | WarnCtx(ctx, "zap warn") 31 | ErrorCtx(ctx, "zap error") 32 | 33 | // logger := ezap.NewFileEzap("test1") 34 | // loggerName := "test" 35 | // SetLogger(loggerName, logger) 36 | // ctx = SetLogerName(ctx, loggerName) 37 | // DebugCtx(ctx, "zap debug") 38 | // InfoCtx(ctx, "", zap.Any("obj", s)) 39 | // WarnCtx(ctx, "zap warn") 40 | // ErrorCtx(ctx, "zap error") 41 | 42 | } 43 | -------------------------------------------------------------------------------- /elog/logx/logx.go: -------------------------------------------------------------------------------- 1 | package logx 2 | 3 | import ( 4 | "context" 5 | 6 | "go.uber.org/zap" 7 | ) 8 | 9 | type GLog interface { 10 | ErrorCtx(ctx context.Context, msg string, fields ...zap.Field) 11 | WarnCtx(ctx context.Context, msg string, fields ...zap.Field) 12 | InfoCtx(ctx context.Context, msg string, fields ...zap.Field) 13 | DebugCtx(ctx context.Context, msg string, fields ...zap.Field) 14 | } 15 | -------------------------------------------------------------------------------- /emath/emath.go: -------------------------------------------------------------------------------- 1 | package emath 2 | 3 | import ( 4 | "math" 5 | "math/big" 6 | ) 7 | 8 | func BigIntQuoDecimal(amount *big.Int, decimals int) float64 { 9 | b := new(big.Float).SetInt(amount) 10 | r, _ := new(big.Float).Quo(b, big.NewFloat(math.Pow10(decimals))).Float64() 11 | return r 12 | } 13 | 14 | func BigFloatQuoDecimal(b *big.Float, decimal float64) float64 { 15 | r, _ := new(big.Float).Quo(b, big.NewFloat(math.Pow(10, decimal))).Float64() 16 | return r 17 | } 18 | 19 | func BigIntToFloat64(b *big.Int) float64 { 20 | f := new(big.Float).SetInt(b) 21 | result, _ := f.Float64() 22 | return result 23 | 24 | } 25 | -------------------------------------------------------------------------------- /env/env.go: -------------------------------------------------------------------------------- 1 | package env 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | ) 7 | 8 | const ( 9 | ReleaseEnv = "onl" 10 | ) 11 | 12 | var releaseFlag = false //为true时表示线上环境 13 | 14 | // GetRunTime 获取当前系统环境 15 | func GetRunTime() string { 16 | RunTime := os.Getenv("RUN_TIME") 17 | if RunTime == "" { 18 | fmt.Println("No RUN_TIME Can't start") 19 | } 20 | return RunTime 21 | } 22 | 23 | // OnRelease 开启线上环境 24 | func OnRelease() { 25 | releaseFlag = true 26 | } 27 | 28 | // IsRelease 如果是线上环境返回true 29 | func IsRelease() bool { 30 | return releaseFlag || GetRunTime() == ReleaseEnv 31 | } 32 | -------------------------------------------------------------------------------- /eprometheus/config.go: -------------------------------------------------------------------------------- 1 | package eprometheus 2 | 3 | type Config struct { 4 | Path string 5 | Port string 6 | } 7 | 8 | // DefaultConfig ... 9 | func DefaultConfig() *Config { 10 | return &Config{ 11 | Path: "/metrics", 12 | Port: ":2112", 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /eprometheus/eprometheus.go: -------------------------------------------------------------------------------- 1 | package eprometheus 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/prometheus/client_golang/prometheus/promhttp" 7 | ) 8 | 9 | func RunPrometheus(cfg *Config) { 10 | http.Handle(cfg.Path, promhttp.Handler()) 11 | http.ListenAndServe(cfg.Port, nil) 12 | } 13 | -------------------------------------------------------------------------------- /etrace/baggage.go: -------------------------------------------------------------------------------- 1 | package etrace 2 | 3 | import ( 4 | "context" 5 | 6 | "go.opentelemetry.io/otel/baggage" 7 | ) 8 | 9 | // GetBaggageValue get baggage info from context, if key not exists, return "", false. 10 | func GetBaggageValue(ctx context.Context, key string) (string, bool) { 11 | b := baggage.FromContext(ctx) 12 | m := b.Member(key) 13 | 14 | if m.Key() == "" { 15 | return "", false 16 | } 17 | 18 | return m.Value(), true 19 | } 20 | 21 | // WithBaggage append baggage by string key val. 22 | func WithBaggage(parent context.Context, key, val string) (context.Context, error) { 23 | member, err := baggage.NewMember(key, val) 24 | if err != nil { 25 | return parent, err 26 | } 27 | 28 | b := baggage.FromContext(parent) 29 | b, err = b.SetMember(member) 30 | if err != nil { 31 | return parent, err 32 | } 33 | 34 | return baggage.ContextWithBaggage(parent, b), nil 35 | } 36 | 37 | // AppendBaggageByMap append map kvs to current ctx baggage, will return origin ctx if error. 38 | func AppendBaggageByMap(ctx context.Context, mp map[string]string) (context.Context, error) { 39 | b := baggage.FromContext(ctx) 40 | 41 | for k, v := range mp { 42 | m, err := baggage.NewMember(k, v) 43 | if err != nil { 44 | return ctx, err 45 | } 46 | 47 | b, err = b.SetMember(m) 48 | if err != nil { 49 | return ctx, err 50 | } 51 | } 52 | 53 | return baggage.ContextWithBaggage(ctx, b), nil 54 | } 55 | -------------------------------------------------------------------------------- /etrace/baggage_test.go: -------------------------------------------------------------------------------- 1 | package etrace 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | "go.opentelemetry.io/otel/baggage" 9 | ) 10 | 11 | func TestAppendBaggageByMap(t *testing.T) { 12 | ctx, err := AppendBaggageByMap(context.Background(), map[string]string{"test": "test", "aaa": "aaa"}) 13 | assert.Nil(t, err) 14 | b := baggage.FromContext(ctx) 15 | assert.Equal(t, "test", b.Member("test").Value()) 16 | assert.Equal(t, "aaa", b.Member("aaa").Value()) 17 | } 18 | 19 | func TestWithBaggage(t *testing.T) { 20 | ctx := context.Background() 21 | ctx, err := WithBaggage(ctx, "aaa", "aaa") 22 | assert.Nil(t, err) 23 | val, ok := GetBaggageValue(ctx, "aaa") 24 | assert.True(t, ok) 25 | assert.Equal(t, "aaa", val) 26 | 27 | ctx, err = WithBaggage(ctx, "aaa", "bbb") 28 | assert.Nil(t, err) 29 | val, ok = GetBaggageValue(ctx, "aaa") 30 | assert.True(t, ok) 31 | assert.Equal(t, "bbb", val) 32 | } 33 | 34 | func TestGetBaggageValue(t *testing.T) { 35 | ctx, _ := AppendBaggageByMap(context.Background(), map[string]string{"test": "aaa"}) 36 | 37 | type args struct { 38 | ctx context.Context 39 | key string 40 | } 41 | tests := []struct { 42 | name string 43 | args args 44 | want string 45 | want1 bool 46 | }{ 47 | { 48 | "not exists", 49 | args{ 50 | ctx: context.Background(), 51 | key: "test", 52 | }, 53 | "", 54 | false, 55 | }, 56 | { 57 | "exists", 58 | args{ 59 | ctx: ctx, 60 | key: "test", 61 | }, 62 | "aaa", 63 | true, 64 | }, 65 | } 66 | for _, tt := range tests { 67 | t.Run(tt.name, func(t *testing.T) { 68 | got, got1 := GetBaggageValue(tt.args.ctx, tt.args.key) 69 | if got != tt.want { 70 | t.Errorf("GetBaggageValue() got = %v, want %v", got, tt.want) 71 | } 72 | if got1 != tt.want1 { 73 | t.Errorf("GetBaggageValue() got1 = %v, want %v", got1, tt.want1) 74 | } 75 | }) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /etrace/fskywalking/skywalking.go: -------------------------------------------------------------------------------- 1 | package fskywalking 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | 8 | "go.uber.org/zap" 9 | 10 | "github.com/SkyAPM/go2sky" 11 | "github.com/SkyAPM/go2sky/reporter" 12 | "github.com/weblazy/easy/elog" 13 | ) 14 | 15 | const ( 16 | ENV_KEY = "MY_ENV_NAME" 17 | PROJECT_NAME_KEY = "MY_PROJECT_NAME" 18 | MY_PROJECT_TRACE_HOST = "MY_PROJECT_TRACE_HOST" 19 | ) 20 | 21 | var emptyCtx = context.Background() 22 | 23 | type Config struct { 24 | Enable bool 25 | ServiceName string 26 | EnvName string 27 | 28 | AgentEndPoint string 29 | Sampler float64 30 | } 31 | 32 | func DefaultConfig() *Config { 33 | return &Config{ 34 | Enable: true, 35 | ServiceName: os.Getenv(PROJECT_NAME_KEY), 36 | EnvName: os.Getenv(ENV_KEY), 37 | AgentEndPoint: os.Getenv(MY_PROJECT_TRACE_HOST), 38 | Sampler: 0.1, 39 | } 40 | } 41 | 42 | // Option 可选项 43 | type Option func(c *Config) 44 | 45 | func (config *Config) Build(ops ...Option) *go2sky.Tracer { 46 | lc := zap.Any("config", config) 47 | 48 | if !config.Enable { 49 | elog.InfoCtx(emptyCtx, "skywalking not enable", lc) 50 | return nil 51 | } 52 | 53 | r, err := reporter.NewGRPCReporter(config.AgentEndPoint) 54 | if err != nil { 55 | elog.InfoCtx(emptyCtx, "skywalking new reporter error", lc, zap.Error(err)) 56 | return nil 57 | } 58 | 59 | tracer, err := go2sky.NewTracer(fmt.Sprintf("%s-%s", config.EnvName, config.ServiceName), go2sky.WithReporter(r), go2sky.WithSampler(config.Sampler)) 60 | if err != nil { 61 | elog.InfoCtx(emptyCtx, "skywalking new tracer error", lc, zap.Error(err)) 62 | return nil 63 | } 64 | 65 | // registers `tracer` as the global Tracer 66 | go2sky.SetGlobalTracer(tracer) 67 | 68 | elog.InfoCtx(emptyCtx, "skywalking init success", lc) 69 | 70 | return tracer 71 | } 72 | -------------------------------------------------------------------------------- /etrace/grpc.go: -------------------------------------------------------------------------------- 1 | package etrace 2 | 3 | import "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" 4 | 5 | var ( 6 | // UnaryServerInterceptor is alias for otelgrpc.UnaryServerInterceptor. 7 | UnaryServerInterceptor = otelgrpc.UnaryServerInterceptor 8 | // StreamServerInterceptor is alias for otelgrpc.StreamServerInterceptor. 9 | StreamServerInterceptor = otelgrpc.StreamServerInterceptor 10 | ) 11 | 12 | var ( 13 | // UnaryClientInterceptor is alias for otelgrpc.UnaryClientInterceptor. 14 | UnaryClientInterceptor = otelgrpc.UnaryClientInterceptor 15 | // StreamClientInterceptor is alias for otelgrpc.StreamClientInterceptor. 16 | StreamClientInterceptor = otelgrpc.StreamClientInterceptor 17 | ) 18 | -------------------------------------------------------------------------------- /etrace/http.go: -------------------------------------------------------------------------------- 1 | package etrace 2 | 3 | import ( 4 | "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" 5 | ) 6 | 7 | var ( 8 | // HTTPMiddleware is alias for telhttp.NewHandler. 9 | HTTPMiddleware = otelhttp.NewHandler 10 | ) 11 | 12 | // HTTPTransport is alias for otelhttp.NewTransport. 13 | var HTTPTransport = otelhttp.NewTransport 14 | -------------------------------------------------------------------------------- /etrace/trace.go: -------------------------------------------------------------------------------- 1 | package etrace 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | "go.opentelemetry.io/contrib/propagators/jaeger" 8 | 9 | "go.opentelemetry.io/otel" 10 | "go.opentelemetry.io/otel/trace" 11 | ) 12 | 13 | const ( 14 | fixSpanIDPrefix = "0000000000000000" 15 | ) 16 | 17 | type registeredTracer struct { 18 | isRegistered bool 19 | } 20 | 21 | var ( 22 | globalTracer = registeredTracer{false} 23 | ) 24 | 25 | func SetGlobalTracer(tp trace.TracerProvider) { 26 | globalTracer = registeredTracer{true} 27 | otel.SetTracerProvider(tp) 28 | // use jaeger propagator, header uber-trace-id 29 | otel.SetTextMapPropagator(jaeger.Jaeger{}) 30 | } 31 | 32 | // IsGlobalTracerRegistered returns a `bool` to indicate if a tracer has been globally registered. 33 | func IsGlobalTracerRegistered() bool { 34 | return globalTracer.isRegistered 35 | } 36 | 37 | // ExtractTraceID HTTP使用request.Context,不要使用错了. 38 | func ExtractTraceID(ctx context.Context) string { 39 | if !IsGlobalTracerRegistered() { 40 | return "" 41 | } 42 | span := trace.SpanContextFromContext(ctx) 43 | if span.HasTraceID() { 44 | sp := span.TraceID().String() 45 | // https://github.com/open-telemetry/opentelemetry-go/issues/686 46 | // remove left padding for 64-bit TraceIDs 47 | return strings.TrimPrefix(sp, fixSpanIDPrefix) 48 | } 49 | return "" 50 | } 51 | -------------------------------------------------------------------------------- /filex/filex.go: -------------------------------------------------------------------------------- 1 | package filex 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | "path/filepath" 7 | "strings" 8 | ) 9 | 10 | func Read(path string) ([]byte, error) { 11 | fi, err := os.Open(path) 12 | if err != nil { 13 | return nil, err 14 | } 15 | defer fi.Close() 16 | fd, err := ioutil.ReadAll(fi) 17 | if err != nil { 18 | return nil, err 19 | } 20 | return fd, nil 21 | } 22 | 23 | func Write(path string, b []byte, isAppend bool) error { 24 | flag := os.O_WRONLY | os.O_TRUNC | os.O_CREATE 25 | if isAppend { 26 | flag = os.O_WRONLY | os.O_APPEND | os.O_CREATE 27 | } 28 | fd, err := os.OpenFile(path, flag, os.ModePerm) 29 | if err != nil { 30 | return err 31 | } 32 | defer fd.Close() 33 | fd.Write(b) 34 | return nil 35 | } 36 | 37 | // GetPath 获取项目路径 38 | func GetPath() string { 39 | dir, err := filepath.Abs(filepath.Dir(os.Args[0])) 40 | if err != nil { 41 | print(err.Error()) 42 | } 43 | return strings.Replace(dir, "\\", "/", -1) 44 | } 45 | 46 | // CheckDir 判断文件目录否存在 47 | func CheckDir(path string) bool { 48 | fi, err := os.Stat(path) 49 | if err != nil { 50 | return os.IsExist(err) 51 | } else { 52 | return fi.IsDir() 53 | } 54 | } 55 | 56 | // MkdirDir 创建文件夹,支持x/a/a 多层级 57 | func MkdirDir(path string) error { 58 | return os.MkdirAll(path, os.ModePerm) 59 | } 60 | 61 | // RemoveDir 删除文件 62 | func RemoveDir(filePath string) error { 63 | return os.RemoveAll(filePath) 64 | } 65 | -------------------------------------------------------------------------------- /filex/filex_test.go: -------------------------------------------------------------------------------- 1 | package filex 2 | 3 | import "testing" 4 | 5 | func TestWrite(t *testing.T) { 6 | path := "test.csv" 7 | 8 | err := Write(path, []byte("666"), false) 9 | if err != nil { 10 | t.Fatalf("err:%#v,", err) 11 | } 12 | } 13 | 14 | func TestRead(t *testing.T) { 15 | path := "test.csv" 16 | excepted := "666" 17 | ret, _ := Read(path) 18 | result := string(ret) 19 | if result != excepted { 20 | t.Fatalf("result:%#v,excepted:%#v", result, excepted) 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /filex/test.csv: -------------------------------------------------------------------------------- 1 | 666 -------------------------------------------------------------------------------- /fmetric/README.md: -------------------------------------------------------------------------------- 1 | # 业务指标监控接入 2 | 3 | `fmetric` 只是对 `prometheus` 的简单封装, 仅仅将初始化和注册指标包装成一个函数而已. 如果有疑惑请先学习 `prometheus` 相关知识. 4 | 5 | 接入需要注意的是: 6 | 7 | 1. 非业务相关 label 不需要关注, 例如: app, projectEnv (k8s prometheus operator 会在收集时自动注入 k8s pod 相关信息) 8 | 2. 注意 labels 组合种类总数**必须**是常数级别有限的 9 | 10 | 对于业务指标, 绝大多数情况 `Counter` 类型就够了, 别的类型也是同理(请确保你在充分了解不同指标类型后选择合适的类型使用). 下面用此类型为例. 11 | 12 | ```go 13 | 14 | // 1. 在项目里初始化指标 15 | var SomeBizCounter = fmetric.CounterVecOpts{ 16 | // 指标名称命名要清晰, 并且是下划线形式 17 | Name: "monitor_some_biz_total", 18 | // help 字段增加指标说明 19 | Help: "Total number of some biz logic on the server", 20 | // 业务 label 名称 21 | Labels: []string{"channel", "status"}, 22 | }.Build() 23 | 24 | 25 | // 2. 在适当的业务逻辑中上报指标 26 | // WithLabelValues 值顺序必须和 Labels 声明顺序一致 27 | // 真正的业务使用 label 值可以定义成全局常量, 避免拼写错误 28 | SomeBizCounter.WithLabelValues("channel1", "OK").Inc() 29 | // other case 30 | SomeBizCounter.WithLabelValues("channel2", "ERROR").Inc() 31 | ``` 32 | -------------------------------------------------------------------------------- /fmetric/counter.go: -------------------------------------------------------------------------------- 1 | package fmetric 2 | 3 | import ( 4 | "github.com/prometheus/client_golang/prometheus" 5 | ) 6 | 7 | // CounterVecOpts ... 8 | type CounterVecOpts struct { 9 | Namespace string 10 | Subsystem string 11 | Name string 12 | Help string 13 | Labels []string 14 | } 15 | 16 | // Build ... 17 | func (opts CounterVecOpts) Build() *CounterVec { 18 | vec := prometheus.NewCounterVec( 19 | prometheus.CounterOpts{ 20 | Namespace: opts.Namespace, 21 | Subsystem: opts.Subsystem, 22 | Name: opts.Name, 23 | Help: opts.Help, 24 | }, opts.Labels) 25 | prometheus.MustRegister(vec) 26 | return &CounterVec{ 27 | CounterVec: vec, 28 | } 29 | } 30 | 31 | // NewCounterVec ... 32 | func NewCounterVec(name string, labels []string) *CounterVec { 33 | return CounterVecOpts{ 34 | Namespace: DefaultNamespace, 35 | Name: name, 36 | Help: name, 37 | Labels: labels, 38 | }.Build() 39 | } 40 | 41 | // CounterVec ... 42 | type CounterVec struct { 43 | *prometheus.CounterVec 44 | } 45 | 46 | // Inc ... 47 | func (counter *CounterVec) Inc(labels ...string) { 48 | counter.WithLabelValues(labels...).Inc() 49 | } 50 | 51 | // Add ... 52 | func (counter *CounterVec) Add(v float64, labels ...string) { 53 | counter.WithLabelValues(labels...).Add(v) 54 | } 55 | -------------------------------------------------------------------------------- /fmetric/gauge.go: -------------------------------------------------------------------------------- 1 | package fmetric 2 | 3 | import "github.com/prometheus/client_golang/prometheus" 4 | 5 | // GaugeVecOpts ... 6 | type GaugeVecOpts struct { 7 | Namespace string 8 | Subsystem string 9 | Name string 10 | Help string 11 | Labels []string 12 | } 13 | 14 | // GaugeVec ... 15 | type GaugeVec struct { 16 | *prometheus.GaugeVec 17 | } 18 | 19 | // Build ... 20 | func (opts GaugeVecOpts) Build() *GaugeVec { 21 | vec := prometheus.NewGaugeVec( 22 | prometheus.GaugeOpts{ 23 | Namespace: opts.Namespace, 24 | Subsystem: opts.Subsystem, 25 | Name: opts.Name, 26 | Help: opts.Help, 27 | }, opts.Labels) 28 | prometheus.MustRegister(vec) 29 | return &GaugeVec{ 30 | GaugeVec: vec, 31 | } 32 | } 33 | 34 | // NewGaugeVec ... 35 | func NewGaugeVec(name string, labels []string) *GaugeVec { 36 | return GaugeVecOpts{ 37 | Namespace: DefaultNamespace, 38 | Name: name, 39 | Help: name, 40 | Labels: labels, 41 | }.Build() 42 | } 43 | 44 | // Inc ... 45 | func (gv *GaugeVec) Inc(labels ...string) { 46 | gv.WithLabelValues(labels...).Inc() 47 | } 48 | 49 | // Add ... 50 | func (gv *GaugeVec) Add(v float64, labels ...string) { 51 | gv.WithLabelValues(labels...).Add(v) 52 | } 53 | 54 | // Set ... 55 | func (gv *GaugeVec) Set(v float64, labels ...string) { 56 | gv.WithLabelValues(labels...).Set(v) 57 | } 58 | -------------------------------------------------------------------------------- /fmetric/grpc.go: -------------------------------------------------------------------------------- 1 | package fmetric 2 | 3 | import "strings" 4 | 5 | const unknown = "unknown" 6 | 7 | // SplitGrpcMethodName split grpc full method into service and method. 8 | func SplitGrpcMethodName(fullMethodName string) (service string, method string) { 9 | fullMethodName = strings.TrimPrefix(fullMethodName, "/") // remove leading slash 10 | if i := strings.Index(fullMethodName, "/"); i >= 0 { 11 | return fullMethodName[:i], fullMethodName[i+1:] 12 | } 13 | return unknown, unknown 14 | } 15 | -------------------------------------------------------------------------------- /fmetric/grpc_test.go: -------------------------------------------------------------------------------- 1 | package fmetric 2 | 3 | import "testing" 4 | 5 | func TestSplitGrpcMethodName(t *testing.T) { 6 | type args struct { 7 | fullMethodName string 8 | } 9 | 10 | tests := []struct { 11 | name string 12 | args args 13 | wantService string 14 | wantMethod string 15 | }{ 16 | { 17 | "normal", 18 | args{fullMethodName: "/channel.quote.v1.QuoteAPI/QuotationGenerate"}, 19 | "channel.quote.v1.QuoteAPI", 20 | "QuotationGenerate", 21 | }, 22 | { 23 | "invalid", 24 | args{fullMethodName: "/channel.quote.v1.QuoteAPI"}, 25 | unknown, 26 | unknown, 27 | }, 28 | } 29 | 30 | for _, tt := range tests { 31 | t.Run(tt.name, func(t *testing.T) { 32 | gotService, gotMethod := SplitGrpcMethodName(tt.args.fullMethodName) 33 | if gotService != tt.wantService { 34 | t.Errorf("SplitGrpcMethodName() gotService = %v, want %v", gotService, tt.wantService) 35 | } 36 | if gotMethod != tt.wantMethod { 37 | t.Errorf("SplitGrpcMethodName() gotMethod = %v, want %v", gotMethod, tt.wantMethod) 38 | } 39 | }) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /fmetric/histogram.go: -------------------------------------------------------------------------------- 1 | package fmetric 2 | 3 | import ( 4 | "github.com/prometheus/client_golang/prometheus" 5 | ) 6 | 7 | // HistogramVecOpts ... 8 | type HistogramVecOpts struct { 9 | Namespace string 10 | Subsystem string 11 | Name string 12 | Help string 13 | Labels []string 14 | Buckets []float64 15 | } 16 | 17 | // HistogramVec ... 18 | type HistogramVec struct { 19 | *prometheus.HistogramVec 20 | } 21 | 22 | // Build ... 23 | func (opts HistogramVecOpts) Build() *HistogramVec { 24 | vec := prometheus.NewHistogramVec( 25 | prometheus.HistogramOpts{ 26 | Namespace: opts.Namespace, 27 | Subsystem: opts.Subsystem, 28 | Name: opts.Name, 29 | Help: opts.Help, 30 | Buckets: opts.Buckets, 31 | }, opts.Labels) 32 | prometheus.MustRegister(vec) 33 | return &HistogramVec{ 34 | HistogramVec: vec, 35 | } 36 | } 37 | 38 | // Observe ... 39 | func (histogram *HistogramVec) Observe(v float64, labels ...string) { 40 | histogram.WithLabelValues(labels...).Observe(v) 41 | } 42 | 43 | func (histogram *HistogramVec) ObserveWithExemplar(v float64, exemplar prometheus.Labels, labels ...string) { 44 | histogram.WithLabelValues(labels...).(prometheus.ExemplarObserver).ObserveWithExemplar(v, exemplar) 45 | } 46 | -------------------------------------------------------------------------------- /fmetric/metric.go: -------------------------------------------------------------------------------- 1 | package fmetric 2 | 3 | var ( 4 | // TypeHTTP ... 5 | TypeHTTP = "http" 6 | // TypeGRPCClient ... 7 | TypeGRPCClient = "grpc_client" 8 | // TypeGRPCServer ... 9 | TypeGRPCServer = "grpc_server" 10 | // TypeRedis ... 11 | TypeRedis = "redis" 12 | // TypeGorm ... 13 | TypeGorm = "gorm" 14 | // TypeMySQL ... 15 | TypeMySQL = "mysql" 16 | 17 | // DefaultNamespace ... 18 | DefaultNamespace = "" 19 | ) 20 | 21 | const ( 22 | CodeOK = "OK" 23 | CodeError = "Error" 24 | ) 25 | -------------------------------------------------------------------------------- /fmetric/summary.go: -------------------------------------------------------------------------------- 1 | package fmetric 2 | 3 | import "github.com/prometheus/client_golang/prometheus" 4 | 5 | // SummaryVecOpts ... 6 | type SummaryVecOpts struct { 7 | Namespace string 8 | Subsystem string 9 | Name string 10 | Help string 11 | Objectives map[float64]float64 12 | Labels []string 13 | } 14 | 15 | // SummaryVec ... 16 | type SummaryVec struct { 17 | *prometheus.SummaryVec 18 | } 19 | 20 | // Build ... 21 | func (opts SummaryVecOpts) Build() *SummaryVec { 22 | vec := prometheus.NewSummaryVec( 23 | prometheus.SummaryOpts{ 24 | Namespace: opts.Namespace, 25 | Subsystem: opts.Subsystem, 26 | Name: opts.Name, 27 | Help: opts.Help, 28 | Objectives: opts.Objectives, 29 | }, opts.Labels) 30 | prometheus.MustRegister(vec) 31 | return &SummaryVec{ 32 | SummaryVec: vec, 33 | } 34 | } 35 | 36 | // Observe ... 37 | func (summary *SummaryVec) Observe(v float64, labels ...string) { 38 | summary.WithLabelValues(labels...).Observe(v) 39 | } 40 | -------------------------------------------------------------------------------- /gpool/gpool.go: -------------------------------------------------------------------------------- 1 | package gpool 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | ) 7 | 8 | type GPool struct { 9 | lock sync.Mutex 10 | maxCount int64 11 | curCount int64 12 | waitGroup sync.WaitGroup 13 | jobs chan interface{} 14 | fun func(param interface{}) 15 | closeCh chan bool 16 | } 17 | 18 | var NilErr = fmt.Errorf("param can not be nil") 19 | var CloseErr = fmt.Errorf("pool was closed") 20 | 21 | func NewGPool(maxCount int64, fun func(param interface{})) *GPool { 22 | return &GPool{ 23 | maxCount: maxCount, 24 | fun: fun, 25 | jobs: make(chan interface{}, 10), 26 | closeCh: make(chan bool), 27 | } 28 | } 29 | 30 | func (g *GPool) Run(param interface{}) error { 31 | if param == nil { 32 | return NilErr 33 | } 34 | g.lock.Lock() 35 | defer g.lock.Unlock() 36 | select { 37 | case <-g.closeCh: 38 | return CloseErr 39 | default: 40 | if g.curCount < g.maxCount { 41 | g.waitGroup.Add(1) 42 | g.curCount++ 43 | go func() { 44 | defer func() { 45 | if p := recover(); p != nil { 46 | fmt.Printf("%#v\n", p) 47 | } 48 | g.waitGroup.Done() 49 | }() 50 | // consumer 51 | g.worker() 52 | }() 53 | } 54 | // producer 55 | g.jobs <- param 56 | return nil 57 | } 58 | } 59 | 60 | func (g *GPool) Clear() { 61 | g.lock.Lock() 62 | defer g.lock.Unlock() 63 | select { 64 | case <-g.closeCh: 65 | default: 66 | for g.curCount > 0 { 67 | g.curCount-- 68 | g.jobs <- nil 69 | } 70 | } 71 | } 72 | 73 | func (g *GPool) Close() { 74 | g.lock.Lock() 75 | select { 76 | case <-g.closeCh: 77 | default: 78 | close(g.closeCh) 79 | close(g.jobs) 80 | } 81 | g.lock.Unlock() 82 | g.waitGroup.Wait() 83 | } 84 | 85 | func (g *GPool) worker() { 86 | for j := range g.jobs { 87 | if j == nil { 88 | break 89 | } 90 | g.fun(j) 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /grpc/grpc_client/README.md: -------------------------------------------------------------------------------- 1 | # grpc client 2 | 3 | ## 配置说明 4 | 5 | ```go 6 | type Config struct { 7 | Debug bool // 是否开启调试,默认不开启, 开启可以打印请求日志 8 | Addr string // 连接地址,直连为 127.0.0.1:9001,服务发现为 nacos:///appname 9 | BalancerName string // 负载均衡方式,默认round robin 10 | DialTimeout time.Duration // 连接超时,默认3s 11 | ReadTimeout time.Duration // 读超时,默认1s 12 | SlowLogThreshold time.Duration // 慢日志记录的阈值,默认600ms 13 | EnableBlock bool // 是否开启阻塞,默认开启 14 | // EnableOfficialGrpcLog bool // 是否开启官方grpc日志,默认关闭 // blog 和 zap 类型不兼容, 没法做 15 | EnableWithInsecure bool // 是否开启非安全传输,默认开启 16 | EnableMetricInterceptor bool // 是否开启监控,默认开启 17 | EnableTraceInterceptor bool // 是否开启链路追踪,默认开启 18 | EnableAppNameInterceptor bool // 是否开启传递应用名,默认开启 19 | EnableTimeoutInterceptor bool // 是否开启超时传递,默认开启 20 | EnableAccessInterceptor bool // 是否开启记录请求数据,默认不开启 21 | EnableAccessInterceptorReq bool // 是否开启记录请求参数,默认不开启 22 | EnableAccessInterceptorRes bool // 是否开启记录响应参数,默认不开启 23 | EnableServiceConfig bool // 是否开启服务配置,默认关闭 24 | EnableFailOnNonTempDialError bool 25 | } 26 | ``` 27 | 28 | ## 连接服务问题 29 | 30 | 默认情况下(我们组件逻辑), grpc 连接会设置 3s 超时, 超时没连接上就会 `panic`. 31 | 32 | 这样做是为了保证依赖的服务正常, 也是为了尽早暴露错误, `fail fast`. 33 | 34 | 但是测试环境不稳定, 或者我们允许循环依赖情况出现时, 默认配置没法满足需求. 35 | 36 | 如果你需要支持上述场景, 需要增加配置: 37 | 38 | ```toml 39 | enableBlock = false 40 | ``` 41 | 42 | 这种情况 grpc 连接不会 block 程序启动. 后续依赖服务正常后 grpc client 功能也会正常, 不需要做重启等操作. 43 | 44 | ### EnableBlock 和 OnFail 参数区别 45 | 46 | 更新: OnFail 参数已删除, 不再支持 `onFail = "error"` 这种行为. 47 | 48 | 先说结论, 大多数时候你应该使用 `EnableBlock = false` 配置. 49 | 50 | `OnFail` 是我们 component 通用参数, 基本意义就是开发者是否将当前 component 视为强依赖. 51 | 52 | `OnFail` 参数控制 `grpc.Dial` 建立连接时出错时的处理方式, 默认为 `panic` 结束程序. 53 | 54 | 当 `OnFail` 设置为 `error` 时, 仅仅会在连接失败时打印错误日志, 但是 conn 返回值其实是 nil, 所以当你的程序后续依赖这个连接时, 程序依旧会 `panic`. 所以这种情况仅适合程序不依赖这个 grpc client 逻辑时使用. 55 | 56 | `EnableBlock` 设置为 false 时, `grpc.Dial` 不会的返回错误, 所以 `OnFail` 参数其实会没有效果. 57 | -------------------------------------------------------------------------------- /grpc/grpc_client/grpc_client.go: -------------------------------------------------------------------------------- 1 | package grpc_client 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/weblazy/easy/elog" 9 | "github.com/weblazy/easy/grpc/grpc_client/grpc_client_config" 10 | "go.uber.org/zap" 11 | "google.golang.org/grpc/credentials/insecure" 12 | 13 | "google.golang.org/grpc" 14 | ) 15 | 16 | var emptyCtx = context.Background() 17 | 18 | // PackageName 设置包名 19 | const PackageName = "client.fgrpc" 20 | 21 | const grpcServiceConfig = `{"loadBalancingPolicy":"%s"}` 22 | 23 | type GrpcClient struct { 24 | config *grpc_client_config.Config 25 | *grpc.ClientConn 26 | err error 27 | } 28 | 29 | func NewGrpcClient(config *grpc_client_config.Config) *GrpcClient { 30 | var ctx = context.Background() 31 | 32 | if config == nil { 33 | config = grpc_client_config.DefaultConfig() 34 | } 35 | config.BuildDialOptions() 36 | 37 | var dialOptions = config.DialOptions 38 | // 默认配置使用block 39 | if config.EnableBlock { 40 | if config.DialTimeout > time.Duration(0) { 41 | var cancel context.CancelFunc 42 | ctx, cancel = context.WithTimeout(ctx, config.DialTimeout) 43 | defer cancel() 44 | } 45 | 46 | dialOptions = append(dialOptions, grpc.WithBlock()) 47 | } 48 | 49 | if config.EnableWithInsecure { 50 | dialOptions = append(dialOptions, grpc.WithTransportCredentials(insecure.NewCredentials())) 51 | } 52 | 53 | if config.KeepAlive != nil { 54 | dialOptions = append(dialOptions, grpc.WithKeepaliveParams(*config.KeepAlive)) 55 | } 56 | 57 | //// 因为默认是开启这个配置 58 | //// 并且开启后,在grpc 1.40以上会导致dns多一次解析txt内容(目测是为了做grpc的load balance策略,但我们实际上不会用到) 59 | //// 因为这个service config dns域名通常是没有设置dns解析,所以会跳过k8s的dns,穿透到上一级的dns,而如果dns配置有问题或者不存在,那么会查询非常长的时间(通常在20s或者更长) 60 | //// 那么为false的时候,禁用他,可以加快我们的启动时间或者提升我们的性能 61 | //if !config.EnableServiceConfig { 62 | // dialOptions = append(dialOptions, grpc.WithDisableServiceConfig()) 63 | //} 64 | 65 | // 直接使用 default server config 66 | dialOptions = append(dialOptions, grpc.WithDefaultServiceConfig(fmt.Sprintf(grpcServiceConfig, config.BalancerName)), grpc.FailOnNonTempDialError(config.EnableFailOnNonTempDialError)) 67 | 68 | startTime := time.Now() 69 | cc, err := grpc.DialContext(ctx, config.Addr, dialOptions...) 70 | 71 | client := &GrpcClient{ 72 | config: config, 73 | ClientConn: cc, 74 | } 75 | 76 | if err != nil { 77 | elog.ErrorCtx(emptyCtx, "dial grpc server", elog.FieldError(err), elog.FieldName(config.Name), zap.String("addr", config.Addr), elog.FieldCost(time.Since(startTime))) 78 | return client 79 | } 80 | 81 | elog.InfoCtx(emptyCtx, "start grpc client", elog.FieldName(config.Name), elog.FieldCost(time.Since(startTime))) 82 | return client 83 | } 84 | 85 | // Error 错误信息 86 | func (c *GrpcClient) Error() error { 87 | return c.err 88 | } 89 | -------------------------------------------------------------------------------- /grpc/grpc_client/grpc_client_test.go: -------------------------------------------------------------------------------- 1 | package grpc_client 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/smartystreets/goconvey/convey" 7 | ) 8 | 9 | func TestNewGrpcClient(t *testing.T) { 10 | convey.Convey("TestNewGrpcClient", t, func() { 11 | // cfg := grpc_client_config.DefaultConfig() 12 | // client := NewGrpcClient(cfg) 13 | // userClient := user.NewUserServiceClient(client) 14 | // resp, err := userClient.GetUserInfo(context.Background(), &user.GetUserInfoRequest{}) 15 | // convey.So(err, convey.ShouldBeNil) 16 | // convey.So(resp, convey.ShouldNotBeNil) 17 | }) 18 | } 19 | -------------------------------------------------------------------------------- /grpc/grpc_client/interceptor/grpc_header_carrier.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/weblazy/easy/transport" 7 | "google.golang.org/grpc" 8 | "google.golang.org/grpc/metadata" 9 | ) 10 | 11 | // GrpcHeaderCarrierInterceptor 12 | func GrpcHeaderCarrierInterceptor() grpc.UnaryClientInterceptor { 13 | return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, 14 | invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 15 | var md metadata.MD 16 | // try to append custom metadata to client request metadata 17 | if m, ok := metadata.FromOutgoingContext(ctx); ok { 18 | md = m.Copy() 19 | } else { 20 | md = metadata.MD{} 21 | } 22 | transport.CustomKeysMapPropagator.Inject(ctx, transport.GrpcHeaderCarrier(md)) 23 | ctx = metadata.NewOutgoingContext(ctx, md) 24 | return invoker(ctx, method, req, reply, cc, opts...) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /grpc/grpc_client/interceptor/metric.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "context" 5 | 6 | grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" 7 | "github.com/prometheus/client_golang/prometheus" 8 | "github.com/weblazy/easy/eerror" 9 | "github.com/weblazy/easy/fmetric" 10 | "google.golang.org/grpc" 11 | "google.golang.org/grpc/status" 12 | ) 13 | 14 | var ( 15 | // ClientHandleCounter ... 16 | ClientHandleCounter = prometheus.NewCounterVec( 17 | prometheus.CounterOpts{ 18 | Name: "grpc_client_handle_total", 19 | }, []string{"type", "name", "method", "peer", "code"}) 20 | 21 | // ClientHandleHistogram ... 22 | ClientHandleHistogram = prometheus.NewHistogramVec( 23 | prometheus.HistogramOpts{ 24 | Name: "grpc_client_handle_seconds", 25 | }, []string{"type", "name", "method", "peer"}) 26 | ) 27 | 28 | var ( 29 | ClientWithBizHandledCounter = prometheus.NewCounterVec( 30 | prometheus.CounterOpts{ 31 | Name: "monitor_grpc_client_accept_result_total", 32 | Help: "Total number of RPCs accept on the client, regardless of success or failure.", 33 | }, []string{"grpc_type", "grpc_service", "grpc_method", "grpc_code", "result_code"}) 34 | ) 35 | 36 | func MetricUnaryClientInterceptor(successCodes []string) grpc.UnaryClientInterceptor { 37 | grpc_prometheus.EnableClientHandlingTimeHistogram() 38 | originMw := grpc_prometheus.UnaryClientInterceptor 39 | extractor := eerror.ExtractBizCode(successCodes) 40 | 41 | return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 42 | err := originMw(ctx, method, req, reply, cc, invoker, opts...) 43 | st, _ := status.FromError(err) 44 | 45 | bizCode, ok := extractor(reply, err) 46 | if ok { 47 | service, method := fmetric.SplitGrpcMethodName(method) 48 | ClientWithBizHandledCounter.WithLabelValues("unary", service, method, st.Code().String(), bizCode).Inc() 49 | } 50 | return err 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /grpc/grpc_client/interceptor/timeout.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "google.golang.org/grpc" 8 | ) 9 | 10 | // copy from go-zero https://github.com/zeromicro/go-zero/blob/2732d3cdae5bf35dc07e926d3b5ed35e3c506393/zrpc/internal/clientinterceptors/timeoutinterceptor.go 11 | 12 | type contextKeyType struct{} 13 | 14 | var ctxKey contextKeyType 15 | 16 | // TimeoutInterceptor is an interceptor that controls timeout. 17 | func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor { 18 | return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, 19 | invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 20 | // fix closure error 21 | t := timeout 22 | 23 | if v, ok := ctx.Value(ctxKey).(time.Duration); ok { 24 | t = v 25 | } 26 | 27 | if v, ok := getForceTimeout(opts); ok { 28 | t = v 29 | } 30 | 31 | if t <= 0 { 32 | return invoker(ctx, method, req, reply, cc, opts...) 33 | } 34 | 35 | var cancel context.CancelFunc 36 | ctx, cancel = context.WithTimeout(ctx, t) 37 | defer cancel() 38 | 39 | return invoker(ctx, method, req, reply, cc, opts...) 40 | } 41 | } 42 | 43 | // CallOption is a grpc.CallOption that is local to timeout interceptor. 44 | type CallOption struct { 45 | grpc.EmptyCallOption 46 | 47 | forceTimeout time.Duration 48 | } 49 | 50 | // WithForceTimeout sets the RPC timeout for this call only. 51 | func WithForceTimeout(forceTimeout time.Duration) CallOption { 52 | return CallOption{forceTimeout: forceTimeout} 53 | } 54 | 55 | func getForceTimeout(callOptions []grpc.CallOption) (time.Duration, bool) { 56 | for _, opt := range callOptions { 57 | if co, ok := opt.(CallOption); ok { 58 | return co.forceTimeout, true 59 | } 60 | } 61 | 62 | return 0, false 63 | } 64 | 65 | // ForceTimeout force set timeout for this rpc call 66 | // Deprecated: use WithForceTimeout 67 | func ForceTimeout(ctx context.Context, timeout time.Duration) context.Context { 68 | if timeout <= 0 { 69 | return ctx 70 | } 71 | 72 | return context.WithValue(ctx, ctxKey, timeout) 73 | } 74 | -------------------------------------------------------------------------------- /grpc/grpc_server/grpc_server_test.go: -------------------------------------------------------------------------------- 1 | package grpc_server 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/smartystreets/goconvey/convey" 8 | "github.com/weblazy/easy/grpc/proto/user" 9 | ) 10 | 11 | func TestNewGrpcServer(t *testing.T) { 12 | convey.Convey("TestNewGrpcServer", t, func() { 13 | // cfg := grpc_server_config.DefaultConfig() 14 | // server := NewGrpcServer(cfg, &elog.LogConf{}) 15 | // user.RegisterUserServiceServer(server, &User{}) 16 | // err := server.Init() 17 | // convey.So(err, convey.ShouldBeNil) 18 | // err = server.Start() 19 | // convey.So(err, convey.ShouldBeNil) 20 | }) 21 | } 22 | 23 | type User struct { 24 | user.UserServiceServer 25 | } 26 | 27 | func (*User) GetUserInfo(ctx context.Context, req *user.GetUserInfoRequest) (*user.GetUserInfoResponse, error) { 28 | return &user.GetUserInfoResponse{ 29 | Detail: &user.User{ 30 | Name: "lazy", 31 | }, 32 | }, nil 33 | } 34 | -------------------------------------------------------------------------------- /grpc/grpc_server/interceptor/grpc_header_carrier.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/weblazy/easy/transport" 7 | "google.golang.org/grpc" 8 | "google.golang.org/grpc/metadata" 9 | ) 10 | 11 | func GrpcHeaderCarrierInterceptor() grpc.UnaryServerInterceptor { 12 | return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { 13 | if md, ok := metadata.FromIncomingContext(ctx); ok { 14 | ctx = transport.CustomKeysMapPropagator.Extract(ctx, transport.GrpcHeaderCarrier(md)) 15 | } 16 | return handler(ctx, req) 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /grpc/grpc_server/interceptor/log.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "time" 7 | 8 | "go.uber.org/zap" 9 | 10 | "github.com/google/uuid" 11 | "github.com/weblazy/easy/elog" 12 | "github.com/weblazy/easy/etrace" 13 | "github.com/weblazy/easy/grpc/grpc_server/grpc_server_config" 14 | "google.golang.org/grpc" 15 | "google.golang.org/grpc/metadata" 16 | ) 17 | 18 | var once sync.Once 19 | 20 | func GrpcLogger(config *grpc_server_config.Config) grpc.UnaryServerInterceptor { 21 | once.Do(config.InitLogger) 22 | return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { 23 | start := time.Now() 24 | 25 | // otel trace 26 | traceId := etrace.ExtractTraceID(ctx) 27 | 28 | md, _ := metadata.FromIncomingContext(ctx) 29 | // 尝试获取网关 traceid 30 | if traceId == "" { 31 | v := md.Get("traceid") 32 | if len(v) > 0 { 33 | traceId = v[0] 34 | } 35 | } 36 | 37 | // 服务内部生成 38 | if traceId == "" { 39 | traceId = uuid.NewString() 40 | } 41 | fields := make([]zap.Field, 0) 42 | fields = append(fields, elog.FieldMethod(info.FullMethod), elog.FieldReq(req), zap.Any("metadata", md)) 43 | 44 | resp, err = handler(ctx, req) 45 | ctx = elog.SetLogerName(ctx, grpc_server_config.PkgName) 46 | if err != nil { 47 | fields = append(fields, elog.FieldError(err), elog.FieldDuration(time.Since(start))) 48 | elog.ErrorCtx(ctx, grpc_server_config.PkgName, fields...) 49 | } else { 50 | fields = append(fields, elog.FieldResp(resp), elog.FieldDuration(time.Since(start))) 51 | elog.InfoCtx(ctx, grpc_server_config.PkgName, fields...) 52 | } 53 | return resp, err 54 | } 55 | } 56 | 57 | // func GrpcLoggerLite() grpc.UnaryServerInterceptor { 58 | // return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { 59 | // // otel trace 60 | // traceId := etrace.ExtractTraceID(ctx) 61 | 62 | // md, _ := metadata.FromIncomingContext(ctx) 63 | // // 尝试获取网关 traceid 64 | // if traceId == "" { 65 | // v := md.Get("traceid") 66 | // if len(v) > 0 { 67 | // traceId = v[0] 68 | // } 69 | // } 70 | 71 | // // 服务内部生成 72 | // if traceId == "" { 73 | // traceId = uuid.NewString() 74 | // } 75 | // logConf.Name = "server.grpc" 76 | // logConf.Fields = append(logConf.Fields, zap.String("trace_id", traceId), zap.String("method", info.FullMethod)) 77 | // return handler(ctx, req) 78 | // } 79 | // } 80 | -------------------------------------------------------------------------------- /grpc/grpc_server/interceptor/metric.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "context" 5 | 6 | grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" 7 | "github.com/prometheus/client_golang/prometheus" 8 | "github.com/weblazy/easy/eerror" 9 | "google.golang.org/grpc" 10 | "google.golang.org/grpc/status" 11 | 12 | "github.com/weblazy/easy/fmetric" 13 | ) 14 | 15 | // 目前图表只用到了 resultCode 和 app 字段(收集时自动注入) 16 | 17 | var ( 18 | ServerWithBizHandledCounter = prometheus.NewCounterVec( 19 | prometheus.CounterOpts{ 20 | Namespace: fmetric.DefaultNamespace, 21 | Name: "monitor_grpc_server_result_total", 22 | Help: "Total number of RPCs completed on the server, regardless of success or failure.", 23 | }, []string{"grpc_type", "grpc_service", "grpc_method", "grpc_code", "result_code"}) 24 | ) 25 | 26 | func MetricUnaryServerInterceptor(successCodes []string) grpc.UnaryServerInterceptor { 27 | grpc_prometheus.EnableHandlingTimeHistogram() 28 | originMw := grpc_prometheus.UnaryServerInterceptor 29 | extractor := eerror.ExtractBizCode(successCodes) 30 | 31 | return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { 32 | resp, err = originMw(ctx, req, info, handler) 33 | st, _ := status.FromError(err) 34 | 35 | bizCode, ok := extractor(resp, err) 36 | if ok { 37 | service, method := fmetric.SplitGrpcMethodName(info.FullMethod) 38 | ServerWithBizHandledCounter.WithLabelValues("unary", service, method, st.Code().String(), bizCode).Inc() 39 | } 40 | return 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /grpc/grpc_server/interceptor/metrics.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "github.com/prometheus/client_golang/prometheus" 5 | ) 6 | 7 | var ( 8 | ServerHandledCounter = prometheus.NewCounterVec( 9 | prometheus.CounterOpts{ 10 | Namespace: "fgo", 11 | Name: "grpc_server_handled_total", 12 | Help: "Total number of RPCs completed on the server, regardless of success or failure.", 13 | }, []string{"grpc_type", "method", "code", "uniform_code"}) 14 | 15 | ServerHandledHistogram = prometheus.NewHistogramVec( 16 | prometheus.HistogramOpts{ 17 | Namespace: "fgo", 18 | Name: "grpc_server_handling_seconds", 19 | Help: "Histogram of response latency (seconds) of gRPC that had been application-level handled by the server.", 20 | }, []string{"grpc_type", "method"}) 21 | ) 22 | -------------------------------------------------------------------------------- /grpc/grpc_server/interceptor/recovery.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "context" 5 | "runtime/debug" 6 | 7 | "go.uber.org/zap" 8 | 9 | grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery" 10 | "github.com/weblazy/easy/elog" 11 | "google.golang.org/grpc" 12 | 13 | "google.golang.org/grpc/codes" 14 | "google.golang.org/grpc/status" 15 | ) 16 | 17 | // 18 | 19 | func GrpcRecoveryHandler(ctx context.Context, p interface{}) (err error) { 20 | elog.ErrorCtx(ctx, "panic", zap.Any("err", p), zap.String("stack", string(debug.Stack()))) 21 | // 返回一个 grpc status 错误, 像 grpc_recovery 中间件默认行为那样 22 | return status.Errorf(codes.Internal, "panic: %v", p) 23 | } 24 | 25 | func UnaryRecoveryInterceptor() grpc.UnaryServerInterceptor { 26 | return grpc_recovery.UnaryServerInterceptor( 27 | grpc_recovery.WithRecoveryHandlerContext(GrpcRecoveryHandler)) 28 | } 29 | -------------------------------------------------------------------------------- /grpc/grpc_server/interceptor/timeout.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "runtime/debug" 7 | "strings" 8 | "sync" 9 | "time" 10 | 11 | "google.golang.org/grpc" 12 | "google.golang.org/grpc/codes" 13 | "google.golang.org/grpc/status" 14 | ) 15 | 16 | // copy from go-zero https://github.com/zeromicro/go-zero/blob/2732d3cdae5bf35dc07e926d3b5ed35e3c506393/zrpc/internal/serverinterceptors/timeoutinterceptor.go 17 | 18 | // UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests. 19 | func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor { 20 | return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, 21 | handler grpc.UnaryHandler) (interface{}, error) { 22 | ctx, cancel := context.WithTimeout(ctx, timeout) 23 | defer cancel() 24 | 25 | var resp interface{} 26 | var err error 27 | var lock sync.Mutex 28 | done := make(chan struct{}) 29 | // create channel with buffer size 1 to avoid goroutine leak 30 | panicChan := make(chan interface{}, 1) 31 | go func() { 32 | defer func() { 33 | if p := recover(); p != nil { 34 | // attach call stack to avoid missing in different goroutine 35 | panicChan <- fmt.Sprintf("%+v\n\n%s", p, strings.TrimSpace(string(debug.Stack()))) 36 | } 37 | }() 38 | 39 | lock.Lock() 40 | defer lock.Unlock() 41 | resp, err = handler(ctx, req) 42 | close(done) 43 | }() 44 | 45 | select { 46 | case p := <-panicChan: 47 | panic(p) 48 | case <-done: 49 | lock.Lock() 50 | defer lock.Unlock() 51 | return resp, err 52 | case <-ctx.Done(): 53 | err := ctx.Err() 54 | 55 | if err == context.Canceled { 56 | err = status.Error(codes.Canceled, err.Error()) 57 | } else if err == context.DeadlineExceeded { 58 | err = status.Error(codes.DeadlineExceeded, err.Error()) 59 | } 60 | return nil, err 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /grpc/proto/response/response.proto: -------------------------------------------------------------------------------- 1 | 2 | syntax = "proto3"; 3 | 4 | package response; 5 | 6 | option go_package = "github.com/weblazy/easy/grpc/proto/response"; 7 | 8 | message CodeError{ 9 | int64 code = 1; 10 | string msg = 2; 11 | } 12 | -------------------------------------------------------------------------------- /grpc/proto/user/user.proto: -------------------------------------------------------------------------------- 1 | 2 | syntax = "proto3"; 3 | 4 | package user; 5 | 6 | option go_package = "./user"; 7 | 8 | message GetUserInfoRequest{ 9 | int64 uid = 1; 10 | 11 | } 12 | message GetUserInfoResponse{ 13 | User detail = 1; 14 | repeated User list = 2; 15 | 16 | } 17 | message User{ 18 | int64 uid = 1; 19 | string name = 2; 20 | 21 | } 22 | 23 | service UserService{ 24 | rpc GetUserInfo(GetUserInfoRequest) returns (GetUserInfoResponse); 25 | } 26 | -------------------------------------------------------------------------------- /http/http_client/http_client_config/config.go: -------------------------------------------------------------------------------- 1 | package http_client_config 2 | 3 | import ( 4 | "crypto/tls" 5 | "runtime" 6 | "time" 7 | ) 8 | 9 | const ( 10 | PkgName = "http_client" 11 | ) 12 | 13 | // Config HTTP配置选项 14 | type Config struct { 15 | Name string //名称 16 | Addr string // 连接地址 17 | RawDebug bool // 是否开启原生调试,默认不开启 18 | ReadTimeout time.Duration // 读超时,默认 3s 19 | SlowLogThreshold time.Duration // 慢日志记录的阈值,默认 1s 20 | IdleConnTimeout time.Duration // 设置空闲连接时间,默认90 * time.Second 21 | MaxIdleConns int // 设置最大空闲连接数 22 | MaxIdleConnsPerHost int // 设置长连接个数 23 | EnableKeepAlives bool // 是否开启长连接,默认打开 24 | Proxy string // 支持配置显示传递代理,如:http:// 25 | 26 | EnableMetricInterceptor bool // 是否开启 metric, 默认关闭 27 | MetricPathRewriter MetricPathRewriter // 指标监控 path 重写方法, 防止 metrics label 不可控 28 | 29 | EnableTraceInterceptor bool // 是否开启链路追踪,默认开启 30 | EnableAccessInterceptor bool // 是否开启记录请求数据,默认开启 31 | EnableAccessInterceptorReq bool // 是否开启记录请求参数,默认开启 32 | EnableAccessInterceptorReqHeader bool // 是否开启记录请求 header 参数,默认关闭 33 | EnableAccessInterceptorRes bool // 是否开启记录响应参数,默认开启 34 | TLSClientConfig *tls.Config 35 | DisableCompression bool 36 | } 37 | 38 | // DefaultConfig ... 39 | func DefaultConfig() *Config { 40 | return &Config{ 41 | RawDebug: false, 42 | ReadTimeout: time.Second * 3, 43 | SlowLogThreshold: time.Second, 44 | MaxIdleConns: 100, 45 | MaxIdleConnsPerHost: runtime.GOMAXPROCS(0) + 1, 46 | IdleConnTimeout: 90 * time.Second, 47 | EnableKeepAlives: true, 48 | EnableTraceInterceptor: true, 49 | EnableAccessInterceptor: true, 50 | EnableAccessInterceptorReq: true, 51 | EnableAccessInterceptorRes: true, 52 | MetricPathRewriter: DefaultMetricPathRewriter, 53 | } 54 | } 55 | 56 | type MetricPathRewriter func(origin string) string 57 | 58 | func DefaultMetricPathRewriter(origin string) string { 59 | return origin 60 | } 61 | -------------------------------------------------------------------------------- /http/http_client/http_client_test.go: -------------------------------------------------------------------------------- 1 | package http_client 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/smartystreets/goconvey/convey" 7 | ) 8 | 9 | func TestNewHttpClient(t *testing.T) { 10 | convey.Convey("TestNewHttpClient", t, func() { 11 | // cfg := http_client_config.DefaultConfig() 12 | // client := NewHttpClient(cfg) 13 | // request := client.Request.SetContext(context.Background()) 14 | // resp, err := request.Get("https://www.baidu.com/") 15 | // body := string(resp.Body()) 16 | // convey.So(body, convey.ShouldNotBeNil) 17 | // convey.So(err, convey.ShouldBeNil) 18 | }) 19 | } 20 | -------------------------------------------------------------------------------- /http/http_client/interceptor/header_carrier.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "github.com/go-resty/resty/v2" 5 | "github.com/weblazy/easy/transport" 6 | "go.opentelemetry.io/otel/propagation" 7 | ) 8 | 9 | // 多个服务间透传参数 10 | func HeaderCarrierInterceptor() (resty.RequestMiddleware, resty.ResponseMiddleware, resty.ErrorHook) { 11 | beforeFn := func(cli *resty.Client, req *resty.Request) error { 12 | transport.CustomKeysMapPropagator.Inject(req.Context(), propagation.HeaderCarrier(req.Header)) 13 | return nil 14 | } 15 | 16 | return beforeFn, nil, nil 17 | } 18 | -------------------------------------------------------------------------------- /http/http_client/interceptor/metric.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "net/url" 5 | "strconv" 6 | "time" 7 | 8 | "github.com/go-resty/resty/v2" 9 | "github.com/prometheus/client_golang/prometheus" 10 | "github.com/weblazy/easy/http/http_client/http_client_config" 11 | ) 12 | 13 | var ( 14 | 15 | // ClientHandleCounter ... 16 | ClientHandleCounter = prometheus.NewCounterVec( 17 | prometheus.CounterOpts{ 18 | Namespace: "", 19 | Name: "http_client_handle_total", 20 | }, []string{"name", "method", "path", "peer", "code"}) 21 | 22 | // ClientHandleHistogram ... 23 | ClientHandleHistogram = prometheus.NewHistogramVec(prometheus.HistogramOpts{ 24 | Namespace: "", 25 | Name: "http_client_handle_seconds", 26 | }, []string{"name", "method", "path", "peer"}) 27 | ) 28 | 29 | func init() { 30 | prometheus.MustRegister(ClientHandleCounter) 31 | prometheus.MustRegister(ClientHandleHistogram) 32 | } 33 | 34 | func MetricInterceptor(name, addr string, rewriter http_client_config.MetricPathRewriter) (resty.RequestMiddleware, resty.ResponseMiddleware, resty.ErrorHook) { 35 | if rewriter == nil { 36 | rewriter = http_client_config.DefaultMetricPathRewriter 37 | } 38 | 39 | afterFn := func(cli *resty.Client, res *resty.Response) error { 40 | method := res.Request.Method 41 | path := rewriter(res.Request.RawRequest.URL.Path) 42 | ClientHandleCounter.WithLabelValues(name, method, path, addr, strconv.Itoa(res.StatusCode())).Inc() 43 | ClientHandleHistogram.WithLabelValues(name, method, path, addr).Observe(res.Time().Seconds()) 44 | return nil 45 | } 46 | 47 | errorFn := func(req *resty.Request, err error) { 48 | method := req.Method 49 | var path string 50 | 51 | // OnBeforeRequest 有错误时, 拿不到 req.RawRequest 52 | u, err2 := url.Parse(req.URL) 53 | if err2 != nil { 54 | path = "invalidUrl" 55 | } else { 56 | path = rewriter(u.Path) 57 | } 58 | 59 | if v, ok := err.(*resty.ResponseError); ok { 60 | ClientHandleCounter.WithLabelValues(name, method, path, addr, strconv.Itoa(v.Response.StatusCode())).Inc() 61 | } else { 62 | ClientHandleCounter.WithLabelValues(name, method, path, addr, "unknown").Inc() 63 | } 64 | 65 | ClientHandleHistogram.WithLabelValues(name, method, path, addr).Observe(time.Since(GetStartTime(req.Context())).Seconds()) 66 | } 67 | 68 | return nil, afterFn, errorFn 69 | } 70 | -------------------------------------------------------------------------------- /http/http_client/interceptor/start_time.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/go-resty/resty/v2" 8 | ) 9 | 10 | // https://stackoverflow.com/questions/40891345/fix-should-not-use-basic-type-string-as-key-in-context-withvalue-golint 11 | // https://blog.golang.org/context#TOC_3.2. 12 | // https://golang.org/pkg/context/#WithValue ,这边文章说明了用struct,可以避免分配 13 | type startTimeKey struct{} 14 | 15 | func SetStartTimeInterceptor() (resty.RequestMiddleware, resty.ResponseMiddleware, resty.ErrorHook) { 16 | return func(cli *resty.Client, req *resty.Request) error { 17 | req.SetContext(context.WithValue(req.Context(), startTimeKey{}, time.Now())) 18 | return nil 19 | }, nil, nil 20 | } 21 | 22 | func GetStartTime(ctx context.Context) time.Time { 23 | startTime, _ := ctx.Value(startTimeKey{}).(time.Time) 24 | return startTime 25 | } 26 | -------------------------------------------------------------------------------- /http/http_client/interceptor/trace.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "github.com/go-resty/resty/v2" 5 | "github.com/spf13/cast" 6 | "github.com/weblazy/easy/http/http_client/http_client_config" 7 | "go.opentelemetry.io/otel" 8 | "go.opentelemetry.io/otel/attribute" 9 | "go.opentelemetry.io/otel/codes" 10 | "go.opentelemetry.io/otel/trace" 11 | ) 12 | 13 | // Deprecated: use otel http transport 14 | func TraceInterceptor(name string, cfg *http_client_config.Config) (resty.RequestMiddleware, resty.ResponseMiddleware, resty.ErrorHook) { //nolint 15 | tracer := otel.Tracer("") 16 | 17 | beforeFn := func(cli *resty.Client, req *resty.Request) error { 18 | ctx, span := tracer.Start(req.Context(), req.Method, trace.WithSpanKind(trace.SpanKindClient)) 19 | 20 | span.SetAttributes( 21 | attribute.String("peer.service", name), 22 | attribute.String("http.method", req.Method), 23 | attribute.String("http.url", req.URL), 24 | ) 25 | 26 | req.SetContext(ctx) 27 | return nil 28 | } 29 | 30 | afterFn := func(cli *resty.Client, res *resty.Response) error { 31 | span := trace.SpanFromContext(res.Request.Context()) 32 | span.SetAttributes( 33 | attribute.String("http.status_code", cast.ToString(res.StatusCode())), 34 | ) 35 | 36 | span.End() 37 | return nil 38 | } 39 | 40 | errorFn := func(req *resty.Request, err error) { 41 | span := trace.SpanFromContext(req.Context()) 42 | 43 | if err != nil { 44 | span.RecordError(err) 45 | span.SetStatus(codes.Error, err.Error()) 46 | } 47 | 48 | span.End() 49 | } 50 | return beforeFn, afterFn, errorFn 51 | } 52 | -------------------------------------------------------------------------------- /http/http_client/log.go: -------------------------------------------------------------------------------- 1 | package http_client 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "time" 9 | 10 | "github.com/go-resty/resty/v2" 11 | "github.com/spf13/cast" 12 | "github.com/weblazy/easy/elog" 13 | "github.com/weblazy/easy/elog/sls" 14 | ) 15 | 16 | type Log interface { 17 | Info(obj *LogObject) error 18 | } 19 | 20 | type LogObject struct { 21 | Url string `json:"url"` 22 | Method string `json:"method"` 23 | RequestHders http.Header `json:"request_headers"` 24 | RequestRawBody interface{} `json:"request_raw_body"` 25 | ResponseHeaders http.Header `json:"response_headers"` 26 | ResponseBody string `json:"response_body"` 27 | StartTime string `json:"start_time"` 28 | Duration time.Duration `json:"duration"` 29 | Status int `json:"status"` 30 | } 31 | 32 | func (h *HttpClient) SetLog(log Log) *HttpClient { 33 | err := h.Client.OnAfterResponse(func(client *resty.Client, resp *resty.Response) error { 34 | r := resp.Request 35 | obj := &LogObject{ 36 | Url: r.URL, 37 | Method: r.Method, 38 | RequestHders: r.Header, 39 | RequestRawBody: r.Body, 40 | ResponseHeaders: resp.Header(), 41 | ResponseBody: string(resp.Body()), 42 | StartTime: r.Time.Format("2006-01-02 15:04:05"), 43 | Duration: resp.Time() / time.Millisecond, 44 | Status: resp.StatusCode(), 45 | } 46 | log.Info(obj) 47 | return nil 48 | }) 49 | if err != nil { 50 | fmt.Printf("%#v\n", err) 51 | } 52 | return h 53 | } 54 | 55 | type GocoreLog struct { 56 | } 57 | 58 | func NewGocoreLog() *GocoreLog { 59 | return &GocoreLog{} 60 | } 61 | 62 | func (l *GocoreLog) Info(obj *LogObject) error { 63 | data, _ := json.Marshal(obj) 64 | elog.InfoCtx(context.Background(), string(data)) 65 | return nil 66 | } 67 | 68 | func NewAliyunLog(topic string) *AliyunLog { 69 | return &AliyunLog{topic: topic} 70 | } 71 | 72 | type AliyunLog struct { 73 | topic string 74 | } 75 | 76 | // 使用阿里云日志需要提前调用sls.InitLog初始化 77 | func (l *AliyunLog) Info(obj *LogObject) error { 78 | requestHeaderBytes, _ := json.Marshal(obj.RequestHders) 79 | requestBodyBytes, _ := json.Marshal(obj.RequestRawBody) 80 | responseHeaderBytes, _ := json.Marshal(obj.ResponseHeaders) 81 | _ = sls.Info(l.topic, map[string]string{ 82 | "url": obj.Url, 83 | "method": obj.Method, 84 | "request_headers": string(requestHeaderBytes), 85 | "request_raw_body": string(requestBodyBytes), 86 | "response_headers": string(responseHeaderBytes), 87 | "start_time": obj.StartTime, 88 | "duration": cast.ToString(obj.Duration), 89 | "status": cast.ToString(obj.Status), 90 | }) 91 | return nil 92 | } 93 | -------------------------------------------------------------------------------- /http/http_client/query.go: -------------------------------------------------------------------------------- 1 | package http_client 2 | 3 | import ( 4 | "fmt" 5 | "net/url" 6 | 7 | "github.com/weblazy/easy/stringx" 8 | ) 9 | 10 | // 编译http参数 11 | func MapToQuery(params map[string]interface{}, urlEncode ...bool) (string, error) { 12 | if params == nil { 13 | return "", fmt.Errorf("param is nil") 14 | } 15 | v := make(url.Values) 16 | for key := range params { 17 | value, err := stringx.ToString(params[key]) 18 | if err != nil { 19 | return "", nil 20 | } 21 | v.Add(key, value) 22 | } 23 | encodeStr := v.Encode() 24 | if len(urlEncode) > 0 && urlEncode[0] { 25 | return encodeStr, nil 26 | } 27 | decodeStr, _ := url.QueryUnescape(encodeStr) 28 | return decodeStr, nil 29 | 30 | } 31 | -------------------------------------------------------------------------------- /http/http_server/http_server.go: -------------------------------------------------------------------------------- 1 | package http_server 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/fvbock/endless" 8 | "github.com/gin-gonic/gin" 9 | "github.com/spf13/viper" 10 | "go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin" 11 | 12 | "github.com/weblazy/easy/http/http_server/http_server_config" 13 | "github.com/weblazy/easy/http/http_server/interceptor" 14 | ) 15 | 16 | type HttpServer struct { 17 | Config *http_server_config.Config 18 | *gin.Engine 19 | } 20 | 21 | func NewHttpServerViper(key string, cfg *viper.Viper) (*HttpServer, error) { 22 | c := http_server_config.DefaultConfig() 23 | cfg.UnmarshalKey(key, c) 24 | server := &HttpServer{ 25 | Config: c, 26 | } 27 | return server, nil 28 | } 29 | 30 | func NewHttpServer(c *http_server_config.Config) (*HttpServer, error) { 31 | if c == nil { 32 | c = http_server_config.DefaultConfig() 33 | } 34 | 35 | server := &HttpServer{ 36 | Config: c, 37 | } 38 | ctx := context.Background() 39 | // opts = append([]RunOption{WithNotFoundHandler(nil)}, opts...) 40 | // for _, opt := range opts { 41 | // opt(server) 42 | // } 43 | r := gin.New() 44 | r.Use(interceptor.SetStartTimeInterceptor()) 45 | if server.Config.EnableTraceInterceptor { 46 | r.Use(otelgin.Middleware(c.Name)) 47 | r.Use(interceptor.Trace(ctx)) 48 | } 49 | r.Use(interceptor.HeaderCarrierInterceptor()) 50 | 51 | if server.Config.EnableLogInterceptor { 52 | r.Use(interceptor.Log(ctx, c)) 53 | } 54 | if server.Config.EnableMetricInterceptor { 55 | r.Use(interceptor.MetricInterceptor(c)) 56 | } 57 | if server.Config.Timeout > 0 { 58 | r.Use(interceptor.Timeout(server.Config.Timeout)) 59 | } 60 | r.Use(gin.Recovery()) 61 | server.Engine = r 62 | return server, nil 63 | } 64 | 65 | func (s *HttpServer) Start() error { 66 | return endless.ListenAndServe(fmt.Sprintf("%s:%d", s.Config.Host, s.Config.Port), s) 67 | } 68 | -------------------------------------------------------------------------------- /http/http_server/http_server_config/http_server_config.go: -------------------------------------------------------------------------------- 1 | package http_server_config 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/spf13/viper" 7 | "github.com/weblazy/easy/elog" 8 | "github.com/weblazy/easy/elog/ezap" 9 | ) 10 | 11 | const PkgName = "http_server" 12 | 13 | type Config struct { 14 | Name string 15 | Host string // IP地址,默认0.0.0.0 16 | Port int // Port端口,默认80 17 | 18 | Timeout time.Duration 19 | SlowLogThreshold time.Duration // 慢日志记录的阈值,默认 1s 20 | 21 | EnableTraceInterceptor bool 22 | EnableMetricInterceptor bool 23 | EnableLogInterceptor bool 24 | EnableAccessInterceptor bool // 是否开启记录请求数据,默认开启 25 | 26 | EnableFielLogger bool // 将日志输出到文件 27 | FielLoggerPath string 28 | MetricPathRewriter MetricPathRewriter 29 | } 30 | 31 | // DefaultConfig default config ... 32 | func DefaultConfig() *Config { 33 | return &Config{ 34 | Host: "0.0.0.0", 35 | Port: 80, 36 | Timeout: 3 * time.Second, 37 | SlowLogThreshold: time.Second, 38 | EnableTraceInterceptor: true, 39 | EnableMetricInterceptor: true, 40 | EnableLogInterceptor: true, 41 | EnableAccessInterceptor: true, 42 | FielLoggerPath: PkgName, 43 | MetricPathRewriter: DefaultMetricPathRewriter, 44 | } 45 | } 46 | 47 | type MetricPathRewriter func(origin string) string 48 | 49 | func DefaultMetricPathRewriter(origin string) string { 50 | return origin 51 | } 52 | 53 | func GetViperConfig(key string, cfg *viper.Viper) (*Config, error) { 54 | c := DefaultConfig() 55 | err := cfg.UnmarshalKey(key, c) 56 | if err != nil { 57 | return nil, err 58 | } 59 | return c, nil 60 | } 61 | 62 | func (config Config) InitLogger() { 63 | if config.EnableFielLogger { 64 | logger := ezap.NewFileEzap(config.FielLoggerPath) 65 | elog.SetLogger(PkgName, logger) 66 | return 67 | } 68 | elog.SetLogger(PkgName, elog.DefaultLogger) 69 | } 70 | -------------------------------------------------------------------------------- /http/http_server/http_server_test.go: -------------------------------------------------------------------------------- 1 | package http_server 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/smartystreets/goconvey/convey" 7 | ) 8 | 9 | func TestNewHttpServer(t *testing.T) { 10 | convey.Convey("TestNewHttpServer", t, func() { 11 | // fcfg := ejaeger.DefaultConfig() 12 | // etrace.SetGlobalTracer(fcfg.Build()) 13 | // cfg := http_server_config.DefaultConfig() 14 | // server, err := NewHttpServer(cfg) 15 | // convey.So(err, convey.ShouldBeNil) 16 | // err = server.Start() 17 | // convey.So(err, convey.ShouldBeNil) 18 | }) 19 | } 20 | 21 | // func TestNewHttpServerViper(t *testing.T) { 22 | // convey.Convey("test config", t, func() { 23 | // cfg := viper.New() 24 | // cfg.SetConfigType("toml") 25 | // s := strings.NewReader(`Name="go" 26 | // [http_server] 27 | // name=6666 28 | // level=50`) 29 | 30 | // err := cfg.ReadConfig(s) 31 | // convey.So(err, convey.ShouldBeNil) 32 | // resp, err := NewHttpServerViper("http_server", cfg) 33 | // fmt.Printf("%#v\n", resp.Config) 34 | // convey.So(resp, convey.ShouldNotBeNil) 35 | // convey.So(err, convey.ShouldBeNil) 36 | // }) 37 | // } 38 | -------------------------------------------------------------------------------- /http/http_server/interceptor/header_carrier.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "github.com/gin-gonic/gin" 5 | "github.com/weblazy/easy/transport" 6 | "go.opentelemetry.io/otel/propagation" 7 | ) 8 | 9 | // 多个服务间透传参数 10 | func HeaderCarrierInterceptor() gin.HandlerFunc { 11 | return func(c *gin.Context) { 12 | c.Request = c.Request.WithContext(transport.CustomKeysMapPropagator.Extract(c.Request.Context(), propagation.HeaderCarrier(c.Request.Header))) 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /http/http_server/interceptor/metric.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "strconv" 5 | "time" 6 | 7 | "github.com/gin-gonic/gin" 8 | "github.com/prometheus/client_golang/prometheus" 9 | "github.com/weblazy/easy/http/http_server/http_server_config" 10 | ) 11 | 12 | var ( 13 | 14 | // ServerHandleCounter ... 15 | ServerHandleCounter = prometheus.NewCounterVec( 16 | prometheus.CounterOpts{ 17 | Namespace: "", 18 | Name: "http_server_handle_total", 19 | }, []string{"name", "method", "path", "host", "code"}) 20 | 21 | // ServerHandleHistogram ... 22 | ServerHandleHistogram = prometheus.NewHistogramVec(prometheus.HistogramOpts{ 23 | Namespace: "", 24 | Name: "http_server_handle_seconds", 25 | }, []string{"name", "method", "path", "host"}) 26 | ) 27 | 28 | func init() { 29 | prometheus.MustRegister(ServerHandleCounter) 30 | prometheus.MustRegister(ServerHandleHistogram) 31 | } 32 | func MetricInterceptor(cfg *http_server_config.Config) gin.HandlerFunc { 33 | return func(c *gin.Context) { 34 | c.Next() 35 | if cfg.MetricPathRewriter == nil { 36 | cfg.MetricPathRewriter = http_server_config.DefaultMetricPathRewriter 37 | } 38 | ServerHandleCounter.WithLabelValues(cfg.Name, c.Request.Method, cfg.MetricPathRewriter(c.Request.URL.Path), c.Request.URL.Host, strconv.Itoa(c.Writer.Status())).Inc() 39 | ServerHandleHistogram.WithLabelValues(cfg.Name, c.Request.Method, cfg.MetricPathRewriter(c.Request.URL.Path), c.Request.URL.Host).Observe(time.Since(GetStartTime(c.Request.Context())).Seconds()) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /http/http_server/interceptor/start_time.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/gin-gonic/gin" 8 | ) 9 | 10 | // https://stackoverflow.com/questions/40891345/fix-should-not-use-basic-type-string-as-key-in-context-withvalue-golint 11 | // https://blog.golang.org/context#TOC_3.2. 12 | // https://golang.org/pkg/context/#WithValue ,这边文章说明了用struct,可以避免分配 13 | type startTimeKey struct{} 14 | 15 | func SetStartTimeInterceptor() gin.HandlerFunc { 16 | return func(c *gin.Context) { 17 | c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), startTimeKey{}, time.Now())) 18 | } 19 | } 20 | 21 | func GetStartTime(ctx context.Context) time.Time { 22 | startTime, _ := ctx.Value(startTimeKey{}).(time.Time) 23 | return startTime 24 | } 25 | 26 | func GetDuration(ctx context.Context) time.Duration { 27 | startTime, _ := ctx.Value(startTimeKey{}).(time.Time) 28 | return time.Since(startTime) 29 | } 30 | 31 | func GetDurationMilliseconds(ctx context.Context) float64 { 32 | startTime, _ := ctx.Value(startTimeKey{}).(time.Time) 33 | return float64(time.Since(startTime).Microseconds()) / 1000 34 | 35 | } 36 | -------------------------------------------------------------------------------- /http/http_server/interceptor/timeout.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/gin-gonic/gin" 9 | ) 10 | 11 | // timeout middleware wraps the request context with a timeout 12 | func Timeout(timeout time.Duration) func(c *gin.Context) { 13 | return func(c *gin.Context) { 14 | 15 | // wrap the request context with a timeout 16 | ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) 17 | 18 | defer func() { 19 | // check if context timeout was reached 20 | if ctx.Err() == context.DeadlineExceeded { 21 | 22 | // write response and abort the request 23 | c.Writer.WriteHeader(http.StatusGatewayTimeout) 24 | c.Abort() 25 | } 26 | 27 | //cancel to clear resources after finished 28 | cancel() 29 | }() 30 | 31 | // replace request with context wrapped request 32 | c.Request = c.Request.WithContext(ctx) 33 | c.Next() 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /http/http_server/interceptor/token.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/gin-gonic/gin" 9 | "github.com/weblazy/easy/code_err" 10 | "github.com/weblazy/easy/econfig" 11 | ) 12 | 13 | // Token 14 | func Token(userIdHeader string, validateToken func(token string) (uid string, err error)) gin.HandlerFunc { 15 | return func(c *gin.Context) { 16 | req := c.Request 17 | header := req.Header 18 | debugKey := header.Get(DebugHeader) 19 | if !econfig.GlobalViper.GetBool("BaseConfig.Debug") || debugKey != econfig.GlobalViper.GetString("BaseConfig.XDebugKey") { 20 | token := c.Request.Header.Get(TokenHeader) 21 | if token == "" { 22 | Error(c, code_err.TokenErr, fmt.Errorf("token 不存在")) 23 | return 24 | } 25 | uid, err := validateToken(token) 26 | if err != nil { 27 | Error(c, code_err.TokenErr, err) 28 | return 29 | } 30 | header.Set(userIdHeader, uid) 31 | } 32 | c.Next() 33 | } 34 | } 35 | 36 | // Sign 37 | func Sign() gin.HandlerFunc { 38 | return func(c *gin.Context) { 39 | req := c.Request 40 | header := req.Header 41 | debugKey := header.Get(DebugHeader) 42 | var bodyBytes []byte 43 | bodyBytes, err := io.ReadAll(c.Request.Body) 44 | if err != nil { 45 | Error(c, code_err.ParamsErr, fmt.Errorf("Invalid request body")) 46 | return 47 | } 48 | // 新建缓冲区并替换原有Request.body 49 | c.Request.Body = io.NopCloser(bytes.NewBuffer([]byte(bodyBytes))) 50 | if !econfig.GlobalViper.GetBool("BaseConfig.Debug") || debugKey != econfig.GlobalViper.GetString("BaseConfig.XDebugKey") { 51 | sign := header.Get(SignHeader) 52 | token := header.Get(TokenHeader) 53 | timestamp := header.Get(TimestampHeader) 54 | nonce := header.Get(NonceHeader) 55 | if token == "" { 56 | token = nonce + timestamp 57 | } 58 | err = ValidateSign(sign, token, []byte(string(bodyBytes)+timestamp+nonce)) 59 | if err != nil { 60 | Error(c, code_err.SignErr, err) 61 | return 62 | } 63 | } 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /http/http_server/interceptor/trace.go: -------------------------------------------------------------------------------- 1 | package interceptor 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/gin-gonic/gin" 7 | "github.com/weblazy/easy/etrace" 8 | "github.com/weblazy/easy/transport" 9 | ) 10 | 11 | func Trace(ctx context.Context) gin.HandlerFunc { 12 | return func(c *gin.Context) { 13 | c.Writer.Header().Set(transport.PrefixPass+"traceid", etrace.ExtractTraceID(c.Request.Context())) 14 | c.Next() 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /http/http_server/service/response.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | type Response struct { 4 | Code int64 `json:"code"` 5 | Data interface{} `json:"data"` 6 | Msg string `json:"msg"` 7 | } 8 | 9 | var defaultResponse Response 10 | 11 | func init() { 12 | defaultResponse = Response{ 13 | Code: 1, 14 | Data: nil, 15 | Msg: "", 16 | } 17 | } 18 | 19 | // NewResponse 获取默认返回内容 20 | func NewResponse() Response { 21 | return defaultResponse 22 | } 23 | 24 | // SetDefaultCode 设置默认返回code码 25 | func SetDefaultCode(code int64) { 26 | defaultResponse.Code = code 27 | } 28 | 29 | // SetDefaultData 设置默认返回data内容 30 | func SetDefaultData(data interface{}) { 31 | defaultResponse.Data = data 32 | } 33 | 34 | // SetDefaultMsg 设置默认返回msg内容 35 | func SetDefaultMsg(msg string) { 36 | defaultResponse.Msg = msg 37 | } 38 | -------------------------------------------------------------------------------- /http/http_server/service/service_context.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | 8 | "github.com/gin-gonic/gin" 9 | "github.com/weblazy/easy/code_err" 10 | "github.com/weblazy/easy/ectx" 11 | "github.com/weblazy/easy/env" 12 | ) 13 | 14 | type ServiceContext struct { 15 | *gin.Context 16 | *code_err.Log 17 | R Response 18 | Ctx context.Context //gin.Context.Request.Context 19 | } 20 | 21 | var ( 22 | ErrorBind = errors.New("missing required parameters") 23 | defaultErrCode int64 = -1 24 | ) 25 | 26 | // NewContext 初始化上下文包含context.Context 27 | func NewServiceContext(g *gin.Context) *ServiceContext { 28 | ctx := g.Request.Context() 29 | c := ServiceContext{ 30 | Context: g, 31 | Log: code_err.NewLog(ctx), 32 | R: NewResponse(), 33 | Ctx: ectx.NewNoCancelContext(ctx), 34 | } 35 | 36 | return &c 37 | } 38 | 39 | // Success 返回正常数据 40 | func (c *ServiceContext) Success(data interface{}) { 41 | c.R.Data = data 42 | c.JSON(http.StatusOK, c.R) 43 | } 44 | 45 | // Error 返回异常信息,自动识别Code码 46 | func (c *ServiceContext) Error(err error) { 47 | c.R.Code = defaultErrCode 48 | if e, ok := err.(*code_err.CodeErr); ok { 49 | c.R.Code = e.Code 50 | } 51 | c.R.Msg = err.Error() 52 | c.JSON(http.StatusOK, c.R) 53 | } 54 | 55 | // ErrorCodeMsg 直接指定code和msg 56 | func (c *ServiceContext) ErrorCodeMsg(code int64, msg string) { 57 | c.R.Code = code 58 | c.R.Msg = msg 59 | c.JSON(http.StatusOK, c.R) 60 | } 61 | 62 | // Response 直接指定code和msg和data 63 | func (c *ServiceContext) Response(code int64, msg string, data interface{}) { 64 | c.R.Code = code 65 | c.R.Msg = msg 66 | c.R.Data = data 67 | c.JSON(http.StatusOK, c.R) 68 | } 69 | func (c *ServiceContext) Return(err *code_err.CodeErr) { 70 | if err != nil { 71 | c.R.Code = err.Code 72 | c.R.Msg = err.Msg 73 | } 74 | c.JSON(http.StatusOK, c.R) 75 | } 76 | 77 | // Success 返回正常数据 78 | func (c *ServiceContext) SetData(data interface{}) *code_err.CodeErr { 79 | c.R.Data = data 80 | return nil 81 | } 82 | 83 | // BindValidator 参数绑定结构体,并且按照tag进行校验返回校验结果 84 | func (c *ServiceContext) BindValidator(obj interface{}) error { 85 | err := c.ShouldBind(obj) 86 | if err != nil { 87 | if env.IsRelease() { 88 | return ErrorBind 89 | } 90 | return err 91 | } 92 | return nil 93 | } 94 | -------------------------------------------------------------------------------- /invite_code/invite_code.go: -------------------------------------------------------------------------------- 1 | package invite_code 2 | 3 | import ( 4 | "math/rand" 5 | "time" 6 | ) 7 | 8 | // 自定义进制(0,1没有加入,容易与o,l混淆),数组顺序可进行调整增加反推难度,A用来补位因此此数组不包含A,共31个字符。 9 | var BaseByte = []byte{'H', 'V', 'E', '8', 'S', '2', 'D', 'Z', 'X', '9', 'C', '7', 'P', 10 | '5', 'I', 'K', '3', 'M', 'J', 'U', 'F', 'R', '4', 'W', 'Y', 'L', 'T', 'N', '6', 'B', 'G', 'Q'} 11 | 12 | // 邀请码长度 13 | var CodeLenth = 8 14 | 15 | // A补位字符,不能与自定义重复 16 | var SuffixByte byte = 'A' 17 | 18 | // 默认邀请码生成器 19 | var DefaultInviteCodeHandler = NewInviteCodeHandler(BaseByte, CodeLenth, SuffixByte) 20 | 21 | type InviteCodeHandler struct { 22 | BaseByte []byte 23 | baseLength int 24 | CodeLength int 25 | SuffixByte byte 26 | } 27 | 28 | func NewInviteCodeHandler(baseByte []byte, codeLength int, suffixByte byte) *InviteCodeHandler { 29 | return &InviteCodeHandler{ 30 | baseLength: len(baseByte), 31 | BaseByte: baseByte, 32 | CodeLength: codeLength, 33 | SuffixByte: suffixByte, 34 | } 35 | } 36 | 37 | func (c *InviteCodeHandler) IdToCode(id int) string { 38 | buf := make([]byte, c.baseLength) 39 | charPos := c.baseLength 40 | for id/c.baseLength > 0 { 41 | index := id % c.baseLength 42 | charPos-- 43 | buf[charPos] = c.BaseByte[index] 44 | id /= c.baseLength 45 | } 46 | charPos-- 47 | buf[charPos] = c.BaseByte[id%c.baseLength] 48 | // 将字符数组转化为字符串 49 | result := buf[charPos:] 50 | // 长度不足指定长度则随机补全 51 | length := len(result) 52 | if length < int(c.CodeLength) { 53 | result = append(result, c.SuffixByte) 54 | now := time.Now().UnixNano() 55 | rand.Seed(now) 56 | // 去除SuffixByte本身占位之后需要补齐的位数 57 | for i := 0; i < c.CodeLength-length-1; i++ { 58 | randomNum := rand.Intn(c.baseLength) 59 | result = append(result, c.BaseByte[randomNum]) 60 | } 61 | } 62 | return string(result) 63 | } 64 | 65 | func (c *InviteCodeHandler) CodeToId(code string) int { 66 | var result int 67 | for i := 0; i < len(code); i++ { 68 | var index int 69 | for j := 0; j < c.baseLength; j++ { 70 | if code[i] == c.BaseByte[j] { 71 | index = j 72 | break 73 | } 74 | } 75 | if code[i] == c.SuffixByte { 76 | break 77 | } 78 | 79 | if i > 0 { 80 | result = result*c.baseLength + index 81 | } else { 82 | result = index 83 | } 84 | } 85 | return result 86 | } 87 | -------------------------------------------------------------------------------- /invite_code/invite_code_test.go: -------------------------------------------------------------------------------- 1 | package invite_code 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "gotest.tools/assert" 8 | ) 9 | 10 | func TestIdToCode(t *testing.T) { 11 | id := 5 12 | invteCode := DefaultInviteCodeHandler.IdToCode(id) 13 | fmt.Println(invteCode) 14 | } 15 | 16 | func TestCodeToId(t *testing.T) { 17 | code := "2A99UCYP" 18 | id := DefaultInviteCodeHandler.CodeToId(code) 19 | assert.Equal(t, 5, id) 20 | fmt.Println(id) 21 | } 22 | -------------------------------------------------------------------------------- /lang/lang.go: -------------------------------------------------------------------------------- 1 | package lang 2 | 3 | import "log" 4 | 5 | var Placeholder PlaceholderType 6 | 7 | type ( 8 | GenericType = interface{} 9 | PlaceholderType = struct{} 10 | ) 11 | 12 | func Must(err error) { 13 | if err != nil { 14 | log.Fatal(err) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /lang/lang_test.go: -------------------------------------------------------------------------------- 1 | package lang 2 | 3 | import "testing" 4 | 5 | func TestMust(t *testing.T) { 6 | Must(nil) 7 | } 8 | -------------------------------------------------------------------------------- /logx/logx.go: -------------------------------------------------------------------------------- 1 | package logx 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "runtime" 7 | "time" 8 | ) 9 | 10 | type ( 11 | Param struct { 12 | Time string `json:"time"` 13 | File string `json:"file"` 14 | Data interface{} `json:"data"` 15 | } 16 | ) 17 | 18 | func Info(args ...interface{}) { 19 | _, file, line, ok := runtime.Caller(1) 20 | if ok { 21 | data, _ := json.Marshal(&Param{ 22 | Time: time.Now().Format("2006-01-02 15:04:05"), 23 | File: fmt.Sprintf("%s:%d", file, line), 24 | Data: args, 25 | }) 26 | fmt.Printf("%s\n", string(data)) 27 | } 28 | } 29 | 30 | func Infof(format string, a ...interface{}) { 31 | _, file, line, ok := runtime.Caller(1) 32 | if ok { 33 | data, _ := json.Marshal(&Param{ 34 | Time: time.Now().Format("2006-01-02 15:04:05"), 35 | File: fmt.Sprintf("%s:%d", file, line), 36 | Data: fmt.Sprintf(format, a...), 37 | }) 38 | fmt.Printf("%s\n", string(data)) 39 | } 40 | } 41 | 42 | func Stack(args ...interface{}) { 43 | param := &Param{ 44 | Time: time.Now().Format("2006-01-02 15:04:05"), 45 | File: "", 46 | Data: args, 47 | } 48 | 49 | for i := 1; ; i++ { 50 | pc, file, line, ok := runtime.Caller(i) 51 | if !ok { 52 | data, _ := json.Marshal(param) 53 | fmt.Printf("%s\n", string(data)) 54 | break 55 | } 56 | f := runtime.FuncForPC(pc) 57 | if f.Name() != "runtime.main" && f.Name() != "runtime.goexit" { 58 | param.File += fmt.Sprintf("%s:%d|------|", file, line) 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /mapx/mapx.go: -------------------------------------------------------------------------------- 1 | package mapx 2 | 3 | import ( 4 | "encoding/xml" 5 | "io" 6 | ) 7 | 8 | type Map map[string]string 9 | 10 | type xmlMapEntry struct { 11 | XMLName xml.Name 12 | Value string `xml:",chardata"` 13 | } 14 | 15 | func (m Map) MarshalXML(e *xml.Encoder, start xml.StartElement) error { 16 | //构建xml输出头部 17 | var err error 18 | for key, value := range m { 19 | name := xml.Name{Space: "", Local: key} 20 | err = e.EncodeToken(xml.StartElement{Name: name}) 21 | if err != nil { 22 | return err 23 | } 24 | err = e.EncodeToken(xml.CharData(value)) 25 | if err != nil { 26 | return err 27 | } 28 | err = e.EncodeToken(xml.EndElement{Name: name}) 29 | if err != nil { 30 | return err 31 | } 32 | } 33 | return nil 34 | } 35 | 36 | func (m *Map) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { 37 | *m = Map{} 38 | for { 39 | var e xmlMapEntry 40 | 41 | err := d.Decode(&e) 42 | if err == io.EOF { 43 | break 44 | } else if err != nil { 45 | return err 46 | } 47 | 48 | (*m)[e.XMLName.Local] = e.Value 49 | } 50 | return nil 51 | } 52 | 53 | /** 54 | * @desc 校验 55 | */ 56 | func IsExist(data map[string]interface{}, name string) bool { 57 | _, ok := data[name] 58 | return ok 59 | } 60 | -------------------------------------------------------------------------------- /mapx/order_map.go: -------------------------------------------------------------------------------- 1 | package mapx 2 | 3 | import ( 4 | "container/list" 5 | ) 6 | 7 | type Obj struct { 8 | key string 9 | val interface{} 10 | } 11 | type OrderMap struct { 12 | l *list.List 13 | m map[string]*list.Element 14 | } 15 | 16 | func (m *OrderMap) Put(key string, val interface{}) { 17 | if e, ok := m.m[key]; ok { 18 | e.Value.(*Obj).val = val 19 | return 20 | } 21 | e := m.l.PushBack(&Obj{key: key, val: val}) 22 | m.m[key] = e 23 | } 24 | 25 | func (m *OrderMap) Get(key string) interface{} { 26 | if e, ok := m.m[key]; ok { 27 | return e.Value.(*Obj).val 28 | } 29 | 30 | return nil 31 | } 32 | 33 | func (m *OrderMap) Delete(key string) { 34 | if e, ok := m.m[key]; ok { 35 | m.l.Remove(e) 36 | delete(m.m, key) 37 | } 38 | } 39 | 40 | func (m *OrderMap) Range(f func(key string, val interface{})) { 41 | e := m.l.Front() 42 | for e != nil { 43 | obj := e.Value.(*Obj) 44 | f(obj.key, obj.val) 45 | e = e.Next() 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /monitor/dingtalk/dingtalk.go: -------------------------------------------------------------------------------- 1 | package dingtalk 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/weblazy/easy/http/http_client" 8 | "github.com/weblazy/easy/http/http_client/http_client_config" 9 | "github.com/weblazy/easy/monitor" 10 | ) 11 | 12 | type ( 13 | DingTalk struct { 14 | monitor.Handler 15 | Url string `json:"url"` 16 | atMobiles []string 17 | isAtAll bool 18 | } 19 | 20 | TextMsg struct { 21 | Msgtype string `json:"msgtype"` 22 | Text struct { 23 | Content string `json:"content"` 24 | } `json:"text"` 25 | At struct { 26 | AtMobiles []string `json:"atMobiles"` 27 | IsAtAll bool `json:"isAtAll"` 28 | } `json:"at"` 29 | } 30 | ) 31 | 32 | // @desc 33 | // @auth liuguoqiang 2020-12-07 34 | // @param 35 | // @return 36 | func NewDingTalk(url string) *DingTalk { 37 | return &DingTalk{ 38 | Url: url, 39 | atMobiles: []string{}, 40 | isAtAll: false, 41 | } 42 | } 43 | 44 | // @desc @部分成员 45 | // @auth liuguoqiang 2020-12-07 46 | // @param 47 | // @return 48 | func (dingTalk *DingTalk) WithAtMobiles(atMobiles []string) *DingTalk { 49 | if atMobiles != nil { 50 | dingTalk.atMobiles = atMobiles 51 | } 52 | return dingTalk 53 | } 54 | 55 | // @desc @所有成员 56 | // @auth liuguoqiang 2020-12-07 57 | // @param 58 | // @return 59 | func (dingTalk *DingTalk) WithIsAtAll(isAtAll bool) *DingTalk { 60 | dingTalk.isAtAll = isAtAll 61 | return dingTalk 62 | } 63 | 64 | // @desc 发送钉钉消息 65 | // @auth liuguoqiang 2020-12-07 66 | // @param 67 | // @return 68 | func (dingTalk *DingTalk) SendMsg(body interface{}) ([]byte, error) { 69 | cfg := http_client_config.DefaultConfig() 70 | client := http_client.NewHttpClient(cfg) 71 | request := client.Request.SetContext(context.Background()).SetBody(body) 72 | resp, err := request.Post(dingTalk.Url) 73 | return resp.Body(), err 74 | } 75 | 76 | // @desc 发送钉钉文本消息 77 | // @auth liuguoqiang 2020-12-07 78 | // @param 79 | // @return 80 | func (dingTalk *DingTalk) SendTextMsg(content string) error { 81 | if dingTalk.Url == "" { 82 | return fmt.Errorf("报警地址为空") 83 | } 84 | msg := TextMsg{ 85 | Msgtype: "text", 86 | } 87 | msg.Text.Content = content 88 | msg.At.IsAtAll = dingTalk.isAtAll 89 | msg.At.AtMobiles = dingTalk.atMobiles 90 | _, err := dingTalk.SendMsg(msg) 91 | return err 92 | } 93 | 94 | // // @desc Request 通用请求 95 | // // @auth liuguoqiang 2020-12-07 96 | // // @param 97 | // // @return 98 | // func Request(url string, body interface{}, headers map[string]string) ([]byte, error) { 99 | // client := http_request.New() 100 | // req := client.Request 101 | // if headers != nil { 102 | // req = req.SetHeaders(headers) 103 | // } 104 | // response, err := req. 105 | // SetBody(body). 106 | // Post(url) 107 | // if err != nil { 108 | // return nil, err 109 | // } 110 | // respByte := response.Body() 111 | // return respByte, err 112 | // } 113 | -------------------------------------------------------------------------------- /monitor/monitor.go: -------------------------------------------------------------------------------- 1 | package monitor 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | type Monitor struct { 11 | handlerList []Handler 12 | closeWaitGroup *sync.WaitGroup 13 | closed bool 14 | } 15 | 16 | type Handler interface { 17 | SendTextMsg(content string) error 18 | } 19 | 20 | // @desc 21 | // @auth liuguoqiang 2020-12-07 22 | // @param 23 | // @return 24 | func NewMonitor(handlerList ...Handler) *Monitor { 25 | return &Monitor{ 26 | handlerList: handlerList, 27 | closeWaitGroup: &sync.WaitGroup{}, 28 | closed: false, 29 | } 30 | } 31 | 32 | // @desc 发送本消息 33 | // @auth liuguoqiang 2020-12-07 34 | // @param 35 | // @return 36 | func (monitor *Monitor) SendTextMsg(content string) error { 37 | if monitor.closed { 38 | return fmt.Errorf("monitor closed") 39 | } 40 | monitor.closeWaitGroup.Add(1) 41 | go func() { 42 | defer func() { 43 | if r := recover(); r != nil { 44 | fmt.Printf("%#v\n", r) 45 | } 46 | monitor.closeWaitGroup.Done() 47 | }() 48 | for k1 := range monitor.handlerList { 49 | err := monitor.handlerList[k1].SendTextMsg(content) 50 | if err != nil { 51 | fmt.Printf("%#v\n", err) 52 | } 53 | } 54 | }() 55 | return nil 56 | } 57 | 58 | // @desc 程序退出前阻塞直到将数据发送出去,或者超时 59 | // @auth liuguoqiang 2020-12-07 60 | // @param 61 | // @return 62 | func (monitor *Monitor) Close(timeout int64) { 63 | monitor.closed = true 64 | ctx, cancel := context.WithCancel(context.Background()) 65 | go func(ctx context.Context) { 66 | monitor.closeWaitGroup.Wait() 67 | cancel() 68 | }(ctx) 69 | 70 | select { 71 | case <-ctx.Done(): 72 | fmt.Println("monitor safe closed") 73 | return 74 | case <-time.After(time.Second * time.Duration(timeout)): 75 | fmt.Println("monitor timeout!!!") 76 | return 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /print/print.go: -------------------------------------------------------------------------------- 1 | package print 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/common-nighthawk/go-figure" 7 | ) 8 | 9 | func PrintBanner(name string) { 10 | myFigure := figure.NewFigure(name, "", true) 11 | myFigure.Print() 12 | fmt.Println() 13 | } 14 | -------------------------------------------------------------------------------- /retry/retry_test.go: -------------------------------------------------------------------------------- 1 | package retry 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | var errRetry = errors.New("Testing") 8 | -------------------------------------------------------------------------------- /run/daemon.go: -------------------------------------------------------------------------------- 1 | package run 2 | 3 | import ( 4 | "context" 5 | "runtime/debug" 6 | "time" 7 | 8 | "emperror.dev/errors" 9 | "github.com/weblazy/easy/elog" 10 | "github.com/weblazy/easy/timex" 11 | "go.uber.org/zap" 12 | ) 13 | 14 | func DaemonRun(interval time.Duration, f func(), daemon func()) { 15 | ticker := timex.NewRealTicker(interval) 16 | stopChannel := make(chan struct{}) 17 | defer func() { 18 | stopChannel <- struct{}{} 19 | }() 20 | go func() { 21 | for { 22 | select { 23 | case <-ticker.Chan(): 24 | daemon() 25 | case <-stopChannel: 26 | ticker.Stop() 27 | return 28 | } 29 | } 30 | }() 31 | f() 32 | } 33 | 34 | // RunSafeWrap wrapper func () error with Recover 35 | func RunSafeWrap(ctx context.Context, fn func() error) (err error) { 36 | defer func() { 37 | if p := recover(); p != nil { 38 | elog.ErrorCtx(ctx, "panic", zap.Any("err", p), zap.String("stack", string(debug.Stack()))) 39 | err = errors.Errorf("panic: %v", p) 40 | } 41 | }() 42 | 43 | err = fn() 44 | 45 | return 46 | } 47 | -------------------------------------------------------------------------------- /set/set.go: -------------------------------------------------------------------------------- 1 | package set 2 | 3 | // 幂集。编写一种方法,返回某集合的所有子集。集合中不包含重复的元素。 4 | // 说明:解集不能包含重复的子集。 5 | // https://leetcode-cn.com/problems/power-set-lcci/ 6 | func Subsets(nums []int64) [][]int64 { 7 | res := make([][]int64, 0) 8 | res = append(res, []int64{}) 9 | for k1 := 0; k1 < len(nums); k1++ { 10 | a1 := []int64{nums[k1]} 11 | res = append(res, a1) 12 | loop(nums, &res, a1, k1) 13 | } 14 | return res 15 | } 16 | 17 | func loop(nums []int64, res *[][]int64, a1 []int64, k int) { 18 | for k1 := k + 1; k1 < len(nums); k1++ { 19 | a2 := make([]int64, len(a1)) 20 | copy(a2, a1) 21 | a2 = append(a2, nums[k1]) 22 | (*res) = append((*res), a2) 23 | loop(nums, res, a2, k1) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /set/stringset.go: -------------------------------------------------------------------------------- 1 | package set 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | type StringSet struct { 8 | store map[string]struct{} 9 | } 10 | 11 | func NewStringSet() *StringSet { 12 | return &StringSet{store: make(map[string]struct{})} 13 | } 14 | 15 | func (ss *StringSet) Add(val string) { 16 | ss.store[val] = struct{}{} 17 | } 18 | 19 | func (ss *StringSet) BatchAdd(vals ...string) { 20 | for _, vv := range vals { 21 | ss.Add(vv) 22 | } 23 | } 24 | 25 | func (ss *StringSet) Has(val string) bool { 26 | _, ok := ss.store[val] 27 | return ok 28 | } 29 | 30 | func (ss *StringSet) Delete(val string) { 31 | delete(ss.store, val) 32 | } 33 | 34 | func (ss *StringSet) Size() int { 35 | return len(ss.store) 36 | } 37 | 38 | func (ss *StringSet) ToArray() []string { 39 | arr := make([]string, 0, len(ss.store)) 40 | 41 | for v := range ss.store { 42 | arr = append(arr, v) 43 | } 44 | 45 | return arr 46 | } 47 | 48 | func (ss *StringSet) String() string { 49 | return strings.Join(ss.ToArray(), ",") 50 | } 51 | -------------------------------------------------------------------------------- /set/stringset_test.go: -------------------------------------------------------------------------------- 1 | package set_test 2 | 3 | // func TestStringSet(t *testing.T) { 4 | // s := set.NewStringSet() 5 | // assert.Equal(t, 0, s.Size()) 6 | 7 | // s.Add("1") 8 | // assert.Equal(t, 1, s.Size()) 9 | // assert.True(t, s.Has("1")) 10 | // s.Add("1") 11 | // assert.Equal(t, 1, s.Size()) 12 | // assert.True(t, s.Has("1")) 13 | // s.Delete("1") 14 | // assert.Equal(t, 0, s.Size()) 15 | // assert.False(t, s.Has("1")) 16 | 17 | // s.BatchAdd("2", "3") 18 | // assert.True(t, s.Has("2")) 19 | // assert.True(t, s.Has("3")) 20 | // assert.Equal(t, 2, s.Size()) 21 | // assert.Equal(t, "2,3", s.String()) 22 | // assert.Equal(t, []string{"2", "3"}, s.ToArray()) 23 | // } 24 | -------------------------------------------------------------------------------- /sortedset/border.go: -------------------------------------------------------------------------------- 1 | package sortedset 2 | 3 | import ( 4 | "errors" 5 | "strconv" 6 | ) 7 | 8 | /* 9 | * ScoreBorder is a struct represents `min` `max` parameter of redis command `ZRANGEBYSCORE` 10 | * can accept: 11 | * int or float value, such as 2.718, 2, -2.718, -2 ... 12 | * exclusive int or float value, such as (2.718, (2, (-2.718, (-2 ... 13 | * infinity: +inf, -inf, inf(same as +inf) 14 | */ 15 | 16 | const ( 17 | negativeInf int8 = -1 18 | positiveInf int8 = 1 19 | ) 20 | 21 | // ScoreBorder represents range of a float value, including: <, <=, >, >=, +inf, -inf 22 | type ScoreBorder struct { 23 | Inf int8 24 | Value float64 25 | Exclude bool 26 | } 27 | 28 | // if max.greater(score) then the score is within the upper border 29 | // do not use min.greater() 30 | func (border *ScoreBorder) greater(value float64) bool { 31 | if border.Inf == negativeInf { 32 | return false 33 | } else if border.Inf == positiveInf { 34 | return true 35 | } 36 | if border.Exclude { 37 | return border.Value > value 38 | } 39 | return border.Value >= value 40 | } 41 | 42 | func (border *ScoreBorder) less(value float64) bool { 43 | if border.Inf == negativeInf { 44 | return true 45 | } else if border.Inf == positiveInf { 46 | return false 47 | } 48 | if border.Exclude { 49 | return border.Value < value 50 | } 51 | return border.Value <= value 52 | } 53 | 54 | var positiveInfBorder = &ScoreBorder{ 55 | Inf: positiveInf, 56 | } 57 | 58 | var negativeInfBorder = &ScoreBorder{ 59 | Inf: negativeInf, 60 | } 61 | 62 | // ParseScoreBorder creates ScoreBorder from redis arguments 63 | func ParseScoreBorder(s string) (*ScoreBorder, error) { 64 | if s == "inf" || s == "+inf" { 65 | return positiveInfBorder, nil 66 | } 67 | if s == "-inf" { 68 | return negativeInfBorder, nil 69 | } 70 | if s[0] == '(' { 71 | value, err := strconv.ParseFloat(s[1:], 64) 72 | if err != nil { 73 | return nil, errors.New("ERR min or max is not a float") 74 | } 75 | return &ScoreBorder{ 76 | Inf: 0, 77 | Value: value, 78 | Exclude: true, 79 | }, nil 80 | } 81 | value, err := strconv.ParseFloat(s, 64) 82 | if err != nil { 83 | return nil, errors.New("ERR min or max is not a float") 84 | } 85 | return &ScoreBorder{ 86 | Inf: 0, 87 | Value: value, 88 | Exclude: false, 89 | }, nil 90 | } 91 | -------------------------------------------------------------------------------- /sortedset/skiplist_test.go: -------------------------------------------------------------------------------- 1 | package sortedset 2 | 3 | import "testing" 4 | 5 | func TestRandomLevel(t *testing.T) { 6 | m := make(map[int16]int) 7 | for i := 0; i < 10000; i++ { 8 | level := randomLevel() 9 | m[level]++ 10 | } 11 | for i := 0; i <= maxLevel; i++ { 12 | t.Logf("level %d, count %d", i, m[int16(i)]) 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /sortedset/sortedset_test.go: -------------------------------------------------------------------------------- 1 | package sortedset 2 | 3 | import "testing" 4 | 5 | func TestSortedSet_PopMin(t *testing.T) { 6 | var set = Make() 7 | set.Add("s1", 1) 8 | set.Add("s2", 2) 9 | set.Add("s3", 3) 10 | set.Add("s4", 4) 11 | 12 | var results = set.PopMin(2) 13 | if results[0].Member != "s1" || results[1].Member != "s2" { 14 | t.Fail() 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /sortx/sort.go: -------------------------------------------------------------------------------- 1 | package sortx 2 | 3 | // 按照接口中的某个字段排序 4 | type Order string 5 | 6 | const ( 7 | ASC Order = "ASC" 8 | DESC Order = "DESC" 9 | ) 10 | 11 | type Sort struct { 12 | Obj interface{} `json:"obj"` 13 | Sort float64 `json:"sort"` 14 | } 15 | type SortList struct { 16 | List []Sort `json:"list"` 17 | Order Order `json:"order"` 18 | } 19 | 20 | func NewSortList(order Order) *SortList { 21 | return &SortList{ 22 | List: make([]Sort, 0), 23 | Order: order, 24 | } 25 | } 26 | 27 | func (list *SortList) Len() int { 28 | return len(list.List) 29 | } 30 | 31 | func (list *SortList) Less(i, j int) bool { 32 | if list.Order == ASC { 33 | return list.List[i].Sort <= list.List[j].Sort 34 | } 35 | return list.List[i].Sort >= list.List[j].Sort 36 | } 37 | 38 | func (list *SortList) Swap(i, j int) { 39 | temp := list.List[i] 40 | list.List[i] = list.List[j] 41 | list.List[j] = temp 42 | } 43 | -------------------------------------------------------------------------------- /stringx/stringx.go: -------------------------------------------------------------------------------- 1 | package stringx 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "strconv" 7 | "time" 8 | ) 9 | 10 | var ( 11 | TimeLayout = "2006-01-02 15:04:05" 12 | ByteSeed = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890") 13 | ) 14 | 15 | func ToString(param interface{}) (string, error) { 16 | resp := "" 17 | switch param.(type) { 18 | case int64: 19 | resp = strconv.FormatInt(param.(int64), 10) 20 | case int32: 21 | resp = strconv.FormatInt(param.(int64), 10) 22 | case int: 23 | resp = strconv.Itoa(param.(int)) 24 | case float64: 25 | resp = strconv.FormatFloat(param.(float64), 'f', -1, 64) 26 | case float32: 27 | resp = strconv.FormatFloat(param.(float64), 'f', -1, 64) 28 | case string: 29 | resp = param.(string) 30 | case []byte: 31 | resp = string(param.([]byte)) 32 | case time.Time: 33 | resp = param.(time.Time).Format(TimeLayout) 34 | case *time.Time: 35 | resp = param.(*time.Time).Format(TimeLayout) 36 | default: 37 | return resp, fmt.Errorf("%v is not base type", param) 38 | } 39 | return resp, nil 40 | } 41 | 42 | func SplitN(s string, n int) []string { 43 | len := len(s) 44 | var resp []string 45 | var index, next int 46 | 47 | for len > index { 48 | next += n 49 | if len >= next { 50 | resp = append(resp, s[index:next]) 51 | } else { 52 | resp = append(resp, s[index:len]) 53 | } 54 | index = next 55 | } 56 | return resp 57 | } 58 | 59 | func ToStr(param interface{}) string { 60 | resp := "" 61 | switch param.(type) { 62 | case int64: 63 | resp = strconv.FormatInt(param.(int64), 10) 64 | case int32: 65 | resp = strconv.FormatInt(param.(int64), 10) 66 | case int: 67 | resp = strconv.Itoa(param.(int)) 68 | case float64: 69 | resp = strconv.FormatFloat(param.(float64), 'f', -1, 64) 70 | case float32: 71 | resp = strconv.FormatFloat(param.(float64), 'f', -1, 64) 72 | case string: 73 | resp = param.(string) 74 | case []byte: 75 | resp = string(param.([]byte)) 76 | case time.Time: 77 | resp = param.(time.Time).Format(TimeLayout) 78 | case *time.Time: 79 | resp = param.(*time.Time).Format(TimeLayout) 80 | default: 81 | return resp 82 | } 83 | return resp 84 | } 85 | 86 | func RandomString(len int) string { 87 | rand.Seed(time.Now().UnixNano()) 88 | resp := make([]byte, len) 89 | for i := 0; i < len; i++ { 90 | resp[i] = ByteSeed[rand.Intn(62)] 91 | } 92 | return string(resp) 93 | } 94 | -------------------------------------------------------------------------------- /syncx/atomicbool.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import "sync/atomic" 4 | 5 | type AtomicBool uint32 6 | 7 | func NewAtomicBool() *AtomicBool { 8 | return new(AtomicBool) 9 | } 10 | 11 | func ForAtomicBool(val bool) *AtomicBool { 12 | b := NewAtomicBool() 13 | b.Set(val) 14 | return b 15 | } 16 | 17 | func (b *AtomicBool) CompareAndSwap(old, val bool) bool { 18 | var ov, nv uint32 19 | if old { 20 | ov = 1 21 | } 22 | if val { 23 | nv = 1 24 | } 25 | return atomic.CompareAndSwapUint32((*uint32)(b), ov, nv) 26 | } 27 | 28 | func (b *AtomicBool) Set(v bool) { 29 | if v { 30 | atomic.StoreUint32((*uint32)(b), 1) 31 | } else { 32 | atomic.StoreUint32((*uint32)(b), 0) 33 | } 34 | } 35 | 36 | func (b *AtomicBool) True() bool { 37 | return atomic.LoadUint32((*uint32)(b)) == 1 38 | } 39 | -------------------------------------------------------------------------------- /syncx/atomicbool_test.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestAtomicBool(t *testing.T) { 10 | val := ForAtomicBool(true) 11 | assert.True(t, val.True()) 12 | val.Set(false) 13 | assert.False(t, val.True()) 14 | val.Set(true) 15 | assert.True(t, val.True()) 16 | val.Set(false) 17 | assert.False(t, val.True()) 18 | ok := val.CompareAndSwap(false, true) 19 | assert.True(t, ok) 20 | assert.True(t, val.True()) 21 | ok = val.CompareAndSwap(true, false) 22 | assert.True(t, ok) 23 | assert.False(t, val.True()) 24 | ok = val.CompareAndSwap(true, false) 25 | assert.False(t, ok) 26 | assert.False(t, val.True()) 27 | } 28 | -------------------------------------------------------------------------------- /syncx/atomicduration.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "sync/atomic" 5 | "time" 6 | ) 7 | 8 | type AtomicDuration int64 9 | 10 | func NewAtomicDuration() *AtomicDuration { 11 | return new(AtomicDuration) 12 | } 13 | 14 | func ForAtomicDuration(val time.Duration) *AtomicDuration { 15 | d := NewAtomicDuration() 16 | d.Set(val) 17 | return d 18 | } 19 | 20 | func (d *AtomicDuration) CompareAndSwap(old, val time.Duration) bool { 21 | return atomic.CompareAndSwapInt64((*int64)(d), int64(old), int64(val)) 22 | } 23 | 24 | func (d *AtomicDuration) Load() time.Duration { 25 | return time.Duration(atomic.LoadInt64((*int64)(d))) 26 | } 27 | 28 | func (d *AtomicDuration) Set(val time.Duration) { 29 | atomic.StoreInt64((*int64)(d), int64(val)) 30 | } 31 | -------------------------------------------------------------------------------- /syncx/atomicduration_test.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestAtomicDuration(t *testing.T) { 11 | d := ForAtomicDuration(time.Duration(100)) 12 | assert.Equal(t, time.Duration(100), d.Load()) 13 | d.Set(time.Duration(200)) 14 | assert.Equal(t, time.Duration(200), d.Load()) 15 | assert.True(t, d.CompareAndSwap(time.Duration(200), time.Duration(300))) 16 | assert.Equal(t, time.Duration(300), d.Load()) 17 | assert.False(t, d.CompareAndSwap(time.Duration(200), time.Duration(400))) 18 | assert.Equal(t, time.Duration(300), d.Load()) 19 | } 20 | -------------------------------------------------------------------------------- /syncx/atomicfloat64.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "math" 5 | "sync/atomic" 6 | ) 7 | 8 | type AtomicFloat64 uint64 9 | 10 | func NewAtomicFloat64() *AtomicFloat64 { 11 | return new(AtomicFloat64) 12 | } 13 | 14 | func ForAtomicFloat64(val float64) *AtomicFloat64 { 15 | f := NewAtomicFloat64() 16 | f.Set(val) 17 | return f 18 | } 19 | 20 | func (f *AtomicFloat64) Add(val float64) float64 { 21 | for { 22 | old := f.Load() 23 | nv := old + val 24 | if f.CompareAndSwap(old, nv) { 25 | return nv 26 | } 27 | } 28 | } 29 | 30 | func (f *AtomicFloat64) CompareAndSwap(old, val float64) bool { 31 | return atomic.CompareAndSwapUint64((*uint64)(f), math.Float64bits(old), math.Float64bits(val)) 32 | } 33 | 34 | func (f *AtomicFloat64) Load() float64 { 35 | return math.Float64frombits(atomic.LoadUint64((*uint64)(f))) 36 | } 37 | 38 | func (f *AtomicFloat64) Set(val float64) { 39 | atomic.StoreUint64((*uint64)(f), math.Float64bits(val)) 40 | } 41 | -------------------------------------------------------------------------------- /syncx/atomicfloat64_test.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestAtomicFloat64(t *testing.T) { 11 | f := ForAtomicFloat64(100) 12 | var wg sync.WaitGroup 13 | for i := 0; i < 5; i++ { 14 | wg.Add(1) 15 | go func() { 16 | for i := 0; i < 100; i++ { 17 | f.Add(1) 18 | } 19 | wg.Done() 20 | }() 21 | } 22 | wg.Wait() 23 | assert.Equal(t, float64(600), f.Load()) 24 | } 25 | -------------------------------------------------------------------------------- /syncx/cond.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/weblazy/easy/timex" 7 | 8 | "github.com/weblazy/easy/lang" 9 | ) 10 | 11 | type Cond struct { 12 | signal chan lang.PlaceholderType 13 | } 14 | 15 | func NewCond() *Cond { 16 | return &Cond{ 17 | signal: make(chan lang.PlaceholderType), 18 | } 19 | } 20 | 21 | // WaitWithTimeout wait for signal return remain wait time or timed out 22 | func (cond *Cond) WaitWithTimeout(timeout time.Duration) (time.Duration, bool) { 23 | timer := time.NewTimer(timeout) 24 | defer timer.Stop() 25 | 26 | begin := timex.Now() 27 | select { 28 | case <-cond.signal: 29 | elapsed := timex.Since(begin) 30 | remainTimeout := timeout - elapsed 31 | return remainTimeout, true 32 | case <-timer.C: 33 | return 0, false 34 | } 35 | } 36 | 37 | // Wait for signal 38 | func (cond *Cond) Wait() { 39 | <-cond.signal 40 | } 41 | 42 | // Signal wakes one goroutine waiting on c, if there is any. 43 | func (cond *Cond) Signal() { 44 | select { 45 | case cond.signal <- lang.Placeholder: 46 | default: 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /syncx/cond_test.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestTimeoutCondWait(t *testing.T) { 12 | var wait sync.WaitGroup 13 | cond := NewCond() 14 | wait.Add(2) 15 | go func() { 16 | cond.Wait() 17 | wait.Done() 18 | }() 19 | time.Sleep(time.Duration(50) * time.Millisecond) 20 | go func() { 21 | cond.Signal() 22 | wait.Done() 23 | }() 24 | wait.Wait() 25 | } 26 | 27 | func TestTimeoutCondWaitTimeout(t *testing.T) { 28 | var wait sync.WaitGroup 29 | cond := NewCond() 30 | wait.Add(1) 31 | go func() { 32 | cond.WaitWithTimeout(time.Duration(500) * time.Millisecond) 33 | wait.Done() 34 | }() 35 | wait.Wait() 36 | } 37 | 38 | func TestTimeoutCondWaitTimeoutRemain(t *testing.T) { 39 | var wait sync.WaitGroup 40 | cond := NewCond() 41 | wait.Add(2) 42 | ch := make(chan time.Duration, 1) 43 | defer close(ch) 44 | timeout := time.Duration(2000) * time.Millisecond 45 | go func() { 46 | remainTimeout, _ := cond.WaitWithTimeout(timeout) 47 | ch <- remainTimeout 48 | wait.Done() 49 | }() 50 | sleep(200) 51 | go func() { 52 | cond.Signal() 53 | wait.Done() 54 | }() 55 | wait.Wait() 56 | remainTimeout := <-ch 57 | assert.True(t, remainTimeout < timeout, "expect remainTimeout %v < %v", remainTimeout, timeout) 58 | assert.True(t, remainTimeout >= time.Duration(200)*time.Millisecond, 59 | "expect remainTimeout %v >= 200 millisecond", remainTimeout) 60 | } 61 | 62 | func TestSignalNoWait(t *testing.T) { 63 | cond := NewCond() 64 | cond.Signal() 65 | } 66 | 67 | func sleep(millisecond int) { 68 | time.Sleep(time.Duration(millisecond) * time.Millisecond) 69 | } 70 | 71 | func currentTimeMillis() int64 { 72 | return time.Now().UnixNano() / int64(time.Millisecond) 73 | } 74 | -------------------------------------------------------------------------------- /syncx/donechan.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/weblazy/easy/lang" 7 | ) 8 | 9 | type DoneChan struct { 10 | done chan lang.PlaceholderType 11 | once sync.Once 12 | } 13 | 14 | func NewDoneChan() *DoneChan { 15 | return &DoneChan{ 16 | done: make(chan lang.PlaceholderType), 17 | } 18 | } 19 | 20 | func (dc *DoneChan) Close() { 21 | dc.once.Do(func() { 22 | close(dc.done) 23 | }) 24 | } 25 | 26 | func (dc *DoneChan) Done() chan lang.PlaceholderType { 27 | return dc.done 28 | } 29 | -------------------------------------------------------------------------------- /syncx/donechan_test.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | ) 7 | 8 | func TestDoneChanClose(t *testing.T) { 9 | doneChan := NewDoneChan() 10 | 11 | for i := 0; i < 5; i++ { 12 | doneChan.Close() 13 | } 14 | } 15 | 16 | func TestDoneChanDone(t *testing.T) { 17 | var waitGroup sync.WaitGroup 18 | doneChan := NewDoneChan() 19 | 20 | waitGroup.Add(1) 21 | go func() { 22 | select { 23 | case <-doneChan.Done(): 24 | waitGroup.Done() 25 | } 26 | }() 27 | 28 | for i := 0; i < 5; i++ { 29 | doneChan.Close() 30 | } 31 | 32 | waitGroup.Wait() 33 | } 34 | -------------------------------------------------------------------------------- /syncx/limit.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/weblazy/easy/lang" 7 | ) 8 | 9 | var ErrReturn = errors.New("discarding limited token, resource pool is full, someone returned multiple times") 10 | 11 | type Limit struct { 12 | pool chan lang.PlaceholderType 13 | } 14 | 15 | func NewLimit(n int) Limit { 16 | return Limit{ 17 | pool: make(chan lang.PlaceholderType, n), 18 | } 19 | } 20 | 21 | func (l Limit) Borrow() { 22 | l.pool <- lang.Placeholder 23 | } 24 | 25 | // Return returns the borrowed resource, returns error only if returned more than borrowed. 26 | func (l Limit) Return() error { 27 | select { 28 | case <-l.pool: 29 | return nil 30 | default: 31 | return ErrReturn 32 | } 33 | } 34 | 35 | func (l Limit) TryBorrow() bool { 36 | select { 37 | case l.pool <- lang.Placeholder: 38 | return true 39 | default: 40 | return false 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /syncx/limit_test.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestLimit(t *testing.T) { 10 | limit := NewLimit(2) 11 | limit.Borrow() 12 | assert.True(t, limit.TryBorrow()) 13 | assert.False(t, limit.TryBorrow()) 14 | assert.Nil(t, limit.Return()) 15 | assert.Nil(t, limit.Return()) 16 | assert.Equal(t, ErrReturn, limit.Return()) 17 | } 18 | -------------------------------------------------------------------------------- /syncx/lockedcalls.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import "sync" 4 | 5 | type ( 6 | // LockedCalls makes sure the calls with the same key to be called sequentially. 7 | // For example, A called F, before it's done, B called F, then B's call would not blocked, 8 | // after A's call finished, B's call got executed. 9 | // The calls with the same key are independent, not sharing the returned values. 10 | // A ------->calls F with key and executes<------->returns 11 | // B ------------------>calls F with key<--------->executes<---->returns 12 | LockedCalls interface { 13 | Do(key string, fn func() (interface{}, error)) (interface{}, error) 14 | } 15 | 16 | lockedGroup struct { 17 | mu sync.Mutex 18 | m map[string]*sync.WaitGroup 19 | } 20 | ) 21 | 22 | func NewLockedCalls() LockedCalls { 23 | return &lockedGroup{ 24 | m: make(map[string]*sync.WaitGroup), 25 | } 26 | } 27 | 28 | func (lg *lockedGroup) Do(key string, fn func() (interface{}, error)) (interface{}, error) { 29 | begin: 30 | lg.mu.Lock() 31 | if wg, ok := lg.m[key]; ok { 32 | lg.mu.Unlock() 33 | wg.Wait() 34 | goto begin 35 | } 36 | 37 | return lg.makeCall(key, fn) 38 | } 39 | 40 | func (lg *lockedGroup) makeCall(key string, fn func() (interface{}, error)) (interface{}, error) { 41 | var wg sync.WaitGroup 42 | wg.Add(1) 43 | lg.m[key] = &wg 44 | lg.mu.Unlock() 45 | 46 | defer func() { 47 | // delete key first, done later. can't reverse the order, because if reverse, 48 | // another Do call might wg.Wait() without get notified with wg.Done() 49 | lg.mu.Lock() 50 | delete(lg.m, key) 51 | lg.mu.Unlock() 52 | wg.Done() 53 | }() 54 | 55 | return fn() 56 | } 57 | -------------------------------------------------------------------------------- /syncx/lockedcalls_test.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "sync" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestLockedCallDo(t *testing.T) { 12 | g := NewLockedCalls() 13 | v, err := g.Do("key", func() (interface{}, error) { 14 | return "bar", nil 15 | }) 16 | if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { 17 | t.Errorf("Do = %v; want %v", got, want) 18 | } 19 | if err != nil { 20 | t.Errorf("Do error = %v", err) 21 | } 22 | } 23 | 24 | func TestLockedCallDoErr(t *testing.T) { 25 | g := NewLockedCalls() 26 | someErr := errors.New("some error") 27 | v, err := g.Do("key", func() (interface{}, error) { 28 | return nil, someErr 29 | }) 30 | if err != someErr { 31 | t.Errorf("Do error = %v; want someErr", err) 32 | } 33 | if v != nil { 34 | t.Errorf("unexpected non-nil value %#v", v) 35 | } 36 | } 37 | 38 | func TestLockedCallDoDupSuppress(t *testing.T) { 39 | g := NewLockedCalls() 40 | c := make(chan string) 41 | var calls int 42 | fn := func() (interface{}, error) { 43 | calls++ 44 | ret := calls 45 | <-c 46 | calls-- 47 | return ret, nil 48 | } 49 | 50 | const n = 10 51 | var results []int 52 | var lock sync.Mutex 53 | var wg sync.WaitGroup 54 | for i := 0; i < n; i++ { 55 | wg.Add(1) 56 | go func() { 57 | v, err := g.Do("key", fn) 58 | if err != nil { 59 | t.Errorf("Do error: %v", err) 60 | } 61 | 62 | lock.Lock() 63 | results = append(results, v.(int)) 64 | lock.Unlock() 65 | wg.Done() 66 | }() 67 | } 68 | time.Sleep(100 * time.Millisecond) // let goroutines above block 69 | for i := 0; i < n; i++ { 70 | c <- "bar" 71 | } 72 | wg.Wait() 73 | 74 | lock.Lock() 75 | defer lock.Unlock() 76 | 77 | for _, item := range results { 78 | if item != 1 { 79 | t.Errorf("number of calls = %d; want 1", item) 80 | } 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /syncx/managedresource.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import "sync" 4 | 5 | type ManagedResource struct { 6 | resource interface{} 7 | lock sync.RWMutex 8 | generate func() interface{} 9 | equals func(a, b interface{}) bool 10 | } 11 | 12 | func NewManagedResource(generate func() interface{}, equals func(a, b interface{}) bool) *ManagedResource { 13 | return &ManagedResource{ 14 | generate: generate, 15 | equals: equals, 16 | } 17 | } 18 | 19 | func (mr *ManagedResource) MarkBroken(resource interface{}) { 20 | mr.lock.Lock() 21 | defer mr.lock.Unlock() 22 | 23 | if mr.equals(mr.resource, resource) { 24 | mr.resource = nil 25 | } 26 | } 27 | 28 | func (mr *ManagedResource) Take() interface{} { 29 | mr.lock.RLock() 30 | resource := mr.resource 31 | mr.lock.RUnlock() 32 | 33 | if resource != nil { 34 | return resource 35 | } 36 | 37 | mr.lock.Lock() 38 | defer mr.lock.Unlock() 39 | // maybe another Take() call already generated the resource. 40 | if mr.resource == nil { 41 | mr.resource = mr.generate() 42 | } 43 | return mr.resource 44 | } 45 | -------------------------------------------------------------------------------- /syncx/managedresource_test.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "sync/atomic" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestManagedResource(t *testing.T) { 11 | var count int32 12 | resource := NewManagedResource(func() interface{} { 13 | return atomic.AddInt32(&count, 1) 14 | }, func(a, b interface{}) bool { 15 | return a == b 16 | }) 17 | 18 | assert.Equal(t, resource.Take(), resource.Take()) 19 | old := resource.Take() 20 | resource.MarkBroken(old) 21 | assert.NotEqual(t, old, resource.Take()) 22 | } 23 | -------------------------------------------------------------------------------- /syncx/once.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import "sync" 4 | 5 | func Once(fn func()) func() { 6 | once := new(sync.Once) 7 | return func() { 8 | once.Do(fn) 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /syncx/once_test.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestOnce(t *testing.T) { 10 | var v int 11 | add := Once(func() { 12 | v++ 13 | }) 14 | 15 | for i := 0; i < 5; i++ { 16 | add() 17 | } 18 | 19 | assert.Equal(t, 1, v) 20 | } 21 | -------------------------------------------------------------------------------- /syncx/onceguard.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import "sync/atomic" 4 | 5 | type OnceGuard struct { 6 | done uint32 7 | } 8 | 9 | func (og *OnceGuard) Taken() bool { 10 | return atomic.LoadUint32(&og.done) == 1 11 | } 12 | 13 | func (og *OnceGuard) Take() bool { 14 | return atomic.CompareAndSwapUint32(&og.done, 0, 1) 15 | } 16 | -------------------------------------------------------------------------------- /syncx/onceguard_test.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestOnceGuard(t *testing.T) { 10 | var guard OnceGuard 11 | 12 | assert.False(t, guard.Taken()) 13 | assert.True(t, guard.Take()) 14 | assert.True(t, guard.Taken()) 15 | assert.False(t, guard.Take()) 16 | assert.True(t, guard.Taken()) 17 | } 18 | -------------------------------------------------------------------------------- /syncx/pool.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | "github.com/weblazy/easy/timex" 8 | ) 9 | 10 | type ( 11 | PoolOption func(*Pool) 12 | 13 | node struct { 14 | item interface{} 15 | next *node 16 | lastUsed time.Duration 17 | } 18 | 19 | Pool struct { 20 | limit int 21 | created int 22 | maxAge time.Duration 23 | lock sync.Locker 24 | cond *sync.Cond 25 | head *node 26 | create func() interface{} 27 | destroy func(interface{}) 28 | } 29 | ) 30 | 31 | func NewPool(n int, create func() interface{}, destroy func(interface{}), opts ...PoolOption) *Pool { 32 | if n <= 0 { 33 | panic("pool size can't be negative or zero") 34 | } 35 | 36 | lock := new(sync.Mutex) 37 | pool := &Pool{ 38 | limit: n, 39 | lock: lock, 40 | cond: sync.NewCond(lock), 41 | create: create, 42 | destroy: destroy, 43 | } 44 | 45 | for _, opt := range opts { 46 | opt(pool) 47 | } 48 | 49 | return pool 50 | } 51 | 52 | func (p *Pool) Get() interface{} { 53 | p.lock.Lock() 54 | defer p.lock.Unlock() 55 | 56 | for { 57 | if p.head != nil { 58 | head := p.head 59 | p.head = head.next 60 | if p.maxAge > 0 && head.lastUsed+p.maxAge < timex.Now() { 61 | p.created-- 62 | p.destroy(head.item) 63 | continue 64 | } else { 65 | return head.item 66 | } 67 | } 68 | 69 | if p.created < p.limit { 70 | p.created++ 71 | return p.create() 72 | } 73 | 74 | p.cond.Wait() 75 | } 76 | } 77 | 78 | func (p *Pool) Put(x interface{}) { 79 | if x == nil { 80 | return 81 | } 82 | 83 | p.lock.Lock() 84 | defer p.lock.Unlock() 85 | 86 | p.head = &node{ 87 | item: x, 88 | next: p.head, 89 | lastUsed: timex.Now(), 90 | } 91 | p.cond.Signal() 92 | } 93 | 94 | func WithMaxAge(duration time.Duration) PoolOption { 95 | return func(pool *Pool) { 96 | pool.maxAge = duration 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /syncx/pool_test.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "sync" 5 | "sync/atomic" 6 | "testing" 7 | "time" 8 | 9 | "github.com/weblazy/easy/lang" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | const limit = 10 15 | 16 | func TestPoolGet(t *testing.T) { 17 | stack := NewPool(limit, create, destroy) 18 | ch := make(chan lang.PlaceholderType) 19 | 20 | for i := 0; i < limit; i++ { 21 | go func() { 22 | v := stack.Get() 23 | if v.(int) != 1 { 24 | t.Fatal("unmatch value") 25 | } 26 | ch <- lang.Placeholder 27 | }() 28 | 29 | select { 30 | case <-ch: 31 | case <-time.After(time.Second): 32 | t.Fail() 33 | } 34 | } 35 | } 36 | 37 | func TestPoolPopTooMany(t *testing.T) { 38 | stack := NewPool(limit, create, destroy) 39 | ch := make(chan lang.PlaceholderType, 1) 40 | 41 | for i := 0; i < limit; i++ { 42 | var wait sync.WaitGroup 43 | wait.Add(1) 44 | go func() { 45 | stack.Get() 46 | ch <- lang.Placeholder 47 | wait.Done() 48 | }() 49 | 50 | wait.Wait() 51 | select { 52 | case <-ch: 53 | default: 54 | t.Fail() 55 | } 56 | } 57 | 58 | var waitGroup, pushWait sync.WaitGroup 59 | waitGroup.Add(1) 60 | pushWait.Add(1) 61 | go func() { 62 | pushWait.Done() 63 | stack.Get() 64 | waitGroup.Done() 65 | }() 66 | 67 | pushWait.Wait() 68 | stack.Put(1) 69 | waitGroup.Wait() 70 | } 71 | 72 | func TestPoolPopFirst(t *testing.T) { 73 | var value int32 74 | stack := NewPool(limit, func() interface{} { 75 | return atomic.AddInt32(&value, 1) 76 | }, destroy) 77 | 78 | for i := 0; i < 100; i++ { 79 | v := stack.Get().(int32) 80 | assert.Equal(t, 1, int(v)) 81 | stack.Put(v) 82 | } 83 | } 84 | 85 | func TestPoolWithMaxAge(t *testing.T) { 86 | var value int32 87 | stack := NewPool(limit, func() interface{} { 88 | return atomic.AddInt32(&value, 1) 89 | }, destroy, WithMaxAge(time.Millisecond)) 90 | 91 | v1 := stack.Get().(int32) 92 | // put nil should not matter 93 | stack.Put(nil) 94 | stack.Put(v1) 95 | time.Sleep(time.Millisecond * 10) 96 | v2 := stack.Get().(int32) 97 | assert.NotEqual(t, v1, v2) 98 | } 99 | 100 | func TestNewPoolPanics(t *testing.T) { 101 | assert.Panics(t, func() { 102 | NewPool(0, create, destroy) 103 | }) 104 | } 105 | 106 | func create() interface{} { 107 | return 1 108 | } 109 | 110 | func destroy(_ interface{}) { 111 | } 112 | -------------------------------------------------------------------------------- /syncx/refresource.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | ) 7 | 8 | var ErrUseOfCleaned = errors.New("using a cleaned resource") 9 | 10 | type RefResource struct { 11 | lock sync.Mutex 12 | ref int32 13 | cleaned bool 14 | clean func() 15 | } 16 | 17 | func NewRefResource(clean func()) *RefResource { 18 | return &RefResource{ 19 | clean: clean, 20 | } 21 | } 22 | 23 | func (r *RefResource) Use() error { 24 | r.lock.Lock() 25 | defer r.lock.Unlock() 26 | 27 | if r.cleaned { 28 | return ErrUseOfCleaned 29 | } 30 | 31 | r.ref++ 32 | return nil 33 | } 34 | 35 | func (r *RefResource) Clean() { 36 | r.lock.Lock() 37 | defer r.lock.Unlock() 38 | 39 | if r.cleaned { 40 | return 41 | } 42 | 43 | r.ref-- 44 | if r.ref == 0 { 45 | r.cleaned = true 46 | r.clean() 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /syncx/refresource_test.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestRefCleaner(t *testing.T) { 10 | var count int 11 | clean := func() { 12 | count += 1 13 | } 14 | 15 | cleaner := NewRefResource(clean) 16 | err := cleaner.Use() 17 | assert.Nil(t, err) 18 | err = cleaner.Use() 19 | assert.Nil(t, err) 20 | cleaner.Clean() 21 | cleaner.Clean() 22 | assert.Equal(t, 1, count) 23 | cleaner.Clean() 24 | cleaner.Clean() 25 | assert.Equal(t, 1, count) 26 | assert.Equal(t, ErrUseOfCleaned, cleaner.Use()) 27 | } 28 | -------------------------------------------------------------------------------- /syncx/resourcemanager.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "io" 5 | "sync" 6 | 7 | errorx "github.com/weblazy/easy/eerror" 8 | ) 9 | 10 | type ResourceManager struct { 11 | resources map[string]io.Closer 12 | sharedCalls SharedCalls 13 | lock sync.RWMutex 14 | } 15 | 16 | func NewResourceManager() *ResourceManager { 17 | return &ResourceManager{ 18 | resources: make(map[string]io.Closer), 19 | sharedCalls: NewSharedCalls(), 20 | } 21 | } 22 | 23 | func (manager *ResourceManager) Close() error { 24 | manager.lock.Lock() 25 | defer manager.lock.Unlock() 26 | 27 | var be errorx.BatchError 28 | for _, resource := range manager.resources { 29 | if err := resource.Close(); err != nil { 30 | be = append(be, err) 31 | } 32 | } 33 | 34 | return be 35 | } 36 | 37 | func (manager *ResourceManager) GetResource(key string, create func() (io.Closer, error)) (io.Closer, error) { 38 | val, err := manager.sharedCalls.Do(key, func() (interface{}, error) { 39 | manager.lock.RLock() 40 | resource, ok := manager.resources[key] 41 | manager.lock.RUnlock() 42 | if ok { 43 | return resource, nil 44 | } 45 | 46 | resource, err := create() 47 | if err != nil { 48 | return nil, err 49 | } 50 | 51 | manager.lock.Lock() 52 | manager.resources[key] = resource 53 | manager.lock.Unlock() 54 | 55 | return resource, nil 56 | }) 57 | if err != nil { 58 | return nil, err 59 | } 60 | 61 | return val.(io.Closer), nil 62 | } 63 | -------------------------------------------------------------------------------- /syncx/resourcemanager_test.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type dummyResource struct { 12 | age int 13 | } 14 | 15 | func (dr *dummyResource) Close() error { 16 | return errors.New("close") 17 | } 18 | 19 | func TestResourceManager_GetResource(t *testing.T) { 20 | manager := NewResourceManager() 21 | defer manager.Close() 22 | 23 | var age int 24 | for i := 0; i < 10; i++ { 25 | val, err := manager.GetResource("key", func() (io.Closer, error) { 26 | age++ 27 | return &dummyResource{ 28 | age: age, 29 | }, nil 30 | }) 31 | assert.Nil(t, err) 32 | assert.Equal(t, 1, val.(*dummyResource).age) 33 | } 34 | } 35 | 36 | func TestResourceManager_GetResourceError(t *testing.T) { 37 | manager := NewResourceManager() 38 | defer manager.Close() 39 | 40 | for i := 0; i < 10; i++ { 41 | _, err := manager.GetResource("key", func() (io.Closer, error) { 42 | return nil, errors.New("fail") 43 | }) 44 | assert.NotNil(t, err) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /syncx/sharedcalls.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import "sync" 4 | 5 | type ( 6 | // SharedCalls lets the concurrent calls with the same key to share the call result. 7 | // For example, A called F, before it's done, B called F. Then B would not execute F, 8 | // and shared the result returned by F which called by A. 9 | // The calls with the same key are dependent, concurrent calls share the returned values. 10 | // A ------->calls F with key<------------------->returns val 11 | // B --------------------->calls F with key------>returns val 12 | SharedCalls interface { 13 | Do(key string, fn func() (interface{}, error)) (interface{}, error) 14 | DoEx(key string, fn func() (interface{}, error)) (interface{}, bool, error) 15 | } 16 | 17 | call struct { 18 | wg sync.WaitGroup 19 | val interface{} 20 | err error 21 | } 22 | 23 | sharedGroup struct { 24 | mu sync.Mutex 25 | m map[string]*call 26 | } 27 | ) 28 | 29 | func NewSharedCalls() SharedCalls { 30 | return &sharedGroup{ 31 | m: make(map[string]*call), 32 | } 33 | } 34 | 35 | func (g *sharedGroup) Do(key string, fn func() (interface{}, error)) (interface{}, error) { 36 | g.mu.Lock() 37 | if c, ok := g.m[key]; ok { 38 | g.mu.Unlock() 39 | c.wg.Wait() 40 | return c.val, c.err 41 | } 42 | 43 | c := g.makeCall(key, fn) 44 | return c.val, c.err 45 | } 46 | 47 | func (g *sharedGroup) DoEx(key string, fn func() (interface{}, error)) (val interface{}, fresh bool, err error) { 48 | g.mu.Lock() 49 | if c, ok := g.m[key]; ok { 50 | g.mu.Unlock() 51 | c.wg.Wait() 52 | return c.val, false, c.err 53 | } 54 | 55 | c := g.makeCall(key, fn) 56 | return c.val, true, c.err 57 | } 58 | 59 | func (g *sharedGroup) makeCall(key string, fn func() (interface{}, error)) *call { 60 | c := new(call) 61 | c.wg.Add(1) 62 | g.m[key] = c 63 | g.mu.Unlock() 64 | 65 | defer func() { 66 | // delete key first, done later. can't reverse the order, because if reverse, 67 | // another Do call might wg.Wait() without get notified with wg.Done() 68 | g.mu.Lock() 69 | delete(g.m, key) 70 | g.mu.Unlock() 71 | c.wg.Done() 72 | }() 73 | 74 | c.val, c.err = fn() 75 | return c 76 | } 77 | -------------------------------------------------------------------------------- /syncx/sharedcalls_test.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "sync" 7 | "sync/atomic" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestExclusiveCallDo(t *testing.T) { 13 | g := NewSharedCalls() 14 | v, err := g.Do("key", func() (interface{}, error) { 15 | return "bar", nil 16 | }) 17 | if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { 18 | t.Errorf("Do = %v; want %v", got, want) 19 | } 20 | if err != nil { 21 | t.Errorf("Do error = %v", err) 22 | } 23 | } 24 | 25 | func TestExclusiveCallDoErr(t *testing.T) { 26 | g := NewSharedCalls() 27 | someErr := errors.New("some error") 28 | v, err := g.Do("key", func() (interface{}, error) { 29 | return nil, someErr 30 | }) 31 | if err != someErr { 32 | t.Errorf("Do error = %v; want someErr", err) 33 | } 34 | if v != nil { 35 | t.Errorf("unexpected non-nil value %#v", v) 36 | } 37 | } 38 | 39 | func TestExclusiveCallDoDupSuppress(t *testing.T) { 40 | g := NewSharedCalls() 41 | c := make(chan string) 42 | var calls int32 43 | fn := func() (interface{}, error) { 44 | atomic.AddInt32(&calls, 1) 45 | return <-c, nil 46 | } 47 | 48 | const n = 10 49 | var wg sync.WaitGroup 50 | for i := 0; i < n; i++ { 51 | wg.Add(1) 52 | go func() { 53 | v, err := g.Do("key", fn) 54 | if err != nil { 55 | t.Errorf("Do error: %v", err) 56 | } 57 | if v.(string) != "bar" { 58 | t.Errorf("got %q; want %q", v, "bar") 59 | } 60 | wg.Done() 61 | }() 62 | } 63 | time.Sleep(100 * time.Millisecond) // let goroutines above block 64 | c <- "bar" 65 | wg.Wait() 66 | if got := atomic.LoadInt32(&calls); got != 1 { 67 | t.Errorf("number of calls = %d; want 1", got) 68 | } 69 | } 70 | 71 | func TestExclusiveCallDoExDupSuppress(t *testing.T) { 72 | g := NewSharedCalls() 73 | c := make(chan string) 74 | var calls int32 75 | fn := func() (interface{}, error) { 76 | atomic.AddInt32(&calls, 1) 77 | return <-c, nil 78 | } 79 | 80 | const n = 10 81 | var wg sync.WaitGroup 82 | var freshes int32 83 | for i := 0; i < n; i++ { 84 | wg.Add(1) 85 | go func() { 86 | v, fresh, err := g.DoEx("key", fn) 87 | if err != nil { 88 | t.Errorf("Do error: %v", err) 89 | } 90 | if fresh { 91 | atomic.AddInt32(&freshes, 1) 92 | } 93 | if v.(string) != "bar" { 94 | t.Errorf("got %q; want %q", v, "bar") 95 | } 96 | wg.Done() 97 | }() 98 | } 99 | time.Sleep(100 * time.Millisecond) // let goroutines above block 100 | c <- "bar" 101 | wg.Wait() 102 | if got := atomic.LoadInt32(&calls); got != 1 { 103 | t.Errorf("number of calls = %d; want 1", got) 104 | } 105 | if got := atomic.LoadInt32(&freshes); got != 1 { 106 | t.Errorf("freshes = %d; want 1", got) 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /syncx/spinlock.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "runtime" 5 | "sync/atomic" 6 | ) 7 | 8 | type SpinLock struct { 9 | lock uint32 10 | } 11 | 12 | func (sl *SpinLock) Lock() { 13 | for !sl.TryLock() { 14 | runtime.Gosched() 15 | } 16 | } 17 | 18 | func (sl *SpinLock) TryLock() bool { 19 | return atomic.CompareAndSwapUint32(&sl.lock, 0, 1) 20 | } 21 | 22 | func (sl *SpinLock) Unlock() { 23 | atomic.StoreUint32(&sl.lock, 0) 24 | } 25 | -------------------------------------------------------------------------------- /syncx/spinlock_test.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestTryLock(t *testing.T) { 12 | var lock SpinLock 13 | assert.True(t, lock.TryLock()) 14 | assert.False(t, lock.TryLock()) 15 | lock.Unlock() 16 | assert.True(t, lock.TryLock()) 17 | } 18 | 19 | func TestSpinLock(t *testing.T) { 20 | var lock SpinLock 21 | lock.Lock() 22 | assert.False(t, lock.TryLock()) 23 | lock.Unlock() 24 | assert.True(t, lock.TryLock()) 25 | } 26 | 27 | func TestSpinLockRace(t *testing.T) { 28 | var lock SpinLock 29 | lock.Lock() 30 | var wait sync.WaitGroup 31 | wait.Add(1) 32 | go func() { 33 | lock.Lock() 34 | lock.Unlock() 35 | wait.Done() 36 | }() 37 | time.Sleep(time.Millisecond * 100) 38 | lock.Unlock() 39 | wait.Wait() 40 | assert.True(t, lock.TryLock()) 41 | } 42 | -------------------------------------------------------------------------------- /syncx/timeoutlimit.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "errors" 5 | "time" 6 | ) 7 | 8 | var ErrTimeout = errors.New("borrow timeout") 9 | 10 | type TimeoutLimit struct { 11 | limit Limit 12 | cond *Cond 13 | } 14 | 15 | func NewTimeoutLimit(n int) TimeoutLimit { 16 | return TimeoutLimit{ 17 | limit: NewLimit(n), 18 | cond: NewCond(), 19 | } 20 | } 21 | 22 | func (l TimeoutLimit) Borrow(timeout time.Duration) error { 23 | if l.TryBorrow() { 24 | return nil 25 | } 26 | 27 | var ok bool 28 | for { 29 | timeout, ok = l.cond.WaitWithTimeout(timeout) 30 | if ok && l.TryBorrow() { 31 | return nil 32 | } 33 | 34 | if timeout <= 0 { 35 | return ErrTimeout 36 | } 37 | } 38 | } 39 | 40 | func (l TimeoutLimit) Return() error { 41 | if err := l.limit.Return(); err != nil { 42 | return err 43 | } 44 | 45 | l.cond.Signal() 46 | return nil 47 | } 48 | 49 | func (l TimeoutLimit) TryBorrow() bool { 50 | return l.limit.TryBorrow() 51 | } 52 | -------------------------------------------------------------------------------- /syncx/timeoutlimit_test.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestTimeoutLimit(t *testing.T) { 12 | limit := NewTimeoutLimit(2) 13 | assert.Nil(t, limit.Borrow(time.Millisecond*200)) 14 | assert.Nil(t, limit.Borrow(time.Millisecond*200)) 15 | var wait1, wait2, wait3 sync.WaitGroup 16 | wait1.Add(1) 17 | wait2.Add(1) 18 | wait3.Add(1) 19 | go func() { 20 | wait1.Wait() 21 | wait2.Done() 22 | assert.Nil(t, limit.Return()) 23 | wait3.Done() 24 | }() 25 | wait1.Done() 26 | wait2.Wait() 27 | assert.Nil(t, limit.Borrow(time.Second)) 28 | wait3.Wait() 29 | assert.Equal(t, ErrTimeout, limit.Borrow(time.Millisecond*100)) 30 | assert.Nil(t, limit.Return()) 31 | assert.Nil(t, limit.Return()) 32 | assert.Equal(t, ErrReturn, limit.Return()) 33 | } 34 | -------------------------------------------------------------------------------- /system/shutdown+polyfill.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | // +build windows 3 | 4 | package system 5 | 6 | import "time" 7 | 8 | func AddShutdownListener(fn func()) func() { 9 | return nil 10 | } 11 | 12 | func AddWrapUpListener(fn func()) func() { 13 | return nil 14 | } 15 | 16 | func SetTimeoutToForceQuit(duration time.Duration) { 17 | } 18 | -------------------------------------------------------------------------------- /system/shutdown.go: -------------------------------------------------------------------------------- 1 | //go:build linux || darwin 2 | // +build linux darwin 3 | 4 | package system 5 | 6 | import ( 7 | "log" 8 | "os" 9 | "os/signal" 10 | "sync" 11 | "syscall" 12 | "time" 13 | ) 14 | 15 | const ( 16 | wrapUpTime = time.Second 17 | // why we use 5500 milliseconds is because most of our queue are blocking mode with 5 seconds 18 | waitTime = 5500 * time.Millisecond 19 | ) 20 | 21 | var ( 22 | wrapUpListeners = new(listenerManager) 23 | shutdownListeners = new(listenerManager) 24 | delayTimeBeforeForceQuit = waitTime 25 | ) 26 | 27 | func AddShutdownListener(fn func()) (waitForCalled func()) { 28 | return shutdownListeners.addListener(fn) 29 | } 30 | 31 | func AddWrapUpListener(fn func()) (waitForCalled func()) { 32 | return wrapUpListeners.addListener(fn) 33 | } 34 | 35 | func SetTimeoutToForceQuit(duration time.Duration) { 36 | delayTimeBeforeForceQuit = duration 37 | } 38 | 39 | func gracefulStop(signals chan os.Signal) { 40 | signal.Stop(signals) 41 | 42 | log.Println("Got signal SIGTERM, shutting down...") 43 | wrapUpListeners.notifyListeners() 44 | 45 | time.Sleep(wrapUpTime) 46 | shutdownListeners.notifyListeners() 47 | 48 | time.Sleep(delayTimeBeforeForceQuit - wrapUpTime) 49 | log.Printf("Still alive after %v, going to force kill the process...\n", delayTimeBeforeForceQuit) 50 | syscall.Kill(syscall.Getpid(), syscall.SIGTERM) 51 | } 52 | 53 | type listenerManager struct { 54 | lock sync.Mutex 55 | waitGroup sync.WaitGroup 56 | listeners []func() 57 | } 58 | 59 | func (lm *listenerManager) addListener(fn func()) (waitForCalled func()) { 60 | lm.waitGroup.Add(1) 61 | 62 | lm.lock.Lock() 63 | lm.listeners = append(lm.listeners, func() { 64 | defer lm.waitGroup.Done() 65 | fn() 66 | }) 67 | lm.lock.Unlock() 68 | 69 | return func() { 70 | lm.waitGroup.Wait() 71 | } 72 | } 73 | 74 | func (lm *listenerManager) notifyListeners() { 75 | lm.lock.Lock() 76 | defer lm.lock.Unlock() 77 | 78 | for _, listener := range lm.listeners { 79 | listener() 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /system/signals.go: -------------------------------------------------------------------------------- 1 | //go:build linux || darwin 2 | // +build linux darwin 3 | 4 | package system 5 | 6 | import ( 7 | "log" 8 | "os" 9 | "os/signal" 10 | "syscall" 11 | ) 12 | 13 | func init() { 14 | go func() { 15 | 16 | // https://golang.org/pkg/os/signal/#Notify 17 | signals := make(chan os.Signal, 1) 18 | signal.Notify(signals, syscall.SIGUSR1, syscall.SIGUSR2, syscall.SIGTERM) 19 | 20 | for { 21 | v := <-signals 22 | switch v { 23 | case syscall.SIGUSR1: 24 | log.Println("syscall.SIGUSR1") 25 | case syscall.SIGUSR2: 26 | log.Println("syscall.SIGUSR1") 27 | case syscall.SIGTERM: 28 | gracefulStop(signals) 29 | default: 30 | log.Println("Got unregistered signal:", v) 31 | } 32 | } 33 | }() 34 | } 35 | -------------------------------------------------------------------------------- /threading/rescue.go: -------------------------------------------------------------------------------- 1 | package threading 2 | 3 | import "github.com/weblazy/easy/logx" 4 | 5 | func Rescue(cleanups ...func()) { 6 | for _, cleanup := range cleanups { 7 | cleanup() 8 | } 9 | 10 | if p := recover(); p != nil { 11 | logx.Stack(p) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /threading/rescue_test.go: -------------------------------------------------------------------------------- 1 | package threading 2 | 3 | import ( 4 | "sync/atomic" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestRescue(t *testing.T) { 11 | var count int32 12 | assert.NotPanics(t, func() { 13 | defer Rescue(func() { 14 | atomic.AddInt32(&count, 2) 15 | }, func() { 16 | atomic.AddInt32(&count, 3) 17 | }) 18 | 19 | panic("hello") 20 | }) 21 | assert.Equal(t, int32(5), atomic.LoadInt32(&count)) 22 | } 23 | -------------------------------------------------------------------------------- /threading/routinegroup.go: -------------------------------------------------------------------------------- 1 | package threading 2 | 3 | import "sync" 4 | 5 | type RoutineGroup struct { 6 | waitGroup sync.WaitGroup 7 | } 8 | 9 | func NewRoutineGroup() *RoutineGroup { 10 | return new(RoutineGroup) 11 | } 12 | 13 | // Don't reference the variables from outside, 14 | // because outside variables can be changed by other goroutines 15 | func (g *RoutineGroup) Run(fn func()) { 16 | g.waitGroup.Add(1) 17 | 18 | go func() { 19 | defer g.waitGroup.Done() 20 | fn() 21 | }() 22 | } 23 | 24 | // Don't reference the variables from outside, 25 | // because outside variables can be changed by other goroutines 26 | func (g *RoutineGroup) RunSafe(fn func()) { 27 | g.waitGroup.Add(1) 28 | 29 | GoSafe(func() { 30 | defer g.waitGroup.Done() 31 | fn() 32 | }) 33 | } 34 | 35 | func (g *RoutineGroup) Wait() { 36 | g.waitGroup.Wait() 37 | } 38 | -------------------------------------------------------------------------------- /threading/routinegroup_test.go: -------------------------------------------------------------------------------- 1 | package threading 2 | 3 | import ( 4 | "io/ioutil" 5 | "log" 6 | "sync" 7 | "sync/atomic" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestRoutineGroupRun(t *testing.T) { 14 | var count int32 15 | group := NewRoutineGroup() 16 | for i := 0; i < 3; i++ { 17 | group.Run(func() { 18 | atomic.AddInt32(&count, 1) 19 | }) 20 | } 21 | 22 | group.Wait() 23 | 24 | assert.Equal(t, int32(3), count) 25 | } 26 | 27 | func TestRoutingGroupRunSafe(t *testing.T) { 28 | log.SetOutput(ioutil.Discard) 29 | 30 | var count int32 31 | group := NewRoutineGroup() 32 | var once sync.Once 33 | for i := 0; i < 3; i++ { 34 | group.RunSafe(func() { 35 | once.Do(func() { 36 | panic("") 37 | }) 38 | atomic.AddInt32(&count, 1) 39 | }) 40 | } 41 | 42 | group.Wait() 43 | 44 | assert.Equal(t, int32(2), count) 45 | } 46 | -------------------------------------------------------------------------------- /threading/routines.go: -------------------------------------------------------------------------------- 1 | package threading 2 | 3 | import ( 4 | "bytes" 5 | "runtime" 6 | "strconv" 7 | ) 8 | 9 | func GoSafe(fn func()) { 10 | go RunSafe(fn) 11 | } 12 | 13 | // Only for debug, never use it in production 14 | func RoutineId() uint64 { 15 | b := make([]byte, 64) 16 | b = b[:runtime.Stack(b, false)] 17 | b = bytes.TrimPrefix(b, []byte("goroutine ")) 18 | b = b[:bytes.IndexByte(b, ' ')] 19 | // if error, just return 0 20 | n, _ := strconv.ParseUint(string(b), 10, 64) 21 | 22 | return n 23 | } 24 | 25 | func RunSafe(fn func()) { 26 | defer Rescue() 27 | 28 | fn() 29 | } 30 | -------------------------------------------------------------------------------- /threading/routines_test.go: -------------------------------------------------------------------------------- 1 | package threading 2 | 3 | import ( 4 | "io/ioutil" 5 | "log" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestRoutineId(t *testing.T) { 12 | assert.True(t, RoutineId() > 0) 13 | } 14 | 15 | func TestRunSafe(t *testing.T) { 16 | log.SetOutput(ioutil.Discard) 17 | 18 | i := 0 19 | 20 | defer func() { 21 | assert.Equal(t, 1, i) 22 | }() 23 | 24 | ch := make(chan struct{}) 25 | go RunSafe(func() { 26 | defer func() { 27 | ch <- struct{}{} 28 | }() 29 | 30 | panic("panic") 31 | }) 32 | 33 | <-ch 34 | i++ 35 | } 36 | -------------------------------------------------------------------------------- /threading/taskrunner.go: -------------------------------------------------------------------------------- 1 | package threading 2 | 3 | import () 4 | 5 | type TaskRunner struct { 6 | limitChan chan struct{} 7 | } 8 | 9 | func NewTaskRunner(concurrency int) *TaskRunner { 10 | return &TaskRunner{ 11 | limitChan: make(chan struct{}, concurrency), 12 | } 13 | } 14 | 15 | func (rp *TaskRunner) Schedule(task func()) { 16 | rp.limitChan <- struct{}{} 17 | 18 | go func() { 19 | defer Rescue(func() { 20 | <-rp.limitChan 21 | }) 22 | 23 | task() 24 | }() 25 | } 26 | -------------------------------------------------------------------------------- /threading/taskrunner_test.go: -------------------------------------------------------------------------------- 1 | package threading 2 | 3 | import ( 4 | "runtime" 5 | "sync" 6 | "sync/atomic" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestRoutinePool(t *testing.T) { 13 | times := 100 14 | pool := NewTaskRunner(runtime.NumCPU()) 15 | 16 | var counter int32 17 | var waitGroup sync.WaitGroup 18 | for i := 0; i < times; i++ { 19 | waitGroup.Add(1) 20 | pool.Schedule(func() { 21 | atomic.AddInt32(&counter, 1) 22 | waitGroup.Done() 23 | }) 24 | } 25 | 26 | waitGroup.Wait() 27 | 28 | assert.Equal(t, times, int(counter)) 29 | } 30 | 31 | func BenchmarkRoutinePool(b *testing.B) { 32 | queue := NewTaskRunner(runtime.NumCPU()) 33 | for i := 0; i < b.N; i++ { 34 | queue.Schedule(func() { 35 | }) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /threading/timeout.go: -------------------------------------------------------------------------------- 1 | package threading 2 | 3 | import ( 4 | "context" 5 | "time" 6 | ) 7 | 8 | var ( 9 | ErrCanceled = context.Canceled 10 | ErrTimeout = context.DeadlineExceeded 11 | ) 12 | 13 | type FxOption func() context.Context 14 | 15 | func DoWithTimeout(fn func() error, timeout time.Duration, opts ...FxOption) error { 16 | parentCtx := context.Background() 17 | for _, opt := range opts { 18 | parentCtx = opt() 19 | } 20 | ctx, cancel := context.WithTimeout(parentCtx, timeout) 21 | defer cancel() 22 | 23 | done := make(chan error) 24 | panicChan := make(chan interface{}, 1) 25 | go func() { 26 | defer func() { 27 | if p := recover(); p != nil { 28 | panicChan <- p 29 | } 30 | }() 31 | done <- fn() 32 | close(done) 33 | }() 34 | 35 | select { 36 | case p := <-panicChan: 37 | panic(p) 38 | case err := <-done: 39 | return err 40 | case <-ctx.Done(): 41 | return ctx.Err() 42 | } 43 | } 44 | 45 | func WithContext(ctx context.Context) FxOption { 46 | return func() context.Context { 47 | return ctx 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /threading/timeout_test.go: -------------------------------------------------------------------------------- 1 | package threading 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestWithPanic(t *testing.T) { 12 | assert.Panics(t, func() { 13 | _ = DoWithTimeout(func() error { 14 | panic("hello") 15 | }, time.Millisecond*50) 16 | }) 17 | } 18 | 19 | func TestWithTimeout(t *testing.T) { 20 | assert.Equal(t, ErrTimeout, DoWithTimeout(func() error { 21 | time.Sleep(time.Millisecond * 50) 22 | return nil 23 | }, time.Millisecond)) 24 | } 25 | 26 | func TestWithoutTimeout(t *testing.T) { 27 | assert.Nil(t, DoWithTimeout(func() error { 28 | return nil 29 | }, time.Millisecond*50)) 30 | } 31 | 32 | func TestWithCancel(t *testing.T) { 33 | ctx, cancel := context.WithCancel(context.Background()) 34 | go func() { 35 | time.Sleep(time.Millisecond * 10) 36 | cancel() 37 | }() 38 | err := DoWithTimeout(func() error { 39 | time.Sleep(time.Minute) 40 | return nil 41 | }, time.Second, WithContext(ctx)) 42 | assert.Equal(t, ErrCanceled, err) 43 | } 44 | -------------------------------------------------------------------------------- /timex/relativetime.go: -------------------------------------------------------------------------------- 1 | package timex 2 | 3 | import "time" 4 | 5 | // Use the long enough past time as start time, in case timex.Now() - lastTime equals 0. 6 | var initTime = time.Now().AddDate(-1, -1, -1) 7 | 8 | func Now() time.Duration { 9 | return time.Since(initTime) 10 | } 11 | 12 | func Since(d time.Duration) time.Duration { 13 | return time.Since(initTime) - d 14 | } 15 | 16 | func ToTime(d time.Duration) time.Time { 17 | return initTime.Add(d) 18 | } 19 | -------------------------------------------------------------------------------- /timex/relativetime_test.go: -------------------------------------------------------------------------------- 1 | package timex 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestRelativeTime(t *testing.T) { 11 | time.Sleep(time.Millisecond) 12 | now := Now() 13 | assert.True(t, now > 0) 14 | time.Sleep(time.Millisecond) 15 | assert.True(t, Since(now) > 0) 16 | } 17 | 18 | func TestRelativeTime_Time(t *testing.T) { 19 | diff := ToTime(Now()).Sub(time.Now()) 20 | if diff > 0 { 21 | assert.True(t, diff < time.Second) 22 | } else { 23 | assert.True(t, -diff < time.Second) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /timex/ticker.go: -------------------------------------------------------------------------------- 1 | package timex 2 | 3 | import ( 4 | "errors" 5 | "time" 6 | 7 | "github.com/weblazy/easy/lang" 8 | ) 9 | 10 | type ( 11 | Ticker interface { 12 | Chan() <-chan time.Time 13 | Stop() 14 | } 15 | 16 | FakeTicker interface { 17 | Ticker 18 | Done() 19 | Tick() 20 | Wait(d time.Duration) error 21 | } 22 | 23 | fakeTicker struct { 24 | c chan time.Time 25 | done chan lang.PlaceholderType 26 | } 27 | 28 | realTicker struct { 29 | *time.Ticker 30 | } 31 | ) 32 | 33 | func NewRealTicker(d time.Duration) Ticker { 34 | return &realTicker{ 35 | Ticker: time.NewTicker(d), 36 | } 37 | } 38 | 39 | func (rt *realTicker) Chan() <-chan time.Time { 40 | return rt.C 41 | } 42 | 43 | func NewFakeTicker() FakeTicker { 44 | return &fakeTicker{ 45 | c: make(chan time.Time, 1), 46 | done: make(chan lang.PlaceholderType, 1), 47 | } 48 | } 49 | 50 | func (ft *fakeTicker) Chan() <-chan time.Time { 51 | return ft.c 52 | } 53 | 54 | func (ft *fakeTicker) Done() { 55 | ft.done <- lang.Placeholder 56 | } 57 | 58 | func (ft *fakeTicker) Stop() { 59 | close(ft.c) 60 | } 61 | 62 | func (ft *fakeTicker) Tick() { 63 | ft.c <- ToTime(Now()) 64 | } 65 | 66 | func (ft *fakeTicker) Wait(d time.Duration) error { 67 | select { 68 | case <-time.After(d): 69 | return errors.New("timeout") 70 | case <-ft.done: 71 | return nil 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /timex/ticker_test.go: -------------------------------------------------------------------------------- 1 | package timex 2 | 3 | import ( 4 | "sync/atomic" 5 | "testing" 6 | "time" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestRealTickerDoTick(t *testing.T) { 12 | ticker := NewRealTicker(time.Millisecond * 10) 13 | defer ticker.Stop() 14 | var count int 15 | for range ticker.Chan() { 16 | count++ 17 | if count > 5 { 18 | break 19 | } 20 | } 21 | } 22 | 23 | func TestFakeTicker(t *testing.T) { 24 | const total = 5 25 | ticker := NewFakeTicker() 26 | defer ticker.Stop() 27 | 28 | var count int32 29 | go func() { 30 | for { 31 | select { 32 | case <-ticker.Chan(): 33 | if atomic.AddInt32(&count, 1) == total { 34 | ticker.Done() 35 | } 36 | } 37 | } 38 | }() 39 | 40 | for i := 0; i < 5; i++ { 41 | ticker.Tick() 42 | } 43 | 44 | assert.Nil(t, ticker.Wait(time.Second)) 45 | assert.Equal(t, int32(total), atomic.LoadInt32(&count)) 46 | } 47 | -------------------------------------------------------------------------------- /timex/timex.go: -------------------------------------------------------------------------------- 1 | package timex 2 | 3 | import ( 4 | "database/sql/driver" 5 | "time" 6 | ) 7 | 8 | type Time time.Time 9 | 10 | const ( 11 | timeFormart = "2006-01-02 15:04:05" 12 | ) 13 | 14 | func (t *Time) UnmarshalJSON(data []byte) (err error) { 15 | now, err := time.ParseInLocation(`"`+timeFormart+`"`, string(data), time.Local) 16 | *t = Time(now) 17 | return 18 | } 19 | 20 | func (t Time) MarshalJSON() ([]byte, error) { 21 | b := make([]byte, 0, len(timeFormart)+2) 22 | b = append(b, '"') 23 | if !time.Time(t).IsZero() { 24 | b = time.Time(t).AppendFormat(b, timeFormart) 25 | } 26 | b = append(b, '"') 27 | return b, nil 28 | } 29 | 30 | func (t Time) String() string { 31 | return time.Time(t).Format(timeFormart) 32 | } 33 | 34 | func (t Time) Value() (driver.Value, error) { 35 | if time.Time(t).IsZero() { 36 | return nil, nil 37 | } 38 | return time.Time(t), nil 39 | } 40 | 41 | func (t *Time) Scan(v interface{}) error { 42 | value, ok := v.(time.Time) 43 | if ok { 44 | *t = Time(value) 45 | return nil 46 | } 47 | return nil 48 | } 49 | -------------------------------------------------------------------------------- /timex/utils.go: -------------------------------------------------------------------------------- 1 | package timex 2 | 3 | import "time" 4 | 5 | const ( 6 | TimeLayout = "2006-01-02 15:04:05" 7 | DateLayout = "2006-01-02" 8 | ) 9 | 10 | // @desc 获取某一天的0点时间 11 | // @auth liuguoqiang 2020-04-27 12 | // @param 13 | // @return 14 | func ZeroTime(d time.Time) time.Time { 15 | return time.Date(d.Year(), d.Month(), d.Day(), 0, 0, 0, 0, time.Local) 16 | } 17 | 18 | func ZeroTimeWithLocation(d time.Time, loc *time.Location) time.Time { 19 | return time.Date(d.Year(), d.Month(), d.Day(), 0, 0, 0, 0, loc) 20 | } 21 | 22 | func UTCZeroTime(d time.Time) time.Time { 23 | return time.Date(d.Year(), d.Month(), d.Day(), 0, 0, 0, 0, time.UTC) 24 | } 25 | 26 | func ShanghaiZeroTime(d time.Time) time.Time { 27 | t, _ := time.LoadLocation("Asia/Shanghai") 28 | return time.Date(d.Year(), d.Month(), d.Day(), 0, 0, 0, 0, t) 29 | } 30 | 31 | // @desc 返回一个月的开始时间和结束时间 32 | // @auth liuguoqiang 2020-04-27 33 | // @param 34 | // @return 35 | func MonthRange(timeStamp int64) (int64, int64) { 36 | d := time.Unix(timeStamp, 0) 37 | d = d.AddDate(0, 0, -d.Day()+1) 38 | start := ZeroTime(d) 39 | end := start.AddDate(0, 1, 0) 40 | return start.Unix(), end.Unix() 41 | } 42 | 43 | func MonthRangeWthiLocation(timeStamp int64, loc *time.Location) (int64, int64) { 44 | d := time.Unix(timeStamp, 0) 45 | d = d.AddDate(0, 0, -d.Day()+1) 46 | start := ZeroTimeWithLocation(d, loc) 47 | end := start.AddDate(0, 1, 0) 48 | return start.Unix(), end.Unix() 49 | } 50 | -------------------------------------------------------------------------------- /transport/grpc_transport.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "google.golang.org/grpc/metadata" 5 | ) 6 | 7 | // GrpcHeaderCarrier ... 8 | type GrpcHeaderCarrier metadata.MD 9 | 10 | // Get returns the value associated with the passed key. 11 | func (mc GrpcHeaderCarrier) Get(key string) string { 12 | vals := metadata.MD(mc).Get(key) 13 | if len(vals) > 0 { 14 | return vals[0] 15 | } 16 | return "" 17 | } 18 | 19 | // Set stores the key-value pair. 20 | func (mc GrpcHeaderCarrier) Set(key string, value string) { 21 | metadata.MD(mc).Set(key, value) 22 | } 23 | 24 | // Keys lists the keys stored in this carrier. 25 | func (mc GrpcHeaderCarrier) Keys() []string { 26 | keys := make([]string, 0, len(mc)) 27 | for k := range metadata.MD(mc) { 28 | keys = append(keys, k) 29 | } 30 | return keys 31 | } 32 | -------------------------------------------------------------------------------- /utils/run/daemon.go: -------------------------------------------------------------------------------- 1 | package run 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/weblazy/easy/timex" 7 | ) 8 | 9 | func DaemonRun(interval time.Duration, f func(), daemon func()) { 10 | ticker := timex.NewRealTicker(interval) 11 | stopChannel := make(chan struct{}) 12 | defer func() { 13 | stopChannel <- struct{}{} 14 | }() 15 | go func() { 16 | for { 17 | select { 18 | case <-ticker.Chan(): 19 | daemon() 20 | case <-stopChannel: 21 | ticker.Stop() 22 | return 23 | } 24 | } 25 | }() 26 | f() 27 | } 28 | -------------------------------------------------------------------------------- /webrtc_sfu/webrtc_sfu.go: -------------------------------------------------------------------------------- 1 | package webrtc_sfu 2 | --------------------------------------------------------------------------------