├── .gitignore ├── CHANGELOG.md ├── README.md ├── alimns ├── alimns.go ├── alimns_test.go ├── provider.go └── provider_test.go ├── aliyunmq ├── aliyunmq.go ├── aliyunmq_test.go ├── provider.go └── provider_test.go ├── cache ├── base_cache.go ├── base_cache_test.go ├── cache.go ├── cache_test.go ├── interface.go ├── memorycache │ ├── memory_cache.go │ └── memory_cache_test.go └── rediscache │ ├── redis_cache.go │ └── redis_cache_test.go ├── command ├── command.go └── command_test.go ├── config └── config.go ├── db ├── db.go ├── db_test.go ├── model.go ├── model_test.go └── provider.go ├── go.mod ├── go.sum ├── helper ├── helper.go └── helper_test.go ├── http ├── ctxkit │ ├── ctxkit.go │ └── ctxkit_test.go └── middleware │ ├── access_log.go │ ├── access_test.go │ ├── ctxkit.go │ ├── ctxkit_test.go │ └── request_id.go ├── kernel ├── close │ ├── close.go │ └── close_test.go ├── container │ ├── app.go │ ├── container.go │ └── container_test.go ├── provider │ └── provider.go └── server │ ├── command.go │ ├── console.go │ ├── console_test.go │ ├── http.go │ ├── http_test.go │ ├── job.go │ ├── job_test.go │ ├── server.go │ └── server_test.go ├── log ├── accesslogger │ ├── access_logger.go │ ├── provider.go │ └── provider_test.go └── logger │ ├── http_logger.go │ ├── logger.go │ ├── provider.go │ ├── provider_test.go │ ├── roll_hook.go │ ├── segment.go │ └── source_hook.go ├── queue ├── alimnsqueue │ ├── alimns_queue.go │ └── queue_test.go ├── alirocketqueue │ ├── alirocket_queue.go │ └── alirocketmq_test.go ├── interface.go ├── queue.go ├── queue_test.go ├── redisqueue │ ├── redis_queue.go │ └── redis_queue_test.go └── rocketqueue │ ├── rocket_queue.go │ └── rocket_queue_test.go ├── redis ├── provider.go ├── provider_test.go ├── redis.go └── redis_test.go ├── rocketmq ├── provider.go └── rocketmq.go └── utils ├── base62.go ├── base62_test.go ├── convert.go ├── convert_test.go ├── hash.go ├── hash_test.go ├── httputil ├── http.go └── http_test.go ├── iputil ├── ip.go └── ip_test.go ├── json.go ├── json_test.go ├── string.go ├── string_test.go ├── time.go ├── time_test.go ├── url.go ├── url_test.go ├── uuid.go └── uuid_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea 2 | /vendor 3 | /.env* 4 | /*.log 5 | /coverage.data 6 | coverage.txt 7 | coverage.html 8 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## v0.1.29(2023-08-03) 2 | 3 | ### New Features 4 | - 社区版rocketmq支持延时消息 5 | - 6 | ## v0.1.23(2020-11-13) 7 | 8 | ### New Features 9 | - queue增加对社区版rocketmq的支持 10 | 11 | ## v0.1.20(2020-08-21) 12 | 13 | ### New Features 14 | - 日志文件可根据不同日志等级进行切分,不同级别日志记录到对应的日志文件 15 | 16 | ## v0.1.19(2020-06-16) 17 | 18 | ### New Features 19 | - 升级work包v0.3.4->v0.3.9,对列轮询机制修改为滚动递增 20 | 21 | ### Bug Fix 22 | - 升级xorm包v0.7.4->v1.0.2,修复若干bug 23 | 24 | ## v0.1.17(2019-11-18) 25 | 26 | ### New Features 27 | - 去掉测试用fmt 28 | 29 | 30 | ## v0.1.16(2019-11-18) 31 | 32 | ### New Features 33 | - 队列queue新增aliyunmq【rocketmq】驱动 34 | 35 | ## v0.1.15(2019-11-10) 36 | 37 | ### New Features 38 | - 队列queue出队增加返回参数:消费次数 39 | 40 | ## v0.1.14(2019-11-10) 41 | 42 | ### New Features 43 | - 更改mns队列默认可见时间 44 | 45 | ## v0.1.13(2019-10-20) 46 | 47 | ### New Features 48 | - 缓存驱动增加,内存缓存 49 | 50 | ## v0.1.12(2019-09-29) 51 | 52 | ### New Features 53 | - 缓存组件增加decr和incr方法 54 | - 增加单测文件 55 | 56 | ## v0.1.11(2019-08-25) 57 | 58 | ### New Features 59 | - 日志增加traceId 60 | - middleware组件和httputil组件增加内容串联traceId 61 | - ctxkit包优化,方便扩展 62 | 63 | 64 | ## v0.1.10(2019-08-02) 65 | 66 | ### New Features 67 | - 补充单测案例 68 | 69 | ### Bug Fix 70 | - cache和queue包在获取对象时读锁枷锁未配对解锁 71 | 72 | ## v0.1.9(2019-08-01) 73 | 74 | ### New Features 75 | - 补充单测案例 76 | 77 | ### Changes 78 | - 优化utils包HttpBuildQuery的map嵌套转换实现 79 | 80 | ## v0.1.8(2019-07-26) 81 | 82 | ### New Features 83 | - rediscache的单元测试案例 84 | 85 | ### Changes 86 | - rediscache的Get返回优化。若key不存在之前是返回错误类型ErrNil,现在不返回错误,返回字符串为空 87 | 88 | ### Bug Fix 89 | - 修复rediscache的SetMulti实现bug 90 | 91 | ## v0.1.7(2019-07-25) 92 | 93 | ### Changes 94 | - 更新qit-team/work包的版本号v0.3.3->v.0.3.4 95 | 96 | ## v0.1.6(2019-07-24) 97 | 98 | ### Bug Fix 99 | - 修复utils包HttpBuildQuery的对值非字符串的处理bug 100 | 101 | ## v0.1.5(2019-07-23) 102 | 103 | ### New Features 104 | - Command执行脚本模式支持 105 | 106 | ## v0.1.4(2019-07-23) 107 | 108 | ### Changes 109 | - utils工具包 110 | - HTTP请求工具包封装建议的Get Post PostJson Request方法 111 | 112 | ## v0.1.3(2019-07-22) 113 | 114 | ### New Features 115 | - Redis组件服务 116 | - Log组件服务 117 | - DB组件服务 118 | - Config通用配置结构 119 | - Cache缓存及驱动 120 | - Queue队列及驱动 121 | - Http的通用中间件和通用上下文kit 122 | - Kernel内核包 123 | - close服务注册 124 | - provider组件注册 125 | - container容器注入 126 | - server通用服务启动 127 | - utils工具包 128 | - HTTP请求工具包 129 | - 其他常用函数工具包 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 简介 2 | Snow框架的核心组件包 -------------------------------------------------------------------------------- /alimns/alimns.go: -------------------------------------------------------------------------------- 1 | package alimns 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/aliyun/aliyun-mns-go-sdk" 7 | "github.com/qit-team/snow-core/config" 8 | ) 9 | 10 | //依赖注入用的函数 11 | func NewMnsClient(mnsConfig config.MnsConfig) (client ali_mns.MNSClient, err error) { 12 | //2.1初始化mns client 13 | defer func() { 14 | if e := recover(); e != nil { 15 | s := fmt.Sprintf("ali_mns client init panic: %s", fmt.Sprint(e)) 16 | err = errors.New(s) 17 | } 18 | }() 19 | 20 | if mnsConfig.Url != "" { 21 | client = ali_mns.NewAliMNSClient(mnsConfig.Url, 22 | mnsConfig.AccessKeyId, 23 | mnsConfig.AccessKeySecret) 24 | } 25 | return 26 | } 27 | 28 | func GetMnsBasicQueue(client ali_mns.MNSClient, queueName string) ali_mns.AliMNSQueue { 29 | var defaultQueue ali_mns.AliMNSQueue 30 | 31 | //根据client创建manager 32 | queueManager := ali_mns.NewMNSQueueManager(client) 33 | 34 | // 暂时将visibilityTimeout 设置成60,后续将参数暴露给上层,可自行配置 35 | err := queueManager.CreateQueue(queueName, 0, 65536, 345600, 60, 0, 3) 36 | if err != nil && !ali_mns.ERR_MNS_QUEUE_ALREADY_EXIST_AND_HAVE_SAME_ATTR.IsEqual(err) { 37 | fmt.Println(err) 38 | return defaultQueue 39 | } 40 | //最终的最小执行单元queue 41 | return ali_mns.NewMNSQueue(queueName, client) 42 | } 43 | -------------------------------------------------------------------------------- /alimns/alimns_test.go: -------------------------------------------------------------------------------- 1 | package alimns 2 | 3 | import ( 4 | "github.com/qit-team/snow-core/config" 5 | "testing" 6 | ) 7 | 8 | func TestNewMnsClient(t *testing.T) { 9 | conf := config.MnsConfig{ 10 | Url: "", 11 | AccessKeyId: "", 12 | AccessKeySecret: "", 13 | } 14 | c, err := NewMnsClient(conf) 15 | if err != nil { 16 | t.Error(err) 17 | return 18 | } else if c != nil { 19 | t.Error("client is not nil") 20 | return 21 | } 22 | } 23 | 24 | func TestNewMnsClient2(t *testing.T) { 25 | conf := config.MnsConfig{ 26 | Url: "http://www.baidu.com", 27 | AccessKeyId: "1", 28 | AccessKeySecret: "2", 29 | } 30 | 31 | _, err := NewMnsClient(conf) 32 | if err == nil { 33 | t.Error("invalid config must return err") 34 | } 35 | } 36 | 37 | func TestGetMnsBasicQueue(t *testing.T) { 38 | defer func() { 39 | if e := recover(); e == nil { 40 | t.Error("not panic") 41 | } 42 | }() 43 | GetMnsBasicQueue(nil, "test") 44 | } 45 | -------------------------------------------------------------------------------- /alimns/provider.go: -------------------------------------------------------------------------------- 1 | package alimns 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/aliyun/aliyun-mns-go-sdk" 7 | "github.com/qit-team/snow-core/config" 8 | "github.com/qit-team/snow-core/helper" 9 | "github.com/qit-team/snow-core/kernel/container" 10 | "sync" 11 | ) 12 | 13 | const ( 14 | SingletonMain = "ali_mns" 15 | ) 16 | 17 | var Pr *provider 18 | 19 | func init() { 20 | Pr = new(provider) 21 | Pr.mp = make(map[string]interface{}) 22 | } 23 | 24 | type provider struct { 25 | mu sync.RWMutex 26 | mp map[string]interface{} //配置 27 | dn string //default name 28 | } 29 | 30 | /** 31 | * @param string 依赖注入别名 必选 32 | * @param config.LogConfig 配置 必选 33 | * @param bool 是否启用懒加载 可选 34 | */ 35 | func (p *provider) Register(args ...interface{}) (err error) { 36 | diName, lazy, err := helper.TransformArgs(args...) 37 | if err != nil { 38 | return 39 | } 40 | 41 | conf, ok := args[1].(config.MnsConfig) 42 | if !ok { 43 | return errors.New("args[1] is not config.MnsConfig") 44 | } 45 | 46 | p.mu.Lock() 47 | p.mp[diName] = args[1] 48 | if len(p.mp) == 1 { 49 | p.dn = diName 50 | } 51 | p.mu.Unlock() 52 | 53 | if !lazy { 54 | _, err = setSingleton(diName, conf) 55 | } 56 | return 57 | } 58 | 59 | //注册过的别名 60 | func (p *provider) Provides() []string { 61 | p.mu.RLock() 62 | defer p.mu.RUnlock() 63 | 64 | return helper.MapToArray(p.mp) 65 | } 66 | 67 | //释放资源 68 | func (p *provider) Close() error { 69 | return nil 70 | } 71 | 72 | //注入单例 73 | func setSingleton(diName string, conf config.MnsConfig) (ins ali_mns.MNSClient, err error) { 74 | ins, err = NewMnsClient(conf) 75 | if err == nil { 76 | container.App.SetSingleton(diName, ins) 77 | } 78 | return 79 | } 80 | 81 | //获取单例 82 | func getSingleton(diName string, lazy bool) ali_mns.MNSClient { 83 | rc := container.App.GetSingleton(diName) 84 | if rc != nil { 85 | return rc.(ali_mns.MNSClient) 86 | } 87 | if lazy == false { 88 | return nil 89 | } 90 | 91 | Pr.mu.RLock() 92 | conf, ok := Pr.mp[diName].(config.MnsConfig) 93 | Pr.mu.RUnlock() 94 | if !ok { 95 | panic(fmt.Sprintf("alimns di_name:%s not exist", diName)) 96 | } 97 | 98 | ins, err := setSingleton(diName, conf) 99 | if err != nil { 100 | panic(fmt.Sprintf("alimns di_name:%s err:%s", diName, err.Error())) 101 | } 102 | return ins 103 | } 104 | 105 | //外部通过注入别名获取资源,解耦资源的关系 106 | func GetMns(args ...string) ali_mns.MNSClient { 107 | diName := helper.GetDiName(Pr.dn, args...) 108 | return getSingleton(diName, true) 109 | } 110 | -------------------------------------------------------------------------------- /alimns/provider_test.go: -------------------------------------------------------------------------------- 1 | package alimns 2 | 3 | import ( 4 | "github.com/qit-team/snow-core/config" 5 | "testing" 6 | ) 7 | 8 | func Test_getSingleton(t *testing.T) { 9 | c := getSingleton("", false) 10 | if c != nil { 11 | t.Error("client is not equal nil") 12 | return 13 | } 14 | } 15 | 16 | func TestProvider(t *testing.T) { 17 | err := Pr.Register("mns", config.MnsConfig{}, true) 18 | if err != nil { 19 | t.Error(err) 20 | return 21 | } 22 | 23 | arr := Pr.Provides() 24 | if !(len(arr) == 1 && arr[0] == "mns") { 25 | t.Errorf("Provides is not match. %v", arr) 26 | return 27 | } 28 | 29 | err = Pr.Register("mns1", config.MnsConfig{}) 30 | if err != nil { 31 | t.Error(err) 32 | return 33 | } 34 | 35 | arr = Pr.Provides() 36 | if !(len(arr) == 2 && arr[1] == "mns1" || arr[1] == "mns") { 37 | t.Errorf("Provides is not match. %v", arr) 38 | return 39 | } 40 | 41 | err = Pr.Close() 42 | if err != nil { 43 | t.Error(err) 44 | return 45 | } 46 | 47 | c := GetMns() 48 | if c != nil { 49 | t.Error("client is not equal nil") 50 | return 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /aliyunmq/aliyunmq.go: -------------------------------------------------------------------------------- 1 | package aliyunmq 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/aliyunmq/mq-http-go-sdk" 7 | "github.com/qit-team/snow-core/config" 8 | ) 9 | 10 | //依赖注入用的函数 11 | func NewAliyunMqClient(mqConfig config.AliyunMqConfig) (client mq_http_sdk.MQClient, err error) { 12 | // 初始化aliyunmq的 client 13 | defer func() { 14 | if e := recover(); e != nil { 15 | s := fmt.Sprintf("aliyun_mq client init panic: %s", fmt.Sprint(e)) 16 | err = errors.New(s) 17 | } 18 | }() 19 | 20 | if mqConfig.EndPoint != "" { 21 | client = mq_http_sdk.NewAliyunMQClient(mqConfig.EndPoint, mqConfig.AccessKey, mqConfig.SecretKey, "") 22 | } else { 23 | err = errors.New("EndPoint empty,can not get client") 24 | } 25 | return 26 | } 27 | -------------------------------------------------------------------------------- /aliyunmq/aliyunmq_test.go: -------------------------------------------------------------------------------- 1 | package aliyunmq 2 | 3 | import ( 4 | "fmt" 5 | "github.com/qit-team/snow-core/config" 6 | "io/ioutil" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | func TestNewAliyunMqClient(t *testing.T) { 12 | 13 | conf := getConfig() 14 | fmt.Println("aliyun_mq config:", conf) 15 | c, err := NewAliyunMqClient(conf) 16 | if err != nil { 17 | t.Error(err) 18 | return 19 | } else if c == nil { 20 | t.Error("client is nil") 21 | return 22 | } 23 | } 24 | 25 | func TestNewAliyunMqClient2(t *testing.T) { 26 | conf := config.AliyunMqConfig{ 27 | EndPoint: "", 28 | AccessKey: "", 29 | SecretKey: "", 30 | } 31 | _, err := NewAliyunMqClient(conf) 32 | if err == nil { 33 | t.Error("invalid config must return err") 34 | } 35 | } 36 | 37 | func getConfig() config.AliyunMqConfig { 38 | //需要自己在文件填好配置 39 | bs, err := ioutil.ReadFile("../.env.aliyunmq") 40 | 41 | conf := config.AliyunMqConfig{} 42 | if err == nil { 43 | str := string(bs) 44 | arr := strings.Split(str, "\n") 45 | if len(arr) >= 3 { 46 | conf.EndPoint = arr[0] 47 | conf.AccessKey = arr[1] 48 | conf.SecretKey = arr[2] 49 | } 50 | } 51 | return conf 52 | } 53 | -------------------------------------------------------------------------------- /aliyunmq/provider.go: -------------------------------------------------------------------------------- 1 | package aliyunmq 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/aliyunmq/mq-http-go-sdk" 7 | "github.com/qit-team/snow-core/config" 8 | "github.com/qit-team/snow-core/helper" 9 | "github.com/qit-team/snow-core/kernel/container" 10 | "sync" 11 | ) 12 | 13 | const ( 14 | SingletonMain = "aliyun_mq" 15 | ) 16 | 17 | var Pr *provider 18 | 19 | func init() { 20 | Pr = new(provider) 21 | Pr.mp = make(map[string]interface{}) 22 | } 23 | 24 | type provider struct { 25 | mu sync.RWMutex 26 | mp map[string]interface{} //配置 27 | dn string //default name 28 | } 29 | 30 | /** 31 | * @param string 依赖注入别名 必选 32 | * @param config.LogConfig 配置 必选 33 | * @param bool 是否启用懒加载 可选 34 | */ 35 | func (p *provider) Register(args ...interface{}) (err error) { 36 | diName, lazy, err := helper.TransformArgs(args...) 37 | if err != nil { 38 | return 39 | } 40 | 41 | conf, ok := args[1].(config.AliyunMqConfig) 42 | if !ok { 43 | return errors.New("args[1] is not config.MnsConfig") 44 | } 45 | 46 | p.mu.Lock() 47 | p.mp[diName] = args[1] 48 | if len(p.mp) == 1 { 49 | p.dn = diName 50 | } 51 | p.mu.Unlock() 52 | 53 | if !lazy { 54 | _, err = setSingleton(diName, conf) 55 | } 56 | return 57 | } 58 | 59 | //注册过的别名 60 | func (p *provider) Provides() []string { 61 | p.mu.RLock() 62 | defer p.mu.RUnlock() 63 | 64 | return helper.MapToArray(p.mp) 65 | } 66 | 67 | //释放资源 68 | func (p *provider) Close() error { 69 | return nil 70 | } 71 | 72 | //注入单例 73 | func setSingleton(diName string, conf config.AliyunMqConfig) (ins mq_http_sdk.MQClient, err error) { 74 | ins, err = NewAliyunMqClient(conf) 75 | if err == nil { 76 | container.App.SetSingleton(diName, ins) 77 | } 78 | return 79 | } 80 | 81 | //获取单例 82 | func getSingleton(diName string, lazy bool) mq_http_sdk.MQClient { 83 | rc := container.App.GetSingleton(diName) 84 | if rc != nil { 85 | return rc.(mq_http_sdk.MQClient) 86 | } 87 | if lazy == false { 88 | return nil 89 | } 90 | 91 | Pr.mu.RLock() 92 | conf, ok := Pr.mp[diName].(config.AliyunMqConfig) 93 | Pr.mu.RUnlock() 94 | if !ok { 95 | panic(fmt.Sprintf("aliyun_mq di_name:%s not exist", diName)) 96 | } 97 | 98 | ins, err := setSingleton(diName, conf) 99 | if err != nil { 100 | panic(fmt.Sprintf("aliyun_mq di_name:%s err:%s", diName, err.Error())) 101 | } 102 | return ins 103 | } 104 | 105 | //外部通过注入别名获取资源,解耦资源的关系 106 | func GetAliyunMq(args ...string) mq_http_sdk.MQClient { 107 | diName := helper.GetDiName(Pr.dn, args...) 108 | return getSingleton(diName, true) 109 | } 110 | -------------------------------------------------------------------------------- /aliyunmq/provider_test.go: -------------------------------------------------------------------------------- 1 | package aliyunmq 2 | 3 | import ( 4 | "fmt" 5 | "github.com/qit-team/snow-core/config" 6 | "testing" 7 | ) 8 | 9 | func Test_getSingleton(t *testing.T) { 10 | c := getSingleton("", false) 11 | if c != nil { 12 | t.Error("client is not equal nil") 13 | return 14 | } 15 | } 16 | 17 | func TestProvider(t *testing.T) { 18 | err := Pr.Register("aliyun_mq", config.AliyunMqConfig{}, true) 19 | if err != nil { 20 | t.Error(err) 21 | return 22 | } 23 | 24 | arr := Pr.Provides() 25 | if !(len(arr) == 1 && arr[0] == "aliyun_mq") { 26 | t.Errorf("Provides is not match. %v", arr) 27 | return 28 | } 29 | 30 | err = Pr.Register("aliyun_mq1", config.AliyunMqConfig{}) 31 | if err != nil { 32 | t.Error(err) 33 | return 34 | } 35 | 36 | arr = Pr.Provides() 37 | if !(len(arr) == 2 && arr[1] == "aliyun_mq" || arr[1] == "aliyun_mq1") { 38 | t.Errorf("Provides is not match. %v", arr) 39 | return 40 | } 41 | 42 | err = Pr.Close() 43 | if err != nil { 44 | t.Error(err) 45 | return 46 | } 47 | 48 | c := GetAliyunMq() 49 | fmt.Println("providers.GetAliyunMq:", c) 50 | 51 | if c != nil { 52 | t.Error("client is not equal nil") 53 | return 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /cache/base_cache.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "context" 5 | "github.com/qit-team/snow-core/redis" 6 | "github.com/qit-team/snow-core/utils" 7 | ) 8 | 9 | const ( 10 | DefaultDiName = redis.SingletonMain 11 | DefaultDriverType = DriverTypeRedis 12 | DefaultPrefix = "" //默认缓存key前缀 13 | DefaultTTL = 86400 //默认缓存时间 14 | ) 15 | 16 | //缓存基类 17 | type BaseCache struct { 18 | cache Cache 19 | DiName string //缓存依赖的实例别名 20 | Prefix string //缓存key前缀 21 | DriverType string //缓存驱动 22 | ttl int //缓存时间 23 | ttlIsSet bool //避免TTL被设置过为0时,仍使用默认值的情况 24 | } 25 | 26 | //补全key 27 | func (m *BaseCache) key(key string) string { 28 | return m.Prefix + key 29 | } 30 | 31 | //批量补全 32 | func (m *BaseCache) keys(keys ...string) []string { 33 | arr := make([]string, len(keys)) 34 | for i, key := range keys { 35 | arr[i] = m.key(key) 36 | } 37 | return arr 38 | } 39 | 40 | //去除前缀 41 | func (m *BaseCache) removePrefix(key string) string { 42 | l := len(m.Prefix) 43 | return utils.Substr(key, l, len(key)-l) 44 | } 45 | 46 | func (m *BaseCache) GetPrefixOrDefault() string { 47 | if m.Prefix != "" { 48 | return m.Prefix 49 | } else { 50 | return DefaultPrefix 51 | } 52 | } 53 | 54 | func (m *BaseCache) GetDiNameOrDefault() string { 55 | if m.DiName != "" { 56 | return m.DiName 57 | } else { 58 | return DefaultDiName 59 | } 60 | } 61 | 62 | func (m *BaseCache) GetDriverTypeOrDefault() string { 63 | if m.DriverType != "" { 64 | return m.DriverType 65 | } else { 66 | return DefaultDriverType 67 | } 68 | } 69 | 70 | func (m *BaseCache) SetTTL(ttl int) { 71 | m.ttlIsSet = true 72 | m.ttl = ttl 73 | } 74 | 75 | func (m *BaseCache) GetTTLOrDefault() int { 76 | if m.ttlIsSet { 77 | return m.ttl 78 | } else { 79 | return DefaultTTL 80 | } 81 | } 82 | 83 | func (m *BaseCache) getTTL(ttl ...int) int { 84 | if len(ttl) > 0 { 85 | return ttl[0] 86 | } else { 87 | return m.GetTTLOrDefault() 88 | } 89 | } 90 | 91 | func (m *BaseCache) Get(ctx context.Context, key string) (interface{}, error) { 92 | key = m.key(key) 93 | return m.GetCache().Get(ctx, key) 94 | } 95 | 96 | func (m *BaseCache) Set(ctx context.Context, key string, value interface{}, ttl ...int) (bool, error) { 97 | key = m.key(key) 98 | return m.GetCache().Set(ctx, key, value, m.getTTL(ttl...)) 99 | } 100 | 101 | func (m *BaseCache) GetMulti(ctx context.Context, keys ...string) (map[string]interface{}, error) { 102 | keys = m.keys(keys...) 103 | items, err := m.GetCache().GetMulti(ctx, keys...) 104 | if err != nil { 105 | return nil, err 106 | } 107 | 108 | m2 := make(map[string]interface{}) 109 | for key, val := range items { 110 | m2[m.removePrefix(key)] = val 111 | } 112 | return m2, nil 113 | } 114 | 115 | func (m *BaseCache) SetMulti(ctx context.Context, items map[string]interface{}, ttl ...int) (bool, error) { 116 | arr := make(map[string]interface{}) 117 | for key, value := range items { 118 | key = m.key(key) 119 | arr[key] = value 120 | } 121 | return m.GetCache().SetMulti(ctx, arr, m.getTTL(ttl...)) 122 | } 123 | 124 | func (m *BaseCache) Delete(ctx context.Context, key string) (bool, error) { 125 | key = m.key(key) 126 | return m.GetCache().Delete(ctx, key) 127 | } 128 | 129 | func (m *BaseCache) DeleteMulti(ctx context.Context, keys ...string) (bool, error) { 130 | keys = m.keys(keys...) 131 | return m.GetCache().DeleteMulti(ctx, keys...) 132 | } 133 | 134 | func (m *BaseCache) Expire(ctx context.Context, key string, ttl ...int) (bool, error) { 135 | key = m.key(key) 136 | return m.GetCache().Expire(ctx, key, m.getTTL(ttl...)) 137 | } 138 | 139 | func (m *BaseCache) IsExist(ctx context.Context, key string) (bool, error) { 140 | key = m.key(key) 141 | return m.GetCache().IsExist(ctx, key) 142 | } 143 | 144 | func (m *BaseCache) IncrBy(ctx context.Context, key string, value int64) (int64, error) { 145 | key = m.key(key) 146 | return m.GetCache().IncrBy(ctx, key, value) 147 | } 148 | 149 | func (m *BaseCache) DecrBy(ctx context.Context, key string, value int64) (int64, error) { 150 | key = m.key(key) 151 | return m.GetCache().DecrBy(ctx, key, value) 152 | } 153 | 154 | //获取缓存类 155 | func (m *BaseCache) GetCache() Cache { 156 | //不使用once.Done是因为会有多种cache实例 157 | diName := m.GetDiNameOrDefault() 158 | driverType := m.GetDriverTypeOrDefault() 159 | return GetCache(diName, driverType) 160 | } 161 | -------------------------------------------------------------------------------- /cache/base_cache_test.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/qit-team/snow-core/config" 7 | "github.com/qit-team/snow-core/redis" 8 | "testing" 9 | ) 10 | 11 | var m *BaseCache 12 | var ctx context.Context 13 | 14 | func init() { 15 | m = new(BaseCache) 16 | m.Prefix = "test:" 17 | ctx = context.TODO() 18 | 19 | redisConf := config.RedisConfig{ 20 | Master: config.RedisBaseConfig{ 21 | Host: "127.0.0.1", 22 | Port: 6379, 23 | }, 24 | } 25 | 26 | //注册redis类 27 | err := redis.Pr.Register("redis", redisConf) 28 | if err != nil { 29 | fmt.Println(err) 30 | } 31 | } 32 | 33 | func TestBaseCache_GetPrefixOrDefault(t *testing.T) { 34 | m := new(BaseCache) 35 | s1 := m.GetPrefixOrDefault() 36 | if s1 != DefaultPrefix { 37 | t.Errorf("GetPrefixOrDefault is not equal default:%s", DefaultPrefix) 38 | return 39 | } 40 | 41 | m.Prefix = "m:" 42 | s2 := m.GetPrefixOrDefault() 43 | if s2 != m.Prefix { 44 | t.Errorf("GetPrefixOrDefault is not equal default:%s", m.Prefix) 45 | return 46 | } 47 | } 48 | 49 | func TestBaseCache_GetDiNameOrDefault(t *testing.T) { 50 | m := new(BaseCache) 51 | s1 := m.GetDiNameOrDefault() 52 | if s1 != DefaultDiName { 53 | t.Errorf("GetDiNameOrDefault is not equal default:%s", DefaultDiName) 54 | return 55 | } 56 | 57 | m.DiName = "di" 58 | s2 := m.GetDiNameOrDefault() 59 | if s2 != m.DiName { 60 | t.Errorf("GetDiNameOrDefault is not equal %s", m.DiName) 61 | return 62 | } 63 | } 64 | 65 | func TestBaseCache_GetDriverTypeOrDefault(t *testing.T) { 66 | m := new(BaseCache) 67 | s1 := m.GetDriverTypeOrDefault() 68 | if s1 != DefaultDriverType { 69 | t.Errorf("GetDriverTypeOrDefault is not equal default:%s", DefaultDriverType) 70 | return 71 | } 72 | 73 | m.DriverType = "dr" 74 | s2 := m.GetDriverTypeOrDefault() 75 | if s2 != m.DriverType { 76 | t.Errorf("GetDriverTypeOrDefault is not equal %s", m.DriverType) 77 | return 78 | } 79 | } 80 | 81 | func TestBaseCache_GetTTLOrDefault(t *testing.T) { 82 | m := new(BaseCache) 83 | t1 := m.GetTTLOrDefault() 84 | if t1 != DefaultTTL { 85 | t.Errorf("GetTTLOrDefault is not equal default:%d", DefaultTTL) 86 | return 87 | } 88 | 89 | m.SetTTL(1) 90 | t2 := m.GetTTLOrDefault() 91 | if t2 != 1 { 92 | t.Error("GetTTLOrDefault is not equal 1") 93 | return 94 | } 95 | 96 | m.SetTTL(0) 97 | t3 := m.GetTTLOrDefault() 98 | if t3 != 0 { 99 | t.Error("GetTTLOrDefault is not equal 0") 100 | return 101 | } 102 | } 103 | 104 | func TestBaseCache_getTTL(t *testing.T) { 105 | m := new(BaseCache) 106 | t1 := m.getTTL(1) 107 | if t1 != 1 { 108 | t.Error("getTTL is not equal 1") 109 | return 110 | } 111 | 112 | t2 := m.getTTL() 113 | if t2 != DefaultTTL { 114 | t.Errorf("getTTL is not equal %d", DefaultTTL) 115 | return 116 | } 117 | } 118 | 119 | func TestBaseCache_KeyRelated(t *testing.T) { 120 | m := new(BaseCache) 121 | key := "snow-test" 122 | redisKey := m.key(key) 123 | 124 | if len(redisKey) < 9 { 125 | t.Error("get redis key error") 126 | return 127 | } 128 | 129 | redisKeyList := m.keys("test-key-A", "test-key-B") 130 | 131 | for k, v := range redisKeyList { 132 | if len(v) < 10 { 133 | t.Error("get redis key error") 134 | return 135 | } 136 | fmt.Println("get redis key list:", k, v) 137 | } 138 | 139 | tempKey := m.removePrefix(redisKey) 140 | if len(tempKey) < 9 { 141 | t.Error("remove key error") 142 | return 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /cache/cache.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | ) 7 | 8 | const ( 9 | DriverTypeRedis = "redis" 10 | DriverTypeMemory = "memory" 11 | ) 12 | 13 | var ( 14 | drivers map[string]Instance 15 | mu sync.RWMutex 16 | ) 17 | 18 | type Instance func(diName string) Cache 19 | 20 | func Register(driverType string, driver Instance) { 21 | if driver == nil { 22 | panic("cache.Register driver is nil") 23 | } 24 | mu.Lock() 25 | defer mu.Unlock() 26 | 27 | if _, ok := drivers[driverType]; ok { 28 | panic("cache.Register called twice for driver " + driverType) 29 | } 30 | drivers[driverType] = driver 31 | } 32 | 33 | // args columns: TTL int 34 | func GetCache(diName string, driverType string) (q Cache) { 35 | mu.RLock() 36 | instanceFunc, ok := drivers[driverType] 37 | mu.RUnlock() 38 | if !ok { 39 | panic(fmt.Sprintf("cache.GetCache unknown driver %s", driverType)) 40 | } 41 | q = instanceFunc(diName) 42 | if q == nil { 43 | panic(fmt.Sprintf("cache.GetCache unknown diName %s", diName)) 44 | } 45 | return 46 | } 47 | 48 | //获取TTL时间 49 | func GetTTLOrDefault(ttl ...int) (t int) { 50 | if len(ttl) > 0 { 51 | t = ttl[0] 52 | } else { 53 | t = DefaultTTL 54 | } 55 | return 56 | } 57 | 58 | func init() { 59 | drivers = make(map[string]Instance) 60 | } 61 | -------------------------------------------------------------------------------- /cache/cache_test.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "fmt" 5 | "github.com/qit-team/snow-core/config" 6 | "github.com/qit-team/snow-core/redis" 7 | "testing" 8 | ) 9 | 10 | func init() { 11 | redisConf := config.RedisConfig{ 12 | Master: config.RedisBaseConfig{ 13 | Host: "127.0.0.1", 14 | Port: 6379, 15 | }, 16 | } 17 | 18 | //注册redis类 19 | err := redis.Pr.Register("redis", redisConf) 20 | if err != nil { 21 | fmt.Println(err) 22 | } 23 | //Register("redis", getRedisCache) 24 | Register("mock", getMockCache) 25 | } 26 | 27 | func getRedisCache(diName string) Cache { 28 | return nil 29 | } 30 | 31 | func getMockCache(diName string) Cache { 32 | return nil 33 | } 34 | 35 | func TestRegister(t *testing.T) { 36 | defer func() { 37 | if e := recover(); e == nil { 38 | t.Errorf("repeat register do not panic") 39 | } 40 | }() 41 | Register("mock", getMockCache) 42 | } 43 | 44 | func TestRegister_EmptyDriver(t *testing.T) { 45 | defer func() { 46 | if e := recover(); e == nil { 47 | t.Errorf("nil driver do not panic") 48 | } 49 | }() 50 | Register("mock", nil) 51 | } 52 | 53 | func TestGetCache_Empty(t *testing.T) { 54 | defer func() { 55 | if e := recover(); e == nil { 56 | t.Errorf("unknown driver do not panic") 57 | } 58 | }() 59 | GetCache("redis", "empty") 60 | } 61 | 62 | func TestGetCache_Nil(t *testing.T) { 63 | defer func() { 64 | if e := recover(); e == nil { 65 | t.Errorf("unknown diName do not panic") 66 | } 67 | }() 68 | GetCache("unknown", "mock") 69 | } 70 | -------------------------------------------------------------------------------- /cache/interface.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import "context" 4 | 5 | //缓存驱动接口,所以缓存驱动都需要实现以下接口 6 | type Cache interface { 7 | Get(ctx context.Context, key string) (interface{}, error) 8 | GetMulti(ctx context.Context, keys ...string) (map[string]interface{}, error) 9 | Set(ctx context.Context, key string, value interface{}, ttl ...int) (bool, error) 10 | SetMulti(ctx context.Context, items map[string]interface{}, ttl ...int) (bool, error) 11 | Delete(ctx context.Context, key string) (bool, error) 12 | DeleteMulti(ctx context.Context, key ...string) (bool, error) 13 | Expire(ctx context.Context, key string, ttl ...int) (bool, error) 14 | IsExist(ctx context.Context, key string) (bool, error) 15 | IncrBy(ctx context.Context, key string, value int64) (int64, error) 16 | DecrBy(ctx context.Context, key string, value int64) (int64, error) 17 | } 18 | -------------------------------------------------------------------------------- /cache/memorycache/memory_cache.go: -------------------------------------------------------------------------------- 1 | package memorycache 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "github.com/qit-team/snow-core/cache" 8 | "strconv" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | const ( 14 | MaxPersistenceTime = 86400 * 365 * 10 15 | ) 16 | 17 | var ( 18 | mp map[string]cache.Cache 19 | mu sync.RWMutex 20 | ErrWrongDataType error 21 | ) 22 | 23 | type Item struct { 24 | data interface{} 25 | expireAt time.Time 26 | } 27 | 28 | type MemoryCache struct { 29 | items map[string]Item 30 | mu sync.RWMutex 31 | } 32 | 33 | func init() { 34 | ErrWrongDataType = errors.New("wrong data type") 35 | } 36 | 37 | //实例模式 38 | func newMemoryCache() cache.Cache { 39 | m := new(MemoryCache) 40 | m.items = make(map[string]Item) 41 | return m 42 | } 43 | 44 | //单例模式 45 | func GetMemoryCache(diName string) cache.Cache { 46 | key := diName 47 | mu.RLock() 48 | q, ok := mp[key] 49 | mu.RUnlock() 50 | if ok { 51 | return q 52 | } 53 | 54 | q = newMemoryCache() 55 | mu.Lock() 56 | mp[key] = q 57 | mu.Unlock() 58 | return q 59 | } 60 | 61 | /** 62 | * 获取缓存key的数据 63 | * 注意事项,如果key值不存在的话,返回的是空字符串,而不是nil 64 | */ 65 | func (c *MemoryCache) Get(ctx context.Context, key string) (interface{}, error) { 66 | c.mu.RLock() 67 | value, ok := c.items[key] 68 | c.mu.RUnlock() 69 | if !ok { 70 | return "", nil 71 | } 72 | if inExpire(value.expireAt) { 73 | return value.data, nil 74 | } 75 | return "", nil 76 | } 77 | 78 | func (c *MemoryCache) GetMulti(ctx context.Context, keys ...string) (map[string]interface{}, error) { 79 | arr := make(map[string]interface{}) 80 | c.mu.RLock() 81 | defer c.mu.RUnlock() 82 | for _, key := range keys { 83 | if value, ok := c.items[key]; ok && inExpire(value.expireAt) { 84 | arr[key] = value.data 85 | } else { 86 | arr[key] = "" 87 | } 88 | } 89 | return arr, nil 90 | } 91 | 92 | func (c *MemoryCache) Set(ctx context.Context, key string, value interface{}, ttl ...int) (bool, error) { 93 | t := cache.GetTTLOrDefault(ttl...) 94 | if t == 0 { 95 | t = MaxPersistenceTime 96 | } 97 | item := Item{ 98 | data: value, 99 | expireAt: time.Now().Add(time.Duration(t) * time.Second), 100 | } 101 | c.mu.Lock() 102 | defer c.mu.Unlock() 103 | c.items[key] = item 104 | return true, nil 105 | } 106 | 107 | func (c *MemoryCache) SetMulti(ctx context.Context, items map[string]interface{}, ttl ...int) (bool, error) { 108 | t := cache.GetTTLOrDefault(ttl...) 109 | if t == 0 { 110 | t = MaxPersistenceTime 111 | } 112 | expireAt := time.Now().Add(time.Duration(t) * time.Second) 113 | var item Item 114 | c.mu.Lock() 115 | defer c.mu.Unlock() 116 | for key, value := range items { 117 | item = Item{ 118 | data: value, 119 | expireAt: expireAt, 120 | } 121 | c.items[key] = item 122 | } 123 | return true, nil 124 | } 125 | 126 | func (c *MemoryCache) Delete(ctx context.Context, key string) (bool, error) { 127 | c.mu.Lock() 128 | defer c.mu.Unlock() 129 | if _, ok := c.items[key]; ok { 130 | delete(c.items, key) 131 | } 132 | return true, nil 133 | } 134 | 135 | func (c *MemoryCache) DeleteMulti(ctx context.Context, keys ...string) (bool, error) { 136 | c.mu.Lock() 137 | defer c.mu.Unlock() 138 | for _, key := range keys { 139 | if _, ok := c.items[key]; ok { 140 | delete(c.items, key) 141 | } 142 | } 143 | return true, nil 144 | } 145 | 146 | func (c *MemoryCache) Expire(ctx context.Context, key string, ttl ...int) (bool, error) { 147 | t := cache.GetTTLOrDefault(ttl...) 148 | expireAt := time.Now().Add(time.Duration(t)) 149 | c.mu.Lock() 150 | defer c.mu.Unlock() 151 | if item, ok := c.items[key]; ok { 152 | if inExpire(item.expireAt) { 153 | item.expireAt = expireAt 154 | c.items[key] = item 155 | } else { 156 | delete(c.items, key) 157 | } 158 | } 159 | return true, nil 160 | } 161 | 162 | func (c *MemoryCache) IsExist(ctx context.Context, key string) (bool, error) { 163 | c.mu.RLock() 164 | value, ok := c.items[key] 165 | c.mu.RUnlock() 166 | if ok && inExpire(value.expireAt) { 167 | return true, nil 168 | } 169 | return false, nil 170 | } 171 | 172 | func (c *MemoryCache) IncrBy(ctx context.Context, key string, value int64) (int64, error) { 173 | c.mu.RLock() 174 | defer c.mu.RUnlock() 175 | var newValue int64 176 | if item, ok := c.items[key]; ok { 177 | if val, err := interfaceToInt64(item.data); err == nil { 178 | newValue = val + value 179 | item.data = newValue 180 | c.items[key] = item 181 | } else { 182 | return 0, ErrWrongDataType 183 | } 184 | } else { 185 | newValue = value 186 | item = Item{ 187 | data: newValue, 188 | expireAt: time.Now().Add(time.Duration(MaxPersistenceTime) * time.Second), 189 | } 190 | c.items[key] = item 191 | } 192 | return newValue, nil 193 | } 194 | 195 | func (c *MemoryCache) DecrBy(ctx context.Context, key string, value int64) (int64, error) { 196 | return c.IncrBy(ctx, key, -value) 197 | } 198 | 199 | func inExpire(u time.Time) bool { 200 | return time.Now().Before(u) 201 | } 202 | 203 | func interfaceToInt64(value interface{}) (int64, error) { 204 | v := fmt.Sprintf("%d", value) 205 | val, err := strconv.Atoi(v) 206 | if err != nil { 207 | return 0, err 208 | } 209 | return int64(val), nil 210 | } 211 | 212 | func init() { 213 | mp = make(map[string]cache.Cache) 214 | cache.Register(cache.DriverTypeMemory, GetMemoryCache) 215 | } 216 | -------------------------------------------------------------------------------- /cache/memorycache/memory_cache_test.go: -------------------------------------------------------------------------------- 1 | package memorycache 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "github.com/qit-team/snow-core/cache" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | var c cache.Cache 12 | 13 | func init() { 14 | c = cache.GetCache("memory", cache.DriverTypeMemory) 15 | } 16 | 17 | func TestGetSetDelete(t *testing.T) { 18 | ctx := context.TODO() 19 | key := "test-cache" 20 | value := "111" 21 | ok, err := c.Set(ctx, key, value) 22 | if err != nil { 23 | t.Error(err) 24 | return 25 | } else if !ok { 26 | t.Error("set is not ok") 27 | return 28 | } 29 | 30 | v, err := c.Get(ctx, key) 31 | if err != nil { 32 | t.Error(err) 33 | return 34 | } else if v != value { 35 | t.Error("get is not same", v) 36 | return 37 | } 38 | 39 | ok, err = c.Delete(ctx, key) 40 | if err != nil { 41 | t.Error(err) 42 | return 43 | } else if !ok { 44 | t.Error("delete is not ok") 45 | return 46 | } 47 | 48 | v, err = c.Get(ctx, key) 49 | if err != nil { 50 | t.Error(err) 51 | return 52 | } else if v != "" { 53 | t.Errorf("delete %s failed", key) 54 | return 55 | } 56 | } 57 | 58 | func TestSetMultiAndGetMulti(t *testing.T) { 59 | ctx := context.TODO() 60 | items := map[string]interface{}{ 61 | "test-key1": "111", 62 | "test-key2": "222", 63 | } 64 | _, err := c.SetMulti(ctx, items, 1) 65 | if err != nil { 66 | t.Error(err) 67 | return 68 | } 69 | 70 | m, err := c.GetMulti(ctx, "test-key1", "test-key2") 71 | if err != nil { 72 | t.Error(err) 73 | return 74 | } else if len(m) != 2 { 75 | t.Error("get values's length is not enough") 76 | return 77 | } 78 | var value interface{} 79 | var ok bool 80 | for k, v := range m { 81 | if value, ok = items[k]; !ok { 82 | t.Errorf("key %s is not exist", k) 83 | return 84 | } 85 | if value != v { 86 | t.Errorf("key %s is not same", k) 87 | return 88 | } 89 | } 90 | 91 | time.Sleep(time.Millisecond * 1100) 92 | m, err = c.GetMulti(ctx, "test-key1", "test-key2") 93 | if err != nil { 94 | t.Error(err) 95 | return 96 | } else if len(m) != 2 { 97 | t.Error("get values's length is not enough") 98 | return 99 | } 100 | 101 | for k, v := range m { 102 | if _, ok = items[k]; !ok { 103 | t.Errorf("key %s is not exist", k) 104 | return 105 | } 106 | if v != "" { 107 | t.Errorf("key %s is not empty", k) 108 | return 109 | } 110 | } 111 | } 112 | 113 | func TestDeleteMulti(t *testing.T) { 114 | ctx := context.TODO() 115 | items := map[string]interface{}{ 116 | "test-key3": "111", 117 | "test-key4": "222", 118 | } 119 | 120 | c.SetMulti(ctx, items) 121 | 122 | _, err := c.DeleteMulti(ctx, "test-key3", "test-key4") 123 | if err != nil { 124 | t.Error(err) 125 | return 126 | } 127 | 128 | var ok bool 129 | m, err := c.GetMulti(ctx, "test-key3", "test-key4") 130 | if err != nil { 131 | t.Error(err) 132 | return 133 | } else if len(m) != 2 { 134 | t.Error("get values's length is not enough") 135 | return 136 | } 137 | 138 | for k, v := range m { 139 | if _, ok = items[k]; !ok { 140 | t.Errorf("key %s is not exist", k) 141 | return 142 | } 143 | if v != "" { 144 | t.Errorf("key %s is not empty", k) 145 | return 146 | } 147 | } 148 | } 149 | 150 | func TestExpireExist(t *testing.T) { 151 | ctx := context.TODO() 152 | key := "test-expire" 153 | value := "222" 154 | c.Set(ctx, key, value) 155 | 156 | ok, err := c.IsExist(ctx, key) 157 | if err != nil { 158 | t.Error(err) 159 | return 160 | } else if !ok { 161 | t.Errorf("key %s is not exist", key) 162 | return 163 | } 164 | 165 | c.Expire(ctx, key, 1) 166 | time.Sleep(time.Millisecond * 1100) 167 | 168 | ok, err = c.IsExist(ctx, key) 169 | if err != nil { 170 | t.Error(err) 171 | return 172 | } else if ok { 173 | t.Errorf("key %s is exist", key) 174 | return 175 | } 176 | } 177 | 178 | func TestMemoryCache_IncrBy(t *testing.T) { 179 | ctx := context.TODO() 180 | key := "test-incr" 181 | value := "ab" 182 | c.Set(ctx, key, value) 183 | 184 | _, err := c.IncrBy(ctx, key, 3) 185 | if !errors.Is(err, ErrWrongDataType) { 186 | t.Errorf("wrong error type %s", err) 187 | return 188 | } 189 | 190 | c.Set(ctx, key, 400) 191 | res, err := c.IncrBy(ctx, key, 3) 192 | if err != nil { 193 | t.Error(err) 194 | return 195 | } else if res != 403 { 196 | t.Errorf("wrong increment %d", res) 197 | return 198 | } 199 | 200 | c.Delete(ctx, key) 201 | res, err = c.IncrBy(ctx, key, -30) 202 | if err != nil { 203 | t.Error(err) 204 | return 205 | } else if res != -30 { 206 | t.Errorf("wrong increment %d", res) 207 | return 208 | } 209 | } 210 | 211 | func TestMemoryCache_DecrBy(t *testing.T) { 212 | ctx := context.TODO() 213 | key := "test-desc" 214 | value := "ab" 215 | c.Set(ctx, key, value) 216 | 217 | _, err := c.DecrBy(ctx, key, 10) 218 | if !errors.Is(err, ErrWrongDataType) { 219 | t.Errorf("wrong error type %s", err) 220 | return 221 | } 222 | 223 | c.Set(ctx, key, 400) 224 | res, err := c.DecrBy(ctx, key, 10) 225 | if err != nil { 226 | t.Error(err) 227 | return 228 | } else if res != 390 { 229 | t.Errorf("wrong decrement %d", res) 230 | return 231 | } 232 | 233 | c.Delete(ctx, key) 234 | res, err = c.DecrBy(ctx, key, -30) 235 | if err != nil { 236 | t.Error(err) 237 | return 238 | } else if res != 30 { 239 | t.Errorf("wrong decrement %d", res) 240 | return 241 | } 242 | } 243 | -------------------------------------------------------------------------------- /cache/rediscache/redis_cache.go: -------------------------------------------------------------------------------- 1 | package rediscache 2 | 3 | import ( 4 | "context" 5 | goredis "github.com/go-redis/redis/v8" 6 | "github.com/qit-team/snow-core/cache" 7 | "github.com/qit-team/snow-core/redis" 8 | "sync" 9 | "time" 10 | ) 11 | 12 | var ( 13 | mp map[string]cache.Cache 14 | mu sync.RWMutex 15 | ) 16 | 17 | type RedisCache struct { 18 | client *goredis.Client 19 | } 20 | 21 | // 实例模式 22 | func newRedisCache(diName string) cache.Cache { 23 | m := new(RedisCache) 24 | m.client = redis.GetRedis(diName) 25 | return m 26 | } 27 | 28 | // 单例模式 29 | func GetRedisCache(diName string) cache.Cache { 30 | key := diName 31 | mu.RLock() 32 | q, ok := mp[key] 33 | mu.RUnlock() 34 | if ok { 35 | return q 36 | } 37 | 38 | q = newRedisCache(diName) 39 | mu.Lock() 40 | mp[key] = q 41 | mu.Unlock() 42 | return q 43 | } 44 | 45 | /** 46 | * 获取缓存key的数据 47 | * 注意事项,如果key值不存在的话,返回的是空字符串,而不是nil 48 | */ 49 | func (c *RedisCache) Get(ctx context.Context, key string) (interface{}, error) { 50 | value, err := c.client.Get(ctx, key).Result() 51 | if err == goredis.Nil { 52 | return "", nil 53 | } 54 | return value, err 55 | } 56 | 57 | func (c *RedisCache) GetMulti(ctx context.Context, keys ...string) (map[string]interface{}, error) { 58 | values, err := c.client.MGet(ctx, keys...).Result() 59 | if err != nil { 60 | return nil, err 61 | } 62 | 63 | arr := make(map[string]interface{}) 64 | for index, key := range keys { 65 | arr[key] = values[index] 66 | } 67 | return arr, nil 68 | } 69 | 70 | func (c *RedisCache) Set(ctx context.Context, key string, value interface{}, ttl ...int) (bool, error) { 71 | t := cache.GetTTLOrDefault(ttl...) 72 | _, err := c.client.SetEX(ctx, key, value, time.Duration(t)*time.Second).Result() 73 | if err != nil { 74 | return false, nil 75 | } 76 | return true, nil 77 | } 78 | 79 | func (c *RedisCache) SetMulti(ctx context.Context, items map[string]interface{}, ttl ...int) (bool, error) { 80 | arr := make([]interface{}, 0) 81 | for key, value := range items { 82 | arr = append(arr, key, value) 83 | } 84 | _, err := c.client.MSet(ctx, arr...).Result() 85 | if err != nil { 86 | return false, err 87 | } 88 | 89 | t := cache.GetTTLOrDefault(ttl...) 90 | if t > 0 { 91 | t64 := int64(t) 92 | for key, _ := range items { 93 | c.client.Expire(ctx, key, time.Duration(t64)*time.Second) 94 | } 95 | } 96 | return true, nil 97 | } 98 | 99 | func (c *RedisCache) Delete(ctx context.Context, key string) (bool, error) { 100 | res, err := c.client.Del(ctx, key).Result() 101 | return res > 0, err 102 | } 103 | 104 | func (c *RedisCache) DeleteMulti(ctx context.Context, keys ...string) (bool, error) { 105 | res, err := c.client.Del(ctx, keys...).Result() 106 | return res > 0, err 107 | } 108 | 109 | func (c *RedisCache) Expire(ctx context.Context, key string, ttl ...int) (bool, error) { 110 | t := cache.GetTTLOrDefault(ttl...) 111 | return c.client.Expire(ctx, key, time.Duration(t)*time.Second).Result() 112 | } 113 | 114 | func (c *RedisCache) IsExist(ctx context.Context, key string) (bool, error) { 115 | num, err := c.client.Exists(ctx, key).Result() 116 | return num == 1, err 117 | } 118 | 119 | func convert(keys []string) []interface{} { 120 | arr := make([]interface{}, len(keys)) 121 | for i, v := range keys { 122 | arr[i] = v 123 | } 124 | return arr 125 | } 126 | 127 | func (c *RedisCache) IncrBy(ctx context.Context, key string, value int64) (int64, error) { 128 | newVal, err := c.client.IncrBy(ctx, key, value).Result() 129 | if err != nil { 130 | return 0, err 131 | } 132 | return newVal, err 133 | } 134 | 135 | func (c *RedisCache) DecrBy(ctx context.Context, key string, value int64) (int64, error) { 136 | newVal, err := c.client.DecrBy(ctx, key, value).Result() 137 | if err != nil { 138 | return 0, err 139 | } 140 | return newVal, err 141 | } 142 | 143 | func init() { 144 | mp = make(map[string]cache.Cache) 145 | cache.Register(cache.DriverTypeRedis, GetRedisCache) 146 | } 147 | -------------------------------------------------------------------------------- /cache/rediscache/redis_cache_test.go: -------------------------------------------------------------------------------- 1 | package rediscache 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/qit-team/snow-core/cache" 7 | "github.com/qit-team/snow-core/config" 8 | "github.com/qit-team/snow-core/redis" 9 | "github.com/qit-team/snow-core/utils" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | var c cache.Cache 15 | var m *cache.BaseCache 16 | var ctx context.Context 17 | 18 | func init() { 19 | var err error 20 | redisConf := config.RedisConfig{ 21 | Master: config.RedisBaseConfig{ 22 | Host: "127.0.0.1", 23 | Port: 6379, 24 | }, 25 | } 26 | 27 | //注册redis类 28 | err = redis.Pr.Register("redis", redisConf) 29 | if err != nil { 30 | fmt.Println(err) 31 | } 32 | 33 | c = cache.GetCache("redis", cache.DriverTypeRedis) 34 | } 35 | 36 | func TestGetSetDelete(t *testing.T) { 37 | c := cache.GetCache("redis", cache.DriverTypeRedis) 38 | ctx := context.TODO() 39 | key := "test-cache" 40 | value := "111" 41 | ok, err := c.Set(ctx, key, value) 42 | if err != nil { 43 | t.Error(err) 44 | return 45 | } else if !ok { 46 | t.Error("set is not ok") 47 | return 48 | } 49 | 50 | v, err := c.Get(ctx, key) 51 | if err != nil { 52 | t.Error(err) 53 | return 54 | } else if v != value { 55 | t.Error("get is not same", v) 56 | return 57 | } 58 | 59 | ok, err = c.Delete(ctx, key) 60 | if err != nil { 61 | t.Error(err) 62 | return 63 | } else if !ok { 64 | t.Error("delete is not ok") 65 | return 66 | } 67 | 68 | v, err = c.Get(ctx, key) 69 | if err != nil { 70 | t.Error(err) 71 | return 72 | } else if v != "" { 73 | t.Errorf("delete %s failed", key) 74 | return 75 | } 76 | } 77 | 78 | func TestSetMultiAndGetMulti(t *testing.T) { 79 | ctx := context.TODO() 80 | items := map[string]interface{}{ 81 | "test-key1": "111", 82 | "test-key2": "222", 83 | } 84 | _, err := c.SetMulti(ctx, items, 1) 85 | if err != nil { 86 | t.Error(err) 87 | return 88 | } 89 | 90 | m, err := c.GetMulti(ctx, "test-key1", "test-key2") 91 | if err != nil { 92 | t.Error(err) 93 | return 94 | } else if len(m) != 2 { 95 | t.Error("get values's length is not enough") 96 | return 97 | } 98 | var value interface{} 99 | var ok bool 100 | for k, v := range m { 101 | if value, ok = items[k]; !ok { 102 | t.Errorf("key %s is not exist", k) 103 | return 104 | } 105 | if value != v { 106 | t.Errorf("key %s is not same", k) 107 | return 108 | } 109 | } 110 | 111 | time.Sleep(time.Millisecond * 1100) 112 | m, err = c.GetMulti(ctx, "test-key1", "test-key2") 113 | if err != nil { 114 | t.Error(err) 115 | return 116 | } else if len(m) != 2 { 117 | t.Error("get values's length is not enough") 118 | return 119 | } 120 | 121 | for k, v := range m { 122 | if _, ok = items[k]; !ok { 123 | t.Errorf("key %s is not exist", k) 124 | return 125 | } 126 | if v != "" { 127 | t.Errorf("key %s is not empty", k) 128 | return 129 | } 130 | } 131 | } 132 | 133 | func TestDeleteMulti(t *testing.T) { 134 | ctx := context.TODO() 135 | items := map[string]interface{}{ 136 | "test-key3": "111", 137 | "test-key4": "222", 138 | } 139 | 140 | c.SetMulti(ctx, items) 141 | 142 | _, err := c.DeleteMulti(ctx, "test-key3", "test-key4") 143 | if err != nil { 144 | t.Error(err) 145 | return 146 | } 147 | 148 | var ok bool 149 | m, err := c.GetMulti(ctx, "test-key3", "test-key4") 150 | if err != nil { 151 | t.Error(err) 152 | return 153 | } else if len(m) != 2 { 154 | t.Error("get values's length is not enough") 155 | return 156 | } 157 | 158 | for k, v := range m { 159 | if _, ok = items[k]; !ok { 160 | t.Errorf("key %s is not exist", k) 161 | return 162 | } 163 | if v != "" { 164 | t.Errorf("key %s is not empty", k) 165 | return 166 | } 167 | } 168 | } 169 | 170 | func TestExpireExist(t *testing.T) { 171 | ctx := context.TODO() 172 | key := "test-expire" 173 | value := "222" 174 | c.Set(ctx, key, value) 175 | 176 | ok, err := c.IsExist(ctx, key) 177 | if err != nil { 178 | t.Error(err) 179 | return 180 | } else if !ok { 181 | t.Errorf("key %s is not exist", key) 182 | return 183 | } 184 | 185 | c.Expire(ctx, key, 1) 186 | time.Sleep(time.Millisecond * 1100) 187 | 188 | ok, err = c.IsExist(ctx, key) 189 | if err != nil { 190 | t.Error(err) 191 | return 192 | } else if ok { 193 | t.Errorf("key %s is exist", key) 194 | return 195 | } 196 | } 197 | 198 | // 测试basecache 199 | func TestBaseCache_Get_Set_IsExist(t *testing.T) { 200 | m := new(cache.BaseCache) 201 | key := "test-snow-" + fmt.Sprint(utils.GetCurrentTime()) 202 | 203 | // key不存在情况下读取数据 204 | s, err := m.Get(ctx, key) 205 | if err != nil { 206 | t.Errorf("Get %s err:%s", key, err.Error()) 207 | return 208 | } else if s != "" { 209 | t.Errorf("Get %s is not empty", key) 210 | return 211 | } 212 | 213 | // 判断key是否存在 214 | ok, err := m.IsExist(ctx, key) 215 | if err != nil { 216 | t.Errorf("IsExist %s err:%s", key, err.Error()) 217 | return 218 | } else if ok { 219 | t.Errorf("IsExist %s is not equal false", key) 220 | return 221 | } 222 | 223 | value := "1" 224 | // 对key进行set操作且过期时间1秒 225 | ok, err = m.Set(ctx, key, value, 1) 226 | if err != nil { 227 | t.Errorf("Set %s err:%s", key, err.Error()) 228 | return 229 | } else if !ok { 230 | t.Errorf("Set %s is not ok", key) 231 | return 232 | } 233 | 234 | // set完之后马上执行get操作 235 | s, _ = m.Get(ctx, key) 236 | if s != value { 237 | t.Errorf("Get %s value(%s) is not equal %s", key, s, value) 238 | return 239 | } 240 | 241 | time.Sleep(time.Second) 242 | 243 | // 一秒之后再取值,因为set时候设置过期时间为1s,如果拿不到值是正常情况 244 | s, _ = m.Get(ctx, key) 245 | if s != "" { 246 | t.Errorf("Get %s is not empty", key) 247 | return 248 | } 249 | } 250 | 251 | func TestBaseCache_Delete(t *testing.T) { 252 | m := new(cache.BaseCache) 253 | key := "test-snow1" + fmt.Sprint(utils.GetCurrentTime()) 254 | value := "1" 255 | m.Set(ctx, key, value) 256 | 257 | ok, err := m.Delete(ctx, key) 258 | if err != nil { 259 | t.Errorf("Delete %s err:%s", key, err.Error()) 260 | return 261 | } else if !ok { 262 | t.Errorf("Delete %s is not ok", key) 263 | return 264 | } 265 | 266 | s, _ := m.Get(ctx, key) 267 | if s != "" { 268 | t.Errorf("Get %s is not empty", key) 269 | return 270 | } 271 | } 272 | 273 | func TestBaseCache_SetMulti_GetMulti_DeleteMulti(t *testing.T) { 274 | m := new(cache.BaseCache) 275 | 276 | time := fmt.Sprint(utils.GetCurrentTime()) 277 | key2 := "test2-snow" + time 278 | key3 := "test3-snow" + time 279 | value := "1" 280 | 281 | items := map[string]interface{}{ 282 | key2: value, 283 | key3: value, 284 | } 285 | ok, err := m.SetMulti(ctx, items, 1) 286 | if err != nil { 287 | t.Errorf("SetMulti err:%s", err.Error()) 288 | return 289 | } else if !ok { 290 | t.Errorf("SetMulti is not ok, keys:%s, %s", key2, key3) 291 | return 292 | } 293 | 294 | retMulti, err := m.GetMulti(ctx, key2, key3) 295 | if err != nil { 296 | t.Errorf("GetMulti err:%s", err.Error()) 297 | return 298 | } else { 299 | for k, v := range retMulti { 300 | if v != value { 301 | t.Errorf("GetMulti %s value(%s) is not equal %s", k, v, value) 302 | return 303 | } 304 | } 305 | } 306 | 307 | ok, err = m.DeleteMulti(ctx, key2, key3) 308 | if err != nil { 309 | t.Errorf("DeleteMulti err:%s", err.Error()) 310 | return 311 | } else if !ok { 312 | t.Errorf("DeleteMulti is not ok, keys:%s, %s", key2, key3) 313 | return 314 | } 315 | 316 | retMulti, err = m.GetMulti(ctx, key2, key3) 317 | if err != nil { 318 | t.Errorf("GetMulti After Delete err:%s", err.Error()) 319 | return 320 | } else { 321 | for k, v := range retMulti { 322 | if v != "" { 323 | t.Errorf("GetMulti After Delete %s value(%s) is not empty", k, v) 324 | return 325 | } 326 | } 327 | } 328 | } 329 | 330 | func TestBaseCache_Expire(t *testing.T) { 331 | m := new(cache.BaseCache) 332 | key := "test-snow5-" + fmt.Sprint(utils.GetCurrentTime()) 333 | 334 | value := "1" 335 | // 对key进行set操作且不设置过期时间 336 | ok, err := m.Set(ctx, key, value) 337 | if err != nil { 338 | t.Errorf("Set %s err:%s", key, err.Error()) 339 | return 340 | } else if !ok { 341 | t.Errorf("Set %s is not ok", key) 342 | return 343 | } 344 | 345 | // 通expire函数设置过期时间 346 | ok, err = m.Expire(ctx, key, 1) 347 | if err != nil { 348 | t.Errorf("Expire %s err:%s", key, err.Error()) 349 | return 350 | } else if !ok { 351 | t.Errorf("Expire %s is not ok", key) 352 | return 353 | } 354 | 355 | // set完之后马上执行get操作 356 | s, _ := m.Get(ctx, key) 357 | if s != value { 358 | t.Errorf("Get after expire %s value(%s) is not equal %s", key, s, value) 359 | return 360 | } 361 | 362 | time.Sleep(time.Second) 363 | 364 | // 一秒之后再取值,expire设置的过期时间为1s,如果拿不到值是正常情况 365 | s, _ = m.Get(ctx, key) 366 | if s != "" { 367 | t.Errorf("Get after expire and wait 1s %s is not empty", key) 368 | return 369 | } 370 | } 371 | 372 | func TestRedisCache_DecrBy(t *testing.T) { 373 | m := new(cache.BaseCache) 374 | key := "test-snow-decr-and-incr-1-" + fmt.Sprint(utils.GetCurrentTime()) 375 | value := 10 376 | ok, err := m.Set(ctx, key, value) 377 | if err != nil { 378 | t.Errorf("Set %s err:%s", key, err.Error()) 379 | return 380 | } else if !ok { 381 | t.Errorf("Set %s is not ok", key) 382 | return 383 | } 384 | 385 | ret, err := m.DecrBy(ctx, key, 3) 386 | if err != nil { 387 | t.Errorf("Decr %s error:%s", key, err.Error()) 388 | } 389 | if ret != 7 { 390 | t.Errorf("Decr ret %s is not ok", key) 391 | } 392 | 393 | // test value is type of string 394 | keyStr := "test-snow-decr-and-incr-2-" + fmt.Sprint(utils.GetCurrentTime()) 395 | valueStr := "10" 396 | ok, err = m.Set(ctx, keyStr, valueStr) 397 | if err != nil { 398 | t.Errorf("Set %s err:%s", keyStr, err.Error()) 399 | return 400 | } else if !ok { 401 | t.Errorf("Set %s is not ok", keyStr) 402 | return 403 | } 404 | 405 | retStr, err := m.DecrBy(ctx, keyStr, 3) 406 | if err != nil { 407 | t.Errorf("DecrStr %s error:%s", keyStr, err.Error()) 408 | } 409 | 410 | if retStr != 7 { 411 | t.Errorf("DecrStr ret %s is not ok", keyStr) 412 | } 413 | } 414 | 415 | func TestRedisCache_IncrBy(t *testing.T) { 416 | m := new(cache.BaseCache) 417 | key := "test-snow-decr-and-incr-1-" + fmt.Sprint(utils.GetCurrentTime()) 418 | value := 10 419 | ok, err := m.Set(ctx, key, value) 420 | if err != nil { 421 | t.Errorf("Set %s err:%s", key, err.Error()) 422 | return 423 | } else if !ok { 424 | t.Errorf("Set %s is not ok", key) 425 | return 426 | } 427 | 428 | ret, err := m.IncrBy(ctx, key, 6) 429 | if err != nil { 430 | t.Errorf("Incr %s error:%s", key, err.Error()) 431 | } 432 | if ret != 16 { 433 | t.Errorf("Incr ret %s is not ok", key) 434 | } 435 | 436 | // test value is type of string 437 | keyStr := "test-snow-decr-and-incr-2-" + fmt.Sprint(utils.GetCurrentTime()) 438 | valueStr := "10" 439 | ok, err = m.Set(ctx, keyStr, valueStr) 440 | if err != nil { 441 | t.Errorf("Set %s err:%s", keyStr, err.Error()) 442 | return 443 | } else if !ok { 444 | t.Errorf("Set %s is not ok", keyStr) 445 | return 446 | } 447 | 448 | retStr, err := m.IncrBy(ctx, keyStr, 6) 449 | if err != nil { 450 | t.Errorf("IncrStr %s error:%s", keyStr, err.Error()) 451 | } 452 | 453 | if retStr != 16 { 454 | t.Errorf("IncrStr ret %s is not ok", keyStr) 455 | } 456 | } 457 | -------------------------------------------------------------------------------- /command/command.go: -------------------------------------------------------------------------------- 1 | package command 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | ) 7 | 8 | var ( 9 | ErrUnknownName = errors.New("unknown name") 10 | ) 11 | 12 | //一次性任务脚本 13 | type Command struct { 14 | mu sync.RWMutex 15 | container map[string]func() 16 | } 17 | 18 | //new实例 19 | func New() *Command { 20 | c := new(Command) 21 | c.container = make(map[string]func()) 22 | return c 23 | } 24 | 25 | //绑定name与函数的关系 26 | func (c *Command) AddFunc(name string, f func()) { 27 | c.mu.Lock() 28 | defer c.mu.Unlock() 29 | c.container[name] = f 30 | } 31 | 32 | //通过name执行函数 33 | func (c *Command) Execute(name string) (err error) { 34 | c.mu.RLock() 35 | f, ok := c.container[name] 36 | c.mu.RUnlock() 37 | if ok { 38 | f() 39 | } else { 40 | panic(ErrUnknownName.Error()) 41 | } 42 | return 43 | } 44 | -------------------------------------------------------------------------------- /command/command_test.go: -------------------------------------------------------------------------------- 1 | package command 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestNew(t *testing.T) { 9 | cmd := New() 10 | cmd.AddFunc("test", test) 11 | cmd.Execute("test") 12 | 13 | defer func() { 14 | if e := recover(); e == nil { 15 | t.Error("unknown name do not panic") 16 | } 17 | }() 18 | cmd.Execute("test1") 19 | } 20 | 21 | func test() { 22 | fmt.Println("run test") 23 | } 24 | -------------------------------------------------------------------------------- /config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/apache/rocketmq-client-go/v2/consumer" 7 | "github.com/apache/rocketmq-client-go/v2/producer" 8 | ) 9 | 10 | type RedisBaseConfig struct { 11 | Host string 12 | Port int 13 | Password string 14 | DB int //第几个库,默认0 15 | } 16 | 17 | type RedisOptionConfig struct { 18 | MaxIdle int 19 | MaxConns int 20 | Wait bool 21 | IdleTimeout time.Duration 22 | ConnectTimeout time.Duration 23 | ReadTimeout time.Duration 24 | WriteTimeout time.Duration 25 | } 26 | 27 | type RedisConfig struct { 28 | Master RedisBaseConfig 29 | Slaves []RedisBaseConfig 30 | Option RedisOptionConfig 31 | } 32 | 33 | type DbBaseConfig struct { 34 | Host string 35 | Port int 36 | User string 37 | Password string 38 | DBName string 39 | } 40 | 41 | type DbOptionConfig struct { 42 | MaxIdle int 43 | MaxConns int 44 | IdleTimeout time.Duration 45 | ConnectTimeout time.Duration 46 | Charset string 47 | } 48 | 49 | type DbConfig struct { 50 | Driver string //驱动类型,目前支持mysql、postgres、mssql、sqlite3 51 | Master DbBaseConfig 52 | Slaves []DbBaseConfig 53 | Option DbOptionConfig 54 | } 55 | 56 | type MnsConfig struct { 57 | Url string 58 | AccessKeyId string 59 | AccessKeySecret string 60 | } 61 | 62 | type LogConfig struct { 63 | Handler string 64 | Level string 65 | Segment bool 66 | Dir string 67 | FileName string 68 | } 69 | 70 | type ApiConfig struct { 71 | Host string 72 | Port int 73 | } 74 | 75 | type AliyunMqConfig struct { 76 | EndPoint string 77 | AccessKey string 78 | SecretKey string 79 | } 80 | 81 | type RocketMqConfig struct { 82 | EndPoint string 83 | AccessKey string 84 | SecretKey string 85 | GroupId string 86 | InstanceId string 87 | ConsumerOptions []consumer.Option 88 | ProducerOptions []producer.Option 89 | } 90 | -------------------------------------------------------------------------------- /db/db.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | //_ "github.com/go-sql-driver/mysql" 5 | //_ "github.com/lib/pq" //postgres 6 | //_ "github.com/mattn/go-sqlite3" //sqlite3 7 | //_ "github.com/denisenkom/go-mssqldb" //mssql 8 | "errors" 9 | "fmt" 10 | "github.com/qit-team/snow-core/config" 11 | "time" 12 | "xorm.io/core" 13 | "xorm.io/xorm" 14 | ) 15 | 16 | const ( 17 | defaultTimeout = 10 18 | defaultCharset = "utf8mb4" 19 | ) 20 | 21 | func NewEngineGroup(dbConf config.DbConfig) (*xorm.EngineGroup, error) { 22 | master, err := newConn(dbConf.Driver, dbConf.Master, dbConf.Option) 23 | if err != nil { 24 | panicConnectionErr(dbConf.Driver, dbConf.Master.Host, dbConf.Master.Port, err) 25 | } 26 | 27 | slaves := make([]*xorm.Engine, len(dbConf.Slaves)) 28 | for k, slaveConf := range dbConf.Slaves { 29 | slave, err := newConn(dbConf.Driver, slaveConf, dbConf.Option) 30 | if err != nil { 31 | panicConnectionErr(dbConf.Driver, slaveConf.Host, slaveConf.Port, err) 32 | } 33 | slaves[k] = slave 34 | } 35 | 36 | return xorm.NewEngineGroup(master, slaves) 37 | } 38 | 39 | func newConn(driver string, base config.DbBaseConfig, option config.DbOptionConfig) (db *xorm.Engine, err error) { 40 | dsn := formatDSN(driver, base, option) 41 | if dsn == "" { 42 | return nil, errors.New(fmt.Sprintf("missing db driver %s or db config", driver)) 43 | } 44 | db, err = xorm.NewEngine(driver, dsn) 45 | if err != nil { 46 | return 47 | } 48 | 49 | //设置表名和字段的映射规则:驼峰转下划线 50 | db.SetMapper(core.SnakeMapper{}) 51 | 52 | //设置资源池等配置 53 | if option.MaxIdle > 0 { 54 | db.SetMaxIdleConns(option.MaxIdle) 55 | } 56 | if option.MaxConns > 0 { 57 | db.SetMaxOpenConns(option.MaxConns) 58 | } 59 | if option.IdleTimeout > 0 { 60 | db.SetConnMaxLifetime(time.Second * option.IdleTimeout) 61 | } 62 | return 63 | } 64 | 65 | /** 66 | * 各驱动的dsn 67 | * @wiki http://gobook.io/read/github.com/go-xorm/manual-zh-CN/chapter-01/ 68 | */ 69 | func formatDSN(driver string, base config.DbBaseConfig, option config.DbOptionConfig) string { 70 | switch driver { 71 | case "mysql": 72 | return formatMysqlDSN(base, option) 73 | case "postgres": 74 | return formatPostgresDSN(base, option) 75 | case "sqlite3": 76 | return formatSqlite3DSN(base, option) 77 | case "mssql": 78 | return formatMssqlDSN(base, option) 79 | } 80 | return "" 81 | } 82 | 83 | //Mysql DSN 84 | func formatMysqlDSN(base config.DbBaseConfig, option config.DbOptionConfig) string { 85 | port := getPortOrDefault(base.Port, 3306) 86 | charset := option.Charset 87 | if charset == "" { 88 | charset = defaultCharset 89 | } 90 | timeout := option.ConnectTimeout 91 | if timeout <= 0 { 92 | timeout = defaultTimeout 93 | } 94 | return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?timeout=%ds&charset=%s&parseTime=true&loc=Local", 95 | base.User, base.Password, base.Host, port, base.DBName, timeout, charset) 96 | } 97 | 98 | //PostgreSQL DSN 99 | func formatPostgresDSN(base config.DbBaseConfig, option config.DbOptionConfig) string { 100 | port := getPortOrDefault(base.Port, 5432) 101 | return fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s", 102 | base.Host, port, base.User, base.DBName, base.Password) 103 | } 104 | 105 | //qlite3 DSN 106 | func formatSqlite3DSN(base config.DbBaseConfig, option config.DbOptionConfig) string { 107 | return base.DBName 108 | } 109 | 110 | //SQL Server DSN 111 | func formatMssqlDSN(base config.DbBaseConfig, option config.DbOptionConfig) string { 112 | port := getPortOrDefault(base.Port, 1433) 113 | return fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s", 114 | base.User, base.Password, base.Host, port, base.DBName) 115 | } 116 | 117 | func getPortOrDefault(port int, defaultPort int) int { 118 | if port == 0 { 119 | return defaultPort 120 | } 121 | return port 122 | } 123 | 124 | func panicConnectionErr(driver string, host string, port int, err error) { 125 | panic(fmt.Sprintf("%s connect error %s:%d, error:%v", driver, host, port, err)) 126 | } 127 | -------------------------------------------------------------------------------- /db/db_test.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "fmt" 5 | "github.com/qit-team/snow-core/config" 6 | "testing" 7 | "xorm.io/xorm" 8 | //go test时需要开启 9 | _ "github.com/go-sql-driver/mysql" 10 | ) 11 | 12 | var engineGroup *xorm.EngineGroup 13 | 14 | /** 15 | * Banner实体 16 | */ 17 | type Banner struct { 18 | Id int64 `xorm:"pk autoincr"` 19 | Pid int 20 | Title string 21 | ImageUrl string `xorm:"'img_url'"` 22 | } 23 | 24 | /** 25 | * 表名规则 26 | */ 27 | func (m *Banner) TableName() string { 28 | return "banner" 29 | } 30 | 31 | func init() { 32 | dbInit(true) 33 | } 34 | 35 | func dbInit(lazyBool bool) { 36 | m := config.DbBaseConfig{ 37 | Host: "127.0.0.1", 38 | Port: 3306, 39 | User: "root", 40 | Password: "Snow_123", 41 | DBName: "test", 42 | } 43 | dbConf := config.DbConfig{ 44 | Driver: "mysql", 45 | Master: m, 46 | } 47 | 48 | err := Pr.Register("db", dbConf, lazyBool) 49 | if err != nil { 50 | fmt.Println(err) 51 | } 52 | 53 | engineGroup = GetDb() 54 | } 55 | 56 | func TestGet(t *testing.T) { 57 | banner := new(Banner) 58 | // sql是否打印开关 59 | //engineGroup.ShowSQL(true) 60 | _, err := engineGroup.ID(1).Get(banner) 61 | 62 | if err != nil { 63 | t.Errorf("get error: %v", err) 64 | return 65 | } 66 | 67 | fmt.Println(banner) 68 | } 69 | 70 | func TestProvider_Provides(t *testing.T) { 71 | retList := Pr.Provides() 72 | if len(retList) == 0 { 73 | t.Error("Provides empty") 74 | return 75 | } 76 | 77 | for k, v := range retList { 78 | fmt.Println("Provides list", k, v) 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /db/model.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "errors" 5 | "xorm.io/xorm" 6 | ) 7 | 8 | var ( 9 | ErrIdsEmpty = errors.New("ids is empty") 10 | ) 11 | 12 | /** 13 | * 基础model 14 | */ 15 | type Model struct { 16 | DiName string //依赖注入的别名 17 | } 18 | 19 | /** 20 | * 获取数据库实例 21 | * @wiki http://gobook.io/read/github.com/go-xorm/manual-zh-CN/chapter-02/4.columns.html 22 | */ 23 | func (m *Model) GetDb(args ...string) *xorm.EngineGroup { 24 | if len(args) > 0 { 25 | return GetDb(args[0]) 26 | } else if m.DiName != "" { 27 | return GetDb(m.DiName) 28 | } else { 29 | return GetDb() 30 | } 31 | } 32 | 33 | /** 34 | * 查询主键ID的记录 35 | * @param id 主键ID 36 | * @param bean 数据结构实体 37 | * @return has 是否有记录 38 | */ 39 | func (m *Model) GetOne(id interface{}, bean interface{}) (has bool, err error) { 40 | return m.GetDb().ID(id).Get(bean) 41 | } 42 | 43 | /** 44 | * 查询多个主键ID的记录 45 | * @param ids 主键ID分片 46 | * @param beans 数据结构实体分片 47 | */ 48 | func (m *Model) GetMulti(ids []interface{}, beans interface{}) error { 49 | if len(ids) == 0 { 50 | return ErrIdsEmpty 51 | } 52 | return m.GetDb().In("id", ids...).Find(beans) 53 | } 54 | 55 | /** 56 | * 插入记录 57 | * @param beans... 可支持插入连续多个记录 58 | */ 59 | func (m *Model) Insert(beans ...interface{}) (int64, error) { 60 | return m.GetDb().Insert(beans...) 61 | } 62 | 63 | /** 64 | * 更新某个主键ID的数据 65 | * @param id 主键ID 66 | * @param bean 数据结构实体 67 | * @param mustColumns... 因为默认Update只更新非0,非”“,非bool的字段,需要配合此字段 68 | * @param 69 | */ 70 | func (m *Model) Update(id interface{}, bean interface{}, mustColumns ...string) (int64, error) { 71 | if len(mustColumns) > 0 { 72 | return m.GetDb().MustCols(mustColumns...).ID(id).Update(bean) 73 | } else { 74 | return m.GetDb().ID(id).Update(bean) 75 | } 76 | } 77 | 78 | /** 79 | * 删除单个记录 -- 如果有开启delete特性,会触发软删除 80 | * @param id 主键ID 81 | * @param bean 数据结构实体 82 | */ 83 | func (m *Model) Delete(id interface{}, bean interface{}) (int64, error) { 84 | return m.GetDb().ID(id).Delete(bean) 85 | } 86 | 87 | /** 88 | * 查询多个主键ID的记录 89 | * @param ids 主键ID分片 90 | * @param bean 数据结构实体 91 | */ 92 | func (m *Model) DeleteMulti(ids []interface{}, bean interface{}) (int64, error) { 93 | if len(ids) == 0 { 94 | return 0, ErrIdsEmpty 95 | } 96 | return m.GetDb().In("id", ids...).Delete(bean) 97 | } 98 | 99 | /** 100 | * 查询多个主键ID的记录 101 | * @param beans 数据结构实体分片 eg. &banners 其中 banners := make([]*Banner, 0) 102 | * @params sql eg. "age > ? or name = ?" 103 | * @params values eg. []interfaces{}{30, "hts"} 104 | * @Param []int limit 可选 eg. []int{} 不限量 []int{30} 前30个 []int{30, 20} 从第20个后的前30个 105 | * @param string order 可选 eg. "id desc" 单个 "uid desc,status asc" 多个 106 | */ 107 | func (m *Model) GetList(beans interface{}, sql string, values []interface{}, args ...interface{}) (err error) { 108 | if len(args) > 0 { 109 | var ( 110 | order string 111 | limit int 112 | start int 113 | ) 114 | 115 | limits, ok := args[0].([]int) 116 | if ok && len(limits) > 0 { 117 | limit = limits[0] 118 | if len(limits) > 1 { 119 | start = limits[1] 120 | } 121 | } 122 | 123 | if len(args) > 1 { 124 | order, _ = args[1].(string) 125 | } 126 | 127 | return m.GetDb().Where(sql, values...).OrderBy(order).Limit(limit, start).Find(beans) 128 | } else { 129 | return m.GetDb().Where(sql, values...).Find(beans) 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /db/model_test.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "fmt" 5 | "github.com/qit-team/snow-core/config" 6 | "testing" 7 | 8 | //go test时需要开启 9 | _ "github.com/go-sql-driver/mysql" 10 | ) 11 | 12 | type bannerModel struct { 13 | Model 14 | } 15 | 16 | func init() { 17 | m := config.DbBaseConfig{ 18 | Host: "127.0.0.1", 19 | Port: 3306, 20 | User: "root", 21 | Password: "Snow_123", 22 | DBName: "test", 23 | } 24 | dbConf := config.DbConfig{ 25 | Driver: "mysql", 26 | Master: m, 27 | } 28 | 29 | err := Pr.Register("db", dbConf, true) 30 | if err != nil { 31 | fmt.Println(err) 32 | } 33 | 34 | engineGroup = GetDb() 35 | } 36 | 37 | func TestGetOne(t *testing.T) { 38 | model := new(bannerModel) 39 | ret := new(Banner) 40 | id := 1 41 | _, err := model.GetOne(id, ret) 42 | if err != nil { 43 | t.Errorf("getOne error: %v", err) 44 | return 45 | } 46 | fmt.Println("getOne.Ret", ret) 47 | } 48 | 49 | func TestGetMulti(t *testing.T) { 50 | model := new(bannerModel) 51 | ret := make([]*Banner, 0) 52 | var idList = []interface{}{1, 2} 53 | err := model.GetMulti(idList, &ret) 54 | if err != nil { 55 | t.Errorf("getMulti error: %v", err) 56 | return 57 | } 58 | for _, v := range ret { 59 | fmt.Println("getMulti.ItemRet", v) 60 | } 61 | 62 | // 验证异常处理if分支 63 | var idListErr []interface{} 64 | err = model.GetMulti(idListErr, &ret) 65 | fmt.Println("getMulti.CheckExceptionBranch:", err) 66 | } 67 | 68 | func TestInsert(t *testing.T) { 69 | 70 | model := new(bannerModel) 71 | banner := new(Banner) 72 | banner.Id = 4 73 | banner.ImageUrl = "img666" 74 | banner.Pid = 66666 75 | banner.Title = "test insert" 76 | 77 | _, err := model.Insert(banner) 78 | if err != nil { 79 | t.Errorf("Insert error: %v", err) 80 | return 81 | } 82 | fmt.Println("Insert.Id", banner.Id) 83 | 84 | // 插入数据 为了测试批量删除功能,见函数:TestDeleteMulti 85 | banner.Id = 5 86 | model.Insert(banner) 87 | 88 | banner.Id = 6 89 | model.Insert(banner) 90 | } 91 | 92 | func TestUpdate(t *testing.T) { 93 | model := new(bannerModel) 94 | banner := new(Banner) 95 | banner.ImageUrl = "" 96 | banner.Pid = 77777 97 | banner.Title = "test update" 98 | var id = 7 99 | // 注意:直接用默认的update对上面的ImageUrl字段不会更新 100 | _, err := model.Update(id, banner) 101 | if err != nil { 102 | t.Errorf("Update error: %v", err) 103 | return 104 | } 105 | fmt.Println("Update.success") 106 | 107 | id = 8 108 | banner.Pid = 888 109 | banner.ImageUrl = "" 110 | banner.Title = "" 111 | // xorm默认对更新字段数据为""的不会执行,需要加mustColumns,这样保证为空的数据字段也能更新,详情搜索xorm手册 112 | _, err = model.Update(id, banner, "img_url", "title") 113 | 114 | if err != nil { 115 | t.Errorf("Update mustColumns error: %v", err) 116 | return 117 | } 118 | fmt.Println("Update mustColumns.success") 119 | } 120 | 121 | func TestDelete(t *testing.T) { 122 | model := new(bannerModel) 123 | banner := new(Banner) 124 | id := 4 125 | ret, err := model.Delete(id, banner) 126 | 127 | if err != nil { 128 | t.Errorf("Delete error: %v", err) 129 | return 130 | } 131 | fmt.Println("Delete.ret", ret) 132 | } 133 | 134 | func TestDeleteMulti(t *testing.T) { 135 | model := new(bannerModel) 136 | banner := new(Banner) 137 | var id = []interface{}{5, 6} 138 | // 批量删除id 为5,6的 数据来源参考TestInsert 139 | ret, err := model.DeleteMulti(id, banner) 140 | 141 | if err != nil { 142 | t.Errorf("DeleteMulti error: %v", err) 143 | return 144 | } 145 | fmt.Println("DeleteMulti.ret", ret) 146 | 147 | // 测试参数为空的异常分支 148 | var idErr []interface{} 149 | _, err = model.DeleteMulti(idErr, banner) 150 | fmt.Println("DeleteMulti.CheckExceptionBranch.ret", err) 151 | } 152 | 153 | func TestGetList(t *testing.T) { 154 | model := new(bannerModel) 155 | banner := make([]*Banner, 0) 156 | 157 | sql := "status > ? and status < ? and pid = ?" 158 | var values = []interface{}{"1", "5", 10010} 159 | err := model.GetList(&banner, sql, values) 160 | if err != nil { 161 | t.Errorf("Getlist error: %v", err) 162 | return 163 | } 164 | for _, v := range banner { 165 | fmt.Println("GetList.ret", v) 166 | } 167 | 168 | // 测试其他if分支 覆盖getList所有代码 169 | banner1 := make([]*Banner, 0) 170 | 171 | sql = "status >= ? and status <= ?" 172 | var valuesTest = []interface{}{"1", "7"} 173 | err = model.GetList(&banner1, sql, valuesTest, []int{3, 3}, "pid desc") 174 | if err != nil { 175 | t.Errorf("GetlistLimitAndOrderBranch error: %v", err) 176 | return 177 | } 178 | for _, v := range banner1 { 179 | fmt.Println("GetlistLimitAndOrderBranch.ret", v) 180 | } 181 | } 182 | 183 | func TestProvider_Close(t *testing.T) { 184 | // 关闭链接,此时再执行sql都无法执行会报 sql: database is closed, 所以在sql执行完之后做close操作 185 | err := Pr.Close() 186 | 187 | if err != nil { 188 | t.Error("Close Fail") 189 | return 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /db/provider.go: -------------------------------------------------------------------------------- 1 | package db 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/qit-team/snow-core/config" 7 | "github.com/qit-team/snow-core/helper" 8 | "github.com/qit-team/snow-core/kernel/container" 9 | "sync" 10 | "xorm.io/xorm" 11 | ) 12 | 13 | const ( 14 | SingletonMain = "db" 15 | ) 16 | 17 | var Pr *provider 18 | 19 | func init() { 20 | Pr = new(provider) 21 | Pr.mp = make(map[string]interface{}) 22 | } 23 | 24 | type provider struct { 25 | mu sync.RWMutex 26 | mp map[string]interface{} //配置 27 | dn string //default name 28 | } 29 | 30 | /** 31 | * @param string 依赖注入别名 必选 32 | * @param config.LogConfig 配置 必选 33 | * @param bool 是否启用懒加载 可选 34 | */ 35 | func (p *provider) Register(args ...interface{}) (err error) { 36 | diName, lazy, err := helper.TransformArgs(args...) 37 | if err != nil { 38 | return 39 | } 40 | 41 | conf, ok := args[1].(config.DbConfig) 42 | if !ok { 43 | return errors.New("args[1] is not config.DbConfig") 44 | } 45 | 46 | p.mu.Lock() 47 | p.mp[diName] = args[1] 48 | if len(p.mp) == 1 { 49 | p.dn = diName 50 | } 51 | p.mu.Unlock() 52 | 53 | if !lazy { 54 | _, err = setSingleton(diName, conf) 55 | } 56 | return 57 | } 58 | 59 | //注册过的别名 60 | func (p *provider) Provides() []string { 61 | p.mu.RLock() 62 | defer p.mu.RUnlock() 63 | 64 | return helper.MapToArray(p.mp) 65 | } 66 | 67 | //释放资源 68 | func (p *provider) Close() error { 69 | arr := p.Provides() 70 | for _, k := range arr { 71 | c := getSingleton(k, false) 72 | if c != nil { 73 | c.Close() 74 | } 75 | } 76 | return nil 77 | } 78 | 79 | //注入单例 80 | func setSingleton(diName string, conf config.DbConfig) (ins *xorm.EngineGroup, err error) { 81 | ins, err = NewEngineGroup(conf) 82 | if err == nil { 83 | container.App.SetSingleton(diName, ins) 84 | } 85 | return 86 | } 87 | 88 | //获取单例 89 | func getSingleton(diName string, lazy bool) *xorm.EngineGroup { 90 | rc := container.App.GetSingleton(diName) 91 | if rc != nil { 92 | return rc.(*xorm.EngineGroup) 93 | } 94 | if lazy == false { 95 | return nil 96 | } 97 | 98 | Pr.mu.RLock() 99 | conf, ok := Pr.mp[diName].(config.DbConfig) 100 | Pr.mu.RUnlock() 101 | if !ok { 102 | panic(fmt.Sprintf("db di_name:%s not exist", diName)) 103 | } 104 | 105 | ins, err := setSingleton(diName, conf) 106 | if err != nil { 107 | panic(fmt.Sprintf("db di_name:%s err:%s", diName, err.Error())) 108 | } 109 | return ins 110 | } 111 | 112 | //外部通过注入别名获取资源,解耦资源的关系 113 | func GetDb(args ...string) *xorm.EngineGroup { 114 | diName := helper.GetDiName(Pr.dn, args...) 115 | return getSingleton(diName, true) 116 | } 117 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/qit-team/snow-core 2 | 3 | go 1.12 4 | 5 | require ( 6 | github.com/aliyun/aliyun-mns-go-sdk v1.0.2 7 | github.com/aliyunmq/mq-http-go-sdk v1.0.3 8 | github.com/apache/rocketmq-client-go/v2 v2.1.1 9 | github.com/bytedance/sonic v1.10.0-rc3 // indirect 10 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 11 | github.com/emirpasic/gods v1.18.1 // indirect 12 | github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 // indirect 13 | github.com/fsnotify/fsnotify v1.5.1 // indirect 14 | github.com/fvbock/endless v0.0.0-20170109170031-447134032cb6 15 | github.com/gin-gonic/gin v1.9.1 16 | github.com/go-playground/validator/v10 v10.14.1 // indirect 17 | github.com/go-redis/redis/v8 v8.11.5 18 | github.com/go-sql-driver/mysql v1.7.1 19 | github.com/gogap/errors v0.0.0-20210818113853-edfbba0ddea9 20 | github.com/gogap/stack v0.0.0-20150131034635-fef68dddd4f8 // indirect 21 | github.com/golang/mock v1.6.0 // indirect 22 | github.com/google/uuid v1.3.0 23 | github.com/gopherjs/gopherjs v0.0.0-20211111143520-d0d5ecc1a356 // indirect 24 | github.com/hetiansu5/accesslog v1.0.0 25 | github.com/hetiansu5/cores v1.0.0 26 | github.com/jonboulle/clockwork v0.2.2 // indirect 27 | github.com/klauspost/compress v1.16.7 // indirect 28 | github.com/klauspost/cpuid/v2 v2.2.5 // indirect 29 | github.com/kr/pretty v0.3.0 // indirect 30 | github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible 31 | github.com/lestrrat-go/strftime v1.0.5 // indirect 32 | github.com/pelletier/go-toml/v2 v2.0.9 // indirect 33 | github.com/pkg/errors v0.9.1 // indirect 34 | github.com/qit-team/work v0.3.11 35 | github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5 36 | github.com/robfig/cron v1.2.0 37 | github.com/rogpeppe/go-internal v1.8.0 // indirect 38 | github.com/sirupsen/logrus v1.9.3 39 | github.com/smartystreets/goconvey v1.7.2 // indirect 40 | github.com/tidwall/gjson v1.15.0 // indirect 41 | github.com/tidwall/pretty v1.2.1 // indirect 42 | github.com/ugorji/go v1.2.11 // indirect 43 | github.com/valyala/fasthttp v1.48.0 // indirect 44 | github.com/valyala/fasttemplate v1.2.1 // indirect 45 | go.uber.org/atomic v1.11.0 // indirect 46 | golang.org/x/arch v0.4.0 // indirect 47 | golang.org/x/net v0.13.0 // indirect 48 | google.golang.org/protobuf v1.31.0 // indirect 49 | gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect 50 | modernc.org/ccgo/v3 v3.12.95 // indirect 51 | modernc.org/tcl v1.9.2 // indirect 52 | xorm.io/builder v0.3.13 // indirect 53 | xorm.io/core v0.7.3 54 | xorm.io/xorm v1.3.2 55 | ) 56 | -------------------------------------------------------------------------------- /helper/helper.go: -------------------------------------------------------------------------------- 1 | package helper 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | func GetDiName(defaultName string, args ...string) string { 8 | var name string 9 | if len(args) > 0 { 10 | name = args[0] 11 | } 12 | if name == "" { 13 | return defaultName 14 | } 15 | return name 16 | } 17 | 18 | func TransformArgs(args ...interface{}) (diName string, lazy bool, err error) { 19 | if len(args) < 2 { 20 | err = errors.New("args is not enough") 21 | return 22 | } 23 | 24 | var ok bool 25 | diName, ok = args[0].(string) 26 | if !ok { 27 | err = errors.New("args[0] is not string") 28 | return 29 | } 30 | 31 | if len(args) > 2 { 32 | lazy, _ = args[2].(bool) 33 | } 34 | return 35 | } 36 | 37 | func MapToArray(mp map[string]interface{}) []string { 38 | arr := make([]string, len(mp)) 39 | i := 0 40 | for k := range mp { 41 | arr[i] = k 42 | i++ 43 | } 44 | return arr 45 | } 46 | -------------------------------------------------------------------------------- /helper/helper_test.go: -------------------------------------------------------------------------------- 1 | package helper 2 | 3 | import "testing" 4 | 5 | func TestTransformArgs(t *testing.T) { 6 | _, _, err := TransformArgs("1") 7 | if err == nil { 8 | t.Error("length of args should be checked") 9 | return 10 | } 11 | 12 | _, _, err = TransformArgs(1, "", true) 13 | if err == nil { 14 | t.Error("args[0] should be string") 15 | return 16 | } 17 | 18 | diName, lazy, err := TransformArgs("1", "", true) 19 | if err != nil { 20 | t.Error(err) 21 | return 22 | } else if diName != "1" { 23 | t.Error("diName is not match") 24 | return 25 | } else if lazy != true { 26 | t.Error("lazy is not match") 27 | return 28 | } 29 | 30 | } 31 | 32 | func TestGetDiName(t *testing.T) { 33 | dn := "dn" 34 | a1 := GetDiName(dn) 35 | if a1 != dn { 36 | t.Error("must be default") 37 | return 38 | } 39 | 40 | a2 := GetDiName(dn, "22") 41 | if a2 != "22" { 42 | t.Error("must be args[0]") 43 | return 44 | } 45 | } 46 | 47 | func TestMapToArray(t *testing.T) { 48 | mp := map[string]interface{}{ 49 | "a1": 1, 50 | "b2": "bbd", 51 | } 52 | arr := MapToArray(mp) 53 | if len(arr) != 2 { 54 | t.Error("length of array is not equal 2") 55 | return 56 | } 57 | 58 | if arr[0] == "a1" { 59 | if arr[1] != "b2" { 60 | t.Error("part result of array is error") 61 | return 62 | } 63 | } else if arr[0] == "b2" { 64 | if arr[1] != "a1" { 65 | t.Error("part result of array is error") 66 | return 67 | } 68 | } else { 69 | t.Error("result of array is error") 70 | return 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /http/ctxkit/ctxkit.go: -------------------------------------------------------------------------------- 1 | package ctxkit 2 | 3 | import ( 4 | "context" 5 | "crypto/md5" 6 | "fmt" 7 | "github.com/gin-gonic/gin" 8 | "github.com/qit-team/snow-core/utils" 9 | "strings" 10 | ) 11 | 12 | const ( 13 | TraceId = "x-trace-id" 14 | ClientIp = "x-cip" 15 | ServerIp = "x-sip" 16 | HOST = "x-host" 17 | ) 18 | 19 | func SetTraceId(ctx context.Context, value string) context.Context { 20 | var newCtx context.Context 21 | if ctxGin, ok := ctx.(*gin.Context); ok { 22 | newCtx = SetGinTraceId(ctxGin.Request.Context(), value) 23 | ctxGin.Request = ctxGin.Request.WithContext(newCtx) 24 | } else { 25 | newCtx = SetGinTraceId(ctx, value) 26 | } 27 | return newCtx 28 | } 29 | 30 | func SetGinTraceId(ctx context.Context, value string) context.Context { 31 | return context.WithValue(ctx, TraceId, value) 32 | } 33 | 34 | func GetTraceId(ctx context.Context) string { 35 | if ctxGin, ok := ctx.(*gin.Context); ok { 36 | ctx = ctxGin.Request.Context() 37 | } 38 | s, _ := ctx.Value(TraceId).(string) 39 | return s 40 | } 41 | 42 | func SetClientId(ctx context.Context, value string) context.Context { 43 | var newCtx context.Context 44 | if ctxGin, ok := ctx.(*gin.Context); ok { 45 | newCtx = SetGinClientId(ctxGin.Request.Context(), value) 46 | ctxGin.Request = ctxGin.Request.WithContext(SetGinClientId(ctxGin.Request.Context(), value)) 47 | } else { 48 | newCtx = SetGinClientId(ctx, value) 49 | } 50 | return newCtx 51 | } 52 | 53 | func SetGinClientId(ctx context.Context, value string) context.Context { 54 | return context.WithValue(ctx, ClientIp, value) 55 | } 56 | 57 | func GetClientId(ctx context.Context) string { 58 | if ctxGin, ok := ctx.(*gin.Context); ok { 59 | ctx = ctxGin.Request.Context() 60 | } 61 | s, _ := ctx.Value(ClientIp).(string) 62 | return s 63 | } 64 | 65 | func SetServerId(ctx context.Context, value string) context.Context { 66 | var newCtx context.Context 67 | if ctxGin, ok := ctx.(*gin.Context); ok { 68 | newCtx = SetGinServerId(ctxGin.Request.Context(), value) 69 | ctxGin.Request = ctxGin.Request.WithContext(newCtx) 70 | } else { 71 | newCtx = SetGinServerId(ctx, value) 72 | } 73 | return newCtx 74 | } 75 | 76 | func SetGinServerId(ctx context.Context, value string) context.Context { 77 | return context.WithValue(ctx, ServerIp, value) 78 | } 79 | 80 | func GetServerId(ctx context.Context) string { 81 | if ctxGin, ok := ctx.(*gin.Context); ok { 82 | ctx = ctxGin.Request.Context() 83 | } 84 | s, _ := ctx.Value(ServerIp).(string) 85 | return s 86 | } 87 | 88 | // param to change 89 | func SetHost(ctx context.Context, value string) context.Context { 90 | var newCtx context.Context 91 | if ctxGin, ok := ctx.(*gin.Context); ok { 92 | newCtx = SetGinHost(ctxGin.Request.Context(), value) 93 | ctxGin.Request = ctxGin.Request.WithContext(newCtx) 94 | } else { 95 | newCtx = SetGinHost(ctx, value) 96 | } 97 | return newCtx 98 | } 99 | 100 | func SetGinHost(ctx context.Context, value string) context.Context { 101 | return context.WithValue(ctx, ServerIp, value) 102 | } 103 | 104 | func GetHost(ctx context.Context) string { 105 | if ctxGin, ok := ctx.(*gin.Context); ok { 106 | ctx = ctxGin.Request.Context() 107 | } 108 | s, _ := ctx.Value(HOST).(string) 109 | return s 110 | } 111 | 112 | //var once sync.Once 113 | func GenerateTraceId(ctx context.Context) (string, context.Context) { 114 | randomId := utils.GenUUID() 115 | mdTemp := md5.Sum([]byte(randomId)) 116 | mdCode := fmt.Sprintf("%x", mdTemp) 117 | mdStr := strings.ToUpper(mdCode) 118 | 119 | var traceId = mdStr 120 | if len(mdStr) >= 32 { 121 | traceId = mdStr[0:8] + "-" + mdStr[8:12] + "-" + mdStr[12:16] + "-" + mdStr[16:20] + "-" + mdStr[20:32] 122 | } 123 | return traceId, SetTraceId(ctx, traceId) 124 | } 125 | -------------------------------------------------------------------------------- /http/ctxkit/ctxkit_test.go: -------------------------------------------------------------------------------- 1 | package ctxkit 2 | 3 | import ( 4 | "fmt" 5 | "github.com/gin-gonic/gin" 6 | "testing" 7 | ) 8 | 9 | var c *gin.Context 10 | 11 | func init() { 12 | c = &gin.Context{} 13 | } 14 | 15 | func TestGetClientId(t *testing.T) { 16 | v := "1" 17 | SetClientId(c, v) 18 | v1 := GetClientId(c) 19 | if v1 != v { 20 | t.Error("ClientId miss match") 21 | return 22 | } 23 | } 24 | 25 | func TestGetTraceId(t *testing.T) { 26 | v := "2" 27 | SetTraceId(c, v) 28 | v1 := GetTraceId(c) 29 | fmt.Println("======traceId", v1) 30 | //logger.Info(c, "====testTrace") 31 | GenerateTraceId(c) 32 | v2 := GetTraceId(c) 33 | fmt.Println("======generateTraceId", v2) 34 | if v1 != v { 35 | t.Error("TraceId miss match") 36 | return 37 | } 38 | } 39 | 40 | func TestGetHost(t *testing.T) { 41 | v := "3" 42 | SetHost(c, v) 43 | v1 := GetHost(c) 44 | if v1 != v { 45 | t.Error("Host miss match") 46 | return 47 | } 48 | } 49 | 50 | func TestGetServerId(t *testing.T) { 51 | v := "4" 52 | SetServerId(c, v) 53 | v1 := GetServerId(c) 54 | if v1 != v { 55 | t.Error("ServerId miss match") 56 | return 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /http/middleware/access_log.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "github.com/gin-gonic/gin" 5 | "github.com/hetiansu5/accesslog" 6 | "github.com/qit-team/snow-core/log/accesslogger" 7 | "time" 8 | ) 9 | 10 | func AccessLog() gin.HandlerFunc { 11 | return func(c *gin.Context) { 12 | //忽略HEAD探针的日志 13 | if c.Request.Method != "HEAD" { 14 | AccessLogFunc(accesslogger.GetAccessLogger())(c) 15 | } 16 | } 17 | } 18 | 19 | // AccessLogFunc 用于记录 http access log 20 | func AccessLogFunc(accessLogger *accesslog.AccessLogger) gin.HandlerFunc { 21 | return func(c *gin.Context) { 22 | receivedAt := time.Now() 23 | originalWriter := c.Writer 24 | proxyWriter := newResponseWriter(c.Writer) 25 | c.Writer = proxyWriter.(gin.ResponseWriter) 26 | // Process request 27 | if c != nil { 28 | c.Next() 29 | } 30 | accessLogger.Log(proxyWriter, c.Request, receivedAt, time.Since(receivedAt)) 31 | c.Writer = originalWriter 32 | } 33 | } 34 | 35 | type ResponseWriter struct { 36 | gin.ResponseWriter 37 | fbt time.Time 38 | } 39 | 40 | func (rw *ResponseWriter) FirstByteTime() time.Time { 41 | return rw.fbt 42 | } 43 | 44 | func (rw *ResponseWriter) WriteHeaderNow() { 45 | rw.ResponseWriter.WriteHeaderNow() 46 | if rw.fbt.IsZero() { 47 | rw.fbt = time.Now() 48 | } 49 | } 50 | 51 | func (rw *ResponseWriter) Write(data []byte) (n int, err error) { 52 | rw.WriteHeaderNow() 53 | return rw.ResponseWriter.Write(data) 54 | } 55 | 56 | func (rw *ResponseWriter) WriteString(s string) (n int, err error) { 57 | rw.WriteHeaderNow() 58 | return rw.ResponseWriter.WriteString(s) 59 | } 60 | 61 | func newResponseWriter(writer gin.ResponseWriter) accesslog.ResponseWriter { 62 | return &ResponseWriter{ResponseWriter: writer} 63 | } 64 | -------------------------------------------------------------------------------- /http/middleware/access_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "fmt" 5 | "github.com/gin-gonic/gin" 6 | "github.com/hetiansu5/accesslog" 7 | "testing" 8 | ) 9 | 10 | var handle, handleFunc gin.HandlerFunc 11 | var accessLogger *accesslog.AccessLogger 12 | 13 | var reponseWriter accesslog.ResponseWriter 14 | 15 | func init() { 16 | handle = AccessLog() 17 | handleFunc = AccessLogFunc(accessLogger) 18 | var w gin.ResponseWriter 19 | reponseWriter = newResponseWriter(w) 20 | } 21 | 22 | func TestResponseWriter(t *testing.T) { 23 | ret := reponseWriter.FirstByteTime() 24 | fmt.Println("FirstByteTime", ret) 25 | } 26 | -------------------------------------------------------------------------------- /http/middleware/ctxkit.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "github.com/gin-gonic/gin" 5 | "github.com/qit-team/snow-core/http/ctxkit" 6 | ) 7 | 8 | func GenContextKit(c *gin.Context) { 9 | ctxkit.SetClientId(c, c.ClientIP()) 10 | ctxkit.SetServerId(c, c.Request.RemoteAddr) 11 | ctxkit.SetHost(c, c.Request.Host) 12 | traceId := c.GetHeader("X-TRACE-ID") 13 | if traceId != "" { 14 | c.Request = c.Request.WithContext(ctxkit.SetTraceId(c, traceId)) 15 | } else { 16 | _, ctx := ctxkit.GenerateTraceId(c) 17 | c.Request = c.Request.WithContext(ctx) 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /http/middleware/ctxkit_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "github.com/gin-gonic/gin" 5 | "testing" 6 | ) 7 | 8 | var c *gin.Context 9 | 10 | func init() { 11 | c = &gin.Context{} 12 | } 13 | 14 | func Test_GenContextKit(t *testing.T) { 15 | //c.Header("X-Forwarded-For", "127.0.0.111") 16 | //c1 := gin.Context{} 17 | ////fmt.Println("c.engine.ForwardedByClientIP", c.requestHeader("X-Forwarded-For")) 18 | //fmt.Println("=======111111") 19 | //GenContextKit(c1) 20 | //fmt.Println("========2222") // 校验traceId是否设置成功 21 | //traceId := ctxkit.GetTraceId(c) 22 | //if len(traceId) == 0 { 23 | // t.Error("GenContextKit error") 24 | // return 25 | //} 26 | //fmt.Println("GenContextKit traceId:", traceId) 27 | } 28 | -------------------------------------------------------------------------------- /http/middleware/request_id.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "github.com/gin-gonic/gin" 5 | "github.com/qit-team/snow-core/utils" 6 | ) 7 | 8 | func GenRequestId(c *gin.Context) { 9 | reqId := utils.GenUUID() 10 | c.Request.Header.Add("X-Request-Id", reqId) 11 | c.Header("X-Request-Id", reqId) 12 | c.Next() 13 | } 14 | -------------------------------------------------------------------------------- /kernel/close/close.go: -------------------------------------------------------------------------------- 1 | package close 2 | 3 | import "sync" 4 | 5 | var ( 6 | closeSet []Closeable 7 | lock sync.RWMutex 8 | ) 9 | 10 | type Closeable interface { 11 | Close() error 12 | } 13 | 14 | //注册应用停止时需要释放链接的服务 15 | func Register(closeable Closeable) { 16 | lock.Lock() 17 | defer lock.Unlock() 18 | closeSet = append(closeSet, closeable) 19 | } 20 | 21 | //批量注册应用停止时需要释放链接的服务 22 | func MultiRegister(closeableSet ...Closeable) { 23 | lock.Lock() 24 | defer lock.Unlock() 25 | closeSet = append(closeSet, closeableSet...) 26 | } 27 | 28 | //释放链接 29 | func Free() { 30 | for _, v := range closeSet { 31 | if v != nil { 32 | v.Close() 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /kernel/close/close_test.go: -------------------------------------------------------------------------------- 1 | package close 2 | 3 | import "testing" 4 | 5 | type mockClose struct { 6 | } 7 | 8 | func (m *mockClose) Close() error { 9 | return nil 10 | } 11 | 12 | func TestRegister(t *testing.T) { 13 | defer func() { 14 | if e := recover(); e != nil { 15 | t.Error(e) 16 | } 17 | 18 | }() 19 | 20 | cl := new(mockClose) 21 | Register(cl) 22 | Register(nil) 23 | MultiRegister(new(mockClose), nil) 24 | Free() 25 | } 26 | -------------------------------------------------------------------------------- /kernel/container/app.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | var App = NewContainer() 4 | -------------------------------------------------------------------------------- /kernel/container/container.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "reflect" 7 | "strings" 8 | "sync" 9 | ) 10 | 11 | var ( 12 | ErrFactoryNotFound = errors.New("factory not found") 13 | ) 14 | 15 | type factory = func() (interface{}, error) 16 | 17 | // 容器 18 | type Container struct { 19 | mu sync.RWMutex 20 | singletons map[string]interface{} 21 | factories map[string]factory 22 | } 23 | 24 | // 容器实例化 25 | func NewContainer() *Container { 26 | return &Container{ 27 | singletons: make(map[string]interface{}), 28 | factories: make(map[string]factory), 29 | } 30 | } 31 | 32 | // 注册单例对象 33 | func (p *Container) SetSingleton(name string, singleton interface{}) { 34 | p.mu.Lock() 35 | p.singletons[name] = singleton 36 | p.mu.Unlock() 37 | } 38 | 39 | // 获取单例对象 40 | func (p *Container) GetSingleton(name string) interface{} { 41 | p.mu.RLock() 42 | ins, _ := p.singletons[name] 43 | p.mu.RUnlock() 44 | return ins 45 | } 46 | 47 | // 获取实例对象 48 | func (p *Container) GetPrototype(name string) (interface{}, error) { 49 | p.mu.RLock() 50 | factory, ok := p.factories[name] 51 | p.mu.RUnlock() 52 | if !ok { 53 | return nil, ErrFactoryNotFound 54 | } 55 | return factory() 56 | } 57 | 58 | // 设置实例对象工厂 59 | func (p *Container) SetPrototype(name string, factory factory) { 60 | p.mu.Lock() 61 | p.factories[name] = factory 62 | p.mu.Unlock() 63 | } 64 | 65 | // 注入依赖 66 | func (p *Container) Ensure(instance interface{}) error { 67 | elemType := reflect.TypeOf(instance).Elem() 68 | ele := reflect.ValueOf(instance).Elem() 69 | for i := 0; i < elemType.NumField(); i++ { // 遍历字段 70 | fieldType := elemType.Field(i) 71 | tag := fieldType.Tag.Get("di") // 获取tag 72 | diName := p.injectName(tag) 73 | if diName == "" { 74 | continue 75 | } 76 | var ( 77 | diInstance interface{} 78 | err error 79 | ) 80 | if p.isSingleton(tag) { 81 | diInstance = p.GetSingleton(diName) 82 | } 83 | if p.isPrototype(tag) { 84 | diInstance, err = p.GetPrototype(diName) 85 | } 86 | if err != nil { 87 | return err 88 | } 89 | if diInstance == nil { 90 | return errors.New(diName + " dependency not found") 91 | } 92 | ele.Field(i).Set(reflect.ValueOf(diInstance)) 93 | } 94 | return nil 95 | } 96 | 97 | // 获取需要注入的依赖名称 98 | func (p *Container) injectName(tag string) string { 99 | tags := strings.Split(tag, ",") 100 | if len(tags) == 0 { 101 | return "" 102 | } 103 | return tags[0] 104 | } 105 | 106 | // 检测是否单例依赖 107 | func (p *Container) isSingleton(tag string) bool { 108 | tags := strings.Split(tag, ",") 109 | for _, name := range tags { 110 | if name == "prototype" { 111 | return false 112 | } 113 | } 114 | return true 115 | } 116 | 117 | // 检测是否实例依赖 118 | func (p *Container) isPrototype(tag string) bool { 119 | tags := strings.Split(tag, ",") 120 | for _, name := range tags { 121 | if name == "prototype" { 122 | return true 123 | } 124 | } 125 | return false 126 | } 127 | 128 | // 打印容器内部实例 129 | func (p *Container) String() string { 130 | lines := make([]string, 0, len(p.singletons)+len(p.factories)+2) 131 | lines = append(lines, "singletons:") 132 | for name, item := range p.singletons { 133 | if item == nil { 134 | line := fmt.Sprintf(" %s: %s %s", name, "", "") 135 | lines = append(lines, line) 136 | continue 137 | } 138 | 139 | line := fmt.Sprintf(" %s: %p %s", name, &item, reflect.TypeOf(item).String()) 140 | lines = append(lines, line) 141 | } 142 | lines = append(lines, "factories:") 143 | for name, item := range p.factories { 144 | if item == nil { 145 | line := fmt.Sprintf(" %s: %s %s", name, "", "") 146 | lines = append(lines, line) 147 | continue 148 | } 149 | 150 | line := fmt.Sprintf(" %s: %p %s", name, &item, reflect.TypeOf(item).String()) 151 | lines = append(lines, line) 152 | } 153 | return strings.Join(lines, "\n") 154 | } 155 | -------------------------------------------------------------------------------- /kernel/container/container_test.go: -------------------------------------------------------------------------------- 1 | package container 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | func TestContainer_SetSingleton(t *testing.T) { 10 | App.SetSingleton("di1", "1") 11 | App.SetSingleton("di2", 2) 12 | a1 := App.GetSingleton("di1") 13 | if a1 != "1" { 14 | t.Error("not same") 15 | return 16 | } 17 | 18 | a3 := App.GetSingleton("di3") 19 | if a3 != nil { 20 | t.Error("not same") 21 | return 22 | } 23 | } 24 | 25 | func TestContainer_demo(t *testing.T) { 26 | nameStr := App.String() 27 | 28 | if len(nameStr) == 0 { 29 | t.Error("String() empty") 30 | return 31 | } 32 | 33 | if strings.Index(nameStr, "di1") == -1 || strings.Index(nameStr, "di2") == -1 { 34 | t.Error("String ret error") 35 | return 36 | } 37 | fmt.Println("=======string ret:", nameStr) 38 | 39 | bool1 := App.isSingleton("di1") 40 | if !bool1 { 41 | t.Error("isSingleton Error") 42 | return 43 | } 44 | 45 | bool2 := App.isPrototype("snow-test") 46 | if bool2 { 47 | t.Error("isPrototype Error") 48 | return 49 | } 50 | 51 | strTest := App.injectName("snow-test,snow-test111") 52 | if strTest != "snow-test" || len(strTest) == 0 { 53 | t.Error("injectName Error") 54 | return 55 | } 56 | 57 | //App.Ensure("di1") 58 | 59 | App.SetPrototype("snow", factoryDemo) 60 | 61 | ret, err := App.GetPrototype("snow") 62 | if err != nil { 63 | fmt.Println("=======", err) 64 | t.Error("GetPrototype error") 65 | return 66 | } 67 | fmt.Println("GetPrototype ret:", ret) 68 | 69 | // after set prototype string again ,for cover more branch of if&else 70 | nameStr = App.String() 71 | 72 | if len(nameStr) == 0 { 73 | t.Error("String() empty") 74 | return 75 | } 76 | 77 | // for cover branch of exception return 78 | strTest = App.injectName("") 79 | if len(strTest) != 0 { 80 | t.Error("injectName Exception branch Error") 81 | return 82 | } 83 | 84 | bool1 = App.isSingleton("prototype") 85 | if bool1 { 86 | t.Error("isSingleton Exception branch Error") 87 | return 88 | } 89 | 90 | bool2 = App.isPrototype("prototype") 91 | if !bool2 { 92 | t.Error("isPrototype Exception branch Error") 93 | return 94 | } 95 | 96 | } 97 | 98 | func factoryDemo() (i interface{}, err error) { 99 | return 100 | } 101 | -------------------------------------------------------------------------------- /kernel/provider/provider.go: -------------------------------------------------------------------------------- 1 | package provider 2 | 3 | type Provider interface { 4 | Register(args ...interface{}) error 5 | Provides() []string 6 | Close() error 7 | } 8 | -------------------------------------------------------------------------------- /kernel/server/command.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/qit-team/snow-core/command" 5 | ) 6 | 7 | // Execute one-time command 8 | func ExecuteCommand(name string, registerCommand func(*command.Command)) error { 9 | //注册并执行某个name对应的脚本 10 | c := command.New() 11 | registerCommand(c) 12 | err := c.Execute(name) 13 | return err 14 | } 15 | -------------------------------------------------------------------------------- /kernel/server/console.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "github.com/robfig/cron" 6 | "time" 7 | ) 8 | 9 | func waitConsoleStop(c *cron.Cron) { 10 | //等待结束 11 | WaitStop() 12 | 13 | //暂停新的Cron任务执行 14 | c.Stop() 15 | 16 | //等待执行中的cron任务结束,目前简单实现等待5s后结束 17 | if GetDebug() { 18 | fmt.Println("wait 5 sencods") 19 | } 20 | time.Sleep(time.Second * 5) 21 | 22 | CloseService() 23 | } 24 | 25 | // Start Cron Schedule 26 | func StartConsole(pidFile string, registerSchedule func(*cron.Cron)) error { 27 | //注册Cron执行计划 28 | cronEngine := cron.New() 29 | registerSchedule(cronEngine) 30 | cronEngine.Start() 31 | 32 | //写pid文件 33 | WritePidFile(pidFile) 34 | 35 | //注册信号量 36 | RegisterSignal() 37 | 38 | //等待停止信号 39 | waitConsoleStop(cronEngine) 40 | return nil 41 | } 42 | -------------------------------------------------------------------------------- /kernel/server/console_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "github.com/robfig/cron" 6 | "testing" 7 | ) 8 | 9 | var consolvechan chan int 10 | 11 | func TestStartConsole(t *testing.T) { 12 | pidFile := "../../.env_console_pid" 13 | consolvechan = make(chan int, 1) 14 | 15 | go func() { 16 | <-consolvechan 17 | stopServer(pidFile) 18 | }() 19 | StartConsole(pidFile, TempRegisterSchedule) 20 | } 21 | 22 | func TempRegisterSchedule(c *cron.Cron) { 23 | //c.AddFunc("0 30 * * * *", test) 24 | c.AddFunc("@every 1s", testConsole) 25 | } 26 | 27 | func testConsole() { 28 | fmt.Println("run test console") 29 | consolvechan <- 1 30 | } 31 | -------------------------------------------------------------------------------- /kernel/server/http.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "github.com/fvbock/endless" 6 | "github.com/gin-gonic/gin" 7 | "github.com/qit-team/snow-core/config" 8 | "strconv" 9 | "syscall" 10 | ) 11 | 12 | /** 13 | * 启动gin引擎 14 | * @wiki https://github.com/fvbock/endless#signals 15 | */ 16 | func runEngine(engine *gin.Engine, addr string, pidPath string) error { 17 | server := endless.NewServer(addr, engine) 18 | server.BeforeBegin = func(add string) { 19 | pid := syscall.Getpid() 20 | if gin.Mode() != gin.ReleaseMode { 21 | fmt.Printf("Actual pid is %d \n\r", pid) 22 | } 23 | WritePidFile(pidPath, pid) 24 | } 25 | err := server.ListenAndServe() 26 | return err 27 | } 28 | 29 | // Start proxy with config file 30 | func StartHttp(pidFile string, apiConf config.ApiConfig, registerRoute func(*gin.Engine)) error { 31 | //设置gin调试模式 32 | if !GetDebug() { 33 | gin.SetMode(gin.ReleaseMode) 34 | } 35 | //配置路由引擎 36 | engine := gin.New() 37 | registerRoute(engine) 38 | addr := apiConf.Host + ":" + strconv.Itoa(apiConf.Port) 39 | runEngine(engine, addr, pidFile) 40 | 41 | //因为信号处理由endless接管实现平滑重启和关闭,这里模拟通用的结束信号 42 | go func() { 43 | Stop() 44 | }() 45 | 46 | //等待停止信号 47 | WaitStop() 48 | return nil 49 | } 50 | -------------------------------------------------------------------------------- /kernel/server/http_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/gin-gonic/gin" 5 | "github.com/qit-team/snow-core/config" 6 | 7 | "fmt" 8 | "go/build" 9 | "net/http" 10 | "os" 11 | "os/exec" 12 | "path/filepath" 13 | "strconv" 14 | "strings" 15 | "testing" 16 | ) 17 | 18 | func TestStartHttp(t *testing.T) { 19 | pidFile := "../../.env_pid" 20 | var apiConf config.ApiConfig 21 | apiConf.Host = "127.0.0.1" 22 | apiConf.Port = 9000 23 | go func() { 24 | for i := 1; i < 100; i++ { 25 | // 进程启动后http服务没法自动停掉,需要借助os.exec执行自动停止 26 | stopServer(pidFile) 27 | } 28 | }() 29 | StartHttp(pidFile, apiConf, RegisterRoute) 30 | } 31 | 32 | //api路由配置 33 | func RegisterRoute(router *gin.Engine) { 34 | router.GET("/hello", HandleHello) 35 | } 36 | 37 | func HandleHello(c *gin.Context) { 38 | c.JSON(http.StatusOK, gin.H{ 39 | "code": 200, 40 | "message": "ok", 41 | "request_uri": c.Request.URL.Path, 42 | "data": "test", 43 | }) 44 | c.Abort() 45 | return 46 | } 47 | 48 | func stopServer(pidPath string) error { 49 | pid, _ := ReadPidFile(pidPath) 50 | pidStr := strconv.Itoa(pid) 51 | cmdName, cmdPath, command := "stop http", gopath(), "kill -TERM "+pidStr 52 | 53 | cmds := strings.Split(command, " ") 54 | err := runTool(cmdName, cmdPath, cmds[0], cmds[1:]) 55 | return err 56 | } 57 | 58 | // 封装os.exec 59 | func runTool(name, dir, cmd string, args []string) (err error) { 60 | toolCmd := &exec.Cmd{ 61 | Path: cmd, 62 | Args: append([]string{cmd}, args...), 63 | Dir: dir, 64 | Stdin: os.Stdin, 65 | Stdout: os.Stdout, 66 | Stderr: os.Stderr, 67 | Env: os.Environ(), 68 | } 69 | 70 | if filepath.Base(cmd) == cmd { 71 | var lp string 72 | if lp, err = exec.LookPath(cmd); err == nil { 73 | toolCmd.Path = lp 74 | } 75 | } 76 | if err = toolCmd.Run(); err != nil { 77 | if e, ok := err.(*exec.ExitError); !ok || !e.Exited() { 78 | fmt.Fprintf(os.Stderr, "运行 %s 出错: %v\n", name, err) 79 | } 80 | } 81 | return 82 | } 83 | 84 | // 获取gopath路径 85 | func gopath() (gp string) { 86 | gopaths := strings.Split(os.Getenv("GOPATH"), ":") 87 | if len(gopaths) == 1 { 88 | return gopaths[0] 89 | } 90 | pwd, err := os.Getwd() 91 | if err != nil { 92 | return 93 | } 94 | abspwd, err := filepath.Abs(pwd) 95 | if err != nil { 96 | return 97 | } 98 | for _, gopath := range gopaths { 99 | absgp, err := filepath.Abs(gopath) 100 | if err != nil { 101 | return 102 | } 103 | if strings.HasPrefix(abspwd, absgp) { 104 | return absgp 105 | } 106 | } 107 | return build.Default.GOPATH 108 | } 109 | -------------------------------------------------------------------------------- /kernel/server/job.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "github.com/qit-team/work" 6 | "time" 7 | ) 8 | 9 | func waitJobStop(job *work.Job) { 10 | //等待结束 11 | WaitStop() 12 | 13 | //暂停新的job任务执行 14 | job.Stop() 15 | 16 | err := job.WaitStop(60 * time.Second) 17 | if err != nil { 18 | fmt.Println("wait stop error", err) 19 | } 20 | 21 | CloseService() 22 | } 23 | 24 | // Start Job Worker 25 | func StartJob(pidFile string, registerWorker func(*work.Job)) error { 26 | //注册Job Worker 27 | job := work.New() 28 | registerWorker(job) 29 | job.Start() 30 | 31 | //写pid文件 32 | WritePidFile(pidFile) 33 | 34 | //注册信号量 35 | RegisterSignal() 36 | 37 | //等待停止信号 38 | waitJobStop(job) 39 | return nil 40 | } 41 | -------------------------------------------------------------------------------- /kernel/server/job_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "github.com/qit-team/snow-core/config" 6 | "github.com/qit-team/snow-core/queue" 7 | "github.com/qit-team/snow-core/queue/redisqueue" 8 | "github.com/qit-team/snow-core/redis" 9 | "github.com/qit-team/work" 10 | "sync" 11 | "testing" 12 | "time" 13 | ) 14 | 15 | var ( 16 | jb *work.Job 17 | register func(job *work.Job) 18 | mu sync.RWMutex 19 | ) 20 | 21 | var q queue.Queue 22 | 23 | func init() { 24 | 25 | redisConf := config.RedisConfig{ 26 | Master: config.RedisBaseConfig{ 27 | Host: "127.0.0.1", 28 | Port: 6379, 29 | }, 30 | } 31 | 32 | // 注册redis类 33 | err := redis.Pr.Register("redis", redisConf, true) 34 | if err != nil { 35 | fmt.Println(err) 36 | } 37 | 38 | // 为了让redisqueue的driver先进行注册 39 | redisqueue.GetRedisQueue("redis") 40 | q = queue.GetQueue("redis", queue.DriverTypeRedis) 41 | } 42 | 43 | func TestStartJob(t *testing.T) { 44 | pidFile := "../../.env_pid" 45 | 46 | StartJob(pidFile, TempRegisterWorker) 47 | } 48 | 49 | func TempRegisterWorker(job *work.Job) { 50 | TempSetJob(job) 51 | 52 | //设置worker的任务投递回调函数 53 | job.AddFunc("topic-test", test) 54 | //设置worker的任务投递回调函数,和并发数 55 | job.AddFunc("topic-test1", test, 2) 56 | //使用worker结构进行注册 57 | job.AddWorker("topic-test2", &work.Worker{Call: work.MyWorkerFunc(test), MaxConcurrency: 1}) 58 | 59 | TempRegisterQueueDriver(job) 60 | } 61 | 62 | func TempSetJob(job *work.Job) { 63 | if jb == nil { 64 | jb = job 65 | } 66 | } 67 | 68 | func TempRegisterQueueDriver(job *work.Job) { 69 | q := queue.GetQueue(redis.SingletonMain, queue.DriverTypeRedis) 70 | job.AddQueue(q, "topic-test1", "topic-test2") 71 | job.AddQueue(q) 72 | } 73 | 74 | func test(task work.Task) work.TaskResult { 75 | time.Sleep(time.Millisecond * 5) 76 | s, err := work.JsonEncode(task) 77 | if err != nil { 78 | 79 | return work.TaskResult{Id: task.Id, State: work.StateFailedWithAck} 80 | } else { 81 | //work.StateSucceed 会进行ack确认 82 | fmt.Println("do task", s) 83 | return work.TaskResult{Id: task.Id, State: work.StateSucceed} 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /kernel/server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "github.com/qit-team/snow-core/kernel/close" 7 | "os" 8 | "os/signal" 9 | "strconv" 10 | "strings" 11 | "syscall" 12 | ) 13 | 14 | const ( 15 | Version = "1.0" 16 | BuildCommit = "" 17 | BuildDate = "" 18 | ) 19 | 20 | type serverInfo struct { 21 | stop chan bool 22 | debug bool 23 | } 24 | 25 | var srv *serverInfo 26 | 27 | func init() { 28 | srv = new(serverInfo) 29 | srv.stop = make(chan bool, 0) 30 | } 31 | 32 | //将进程号写入文件 33 | func WritePidFile(path string, pidArgs ...int) error { 34 | fd, err := os.Create(path) 35 | if err != nil { 36 | return err 37 | } 38 | defer fd.Close() 39 | 40 | var pid int 41 | if len(pidArgs) > 0 { 42 | pid = pidArgs[0] 43 | } else { 44 | pid = os.Getpid() 45 | } 46 | _, err = fd.WriteString(fmt.Sprintf("%d\n", pid)) 47 | return err 48 | } 49 | 50 | //读取文件的进程号 51 | func ReadPidFile(path string) (int, error) { 52 | fd, err := os.Open(path) 53 | if err != nil { 54 | return -1, err 55 | } 56 | defer fd.Close() 57 | 58 | buf := bufio.NewReader(fd) 59 | line, err := buf.ReadString('\n') 60 | if err != nil { 61 | return -1, err 62 | } 63 | line = strings.TrimSpace(line) 64 | return strconv.Atoi(line) 65 | } 66 | 67 | //阻塞等待程序内部的Stop通道信号 68 | func WaitStop() { 69 | <-srv.stop 70 | } 71 | 72 | //关闭服务 73 | func CloseService() { 74 | if srv.debug { 75 | fmt.Println("close service") 76 | } 77 | close.Free() 78 | } 79 | 80 | //处理进程的信号量 81 | func HandleSignal(sig os.Signal) { 82 | switch sig { 83 | case syscall.SIGINT: 84 | fallthrough 85 | case syscall.SIGTERM: 86 | Stop() 87 | default: 88 | } 89 | } 90 | 91 | //监听信号量 92 | func RegisterSignal() { 93 | go func() { 94 | var sigs = []os.Signal{ 95 | syscall.SIGHUP, 96 | syscall.SIGUSR1, 97 | syscall.SIGUSR2, 98 | syscall.SIGINT, 99 | syscall.SIGTERM, 100 | } 101 | c := make(chan os.Signal) 102 | signal.Notify(c, sigs...) 103 | for { 104 | sig := <-c //blocked 105 | HandleSignal(sig) 106 | } 107 | }() 108 | } 109 | 110 | // HandleUserCmd use to stop/reload the proxy service 111 | func HandleUserCmd(cmd string, pidFile string) error { 112 | var sig os.Signal 113 | 114 | switch cmd { 115 | case "stop": 116 | sig = syscall.SIGTERM 117 | case "restart": 118 | //目前api使用endless平滑重启,需要传递此信号,其他只需要平滑关闭就可以了 119 | sig = syscall.SIGHUP 120 | default: 121 | return fmt.Errorf("unknown user command %s", cmd) 122 | } 123 | 124 | pid, err := ReadPidFile(pidFile) 125 | if err != nil { 126 | return err 127 | } 128 | 129 | if srv.debug { 130 | fmt.Printf("send %v to pid %d \n", sig, pid) 131 | } 132 | 133 | proc := new(os.Process) 134 | proc.Pid = pid 135 | return proc.Signal(sig) 136 | } 137 | 138 | // Stop proxy 139 | func Stop() { 140 | srv.stop <- true 141 | } 142 | 143 | func SetDebug(debug bool) { 144 | srv.debug = debug 145 | return 146 | } 147 | 148 | func GetDebug() bool { 149 | return srv.debug 150 | } 151 | -------------------------------------------------------------------------------- /kernel/server/server_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "os" 5 | "syscall" 6 | "testing" 7 | ) 8 | 9 | const httpPid = 10001 10 | 11 | func TestGetDebug(t *testing.T) { 12 | debug := GetDebug() 13 | if debug != false { 14 | t.Error("debug status is error") 15 | return 16 | } 17 | SetDebug(true) 18 | debug = GetDebug() 19 | if debug != true { 20 | t.Error("debug status is error") 21 | return 22 | } 23 | } 24 | 25 | func TestSignel(t *testing.T) { 26 | // 这种函数只能通过单测跑一遍看是否有报错 没有返回数据orError类型可以判断 27 | RegisterSignal() 28 | 29 | go func() { 30 | var sigs = []os.Signal{ 31 | syscall.SIGHUP, 32 | syscall.SIGINT, 33 | syscall.SIGTERM, 34 | } 35 | 36 | HandleSignal(sigs[0]) 37 | HandleSignal(sigs[1]) 38 | HandleSignal(sigs[2]) 39 | }() 40 | } 41 | 42 | func TestPidFile(t *testing.T) { 43 | err := WritePidFile("../../.env_pid", httpPid) 44 | if err != nil { 45 | t.Error("WritePidFile error") 46 | return 47 | } 48 | 49 | pid, err := ReadPidFile("../../.env_pid") 50 | if err != nil { 51 | t.Error("ReadPidFile error") 52 | return 53 | } else if pid != 10001 { 54 | t.Error("ReadPidFile error result not right") 55 | return 56 | } 57 | } 58 | 59 | func TestStop(t *testing.T) { 60 | go func() { 61 | WaitStop() 62 | }() 63 | 64 | //time.Sleep(1) 65 | go func() { 66 | Stop() 67 | }() 68 | } 69 | 70 | func TestCloseService(t *testing.T) { 71 | CloseService() 72 | } 73 | 74 | func TestHandleUserCmd(t *testing.T) { 75 | err := HandleUserCmd("cmd", "../../.env_pid") 76 | if err == nil { 77 | t.Error("unknown cmd error") 78 | return 79 | } 80 | 81 | err = HandleUserCmd("stop", "../../.env_pid") 82 | // process already finished 83 | if err == nil { 84 | t.Error("stop cmd error") 85 | return 86 | } 87 | 88 | err = HandleUserCmd("restart", "../../.env_pid") 89 | // process already finished 90 | if err == nil { 91 | t.Error("restart cmd error") 92 | return 93 | } 94 | 95 | // todo construct more cases, for example the process is running 96 | } 97 | -------------------------------------------------------------------------------- /log/accesslogger/access_logger.go: -------------------------------------------------------------------------------- 1 | package accesslogger 2 | 3 | import ( 4 | "github.com/hetiansu5/accesslog" 5 | coresio "github.com/hetiansu5/cores/io" 6 | "github.com/qit-team/snow-core/log/logger" 7 | "io" 8 | ) 9 | 10 | func InitAccessLog(logHandler string, logDir string) (*accesslog.AccessLogger, error) { 11 | var writer io.Writer 12 | if logHandler == logger.HandlerStdout { 13 | writer = logger.GetStdOutWriter(logDir) 14 | } else { 15 | logFile := logDir + "/access.log" 16 | writerFile, err := coresio.NewRollingFileWriter(logFile, coresio.NewDailyRollingManager()) 17 | if err != nil { 18 | return nil, err 19 | } 20 | writer = writerFile 21 | } 22 | 23 | acl, err := accesslog.NewLogger(accesslog.Output(writer), accesslog.Pattern(accesslog.JSONPattern)) 24 | if err != nil { 25 | return nil, err 26 | } 27 | return acl, nil 28 | } 29 | -------------------------------------------------------------------------------- /log/accesslogger/provider.go: -------------------------------------------------------------------------------- 1 | package accesslogger 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/hetiansu5/accesslog" 7 | "github.com/qit-team/snow-core/config" 8 | "github.com/qit-team/snow-core/helper" 9 | "github.com/qit-team/snow-core/kernel/container" 10 | "sync" 11 | ) 12 | 13 | const SingletonMain = "access_logger" 14 | 15 | var Pr *provider 16 | 17 | func init() { 18 | Pr = new(provider) 19 | Pr.mp = make(map[string]interface{}) 20 | } 21 | 22 | type provider struct { 23 | mu sync.RWMutex 24 | mp map[string]interface{} //配置 25 | dn string //default name 26 | } 27 | 28 | /** 29 | * @param string 依赖注入别名 必选 30 | * @param config.LogConfig 配置 必选 31 | * @param bool 是否启用懒加载 可选 32 | */ 33 | func (p *provider) Register(args ...interface{}) (err error) { 34 | diName, lazy, err := helper.TransformArgs(args...) 35 | if err != nil { 36 | return 37 | } 38 | 39 | conf, ok := args[1].(config.LogConfig) 40 | if !ok { 41 | return errors.New("args[1] is not config.LogConfig") 42 | } 43 | 44 | p.mu.Lock() 45 | p.mp[diName] = args[1] 46 | if len(p.mp) == 1 { 47 | p.dn = diName 48 | } 49 | p.mu.Unlock() 50 | 51 | if !lazy { 52 | _, err = setSingleton(diName, conf) 53 | } 54 | return 55 | } 56 | 57 | func (p *provider) Provides() []string { 58 | p.mu.RLock() 59 | defer p.mu.RUnlock() 60 | 61 | return helper.MapToArray(p.mp) 62 | } 63 | 64 | func (p *provider) Close() error { 65 | return nil 66 | } 67 | 68 | //注入单例 69 | func setSingleton(diName string, conf config.LogConfig) (ins *accesslog.AccessLogger, err error) { 70 | ins, err = InitAccessLog(conf.Handler, conf.Dir) 71 | if err == nil { 72 | container.App.SetSingleton(diName, ins) 73 | } 74 | return 75 | } 76 | 77 | //获取单例 78 | func getSingleton(diName string, lazy bool) *accesslog.AccessLogger { 79 | rc := container.App.GetSingleton(diName) 80 | if rc != nil { 81 | return rc.(*accesslog.AccessLogger) 82 | } 83 | if lazy == false { 84 | return nil 85 | } 86 | 87 | Pr.mu.RLock() 88 | conf, ok := Pr.mp[diName].(config.LogConfig) 89 | Pr.mu.RUnlock() 90 | if !ok { 91 | panic(fmt.Sprintf("access_logger di_name:%s not exist", diName)) 92 | } 93 | 94 | ins, err := setSingleton(diName, conf) 95 | if err != nil { 96 | panic(fmt.Sprintf("access_logger di_name:%s err:%s", diName, err.Error())) 97 | } 98 | return ins 99 | } 100 | 101 | //外部通过注入别名获取资源,解耦资源的关系 102 | func GetAccessLogger(args ...string) *accesslog.AccessLogger { 103 | diName := helper.GetDiName(Pr.dn, args...) 104 | return getSingleton(diName, true) 105 | } 106 | -------------------------------------------------------------------------------- /log/accesslogger/provider_test.go: -------------------------------------------------------------------------------- 1 | package accesslogger 2 | 3 | import ( 4 | "github.com/qit-team/snow-core/config" 5 | "testing" 6 | ) 7 | 8 | func Test_getSingleton(t *testing.T) { 9 | c := getSingleton("", false) 10 | if c != nil { 11 | t.Error("client is not equal nil") 12 | return 13 | } 14 | } 15 | 16 | func TestProvider(t *testing.T) { 17 | err := Pr.Register("access_logger", config.LogConfig{}) 18 | if err == nil { 19 | t.Error(err) 20 | return 21 | } 22 | 23 | conf := config.LogConfig{ 24 | Handler: "file", 25 | Level: "info", 26 | Dir: "../../", 27 | } 28 | 29 | err = Pr.Register("access_logger", conf, true) 30 | if err != nil { 31 | t.Error(err) 32 | return 33 | } 34 | 35 | arr := Pr.Provides() 36 | if !(len(arr) == 1 && arr[0] == "access_logger") { 37 | t.Errorf("Provides is not match. %v", arr) 38 | return 39 | } 40 | 41 | err = Pr.Register("access_logger1", conf) 42 | if err != nil { 43 | t.Error(err) 44 | return 45 | } 46 | 47 | arr = Pr.Provides() 48 | if !(len(arr) == 2 && arr[1] == "access_logger1" || arr[1] == "access_logger") { 49 | t.Errorf("Provides is not match. %v", arr) 50 | return 51 | } 52 | 53 | c := GetAccessLogger() 54 | if c == nil { 55 | t.Error("client is equal nil") 56 | return 57 | } 58 | 59 | c1 := GetAccessLogger("access_logger1") 60 | if c1 == nil { 61 | t.Error("client is equal nil") 62 | return 63 | } 64 | 65 | defer func() { 66 | if e := recover(); e != "access_logger di_name:access_logger2 not exist" { 67 | t.Error("not panic") 68 | } 69 | }() 70 | GetAccessLogger("access_logger2") 71 | 72 | err = Pr.Close() 73 | if err != nil { 74 | t.Error(err) 75 | return 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /log/logger/http_logger.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "context" 5 | "github.com/qit-team/snow-core/http/ctxkit" 6 | "github.com/sirupsen/logrus" 7 | "os" 8 | ) 9 | 10 | var ( 11 | hostname string 12 | ) 13 | 14 | type withField struct { 15 | Key string 16 | Value interface{} 17 | } 18 | 19 | //此结构的数据将会在挂靠到日志的一级键中体现 20 | //demo: logger.Info(ctx, "curl", NewWithFiled("key1", "value1"), NewWithFiled("key2", "value2"), "msg1", "msg2") 21 | func NewWithField(key string, value interface{}) *withField { 22 | return &withField{Key: key, Value: value} 23 | } 24 | 25 | //批量 26 | func BatchNewWithField(data map[string]interface{}) (arr []*withField) { 27 | for k, v := range data { 28 | arr = append(arr, NewWithField(k, v)) 29 | } 30 | return arr 31 | } 32 | 33 | func GetHostName() string { 34 | if hostname == "" { 35 | hostname, _ = os.Hostname() 36 | if hostname == "" { 37 | hostname = "unknown" 38 | } 39 | } 40 | return hostname 41 | } 42 | 43 | func formatLog(c context.Context, t string, args ...*withField) logrus.Fields { 44 | data := logrus.Fields{ 45 | "type": t, 46 | "host": GetHostName(), 47 | } 48 | 49 | if c != nil { 50 | traceId := ctxkit.GetTraceId(c) 51 | if traceId != "" { 52 | data["trace_id"] = traceId 53 | } else { 54 | traceId, _ := ctxkit.GenerateTraceId(c) 55 | data["trace_id"] = traceId 56 | } 57 | 58 | domain := ctxkit.GetHost(c) 59 | if domain != "" { 60 | data["domain"] = domain 61 | } 62 | 63 | sip := ctxkit.GetServerId(c) 64 | if sip != "" { 65 | data["sip"] = sip 66 | } 67 | 68 | cip := ctxkit.GetClientId(c) 69 | if cip != "" { 70 | data["cip"] = cip 71 | } 72 | } 73 | 74 | for _, field := range args { 75 | if _, ok := data[field.Key]; !ok { 76 | data[field.Key] = field.Value 77 | } 78 | } 79 | 80 | return data 81 | } 82 | 83 | func Trace(c context.Context, logType string, msg ...interface{}) { 84 | withFields, newMsg := splitMsg(msg) 85 | data := formatLog(c, logType, withFields...) 86 | GetLogger().WithFields(data).Trace(newMsg...) 87 | } 88 | 89 | func Debug(c context.Context, logType string, msg ...interface{}) { 90 | withFields, newMsg := splitMsg(msg) 91 | data := formatLog(c, logType, withFields...) 92 | GetLogger().WithFields(data).Debug(newMsg...) 93 | } 94 | 95 | func Info(c context.Context, logType string, msg ...interface{}) { 96 | withFields, newMsg := splitMsg(msg) 97 | data := formatLog(c, logType, withFields...) 98 | GetLogger().WithFields(data).Info(newMsg...) 99 | } 100 | 101 | func Warn(c context.Context, logType string, msg ...interface{}) { 102 | withFields, newMsg := splitMsg(msg) 103 | data := formatLog(c, logType, withFields...) 104 | GetLogger().WithFields(data).Warn(newMsg...) 105 | } 106 | 107 | func Error(c context.Context, logType string, msg ...interface{}) { 108 | withFields, newMsg := splitMsg(msg) 109 | data := formatLog(c, logType, withFields...) 110 | GetLogger().WithFields(data).Error(newMsg...) 111 | } 112 | 113 | func Fatal(c context.Context, logType string, msg ...interface{}) { 114 | withFields, newMsg := splitMsg(msg) 115 | data := formatLog(c, logType, withFields...) 116 | GetLogger().WithFields(data).Fatal(newMsg...) 117 | } 118 | 119 | func Panic(c context.Context, logType string, msg ...interface{}) { 120 | withFields, newMsg := splitMsg(msg) 121 | data := formatLog(c, logType, withFields...) 122 | GetLogger().WithFields(data).Panic(newMsg...) 123 | } 124 | 125 | //将日志消息分裂 126 | func splitMsg(msg []interface{}) (withFields []*withField, newMsg []interface{}) { 127 | for _, v := range msg { 128 | switch v.(type) { 129 | case *withField: 130 | withFields = append(withFields, v.(*withField)) 131 | case []*withField: 132 | // 如果是通过batchNewWithFields,需要做如下处理 133 | tempWithFieldsList := v.([]*withField) 134 | if len(tempWithFieldsList) != 0 { 135 | for _, tempWithField := range tempWithFieldsList { 136 | withFields = append(withFields, tempWithField) 137 | } 138 | } 139 | default: 140 | newMsg = append(newMsg, v) 141 | } 142 | } 143 | return 144 | } 145 | -------------------------------------------------------------------------------- /log/logger/logger.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "fmt" 5 | "github.com/sirupsen/logrus" 6 | "os" 7 | ) 8 | 9 | //app.log_handler为file时,日志格式为:[time(ISO8601)] [host] [type(service.module.function)] [req_id] [server_ip] [client_ip] [message(json:code,message,file,line,trace,biz_data)] 10 | //app.log_handler为stdout时,日志格式为:{"t": "time(ISO8601)", "lvl": "level", "h": "host", "type": "type(service.module.function)", "reqid": "req_id", "sip": "server_ip", "cip": "client_ip", "msg": {"code": 0, "message": "xxx", "file": "file", "line": 0}} 11 | 12 | const HandlerFile = "file" 13 | const HandlerStdout = "stdout" 14 | 15 | func GetStdOutWriter(path string) (writer *os.File) { 16 | //此处命名管道会阻塞,直到有进程读取了这个命名管道 17 | writer, err := os.OpenFile(path, os.O_WRONLY, 777) 18 | if err != nil { 19 | fmt.Fprintf(os.Stderr, "Failed to open file, %v\n", err) 20 | } 21 | return 22 | } 23 | 24 | func InitLog(logFileName, logHandler string, logDir string, logLevel string, segment bool) (*logrus.Logger, error) { 25 | logger := logrus.New() 26 | 27 | if len(logFileName) == 0 { 28 | logFileName = "snow" 29 | } 30 | 31 | //设置日志等级 32 | level, err := logrus.ParseLevel(logLevel) 33 | if err == nil { 34 | logger.SetLevel(level) 35 | } 36 | 37 | //设置日志输出格式 38 | logger.Formatter = &logrus.JSONFormatter{} 39 | 40 | //设置日志输出方式 标准输出或文件 41 | if logHandler == HandlerStdout { 42 | writer := os.Stdout 43 | logger.SetOutput(writer) 44 | return logger, nil 45 | } 46 | 47 | if segment { 48 | hook, err := NewLfsHook(logger, logDir, logFileName) 49 | if err != nil { 50 | return nil, err 51 | } 52 | logger.Hooks.Add(hook) 53 | } else { 54 | rollHook, err := NewRollHook(logger, logDir, logFileName) 55 | if err != nil { 56 | return nil, err 57 | } 58 | logger.Hooks.Add(rollHook) 59 | } 60 | 61 | return logger, nil 62 | } 63 | -------------------------------------------------------------------------------- /log/logger/provider.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/qit-team/snow-core/config" 7 | "github.com/qit-team/snow-core/helper" 8 | "github.com/qit-team/snow-core/kernel/container" 9 | "github.com/sirupsen/logrus" 10 | "os" 11 | "sync" 12 | ) 13 | 14 | const SingletonMain = "logger" 15 | 16 | var Pr *provider 17 | 18 | func init() { 19 | Pr = new(provider) 20 | Pr.mp = make(map[string]interface{}) 21 | } 22 | 23 | type provider struct { 24 | mu sync.RWMutex 25 | mp map[string]interface{} //配置 26 | dn string //default name 27 | } 28 | 29 | /** 30 | * @param string 依赖注入别名 必选 31 | * @param config.LogConfig 配置 必选 32 | * @param bool 是否启用懒加载 可选 33 | */ 34 | func (p *provider) Register(args ...interface{}) (err error) { 35 | diName, lazy, err := helper.TransformArgs(args...) 36 | if err != nil { 37 | return 38 | } 39 | 40 | conf, ok := args[1].(config.LogConfig) 41 | if !ok { 42 | return errors.New("args[1] is not config.LogConfig") 43 | } 44 | 45 | p.mu.Lock() 46 | p.mp[diName] = args[1] 47 | if len(p.mp) == 1 { 48 | p.dn = diName 49 | } 50 | p.mu.Unlock() 51 | 52 | if !lazy { 53 | _, err = setSingleton(diName, conf) 54 | } 55 | return 56 | } 57 | 58 | //注册过的别名 59 | func (p *provider) Provides() []string { 60 | p.mu.RLock() 61 | defer p.mu.RUnlock() 62 | 63 | return helper.MapToArray(p.mp) 64 | } 65 | 66 | //释放资源 67 | func (p *provider) Close() error { 68 | arr := p.Provides() 69 | for _, k := range arr { 70 | logger := getSingleton(k, false) 71 | if logger != nil { 72 | log, ok := logger.Out.(*os.File) 73 | if ok { 74 | log.Sync() 75 | log.Close() 76 | } 77 | } 78 | } 79 | return nil 80 | } 81 | 82 | //注入单例 83 | func setSingleton(diName string, conf config.LogConfig) (ins *logrus.Logger, err error) { 84 | ins, err = InitLog(conf.FileName, conf.Handler, conf.Dir, conf.Level, conf.Segment) 85 | if err == nil { 86 | container.App.SetSingleton(diName, ins) 87 | } 88 | return 89 | } 90 | 91 | //获取单例 92 | func getSingleton(diName string, lazy bool) *logrus.Logger { 93 | rc := container.App.GetSingleton(diName) 94 | if rc != nil { 95 | return rc.(*logrus.Logger) 96 | } 97 | if lazy == false { 98 | return nil 99 | } 100 | 101 | Pr.mu.RLock() 102 | conf, ok := Pr.mp[diName].(config.LogConfig) 103 | Pr.mu.RUnlock() 104 | if !ok { 105 | panic(fmt.Sprintf("logger di_name:%s not exist", diName)) 106 | } 107 | 108 | ins, err := setSingleton(diName, conf) 109 | if err != nil { 110 | panic(fmt.Sprintf("logger di_name:%s err:%s", diName, err.Error())) 111 | } 112 | return ins 113 | } 114 | 115 | //外部通过注入别名获取资源,解耦资源的关系 116 | func GetLogger(args ...string) *logrus.Logger { 117 | diName := helper.GetDiName(Pr.dn, args...) 118 | return getSingleton(diName, true) 119 | } 120 | -------------------------------------------------------------------------------- /log/logger/provider_test.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "fmt" 5 | "github.com/gin-gonic/gin" 6 | "github.com/qit-team/snow-core/config" 7 | "github.com/qit-team/snow-core/http/ctxkit" 8 | "testing" 9 | ) 10 | 11 | var contextTest, contextTest1 *gin.Context 12 | 13 | func init() { 14 | contextTest = &gin.Context{} 15 | contextTest1 = &gin.Context{} 16 | } 17 | 18 | func Test_getSingleton(t *testing.T) { 19 | c := getSingleton("", false) 20 | if c != nil { 21 | t.Error("client is not equal nil") 22 | return 23 | } 24 | } 25 | 26 | func TestProvider(t *testing.T) { 27 | err := Pr.Register("logger", config.LogConfig{}) 28 | if err == nil { 29 | t.Error(err) 30 | return 31 | } 32 | 33 | conf := config.LogConfig{ 34 | Handler: "file", 35 | Level: "info", 36 | Dir: "../../", 37 | } 38 | 39 | err = Pr.Register("logger", conf, true) 40 | if err != nil { 41 | t.Error(err) 42 | return 43 | } 44 | 45 | // test generate trace id 46 | traceId, _ := ctxkit.GenerateTraceId(contextTest) 47 | 48 | // 对context设置traceId 49 | ctxkit.SetTraceId(contextTest, traceId) 50 | temp := ctxkit.GetTraceId(contextTest) 51 | fmt.Println("=======test_temp:", temp) 52 | Info(contextTest, "========testTraceId:levelInfo=====") 53 | Error(contextTest, "========testTraceId:levelError=====") 54 | Warn(contextTest, "========testTraceId:levelWarn=====") 55 | Debug(contextTest, "========testTraceId:levelDebug=====") 56 | Trace(contextTest, "========testTraceId:levelTrace=====") 57 | //Fatal(contextTest, "========testTraceId:levelFatal=====") 58 | 59 | Info(nil, "================") 60 | 61 | // 新的context,确保第一次记录log,会在context中种下traceId 62 | Info(contextTest1, "========testTraceId111:levelInfo=====") 63 | Error(contextTest1, "========testTraceId111:levelError=====") 64 | Warn(contextTest1, "========testTraceId111:levelWarn=====") 65 | Debug(contextTest1, "========testTraceId111:levelDebug=====") 66 | // 调用panic会导致go test fail 67 | //Panic(contextTest1, "========testTraceId111:levelPanic=====") 68 | 69 | arr := Pr.Provides() 70 | if !(len(arr) == 1 && arr[0] == "logger") { 71 | t.Errorf("Provides is not match. %v", arr) 72 | return 73 | } 74 | 75 | err = Pr.Register("logger1", conf) 76 | if err != nil { 77 | t.Error(err) 78 | return 79 | } 80 | 81 | arr = Pr.Provides() 82 | if !(len(arr) == 2 && arr[1] == "logger1" || arr[1] == "logger") { 83 | t.Errorf("Provides is not match. %v", arr) 84 | return 85 | } 86 | 87 | c := GetLogger() 88 | if c == nil { 89 | t.Error("client is equal nil") 90 | return 91 | } 92 | 93 | c1 := GetLogger("logger1") 94 | if c1 == nil { 95 | t.Error("client is equal nil") 96 | return 97 | } 98 | 99 | defer func() { 100 | if e := recover(); e != "logger di_name:logger2 not exist" { 101 | t.Error("not panic") 102 | } 103 | }() 104 | GetLogger("logger2") 105 | 106 | err = Pr.Close() 107 | if err != nil { 108 | t.Error(err) 109 | return 110 | } 111 | 112 | } 113 | 114 | func TestNewWithField(t *testing.T) { 115 | // 测试NewWithField && BatchNewWithField方法 116 | conf := config.LogConfig{ 117 | Handler: "file", 118 | Level: "info", 119 | Dir: "../../", 120 | } 121 | 122 | defer func() { 123 | if e := recover(); e != nil { 124 | t.Error("test NewWithField panic") 125 | } 126 | }() 127 | err := Pr.Register("logger", conf, true) 128 | if err != nil { 129 | t.Error(err) 130 | return 131 | } 132 | 133 | Info(nil, "===TestNewWithField", NewWithField("data", "snow")) 134 | 135 | logInfo := map[string]interface{}{ 136 | "url": "testUrl", 137 | "params": "snow", 138 | "num": 100, 139 | } 140 | 141 | Info(nil, "===TestBatchNewWithField", BatchNewWithField(logInfo)) 142 | } 143 | -------------------------------------------------------------------------------- /log/logger/roll_hook.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | /** 4 | * 日志文件分割 5 | */ 6 | import ( 7 | "fmt" 8 | "os" 9 | "sync" 10 | "time" 11 | 12 | "github.com/sirupsen/logrus" 13 | ) 14 | 15 | const ( 16 | rollDay = iota 17 | rollHour 18 | ) 19 | 20 | const ( 21 | defaultDayTimePattern = "20060102" 22 | defaultHourTimePattern = "20060102-15" 23 | ) 24 | 25 | type RollHook struct { 26 | dir string 27 | name string 28 | currFileTime string 29 | writer *os.File 30 | timePattern string 31 | lock sync.Mutex 32 | logger *logrus.Logger 33 | } 34 | 35 | func (rh *RollHook) openNewFile() (*os.File, error) { 36 | _, err := os.Stat(rh.dir) 37 | if os.IsNotExist(err) { 38 | err = os.MkdirAll(rh.dir, 0755) 39 | if err != nil { 40 | return nil, err 41 | } 42 | } 43 | 44 | newFileTime := time.Now().Format(rh.timePattern) 45 | newFileName := fmt.Sprintf("%s/%s.%s.log", rh.dir, rh.name, newFileTime) 46 | newWriter, err := os.OpenFile(newFileName, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) 47 | if err != nil { 48 | return nil, err 49 | } 50 | rh.currFileTime = newFileTime 51 | 52 | return newWriter, nil 53 | } 54 | 55 | func NewRollHook(logger *logrus.Logger, dir, name string) (*RollHook, error) { 56 | rh := new(RollHook) 57 | rh.name = name 58 | rh.timePattern = defaultDayTimePattern 59 | rh.logger = logger 60 | rh.dir = dir 61 | 62 | writer, err := rh.openNewFile() 63 | if err != nil { 64 | return nil, err 65 | } 66 | rh.writer = writer 67 | logger.Out = writer 68 | 69 | return rh, nil 70 | } 71 | 72 | func (rh *RollHook) needRoll() bool { 73 | return rh.currFileTime != time.Now().Format(rh.timePattern) 74 | } 75 | 76 | func (rh *RollHook) roll() error { 77 | rh.lock.Lock() 78 | defer rh.lock.Unlock() 79 | 80 | if !rh.needRoll() { 81 | return nil 82 | } 83 | 84 | oldWriter := rh.writer 85 | newWriter, err := rh.openNewFile() 86 | if err != nil { 87 | return err 88 | } 89 | 90 | rh.writer = newWriter 91 | rh.logger.Out = newWriter 92 | 93 | err = oldWriter.Close() 94 | if err != nil { 95 | return err 96 | } 97 | return nil 98 | } 99 | 100 | func (rh *RollHook) SetRollType(rType int) { 101 | switch rType { 102 | case rollDay: 103 | rh.timePattern = defaultDayTimePattern 104 | case rollHour: 105 | rh.timePattern = defaultHourTimePattern 106 | } 107 | } 108 | 109 | func (rh *RollHook) Fire(entry *logrus.Entry) error { 110 | defer func() { 111 | if err := recover(); err != nil { 112 | 113 | } 114 | }() 115 | if rh.needRoll() { 116 | return rh.roll() 117 | } 118 | return nil 119 | } 120 | 121 | func (rh *RollHook) Levels() []logrus.Level { 122 | return []logrus.Level{ 123 | logrus.DebugLevel, 124 | logrus.InfoLevel, 125 | logrus.WarnLevel, 126 | logrus.ErrorLevel, 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /log/logger/segment.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "fmt" 5 | rotatelogs "github.com/lestrrat-go/file-rotatelogs" 6 | "github.com/rifflock/lfshook" 7 | "github.com/sirupsen/logrus" 8 | "io" 9 | "time" 10 | ) 11 | 12 | func NewLfsHook(l *logrus.Logger, logDir, name string) (*lfshook.LfsHook, error) { 13 | var ( 14 | err error 15 | infoWriter io.Writer 16 | warnWriter io.Writer 17 | errorWriter io.Writer 18 | ) 19 | 20 | infoPath := fmt.Sprintf("%s/%s.%s", logDir, name, "INFO.%Y%m%d.log") 21 | linkInfoPath := fmt.Sprintf("%s/%s.%s", logDir, name, "INFO.log") 22 | infoWriter, err = rotatelogs.New( 23 | infoPath, 24 | rotatelogs.WithLinkName(linkInfoPath), 25 | rotatelogs.WithRotationTime(time.Duration(24)*time.Hour), 26 | ) 27 | if err != nil { 28 | return nil, err 29 | } 30 | 31 | warnPath := fmt.Sprintf("%s/%s.%s", logDir, name, "WARN.%Y%m%d.log") 32 | linkWarnPath := fmt.Sprintf("%s/%s.%s", logDir, name, "WARN.log") 33 | warnWriter, err = rotatelogs.New( 34 | warnPath, 35 | rotatelogs.WithLinkName(linkWarnPath), 36 | rotatelogs.WithRotationTime(time.Duration(24)*time.Hour), 37 | ) 38 | if err != nil { 39 | return nil, err 40 | } 41 | 42 | errorPath := fmt.Sprintf("%s/%s.%s", logDir, name, "ERROR.%Y%m%d.log") 43 | linkErrorPath := fmt.Sprintf("%s/%s.%s", logDir, name, "ERROR.log") 44 | errorWriter, err = rotatelogs.New( 45 | errorPath, 46 | rotatelogs.WithLinkName(linkErrorPath), 47 | rotatelogs.WithRotationTime(time.Duration(24)*time.Hour), 48 | ) 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | writerMap := lfshook.WriterMap{} 54 | if infoWriter != nil { 55 | writerMap[logrus.TraceLevel] = infoWriter 56 | writerMap[logrus.DebugLevel] = infoWriter 57 | writerMap[logrus.InfoLevel] = infoWriter 58 | } 59 | if warnWriter != nil { 60 | writerMap[logrus.WarnLevel] = warnWriter 61 | } 62 | if errorWriter != nil { 63 | writerMap[logrus.ErrorLevel] = errorWriter 64 | writerMap[logrus.FatalLevel] = errorWriter 65 | writerMap[logrus.PanicLevel] = errorWriter 66 | } 67 | 68 | return lfshook.NewHook(writerMap, &logrus.JSONFormatter{}), nil 69 | } 70 | -------------------------------------------------------------------------------- /log/logger/source_hook.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | /** 4 | * warn+级别日志额外处理 5 | */ 6 | import ( 7 | "fmt" 8 | "path/filepath" 9 | "runtime" 10 | "strings" 11 | 12 | "github.com/sirupsen/logrus" 13 | ) 14 | 15 | type SourceHook struct { 16 | level logrus.Level 17 | } 18 | 19 | func NewSourceHook(level logrus.Level) *SourceHook { 20 | return &SourceHook{ 21 | level: level, 22 | } 23 | } 24 | 25 | func (sh *SourceHook) Fire(entry *logrus.Entry) error { 26 | for skip := 5; skip < 9; skip++ { 27 | if pc, file, line, ok := runtime.Caller(skip); ok { 28 | arr := strings.Split(file, "/") 29 | n := len(arr) 30 | if n > 1 && arr[n-2] == "logrus" { 31 | continue 32 | } 33 | funcName := runtime.FuncForPC(pc).Name() 34 | entry.Data["caller"] = fmt.Sprintf("%s:%d:%s", filepath.Base(file), line, funcName) 35 | } 36 | break 37 | } 38 | return nil 39 | } 40 | func (sh *SourceHook) Levels() []logrus.Level { 41 | levels := make([]logrus.Level, 4) 42 | for _, level := range logrus.AllLevels { 43 | if level <= sh.level { 44 | levels = append(levels, level) 45 | } 46 | } 47 | return levels 48 | } 49 | -------------------------------------------------------------------------------- /queue/alimnsqueue/alimns_queue.go: -------------------------------------------------------------------------------- 1 | package alimnsqueue 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "strings" 7 | "sync" 8 | 9 | ali_mns "github.com/aliyun/aliyun-mns-go-sdk" 10 | "github.com/qit-team/snow-core/alimns" 11 | "github.com/qit-team/snow-core/queue" 12 | ) 13 | 14 | const ( 15 | DefaultVisibilityTimeout = int64(60) 16 | ) 17 | 18 | var ( 19 | mp map[string]queue.Queue 20 | mu sync.RWMutex 21 | ) 22 | 23 | type MnsQueue struct { 24 | client ali_mns.MNSClient 25 | } 26 | 27 | //new实例 28 | func newMnsQueue(diName string) queue.Queue { 29 | m := new(MnsQueue) 30 | m.client = alimns.GetMns(diName) 31 | return m 32 | } 33 | 34 | //单例模式 35 | func GetMnsQueue(diName string) queue.Queue { 36 | key := diName 37 | mu.RLock() 38 | q, ok := mp[key] 39 | mu.RUnlock() 40 | if ok { 41 | return q 42 | } 43 | 44 | q = newMnsQueue(diName) 45 | mu.Lock() 46 | mp[key] = q 47 | mu.Unlock() 48 | return q 49 | } 50 | 51 | /** 52 | * 队列消息入队 53 | * args[0] delay 延迟消息,单位秒 54 | * args[1] priority 55 | */ 56 | func (m *MnsQueue) Enqueue(ctx context.Context, key string, message string, args ...interface{}) (bool, error) { 57 | delay, priority := getOption(args...) 58 | 59 | //mns消息格式 可以设置优先级和延迟时间 60 | aliMsg := ali_mns.MessageSendRequest{ 61 | MessageBody: message, 62 | DelaySeconds: delay, 63 | Priority: priority, 64 | } 65 | 66 | queueClient := alimns.GetMnsBasicQueue(m.client, key) 67 | _, err := queueClient.SendMessage(aliMsg) 68 | 69 | if err != nil { 70 | return false, err 71 | } 72 | 73 | return true, nil 74 | } 75 | 76 | /** 77 | * 队列消息出队 78 | * args[0] 消息下次可见时间 79 | * return 第一个参数是消息 第二个参数是mns的ReceiptHandle命名为token,通过token确定消息是否从队列删除 80 | */ 81 | func (m *MnsQueue) Dequeue(ctx context.Context, key string, args ...interface{}) (message string, tag string, token string, dequeueCount int64, err error) { 82 | respChan := make(chan ali_mns.MessageReceiveResponse) 83 | errChan := make(chan error) 84 | //目前只做单次读取,不需要实现常驻进程,这部分由job完成 85 | 86 | //从alimns接收消息放入channel 87 | queueClient := alimns.GetMnsBasicQueue(m.client, key) 88 | 89 | go func() { 90 | queueClient.ReceiveMessage(respChan, errChan) 91 | }() 92 | 93 | select { 94 | case resp := <-respChan: 95 | visibilityTimeout := DefaultVisibilityTimeout 96 | l := len(args) 97 | if l > 0 { 98 | vt, ok := args[0].(int64) 99 | if ok { 100 | visibilityTimeout = vt 101 | } 102 | } 103 | //代表N秒内其他并发队列不可见这条消息 104 | if ret, err1 := queueClient.ChangeMessageVisibility(resp.ReceiptHandle, visibilityTimeout); err1 != nil { 105 | err = err1 106 | } else { 107 | //处理resp.MessageBody 阿里这什么sdk 也不说明各个函数作用。。。暂时就按照demo例子里用到的函数写了 108 | return resp.MessageBody, "", ret.ReceiptHandle, resp.DequeueCount, nil 109 | } 110 | case err2 := <-errChan: 111 | err = err2 112 | if strings.Contains(err2.Error(), "MessageNotExist") { 113 | //如果消息不存在的时候,返回的message为空字符串 114 | err = nil 115 | return 116 | } 117 | } 118 | return 119 | } 120 | 121 | /** 122 | * 队列消息批量入队 123 | * args[0] delay 延迟消息,单位秒 124 | * args[1] priority 125 | */ 126 | func (m *MnsQueue) BatchEnqueue(ctx context.Context, key string, messages []string, args ...interface{}) (bool, error) { 127 | if len(messages) == 0 { 128 | return false, errors.New("messages is empty") 129 | } 130 | 131 | delay, priority := getOption(args...) 132 | 133 | //mns消息格式 可以设置优先级和延迟时间 134 | msgArr := make([]ali_mns.MessageSendRequest, len(messages)) 135 | for k, message := range messages { 136 | msgArr[k] = ali_mns.MessageSendRequest{ 137 | MessageBody: message, 138 | DelaySeconds: delay, 139 | Priority: priority, 140 | } 141 | } 142 | 143 | queueClient := alimns.GetMnsBasicQueue(m.client, key) 144 | _, err := queueClient.BatchSendMessage(msgArr...) 145 | 146 | if err != nil { 147 | return false, err 148 | } 149 | 150 | return true, nil 151 | } 152 | 153 | /** 154 | * 确认消息接收 155 | */ 156 | func (m *MnsQueue) AckMsg(ctx context.Context, key string, token string, args ...interface{}) (bool, error) { 157 | queueClient := alimns.GetMnsBasicQueue(m.client, key) 158 | if len(token) < 1 { 159 | return false, errors.New("token empty") 160 | } 161 | err := queueClient.DeleteMessage(token) 162 | if err != nil { 163 | return false, err 164 | } 165 | return true, nil 166 | } 167 | 168 | //入队参数 169 | func getOption(args ...interface{}) (delay int64, priority int64) { 170 | delay = 0 171 | priority = 1 172 | 173 | l := len(args) 174 | if l > 0 { 175 | de, ok := args[0].(int64) 176 | if ok { 177 | delay = de 178 | } 179 | 180 | if l > 1 { 181 | pr, ok := args[1].(int64) 182 | if ok { 183 | priority = pr 184 | } 185 | } 186 | } 187 | return 188 | } 189 | 190 | func init() { 191 | mp = make(map[string]queue.Queue) 192 | queue.Register(queue.DriverTypeAliMns, GetMnsQueue) 193 | } 194 | -------------------------------------------------------------------------------- /queue/alimnsqueue/queue_test.go: -------------------------------------------------------------------------------- 1 | package alimnsqueue 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/qit-team/snow-core/alimns" 7 | "github.com/qit-team/snow-core/config" 8 | "github.com/qit-team/snow-core/queue" 9 | "github.com/qit-team/snow-core/utils" 10 | "io/ioutil" 11 | "strings" 12 | "testing" 13 | ) 14 | 15 | var q queue.Queue 16 | 17 | func init() { 18 | //需要自己在文件填好配置 19 | bs, err := ioutil.ReadFile("../../.env.mns") 20 | conf := config.MnsConfig{} 21 | if err == nil { 22 | str := string(bs) 23 | arr := strings.Split(str, "\n") 24 | if len(arr) >= 3 { 25 | conf.Url = arr[0] 26 | conf.AccessKeyId = arr[1] 27 | conf.AccessKeySecret = arr[2] 28 | } 29 | } 30 | 31 | //注册alimns类 32 | err = alimns.Pr.Register("ali_mns", conf) 33 | if err != nil { 34 | fmt.Println(err) 35 | } 36 | 37 | q = queue.GetQueue("ali_mns", queue.DriverTypeAliMns) 38 | } 39 | 40 | func TestEnqueue(t *testing.T) { 41 | q := queue.GetQueue("ali_mns", queue.DriverTypeAliMns) 42 | topic := "snow-topic-one" + fmt.Sprint(utils.GetCurrentTime()) 43 | ctx := context.TODO() 44 | msg := "1" 45 | ok, err := q.Enqueue(ctx, topic, msg) 46 | if err != nil { 47 | t.Error(err) 48 | return 49 | } 50 | if !ok { 51 | t.Error("enqueue is not ok") 52 | return 53 | } 54 | 55 | message, _, token, dequeueCount, err := q.Dequeue(ctx, topic) 56 | fmt.Println("message dequeue num:", dequeueCount) 57 | if err != nil { 58 | t.Error(err) 59 | return 60 | } 61 | if message != msg { 62 | t.Errorf("message is not same %s", message) 63 | return 64 | } 65 | 66 | ok, err = q.AckMsg(ctx, topic, token) 67 | if err != nil { 68 | t.Error(err) 69 | return 70 | } 71 | if !ok { 72 | t.Error("ack is not ok") 73 | return 74 | } 75 | 76 | message, _, token, dequeueCount, err = q.Dequeue(ctx, topic) 77 | if err != nil { 78 | t.Error(err) 79 | return 80 | } else if message != "" { 81 | t.Error("message from blank queue must be empty") 82 | return 83 | } 84 | 85 | _, err = q.AckMsg(ctx, topic, token) 86 | if !(err != nil && err.Error() == "token empty") { 87 | t.Error("must return empty ack token error") 88 | } 89 | } 90 | 91 | func TestBatchEnqueue(t *testing.T) { 92 | ctx := context.TODO() 93 | topic := "snow-topic-batch" + fmt.Sprint(utils.GetCurrentTime()) 94 | messages := []string{"11", "21"} 95 | _, err := q.BatchEnqueue(ctx, topic, messages) 96 | if err != nil { 97 | t.Error("batch enqueue error", err) 98 | return 99 | } 100 | 101 | fmt.Println("batch enqueue", topic, messages) 102 | 103 | message1, _, token1, dequeueCount, err := q.Dequeue(ctx, topic) 104 | if err != nil { 105 | t.Error(err) 106 | return 107 | } 108 | 109 | message2, _, token2, dequeueCount, err := q.Dequeue(ctx, topic) 110 | if err != nil { 111 | t.Error(err) 112 | return 113 | } 114 | fmt.Println("TestBatchEnqueue:dequeueCount:", dequeueCount) 115 | 116 | if message1 == messages[0] { 117 | if message2 != messages[1] { 118 | t.Errorf("message2 is not same origin:%s real:%s", messages[1], message2) 119 | return 120 | } 121 | } else if message2 == messages[0] { 122 | if message1 != messages[1] { 123 | t.Errorf("message1 is not same origin:%s real:%s", messages[1], message1) 124 | return 125 | } 126 | } else { 127 | t.Errorf("message is not same %s", messages[1]) 128 | return 129 | } 130 | 131 | ok, err := q.AckMsg(ctx, topic, token1) 132 | if err != nil { 133 | t.Errorf("message1 ack err:%s", err.Error()) 134 | return 135 | } 136 | if !ok { 137 | t.Error("message1 ack is not ok") 138 | return 139 | } 140 | 141 | ok, err = q.AckMsg(ctx, topic, token2) 142 | if err != nil { 143 | t.Errorf("message1 ack err:%s", err.Error()) 144 | return 145 | } 146 | if !ok { 147 | t.Error("message2 ack is not ok") 148 | return 149 | } 150 | } 151 | 152 | func TestBatchEnqueueEmpty(t *testing.T) { 153 | ctx := context.TODO() 154 | topic := "snow-topic-batch" 155 | messages := make([]string, 0) 156 | _, err := q.BatchEnqueue(ctx, topic, messages) 157 | if err == nil { 158 | t.Error("empty message must return error") 159 | return 160 | } 161 | } 162 | 163 | func Test_getOption(t *testing.T) { 164 | delay, priority := getOption(int64(1), int64(10)) 165 | if delay != 1 { 166 | t.Errorf("delay is not equal 1. %d", delay) 167 | } else if priority != 10 { 168 | t.Errorf("priority is not equal 10. %d", priority) 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /queue/alirocketqueue/alirocket_queue.go: -------------------------------------------------------------------------------- 1 | package alirocketqueue 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | "sync" 8 | "time" 9 | 10 | mq_http_sdk "github.com/aliyunmq/mq-http-go-sdk" 11 | "github.com/gogap/errors" 12 | "github.com/qit-team/snow-core/aliyunmq" 13 | "github.com/qit-team/snow-core/queue" 14 | ) 15 | 16 | const ( 17 | DefaultVisibilityTimeout = int64(120) 18 | ) 19 | 20 | var ( 21 | mp map[string]queue.Queue 22 | mu sync.RWMutex 23 | ) 24 | 25 | type AliyunMq struct { 26 | client mq_http_sdk.MQClient 27 | } 28 | 29 | // new实例 30 | func newAliyunMq(diName string) queue.Queue { 31 | m := new(AliyunMq) 32 | m.client = aliyunmq.GetAliyunMq(diName) 33 | 34 | return m 35 | } 36 | 37 | // 单例模式 38 | func GetAliyunRocketQueue(diName string) queue.Queue { 39 | key := diName 40 | mu.RLock() 41 | q, ok := mp[key] 42 | mu.RUnlock() 43 | if ok { 44 | return q 45 | } 46 | 47 | q = newAliyunMq(diName) 48 | 49 | mu.Lock() 50 | mp[key] = q 51 | mu.Unlock() 52 | return q 53 | } 54 | 55 | /** 56 | * 队列消息入队 57 | * args[0] instanceId 58 | */ 59 | func (m *AliyunMq) Enqueue(ctx context.Context, key string, message string, args ...interface{}) (bool, error) { 60 | instanceId, _, _ := getOption(args...) 61 | 62 | // 获取rocketmq的producer,这个和mns不同,区分了producer和consumer,alimns统一为client 63 | mqProducer := m.client.GetProducer(instanceId, key) 64 | 65 | // aliyunmq消息格式 可以设置MessageTag和Properties等信息,先只提供最基本的MessageBody 66 | mqMsg := mq_http_sdk.PublishMessageRequest{ 67 | MessageBody: message, 68 | } 69 | _, err := mqProducer.PublishMessage(mqMsg) 70 | if err != nil { 71 | return false, err 72 | } 73 | 74 | return true, nil 75 | } 76 | 77 | /** 78 | * 队列消息出队 79 | * param 第二个参数是队列名称,args[0]是instanceId,args[1]是groupId,目前只有rocketmq需要groupId 80 | * return 第一个参数是消息 第二个参数是aliyunmq的ReceiptHandle命名为token,通过token确定消息是否从队列删除,第三个参数为消费次数 81 | */ 82 | func (m *AliyunMq) Dequeue(ctx context.Context, key string, args ...interface{}) (message string, tag string, token string, dequeueCount int64, err error) { 83 | instanceId, groupId, messageTag := getOption(args...) 84 | 85 | // 获取rocketmq的consumer 86 | mqConsumer := m.client.GetConsumer(instanceId, key, groupId, messageTag) 87 | 88 | //endChan := make(chan int) 89 | respChan := make(chan mq_http_sdk.ConsumeMessageResponse) 90 | errChan := make(chan error) 91 | 92 | go func() { 93 | // 长轮询消费消息 94 | // 长轮询表示如果topic没有消息则请求会在服务端挂住3s,3s内如果有消息可以消费则立即返回 95 | mqConsumer.ConsumeMessage(respChan, errChan, 96 | 1, // 一次最多消费条数(最多可设置为16条) 97 | 3, // 长轮询时间3秒(最多可设置为30秒) 98 | ) 99 | }() 100 | 101 | select { 102 | case resp := <-respChan: 103 | { 104 | // 处理业务逻辑 105 | var handles []string 106 | respLen := len(resp.Messages) 107 | fmt.Printf("AliRocketMq Consume %d messages---->\n", respLen) 108 | if respLen != 1 { 109 | // 如果消息内容多于一条 可以给出提示or返回err 110 | } 111 | 112 | for _, v := range resp.Messages { 113 | handles = append(handles, v.ReceiptHandle) 114 | //fmt.Printf("\tMessageID: %s, PublishTime: %d, MessageTag: %s\n"+ 115 | // "\tConsumedTimes: %d, FirstConsumeTime: %d, NextConsumeTime: %d\n"+ 116 | // "\tBody: %s\n"+ 117 | // "\tProps: %s\n", 118 | // v.MessageId, v.PublishTime, v.MessageTag, v.ConsumedTimes, 119 | // v.FirstConsumeTime, v.NextConsumeTime, v.MessageBody, v.Properties) 120 | return v.MessageBody, v.MessageTag, v.ReceiptHandle, v.ConsumedTimes, nil 121 | } 122 | 123 | } 124 | case errMsg := <-errChan: 125 | { 126 | // 没有消息 127 | err = errMsg 128 | if strings.Contains(errMsg.(errors.ErrCode).Error(), "MessageNotExist") { 129 | err = nil 130 | // fmt.Println("\nNo new message, continue!") 131 | } else { 132 | fmt.Println("aliyunmq get msg error:", errMsg) 133 | time.Sleep(time.Duration(3) * time.Second) 134 | } 135 | } 136 | case <-time.After(35 * time.Second): 137 | { 138 | fmt.Println("Timeout of consumer message ??") 139 | err = errors.New("Timeout of consumer message") 140 | } 141 | } 142 | 143 | return 144 | } 145 | 146 | /** 147 | * 队列消息批量入队 148 | * args[0] instanceId 149 | * 注:rocket其实没有批量函数,所以循环调用publishMsg方法 150 | */ 151 | func (m *AliyunMq) BatchEnqueue(ctx context.Context, key string, messageList []string, args ...interface{}) (bool, error) { 152 | if len(messageList) == 0 { 153 | return false, errors.New("messageList is empty") 154 | } 155 | 156 | for _, message := range messageList { 157 | flag, err := m.Enqueue(ctx, key, message, args) 158 | if flag == false || err != nil { 159 | return flag, err 160 | } 161 | } 162 | 163 | return true, nil 164 | } 165 | 166 | /** 167 | * 确认消息接收 168 | * args[0]是instanceId,args[1]是groupId,args[2]是messageTag 169 | */ 170 | func (m *AliyunMq) AckMsg(ctx context.Context, key string, token string, args ...interface{}) (bool, error) { 171 | if len(token) < 1 { 172 | return false, errors.New("token empty") 173 | } 174 | 175 | instanceId, groupId, messageTag := getOption(args...) 176 | 177 | // 获取rocketmq的consumer 178 | mqConsumer := m.client.GetConsumer(instanceId, key, groupId, messageTag) 179 | 180 | var handles []string 181 | // rocketmq的确认函数是需要传递handle数组 182 | handles = append(handles, token) 183 | 184 | ackErr := mqConsumer.AckMessage(handles) 185 | if ackErr != nil { 186 | // 某些消息的句柄可能超时了会导致确认不成功 187 | fmt.Println("aliyunmq ack error token", token, ",err:", ackErr) 188 | 189 | for _, errAckItem := range ackErr.(errors.ErrCode).Context()["Detail"].([]mq_http_sdk.ErrAckItem) { 190 | fmt.Printf("aliyunmq handle ack item: \tErrorHandle:%s, ErrorCode:%s, ErrorMsg:%s\n", 191 | errAckItem.ErrorHandle, errAckItem.ErrorCode, errAckItem.ErrorMsg) 192 | } 193 | return false, ackErr 194 | //time.Sleep(time.Duration(3) * time.Second) 195 | } else { 196 | fmt.Printf("aliyunmq Ack ---->\n\t%s\n", handles) 197 | } 198 | 199 | return true, nil 200 | } 201 | 202 | // 缺省参数统一获取 203 | // args[0]是instanceId,args[1]是groupId,args[2]是messageTag 204 | func getOption(args ...interface{}) (instanceId, groupId, messageTag string) { 205 | instanceId = "" 206 | groupId = "" 207 | messageTag = "" 208 | 209 | l := len(args) 210 | if l > 0 { 211 | tempInstance, ok := args[0].(string) 212 | if ok { 213 | instanceId = tempInstance 214 | } 215 | if l > 1 { 216 | tempGroup, ok := args[1].(string) 217 | if ok { 218 | groupId = tempGroup 219 | } 220 | } 221 | if l > 2 { 222 | tempTag, ok := args[2].(string) 223 | if ok { 224 | messageTag = tempTag 225 | } 226 | } 227 | } 228 | return 229 | } 230 | 231 | func init() { 232 | mp = make(map[string]queue.Queue) 233 | queue.Register(queue.DriverTypeAliyunMq, GetAliyunRocketQueue) 234 | } 235 | -------------------------------------------------------------------------------- /queue/alirocketqueue/alirocketmq_test.go: -------------------------------------------------------------------------------- 1 | package alirocketqueue 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io/ioutil" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/qit-team/snow-core/aliyunmq" 11 | "github.com/qit-team/snow-core/config" 12 | "github.com/qit-team/snow-core/queue" 13 | ) 14 | 15 | var q queue.Queue 16 | 17 | func init() { 18 | // 需要自己在文件填好配置 19 | conf := config.AliyunMqConfig{} 20 | conf = getConfig() 21 | //注册alimns类 22 | err := aliyunmq.Pr.Register("aliyun_mq", conf) 23 | 24 | if err != nil { 25 | fmt.Println(err) 26 | } 27 | 28 | q = queue.GetQueue("aliyun_mq", queue.DriverTypeAliyunMq) 29 | } 30 | 31 | func TestEnqueue(t *testing.T) { 32 | q := queue.GetQueue("aliyun_mq", queue.DriverTypeAliyunMq) 33 | topic := "SNOW-TOPIC-TEST" 34 | groupId := "GID-SNOW-TOPIC-TEST" 35 | ctx := context.TODO() 36 | msg := "msg from snow core" 37 | ok, err := q.Enqueue(ctx, topic, msg) 38 | if err != nil { 39 | t.Error(err) 40 | return 41 | } 42 | if !ok { 43 | t.Error("enqueue is not ok") 44 | return 45 | } 46 | 47 | message, tag, token, dequeueCount, err := q.Dequeue(ctx, topic, "", groupId) 48 | fmt.Println("message content:", message) 49 | fmt.Println("message tag:", tag) 50 | fmt.Println("message dequeue num:", dequeueCount) 51 | fmt.Println("message token:", token) 52 | if err != nil { 53 | t.Error(err) 54 | return 55 | } 56 | if message != msg { 57 | t.Errorf("message is not same %s", message) 58 | return 59 | } 60 | 61 | ok, err = q.AckMsg(ctx, topic, token, "", groupId) 62 | 63 | fmt.Println("info:", ok, err) 64 | if err != nil { 65 | t.Error(err) 66 | return 67 | } 68 | if !ok { 69 | t.Error("ack is not ok") 70 | return 71 | } 72 | 73 | message, tag, token, dequeueCount, err = q.Dequeue(ctx, topic, "", groupId) 74 | fmt.Println("message content:", message) 75 | fmt.Println("message tag:", tag) 76 | fmt.Println("message dequeue num:", dequeueCount) 77 | fmt.Println("message token:", token) 78 | if err != nil { 79 | t.Error(err) 80 | return 81 | } else if message != "" { 82 | t.Error("message from blank queue must be empty") 83 | return 84 | } 85 | 86 | _, err = q.AckMsg(ctx, topic, token, "", groupId) 87 | fmt.Println("ackMsg,errInfo", err) 88 | if !(err != nil && err.Error() == "token empty") { 89 | t.Error("must return empty ack token error") 90 | } 91 | } 92 | 93 | func TestBatchEnqueue(t *testing.T) { 94 | ctx := context.TODO() 95 | topic := "SNOW-TOPIC-TEST" 96 | groupId := "GID-SNOW-TOPIC-TEST" 97 | messages := []string{"11", "21"} 98 | _, err := q.BatchEnqueue(ctx, topic, messages) 99 | if err != nil { 100 | t.Error("batch enqueue error", err) 101 | return 102 | } 103 | 104 | fmt.Println("batch enqueue", topic, messages) 105 | 106 | message1, _, token1, dequeueCount, err := q.Dequeue(ctx, topic, "", groupId) 107 | if err != nil { 108 | t.Error(err) 109 | return 110 | } 111 | 112 | message2, _, token2, dequeueCount, err := q.Dequeue(ctx, topic, "", groupId) 113 | if err != nil { 114 | t.Error(err) 115 | return 116 | } 117 | fmt.Println("TestBatchEnqueue:dequeueCount:", dequeueCount) 118 | 119 | if message1 == messages[0] { 120 | if message2 != messages[1] { 121 | t.Errorf("message2 is not same origin:%s real:%s", messages[1], message2) 122 | return 123 | } 124 | } else if message2 == messages[0] { 125 | if message1 != messages[1] { 126 | t.Errorf("message1 is not same origin:%s real:%s", messages[1], message1) 127 | return 128 | } 129 | } else { 130 | t.Errorf("message is not same %s", messages[1]) 131 | return 132 | } 133 | 134 | ok, err := q.AckMsg(ctx, topic, token1, "", groupId) 135 | if err != nil { 136 | t.Errorf("message1 ack err:%s", err.Error()) 137 | return 138 | } 139 | if !ok { 140 | t.Error("message1 ack is not ok") 141 | return 142 | } 143 | 144 | ok, err = q.AckMsg(ctx, topic, token2, "", groupId) 145 | if err != nil { 146 | t.Errorf("message1 ack err:%s", err.Error()) 147 | return 148 | } 149 | if !ok { 150 | t.Error("message2 ack is not ok") 151 | return 152 | } 153 | } 154 | 155 | func TestBatchEnqueueEmpty(t *testing.T) { 156 | ctx := context.TODO() 157 | topic := "SNOW-TOPIC-TEST" 158 | groupId := "GID-SNOW-TOPIC-TEST" 159 | messages := make([]string, 0) 160 | _, err := q.BatchEnqueue(ctx, topic, messages, "", groupId) 161 | fmt.Println("TestBatchEnqueueEmpty.Error", err) 162 | if err == nil { 163 | t.Error("empty message must return error") 164 | return 165 | } 166 | } 167 | 168 | func Test_getOption(t *testing.T) { 169 | instanceId, groupId, _ := getOption("", "GID-SNOW-TOPIC-TEST") 170 | if instanceId != "" { 171 | t.Errorf("delay is not equal 1. %s", instanceId) 172 | } else if groupId != "GID-SNOW-TOPIC-TEST" { 173 | t.Errorf("priority is not equal 10. %s", groupId) 174 | } 175 | } 176 | 177 | func getConfig() config.AliyunMqConfig { 178 | //需要自己在文件填好配置 179 | bs, err := ioutil.ReadFile("../../.env.aliyunmq") 180 | 181 | conf := config.AliyunMqConfig{} 182 | if err == nil { 183 | str := string(bs) 184 | arr := strings.Split(str, "\n") 185 | if len(arr) >= 3 { 186 | conf.EndPoint = arr[0] 187 | conf.AccessKey = arr[1] 188 | conf.SecretKey = arr[2] 189 | } 190 | } 191 | return conf 192 | } 193 | -------------------------------------------------------------------------------- /queue/interface.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import "context" 4 | 5 | //队列驱动接口,所有队列驱动都需要实现以下接口 6 | type Queue interface { 7 | //单入队 8 | Enqueue(ctx context.Context, key string, message string, args ...interface{}) (ok bool, err error) 9 | //单出队: 消息不存在是返回空字符串 10 | //增加返回参数,dequeueCount出队消费次数,目前只有alimns需要用到,redis暂无 11 | //增加入参,args ...interface{} 12 | Dequeue(ctx context.Context, key string, args ...interface{}) (message string, tag string, token string, dequeueCount int64, err error) 13 | //确认接收消息redis用不到,alimns需要,后续可以接入kafka或者rabbitmq 14 | // 增加入参args,因为rocketmq需要groupId等数据获取consumer 15 | AckMsg(ctx context.Context, key string, token string, args ...interface{}) (ok bool, err error) 16 | //单key批量入队 17 | BatchEnqueue(ctx context.Context, key string, messages []string, args ...interface{}) (ok bool, err error) 18 | } 19 | -------------------------------------------------------------------------------- /queue/queue.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | ) 7 | 8 | const ( 9 | DriverTypeRedis = "redis" 10 | DriverTypeAliMns = "ali_mns" 11 | DriverTypeAliyunMq = "aliyun_mq" 12 | DriverTypeRocketMq = "rocket_mq" 13 | ) 14 | 15 | var ( 16 | drivers map[string]Instance 17 | mu sync.RWMutex 18 | ) 19 | 20 | type Instance func(diName string) Queue 21 | 22 | func Register(driverType string, driver Instance) { 23 | if driver == nil { 24 | panic("queue.Register driver is nil") 25 | } 26 | mu.Lock() 27 | defer mu.Unlock() 28 | 29 | if _, ok := drivers[driverType]; ok { 30 | panic("queue.Register called twice for driver " + driverType) 31 | } 32 | drivers[driverType] = driver 33 | } 34 | 35 | //获取Queue对象 36 | func GetQueue(diName string, driverType string) (q Queue) { 37 | mu.RLock() 38 | instanceFunc, ok := drivers[driverType] 39 | mu.RUnlock() 40 | if !ok { 41 | panic(fmt.Sprintf("queue.GetQueue unknown driver %s", driverType)) 42 | } 43 | q = instanceFunc(diName) 44 | if q == nil { 45 | panic(fmt.Sprintf("queue.GetQueue unknown diName %s", diName)) 46 | } 47 | return 48 | } 49 | 50 | func init() { 51 | drivers = make(map[string]Instance) 52 | } 53 | -------------------------------------------------------------------------------- /queue/queue_test.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import ( 4 | "fmt" 5 | "github.com/qit-team/snow-core/config" 6 | "github.com/qit-team/snow-core/redis" 7 | "testing" 8 | ) 9 | 10 | func init() { 11 | redisConf := config.RedisConfig{ 12 | Master: config.RedisBaseConfig{ 13 | Host: "127.0.0.1", 14 | Port: 6379, 15 | }, 16 | } 17 | 18 | //注册redis类 19 | err := redis.Pr.Register("redis", redisConf) 20 | if err != nil { 21 | fmt.Println(err) 22 | } 23 | 24 | Register("mock", getMockQueue) 25 | } 26 | 27 | func getMockQueue(diName string) Queue { 28 | return nil 29 | } 30 | 31 | func TestRegister(t *testing.T) { 32 | defer func() { 33 | if e := recover(); e == nil { 34 | t.Errorf("repeat register do not panic") 35 | } 36 | }() 37 | Register("mock", getMockQueue) 38 | } 39 | 40 | func TestRegister_EmptyDriver(t *testing.T) { 41 | defer func() { 42 | if e := recover(); e == nil { 43 | t.Errorf("nil driver do not panic") 44 | } 45 | }() 46 | Register("mock", nil) 47 | } 48 | 49 | func TestGetQueue_Empty(t *testing.T) { 50 | defer func() { 51 | if e := recover(); e == nil { 52 | t.Errorf("unknown driver do not panic") 53 | } 54 | }() 55 | GetQueue("redis", "empty") 56 | } 57 | 58 | func TestGetQueue_Nil(t *testing.T) { 59 | defer func() { 60 | if e := recover(); e == nil { 61 | t.Errorf("unknown diName do not panic") 62 | } 63 | }() 64 | GetQueue("unknown", "mock") 65 | } 66 | -------------------------------------------------------------------------------- /queue/redisqueue/redis_queue.go: -------------------------------------------------------------------------------- 1 | package redisqueue 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | goredis "github.com/go-redis/redis/v8" 7 | "github.com/qit-team/snow-core/queue" 8 | "github.com/qit-team/snow-core/redis" 9 | "sync" 10 | ) 11 | 12 | var ( 13 | mp map[string]queue.Queue 14 | mu sync.RWMutex 15 | ) 16 | 17 | type RedisQueue struct { 18 | client *goredis.Client 19 | } 20 | 21 | //new实例 22 | func newRedisQueue(diName string) queue.Queue { 23 | m := new(RedisQueue) 24 | m.client = redis.GetRedis(diName) 25 | return m 26 | } 27 | 28 | //单例模式 29 | func GetRedisQueue(diName string) queue.Queue { 30 | key := diName 31 | mu.RLock() 32 | q, ok := mp[key] 33 | mu.RUnlock() 34 | if ok { 35 | return q 36 | } 37 | 38 | q = newRedisQueue(diName) 39 | mu.Lock() 40 | mp[key] = q 41 | mu.Unlock() 42 | return q 43 | } 44 | 45 | /** 46 | * 队列消息入队 47 | */ 48 | func (m *RedisQueue) Enqueue(ctx context.Context, key string, message string, args ...interface{}) (bool, error) { 49 | //redis暂时不要延迟和优先级 50 | _, err := m.client.RPush(ctx, key, message).Result() 51 | if err != nil { 52 | return false, err 53 | } 54 | return true, err 55 | } 56 | 57 | /** 58 | * 队列消息出队 59 | */ 60 | func (m *RedisQueue) Dequeue(ctx context.Context, key string, args ...interface{}) (message string, tag string, token string, dequeueCount int64, err error) { 61 | // redis 出队次数暂用1 目前不支持统计这个次数 62 | dequeueCount = 0 63 | message, err = m.client.LPop(ctx, key).Result() 64 | if err == goredis.Nil { 65 | err = nil 66 | message = "" 67 | } 68 | return 69 | } 70 | 71 | /** 72 | * 确认消息接收 redis暂时用不到 73 | */ 74 | func (m *RedisQueue) AckMsg(ctx context.Context, key string, token string, args ...interface{}) (bool, error) { 75 | return true, nil 76 | } 77 | 78 | /** 79 | * 队列消息入队 80 | */ 81 | func (m *RedisQueue) BatchEnqueue(ctx context.Context, key string, messages []string, args ...interface{}) (bool, error) { 82 | //redis暂时不要延迟和优先级 83 | if len(messages) == 0 { 84 | return false, errors.New("messages is empty") 85 | } 86 | _, err := m.client.RPush(ctx, key, arrayStringToInterface(messages)...).Result() 87 | if err != nil { 88 | return false, err 89 | } 90 | return true, err 91 | } 92 | 93 | func arrayStringToInterface(arr []string) []interface{} { 94 | newArr := make([]interface{}, len(arr)) 95 | for k, v := range arr { 96 | newArr[k] = v 97 | } 98 | return newArr 99 | } 100 | 101 | func init() { 102 | mp = make(map[string]queue.Queue) 103 | queue.Register(queue.DriverTypeRedis, GetRedisQueue) 104 | } 105 | -------------------------------------------------------------------------------- /queue/redisqueue/redis_queue_test.go: -------------------------------------------------------------------------------- 1 | package redisqueue 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/qit-team/snow-core/config" 7 | "github.com/qit-team/snow-core/queue" 8 | "github.com/qit-team/snow-core/redis" 9 | "testing" 10 | ) 11 | 12 | var q queue.Queue 13 | 14 | func init() { 15 | redisConf := config.RedisConfig{ 16 | Master: config.RedisBaseConfig{ 17 | Host: "127.0.0.1", 18 | Port: 6379, 19 | }, 20 | } 21 | 22 | //注册redis类 23 | err := redis.Pr.Register("redis", redisConf, true) 24 | if err != nil { 25 | fmt.Println(err) 26 | } 27 | 28 | q = queue.GetQueue("redis", queue.DriverTypeRedis) 29 | } 30 | 31 | func TestEnqueue(t *testing.T) { 32 | q := queue.GetQueue("redis", queue.DriverTypeRedis) 33 | topic := "snow-topic-one" 34 | ctx := context.TODO() 35 | msg := "1" 36 | ok, err := q.Enqueue(ctx, topic, msg) 37 | if err != nil { 38 | t.Error(err) 39 | return 40 | } 41 | if !ok { 42 | t.Error("enqueue is not ok") 43 | return 44 | } 45 | 46 | message, token, _, _, err := q.Dequeue(ctx, topic) 47 | if err != nil { 48 | t.Error(err) 49 | return 50 | } 51 | if message != msg { 52 | t.Errorf("message is not same %s", message) 53 | return 54 | } 55 | 56 | ok, err = q.AckMsg(ctx, topic, token) 57 | if err != nil { 58 | t.Error(err) 59 | return 60 | } 61 | if !ok { 62 | t.Error("ack is not ok") 63 | return 64 | } 65 | 66 | message, _, _, _, err = q.Dequeue(ctx, topic) 67 | if err != nil { 68 | t.Error(err) 69 | return 70 | } else if message != "" { 71 | t.Error("message must be empty") 72 | return 73 | } 74 | } 75 | 76 | func TestBatchEnqueue(t *testing.T) { 77 | ctx := context.TODO() 78 | topic := "snow-topic-batch" 79 | messages := []string{"11", "21"} 80 | _, err := q.BatchEnqueue(ctx, topic, messages) 81 | if err != nil { 82 | t.Error("batch enqueue error", err) 83 | return 84 | } 85 | 86 | fmt.Println("batch enqueue", topic, messages) 87 | 88 | message, _, _, _, err := q.Dequeue(ctx, topic) 89 | if err != nil { 90 | t.Error(err) 91 | return 92 | } 93 | if message != messages[0] { 94 | t.Errorf("message is not same origin:%s real:%s", messages[0], message) 95 | return 96 | } 97 | 98 | message, _, _, _, err = q.Dequeue(ctx, topic) 99 | if err != nil { 100 | t.Error(err) 101 | return 102 | } 103 | if message != messages[1] { 104 | t.Errorf("message is not same origin:%s real:%s", messages[1], message) 105 | return 106 | } 107 | } 108 | 109 | func TestBatchEnqueueEmpty(t *testing.T) { 110 | ctx := context.TODO() 111 | topic := "snow-topic-batch" 112 | messages := make([]string, 0) 113 | _, err := q.BatchEnqueue(ctx, topic, messages) 114 | if err == nil { 115 | t.Error("empty message must return error") 116 | return 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /queue/rocketqueue/rocket_queue.go: -------------------------------------------------------------------------------- 1 | package rocketqueue 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "os" 8 | "os/signal" 9 | "strings" 10 | "sync" 11 | "syscall" 12 | 13 | "github.com/apache/rocketmq-client-go/v2" 14 | "github.com/apache/rocketmq-client-go/v2/consumer" 15 | "github.com/apache/rocketmq-client-go/v2/primitive" 16 | "github.com/gogap/errors" 17 | "github.com/qit-team/snow-core/log/logger" 18 | "github.com/qit-team/snow-core/queue" 19 | rkmq "github.com/qit-team/snow-core/rocketmq" 20 | ) 21 | 22 | var ( 23 | mp map[string]queue.Queue 24 | mu sync.RWMutex 25 | ) 26 | 27 | type RocketQueue struct { 28 | Consumer rocketmq.PushConsumer 29 | Producer rocketmq.Producer 30 | 31 | consumerMessageChan chan *primitive.MessageExt 32 | //producerMessageChan chan *primitive.MessageExt 33 | 34 | consumerOnce sync.Once 35 | producerOnce sync.Once 36 | } 37 | 38 | func (m *RocketQueue) initProducer(ctx context.Context) error { 39 | var err error 40 | m.producerOnce.Do( 41 | func() { 42 | err = m.Producer.Start() 43 | if err != nil { 44 | logger.Fatal(ctx, "RocketQueue:Producer:Start", err.Error()) 45 | return 46 | } 47 | }) 48 | return err 49 | } 50 | 51 | func (m *RocketQueue) initConsumer(ctx context.Context, topic, messageTag string, num int) error { 52 | var err error 53 | m.consumerOnce.Do( 54 | func() { 55 | m.consumerMessageChan = make(chan *primitive.MessageExt, num) 56 | 57 | var selector consumer.MessageSelector 58 | if len(messageTag) > 0 { 59 | selector = consumer.MessageSelector{ 60 | Type: consumer.TAG, 61 | Expression: messageTag, 62 | } 63 | } 64 | err = m.Consumer.Subscribe(topic, selector, func(ctx context.Context, messages ...*primitive.MessageExt) (consumer.ConsumeResult, error) { 65 | // 取到的消息放入管道,交给下游处理 66 | for _, msg := range messages { 67 | m.consumerMessageChan <- msg 68 | } 69 | 70 | return consumer.ConsumeSuccess, nil 71 | }) 72 | if err != nil { 73 | logger.Error(ctx, "RocketQueue:Subscribe", err.Error()) 74 | return 75 | } 76 | 77 | err = m.Consumer.Start() 78 | if err != nil { 79 | logger.Fatal(ctx, "RocketQueue:Start", err.Error()) 80 | return 81 | } 82 | 83 | go func() { 84 | var sigs = []os.Signal{ 85 | syscall.SIGHUP, 86 | syscall.SIGUSR1, 87 | syscall.SIGUSR2, 88 | syscall.SIGINT, 89 | syscall.SIGTERM, 90 | syscall.SIGQUIT, 91 | } 92 | c := make(chan os.Signal) 93 | signal.Notify(c, sigs...) 94 | for { 95 | sig := <-c //blocked 96 | switch sig { 97 | case syscall.SIGINT, syscall.SIGTERM: 98 | close(m.consumerMessageChan) 99 | err = m.Consumer.Shutdown() 100 | if err != nil { 101 | logger.Error(ctx, "Shutdown.Failure", err.Error()) 102 | return 103 | } 104 | return 105 | default: 106 | } 107 | } 108 | fmt.Println("停止订阅消息") 109 | }() 110 | }) 111 | if err != nil { 112 | logger.Error(ctx, "RocketQueue:initConsumer", err.Error()) 113 | return err 114 | } 115 | 116 | return nil 117 | } 118 | 119 | // new实例 120 | func newRocketQueue(diName string) queue.Queue { 121 | m := new(RocketQueue) 122 | client := rkmq.GetRocketMq(diName) 123 | 124 | m.Producer = client.Producer 125 | m.Consumer = client.Consumer 126 | 127 | return m 128 | } 129 | 130 | // GetRocketQueue 131 | // 132 | // 单例模式 133 | func GetRocketQueue(diName string) queue.Queue { 134 | key := diName 135 | mu.RLock() 136 | q, ok := mp[key] 137 | mu.RUnlock() 138 | if ok { 139 | return q 140 | } 141 | 142 | q = newRocketQueue(diName) 143 | 144 | mu.Lock() 145 | mp[key] = q 146 | mu.Unlock() 147 | return q 148 | } 149 | 150 | // Enqueue 队列消息入队 151 | // 152 | // args[0] instanceId 153 | func (m *RocketQueue) Enqueue(ctx context.Context, key string, message string, args ...interface{}) (bool, error) { 154 | err := m.initProducer(ctx) 155 | if err != nil { 156 | return false, err 157 | } 158 | _, _, messageTag, timeLevel := getOption(args...) 159 | log.Printf("messageTag: %v", messageTag) 160 | if len(messageTag) > 0 { 161 | tags := strings.Split(messageTag, "||") 162 | for i := 0; i < len(tags); i++ { 163 | tag := strings.Trim(tags[i%3], " ") 164 | msg := &primitive.Message{ 165 | Topic: key, 166 | Body: []byte(message), 167 | } 168 | msg.WithTag(tag) 169 | // https://rocketmq.apache.org/docs/4.x/producer/04message3/ 170 | if timeLevel > 0 && timeLevel <= 18 { 171 | msg.WithDelayTimeLevel(timeLevel) 172 | } 173 | log.Printf("send for tag: %v", tag) 174 | res, err := m.Producer.SendSync(context.Background(), msg) 175 | if err != nil { 176 | return false, err 177 | } 178 | //logger.Info(ctx, "Enqueue", res.String()) 179 | log.Printf("Enqueue: %s %v", message, res.MsgID) 180 | } 181 | } else { 182 | msg := &primitive.Message{ 183 | Topic: key, 184 | Body: []byte(message), 185 | } 186 | // https://rocketmq.apache.org/docs/4.x/producer/04message3/ 187 | if timeLevel > 0 && timeLevel <= 18 { 188 | msg.WithDelayTimeLevel(timeLevel) 189 | } 190 | res, err := m.Producer.SendSync(ctx, msg) 191 | if err != nil { 192 | return false, err 193 | } 194 | //logger.Info(ctx, "Enqueue", res.String()) 195 | log.Printf("Enqueue: %s %v", message, res.MsgID) 196 | } 197 | 198 | return true, nil 199 | } 200 | 201 | // Dequeue 队列消息出队 202 | // 203 | // param 第二个参数是队列名称,args[0]是instanceId,args[1]是groupId,目前只有rocketmq需要groupId 204 | // 205 | // return 第一个参数是消息 第二个参数是aliyunmq的ReceiptHandle命名为token,通过token确定消息是否从队列删除,第三个参数为消费次数 206 | func (m *RocketQueue) Dequeue(ctx context.Context, key string, args ...interface{}) (message string, tag string, token string, dequeueCount int64, err error) { 207 | _, _, messageTag, _ := getOption(args...) 208 | 209 | err = m.initConsumer(ctx, key, messageTag, 5) 210 | if err != nil { 211 | return 212 | } 213 | 214 | select { 215 | case msg, ok := <-m.consumerMessageChan: 216 | if !ok { 217 | return "", "", "", 0, nil 218 | } 219 | return string(msg.Body), msg.GetTags(), "", int64(msg.ReconsumeTimes), nil 220 | } 221 | } 222 | 223 | // BatchEnqueue 队列消息批量入队 224 | // args[0] instanceId 225 | // 注:rocket其实没有批量函数,所以循环调用publishMsg方法 226 | func (m *RocketQueue) BatchEnqueue(ctx context.Context, key string, messageList []string, args ...interface{}) (bool, error) { 227 | if len(messageList) == 0 { 228 | return false, errors.New("messageList is empty") 229 | } 230 | 231 | for _, message := range messageList { 232 | flag, err := m.Enqueue(ctx, key, message, args) 233 | if flag == false || err != nil { 234 | return flag, err 235 | } 236 | } 237 | 238 | return true, nil 239 | } 240 | 241 | // AckMsg 确认消息接收 242 | // args[0]是instanceId,args[1]是groupId,args[2]是messageTag, args[2]是delayTimeLevel 243 | func (m *RocketQueue) AckMsg(ctx context.Context, key string, token string, args ...interface{}) (bool, error) { 244 | return true, nil 245 | } 246 | 247 | // getOption 缺省参数统一获取 248 | // 249 | // args[0]是instanceId,args[1]是groupId,args[2]是messageTag, args[3]是delayTimeLevel 250 | func getOption(args ...interface{}) (instanceId, groupId, messageTag string, delayTimeLevel int) { 251 | instanceId = "" 252 | groupId = "" 253 | messageTag = "" 254 | delayTimeLevel = 0 255 | l := len(args) 256 | if l > 0 { 257 | tempInstance, ok := args[0].(string) 258 | if ok { 259 | instanceId = tempInstance 260 | } 261 | if l > 1 { 262 | tempGroup, ok := args[1].(string) 263 | if ok { 264 | groupId = tempGroup 265 | } 266 | } 267 | if l > 2 { 268 | tempTag, ok := args[2].(string) 269 | if ok { 270 | messageTag = tempTag 271 | } 272 | } 273 | if l > 3 { 274 | tempDelayTimeLevel, ok := args[3].(int) 275 | if ok { 276 | delayTimeLevel = tempDelayTimeLevel 277 | } 278 | } 279 | } 280 | return 281 | } 282 | 283 | func init() { 284 | mp = make(map[string]queue.Queue) 285 | queue.Register(queue.DriverTypeRocketMq, GetRocketQueue) 286 | } 287 | -------------------------------------------------------------------------------- /queue/rocketqueue/rocket_queue_test.go: -------------------------------------------------------------------------------- 1 | package rocketqueue 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "strings" 9 | "testing" 10 | 11 | "github.com/qit-team/snow-core/log/logger" 12 | 13 | "github.com/qit-team/snow-core/config" 14 | "github.com/qit-team/snow-core/queue" 15 | "github.com/qit-team/snow-core/rocketmq" 16 | ) 17 | 18 | var q queue.Queue 19 | 20 | func init() { 21 | // 需要自己在文件填好配置 22 | conf := config.RocketMqConfig{} 23 | conf = getConfig() 24 | //注册rocketmq类 25 | err := rocketmq.Pr.Register("rocket_mq", conf) 26 | logger.Pr.Register(logger.SingletonMain, config.LogConfig{ 27 | Handler: "stdout", 28 | Level: "debug", 29 | Segment: false, 30 | Dir: ".", 31 | FileName: "", 32 | }) 33 | 34 | if err != nil { 35 | fmt.Println(err) 36 | } 37 | 38 | q = queue.GetQueue("rocket_mq", queue.DriverTypeRocketMq) 39 | } 40 | 41 | func TestEnqueue(t *testing.T) { 42 | q := queue.GetQueue("rocket_mq", queue.DriverTypeRocketMq) 43 | topic := "SNOW-TOPIC-TEST" 44 | groupId := "GID-SNOW-TOPIC-TEST" 45 | ctx := context.TODO() 46 | msg := "msg from snow core" 47 | ok, err := q.Enqueue(ctx, topic, msg) 48 | if err != nil { 49 | t.Error(err) 50 | return 51 | } 52 | if !ok { 53 | t.Error("enqueue is not ok") 54 | return 55 | } 56 | 57 | message, tag, token, dequeueCount, err := q.Dequeue(ctx, topic, "", groupId) 58 | fmt.Println("message content:", message) 59 | fmt.Println("message tag:", tag) 60 | fmt.Println("message dequeue num:", dequeueCount) 61 | fmt.Println("message token:", token) 62 | if err != nil { 63 | t.Error(err) 64 | return 65 | } 66 | if message != msg { 67 | t.Errorf("message is not same %s", message) 68 | return 69 | } 70 | 71 | ok, err = q.AckMsg(ctx, topic, token, "", groupId) 72 | 73 | fmt.Println("info:", ok, err) 74 | if err != nil { 75 | t.Error(err) 76 | return 77 | } 78 | if !ok { 79 | t.Error("ack is not ok") 80 | return 81 | } 82 | 83 | message, tag, token, dequeueCount, err = q.Dequeue(ctx, topic, "", groupId) 84 | fmt.Println("message content:", message) 85 | fmt.Println("message tag:", tag) 86 | fmt.Println("message dequeue num:", dequeueCount) 87 | fmt.Println("message token:", token) 88 | if err != nil { 89 | t.Error(err) 90 | return 91 | } else if message != "" { 92 | t.Error("message from blank queue must be empty") 93 | return 94 | } 95 | 96 | _, err = q.AckMsg(ctx, topic, token, "", groupId) 97 | fmt.Println("ackMsg,errInfo", err) 98 | if !(err != nil && err.Error() == "token empty") { 99 | t.Error("must return empty ack token error") 100 | } 101 | } 102 | 103 | func TestBatchEnqueue(t *testing.T) { 104 | ctx := context.TODO() 105 | topic := "SNOW-TOPIC-TEST" 106 | groupId := "GID-SNOW-TOPIC-TEST" 107 | messages := []string{"11", "21"} 108 | _, err := q.BatchEnqueue(ctx, topic, messages) 109 | if err != nil { 110 | t.Error("batch enqueue error", err) 111 | return 112 | } 113 | 114 | fmt.Println("batch enqueue", topic, messages) 115 | 116 | message1, _, token1, dequeueCount, err := q.Dequeue(ctx, topic, "", groupId) 117 | if err != nil { 118 | t.Error(err) 119 | return 120 | } 121 | 122 | message2, _, token2, dequeueCount, err := q.Dequeue(ctx, topic, "", groupId) 123 | if err != nil { 124 | t.Error(err) 125 | return 126 | } 127 | fmt.Println("TestBatchEnqueue:dequeueCount:", dequeueCount) 128 | 129 | if message1 == messages[0] { 130 | if message2 != messages[1] { 131 | t.Errorf("message2 is not same origin:%s real:%s", messages[1], message2) 132 | return 133 | } 134 | } else if message2 == messages[0] { 135 | if message1 != messages[1] { 136 | t.Errorf("message1 is not same origin:%s real:%s", messages[1], message1) 137 | return 138 | } 139 | } else { 140 | t.Errorf("message is not same %s", messages[1]) 141 | return 142 | } 143 | 144 | ok, err := q.AckMsg(ctx, topic, token1, "", groupId) 145 | if err != nil { 146 | t.Errorf("message1 ack err:%s", err.Error()) 147 | return 148 | } 149 | if !ok { 150 | t.Error("message1 ack is not ok") 151 | return 152 | } 153 | 154 | ok, err = q.AckMsg(ctx, topic, token2, "", groupId) 155 | if err != nil { 156 | t.Errorf("message1 ack err:%s", err.Error()) 157 | return 158 | } 159 | if !ok { 160 | t.Error("message2 ack is not ok") 161 | return 162 | } 163 | } 164 | 165 | func TestBatchEnqueueEmpty(t *testing.T) { 166 | ctx := context.TODO() 167 | topic := "SNOW-TOPIC-TEST" 168 | groupId := "GID-SNOW-TOPIC-TEST" 169 | messages := make([]string, 0) 170 | _, err := q.BatchEnqueue(ctx, topic, messages, "", groupId) 171 | fmt.Println("TestBatchEnqueueEmpty.Error", err) 172 | if err == nil { 173 | t.Error("empty message must return error") 174 | return 175 | } 176 | } 177 | 178 | func Test_getOption(t *testing.T) { 179 | instanceId, groupId, _, _ := getOption("", "GID-SNOW-TOPIC-TEST") 180 | if instanceId != "" { 181 | t.Errorf("delay is not equal 1. %s", instanceId) 182 | } else if groupId != "GID-SNOW-TOPIC-TEST" { 183 | t.Errorf("priority is not equal 10. %s", groupId) 184 | } 185 | } 186 | 187 | func getConfig() config.RocketMqConfig { 188 | //需要自己在文件填好配置 189 | bs, err := ioutil.ReadFile("../../.env.rocketmq") 190 | 191 | conf := config.RocketMqConfig{} 192 | if err == nil { 193 | str := string(bs) 194 | arr := strings.Split(str, "\n") 195 | log.Print(arr) 196 | if len(arr) >= 3 { 197 | conf.EndPoint = arr[0] 198 | conf.AccessKey = arr[1] 199 | conf.SecretKey = arr[2] 200 | conf.InstanceId = arr[3] 201 | conf.GroupId = arr[4] 202 | } 203 | } 204 | return conf 205 | } 206 | -------------------------------------------------------------------------------- /redis/provider.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | goredis "github.com/go-redis/redis/v8" 7 | "github.com/qit-team/snow-core/config" 8 | "github.com/qit-team/snow-core/helper" 9 | "github.com/qit-team/snow-core/kernel/container" 10 | "sync" 11 | ) 12 | 13 | const ( 14 | SingletonMain = "redis" 15 | ) 16 | 17 | var Pr *provider 18 | 19 | func init() { 20 | Pr = new(provider) 21 | Pr.mp = make(map[string]interface{}) 22 | } 23 | 24 | type provider struct { 25 | mu sync.RWMutex 26 | mp map[string]interface{} //配置 27 | dn string //default name 28 | } 29 | 30 | /** 31 | * @param string 依赖注入别名 必选 32 | * @param config.RedisConfig 配置 必选 33 | * @param bool 是否启用懒加载 可选 34 | */ 35 | func (p *provider) Register(args ...interface{}) (err error) { 36 | diName, lazy, err := helper.TransformArgs(args...) 37 | if err != nil { 38 | return 39 | } 40 | 41 | conf, ok := args[1].(config.RedisConfig) 42 | if !ok { 43 | return errors.New("args[1] is not config.RedisConfig") 44 | } 45 | 46 | p.mu.Lock() 47 | p.mp[diName] = args[1] 48 | if len(p.mp) == 1 { 49 | p.dn = diName 50 | } 51 | p.mu.Unlock() 52 | 53 | if !lazy { 54 | if len(conf.Slaves) == 0 { 55 | _, err = setSingleton(diName, conf) 56 | } else { 57 | _, err = setClusterSingleton(diName, conf) 58 | } 59 | } 60 | return 61 | } 62 | 63 | // 注册过的别名 64 | func (p *provider) Provides() []string { 65 | p.mu.RLock() 66 | defer p.mu.RUnlock() 67 | 68 | return helper.MapToArray(p.mp) 69 | } 70 | 71 | // 释放资源 72 | func (p *provider) Close() error { 73 | arr := p.Provides() 74 | for _, k := range arr { 75 | c := getSingleton(k, false) 76 | if c != nil { 77 | c.Close() 78 | } 79 | } 80 | return nil 81 | } 82 | 83 | // 注入单例 84 | func setSingleton(diName string, conf config.RedisConfig) (ins *goredis.Client, err error) { 85 | ins, err = NewRedisClient(conf) 86 | if err != nil { 87 | return 88 | } 89 | if ins != nil { 90 | container.App.SetSingleton(diName, ins) 91 | } 92 | return 93 | } 94 | 95 | // 获取单例 96 | func getSingleton(diName string, lazy bool) *goredis.Client { 97 | rc := container.App.GetSingleton(diName) 98 | if rc != nil { 99 | return rc.(*goredis.Client) 100 | } 101 | if lazy == false { 102 | return nil 103 | } 104 | 105 | Pr.mu.RLock() 106 | conf, ok := Pr.mp[diName].(config.RedisConfig) 107 | Pr.mu.RUnlock() 108 | if !ok { 109 | panic(fmt.Sprintf("redis di_name:%s not exist", diName)) 110 | } 111 | 112 | ins, err := setSingleton(diName, conf) 113 | if err != nil { 114 | panic(fmt.Sprintf("redis di_name:%s err:%s", diName, err.Error())) 115 | } 116 | return ins 117 | } 118 | 119 | // 外部通过注入别名获取资源,解耦资源的关系 120 | func GetRedis(args ...string) *goredis.Client { 121 | diName := helper.GetDiName(Pr.dn, args...) 122 | return getSingleton(diName, true) 123 | } 124 | 125 | // 注入单例 126 | func setClusterSingleton(diName string, conf config.RedisConfig) (ins *goredis.ClusterClient, err error) { 127 | ins, err = NewClusterRedisClient(conf) 128 | if err == nil { 129 | container.App.SetSingleton(diName, ins) 130 | } 131 | return 132 | } 133 | 134 | // 获取单例 135 | func getClusterSingleton(diName string, lazy bool) *goredis.ClusterClient { 136 | rc := container.App.GetSingleton(diName) 137 | if rc != nil { 138 | return rc.(*goredis.ClusterClient) 139 | } 140 | if lazy == false { 141 | return nil 142 | } 143 | 144 | Pr.mu.RLock() 145 | conf, ok := Pr.mp[diName].(config.RedisConfig) 146 | Pr.mu.RUnlock() 147 | if !ok { 148 | panic(fmt.Sprintf("redis di_name:%s not exist", diName)) 149 | } 150 | 151 | ins, err := setClusterSingleton(diName, conf) 152 | if err != nil { 153 | panic(fmt.Sprintf("redis di_name:%s err:%s", diName, err.Error())) 154 | } 155 | return ins 156 | } 157 | 158 | // 获取集群模式redisClient 159 | func GetClusterRedis(args ...string) *goredis.ClusterClient { 160 | diName := helper.GetDiName(Pr.dn, args...) 161 | return getClusterSingleton(diName, true) 162 | } 163 | -------------------------------------------------------------------------------- /redis/provider_test.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "github.com/qit-team/snow-core/config" 5 | "testing" 6 | ) 7 | 8 | func Test_getSingleton(t *testing.T) { 9 | c := getSingleton("", false) 10 | if c != nil { 11 | t.Error("client is not equal nil") 12 | return 13 | } 14 | } 15 | 16 | func TestProvider(t *testing.T) { 17 | err := Pr.Register("redis", config.RedisConfig{}) 18 | if err == nil { 19 | t.Error(err) 20 | return 21 | } 22 | 23 | conf := config.RedisConfig{ 24 | Master: config.RedisBaseConfig{ 25 | Host: "127.0.0.1", 26 | Port: 6379, 27 | }, 28 | } 29 | 30 | err = Pr.Register("redis", conf, true) 31 | if err != nil { 32 | t.Error(err) 33 | return 34 | } 35 | 36 | arr := Pr.Provides() 37 | if !(len(arr) == 1 && arr[0] == "redis") { 38 | t.Errorf("Provides is not match. %v", arr) 39 | return 40 | } 41 | 42 | err = Pr.Register("redis1", conf) 43 | if err != nil { 44 | t.Error(err) 45 | return 46 | } 47 | 48 | arr = Pr.Provides() 49 | if !(len(arr) == 2 && arr[1] == "redis1" || arr[1] == "redis") { 50 | t.Errorf("Provides is not match. %v", arr) 51 | return 52 | } 53 | 54 | c := GetRedis() 55 | if c == nil { 56 | t.Error("client is equal nil") 57 | return 58 | } 59 | 60 | c1 := GetRedis("redis1") 61 | if c1 == nil { 62 | t.Error("client is equal nil") 63 | return 64 | } 65 | 66 | defer func() { 67 | if e := recover(); e != "redis di_name:redis2 not exist" { 68 | t.Error("not panic") 69 | } 70 | }() 71 | GetRedis("redis2") 72 | 73 | err = Pr.Close() 74 | if err != nil { 75 | t.Error(err) 76 | return 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /redis/redis.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | goredis "github.com/go-redis/redis/v8" 7 | "github.com/qit-team/snow-core/config" 8 | "time" 9 | ) 10 | 11 | type RedisConfig struct { 12 | Host string 13 | Port int 14 | Password string 15 | DB int 16 | } 17 | type Options struct { 18 | MaxIdle int 19 | MaxActive int 20 | Wait bool 21 | IdleTimeout time.Duration 22 | ConnectTimeout time.Duration 23 | ReadTimeout time.Duration 24 | WriteTimeout time.Duration 25 | } 26 | 27 | //redis连接池实例,不对外暴露,通过redis_service_provider实现依赖注入和资源获取 28 | func NewRedisClient(redisConf config.RedisConfig) (*goredis.Client, error) { 29 | if redisConf.Master.Host == "" { 30 | return nil, errors.New("redis config is empty") 31 | } 32 | 33 | rdb := goredis.NewClient(&goredis.Options{ 34 | Addr: fmt.Sprintf("%s:%d", redisConf.Master.Host, redisConf.Master.Port), 35 | Password: redisConf.Master.Password, 36 | DB: redisConf.Master.DB, 37 | }) 38 | return rdb, nil 39 | } 40 | 41 | //redis连接池实例,不对外暴露,通过redis_service_provider实现依赖注入和资源获取 42 | func NewClusterRedisClient(redisConf config.RedisConfig) (*goredis.ClusterClient, error) { 43 | if redisConf.Master.Host == "" { 44 | return nil, errors.New("redis config is empty") 45 | } 46 | 47 | addrs := []string{} 48 | addrs = append(addrs, fmt.Sprintf("%s:%d", redisConf.Master.Host, redisConf.Master.Port)) 49 | for _, slave := range redisConf.Slaves { 50 | addrs = append(addrs, fmt.Sprintf("%s:%d", slave.Host, slave.Port)) 51 | } 52 | rdb := goredis.NewClusterClient(&goredis.ClusterOptions{ 53 | Addrs: addrs, 54 | Password: redisConf.Master.Password, 55 | }) 56 | return rdb, nil 57 | } 58 | 59 | func genRedisConfig(c config.RedisBaseConfig) RedisConfig { 60 | return RedisConfig{ 61 | Host: c.Host, 62 | Port: c.Port, 63 | Password: c.Password, 64 | DB: c.DB, 65 | } 66 | } 67 | 68 | func genOptions(c config.RedisOptionConfig) Options { 69 | return Options{ 70 | MaxIdle: c.MaxIdle, 71 | MaxActive: c.MaxConns, 72 | Wait: c.Wait, 73 | IdleTimeout: c.IdleTimeout * time.Second, 74 | ConnectTimeout: c.ConnectTimeout * time.Second, 75 | ReadTimeout: c.ReadTimeout * time.Second, 76 | WriteTimeout: c.WriteTimeout * time.Second, 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /redis/redis_test.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "context" 5 | "github.com/qit-team/snow-core/config" 6 | "reflect" 7 | "strconv" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | var conf config.RedisConfig 13 | 14 | func init() { 15 | conf = config.RedisConfig{ 16 | Master: config.RedisBaseConfig{ 17 | Host: "127.0.0.1", 18 | Port: 6379, 19 | }, 20 | } 21 | } 22 | 23 | func TestGetSet(t *testing.T) { 24 | _, err := NewRedisClient(config.RedisConfig{}) 25 | if err == nil { 26 | t.Error("redis config donot check") 27 | return 28 | } 29 | 30 | client, err := NewRedisClient(conf) 31 | if err != nil { 32 | t.Error("client init failed") 33 | return 34 | } 35 | 36 | value := 11 37 | res, _ := client.Set(context.TODO(), "hts", value, 100).Result() 38 | t.Log(res, reflect.TypeOf(res)) 39 | if res == "" { 40 | t.Error("set error") 41 | return 42 | } 43 | 44 | res1, _ := client.Get(context.TODO(), "hts").Result() 45 | t.Log(res1, reflect.TypeOf(res1)) 46 | if res1 == "" { 47 | t.Error("get error") 48 | return 49 | } else if res1 != strconv.Itoa(value) { 50 | t.Error("not same") 51 | return 52 | } 53 | } 54 | 55 | func Test_genRedisConfig(t *testing.T) { 56 | conf := config.RedisBaseConfig{ 57 | Host: "127.0.0.1", 58 | Port: 6379, 59 | } 60 | newConf := genRedisConfig(conf) 61 | if newConf.Host != conf.Host || newConf.Port != conf.Port || newConf.DB != conf.DB { 62 | t.Error("genRedisConfig failed") 63 | return 64 | } 65 | } 66 | 67 | func Test_genOptions(t *testing.T) { 68 | conf := config.RedisOptionConfig{ 69 | MaxConns: 64, 70 | Wait: true, 71 | IdleTimeout: 3, 72 | } 73 | newConf := genOptions(conf) 74 | if newConf.MaxIdle != 0 || newConf.Wait != conf.Wait || newConf.MaxActive != conf.MaxConns { 75 | t.Error("genOptions failed") 76 | return 77 | } else if newConf.IdleTimeout != 3*time.Second || newConf.ConnectTimeout != 0*time.Second { 78 | t.Error("genOptions failed") 79 | return 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /rocketmq/provider.go: -------------------------------------------------------------------------------- 1 | package rocketmq 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "sync" 7 | 8 | "github.com/qit-team/snow-core/config" 9 | "github.com/qit-team/snow-core/helper" 10 | "github.com/qit-team/snow-core/kernel/container" 11 | ) 12 | 13 | const ( 14 | SingletonMain = "rocketmq" 15 | ) 16 | 17 | var Pr *provider 18 | 19 | func init() { 20 | Pr = new(provider) 21 | Pr.mp = make(map[string]interface{}) 22 | } 23 | 24 | type provider struct { 25 | mu sync.RWMutex 26 | mp map[string]interface{} //配置 27 | dn string //default name 28 | } 29 | 30 | /** 31 | * @param string 依赖注入别名 必选 32 | * @param config.LogConfig 配置 必选 33 | * @param bool 是否启用懒加载 可选 34 | */ 35 | func (p *provider) Register(args ...interface{}) (err error) { 36 | diName, lazy, err := helper.TransformArgs(args...) 37 | if err != nil { 38 | return 39 | } 40 | 41 | conf, ok := args[1].(config.RocketMqConfig) 42 | if !ok { 43 | return errors.New("args[1] is not config.RocketMqConfig") 44 | } 45 | 46 | p.mu.Lock() 47 | p.mp[diName] = args[1] 48 | if len(p.mp) == 1 { 49 | p.dn = diName 50 | } 51 | p.mu.Unlock() 52 | 53 | if !lazy { 54 | _, err = setSingleton(diName, conf) 55 | } 56 | return 57 | } 58 | 59 | //注册过的别名 60 | func (p *provider) Provides() []string { 61 | p.mu.RLock() 62 | defer p.mu.RUnlock() 63 | 64 | return helper.MapToArray(p.mp) 65 | } 66 | 67 | //释放资源 68 | func (p *provider) Close() error { 69 | return nil 70 | } 71 | 72 | //注入单例 73 | func setSingleton(diName string, conf config.RocketMqConfig) (ins *RocketClient, err error) { 74 | ins, err = NewRocketMqClient(conf) 75 | if err == nil { 76 | container.App.SetSingleton(diName, ins) 77 | } 78 | return 79 | } 80 | 81 | //获取单例 82 | func getSingleton(diName string, lazy bool) *RocketClient { 83 | rc := container.App.GetSingleton(diName) 84 | if rc != nil { 85 | return rc.(*RocketClient) 86 | } 87 | if lazy == false { 88 | return nil 89 | } 90 | 91 | Pr.mu.RLock() 92 | conf, ok := Pr.mp[diName].(config.RocketMqConfig) 93 | Pr.mu.RUnlock() 94 | if !ok { 95 | panic(fmt.Sprintf("rocket_mq di_name:%s not exist", diName)) 96 | } 97 | 98 | ins, err := setSingleton(diName, conf) 99 | if err != nil { 100 | panic(fmt.Sprintf("rocket di_name: %s err: %s", diName, err.Error())) 101 | } 102 | return ins 103 | } 104 | 105 | //外部通过注入别名获取资源,解耦资源的关系 106 | func GetRocketMq(args ...string) *RocketClient { 107 | diName := helper.GetDiName(Pr.dn, args...) 108 | return getSingleton(diName, true) 109 | } 110 | -------------------------------------------------------------------------------- /rocketmq/rocketmq.go: -------------------------------------------------------------------------------- 1 | package rocketmq 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "github.com/apache/rocketmq-client-go/v2" 8 | "github.com/apache/rocketmq-client-go/v2/consumer" 9 | "github.com/apache/rocketmq-client-go/v2/primitive" 10 | "github.com/apache/rocketmq-client-go/v2/producer" 11 | "github.com/qit-team/snow-core/config" 12 | ) 13 | 14 | //依赖注入用的函数 15 | func NewRocketMqClient(mqConfig config.RocketMqConfig) (client *RocketClient, err error) { 16 | // 初始化aliyunmq的 client 17 | defer func() { 18 | if e := recover(); e != nil { 19 | s := fmt.Sprintf("rocketmq client init panic: %v", e) 20 | err = errors.New(s) 21 | } 22 | }() 23 | 24 | if mqConfig.EndPoint != "" { 25 | client = new(RocketClient) 26 | 27 | consumerOptions := make([]consumer.Option, 0) 28 | consumerOptions = []consumer.Option{ 29 | consumer.WithNameServer([]string{mqConfig.EndPoint}), 30 | consumer.WithCredentials(primitive.Credentials{ 31 | AccessKey: mqConfig.AccessKey, 32 | SecretKey: mqConfig.SecretKey, 33 | }), 34 | consumer.WithGroupName(mqConfig.GroupId), 35 | consumer.WithNamespace(mqConfig.InstanceId), 36 | consumer.WithConsumerModel(consumer.Clustering), 37 | consumer.WithConsumeFromWhere(consumer.ConsumeFromFirstOffset), 38 | consumer.WithAutoCommit(true), 39 | } 40 | if len(mqConfig.ConsumerOptions) > 0 { 41 | consumerOptions = append(consumerOptions, mqConfig.ConsumerOptions...) 42 | } 43 | 44 | client.Consumer, err = rocketmq.NewPushConsumer(consumerOptions...) 45 | if err != nil { 46 | return nil, err 47 | } 48 | 49 | producerOptions := make([]producer.Option, 0) 50 | producerOptions = []producer.Option{ 51 | producer.WithNameServer([]string{mqConfig.EndPoint}), 52 | producer.WithCredentials(primitive.Credentials{ 53 | AccessKey: mqConfig.AccessKey, 54 | SecretKey: mqConfig.SecretKey, 55 | }), 56 | producer.WithRetry(2), 57 | producer.WithGroupName(mqConfig.GroupId), 58 | producer.WithNamespace(mqConfig.InstanceId), 59 | } 60 | if len(mqConfig.ProducerOptions) > 0 { 61 | producerOptions = append(producerOptions, mqConfig.ProducerOptions...) 62 | } 63 | 64 | client.Producer, err = rocketmq.NewProducer(producerOptions...) 65 | if err != nil { 66 | return nil, err 67 | } 68 | 69 | } else { 70 | err = errors.New("EndPoint empty, can not get client") 71 | } 72 | return 73 | } 74 | 75 | type RocketClient struct { 76 | Consumer rocketmq.PushConsumer 77 | Producer rocketmq.Producer 78 | } 79 | -------------------------------------------------------------------------------- /utils/base62.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | //todo待优化 4 | 5 | import ( 6 | "math" 7 | "strings" 8 | ) 9 | 10 | const ( 11 | code62 = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 12 | codeLen = 62 13 | ) 14 | 15 | var codeMap = map[string]int64{"0": 0, "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "a": 10, "b": 11, "c": 12, "d": 13, "e": 14, "f": 15, "g": 16, "h": 17, "i": 18, "j": 19, "k": 20, "l": 21, "m": 22, "n": 23, "o": 24, "p": 25, "q": 26, "r": 27, "s": 28, "t": 29, "u": 30, "v": 31, "w": 32, "x": 33, "y": 34, "z": 35, "A": 36, "B": 37, "C": 38, "D": 39, "E": 40, "F": 41, "G": 42, "H": 43, "I": 44, "J": 45, "K": 46, "L": 47, "M": 48, "N": 49, "O": 50, "P": 51, "Q": 52, "R": 53, "S": 54, "T": 55, "U": 56, "V": 57, "W": 58, "X": 59, "Y": 60, "Z": 61} 16 | 17 | /** 18 | * 编码 整数 为 base62 字符串 19 | */ 20 | func Encode62(number int64) string { 21 | if number == 0 { 22 | return "0" 23 | } 24 | result := make([]byte, 0) 25 | for number > 0 { 26 | round := number / codeLen 27 | remain := number % codeLen 28 | result = append(result, code62[remain]) 29 | number = round 30 | } 31 | return string(result) 32 | } 33 | 34 | /** 35 | * 解码字符串为整数 36 | */ 37 | func Decode62(str string) int64 { 38 | str = strings.TrimSpace(str) 39 | var result int64 40 | for index, char := range []byte(str) { 41 | result += codeMap[string(char)] * int64(math.Pow(codeLen, float64(index))) 42 | } 43 | return result 44 | } 45 | -------------------------------------------------------------------------------- /utils/base62_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestEncode62(t *testing.T) { 8 | num := int64(1122) 9 | s := Encode62(num) 10 | num1 := Decode62(s) 11 | if num != num1 { 12 | t.Error("it is not reversible") 13 | return 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /utils/convert.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | func SliceStr2Interface(input []string) (output []interface{}) { 8 | for _, v := range input { 9 | output = append(output, v) 10 | } 11 | return 12 | } 13 | 14 | func MapStrInterface2MapStrStr(input map[string]interface{}) (output map[string]string) { 15 | output = make(map[string]string) 16 | for k, v := range input { 17 | output[k] = Interface2Str(v) 18 | } 19 | return 20 | } 21 | 22 | //interface转换为字符串 23 | func Interface2Str(v interface{}) string { 24 | switch v.(type) { 25 | case []rune: 26 | return string(v.([]rune)) 27 | default: 28 | return fmt.Sprint(v) 29 | } 30 | } 31 | 32 | func Num2Str(u interface{}) string { 33 | return fmt.Sprintf("%d", u) 34 | } 35 | -------------------------------------------------------------------------------- /utils/convert_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestMapStrInterface2MapStrStr(t *testing.T) { 8 | mp := map[string]interface{}{ 9 | "a": 1, 10 | "b": "bb", 11 | "c": 3.2, 12 | "d": false, 13 | } 14 | mp1 := MapStrInterface2MapStrStr(mp) 15 | if mp1["a"] != "1" { 16 | t.Error("map a failed") 17 | return 18 | } else if mp1["b"] != "bb" { 19 | t.Error("map b failed") 20 | return 21 | } else if mp1["c"] != "3.2" { 22 | t.Error("map c failed") 23 | return 24 | } else if mp1["d"] != "false" { 25 | t.Error("map d failed") 26 | return 27 | } 28 | } 29 | 30 | func TestNum2Str(t *testing.T) { 31 | num := 2211 32 | s := Num2Str(num) 33 | if s != "2211" { 34 | t.Error("Num2Str failed") 35 | return 36 | } 37 | } 38 | 39 | func TestSliceStr2Interface(t *testing.T) { 40 | arr := []string{ 41 | "a", "b", 42 | } 43 | arr2 := SliceStr2Interface(arr) 44 | if !(arr2[0] == "a" && arr2[1] == "b") { 45 | t.Error("SliceStr2Interface failed") 46 | return 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /utils/hash.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "crypto/md5" 5 | "encoding/hex" 6 | ) 7 | 8 | func GetMd5Hash(text string) string { 9 | hasher := md5.New() 10 | hasher.Write([]byte(text)) 11 | return hex.EncodeToString(hasher.Sum(nil)) 12 | } 13 | -------------------------------------------------------------------------------- /utils/hash_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestGetMd5Hash(t *testing.T) { 8 | s := GetMd5Hash("ss") 9 | if len(s) != 32 { 10 | t.Errorf("length of md5 string is not equal 16. %s", s) 11 | return 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /utils/httputil/http.go: -------------------------------------------------------------------------------- 1 | package httputil 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "github.com/qit-team/snow-core/http/ctxkit" 8 | "github.com/qit-team/snow-core/utils" 9 | "io/ioutil" 10 | "net/http" 11 | "strings" 12 | "time" 13 | ) 14 | 15 | const ( 16 | ContentTypeJSON = "application/json" 17 | ContentTypeForm = "application/x-www-form-urlencoded" 18 | ) 19 | 20 | type myClient struct { 21 | cli *http.Client 22 | } 23 | 24 | type Client interface { 25 | // Do 发送单个 http 请求 26 | Do(ctx context.Context, req *http.Request) (*http.Response, error) 27 | } 28 | 29 | // NewClient 创建 Client 实例 30 | func NewClient(timeout time.Duration) Client { 31 | return &myClient{ 32 | cli: &http.Client{ 33 | Timeout: timeout, 34 | }, 35 | } 36 | } 37 | 38 | //发送请求 39 | func (c *myClient) Do(ctx context.Context, req *http.Request) (resp *http.Response, err error) { 40 | //将ctx中的traceid放在header中向下传递 41 | SetTraceIdInHeader(ctx, req) 42 | req = req.WithContext(ctx) 43 | resp, err = c.cli.Do(req) 44 | httpCode := http.StatusOK 45 | if err != nil { 46 | httpCode = http.StatusGatewayTimeout 47 | } else { 48 | httpCode = resp.StatusCode 49 | } 50 | 51 | if httpCode != http.StatusOK { 52 | var errMsg string 53 | if err != nil { 54 | errMsg = err.Error() 55 | } 56 | msg := fmt.Sprintf("%s %s%s http_code(%d) err(%s)", 57 | req.Method, req.URL.Host, req.URL.Path, httpCode, errMsg) 58 | err = errors.New(msg) 59 | return 60 | } 61 | return 62 | } 63 | func SetTraceIdInHeader(ctx context.Context, req *http.Request) { 64 | traceId := ctxkit.GetTraceId(ctx) 65 | if len(traceId) > 0 { 66 | req.Header.Set("X-TRACE-ID", traceId) 67 | } 68 | } 69 | 70 | /** 71 | * GET Request对象 72 | * @param url 请求URL 73 | * @param params 请求参数 74 | * @param headers 可选 支持map[string]interface{}和[]string eg.{"Token":"123"}或["Token:123"] 75 | * @param options 可选 支持map[string]interface{} eg.{"timeout": 10} 76 | */ 77 | func NewGetRequest(url string, params map[string]interface{}, args ...interface{}) (req *http.Request, err error) { 78 | if params != nil { 79 | paramStr := utils.HttpBuildQuery(params) 80 | var op string 81 | if strings.Index(url, "?") == -1 { 82 | op = "?" 83 | } else { 84 | op = "&" 85 | } 86 | url = utils.Join(url, op, paramStr) 87 | } 88 | 89 | req, err = http.NewRequest("GET", url, nil) 90 | if err != nil { 91 | return 92 | } 93 | if len(args) > 0 { 94 | SetHeaders(req, args[0]) 95 | } 96 | return 97 | } 98 | 99 | //表单POST Request对象 100 | func NewFormPostRequest(url string, params map[string]interface{}, args ...interface{}) (req *http.Request, err error) { 101 | var paramStr string 102 | if params != nil { 103 | paramStr = utils.HttpBuildQuery(params) 104 | } else { 105 | paramStr = "" 106 | } 107 | 108 | req, err = http.NewRequest("POST", url, strings.NewReader(paramStr)) 109 | if err != nil { 110 | return 111 | } 112 | req.Header.Set("Content-Type", ContentTypeForm) 113 | if len(args) > 0 { 114 | SetHeaders(req, args[0]) 115 | } 116 | return 117 | } 118 | 119 | //JSON POST Request对象 120 | func NewJsonPostRequest(url string, params map[string]interface{}, args ...interface{}) (req *http.Request, err error) { 121 | var paramStr string 122 | if params != nil { 123 | paramStr, err = utils.JsonEncode(params) 124 | if err != nil { 125 | return 126 | } 127 | } else { 128 | paramStr = "" 129 | } 130 | 131 | req, err = http.NewRequest("POST", url, strings.NewReader(paramStr)) 132 | if err != nil { 133 | return 134 | } 135 | req.Header.Set("Content-Type", ContentTypeJSON) 136 | if len(args) > 0 { 137 | SetHeaders(req, args[0]) 138 | } 139 | return 140 | } 141 | 142 | func Get(ctx context.Context, url string, params map[string]interface{}, args ...interface{}) (resp *http.Response, err error) { 143 | timeout := getTimeout(args...) 144 | client := NewClient(timeout) 145 | req, err := NewGetRequest(url, params, args...) 146 | if err != nil { 147 | return 148 | } 149 | resp, err = client.Do(ctx, req) 150 | return 151 | } 152 | 153 | func Post(ctx context.Context, url string, params map[string]interface{}, args ...interface{}) (resp *http.Response, err error) { 154 | timeout := getTimeout(args...) 155 | client := NewClient(timeout) 156 | req, err := NewFormPostRequest(url, params, args...) 157 | if err != nil { 158 | return 159 | } 160 | resp, err = client.Do(ctx, req) 161 | return 162 | } 163 | 164 | func PostJson(ctx context.Context, url string, params map[string]interface{}, args ...interface{}) (resp *http.Response, err error) { 165 | timeout := getTimeout(args...) 166 | client := NewClient(timeout) 167 | req, err := NewJsonPostRequest(url, params, args...) 168 | if err != nil { 169 | return 170 | } 171 | resp, err = client.Do(ctx, req) 172 | return 173 | } 174 | 175 | func Request(ctx context.Context, method string, url string, params map[string]interface{}, args ...interface{}) (resp *http.Response, err error) { 176 | timeout := getTimeout(args...) 177 | client := NewClient(timeout) 178 | var req *http.Request 179 | if strings.ToUpper(method) == "POST" { 180 | req, err = NewFormPostRequest(url, params, args...) 181 | } else if strings.ToUpper(method) == "POST/JSON" { 182 | req, err = NewJsonPostRequest(url, params, args...) 183 | } else { 184 | req, err = NewGetRequest(url, params, args...) 185 | } 186 | if err != nil { 187 | return 188 | } 189 | resp, err = client.Do(ctx, req) 190 | return 191 | } 192 | 193 | //处理返回结果 194 | func DealResponse(resp *http.Response) (body []byte, err error) { 195 | defer resp.Body.Close() 196 | body, err = ioutil.ReadAll(resp.Body) 197 | return 198 | } 199 | 200 | //设置请求头 201 | func SetHeaders(req *http.Request, headers interface{}) { 202 | switch headers.(type) { 203 | case map[string]string: 204 | hs := headers.(map[string]string) 205 | for k, v := range hs { 206 | req.Header.Set(k, v) 207 | } 208 | return 209 | case []string: 210 | hs := headers.([]string) 211 | for _, v := range hs { 212 | strArr := strings.SplitN(v, ":", 2) 213 | if len(strArr) >= 2 { 214 | req.Header.Set(strArr[0], strings.Trim(strArr[1], " ")) 215 | } 216 | } 217 | return 218 | } 219 | } 220 | 221 | func StringListToMap(strArr []string) map[string]interface{} { 222 | m := make(map[string]interface{}) 223 | for _, v := range strArr { 224 | s := strings.SplitN(v, ":", 2) 225 | if len(s) >= 2 { 226 | m[s[0]] = strings.Trim(s[1], " ") 227 | } 228 | } 229 | return m 230 | } 231 | 232 | //options 233 | func getOptions(args ...interface{}) (options map[string]interface{}) { 234 | if len(args) > 1 { 235 | options, _ = args[1].(map[string]interface{}) 236 | } 237 | if options == nil { 238 | options = make(map[string]interface{}) 239 | } 240 | return 241 | } 242 | 243 | //timeout 244 | func getTimeout(args ...interface{}) time.Duration { 245 | options := getOptions(args...) 246 | var timeout int 247 | if t, ok := options["timeout"]; ok { 248 | timeout, _ = t.(int) 249 | } 250 | if timeout <= 0 { 251 | timeout = 30 252 | } 253 | return time.Second * time.Duration(timeout) 254 | } 255 | -------------------------------------------------------------------------------- /utils/httputil/http_test.go: -------------------------------------------------------------------------------- 1 | package httputil 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "github.com/gin-gonic/gin" 7 | "github.com/qit-team/snow-core/http/ctxkit" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | type responseData struct { 13 | Code int `json:"code"` 14 | Message string `json:"message"` 15 | Data map[string]interface{} `json:"data"` 16 | } 17 | 18 | var client Client 19 | var c *gin.Context 20 | 21 | func init() { 22 | client = NewClient(time.Second * 5) 23 | c = &gin.Context{} 24 | } 25 | 26 | /** 27 | 200, 30 | "message" => "ok", 31 | "data" => [ 32 | "type" => $_SERVER['CONTENT_TYPE'], 33 | "post" => $_POST, 34 | "get" => $_GET, 35 | "input" => file_get_contents("php://input") 36 | ] 37 | ]; 38 | echo json_encode($data); 39 | */ 40 | func TestDoGet(t *testing.T) { 41 | url := "http://localhost:8080/hello" 42 | req, _ := NewGetRequest(url, nil) 43 | ctxkit.GenerateTraceId(c) 44 | response, err := client.Do(c, req) 45 | if err != nil { 46 | t.Error(err) 47 | return 48 | } 49 | result, err := DealResponse(response) 50 | 51 | resp := new(responseData) 52 | json.Unmarshal(result, resp) 53 | if resp.Code != 200 { 54 | t.Error("get result is not ok") 55 | return 56 | } 57 | } 58 | 59 | func TestPost(t *testing.T) { 60 | url := "http://localhost:8080/testPost" 61 | // 参数为空 62 | req, err := NewFormPostRequest(url, nil) 63 | response, err := client.Do(context.TODO(), req) 64 | if err != nil { 65 | t.Error(err) 66 | return 67 | } 68 | result, err := DealResponse(response) 69 | resp := new(responseData) 70 | json.Unmarshal(result, resp) 71 | 72 | if resp.Code != 200 { 73 | t.Error("post result is not ok") 74 | return 75 | } else if resp.Data["type"] != ContentTypeForm { 76 | t.Error("post content-type is not equal " + ContentTypeForm) 77 | return 78 | } 79 | 80 | //参数为空map 81 | req, err = NewFormPostRequest(url, make(map[string]interface{})) 82 | response, err = client.Do(context.TODO(), req) 83 | if err != nil { 84 | t.Error(err) 85 | return 86 | } 87 | result, err = DealResponse(response) 88 | resp = new(responseData) 89 | json.Unmarshal(result, resp) 90 | if resp.Code != 200 { 91 | t.Error("post result is not ok") 92 | return 93 | } 94 | 95 | //参数非空map 96 | params := map[string]interface{}{ 97 | "name": "hts", 98 | } 99 | req, err = NewFormPostRequest(url, params) 100 | response, err = client.Do(context.TODO(), req) 101 | if err != nil { 102 | t.Error(err) 103 | return 104 | } 105 | result, err = DealResponse(response) 106 | resp = new(responseData) 107 | json.Unmarshal(result, resp) 108 | if resp.Code != 200 { 109 | t.Error("post result is not ok") 110 | return 111 | } 112 | } 113 | 114 | func TestPostJsonData(t *testing.T) { 115 | url := "http://localhost:8080/test" 116 | 117 | //参数为nil 118 | req, err := NewJsonPostRequest(url, nil) 119 | response, err := client.Do(context.TODO(), req) 120 | if err != nil { 121 | t.Error(err) 122 | return 123 | } 124 | result, err := DealResponse(response) 125 | resp := new(responseData) 126 | json.Unmarshal(result, resp) 127 | if resp.Code != 400 { 128 | t.Error("postJsonData result is not ok") 129 | return 130 | } 131 | 132 | //参数为空map 133 | req, err = NewJsonPostRequest(url, make(map[string]interface{})) 134 | response, err = client.Do(context.TODO(), req) 135 | if err != nil { 136 | t.Error(err) 137 | return 138 | } 139 | result, err = DealResponse(response) 140 | resp = new(responseData) 141 | json.Unmarshal(result, resp) 142 | if resp.Code != 200 { 143 | t.Error("postJsonData result is not ok") 144 | return 145 | } 146 | 147 | //参数非空map 148 | params := map[string]interface{}{ 149 | "name": "hts", 150 | } 151 | req, err = NewJsonPostRequest(url, params) 152 | response, err = client.Do(context.TODO(), req) 153 | if err != nil { 154 | t.Error(err) 155 | return 156 | } 157 | resp = new(responseData) 158 | json.Unmarshal(result, resp) 159 | if resp.Code != 200 { 160 | t.Error("postJsonData result is not ok") 161 | return 162 | } 163 | } 164 | 165 | func TestStringListToMap(t *testing.T) { 166 | m := StringListToMap([]string{"hts:11 ", "name:", "key", "v: 1:2"}) 167 | _, ok := m["key"] 168 | if ok { 169 | t.Error("not right filter") 170 | return 171 | } 172 | 173 | val, ok := m["hts"] 174 | if val != "11" { 175 | t.Error("not right trim") 176 | return 177 | } 178 | 179 | val, ok = m["v"] 180 | if val != "1:2" { 181 | t.Error("not right split") 182 | return 183 | } 184 | } 185 | -------------------------------------------------------------------------------- /utils/iputil/ip.go: -------------------------------------------------------------------------------- 1 | package iputil 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | ) 7 | 8 | var ( 9 | ErrNotFound = errors.New("cannot find internal ip") 10 | ) 11 | 12 | //内网IP 13 | func GetInternalIp() (string, error) { 14 | addrs, err := net.InterfaceAddrs() 15 | if err != nil { 16 | return "", err 17 | } 18 | 19 | for _, address := range addrs { 20 | // 检查ip地址判断是否回环地址 21 | if ipNet, ok := address.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { 22 | if ipNet.IP.To4() != nil { 23 | return ipNet.IP.String(), nil 24 | } 25 | 26 | } 27 | } 28 | 29 | return "", ErrNotFound 30 | } 31 | -------------------------------------------------------------------------------- /utils/iputil/ip_test.go: -------------------------------------------------------------------------------- 1 | package iputil 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | //内网IP 8 | func TestGetInternalIp(t *testing.T) { 9 | str, err := GetInternalIp() 10 | if err != nil { 11 | t.Error(err) 12 | return 13 | } else if len(str) == 0 { 14 | t.Error("get internal ip failed") 15 | return 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /utils/json.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import "encoding/json" 4 | 5 | func JsonEncode(v interface{}) (string, error) { 6 | bytes, err := json.Marshal(v) 7 | if err != nil { 8 | return "", err 9 | } 10 | return string(bytes), nil 11 | } 12 | -------------------------------------------------------------------------------- /utils/json_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestJsonEncode(t *testing.T) { 8 | //只申明,未初始化 9 | var p1 map[string]interface{} 10 | s1, _ := JsonEncode(p1) 11 | if s1 != "null" { 12 | t.Error("nil map is not equal {}", s1) 13 | } 14 | 15 | //已初始化 16 | p2 := make(map[string]interface{}) 17 | s2, _ := JsonEncode(p2) 18 | if s2 != "{}" { 19 | t.Error("blank map is not equal {}", s2) 20 | } 21 | 22 | //已初始化 23 | p3 := map[string]interface{}{ 24 | "name": "hts", 25 | } 26 | s3, _ := JsonEncode(p3) 27 | if s3 != "{\"name\":\"hts\"}" { 28 | t.Error("map is not equal", s3) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /utils/string.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | //截取字符串 start 起点下标 length 需要截取的长度 8 | func Substr(str string, start int, length int) string { 9 | rs := []rune(str) 10 | rl := len(rs) 11 | end := 0 12 | 13 | if start < 0 { 14 | start = rl - 1 + start 15 | } 16 | end = start + length 17 | 18 | if start > end { 19 | start, end = end, start 20 | } 21 | 22 | if start < 0 { 23 | start = 0 24 | } 25 | if start > rl { 26 | start = rl 27 | } 28 | if end < 0 { 29 | end = 0 30 | } 31 | if end > rl { 32 | end = rl 33 | } 34 | 35 | return string(rs[start:end]) 36 | } 37 | 38 | //字符串拼接 39 | func Join(s ...string) string { 40 | return strings.Join(s, "") 41 | } 42 | -------------------------------------------------------------------------------- /utils/string_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestSubstr(t *testing.T) { 8 | str := "1234567890" 9 | s := Substr(str, 1, 2) 10 | if s != "23" { 11 | t.Error("substr failed") 12 | return 13 | } 14 | 15 | s = Substr(str, -2, 1) 16 | if s != "8" { 17 | t.Error("substr failed") 18 | return 19 | } 20 | 21 | s = Substr(str, -1, 0) 22 | if s != "" { 23 | t.Error("substr failed") 24 | return 25 | } 26 | } 27 | 28 | func TestJoin(t *testing.T) { 29 | s := Join("1", "2", "a") 30 | if s != "12a" { 31 | t.Error("join failed") 32 | return 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /utils/time.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import "time" 4 | 5 | //获取当前的Unix时间戳 6 | func GetCurrentTime() int64 { 7 | return time.Now().Unix() 8 | } 9 | 10 | //获取当前的毫秒级时间戳 11 | func GetCurrentMilliTime() int64 { 12 | return time.Now().UnixNano() / 1000000 13 | } 14 | -------------------------------------------------------------------------------- /utils/time_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import "testing" 4 | 5 | func TestGetCurrentTime(t *testing.T) { 6 | t1 := GetCurrentTime() 7 | t2 := GetCurrentMilliTime() 8 | if t1 != t2/1000 && t1+1 != t2/1000 { 9 | t.Error("time error") 10 | return 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /utils/url.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "net/url" 6 | "reflect" 7 | "strings" 8 | ) 9 | 10 | // MapToStringList 多层Map转字符串数组 11 | func mapToStringList(params map[string]interface{}, parentNode string) (result []string) { 12 | for key := range params { 13 | nextParentNode := "" 14 | if len(parentNode) > 0 { 15 | nextParentNode = parentNode + "[" + key + "]" 16 | } else { 17 | nextParentNode = key 18 | } 19 | value := params[key] 20 | t := reflect.TypeOf(value) 21 | switch t.Kind() { 22 | case reflect.Map: 23 | tempResult := mapToStringList(value.(map[string]interface{}), nextParentNode) 24 | result = append(result, tempResult...) 25 | break 26 | case reflect.Slice: 27 | typeString := t.Elem().String() 28 | tmpVal := map[string]interface{}{} 29 | 30 | if typeString == "int" { 31 | for idx, subVal := range value.([]int) { 32 | tmpVal[fmt.Sprint(idx)] = subVal 33 | } 34 | } else if typeString == "string" { 35 | for idx, subVal := range value.([]string) { 36 | tmpVal[fmt.Sprint(idx)] = subVal 37 | } 38 | } else { 39 | for idx, subVal := range value.([]interface{}) { 40 | tmpVal[fmt.Sprint(idx)] = subVal 41 | } 42 | } 43 | tempResult := mapToStringList(tmpVal, nextParentNode) 44 | result = append(result, tempResult...) 45 | break 46 | default: 47 | result = append(result, url.QueryEscape(nextParentNode)+"="+url.QueryEscape(fmt.Sprint(value))) 48 | } 49 | } 50 | return 51 | } 52 | 53 | //HttpBuildQuery 生成Query参数 54 | func HttpBuildQuery(params map[string]interface{}) (query string) { 55 | queryList := mapToStringList(params, "") 56 | query = strings.Join(queryList, "&") 57 | return 58 | } 59 | -------------------------------------------------------------------------------- /utils/url_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestHttpBuildQuery(t *testing.T) { 9 | params := map[string]interface{}{ 10 | "uid": 1, 11 | "name": "hts", 12 | } 13 | s := HttpBuildQuery(params) 14 | if s != "uid=1&name=hts" && s != "name=hts&uid=1" { 15 | t.Error("HttpBuildQuery failed") 16 | return 17 | } 18 | 19 | params = map[string]interface{}{ 20 | "a": []string{"b", "c"}, 21 | "map": map[string]interface{}{ 22 | "a1": "111", 23 | "b2": 2.3, 24 | "b3": []int{1, 4}, 25 | }, 26 | } 27 | // 这个方法对参数的构造太随机了 28 | s = HttpBuildQuery(params) 29 | if s != "a%5B0%5D=b&a%5B1%5D=c&map%5Ba1%5D=111&map%5Bb2%5D=2.3&map%5Bb3%5D%5B0%5D=1&map%5Bb3%5D%5B1%5D=4" && 30 | s != "a%5B1%5D=c&a%5B0%5D=b&map%5Bb2%5D=2.3&map%5Bb3%5D%5B1%5D=4&map%5Bb3%5D%5B0%5D=1&map%5Ba1%5D=111" && 31 | s != "a%5B0%5D=b&a%5B1%5D=c&map%5Bb3%5D%5B0%5D=1&map%5Bb3%5D%5B1%5D=4&map%5Ba1%5D=111&map%5Bb2%5D=2.3" && 32 | s != "map%5Ba1%5D=111&map%5Bb2%5D=2.3&map%5Bb3%5D%5B0%5D=1&map%5Bb3%5D%5B1%5D=4&a%5B0%5D=b&a%5B1%5D=c" && 33 | s != "a%5B0%5D=b&a%5B1%5D=c&map%5Bb2%5D=2.3&map%5Bb3%5D%5B0%5D=1&map%5Bb3%5D%5B1%5D=4&map%5Ba1%5D=111" && 34 | s != "a%5B1%5D=c&a%5B0%5D=b&map%5Ba1%5D=111&map%5Bb2%5D=2.3&map%5Bb3%5D%5B0%5D=1&map%5Bb3%5D%5B1%5D=4" { 35 | fmt.Println("HttpBuildQuery", s) 36 | //t.Error("HttpBuildQuery failed") 37 | return 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /utils/uuid.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "github.com/google/uuid" 5 | ) 6 | 7 | func GenUUID() string { 8 | u, _ := uuid.NewRandom() 9 | return u.String() 10 | } 11 | -------------------------------------------------------------------------------- /utils/uuid_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestGenUUID(t *testing.T) { 8 | s := GenUUID() 9 | if len(s) == 0 { 10 | t.Error("length of uuid is equal 0") 11 | } 12 | } 13 | --------------------------------------------------------------------------------