├── .gitignore ├── LICENSE ├── README.md ├── cache ├── cache.go ├── cache_test.go ├── codec.go ├── counter │ ├── counter.go │ ├── redis_counter.go │ ├── redis_counter_persist.go │ ├── redis_counter_scripts.go │ └── redis_counter_test.go ├── list │ ├── list.go │ ├── list_add.lua │ ├── list_del.lua │ ├── list_lua.go │ └── list_test.go ├── redis.go ├── redis_conf.go └── redis_test.go ├── common ├── aes.go ├── aes_test.go ├── concurrent_map.go ├── concurrent_map_test.go ├── config.go ├── config_test.go ├── config_yaml.go ├── constant.go ├── copy_on_write_map.go ├── io_util.go ├── io_util_test.go ├── json_util.go ├── json_util_test.go ├── linked_map.go ├── linked_map_test.go ├── logger.go ├── logger_std.go ├── logger_test.go ├── logger_zap.go ├── model.go ├── perm │ └── perm.go ├── pkg_test.go ├── service.go ├── service_test.go ├── testdata │ └── 1.txt ├── time.go ├── time_test.go ├── util.go ├── util_test.go ├── validates.go ├── validates_service.go └── validates_test.go ├── go.mod ├── go.sum ├── http ├── controller.go ├── controller_test.go ├── cookiejar.go ├── cookiejar_test.go ├── filesystem.go ├── http.go ├── http_config.go ├── http_context.go ├── http_util.go ├── http_util_test.go └── middleware.go ├── inject ├── Readme.md ├── inject.go ├── inject_helper.go ├── inject_helper_test.go ├── inject_test.go ├── module.go └── testdata │ └── conf │ └── common.yaml └── orm ├── comm.go ├── comm_test.go ├── config.go ├── mysql.go ├── mysql_test.go ├── op.go ├── orm_funcs.go ├── orm_meta.go ├── orm_meta_test.go ├── orm_test.go ├── service_shard.go ├── service_shard_test.go ├── service_simple.go ├── shard.go ├── shard_test.go └── testdata ├── setup.sql ├── shard.yaml └── teardown.sql /.gitignore: -------------------------------------------------------------------------------- 1 | bin/ 2 | .DS_Store 3 | *.iml 4 | *.ipr 5 | *.iws 6 | .idea 7 | *.log 8 | Thumbs.db 9 | **/Godeps/_workspace 10 | .vscode 11 | *.rdb 12 | debug.test 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go 2 | Go Toolkit 3 | 4 | ## Inject example 5 | ```go 6 | import ( 7 | "flag" 8 | "fmt" 9 | "os" 10 | 11 | c "github.com/d0ngw/go/common" 12 | "github.com/d0ngw/go/inject" 13 | "github.com/d0ngw/go/orm" 14 | ) 15 | 16 | // WorkerApp woker服务,适用于后台执行的任务 17 | type WorkerApp struct { 18 | config c.Configurer 19 | modules []*inject.Module 20 | env string 21 | } 22 | 23 | // NewWorkerApp new http app 24 | func NewWorkerApp(config c.Configurer, modules []*inject.Module) *WorkerApp { 25 | ret := &WorkerApp{ 26 | config: config, 27 | modules: modules, 28 | } 29 | return ret 30 | } 31 | 32 | // Usage flag.Usage 33 | func (p *WorkerApp) Usage() { 34 | fmt.Fprintf(os.Stderr, "\n") 35 | fmt.Fprintf(os.Stderr, "usage:%s args...\n", os.Args[0]) 36 | fmt.Fprintf(os.Stderr, "\n") 37 | flag.PrintDefaults() 38 | os.Exit(1) 39 | } 40 | 41 | // FlagInit flag的初始化 42 | func (p *WorkerApp) FlagInit() { 43 | flag.Usage = p.Usage 44 | flag.StringVar(&p.env, "env", "dev", "the enviroment") 45 | } 46 | 47 | // Run 运行,直到推出 48 | func (p *WorkerApp) Run() { 49 | flag.Parse() 50 | 51 | injector, err := inject.SetupInjector(p.config, "", p.env, p.modules...) 52 | if err != nil { 53 | c.Errorf("init injector fail,err:%s", err) 54 | p.Usage() 55 | } 56 | 57 | if err := injector.Start(nil); err != nil { 58 | c.Errorf("Start servcie fail,err:%s", err) 59 | os.Exit(1) 60 | } 61 | 62 | shutdownHook := c.NewShutdownhook(defaultShutdownHooks...) 63 | shutdownHook.AddHook(func() { 64 | injector.Stop(nil) 65 | }) 66 | 67 | shutdownHook.WaitShutdown() 68 | } 69 | 70 | var app *WorkerApp 71 | 72 | func init(){ 73 | var module = inject.NewModule() 74 | module.Bind(orm.NewSimpleShardDBService(orm.NewMySQLDBPool)) 75 | 76 | app = NewWorkerApp(xxxconfig,[]*inject.Module{module}) 77 | app.FlagInit() 78 | } 79 | 80 | func main(){ 81 | app.Run() 82 | } 83 | ``` 84 | -------------------------------------------------------------------------------- /cache/cache.go: -------------------------------------------------------------------------------- 1 | // Package cache 提供缓冲相关的服务 2 | package cache 3 | 4 | // Param is the cache param 5 | type Param interface { 6 | //Group cache group id 7 | Group() string 8 | //Key cache key 9 | Key() string 10 | //Expire second time 11 | Expire() int 12 | } 13 | 14 | // ParamConf is the cache param conf with cache group,key prefix and expire 15 | type ParamConf struct { 16 | group string 17 | keyPrefix string 18 | expire int 19 | } 20 | 21 | // NewParamConf create ParamConf 22 | func NewParamConf(group, keyPrefix string, expire int) *ParamConf { 23 | return &ParamConf{ 24 | group: group, 25 | keyPrefix: keyPrefix, 26 | expire: expire, 27 | } 28 | } 29 | 30 | // Group return cache group 31 | func (p *ParamConf) Group() string { 32 | return p.group 33 | } 34 | 35 | // Expire return expire second 36 | func (p *ParamConf) Expire() int { 37 | return p.expire 38 | } 39 | 40 | // KeyPrefix return key prefix 41 | func (p *ParamConf) KeyPrefix() string { 42 | return p.keyPrefix 43 | } 44 | 45 | // NewWithExpire create new ParamConf with new expire parameter 46 | func (p *ParamConf) NewWithExpire(expire int) *ParamConf { 47 | var param = *p 48 | param.expire = expire 49 | return ¶m 50 | } 51 | 52 | // NewWithKeyPrefix append keyPrefix to exist ParamConf,return new ParamConf 53 | func (p *ParamConf) NewWithKeyPrefix(keyPrefix string) *ParamConf { 54 | var param = *p 55 | param.keyPrefix = p.keyPrefix + keyPrefix 56 | return ¶m 57 | } 58 | 59 | // NewParamKey create new ParamKey with key 60 | func (p *ParamConf) NewParamKey(key string) *ParamKey { 61 | return &ParamKey{ 62 | ParamConf: *p, 63 | key: p.keyPrefix + key, 64 | } 65 | } 66 | 67 | // NewParamKeyWithoutPrefix create new ParamKey without p.keyPrefix 68 | func (p *ParamConf) NewParamKeyWithoutPrefix(key string) *ParamKey { 69 | return &ParamKey{ 70 | ParamConf: *p, 71 | key: key, 72 | } 73 | } 74 | 75 | // ParamKey is the cache param with key 76 | type ParamKey struct { 77 | ParamConf 78 | key string 79 | } 80 | 81 | // Key implements Param.Key() 82 | func (p *ParamKey) Key() string { 83 | return p.key 84 | } 85 | 86 | // NewWithExpire new key with expire 87 | func (p *ParamKey) NewWithExpire(expire int) *ParamKey { 88 | var k = *p 89 | k.expire = expire 90 | return &k 91 | } 92 | -------------------------------------------------------------------------------- /cache/cache_test.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestEncDec(t *testing.T) { 10 | redisServer := &RedisServer{ 11 | ID: "test", 12 | Host: "127.0.0.1", 13 | Port: 6379, 14 | } 15 | bytes, err := MsgPackEncodeBytes(redisServer) 16 | assert.Nil(t, err) 17 | 18 | server := &RedisServer{} 19 | err = MsgPackDecodeBytes(bytes, server) 20 | assert.Nil(t, err) 21 | assert.Equal(t, *redisServer, *server) 22 | 23 | bytes, err = MsgPackEncodeBytes(nil) 24 | assert.Nil(t, err) 25 | 26 | var v *int 27 | err = MsgPackDecodeBytes(bytes, &v) 28 | assert.Nil(t, err) 29 | assert.Nil(t, v) 30 | 31 | servers := []*RedisServer{server, server} 32 | bytes, err = MsgPackEncodeBytes(servers) 33 | assert.Nil(t, err) 34 | 35 | var v1 *int 36 | err = MsgPackDecodeBytes(nil, &v1) 37 | assert.NotNil(t, err) 38 | assert.Nil(t, v) 39 | } 40 | -------------------------------------------------------------------------------- /cache/codec.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | 7 | "github.com/ugorji/go/codec" 8 | ) 9 | 10 | var msgpackHandle = &codec.MsgpackHandle{} 11 | 12 | func init() { 13 | msgpackHandle.MapType = reflect.TypeOf(map[string]interface{}(nil)) 14 | } 15 | 16 | // MsgPackEncodeBytes encode data to bytes use msgpack 17 | func MsgPackEncodeBytes(data interface{}) (bytes []byte, err error) { 18 | enc := codec.NewEncoderBytes(&bytes, msgpackHandle) 19 | err = enc.Encode(data) 20 | return 21 | } 22 | 23 | // MsgPackDecodeBytes decode bytes to dest use msgpack 24 | func MsgPackDecodeBytes(bytes []byte, dest interface{}) (err error) { 25 | if len(bytes) == 0 { 26 | return errors.New("nil bytes to decode") 27 | } 28 | dec := codec.NewDecoderBytes(bytes, msgpackHandle) 29 | err = dec.Decode(dest) 30 | return 31 | } 32 | -------------------------------------------------------------------------------- /cache/counter/counter.go: -------------------------------------------------------------------------------- 1 | // Package counter supply counter service 2 | package counter 3 | 4 | // Fields define the counter's field and value 5 | type Fields map[string]int64 6 | 7 | // Counter service 8 | type Counter interface { 9 | // GetName counter name 10 | GetName() string 11 | // Incr increase the counterID with fieldAndDelta 12 | Incr(counterID string, fieldAndDelta Fields) error 13 | 14 | // Get the fields of counterID 15 | Get(counterID string) (fields Fields, err error) 16 | 17 | // Del delete the counter whose id is `counterID`` 18 | Del(counterID string) error 19 | } 20 | 21 | // Persist counter fields to the persist storage 22 | type Persist interface { 23 | // Load the fields of counterID from persist storage 24 | Load(counterID string) (fields Fields, err error) 25 | 26 | // Del delete the counter whose id is `counterID`` 27 | Del(counterID string) (deleted bool, err error) 28 | 29 | // Store save the value of fields with counterID 30 | Store(counterID string, fields Fields) error 31 | } 32 | -------------------------------------------------------------------------------- /cache/counter/redis_counter_scripts.go: -------------------------------------------------------------------------------- 1 | package counter 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io/ioutil" 7 | 8 | "github.com/gomodule/redigo/redis" 9 | ) 10 | 11 | // Scripts define persist counter lua scripts 12 | type Scripts struct { 13 | Update string `yaml:"update_lua"` 14 | SetSync string `yaml:"sync_set_lua"` 15 | Evict string `yaml:"evict_lua"` 16 | HgetAll string `yaml:"hgetall_lua"` 17 | Del string `yaml:"del_lua"` 18 | loadFromString bool 19 | update *redis.Script 20 | setSync *redis.Script 21 | evict *redis.Script 22 | hgetAll *redis.Script 23 | del *redis.Script 24 | } 25 | 26 | //NewScripts new 27 | func NewScripts(loadFromString bool) *Scripts { 28 | return &Scripts{loadFromString: loadFromString} 29 | } 30 | 31 | // Lua 32 | const ( 33 | LUAFALSE int = 0 34 | LUATRUE int = 1 35 | ) 36 | 37 | // Init implements Init 38 | func (p *Scripts) Init() (err error) { 39 | scripts := []struct { 40 | path string 41 | dest **redis.Script 42 | }{ 43 | {p.Update, &p.update}, 44 | {p.SetSync, &p.setSync}, 45 | {p.Evict, &p.evict}, 46 | {p.HgetAll, &p.hgetAll}, 47 | {p.Del, &p.del}, 48 | } 49 | 50 | for _, v := range scripts { 51 | if p.loadFromString { 52 | if err := p.loadScriptFromString(v.path, v.dest); err != nil { 53 | return err 54 | } 55 | } else { 56 | if err := p.loadScriptFromFile(v.path, v.dest); err != nil { 57 | return err 58 | } 59 | } 60 | } 61 | return nil 62 | } 63 | 64 | func (p *Scripts) loadScriptFromFile(luaPath string, dest **redis.Script) error { 65 | data, err := ioutil.ReadFile(luaPath) 66 | if err != nil { 67 | return err 68 | } 69 | if len(data) == 0 { 70 | return fmt.Errorf("empty lua script in %s", luaPath) 71 | } 72 | script := redis.NewScript(1, string(data)) 73 | *dest = script 74 | return nil 75 | } 76 | 77 | func (p *Scripts) loadScriptFromString(luaData string, dest **redis.Script) error { 78 | if len(luaData) == 0 { 79 | return errors.New("empty lua script in %s") 80 | } 81 | script := redis.NewScript(1, string(luaData)) 82 | *dest = script 83 | return nil 84 | } 85 | -------------------------------------------------------------------------------- /cache/counter/redis_counter_test.go: -------------------------------------------------------------------------------- 1 | package counter 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "os/user" 9 | "path" 10 | 11 | "github.com/d0ngw/go/cache" 12 | c "github.com/d0ngw/go/common" 13 | "github.com/d0ngw/go/orm" 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | type V struct { 18 | BaseEntity 19 | Time int64 `column:"ut"` 20 | } 21 | 22 | func (p *V) TableName() string { 23 | return "v" 24 | } 25 | 26 | func (p *V) Entity(counterID string, fields Fields) (orm.Entity, error) { 27 | e, err := p.BaseEntity.ToBaseEntity(counterID, fields) 28 | if err != nil { 29 | return nil, err 30 | } 31 | return &V{ 32 | BaseEntity: *e, 33 | Time: c.UnixMills(time.Now()), 34 | }, nil 35 | } 36 | func (p *V) ZeroFields() Fields { 37 | return Fields{"a": int64(1), "b": int64(0)} 38 | } 39 | 40 | var r *cache.RedisClient 41 | var dbService orm.ShardDBService 42 | 43 | func init() { 44 | var err error 45 | config := &orm.DBShardConfig{ 46 | Shards: map[string]*orm.DBConfig{ 47 | "test": { 48 | User: "root", 49 | Pass: "123456", 50 | URL: "127.0.0.1:3306", 51 | Schema: "test", 52 | MaxConn: 100, 53 | MaxIdle: 10}, 54 | }, 55 | Default: "test", 56 | } 57 | 58 | shardDBServcie := orm.NewSimpleShardDBService(orm.NewMySQLDBPool) 59 | shardDBServcie.DBShardConfig = config 60 | dbService = shardDBServcie 61 | 62 | dbService.Init() 63 | 64 | redisServer := &cache.RedisServer{ 65 | ID: "test", 66 | Host: "127.0.0.1", 67 | Port: 6379, 68 | } 69 | var redisConf = cache.RedisConf{ 70 | Servers: []*cache.RedisServer{redisServer}, 71 | Groups: map[string][]string{"test": {"test"}, "test.sync": {"test"}}, 72 | } 73 | 74 | err = redisConf.Parse() 75 | if err != nil { 76 | panic(err) 77 | } 78 | r = cache.NewRedisClientWithConf(&redisConf) 79 | orm.AddMeta(&V{}) 80 | } 81 | 82 | type persistMock struct { 83 | } 84 | 85 | func (p *persistMock) Load(counterID string) (fields Fields, err error) { 86 | fmt.Println("Load:" + counterID) 87 | return Fields{"a": int64(1), "b": int64(0)}, nil 88 | } 89 | 90 | func (p *persistMock) Del(counterID string) (deleted bool, err error) { 91 | fmt.Println("Del:" + counterID) 92 | return true, nil 93 | } 94 | 95 | func (p *persistMock) Store(counterID string, fieldAndDelta Fields) error { 96 | fmt.Printf("Store %s,v:%v:", counterID, fieldAndDelta) 97 | return nil 98 | } 99 | 100 | func TestPersistCounter(t *testing.T) { 101 | user, err := user.Current() 102 | assert.Nil(t, err) 103 | var cacheConf = cache.NewParamConf("test", "c_", 0) 104 | scripts := &Scripts{ 105 | Update: path.Join(user.HomeDir, "temp", "lua", "counter_update.lua"), 106 | SetSync: path.Join(user.HomeDir, "temp", "lua", "counter_update_sync.lua"), 107 | Evict: path.Join(user.HomeDir, "temp", "lua", "counter_evict.lua"), 108 | HgetAll: path.Join(user.HomeDir, "temp", "lua", "counter_getall.lua"), 109 | Del: path.Join(user.HomeDir, "temp", "lua", "counter_del.lua"), 110 | } 111 | 112 | //persist := &persistMock{} 113 | persist, err := NewDBPersist(func() orm.ShardDBService { return dbService }, &V{}) 114 | assert.Nil(t, err) 115 | 116 | counter := NewPersistRedisCounter("test", func() *cache.RedisClient { return r }, scripts, persist, cacheConf, 10) 117 | 118 | err = counter.scripts.Init() 119 | assert.Nil(t, err) 120 | 121 | err = counter.Init() 122 | assert.Nil(t, err) 123 | 124 | testCounter(t, counter) 125 | 126 | //counter.persist, err = 127 | //assert.Nil(t, err) 128 | testCounter(t, counter) 129 | 130 | redisCounterSync, err := NewRedisCounterSync(counter, 10, 1, 1, 1) 131 | assert.Nil(t, err) 132 | err = redisCounterSync.ScanAll() 133 | assert.Nil(t, err) 134 | 135 | syncSchedule, err := NewRedisCounterSyncSchedule("test", []*RedisCounterSync{redisCounterSync}, 5) 136 | assert.Nil(t, err) 137 | assert.Nil(t, syncSchedule.Init()) 138 | assert.True(t, syncSchedule.Start()) 139 | time.Sleep(time.Duration(5*syncSchedule.scanIntervalSecond) * time.Second) 140 | assert.True(t, syncSchedule.Stop()) 141 | } 142 | 143 | func testCounter(t *testing.T, counter *PersistRedisCounter) { 144 | var err error 145 | id := "1" 146 | err = counter.Del(id) 147 | assert.Nil(t, err) 148 | 149 | fields, err := counter.Get(id) 150 | assert.Nil(t, err) 151 | assert.EqualValues(t, 1, fields["a"]) 152 | assert.EqualValues(t, 0, fields["b"]) 153 | 154 | err = counter.Incr(id, Fields{"a": 1, "b": 2}) 155 | assert.Nil(t, err) 156 | 157 | fields, err = counter.Get(id) 158 | assert.Nil(t, err) 159 | assert.EqualValues(t, 2, fields["a"]) 160 | assert.EqualValues(t, 2, fields["b"]) 161 | 162 | err = counter.persist.Store(id, fields) 163 | assert.Nil(t, err) 164 | 165 | fields, err = counter.Get(id) 166 | assert.Nil(t, err) 167 | assert.EqualValues(t, 2, fields["a"]) 168 | assert.EqualValues(t, 2, fields["b"]) 169 | 170 | fields, err = counter.Get(id) 171 | assert.Nil(t, err) 172 | assert.EqualValues(t, 2, fields["a"]) 173 | assert.EqualValues(t, 2, fields["b"]) 174 | } 175 | 176 | func TestNoPersistCounter(t *testing.T) { 177 | var cacheConf = cache.NewParamConf("test", "np_c_", 30) 178 | counter, err := NewNoPersistRedisCounter("test", r, cacheConf) 179 | assert.Nil(t, err) 180 | 181 | id := "1" 182 | var fieldAndDelta = Fields{"a": 1, "b": 2} 183 | err = counter.Incr(id, fieldAndDelta) 184 | assert.Nil(t, err) 185 | 186 | reply, err := counter.Get(id) 187 | assert.Nil(t, err) 188 | assert.Equal(t, fieldAndDelta, reply) 189 | 190 | err = counter.DelFields(id, "a") 191 | assert.Nil(t, err) 192 | 193 | reply, err = counter.Get(id) 194 | assert.Nil(t, err) 195 | assert.Equal(t, Fields{"b": 2}, reply) 196 | 197 | err = counter.Del(id) 198 | assert.Nil(t, err) 199 | 200 | reply, err = counter.Get(id) 201 | assert.Nil(t, err) 202 | assert.Nil(t, reply) 203 | } 204 | -------------------------------------------------------------------------------- /cache/list/list_add.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | init or update list cache 3 | params: 4 | list_key,max_count,must_exist_key,expire_seconds,score1,member1,score2,member2.... 5 | retun: 6 | {exist,updated} 7 | -- ]] 8 | local list_key = KEYS[1] 9 | local max_count = tonumber(ARGV[1]) 10 | local must_exist_key = tonumber(ARGV[2]) 11 | local expire_seconds = tonumber(ARGV[3]) 12 | 13 | if #ARGV < 5 or (#ARGV - 5) % 2 ~= 0 then 14 | return redis.error_reply("Wrong score and member args numbers") 15 | end 16 | 17 | redis.call("PERSIST", list_key) 18 | local exist = redis.call("EXISTS", list_key) 19 | local updated = 0 20 | local need_update = false 21 | 22 | if exist == 1 then 23 | need_update = true 24 | else 25 | if must_exist_key == 1 then 26 | need_update = false 27 | else 28 | need_update = true 29 | end 30 | end 31 | 32 | if need_update then 33 | local score_members = {} 34 | for i = 4, #ARGV, 2 do 35 | score_members[#score_members + 1] = ARGV[i] 36 | score_members[#score_members + 1] = ARGV[i + 1] 37 | end 38 | redis.call("ZADD", list_key, unpack(score_members)) 39 | redis.call("ZREMRANGEBYRANK", list_key,max_count,-1 ) 40 | updated = 1 41 | exist = 1 42 | end 43 | 44 | if expire_seconds > 0 then 45 | redis.call("EXPIRE", list_key, expire_seconds) 46 | end 47 | 48 | return { exist, updated } 49 | -------------------------------------------------------------------------------- /cache/list/list_del.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | init or update list cache 3 | params: 4 | list_key,expire_seconds,member1,member2.... 5 | retun: 6 | {deleted,last_member,length} 7 | -- ]] 8 | local list_key = KEYS[1] 9 | local expire_seconds = tonumber(ARGV[1]) 10 | 11 | if #ARGV < 2 then 12 | return redis.error_reply("Wrong args numbers") 13 | end 14 | 15 | redis.call("PERSIST", list_key) 16 | local exist = redis.call("EXISTS", list_key) 17 | local deleted = 0 18 | local last_member = 0 19 | local length = 0 20 | 21 | if exist == 1 then 22 | local members = {} 23 | for i = 2, #ARGV, 1 do 24 | members[#members + 1] = ARGV[i] 25 | end 26 | deleted=redis.call("ZREM", list_key, unpack(members)) 27 | local lastm = redis.call("ZRANGE",list_key,-1,-1) 28 | if #lastm > 0 then 29 | last_member = lastm[1] 30 | end 31 | length = redis.call("ZCARD",list_key) 32 | end 33 | 34 | if expire_seconds > 0 then 35 | redis.call("EXPIRE", list_key, expire_seconds) 36 | end 37 | 38 | return { deleted, last_member,length} 39 | -------------------------------------------------------------------------------- /cache/list/list_lua.go: -------------------------------------------------------------------------------- 1 | package list 2 | 3 | import "github.com/gomodule/redigo/redis" 4 | 5 | var addLua = ` 6 | local list_key = KEYS[1] 7 | local max_count = tonumber(ARGV[1]) 8 | local must_exist_key = tonumber(ARGV[2]) 9 | local expire_seconds = tonumber(ARGV[3]) 10 | 11 | if #ARGV < 5 or (#ARGV - 5) % 2 ~= 0 then 12 | return redis.error_reply("Wrong score and member args numbers") 13 | end 14 | 15 | redis.call("PERSIST", list_key) 16 | local exist = redis.call("EXISTS", list_key) 17 | local updated = 0 18 | local need_update = false 19 | 20 | if exist == 1 then 21 | need_update = true 22 | else 23 | if must_exist_key == 1 then 24 | need_update = false 25 | else 26 | need_update = true 27 | end 28 | end 29 | 30 | if need_update then 31 | local score_members = {} 32 | for i = 4, #ARGV, 2 do 33 | score_members[#score_members + 1] = ARGV[i] 34 | score_members[#score_members + 1] = ARGV[i + 1] 35 | end 36 | redis.call("ZADD", list_key, unpack(score_members)) 37 | redis.call("ZREMRANGEBYRANK", list_key,max_count,-1 ) 38 | updated = 1 39 | exist = 1 40 | end 41 | 42 | if expire_seconds > 0 then 43 | redis.call("EXPIRE", list_key, expire_seconds) 44 | end 45 | 46 | return { exist, updated } 47 | ` 48 | var addScript = redis.NewScript(1, addLua) 49 | 50 | var delLua = ` 51 | local list_key = KEYS[1] 52 | local expire_seconds = tonumber(ARGV[1]) 53 | 54 | if #ARGV < 2 then 55 | return redis.error_reply("Wrong args numbers") 56 | end 57 | 58 | redis.call("PERSIST", list_key) 59 | local exist = redis.call("EXISTS", list_key) 60 | local deleted = 0 61 | local last_member = 0 62 | local length = 0 63 | 64 | if exist == 1 then 65 | local members = {} 66 | for i = 2, #ARGV, 1 do 67 | members[#members + 1] = ARGV[i] 68 | end 69 | deleted=redis.call("ZREM", list_key, unpack(members)) 70 | local lastm = redis.call("ZRANGE",list_key,-1,-1) 71 | if #lastm > 0 then 72 | last_member = lastm[1] 73 | end 74 | length = redis.call("ZCARD",list_key) 75 | end 76 | 77 | if expire_seconds > 0 then 78 | redis.call("EXPIRE", list_key, expire_seconds) 79 | end 80 | 81 | return { deleted, last_member, length } 82 | ` 83 | 84 | var delScript = redis.NewScript(1, delLua) 85 | -------------------------------------------------------------------------------- /cache/list/list_test.go: -------------------------------------------------------------------------------- 1 | package list 2 | 3 | import ( 4 | "os/user" 5 | "path" 6 | "testing" 7 | 8 | "github.com/d0ngw/go/cache" 9 | "github.com/d0ngw/go/cache/counter" 10 | "github.com/d0ngw/go/orm" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | var r *cache.RedisClient 15 | var dbService orm.ShardDBService 16 | 17 | type testCounterEntity struct { 18 | CounterEntity 19 | } 20 | 21 | func (p *testCounterEntity) TableName() string { 22 | return "v" 23 | } 24 | 25 | func (p *testCounterEntity) Entity(counterID string, fields counter.Fields) (orm.Entity, error) { 26 | e, err := p.CounterEntity.ToBaseEntity(counterID, fields) 27 | if err != nil { 28 | return nil, err 29 | } 30 | return &testCounterEntity{ 31 | CounterEntity: *e, 32 | }, nil 33 | } 34 | 35 | type testListEntity struct { 36 | BaseEntity 37 | } 38 | 39 | func (p *testListEntity) TableName() string { 40 | return "list" 41 | } 42 | 43 | func init() { 44 | var err error 45 | config := &orm.DBShardConfig{ 46 | Shards: map[string]*orm.DBConfig{ 47 | "test": { 48 | User: "root", 49 | Pass: "123456", 50 | URL: "127.0.0.1:3306", 51 | Schema: "test", 52 | MaxConn: 100, 53 | MaxIdle: 10}, 54 | }, 55 | Default: "test", 56 | } 57 | 58 | shardDBService := orm.NewSimpleShardDBService(orm.NewMySQLDBPool) 59 | shardDBService.DBShardConfig = config 60 | 61 | dbService = shardDBService 62 | dbService.Init() 63 | 64 | redisServer := &cache.RedisServer{ 65 | ID: "test", 66 | Host: "127.0.0.1", 67 | Port: 6379, 68 | } 69 | var redisConf = cache.RedisConf{ 70 | Servers: []*cache.RedisServer{redisServer}, 71 | Groups: map[string][]string{"test": {"test"}}, 72 | } 73 | 74 | err = redisConf.Parse() 75 | if err != nil { 76 | panic(err) 77 | } 78 | r = cache.NewRedisClientWithConf(&redisConf) 79 | orm.AddMeta(&testCounterEntity{}) 80 | orm.AddMeta(&testListEntity{}) 81 | } 82 | 83 | func TestList(t *testing.T) { 84 | listCacheParm := cache.NewParamConf("test", "list_", 30) 85 | counterCacheParam := cache.NewParamConf("test", "list_c_", 30) 86 | user, err := user.Current() 87 | scripts := &counter.Scripts{ 88 | Update: path.Join(user.HomeDir, "temp", "lua", "counter_update.lua"), 89 | SetSync: path.Join(user.HomeDir, "temp", "lua", "counter_update_sync.lua"), 90 | Evict: path.Join(user.HomeDir, "temp", "lua", "counter_evict.lua"), 91 | HgetAll: path.Join(user.HomeDir, "temp", "lua", "counter_getall.lua"), 92 | Del: path.Join(user.HomeDir, "temp", "lua", "counter_del.lua"), 93 | } 94 | err = scripts.Init() 95 | assert.Nil(t, err) 96 | 97 | persist, err := counter.NewDBPersist(func() orm.ShardDBService { return dbService }, &testCounterEntity{}) 98 | assert.Nil(t, err) 99 | counter := counter.NewPersistRedisCounter("test", func() *cache.RedisClient { return r }, scripts, persist, counterCacheParam, 10) 100 | err = counter.Init() 101 | assert.Nil(t, err) 102 | 103 | // id as score 104 | listCache, err := NewCache(&testListEntity{}, func() orm.ShardDBService { return dbService }, func() *cache.RedisClient { return r }, listCacheParm, 500, false, counter) 105 | assert.Nil(t, err) 106 | 107 | for i := 1; i <= 100; i++ { 108 | toAdd := &testListEntity{BaseEntity: BaseEntity{OwnerID: "d0ngw", TargetID: int64(i)}} 109 | succ, err := listCache.Add(toAdd) 110 | assert.NoError(t, err) 111 | assert.True(t, succ) 112 | succ, err = listCache.Add(&testListEntity{BaseEntity: BaseEntity{OwnerID: "d0ngw", TargetID: int64(i)}}) 113 | assert.Error(t, err) 114 | assert.False(t, succ) 115 | } 116 | 117 | total, err := listCache.GetCount("d0ngw") 118 | assert.Nil(t, err) 119 | assert.EqualValues(t, 100, total) 120 | 121 | total, ids, err := listCache.LoadList("d0ngw", 1, 10, 0) 122 | assert.Nil(t, err) 123 | assert.EqualValues(t, 100, total) 124 | assert.EqualValues(t, 10, len(ids)) 125 | for i, v := range ids { 126 | assert.EqualValues(t, 100-i, v) 127 | } 128 | 129 | total, targetScores, err := listCache.LoadListWithScore("d0ngw", 1, 10, 0) 130 | assert.Nil(t, err) 131 | assert.EqualValues(t, 100, total) 132 | assert.EqualValues(t, 10, len(targetScores)) 133 | for i, v := range targetScores { 134 | assert.EqualValues(t, 100-i, v[0]) 135 | t.Logf("tareget id:%d score id:%d", v[0], v[1]) 136 | } 137 | 138 | for i := 1; i <= 100; i++ { 139 | succ, err := listCache.Del("d0ngw", int64(i)) 140 | assert.Nil(t, err) 141 | assert.True(t, succ) 142 | } 143 | 144 | total, err = listCache.GetCount("d0ngw") 145 | assert.Nil(t, err) 146 | assert.EqualValues(t, 0, total) 147 | 148 | // target id as score 149 | listCache, err = NewCache(&testListEntity{}, func() orm.ShardDBService { return dbService }, func() *cache.RedisClient { return r }, listCacheParm, 5, true, counter) 150 | assert.Nil(t, err) 151 | 152 | for i := 1; i <= 100; i++ { 153 | toAdd := &testListEntity{BaseEntity: BaseEntity{OwnerID: "d0ngw-t", TargetID: int64(i)}} 154 | succ, err := listCache.Add(toAdd) 155 | assert.NoError(t, err) 156 | assert.True(t, succ) 157 | succ, err = listCache.Add(&testListEntity{BaseEntity: BaseEntity{OwnerID: "d0ngw-t", TargetID: int64(i)}}) 158 | assert.Error(t, err) 159 | assert.False(t, succ) 160 | } 161 | 162 | total, err = listCache.GetCount("d0ngw-t") 163 | assert.Nil(t, err) 164 | assert.EqualValues(t, 100, total) 165 | 166 | total, ids, err = listCache.LoadList("d0ngw-t", 1, 10, 0) 167 | assert.Nil(t, err) 168 | assert.EqualValues(t, 100, total) 169 | assert.EqualValues(t, 10, len(ids)) 170 | for i, v := range ids { 171 | assert.EqualValues(t, 100-i, v) 172 | } 173 | 174 | total, targetScores, err = listCache.LoadListWithScore("d0ngw-t", 1, 100, 0) 175 | assert.Nil(t, err) 176 | assert.EqualValues(t, 100, total) 177 | assert.EqualValues(t, 100, len(targetScores)) 178 | for i, v := range targetScores { 179 | assert.EqualValues(t, 100-i, v[0]) 180 | t.Logf("tareget id:%d score id:%d", v[0], v[1]) 181 | } 182 | 183 | for i := 1; i <= 100; i++ { 184 | succ, err := listCache.Del("d0ngw-t", int64(i)) 185 | assert.Nil(t, err) 186 | assert.True(t, succ) 187 | } 188 | 189 | total, err = listCache.GetCount("d0ngw-t") 190 | assert.Nil(t, err) 191 | assert.EqualValues(t, 0, total) 192 | } 193 | -------------------------------------------------------------------------------- /cache/redis_conf.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "fmt" 5 | "sort" 6 | "time" 7 | 8 | c "github.com/d0ngw/go/common" 9 | "github.com/gomodule/redigo/redis" 10 | ) 11 | 12 | // Redis连接池的默认参数 13 | const ( 14 | DefaultConnectTimout = 5 * 1000 15 | DefaultReadTimeout = 5 * 1000 16 | DefaultWriteTimeout = 5 * 1000 17 | DefaultMaxActive = 100 18 | DefaultMaxIdle = 2 19 | DefaultIdleTimeout = 60 * 1000 20 | ) 21 | 22 | // RedisConfigurer Redis配置器 23 | type RedisConfigurer interface { 24 | c.Configurer 25 | RedisConfig() *RedisConf 26 | } 27 | 28 | // RedisPoolConf Redis连接池配置 29 | type RedisPoolConf struct { 30 | ConnectTimeout int `yaml:"connect_timeout"` //连接超时时间,单位毫秒秒 31 | ReadTimeout int `yaml:"read_timeout"` //读取超时,单位毫秒 32 | WriteTimeout int `yaml:"write_timeout"` //写取超时,单位毫秒 33 | MaxIdle int `yaml:"max_idle"` //最大空闲连接 34 | MaxActive int `yaml:"max_active"` //最大活跃连接,0表示不限制 35 | IdleTimeout int `yaml:"idel_teimout"` //空闲连接的超时时间,单位毫秒 36 | } 37 | 38 | var defaultPool = &RedisPoolConf{ 39 | ConnectTimeout: DefaultConnectTimout, 40 | ReadTimeout: DefaultReadTimeout, 41 | WriteTimeout: DefaultWriteTimeout, 42 | MaxActive: DefaultMaxActive, 43 | MaxIdle: DefaultMaxIdle, 44 | IdleTimeout: DefaultIdleTimeout, 45 | } 46 | 47 | // RedisServer Redis实例的配置 48 | type RedisServer struct { 49 | ID string `yaml:"id"` //Redis实例的id 50 | Host string `yaml:"host"` //Redis主机地址 51 | Port int `yaml:"port"` //Redis的端口 52 | Auth string `yaml:"auth"` //Redis认证密码 53 | pool *redis.Pool //Redis实例的连接池 54 | } 55 | 56 | // initPool 使用指定的参数初始化pool 57 | func (p *RedisServer) initPool(poolConf *RedisPoolConf) error { 58 | if p.pool != nil { 59 | return fmt.Errorf("server %s already inited", p.ID) 60 | } 61 | options := []redis.DialOption{} 62 | options = append(options, redis.DialConnectTimeout(time.Duration(poolConf.ConnectTimeout)*time.Millisecond)) 63 | options = append(options, redis.DialReadTimeout(time.Duration(poolConf.ReadTimeout)*time.Millisecond)) 64 | options = append(options, redis.DialWriteTimeout(time.Duration(poolConf.WriteTimeout)*time.Millisecond)) 65 | if p.Auth != "" { 66 | options = append(options, redis.DialPassword(p.Auth)) 67 | } 68 | 69 | var addr = fmt.Sprintf("%s:%d", p.Host, p.Port) 70 | 71 | pool := &redis.Pool{ 72 | Dial: func() (redis.Conn, error) { 73 | return redis.Dial("tcp", addr, options...) 74 | }, 75 | MaxActive: poolConf.MaxActive, 76 | MaxIdle: poolConf.MaxIdle, 77 | IdleTimeout: time.Duration(poolConf.IdleTimeout) * time.Millisecond, 78 | Wait: true, 79 | } 80 | p.pool = pool 81 | return nil 82 | } 83 | 84 | // GetConn acquire redis conn 85 | func (p *RedisServer) GetConn() (redis.Conn, error) { 86 | if p.pool == nil { 87 | return nil, fmt.Errorf("no pool") 88 | } 89 | return p.pool.Get(), nil 90 | } 91 | 92 | // RedisConf redis config 93 | type RedisConf struct { 94 | Servers []*RedisServer `yaml:"servers"` //实例列表 95 | Groups map[string][]string `yaml:"groups"` //Redis组定义,key为组ID;value为Server的id列表 96 | Pool *RedisPoolConf `yaml:"pool"` //默认的链接池配置 97 | GroupPool map[string]*RedisPoolConf `yaml:"groups_pools"` //Redis组的连接池配置 98 | groups map[string][]*RedisServer 99 | } 100 | 101 | // Parse implements Configurer interface 102 | func (p *RedisConf) Parse() error { 103 | if p == nil { 104 | c.Warnf("no redis conf") 105 | return nil 106 | } 107 | groups := map[string][]*RedisServer{} 108 | servers := map[string]*RedisServer{} 109 | 110 | //解析,并检查server的配置 111 | var dupChekc = map[string]struct{}{} 112 | for _, server := range p.Servers { 113 | if c.IsEmpty(server.ID, server.Host) { 114 | return fmt.Errorf("invalid redis server conf,id and host must not be emtpy") 115 | } 116 | if server.Port <= 0 { 117 | return fmt.Errorf("invalid redis server conf,port %d ", server.Port) 118 | } 119 | 120 | id := "id " + server.ID 121 | if _, ok := dupChekc[id]; ok { 122 | return fmt.Errorf("duplicate server:%s", id) 123 | } 124 | dupChekc[id] = struct{}{} 125 | 126 | addr := fmt.Sprintf("%s:%d", server.Host, server.Port) 127 | if _, ok := dupChekc[addr]; ok { 128 | return fmt.Errorf("duplicate server: %s", addr) 129 | } 130 | dupChekc[addr] = struct{}{} 131 | servers[server.ID] = server 132 | } 133 | 134 | //解析并检查group 135 | for groupID, groupServers := range p.Groups { 136 | if groupID == "" { 137 | return fmt.Errorf("invalid redis group id") 138 | } 139 | if len(servers) == 0 { 140 | return fmt.Errorf("redis group id %s has no servers", groupID) 141 | } 142 | dupChekc = map[string]struct{}{} 143 | for _, serverID := range groupServers { 144 | if _, ok := dupChekc[serverID]; ok { 145 | return fmt.Errorf("duplicate server id %s in group %s", serverID, groupID) 146 | } 147 | dupChekc[serverID] = struct{}{} 148 | } 149 | 150 | poolConf := p.GroupPool[groupID] 151 | if poolConf == nil { 152 | poolConf = p.Pool 153 | } 154 | if poolConf == nil { 155 | poolConf = defaultPool 156 | } 157 | 158 | //对redis实例进行排序 159 | sort.Sort(sort.StringSlice(groupServers)) 160 | redisServers := make([]*RedisServer, 0, len(groupServers)) 161 | for _, serverID := range groupServers { 162 | server := servers[serverID] 163 | if server == nil { 164 | return fmt.Errorf("can't find server id %s", serverID) 165 | } 166 | groupServer := *server 167 | if err := groupServer.initPool(poolConf); err != nil { 168 | return err 169 | } 170 | redisServers = append(redisServers, &groupServer) 171 | } 172 | groups[groupID] = redisServers 173 | } 174 | p.groups = groups 175 | return nil 176 | } 177 | 178 | // RedisConfig implements RedisConfigurer 179 | func (p *RedisConf) RedisConfig() *RedisConf { 180 | return p 181 | } 182 | -------------------------------------------------------------------------------- /cache/redis_test.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strconv" 7 | "testing" 8 | 9 | c "github.com/d0ngw/go/common" 10 | "github.com/gomodule/redigo/redis" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | var ( 15 | redisServer = &RedisServer{ 16 | ID: "test", 17 | Host: "127.0.0.1", 18 | Port: 6379, 19 | } 20 | r *RedisClient 21 | ) 22 | 23 | func init() { 24 | err := redisServer.initPool(defaultPool) 25 | if err != nil { 26 | panic(err) 27 | } 28 | var groups = map[string][]*RedisServer{"test": {redisServer}, "example": {redisServer}} 29 | r = NewRedisClient(groups) 30 | } 31 | 32 | func TestRedis(t *testing.T) { 33 | 34 | param := NewParamConf("test", "test_", 0) 35 | testSetGet(t, r, param) 36 | 37 | expireParam := NewParamConf("test", "test_ex_", 20) 38 | testSetGet(t, r, expireParam) 39 | 40 | confKey := expireParam.NewParamKey("server") 41 | err := r.SetObject(confKey, redisServer) 42 | assert.Nil(t, err) 43 | 44 | server := RedisServer{} 45 | ok, err := r.GetObject(confKey, &server) 46 | assert.Nil(t, err) 47 | assert.True(t, ok) 48 | assert.EqualValues(t, server.ID, redisServer.ID) 49 | 50 | ageParam := param.NewParamKey("age") 51 | exist, err := r.Exists(ageParam) 52 | assert.Nil(t, err) 53 | assert.True(t, exist) 54 | deleted, err := r.Del(ageParam) 55 | assert.Nil(t, err) 56 | assert.True(t, deleted) 57 | exist, err = r.Exists(ageParam) 58 | assert.Nil(t, err) 59 | assert.False(t, exist) 60 | 61 | ageNotExistParam := expireParam.NewParamKey("age_not_exist") 62 | expired, err := r.Expire(ageNotExistParam) 63 | assert.False(t, expired) 64 | deleted, err = r.Del(ageNotExistParam) 65 | assert.Nil(t, err) 66 | assert.False(t, deleted) 67 | } 68 | 69 | func testSetGet(t *testing.T, r *RedisClient, param *ParamConf) { 70 | ageParam := param.NewParamKey("age") 71 | assert.Nil(t, r.Set(ageParam, 10)) 72 | reply, ok, err := r.Get(ageParam) 73 | assert.Nil(t, err) 74 | assert.True(t, ok) 75 | i, _ := redis.Int(reply, err) 76 | assert.EqualValues(t, 10, i) 77 | 78 | v, ok, err := r.GetInt(ageParam) 79 | assert.Nil(t, err) 80 | assert.EqualValues(t, 10, v) 81 | assert.True(t, ok) 82 | 83 | ageNotExistParam := param.NewParamKey("age_not_exist") 84 | ageNotExistParam.expire = 0 85 | v, ok, err = r.GetInt(ageNotExistParam) 86 | assert.Nil(t, err) 87 | assert.False(t, ok) 88 | 89 | v64, ok, err := r.GetInt64(ageParam) 90 | assert.Nil(t, err) 91 | assert.EqualValues(t, 10, v64) 92 | assert.True(t, ok) 93 | 94 | f64, ok, err := r.GetFloat64(ageParam) 95 | assert.Nil(t, err) 96 | assert.EqualValues(t, 10, f64) 97 | assert.True(t, ok) 98 | 99 | s, ok, err := r.GetString(ageParam) 100 | assert.Nil(t, err) 101 | assert.EqualValues(t, "10", s) 102 | assert.True(t, ok) 103 | } 104 | 105 | type TestUser struct { 106 | Name string 107 | Age int 108 | } 109 | 110 | func TestPipeline(t *testing.T) { 111 | pipeline, err := NewPipeline(r) 112 | if err != nil { 113 | panic(err) 114 | } 115 | defer pipeline.Close() 116 | paramOdd := NewParamConf("test", "u_odd_", 0) 117 | paramEven := NewParamConf("test", "u_even_", 0) 118 | 119 | // set user 120 | for i := 0; i < 10; i++ { 121 | var paramConf *ParamConf 122 | if i%2 == 0 { 123 | paramConf = paramEven 124 | } else { 125 | paramConf = paramOdd 126 | } 127 | user := &TestUser{Name: "user" + strconv.Itoa(i)} 128 | param := paramConf.NewParamKey(strconv.Itoa(i)) 129 | bytes, _ := MsgPackEncodeBytes(user) 130 | pipeline.Send(param, SET, param.Key(), bytes) 131 | } 132 | // get user 133 | for i := 0; i < 11; i++ { 134 | var paramConf *ParamConf 135 | if i%2 == 0 { 136 | paramConf = paramEven 137 | } else { 138 | paramConf = paramOdd 139 | } 140 | param := paramConf.NewParamKey(strconv.Itoa(i)) 141 | pipeline.Send(param, GET, param.Key()) 142 | } 143 | 144 | // del user 145 | for i := 0; i < 10; i++ { 146 | var paramConf *ParamConf 147 | if i%2 == 0 { 148 | paramConf = paramEven 149 | } else { 150 | paramConf = paramOdd 151 | } 152 | param := paramConf.NewParamKey(strconv.Itoa(i)) 153 | pipeline.Send(param, DEL, param.Key()) 154 | } 155 | 156 | replies, err := pipeline.Receive() 157 | assert.Nil(t, err) 158 | assert.Equal(t, 31, len(replies)) 159 | 160 | setReplies := replies[0:10] 161 | for _, v := range setReplies { 162 | seted, err := redis.String(v.Reply, v.Err) 163 | assert.Nil(t, err) 164 | assert.Equal(t, ReplyOK, seted) 165 | } 166 | getReplies := replies[10:20] 167 | for i, v := range getReplies { 168 | bytes, err := redis.Bytes(v.Reply, v.Err) 169 | assert.Nil(t, err) 170 | user := &TestUser{} 171 | MsgPackDecodeBytes(bytes, user) 172 | assert.Equal(t, "user"+strconv.Itoa(i), user.Name) 173 | } 174 | 175 | getFailReply := replies[20] 176 | assert.Nil(t, getFailReply.Reply) 177 | assert.Nil(t, getFailReply.Err) 178 | 179 | delReplies := replies[21:31] 180 | for _, v := range delReplies { 181 | deleted, err := redis.Bool(v.Reply, v.Err) 182 | assert.Nil(t, err) 183 | assert.True(t, deleted) 184 | } 185 | } 186 | 187 | func TestGetObjects(t *testing.T) { 188 | paramConf := NewParamConf("test", "u__", 0) 189 | keys := []string{} 190 | 191 | // set user 192 | for i := 0; i < 10; i++ { 193 | k := strconv.Itoa(i) 194 | user := &TestUser{Name: "user" + strconv.Itoa(i), Age: i} 195 | param := paramConf.NewParamKey(k) 196 | bytes, _ := MsgPackEncodeBytes(user) 197 | r.Set(param, bytes) 198 | keys = append(keys, k) 199 | } 200 | 201 | keys = append([]string{"-1"}, keys...) 202 | var users = make([]interface{}, len(keys)) 203 | c.FillSlice(len(keys), func(index int) { users[index] = &TestUser{} }) 204 | 205 | assert.NotNil(t, users[0]) 206 | assert.True(t, users[0] != nil) 207 | err := r.GetObjects(paramConf, keys, users, nil) 208 | assert.Nil(t, err) 209 | assert.Equal(t, 11, len(users)) 210 | assert.Nil(t, users[0]) 211 | fmt.Println(reflect.TypeOf(users[0])) 212 | assert.True(t, users[0] == nil) 213 | for i, v := range users[1:] { 214 | vu := v.(*TestUser) 215 | assert.Equal(t, i, vu.Age) 216 | assert.Equal(t, "user"+strconv.Itoa(i), vu.Name) 217 | } 218 | } 219 | 220 | func TestIncr(t *testing.T) { 221 | paramConf := NewParamConf("test", "incu__", 10) 222 | 223 | key := paramConf.NewParamKey("expire") 224 | defer r.Del(key) 225 | // set user 226 | for i := 0; i < 10; i++ { 227 | val, err := r.IncrBy(key, 2) 228 | assert.NoError(t, err) 229 | assert.EqualValues(t, 2*(i+1), val) 230 | } 231 | 232 | ttl, err := r.TTL(key) 233 | assert.NoError(t, err) 234 | assert.True(t, ttl > 0) 235 | 236 | paramConf = NewParamConf("test", "incu__", 0) 237 | key = paramConf.NewParamKey("noexpire") 238 | defer r.Del(key) 239 | for i := 0; i < 10; i++ { 240 | val, err := r.IncrBy(key, 2) 241 | assert.NoError(t, err) 242 | assert.EqualValues(t, 2*(i+1), val) 243 | } 244 | 245 | ttl, err = r.TTL(key) 246 | assert.NoError(t, err) 247 | assert.Equal(t, -1, ttl) 248 | } 249 | -------------------------------------------------------------------------------- /common/aes.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "bytes" 5 | "crypto/aes" 6 | "crypto/cipher" 7 | "errors" 8 | "fmt" 9 | ) 10 | 11 | // PKCS5Padding pkcs5 padding 12 | func PKCS5Padding(ciphertext []byte, blockSize int) ([]byte, error) { 13 | padding := blockSize - len(ciphertext)%blockSize 14 | padtext := bytes.Repeat([]byte{byte(padding)}, padding) 15 | return append(ciphertext, padtext...), nil 16 | } 17 | 18 | // PKCS5UnPadding pkcs5 unpadding 19 | func PKCS5UnPadding(origData []byte) ([]byte, error) { 20 | length := len(origData) 21 | if length == 0 { 22 | return nil, errors.New("invalid length") 23 | } 24 | unpadding := int(origData[length-1]) 25 | if unpadding < 0 { 26 | return nil, errors.New("invalid unpadding") 27 | } 28 | end := length - unpadding 29 | if end < 0 || end > length { 30 | return nil, errors.New("invalid end padding") 31 | } 32 | return origData[:(length - unpadding)], nil 33 | } 34 | 35 | // AesEncrypt 对data用key加密,使用PKCS5 Padding 36 | func AesEncrypt(data, key []byte) (result []byte, err error) { 37 | return aesEncrypt(data, key, nil) 38 | } 39 | 40 | // AesEncryptWithIV 对data用key加密,使用PKCS5 Padding 41 | func AesEncryptWithIV(data, key, iv []byte) (result []byte, err error) { 42 | return aesEncrypt(data, key, iv) 43 | } 44 | 45 | // AesDecrypt 对data用key解密,使用PKCS5 Padding 46 | func AesDecrypt(data, key []byte) (result []byte, err error) { 47 | return aesDecrypt(data, key, nil) 48 | } 49 | 50 | // AesDecryptWithIV 对data用key和iv解密加密,使用PKCS5 Padding 51 | func AesDecryptWithIV(data, key, iv []byte) (result []byte, err error) { 52 | return aesDecrypt(data, key, iv) 53 | } 54 | 55 | func aesEncrypt(data, key, iv []byte) (result []byte, err error) { 56 | defer func() { 57 | if reErr := recover(); reErr != nil { 58 | Errorf("AES Encrypt err:%s ", reErr) 59 | err = fmt.Errorf("AES Encrypt fail,%v", reErr) 60 | } 61 | }() 62 | block, err := aes.NewCipher(key) 63 | if err != nil { 64 | return nil, err 65 | } 66 | blockSize := block.BlockSize() 67 | data, err = PKCS5Padding(data, blockSize) 68 | if err != nil { 69 | return nil, err 70 | } 71 | if len(iv) == 0 { 72 | iv = key[:blockSize] 73 | } 74 | if len(iv) != block.BlockSize() { 75 | err = fmt.Errorf("invalid iv length %d", len(iv)) 76 | return 77 | } 78 | blockMode := cipher.NewCBCEncrypter(block, iv) 79 | crypted := make([]byte, len(data)) 80 | blockMode.CryptBlocks(crypted, data) 81 | return crypted, nil 82 | } 83 | 84 | func aesDecrypt(data, key, iv []byte) (result []byte, err error) { 85 | defer func() { 86 | if reErr := recover(); reErr != nil { 87 | Errorf("AES Decrypt err:%s", reErr) 88 | err = fmt.Errorf("AES Decrypt fail,%v", reErr) 89 | } 90 | }() 91 | 92 | block, err := aes.NewCipher(key) 93 | if err != nil { 94 | return nil, err 95 | } 96 | blockSize := block.BlockSize() 97 | if len(iv) == 0 { 98 | iv = key[:blockSize] 99 | } 100 | if len(iv) != block.BlockSize() { 101 | err = fmt.Errorf("invalid iv length %d", len(iv)) 102 | return 103 | } 104 | blockMode := cipher.NewCBCDecrypter(block, iv) 105 | origData := make([]byte, len(data)) 106 | blockMode.CryptBlocks(origData, data) 107 | return PKCS5UnPadding(origData) 108 | } 109 | -------------------------------------------------------------------------------- /common/aes_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestAES(t *testing.T) { 10 | s := "abc123" 11 | key := []byte("123456789abcdefghijklnm"[0:16]) 12 | 13 | enc, err := AesEncrypt(nil, key) 14 | assert.Nil(t, err, "error:", err) 15 | dec, err := AesDecrypt(enc, key) 16 | assert.Nil(t, err, "error:", err) 17 | 18 | enc, err = AesEncrypt([]byte(s), key) 19 | assert.Nil(t, err, "error:", err) 20 | dec, err = AesDecrypt(enc, key) 21 | assert.EqualValues(t, []byte(s), dec) 22 | 23 | iv := []byte("1234567890123456") 24 | enc, err = AesEncryptWithIV([]byte(s), key, iv) 25 | assert.Nil(t, err, "error:", err) 26 | dec, err = AesDecryptWithIV(enc, key, iv) 27 | assert.Nil(t, err, "error:", err) 28 | assert.EqualValues(t, []byte(s), dec) 29 | } 30 | -------------------------------------------------------------------------------- /common/concurrent_map.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "fmt" 5 | "hash/fnv" 6 | "sync" 7 | ) 8 | 9 | const ( 10 | defaultCurrentMapShardCount = 32 11 | ) 12 | 13 | type mCurrentMapShared struct { 14 | items map[interface{}]interface{} 15 | sync.RWMutex 16 | } 17 | 18 | // CurrentMap A "thread" safe map of type interface{}:interface{},auto gen from concurrent_map_template.go 19 | type CurrentMap struct { 20 | shardCount uint 21 | shards []*mCurrentMapShared 22 | } 23 | 24 | // NewCurrentMap Create a new concurrent map with 32 shards 25 | func NewCurrentMap() *CurrentMap { 26 | return NewCurrentMapWithShard(defaultCurrentMapShardCount) 27 | } 28 | 29 | // NewCurrentMapWithShard Creates a new concurrent map. 30 | func NewCurrentMapWithShard(shardCount uint) *CurrentMap { 31 | if shardCount == 0 { 32 | shardCount = defaultCurrentMapShardCount 33 | } 34 | 35 | shards := make([]*mCurrentMapShared, shardCount) 36 | 37 | var i uint 38 | for i = 0; i < shardCount; i++ { 39 | shards[i] = &mCurrentMapShared{items: make(map[interface{}]interface{})} 40 | } 41 | return &CurrentMap{shardCount: shardCount, shards: shards} 42 | } 43 | 44 | // Returns shard under given key 45 | func (m *CurrentMap) getShard(key interface{}) *mCurrentMapShared { 46 | strKey := fmt.Sprintf("%v", key) 47 | hasher := fnv.New32() 48 | hasher.Write([]byte(strKey)) 49 | return m.shards[uint(hasher.Sum32())%uint(m.shardCount)] 50 | } 51 | 52 | // Set the given value under the specified key. 53 | func (m *CurrentMap) Set(key interface{}, value interface{}) { 54 | shard := m.getShard(key) 55 | shard.Lock() 56 | defer shard.Unlock() 57 | shard.items[key] = value 58 | } 59 | 60 | // SetIfAbsent the given value under the specified key if no value was associated with it. 61 | func (m *CurrentMap) SetIfAbsent(key interface{}, value interface{}) (success bool, preVal interface{}) { 62 | shard := m.getShard(key) 63 | shard.Lock() 64 | v, ok := shard.items[key] 65 | if !ok { 66 | shard.items[key] = value 67 | } 68 | shard.Unlock() 69 | return !ok, v 70 | } 71 | 72 | // Get Retrieves an element from map under given key. 73 | func (m *CurrentMap) Get(key interface{}) (interface{}, bool) { 74 | shard := m.getShard(key) 75 | shard.RLock() 76 | defer shard.RUnlock() 77 | val, ok := shard.items[key] 78 | return val, ok 79 | } 80 | 81 | // Count Returns the number of elements within the map. 82 | func (m *CurrentMap) Count() int { 83 | count := 0 84 | var i uint 85 | for i = 0; i < m.shardCount; i++ { 86 | shard := m.shards[i] 87 | shard.RLock() 88 | count += len(shard.items) 89 | shard.RUnlock() 90 | } 91 | return count 92 | } 93 | 94 | // Has Looks up an item under specified key 95 | func (m *CurrentMap) Has(key interface{}) bool { 96 | shard := m.getShard(key) 97 | shard.RLock() 98 | defer shard.RUnlock() 99 | _, ok := shard.items[key] 100 | return ok 101 | } 102 | 103 | // Remove an element from the map. 104 | func (m *CurrentMap) Remove(key interface{}) { 105 | shard := m.getShard(key) 106 | shard.Lock() 107 | defer shard.Unlock() 108 | delete(shard.items, key) 109 | } 110 | 111 | // IsEmpty Checks if map is empty. 112 | func (m *CurrentMap) IsEmpty() bool { 113 | return m.Count() == 0 114 | } 115 | 116 | // CurrentMapTuple Used by the Iter & IterBuffered functions to wrap two variables together over a channel, 117 | type CurrentMapTuple struct { 118 | Key interface{} 119 | Val interface{} 120 | } 121 | 122 | // Iter Returns an iterator which could be used in a for range loop. 123 | func (m CurrentMap) Iter() <-chan CurrentMapTuple { 124 | ch := make(chan CurrentMapTuple) 125 | go func() { 126 | // Foreach shard. 127 | for _, shard := range m.shards { 128 | // Foreach key, value pair. 129 | shard.RLock() 130 | for key, val := range shard.items { 131 | ch <- CurrentMapTuple{key, val} 132 | } 133 | shard.RUnlock() 134 | } 135 | close(ch) 136 | }() 137 | return ch 138 | } 139 | 140 | // IterBuffered Returns a buffered iterator which could be used in a for range loop. 141 | func (m CurrentMap) IterBuffered() <-chan CurrentMapTuple { 142 | ch := make(chan CurrentMapTuple, m.Count()) 143 | go func() { 144 | // Foreach shard. 145 | for _, shard := range m.shards { 146 | // Foreach key, value pair. 147 | shard.RLock() 148 | for key, val := range shard.items { 149 | ch <- CurrentMapTuple{key, val} 150 | } 151 | shard.RUnlock() 152 | } 153 | close(ch) 154 | }() 155 | return ch 156 | } 157 | 158 | // Items Returns all items as map[interface{}]interface{} 159 | func (m CurrentMap) Items() map[interface{}]interface{} { 160 | tmp := make(map[interface{}]interface{}) 161 | 162 | // Insert items to temporary map. 163 | for item := range m.IterBuffered() { 164 | tmp[item.Key] = item.Val 165 | } 166 | 167 | return tmp 168 | } 169 | -------------------------------------------------------------------------------- /common/concurrent_map_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestConcurrentMap(t *testing.T) { 12 | currentMap := NewCurrentMap() 13 | wg := sync.WaitGroup{} 14 | count := 100 15 | wg.Add(count) 16 | for i := 0; i < count; i++ { 17 | go func() { 18 | for j := 0; j < 1000; j++ { 19 | currentMap.Set(1, 1) 20 | val, ok := currentMap.Get(1) 21 | assert.True(t, ok) 22 | assert.EqualValues(t, 1, val) 23 | assert.IsType(t, 1, val) 24 | //fmt.Printf("val %v,type:%T\n", val, val) 25 | 26 | currentMap.Set("1", 2) 27 | val, ok = currentMap.Get("1") 28 | assert.True(t, ok) 29 | assert.EqualValues(t, 2, val) 30 | assert.IsType(t, 1, val) 31 | //fmt.Printf("val %v,type:%T\n", val, val) 32 | 33 | val, ok = currentMap.Get(1) 34 | assert.True(t, ok) 35 | assert.EqualValues(t, 1, val) 36 | assert.IsType(t, 1, val) 37 | 38 | currentMap.Remove(3) 39 | 40 | //fmt.Printf("val %v,type:%T\n", val, val) 41 | assert.EqualValues(t, 2, currentMap.Count(), "") 42 | for range currentMap.IterBuffered() { 43 | //fmt.Println(v.Key, "=", v.Val) 44 | } 45 | } 46 | wg.Done() 47 | }() 48 | } 49 | wg.Wait() 50 | mapItems := currentMap.Items() 51 | assert.EqualValues(t, 2, len(mapItems)) 52 | for k, v := range mapItems { 53 | fmt.Println(k, "=", v) 54 | } 55 | fmt.Println("finish") 56 | } 57 | -------------------------------------------------------------------------------- /common/config.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "errors" 5 | "io/ioutil" 6 | "os" 7 | "reflect" 8 | "runtime" 9 | ) 10 | 11 | var ( 12 | errInvalidConf = errors.New("invalid conf") 13 | ) 14 | 15 | // ConfigLoader 配置内容加载器 16 | type ConfigLoader interface { 17 | Load(configPath string) (content []byte, err error) 18 | 19 | Exist(configPath string) (exist bool, err error) 20 | } 21 | 22 | // ConfigFileLoader 从本地文件中加载配置 23 | type ConfigFileLoader struct { 24 | } 25 | 26 | // Load impls ConfigLoader.Load 27 | func (p *ConfigFileLoader) Load(configPath string) (content []byte, err error) { 28 | content, err = ioutil.ReadFile(configPath) 29 | return 30 | } 31 | 32 | // Exist impls ConfigLoader.Exist 33 | func (p *ConfigFileLoader) Exist(configPath string) (exist bool, err error) { 34 | info, err := os.Stat(configPath) 35 | if os.IsNotExist(err) { 36 | err = nil 37 | return 38 | } 39 | if info != nil { 40 | exist = !info.IsDir() 41 | } 42 | return 43 | } 44 | 45 | var ( 46 | //FileLoader 默认加载 47 | FileLoader ConfigLoader = &ConfigFileLoader{} 48 | ) 49 | 50 | // Configurer 配置器 51 | type Configurer interface { 52 | //解析配置 53 | Parse() error 54 | } 55 | 56 | // LogConfig 日志配置 57 | type LogConfig struct { 58 | Env string `yaml:"env"` 59 | FileName string `yaml:"file_name"` 60 | MaxSize int `yaml:"max_size"` 61 | MaxBackups int `yaml:"max_backups"` 62 | MaxAge int `yaml:"max_age"` 63 | NoCaller bool `yaml:"no_caller"` 64 | Level string `yaml:"level"` 65 | } 66 | 67 | // LogConfiger the log configer 68 | type LogConfiger interface { 69 | GetLogConfig() *LogConfig 70 | } 71 | 72 | // Parse 解析日志配置 73 | func (p *LogConfig) Parse() error { 74 | return initLogger(p) 75 | } 76 | 77 | // GetOutputFile 返回日志的输出文件 78 | func (p *LogConfig) GetOutputFile() string { 79 | if p == nil { 80 | return "" 81 | } 82 | 83 | name := os.Getenv("LOG_FILE") 84 | if name != "" { 85 | return name 86 | } 87 | return p.FileName 88 | } 89 | 90 | // RuntimeConfig 运行期配置 91 | type RuntimeConfig struct { 92 | Maxprocs int //最大的PROCS个数 93 | } 94 | 95 | // Parse 解析运行期配置 96 | func (p *RuntimeConfig) Parse() error { 97 | if p.Maxprocs > 0 { 98 | preProcs := runtime.GOMAXPROCS(p.Maxprocs) 99 | Infof("Set runtime.MAXPROCS to %v,old is %v", p.Maxprocs, preProcs) 100 | } 101 | return nil 102 | } 103 | 104 | // AppConfig 基础的应用配置 105 | type AppConfig struct { 106 | *LogConfig `yaml:"log"` 107 | *RuntimeConfig `yaml:"runtime"` 108 | *ValidateRuleConfig `yaml:"validates"` 109 | } 110 | 111 | // Parse 解析基础的应用配置 112 | func (p *AppConfig) Parse() error { 113 | return Parse(p) 114 | } 115 | 116 | // GetValidateRuleConfig implements ValidateConfiguer 117 | func (p *AppConfig) GetValidateRuleConfig() *ValidateRuleConfig { 118 | return p.ValidateRuleConfig 119 | } 120 | 121 | // GetLogConfig impls LogConfiger 122 | func (p *AppConfig) GetLogConfig() *LogConfig { 123 | return p.LogConfig 124 | } 125 | 126 | // Parse 解析配置 127 | func Parse(conf interface{}) error { 128 | config := reflect.Indirect(reflect.ValueOf(conf)) 129 | fieldCount := config.NumField() 130 | 131 | for i := 0; i < fieldCount; i++ { 132 | val := reflect.Indirect(config.Field(i)) 133 | if !val.IsValid() { 134 | continue 135 | } 136 | 137 | if configFieldValue, ok := val.Addr().Interface().(Configurer); ok { 138 | if err := configFieldValue.Parse(); err != nil { 139 | return err 140 | } 141 | } 142 | } 143 | return nil 144 | } 145 | -------------------------------------------------------------------------------- /common/config_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "unsafe" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | var data = ` 12 | a: Easy! 13 | b: 14 | c: 2 15 | d: [3, 4] 16 | ` 17 | 18 | type conf struct { 19 | A string 20 | B struct { 21 | C int 22 | D []int 23 | } 24 | } 25 | 26 | func TestLoadYAML(t *testing.T) { 27 | config := conf{} 28 | LoadYAMl([]byte(data), &config) 29 | assert.Equal(t, "Easy!", config.A) 30 | assert.Equal(t, 2, config.B.C) 31 | assert.Equal(t, 2, len(config.B.D)) 32 | assert.Equal(t, []int{3, 4}, config.B.D) 33 | } 34 | 35 | var appConfigData = `validates: 36 | sname: v1 37 | rules: 38 | - name: minStr 39 | desc: "字符串不恩能够为空,且最小长度为1,最大长度为2" 40 | validators: 41 | - min: "1" 42 | name: strlen 43 | max: "2" 44 | - name: notempty 45 | validates2: 46 | sname: v2 47 | rules: 48 | - name: minStr 49 | desc: "字符串不恩能够为空,且最小长度为1,最大长度为5" 50 | validators: 51 | - min: "1" 52 | name: strlen 53 | max: "5" 54 | - name: notempty 55 | - name: allowempty 56 | desc: "字符串可以为空,如果不为空,则最小长度为2,最大长度为5" 57 | validators: 58 | - min: "0" 59 | name: strlen 60 | max: "5" 61 | ` 62 | 63 | type ConfigTest struct { 64 | AppConfig `yaml:",inline"` 65 | V2 *ValidateRuleConfig `yaml:"validates2"` 66 | } 67 | 68 | func TestAppConfig(t *testing.T) { 69 | var appConfig ConfigTest 70 | LoadYAMl([]byte(appConfigData), &appConfig) 71 | Parse(&appConfig) 72 | fmt.Println("validates:", appConfig.ValidateRuleConfig.parsed) 73 | fmt.Println("validates:", appConfig.V2.parsed) 74 | v2 := appConfig.V2.NewService().(ValidateService) 75 | v1 := appConfig.ValidateRuleConfig.NewService().(ValidateService) 76 | fmt.Printf("v1 name:%s,v2 name:%s\n", v1.Name(), v2.Name()) 77 | err := v1.Validate("minStr", "") 78 | assert.NotNil(t, err) 79 | err = v2.Validate("allowempty", "") 80 | assert.Nil(t, err) 81 | 82 | err = v1.Validate("minStr", "he") 83 | assert.Nil(t, err) 84 | 85 | err = v2.Validate("minStr", "hello") 86 | assert.Nil(t, err) 87 | 88 | var s1 ValidateService 89 | var s2 interface{} 90 | 91 | fmt.Printf("s1 type:%T,size:%d\n", s1, unsafe.Sizeof(s1)) 92 | fmt.Printf("s2 type:%T,size:%d\n", s2, unsafe.Sizeof(s2)) 93 | } 94 | 95 | func TestConfigLoader(t *testing.T) { 96 | fileLoader := &ConfigFileLoader{} 97 | exist, err := fileLoader.Exist("") 98 | assert.False(t, exist) 99 | assert.NoError(t, err) 100 | 101 | exist, err = fileLoader.Exist("config_test.go") 102 | assert.True(t, exist) 103 | assert.NoError(t, err) 104 | } 105 | -------------------------------------------------------------------------------- /common/config_yaml.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "os" 7 | "path" 8 | 9 | yaml "gopkg.in/yaml.v3" 10 | ) 11 | 12 | // LoadYAMLFromPath 将YAML文件中的配置加载到到结构体target中 13 | func LoadYAMLFromPath(filename string, target interface{}) error { 14 | data, err := os.ReadFile(filename) 15 | if err != nil { 16 | return err 17 | } 18 | return LoadYAMl(data, target) 19 | } 20 | 21 | // LoadYAMl 将data中的YAML配置加载到到结构体target中 22 | func LoadYAMl(data []byte, target interface{}) error { 23 | if len(data) == 0 { 24 | return fmt.Errorf("can't load yaml config from empty data") 25 | } 26 | return yaml.Unmarshal([]byte(data), target) 27 | } 28 | 29 | // LoadConfig 从configDir目录下的多个path指定的YAML配置文件中加载配置 30 | func LoadConfig(config Configurer, addonConfig string, configDir string, pathes ...string) (err error) { 31 | return LoadConfigWithLoader(FileLoader, config, addonConfig, configDir, pathes...) 32 | } 33 | 34 | // LoadConfigWithLoader 使用指定的加载器加载配置 35 | func LoadConfigWithLoader(loader ConfigLoader, config Configurer, addonConfig string, configDir string, pathes ...string) (err error) { 36 | if loader == nil { 37 | err = errors.New("no loader") 38 | return 39 | } 40 | if len(pathes) == 0 && addonConfig == "" { 41 | return errInvalidConf 42 | } 43 | 44 | var content []byte 45 | if addonConfig != "" { 46 | content = append(content, addonConfig...) 47 | content = append(content, []byte("\n")...) 48 | } 49 | for _, p := range pathes { 50 | p = path.Join(configDir, p) 51 | Infof("load conf from:%s", p) 52 | cnt, err := loader.Load(p) 53 | if err != nil { 54 | return err 55 | } 56 | if len(cnt) == 0 { 57 | Warnf("empty content in %s", p) 58 | continue 59 | } 60 | content = append(content, cnt...) 61 | content = append(content, []byte("\n")...) 62 | } 63 | err = LoadYAMl(content, config) 64 | if err != nil { 65 | return err 66 | } 67 | return 68 | } 69 | -------------------------------------------------------------------------------- /common/constant.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | // Status 定义状态 4 | type Status int8 5 | 6 | const ( 7 | // DISABLE 禁用 8 | DISABLE Status = 0 9 | // ENABLE 有效 10 | ENABLE Status = 1 11 | ) 12 | 13 | // IsValid 判断状态是否是有效的 14 | func (p Status) IsValid() bool { 15 | return p == DISABLE || p == ENABLE 16 | } 17 | 18 | // Value 状态的值 19 | func (p Status) Value() int8 { 20 | return int8(p) 21 | } 22 | 23 | // 定义环境变量 24 | const ( 25 | EnvWorkfDir = "work_dir" 26 | ) 27 | 28 | // 定义环境的常量 29 | const ( 30 | EnvDev = "dev" 31 | EnvTest = "test" 32 | EnvProduction = "prod" 33 | ) 34 | -------------------------------------------------------------------------------- /common/copy_on_write_map.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | "sync/atomic" 7 | ) 8 | 9 | type opType uint 10 | 11 | // opType的类型 12 | const ( 13 | opPut opType = iota //添加 14 | opPutOnlyAbsent //添加,如果指定的key已经存在,则返回error 15 | opDel //删除指定的key 16 | ) 17 | 18 | type cowMap map[interface{}]interface{} 19 | 20 | // CopyOnWriteMap copy on write map 21 | type CopyOnWriteMap struct { 22 | m atomic.Value 23 | mutex sync.Mutex 24 | } 25 | 26 | // NewCopyOnWriteMap 创建CopyOnWriteMap 27 | func NewCopyOnWriteMap() *CopyOnWriteMap { 28 | reg := &CopyOnWriteMap{} 29 | reg.m.Store(make(cowMap)) 30 | return reg 31 | } 32 | 33 | // copyMap 复制src map 34 | func copyMap(src cowMap) cowMap { 35 | m := make(cowMap) 36 | for k, v := range src { 37 | m[k] = v 38 | } 39 | return m 40 | } 41 | 42 | // modify 根据opType的操作类型修改CopyOnWriteMap 43 | func (p *CopyOnWriteMap) modify(key interface{}, value interface{}, op opType) error { 44 | p.mutex.Lock() 45 | defer p.mutex.Unlock() 46 | m1 := p.m.Load().(cowMap) 47 | 48 | switch op { 49 | case opPutOnlyAbsent: 50 | if _, ok := m1[key]; ok { 51 | return fmt.Errorf("duplicate key:%v", key) 52 | } 53 | fallthrough 54 | case opPut: 55 | m2 := copyMap(m1) 56 | m2[key] = value 57 | p.m.Store(m2) 58 | case opDel: 59 | m2 := copyMap(m1) 60 | delete(m2, key) 61 | p.m.Store(m2) 62 | default: 63 | panic(fmt.Errorf("unsupported op type %#v", op)) 64 | } 65 | 66 | return nil 67 | } 68 | 69 | // Put key及对应的value,如果key已经存在,则进行替换 70 | func (p *CopyOnWriteMap) Put(key interface{}, value interface{}) { 71 | p.modify(key, value, opPut) 72 | } 73 | 74 | // PutIfAbsent put key及对应的value,如果key已经存在,不进行替换,并返回错误 75 | func (p *CopyOnWriteMap) PutIfAbsent(key interface{}, value interface{}) error { 76 | return p.modify(key, value, opPutOnlyAbsent) 77 | } 78 | 79 | // Delete 删除key 80 | func (p *CopyOnWriteMap) Delete(key interface{}) { 81 | p.modify(key, nil, opDel) 82 | } 83 | 84 | // Get 取得key对应的值 85 | func (p *CopyOnWriteMap) Get(key interface{}) interface{} { 86 | m1 := p.m.Load().(cowMap) 87 | if value, ok := m1[key]; ok { 88 | return value 89 | } 90 | return nil 91 | } 92 | 93 | // CopyOnWriteSlice 94 | type cowSlice []interface{} 95 | 96 | // CopyOnWriteSlice copy on write slice 97 | type CopyOnWriteSlice struct { 98 | m atomic.Value 99 | mutex sync.Mutex 100 | } 101 | 102 | // NewCopyOnWriteSlice 创建CopyOnWriteSlice 103 | func NewCopyOnWriteSlice() *CopyOnWriteSlice { 104 | reg := &CopyOnWriteSlice{} 105 | reg.m.Store(make(cowSlice, 0)) 106 | return reg 107 | } 108 | 109 | // modify 根据opType的类型,修改CopyOnWriteSlice 110 | func (p *CopyOnWriteSlice) modify(value interface{}, op opType) error { 111 | p.mutex.Lock() 112 | defer p.mutex.Unlock() 113 | m1 := p.m.Load().(cowSlice) 114 | 115 | switch op { 116 | case opPut: 117 | m2 := make(cowSlice, len(m1), len(m1)+1) 118 | copy(m2, m1) 119 | m2 = append(m2, value) 120 | p.m.Store(m2) 121 | case opDel: 122 | m2 := make(cowSlice, 0, len(m1)) 123 | for _, v := range m1 { 124 | if v != value { 125 | m2 = append(m2, v) 126 | } 127 | } 128 | p.m.Store(m2) 129 | default: 130 | panic(fmt.Errorf("unsupported op type %#v", op)) 131 | } 132 | 133 | return nil 134 | } 135 | 136 | // Add 添加 137 | func (p *CopyOnWriteSlice) Add(value interface{}) error { 138 | if value == nil { 139 | panic("Can't add nil value") 140 | } 141 | return p.modify(value, opPut) 142 | } 143 | 144 | // Delete 删除value 145 | func (p *CopyOnWriteSlice) Delete(value interface{}) error { 146 | if value == nil { 147 | panic("Can't delete nil value") 148 | } 149 | return p.modify(value, opDel) 150 | } 151 | 152 | // Get 取得Slice 153 | func (p *CopyOnWriteSlice) Get() []interface{} { 154 | return p.m.Load().(cowSlice) 155 | } 156 | -------------------------------------------------------------------------------- /common/io_util.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "os" 9 | "os/signal" 10 | "syscall" 11 | ) 12 | 13 | // ParseInputAndOutput 解析输入的文件 14 | func ParseInputAndOutput(input, output string) (inputReader, outWriter *os.File, err error) { 15 | if input == "-" { 16 | inputReader = os.Stdin 17 | fmt.Fprintln(os.Stderr, "Read data from stdin") 18 | } else if input != "" { 19 | fmt.Fprintln(os.Stderr, "Read data from "+input) 20 | fileInput, err := os.Open(input) 21 | if err == nil { 22 | inputReader = fileInput 23 | } 24 | } 25 | 26 | if output == "" { 27 | fmt.Fprintln(os.Stderr, "Write data to stdout") 28 | outWriter = os.Stdout 29 | } else { 30 | fmt.Fprintln(os.Stderr, "Write data to "+output) 31 | fileOut, err := os.OpenFile(output, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) 32 | if err == nil { 33 | outWriter = fileOut 34 | } 35 | } 36 | 37 | var errorMsg string 38 | if inputReader == nil { 39 | errorMsg = "Invalid input:" + input 40 | if outWriter != nil { 41 | defer outWriter.Close() 42 | } 43 | } 44 | 45 | if outWriter == nil { 46 | if len(errorMsg) > 0 { 47 | errorMsg += "Invalid output:" + output 48 | } 49 | if inputReader != nil { 50 | defer inputReader.Close() 51 | } 52 | } 53 | 54 | if errorMsg != "" { 55 | return nil, nil, errors.New(errorMsg) 56 | } 57 | 58 | return 59 | } 60 | 61 | // ParseInput 解析输入的文件 62 | func ParseInput(input string) (inputReader *os.File, err error) { 63 | if input == "-" { 64 | inputReader = os.Stdin 65 | fmt.Fprintln(os.Stderr, "Read data from stdin") 66 | } else if input != "" { 67 | fmt.Fprintln(os.Stderr, "Read data from "+input) 68 | fileInput, err := os.Open(input) 69 | if err == nil { 70 | inputReader = fileInput 71 | } 72 | } 73 | 74 | var errorMsg string 75 | if inputReader == nil { 76 | errorMsg = "Invalid input:" + input 77 | } 78 | 79 | if errorMsg != "" { 80 | return nil, errors.New(errorMsg) 81 | } 82 | 83 | return 84 | } 85 | 86 | // PrintErrorMsgAndExit 打印信息并退出 87 | func PrintErrorMsgAndExit(msg string, err error) { 88 | fmt.Fprintf(os.Stderr, "%s Error:%v\n", msg, err) 89 | os.Exit(1) 90 | } 91 | 92 | // LF `\n` 93 | const LF = '\n' 94 | 95 | // ProcessLineFunc 行处理函数 96 | type ProcessLineFunc func(data string, lineNum int, readErr error) (stop bool) 97 | 98 | // ProcessLines 按行从rd中读取数据,交由processFunc进行处理 99 | func ProcessLines(rd io.Reader, processFunc ProcessLineFunc) { 100 | scanner := bufio.NewReaderSize(rd, 4*1024) 101 | var readErr error 102 | var lineNum = 0 103 | var data string 104 | for readErr == nil { 105 | data, readErr = scanner.ReadString(LF) 106 | lineNum++ 107 | if readErr != nil && readErr != io.EOF { 108 | processFunc(data, lineNum, readErr) 109 | break 110 | } else if readErr != nil && readErr == io.EOF { 111 | if len(data) > 0 { 112 | processFunc(data, lineNum, nil) 113 | } 114 | break 115 | } else { 116 | if len(data) > 0 { 117 | if processFunc(data, lineNum, nil) { 118 | break 119 | } 120 | } 121 | } 122 | } 123 | } 124 | 125 | // ProcessFileLines 按行处理文件 126 | func ProcessFileLines(file string, lineFunc ProcessLineFunc) { 127 | f, err := os.Open(file) 128 | if err != nil { 129 | return 130 | } 131 | defer f.Close() 132 | ProcessLines(f, lineFunc) 133 | } 134 | 135 | // WaitStop 等待退出信号 136 | func WaitStop() os.Signal { 137 | c := make(chan os.Signal, 1) 138 | signal.Notify(c, syscall.SIGINT, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGQUIT) 139 | s := <-c 140 | return s 141 | } 142 | 143 | // CreateDirIfAbsent 当目录不存在时创建 144 | func CreateDirIfAbsent(dir string) error { 145 | info, err := os.Stat(dir) 146 | if err != nil { 147 | if os.IsNotExist(err) { 148 | err = os.MkdirAll(dir, os.ModePerm) 149 | if err != nil { 150 | return fmt.Errorf("can't create dir:%s,err:%s", dir, err) 151 | } 152 | } 153 | } else if !info.IsDir() { 154 | return fmt.Errorf("not a dir `%s`", dir) 155 | } 156 | return nil 157 | } 158 | -------------------------------------------------------------------------------- /common/io_util_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "fmt" 5 | "path" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestProcessFileLines(t *testing.T) { 13 | f := path.Join("testdata", "1.txt") 14 | var fp = func(line string, lineNum int, readErr error) (stop bool) { 15 | assert.NoError(t, readErr) 16 | line = strings.TrimSpace(line) 17 | fmt.Println(lineNum, line) 18 | return 19 | } 20 | ProcessFileLines(f, fp) 21 | } 22 | -------------------------------------------------------------------------------- /common/json_util.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | ) 7 | 8 | // UnmarshalUseNumber 使用UserNumber进行解析,避免int64被错误地转为float64 9 | func UnmarshalUseNumber(data []byte, v interface{}) error { 10 | dec := json.NewDecoder(bytes.NewBuffer(data)) 11 | dec.UseNumber() 12 | return dec.Decode(v) 13 | } 14 | -------------------------------------------------------------------------------- /common/json_util_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | gjson "encoding/json" 5 | "fmt" 6 | "testing" 7 | ) 8 | 9 | type TestJSON struct { 10 | ID int64 `json:"id"` 11 | Data map[string]interface{} `json:"data"` 12 | Fid float64 `json:"fid"` 13 | } 14 | 15 | func TestJSONUnmarshal(t *testing.T) { 16 | var ( 17 | packet = []byte(`{"id":102410241023,"fid":1.2,"data":{"b":12356332453}}`) 18 | err error 19 | ) 20 | 21 | fmt.Println("Use json") 22 | m := map[string]interface{}{} 23 | err = gjson.Unmarshal(packet, &m) 24 | fmt.Println("err:", err, "m:", m) 25 | mjsonStr, err := gjson.Marshal(m) 26 | fmt.Println("err:", err, "mjsonStr:", string(mjsonStr)) 27 | 28 | fmt.Println("Use json useNumber") 29 | m2 := map[string]interface{}{} 30 | err = UnmarshalUseNumber(packet, &m2) 31 | fmt.Println("err:", err, "m2:", m2) 32 | mjsonStr, err = gjson.Marshal(m2) 33 | fmt.Println("err:", err, "mjsonStr:", string(mjsonStr)) 34 | 35 | fmt.Println("tj") 36 | tj := &TestJSON{} 37 | 38 | err = gjson.Unmarshal(packet, tj) 39 | fmt.Println("err:", err, "use default unmarshal tj:", tj) 40 | mjsonStr, err = gjson.Marshal(tj) 41 | fmt.Println("err:", err, "mjsonStr:", string(mjsonStr)) 42 | 43 | err = UnmarshalUseNumber(packet, tj) 44 | fmt.Println("err:", err, "use usenumber unmarshal tj:", tj) 45 | } 46 | -------------------------------------------------------------------------------- /common/linked_map.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "container/list" 5 | "sync" 6 | ) 7 | 8 | type mapElement struct { 9 | val interface{} 10 | element *list.Element 11 | } 12 | 13 | // LinkedMap implements linked map 14 | type LinkedMap struct { 15 | mutex sync.RWMutex 16 | l *list.List 17 | m map[interface{}]*mapElement 18 | } 19 | 20 | // NewLinkedMap create linked map 21 | func NewLinkedMap() *LinkedMap { 22 | return &LinkedMap{ 23 | l: list.New(), 24 | m: map[interface{}]*mapElement{}, 25 | } 26 | } 27 | 28 | // Put put value with key 29 | func (p *LinkedMap) Put(key, value interface{}) { 30 | p.mutex.Lock() 31 | defer p.mutex.Unlock() 32 | var keyElem *list.Element 33 | if pre, ok := p.m[key]; !ok { 34 | keyElem = p.l.PushBack(key) 35 | p.m[key] = &mapElement{ 36 | val: value, 37 | element: keyElem, 38 | } 39 | } else { 40 | pre.val = value 41 | } 42 | } 43 | 44 | // Get value with key 45 | func (p *LinkedMap) Get(key interface{}) (val interface{}, ok bool) { 46 | p.mutex.RLock() 47 | defer p.mutex.RUnlock() 48 | if pre, ok := p.m[key]; ok { 49 | return pre.val, ok 50 | } 51 | return nil, false 52 | } 53 | 54 | // Remove value with key 55 | func (p *LinkedMap) Remove(key interface{}) (preVal interface{}) { 56 | p.mutex.Lock() 57 | defer p.mutex.Unlock() 58 | if pre, ok := p.m[key]; ok { 59 | delete(p.m, key) 60 | p.l.Remove(pre.element) 61 | return pre.val 62 | } 63 | return nil 64 | } 65 | 66 | // Len return the length of the map 67 | func (p *LinkedMap) Len() int { 68 | p.mutex.RLock() 69 | defer p.mutex.RUnlock() 70 | return p.l.Len() 71 | } 72 | 73 | // MapEntry define map entry with key and value 74 | type MapEntry struct { 75 | Key interface{} 76 | Value interface{} 77 | } 78 | 79 | // Entries return entry slice 80 | func (p *LinkedMap) Entries() []*MapEntry { 81 | p.mutex.RLock() 82 | defer p.mutex.RUnlock() 83 | entries := make([]*MapEntry, p.l.Len()) 84 | var i = 0 85 | for e := p.l.Front(); e != nil; e = e.Next() { 86 | key := e.Value 87 | value := p.m[key].val 88 | entries[i] = &MapEntry{Key: key, Value: value} 89 | i++ 90 | } 91 | return entries 92 | } 93 | -------------------------------------------------------------------------------- /common/linked_map_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import "testing" 4 | import "github.com/stretchr/testify/assert" 5 | 6 | func TestLinkedMap(t *testing.T) { 7 | lm := NewLinkedMap() 8 | for i := 0; i < 100; i++ { 9 | lm.Put(i, i+1) 10 | } 11 | assert.Equal(t, 100, lm.Len()) 12 | for i := 0; i < 100; i++ { 13 | val, ok := lm.Get(i) 14 | assert.True(t, ok) 15 | assert.Equal(t, i+1, val) 16 | } 17 | entries := lm.Entries() 18 | assert.Equal(t, 100, len(entries)) 19 | for i := 0; i < 100; i++ { 20 | e := entries[i] 21 | assert.Equal(t, i, e.Key) 22 | assert.Equal(t, i+1, e.Value) 23 | } 24 | for i := 0; i < 100; i++ { 25 | preVal := lm.Remove(i) 26 | assert.Equal(t, i+1, preVal) 27 | } 28 | assert.Equal(t, 0, lm.Len()) 29 | assert.Equal(t, 0, lm.l.Len()) 30 | assert.Equal(t, 0, len(lm.m)) 31 | } 32 | -------------------------------------------------------------------------------- /common/logger.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "sync" 5 | 6 | "go.uber.org/zap/zapcore" 7 | ) 8 | 9 | // LogLevel log level 10 | type LogLevel int8 11 | 12 | // log levels 13 | const ( 14 | Debug LogLevel = LogLevel(zapcore.DebugLevel) 15 | Info LogLevel = LogLevel(zapcore.InfoLevel) 16 | Warn LogLevel = LogLevel(zapcore.WarnLevel) 17 | Error LogLevel = LogLevel(zapcore.ErrorLevel) 18 | ) 19 | 20 | // LogLevelName LogLevel 21 | func LogLevelName(name string) LogLevel { 22 | switch name { 23 | case "debug": 24 | return Debug 25 | case "info": 26 | return Info 27 | case "warn": 28 | return Warn 29 | case "error": 30 | return Error 31 | } 32 | return Info 33 | } 34 | 35 | func (p LogLevel) zapLevel() (level zapcore.Level, ok bool) { 36 | switch p { 37 | case Debug: 38 | level = zapcore.DebugLevel 39 | ok = true 40 | case Info: 41 | level = zapcore.InfoLevel 42 | ok = true 43 | case Warn: 44 | level = zapcore.WarnLevel 45 | ok = true 46 | case Error: 47 | level = zapcore.ErrorLevel 48 | ok = true 49 | } 50 | return 51 | } 52 | 53 | // Logger 日志记录接口 54 | type Logger interface { 55 | Debugf(format string, params ...interface{}) 56 | DebugEnabled() bool 57 | Infof(format string, params ...interface{}) 58 | InfoEnabled() bool 59 | Warnf(format string, params ...interface{}) 60 | WarnEnabled() bool 61 | Errorf(format string, params ...interface{}) 62 | ErrorEnabled() bool 63 | SetLevel(level LogLevel) 64 | Sync() 65 | } 66 | 67 | // Debugf debug级别记录日志 68 | func Debugf(format string, params ...interface{}) { 69 | logger.Debugf(format, params...) 70 | } 71 | 72 | // DebugEnabled debug 73 | func DebugEnabled() bool { 74 | return logger.DebugEnabled() 75 | } 76 | 77 | // Infof info级别记录日志 78 | func Infof(format string, params ...interface{}) { 79 | logger.Infof(format, params...) 80 | } 81 | 82 | // InfoEnabled info 83 | func InfoEnabled() bool { 84 | return logger.InfoEnabled() 85 | } 86 | 87 | // Warnf warn级别记录日志 88 | func Warnf(format string, params ...interface{}) { 89 | logger.Warnf(format, params...) 90 | } 91 | 92 | // WarnEnabled warn 93 | func WarnEnabled() bool { 94 | return logger.WarnEnabled() 95 | } 96 | 97 | // Errorf error级别记录日志 98 | func Errorf(format string, params ...interface{}) { 99 | logger.Errorf(format, params...) 100 | } 101 | 102 | // ErrorEnabled error 103 | func ErrorEnabled() bool { 104 | return logger.ErrorEnabled() 105 | } 106 | 107 | // SetLogLevel set the log level 108 | func SetLogLevel(level LogLevel) { 109 | logger.SetLevel(level) 110 | } 111 | 112 | // Logf log 113 | func Logf(level LogLevel, foramt string, params ...interface{}) { 114 | if level == Debug { 115 | logger.Debugf(foramt, params...) 116 | } else if level == Info { 117 | logger.Infof(foramt, params...) 118 | } else if level == Warn { 119 | logger.Warnf(foramt, params...) 120 | } else if level == Error { 121 | logger.Errorf(foramt, params...) 122 | } else { 123 | logger.Debugf(foramt, params...) 124 | } 125 | } 126 | 127 | // LoggerSync sync log 128 | func LoggerSync() { 129 | logger.Sync() 130 | } 131 | 132 | var ( 133 | logger Logger = NewStdLogger() 134 | ) 135 | 136 | var m sync.Mutex 137 | var loggerInitd bool 138 | 139 | // initLogger 初始化logger 140 | func initLogger(logConfig *LogConfig) (err error) { 141 | m.Lock() 142 | defer m.Unlock() 143 | 144 | if loggerInitd { 145 | Errorf("Logger has been already inited.") 146 | return 147 | } 148 | 149 | zapLogger := NewZapLogger(logConfig) 150 | logger = zapLogger 151 | return 152 | } 153 | -------------------------------------------------------------------------------- /common/logger_std.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "log" 5 | 6 | "go.uber.org/zap" 7 | "go.uber.org/zap/zapcore" 8 | ) 9 | 10 | // StdLogger 使用标准库封装的logger 11 | type StdLogger struct { 12 | logEnable zap.AtomicLevel 13 | } 14 | 15 | // NewStdLogger new info level logger 16 | func NewStdLogger() *StdLogger { 17 | logger := &StdLogger{ 18 | logEnable: zap.NewAtomicLevelAt(zapcore.InfoLevel), 19 | } 20 | return logger 21 | } 22 | 23 | // Debugf debug 24 | func (l *StdLogger) Debugf(format string, params ...interface{}) { 25 | if !l.logEnable.Enabled(zap.DebugLevel) { 26 | return 27 | } 28 | log.Printf(format, params...) 29 | } 30 | 31 | // DebugEnabled is debug enbale 32 | func (l *StdLogger) DebugEnabled() bool { 33 | return l.logEnable.Enabled(zap.DebugLevel) 34 | } 35 | 36 | // Infof info 37 | func (l *StdLogger) Infof(format string, params ...interface{}) { 38 | if !l.logEnable.Enabled(zap.InfoLevel) { 39 | return 40 | } 41 | log.Printf(format, params...) 42 | } 43 | 44 | // InfoEnabled is info enable 45 | func (l *StdLogger) InfoEnabled() bool { 46 | return l.logEnable.Enabled(zap.InfoLevel) 47 | } 48 | 49 | // Warnf warn 50 | func (l *StdLogger) Warnf(format string, params ...interface{}) { 51 | if !l.logEnable.Enabled(zap.WarnLevel) { 52 | return 53 | } 54 | log.Printf(format, params...) 55 | } 56 | 57 | // WarnEnabled is warn enabled 58 | func (l *StdLogger) WarnEnabled() bool { 59 | return l.logEnable.Enabled(zap.WarnLevel) 60 | } 61 | 62 | // Errorf error 63 | func (l *StdLogger) Errorf(format string, params ...interface{}) { 64 | if !l.logEnable.Enabled(zap.ErrorLevel) { 65 | return 66 | } 67 | log.Printf(format, params...) 68 | } 69 | 70 | // ErrorEnabled error 71 | func (l *StdLogger) ErrorEnabled() bool { 72 | return l.logEnable.Enabled(zap.ErrorLevel) 73 | } 74 | 75 | // Sync sync 76 | func (l *StdLogger) Sync() { 77 | 78 | } 79 | 80 | // SetLevel set the log level 81 | func (l *StdLogger) SetLevel(level LogLevel) { 82 | zapl, ok := level.zapLevel() 83 | if ok { 84 | l.logEnable.SetLevel(zapl) 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /common/logger_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestLog(t *testing.T) { 10 | SetLogLevel(Debug) 11 | Debugf("this is a test") 12 | assert.True(t, DebugEnabled()) 13 | assert.True(t, InfoEnabled()) 14 | SetLogLevel(Info) 15 | assert.False(t, DebugEnabled()) 16 | assert.True(t, InfoEnabled()) 17 | Debugf("this is a test, no debug") 18 | Infof("this is a test, info") 19 | SetLogLevel(0) 20 | assert.False(t, DebugEnabled()) 21 | assert.True(t, InfoEnabled()) 22 | Infof("this is a test, no level") 23 | Logf(Warn, "The is a test, warn") 24 | SetLogLevel(Error) 25 | assert.False(t, DebugEnabled()) 26 | assert.False(t, InfoEnabled()) 27 | assert.True(t, ErrorEnabled()) 28 | Infof("this is a test, no error") 29 | Errorf("this is a test, error") 30 | } 31 | -------------------------------------------------------------------------------- /common/logger_zap.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "os" 5 | 6 | "go.uber.org/zap" 7 | "go.uber.org/zap/zapcore" 8 | "gopkg.in/natefinch/lumberjack.v2" 9 | ) 10 | 11 | // ZapLogger 使用zap封装的logger 12 | type ZapLogger struct { 13 | logEnable zap.AtomicLevel 14 | logger *zap.SugaredLogger 15 | } 16 | 17 | // Debugf debug 18 | func (l *ZapLogger) Debugf(format string, params ...interface{}) { 19 | l.logger.Debugf(format, params...) 20 | } 21 | 22 | // DebugEnabled is debug enbale 23 | func (l *ZapLogger) DebugEnabled() bool { 24 | return l.logEnable.Enabled(zap.DebugLevel) 25 | } 26 | 27 | // Infof info 28 | func (l *ZapLogger) Infof(format string, params ...interface{}) { 29 | l.logger.Infof(format, params...) 30 | } 31 | 32 | // InfoEnabled is info enbale 33 | func (l *ZapLogger) InfoEnabled() bool { 34 | return l.logEnable.Enabled(zap.InfoLevel) 35 | } 36 | 37 | // Warnf warn 38 | func (l *ZapLogger) Warnf(format string, params ...interface{}) { 39 | l.logger.Warnf(format, params...) 40 | } 41 | 42 | // WarnEnabled is info enbale 43 | func (l *ZapLogger) WarnEnabled() bool { 44 | return l.logEnable.Enabled(zap.WarnLevel) 45 | } 46 | 47 | // Errorf error 48 | func (l *ZapLogger) Errorf(format string, params ...interface{}) { 49 | l.logger.Errorf(format, params...) 50 | } 51 | 52 | // ErrorEnabled is info enbale 53 | func (l *ZapLogger) ErrorEnabled() bool { 54 | return l.logEnable.Enabled(zap.ErrorLevel) 55 | } 56 | 57 | // Sync impls Logger.Sync 58 | func (l *ZapLogger) Sync() { 59 | l.logger.Sync() 60 | } 61 | 62 | // SetLevel set the log level 63 | func (l *ZapLogger) SetLevel(level LogLevel) { 64 | zapl, ok := level.zapLevel() 65 | if ok { 66 | l.logEnable.SetLevel(zapl) 67 | } 68 | } 69 | 70 | // NewZapLogger new zap logger 71 | func NewZapLogger(logConfig *LogConfig) *ZapLogger { 72 | var encoder zapcore.Encoder 73 | var writerSync zapcore.WriteSyncer 74 | var logEnable zap.AtomicLevel 75 | 76 | if logConfig.Env == EnvProduction { 77 | config := zap.NewProductionEncoderConfig() 78 | config.EncodeTime = zapcore.ISO8601TimeEncoder 79 | encoder = zapcore.NewConsoleEncoder(config) 80 | logEnable = zap.NewAtomicLevelAt(zapcore.InfoLevel) 81 | } else { 82 | config := zap.NewDevelopmentEncoderConfig() 83 | encoder = zapcore.NewConsoleEncoder(config) 84 | logEnable = zap.NewAtomicLevelAt(zapcore.DebugLevel) 85 | } 86 | 87 | if logConfig.Level != "" { 88 | zapl, ok := LogLevelName(logConfig.Level).zapLevel() 89 | if ok { 90 | logEnable = zap.NewAtomicLevelAt(zapl) 91 | } 92 | } 93 | 94 | outputFile := logConfig.GetOutputFile() 95 | if outputFile != "" { 96 | writerSync = zapcore.AddSync(&lumberjack.Logger{ 97 | Filename: outputFile, 98 | MaxSize: logConfig.MaxSize, 99 | MaxBackups: logConfig.MaxBackups, 100 | MaxAge: logConfig.MaxAge, 101 | LocalTime: true, 102 | }) 103 | } else { 104 | writerSync = zapcore.AddSync(os.Stderr) 105 | } 106 | 107 | core := zapcore.NewCore(encoder, writerSync, logEnable) 108 | logger := zap.New(core) 109 | if !logConfig.NoCaller { 110 | logger = logger.WithOptions(zap.AddCaller(), zap.AddCallerSkip(2)) 111 | } 112 | sugarLogger := logger.Sugar() 113 | return &ZapLogger{logger: sugarLogger, logEnable: logEnable} 114 | } 115 | -------------------------------------------------------------------------------- /common/model.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | //PageParam 分页参数 4 | type PageParam struct { 5 | //页数,从1开始 6 | Page int `json:"page"` 7 | //每页的条数,>0 8 | PageSize int `json:"page_size"` 9 | //游标 10 | Cursor int64 `json:"cursor"` 11 | } 12 | 13 | //Limit 根据maxPage和maxPageSize限制Page和PageSize 14 | func (p *PageParam) Limit(maxPage, maxPageSize int) { 15 | if p.Page <= 0 { 16 | p.Page = 1 17 | } 18 | if maxPage > 0 && p.Page > maxPage { 19 | p.Page = maxPage 20 | } 21 | if maxPageSize > 0 && (p.PageSize > maxPageSize || p.PageSize <= 0) { 22 | p.PageSize = maxPageSize 23 | } 24 | if p.PageSize <= 0 { 25 | p.PageSize = 10 26 | } 27 | } 28 | 29 | //StartIndex 返回从0开始的起始索引 30 | func (p *PageParam) StartIndex() int { 31 | return (p.Page - 1) * p.PageSize 32 | } 33 | 34 | //EndIndex 返回从0开始的截止索引 35 | func (p *PageParam) EndIndex() int { 36 | return p.StartIndex() + p.PageSize - 1 37 | } 38 | 39 | //PageResult 分页结果 40 | type PageResult[T any] struct { 41 | PageParam 42 | Total int64 `json:"total"` 43 | TotalPage int64 `json:"totalPage"` 44 | Items []T `json:"items"` 45 | } 46 | 47 | // SetTotal set total 48 | func (p *PageResult[T]) SetTotal(total int64) { 49 | p.Total = total 50 | } 51 | 52 | // SetData implements ResultSet.SetData 53 | func (p *PageResult[T]) SetData(items []T) { 54 | p.Items = make([]T, 0, len(items)) 55 | p.Items = append(p.Items, items...) 56 | } 57 | 58 | // CalTotalPage 计算总页数 59 | func (p *PageResult[T]) CalTotalPage() { 60 | if p.PageSize > 0 { 61 | if p.Total%int64(p.PageSize) == 0 { 62 | p.TotalPage = p.Total / int64(p.PageSize) 63 | } else { 64 | p.TotalPage = p.Total/int64(p.PageSize) + 1 65 | } 66 | } 67 | } 68 | 69 | // CopyPageResult copy PageResult from src to dest (no items) 70 | func CopyPageResult[S any, T any](src PageResult[S], dest *PageResult[T]) { 71 | dest.PageParam = src.PageParam 72 | dest.Total = src.Total 73 | dest.TotalPage = src.TotalPage 74 | } 75 | 76 | //Query 基本的查询参数 77 | type Query struct { 78 | PageParam 79 | //ID 80 | ID int64 `json:"id"` 81 | } 82 | -------------------------------------------------------------------------------- /common/perm/perm.go: -------------------------------------------------------------------------------- 1 | package perm 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strings" 7 | 8 | c "github.com/d0ngw/go/common" 9 | ) 10 | 11 | // Operation 定义操作类型 12 | type Operation int8 13 | 14 | // 定义操作的类型 15 | const ( 16 | OPRead Operation = 1 << iota 17 | OPInsert 18 | OPUpdate 19 | OPDelete 20 | OPAll = OPRead | OPInsert | OPUpdate | OPDelete 21 | ) 22 | 23 | // ParseOperation 从字符串中解析操作的权限 24 | func ParseOperation(operation string) Operation { 25 | operation = strings.ToLower(operation) 26 | 27 | if operation == "all" { 28 | return OPAll 29 | } 30 | var op Operation 31 | for _, o := range operation { 32 | switch o { 33 | case 'r': 34 | op = op | OPRead 35 | case 'i': 36 | op = op | OPInsert 37 | case 'd': 38 | op = op | OPDelete 39 | case 'u': 40 | op = op | OPUpdate 41 | } 42 | } 43 | return op 44 | } 45 | 46 | // String 将权限转为字符串表达 47 | func (p Operation) String() string { 48 | str := "" 49 | if p&OPRead != 0 { 50 | str += "r" 51 | } 52 | if p&OPInsert != 0 { 53 | str += "i" 54 | } 55 | if p&OPDelete != 0 { 56 | str += "d" 57 | } 58 | if p&OPUpdate != 0 { 59 | str += "u" 60 | } 61 | return str 62 | } 63 | 64 | // Resource 定义资源 65 | type Resource struct { 66 | parent *Resource 67 | name string 68 | id string 69 | } 70 | 71 | // GetParent 父级资源 72 | func (p *Resource) GetParent() *Resource { 73 | return p.parent 74 | } 75 | 76 | // GetName 资源的名称 77 | func (p *Resource) GetName() string { 78 | return p.name 79 | } 80 | 81 | // GetID 资源的id 82 | func (p *Resource) GetID() string { 83 | return p.id 84 | } 85 | 86 | // ResourceRegistry 记录所有的资源 87 | type ResourceRegistry struct { 88 | resouceReg *c.LinkedMap 89 | lastError error 90 | } 91 | 92 | // NewResourceRegistry 构建资源注册 93 | func NewResourceRegistry() *ResourceRegistry { 94 | return &ResourceRegistry{ 95 | resouceReg: c.NewLinkedMap(), 96 | lastError: nil, 97 | } 98 | } 99 | 100 | // Add 注册一个Resource,如果相同的资源在registry中已经存在,返回error 101 | func (p *ResourceRegistry) Add(resource *Resource) error { 102 | if resource == nil { 103 | return fmt.Errorf("Not allow nil resource") 104 | } 105 | rid := resource.GetID() 106 | if _, ok := p.resouceReg.Get(rid); ok { 107 | return fmt.Errorf("Duplicate resouce id:%s", rid) 108 | } 109 | p.resouceReg.Put(rid, resource) 110 | return nil 111 | } 112 | 113 | // IsExist 检查指定的资源id是否存在 114 | func (p *ResourceRegistry) IsExist(resID string) bool { 115 | _, ok := p.resouceReg.Get(resID) 116 | return ok 117 | } 118 | 119 | // ResourceGroup 资源分组 120 | type ResourceGroup struct { 121 | Name string //组名称 122 | Resources []*Resource //资源 123 | } 124 | 125 | // BuildResourceGroup 构建resource group列表 126 | func (p *ResourceRegistry) BuildResourceGroup(depth int) (groups []*ResourceGroup, err error) { 127 | var result = c.NewLinkedMap() 128 | for _, v := range p.resouceReg.Entries() { 129 | id := v.Key.(string) 130 | resource := v.Value.(*Resource) 131 | ids := c.SplitTrimOmitEmpty(id, ".") 132 | if len(ids) > depth { 133 | groupID := strings.Join(ids[0:depth], ".") 134 | exist, ok := result.Get(groupID) 135 | if !ok { 136 | group := &ResourceGroup{} 137 | groupResource, ok := p.resouceReg.Get(groupID) 138 | if ok && groupResource != nil { 139 | group.Name = groupResource.(*Resource).GetName() 140 | } else { 141 | err = fmt.Errorf("can't find group id %s", groupID) 142 | return 143 | } 144 | result.Put(groupID, group) 145 | exist = group 146 | } 147 | group := exist.(*ResourceGroup) 148 | group.Resources = append(group.Resources, resource) 149 | } 150 | } 151 | var ret []*ResourceGroup 152 | for _, v := range result.Entries() { 153 | ret = append(ret, v.Value.(*ResourceGroup)) 154 | } 155 | return ret, nil 156 | } 157 | 158 | // NewResource 创建一个新的资源 159 | func NewResource(name, id string, parent *Resource) *Resource { 160 | ids := []string{} 161 | if parent != nil { 162 | ids = append(ids, parent.GetID()) 163 | } 164 | ids = append(ids, id) 165 | return &Resource{ 166 | parent: parent, 167 | name: name, 168 | id: strings.Join(ids, "."), 169 | } 170 | } 171 | 172 | // NewResourceAndReg 创建并新建一个资源,如果相同的资源在registry中已经存在,则会panic 173 | func NewResourceAndReg(registry *ResourceRegistry, name, id string, parent *Resource) *Resource { 174 | res := NewResource(name, id, parent) 175 | if err := registry.Add(res); err != nil { 176 | panic(err) 177 | } 178 | return res 179 | } 180 | 181 | type permKey int 182 | 183 | const ( 184 | required permKey = 0 //需要的权限 185 | user permKey = 1 //登录的用户 186 | ) 187 | 188 | //Perm 定义了一个权限,一个权限由资源及其对应的操作组成 189 | type Perm struct { 190 | Res *Resource //资源 191 | Op Operation //操作 192 | } 193 | 194 | // NewPerm 构建Perm 195 | func NewPerm(res *Resource, op Operation) *Perm { 196 | return &Perm{Res: res, Op: op} 197 | } 198 | 199 | //Role 定义角色 200 | type Role interface { 201 | //GetName 角色的名称 202 | GetName() string 203 | //GetPerms 角色拥有的权限 204 | GetPerms() map[string]Operation 205 | } 206 | 207 | // Principal 定义了拥有权限的主体 208 | type Principal interface { 209 | // GetID 取得principal的id 210 | GetID() int64 211 | // GetName 取得principal的名称 212 | GetName() string 213 | // GetRoles 取得principal所拥有的角色 214 | GetRoles() []Role 215 | } 216 | 217 | // ReqPerm 在ctx中声明需要由perms指定的权限 218 | func ReqPerm(ctx context.Context, perms []*Perm) (context.Context, error) { 219 | if ctx == nil || len(perms) == 0 { 220 | return ctx, fmt.Errorf("Ctx or resource must not be nil") 221 | } 222 | 223 | existed, ok := ctx.Value(required).([]*Perm) 224 | if ok { 225 | perms = append(perms, existed...) 226 | } 227 | 228 | ctx = context.WithValue(ctx, required, perms) 229 | return ctx, nil 230 | } 231 | 232 | // BindPrincipal 在ctx中绑定principal 233 | func BindPrincipal(ctx context.Context, principal Principal) (context.Context, error) { 234 | if ctx == nil || principal == nil { 235 | return ctx, fmt.Errorf("Ctx or principal must not be nil") 236 | } 237 | ctx = context.WithValue(ctx, user, principal) 238 | return ctx, nil 239 | } 240 | 241 | // GetPrincipal 在ctx中取得principal 242 | func GetPrincipal(ctx context.Context) (Principal, error) { 243 | if ctx == nil { 244 | return nil, fmt.Errorf("Ctx must not be nil") 245 | } 246 | 247 | principal, ok := ctx.Value(user).(Principal) 248 | if ok { 249 | return principal, nil 250 | } 251 | return nil, nil 252 | } 253 | 254 | // GetRequiredPerm 在ctx中取得需要权限 255 | func GetRequiredPerm(ctx context.Context) ([]*Perm, error) { 256 | if ctx == nil { 257 | return nil, fmt.Errorf("Ctx must not be nil") 258 | } 259 | 260 | reqPerms, ok := ctx.Value(required).([]*Perm) 261 | if !ok { 262 | return nil, nil 263 | } 264 | return reqPerms, nil 265 | } 266 | 267 | // HasPermWithPrinciapl 检查principal是否拥有ctx中要求的权限 268 | func HasPermWithPrinciapl(ctx context.Context, principal Principal) bool { 269 | if ctx == nil { 270 | return false 271 | } 272 | 273 | reqPerms, ok := ctx.Value(required).([]*Perm) 274 | if !ok { 275 | return true 276 | } 277 | 278 | return HasPermWithPrincipalAndPerms(principal, reqPerms) 279 | } 280 | 281 | // HasPermWithPrincipalAndPerms 检查principal是否拥有reqPerms指定的权限 282 | func HasPermWithPrincipalAndPerms(principal Principal, reqPerms []*Perm) bool { 283 | if principal == nil { 284 | return false 285 | } 286 | 287 | if len(reqPerms) == 0 { 288 | return true 289 | } 290 | 291 | roles := principal.GetRoles() 292 | if len(roles) == 0 { 293 | return false 294 | } 295 | 296 | for _, r := range reqPerms { 297 | resID := r.Res.GetID() 298 | mask := r.Op 299 | for _, role := range roles { 300 | opMask, ok := role.GetPerms()[resID] 301 | if !ok { 302 | continue 303 | } 304 | mask = mask & (mask ^ opMask) 305 | if mask == 0 { 306 | break 307 | } 308 | } 309 | 310 | if mask != 0 { 311 | return false 312 | } 313 | } 314 | return true 315 | } 316 | -------------------------------------------------------------------------------- /common/pkg_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | func init() { 4 | initLogger(&LogConfig{}) 5 | } 6 | -------------------------------------------------------------------------------- /common/service.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "fmt" 5 | "sort" 6 | "sync" 7 | ) 8 | 9 | // ServiceState 表示服务的状态 10 | type ServiceState uint32 11 | 12 | const ( 13 | // NEW 新建`` 14 | NEW ServiceState = iota 15 | // INITED 初始化完毕 16 | INITED 17 | // STARTING 正在启动 18 | STARTING 19 | // RUNNING 正在运行 20 | RUNNING 21 | // STOPPING 正在停止 22 | STOPPING 23 | // TERMINATED 已经停止 24 | TERMINATED 25 | // FAILED 失败 26 | FAILED 27 | ) 28 | 29 | var serviceStateStrings = map[ServiceState]string{ 30 | NEW: "NEW", 31 | INITED: "INITED", 32 | STARTING: "STARTING", 33 | RUNNING: "RUNNING", 34 | STOPPING: "STOPPING", 35 | TERMINATED: "TERMINATED", 36 | FAILED: "FAILED"} 37 | 38 | func (p ServiceState) String() string { 39 | return serviceStateStrings[p] 40 | } 41 | 42 | var validStateState = map[ServiceState][]ServiceState{ 43 | NEW: {INITED, FAILED, TERMINATED}, 44 | INITED: {STARTING, FAILED, TERMINATED}, 45 | STARTING: {RUNNING, FAILED, TERMINATED}, 46 | RUNNING: {STOPPING, FAILED, TERMINATED}, 47 | STOPPING: {TERMINATED, FAILED}, 48 | TERMINATED: {}, 49 | FAILED: {}, 50 | } 51 | 52 | // IsValidServiceState 检查ServiceState的状态转移是否有效 53 | func IsValidServiceState(oldState ServiceState, newState ServiceState) bool { 54 | if targetStates, ok := validStateState[oldState]; ok { 55 | for _, targetState := range targetStates { 56 | if targetState == newState { 57 | return true 58 | } 59 | } 60 | } 61 | return false 62 | } 63 | 64 | // Initable 表示需要进行初始化 65 | type Initable interface { 66 | // Init 执行初始化操作,如果初始化失败,返回错误的原因 67 | Init() error 68 | } 69 | 70 | // Service 统一的服务接口 71 | type Service interface { 72 | Initable 73 | // Name 取得服务名称 74 | Name() string 75 | // Start 启动服务 76 | Start() bool 77 | // 启动的次序 78 | GetStartOrder() int 79 | // Stop 停止服务 80 | Stop() bool 81 | // 停止的次序 82 | GetStopOrder() int 83 | // State 服务的状态 84 | State() ServiceState 85 | // SetState 设置服务的状态 86 | setState(newState ServiceState) bool 87 | } 88 | 89 | // ServiceInit 初始化服务 90 | func ServiceInit(service Service) bool { 91 | if service.State() == INITED { 92 | Infof("%s has been inited,skip", service) 93 | return true 94 | } 95 | name := ServiceName(service) 96 | err := service.Init() 97 | if err == nil && service.setState(INITED) { 98 | return true 99 | } 100 | Errorf("init %s fail,err:%s", name, err) 101 | service.setState(FAILED) 102 | return false 103 | } 104 | 105 | // ServiceStart 开始服务 106 | func ServiceStart(service Service) bool { 107 | name := ServiceName(service) 108 | service.setState(STARTING) 109 | if service.Start() && service.setState(RUNNING) { 110 | return true 111 | } 112 | Errorf("start %s fail", name) 113 | service.setState(FAILED) 114 | return false 115 | } 116 | 117 | // ServiceStop 停止服务 118 | func ServiceStop(service Service) bool { 119 | name := ServiceName(service) 120 | service.setState(STOPPING) 121 | if service.Stop() && service.setState(TERMINATED) { 122 | return true 123 | } 124 | Errorf("stop %s fail", name) 125 | service.setState(FAILED) 126 | return false 127 | } 128 | 129 | // BaseService 提供基本的Service接口实现 130 | type BaseService struct { 131 | SName string //服务的名称 132 | Order int 133 | state ServiceState //服务的状态 134 | stateLock sync.RWMutex //读写锁 135 | } 136 | 137 | // Name 服务名称 138 | func (p *BaseService) Name() string { 139 | return p.SName 140 | } 141 | 142 | // Init 初始化 143 | func (p *BaseService) Init() error { 144 | return nil 145 | } 146 | 147 | // Start 启动服务 148 | func (p *BaseService) Start() bool { 149 | return true 150 | } 151 | 152 | // GetStartOrder 启动服务 153 | func (p *BaseService) GetStartOrder() int { 154 | return p.Order 155 | } 156 | 157 | // Stop 停止服务 158 | func (p *BaseService) Stop() bool { 159 | return true 160 | } 161 | 162 | // GetStopOrder 停止服务 163 | func (p *BaseService) GetStopOrder() int { 164 | return -p.GetStartOrder() 165 | } 166 | 167 | // State 取得服务的状态 168 | func (p *BaseService) State() ServiceState { 169 | p.stateLock.RLock() 170 | defer p.stateLock.RUnlock() 171 | return p.state 172 | } 173 | 174 | func (p *BaseService) setState(newState ServiceState) bool { 175 | p.stateLock.Lock() 176 | defer p.stateLock.Unlock() 177 | if IsValidServiceState(p.state, newState) { 178 | p.state = newState 179 | return true 180 | } 181 | Errorf("Invalid state transfer %s->%s,%s", p.state, newState, p.Name()) 182 | return false 183 | } 184 | 185 | // ServiceName 取得服务的名称 186 | func ServiceName(service Service) string { 187 | name := fmt.Sprintf("%T", service) 188 | if service.Name() != "" { 189 | name += "#" + service.Name() 190 | } 191 | return name 192 | } 193 | 194 | // Services 一组Service的集合 195 | type Services struct { 196 | sorted []Service //排序后的服务集合 197 | } 198 | 199 | // NewServices 构建新的Service集合 200 | func NewServices(services []Service, start bool) *Services { 201 | //排序 202 | var sorted = make([]Service, len(services)) 203 | copy(sorted, services) 204 | sort.Slice(sorted, func(i, j int) bool { 205 | if start { 206 | return sorted[i].GetStartOrder() < sorted[j].GetStartOrder() 207 | } 208 | return sorted[i].GetStopOrder() < sorted[j].GetStopOrder() 209 | }) 210 | return &Services{sorted: sorted} 211 | } 212 | 213 | // Init 初始化服务集合 214 | func (p *Services) Init() bool { 215 | for _, service := range p.sorted { 216 | if !ServiceInit(service) { 217 | Warnf("init %s fail", service) 218 | return false 219 | } 220 | } 221 | return true 222 | } 223 | 224 | // Start 启动服务 225 | func (p *Services) Start() bool { 226 | for _, service := range p.sorted { 227 | name := ServiceName(service) 228 | if !ServiceStart(service) { 229 | Warnf("start %s fail", name) 230 | return false 231 | } 232 | } 233 | return true 234 | } 235 | 236 | // Stop 停止服务 237 | func (p *Services) Stop() bool { 238 | for _, service := range p.sorted { 239 | name := ServiceName(service) 240 | if !ServiceStop(service) { 241 | Warnf("stop %s fail", name) 242 | } 243 | } 244 | return true 245 | } 246 | -------------------------------------------------------------------------------- /common/service_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | type aService struct { 10 | BaseService 11 | } 12 | 13 | type bService struct { 14 | BaseService 15 | } 16 | 17 | func TestServices(t *testing.T) { 18 | as := &aService{BaseService{SName: "a", Order: -1}} 19 | bs := &bService{BaseService{SName: "b", Order: 2}} 20 | s1 := NewServices([]Service{as, bs}, true) 21 | assert.Equal(t, 2, len(s1.sorted)) 22 | assert.Equal(t, true, s1.Init()) 23 | assert.Equal(t, true, s1.Start()) 24 | assert.Equal(t, true, s1.Stop()) 25 | 26 | as.state = NEW 27 | bs.state = NEW 28 | 29 | s2 := NewServices([]Service{bs, as}, true) 30 | assert.Equal(t, 2, len(s2.sorted)) 31 | requiredOrder := []string{"a", "b"} 32 | for i := 0; i < len(s2.sorted); i++ { 33 | assert.Equal(t, requiredOrder[i], s2.sorted[i].Name()) 34 | } 35 | 36 | s2 = NewServices([]Service{bs, as}, false) 37 | assert.Equal(t, 2, len(s2.sorted)) 38 | requiredOrder = []string{"b", "a"} 39 | for i := 0; i < len(s2.sorted); i++ { 40 | assert.Equal(t, requiredOrder[i], s2.sorted[i].Name()) 41 | } 42 | 43 | assert.Equal(t, true, s2.Init()) 44 | assert.Equal(t, true, s2.Start()) 45 | assert.Equal(t, true, s2.Stop()) 46 | } 47 | -------------------------------------------------------------------------------- /common/testdata/1.txt: -------------------------------------------------------------------------------- 1 | a 2 | b 3 | c 4 | d 5 | ef -------------------------------------------------------------------------------- /common/time.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | const ( 8 | // FormatDefault 默认的日期时间格式 9 | FormatDefault = "2006-01-02 15:04:05" 10 | // FormatYYYYMMDD 日期时间格式 11 | FormatYYYYMMDD = "20060102 15:04:05" 12 | // FormatYMDH 日期格式 13 | FormatYMDH = "2006010215" 14 | // FormatYMD 日期格式 15 | FormatYMD = "2006-01-02" 16 | ) 17 | 18 | // LocalLocation 本地时区 19 | var LocalLocation = time.Now().Local().Location() 20 | 21 | // ParseLocalTime 解析本地时间 22 | func ParseLocalTime(t string) (time.Time, error) { 23 | return time.ParseInLocation(FormatDefault, t, LocalLocation) 24 | } 25 | 26 | // ParseLocatTimeWithFormat 解析本地时间 27 | func ParseLocatTimeWithFormat(format, t string) (time.Time, error) { 28 | return time.ParseInLocation(format, t, LocalLocation) 29 | } 30 | 31 | // UnixMills 取得毫秒 32 | func UnixMills(t time.Time) int64 { 33 | return t.UnixNano() / int64(time.Millisecond) 34 | } 35 | 36 | // UnixMillsTime 根据毫秒取得时间 37 | func UnixMillsTime(tmillis int64) time.Time { 38 | return time.Unix(tmillis/1000, (tmillis%1000)*int64(time.Millisecond)) 39 | } 40 | -------------------------------------------------------------------------------- /common/time_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestUnixMillisTime(t *testing.T) { 10 | tt := UnixMillsTime(1453839313078) 11 | fmt.Println(tt) 12 | fmt.Println(tt.Year(), tt.Month(), tt.Day(), tt.Hour()) 13 | } 14 | 15 | func TestUnixMillis(t *testing.T) { 16 | now := time.Now() 17 | fmt.Println(now.UnixNano()) 18 | fmt.Println(UnixMills(now)) 19 | } 20 | -------------------------------------------------------------------------------- /common/util_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "reflect" 7 | "runtime" 8 | "syscall" 9 | "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | type st struct { 16 | } 17 | 18 | func TestGetType(t *testing.T) { 19 | vt := GetFirstFieldType(struct{ st *interface{} }{}) 20 | assert.Equal(t, vt.Kind(), reflect.Ptr) 21 | assert.Equal(t, vt.Elem().Kind(), reflect.Interface) 22 | } 23 | 24 | func TestShutdownHook(t *testing.T) { 25 | shook := NewShutdownhook() 26 | shook.AddHook(func() { 27 | _, c1, c2, _ := runtime.Caller(1) 28 | fmt.Println("Call @", c1, c2) 29 | }) 30 | 31 | fmt.Println("First wait") 32 | go func() { 33 | time.Sleep(time.Duration(100) * time.Millisecond) 34 | shook.ch <- syscall.SIGINT 35 | }() 36 | shook.WaitShutdown() 37 | } 38 | 39 | func TestInt64(t *testing.T) { 40 | testInt64(t, 2, 2) 41 | testInt64(t, int8(2), 2) 42 | testInt64(t, int16(2), 2) 43 | testInt64(t, int32(2), 2) 44 | testInt64(t, int64(2), 2) 45 | testInt64(t, "2", 2) 46 | testInt64(t, float32(2.0), 2) 47 | testInt64(t, float64(2.0), 2) 48 | testInt64(t, json.Number("2"), 2) 49 | _, err := Int64(struct{}{}) 50 | assert.NotNil(t, err) 51 | } 52 | 53 | func testInt64(t *testing.T, v interface{}, av int64) { 54 | i, err := Int64(v) 55 | assert.Nil(t, err) 56 | assert.EqualValues(t, i, av) 57 | } 58 | 59 | func TestFloat64(t *testing.T) { 60 | testFloat64(t, 2, 2.0) 61 | testFloat64(t, int8(2), 2.0) 62 | testFloat64(t, int16(2), 2.0) 63 | testFloat64(t, int32(2), 2.0) 64 | testFloat64(t, int64(2), 2.0) 65 | testFloat64(t, "2", 2.0) 66 | testFloat64(t, float32(2.0), 2.0) 67 | testFloat64(t, float64(2.0), 2.0) 68 | testFloat64(t, json.Number("2"), 2.0) 69 | _, err := Float64(struct{}{}) 70 | assert.NotNil(t, err) 71 | } 72 | 73 | func testFloat64(t *testing.T, v interface{}, av float64) { 74 | i, err := Float64(v) 75 | assert.Nil(t, err) 76 | assert.EqualValues(t, i, av) 77 | } 78 | 79 | type TestStruct struct { 80 | ID int64 81 | } 82 | 83 | func TestByteSlice2String(t *testing.T) { 84 | var bs []byte 85 | t.Logf("bs:%v", bs) 86 | 87 | str0 := ByteSlice2String(bs) 88 | str1 := string(bs) 89 | assert.Equal(t, str0, str1) 90 | t.Logf("str0:`%s`,str1:`%s`,%p", str0, str1, &str1) 91 | 92 | var bss = []byte("abcdefgh") 93 | str1 = string(bss) 94 | str2 := "1" 95 | bs0 := String2ByteSlice(str1) 96 | bs0[0] = 'A' 97 | 98 | fmt.Printf("str1:%p,%s\n", &str1, str1) 99 | assert.EqualValues(t, len(str1), len(bs0)) 100 | t.Logf("%s,%s,%d,%p", str1, string(bs0), bs0[0], &bs0) 101 | bs0[0] = 1 102 | 103 | bs1 := String2ByteSlice(str2) 104 | t.Logf("str2:%p,bs1:%v,%p,cap:%d", &str2, bs1, &bs1, cap(bs1)) 105 | } 106 | 107 | type CopyBase0 struct { 108 | ID2 int32 109 | BaseCountry int8 110 | Haha string 111 | } 112 | 113 | type CopyBase struct { 114 | CopyBase0 115 | BaseName string 116 | } 117 | 118 | type CopyBase2 struct { 119 | CopyBase 120 | BaseName2 string 121 | } 122 | 123 | type Int32 int32 124 | 125 | func TestStructCopier(t *testing.T) { 126 | var from = &struct { 127 | ID int64 128 | Name string 129 | Age int32 130 | Address []string 131 | T *TestStruct 132 | CopyBase2 133 | }{ 134 | ID: 1, 135 | Name: "ok", 136 | Age: 32, 137 | Address: []string{"a", "b", "c"}, 138 | T: &TestStruct{ID: 100}, 139 | } 140 | from.BaseName = "b1" 141 | from.BaseName2 = "b2" 142 | from.BaseCountry = 10 143 | from.ID2 = 33 144 | 145 | var to = &struct { 146 | ID2 Int32 147 | Name string 148 | TestStruct 149 | Age int32 150 | Address []string 151 | T *TestStruct 152 | BaseName string 153 | BaseName2 string 154 | BaseCountry int8 155 | Haha string 156 | }{} 157 | 158 | copier, err := NewStructCopier(from, to) 159 | assert.NoError(t, err) 160 | err = copier(from, to) 161 | assert.NoError(t, err) 162 | assert.Equal(t, from.ID, to.ID) 163 | assert.Equal(t, from.Name, to.Name) 164 | assert.Equal(t, from.Age, to.Age) 165 | assert.Equal(t, from.Address, to.Address) 166 | assert.Equal(t, from.T.ID, to.T.ID) 167 | assert.Equal(t, from.BaseName, to.BaseName) 168 | assert.Equal(t, from.BaseName2, to.BaseName2) 169 | assert.Equal(t, from.BaseCountry, to.BaseCountry) 170 | assert.EqualValues(t, from.ID2, to.ID2) 171 | t.Log(to.Address) 172 | 173 | var to2 int32 174 | err = copier(from, to2) 175 | assert.Error(t, err) 176 | t.Logf("err:%v", err) 177 | var to3 = *to 178 | err = copier(from, to3) 179 | assert.Error(t, err) 180 | t.Logf("err:%v", err) 181 | } 182 | 183 | func TestIsValNil(t *testing.T) { 184 | var i int 185 | assert.False(t, IsValNil(i)) 186 | 187 | var ip *int 188 | assert.True(t, IsValNil(ip)) 189 | 190 | var j float32 191 | assert.False(t, IsValNil(j)) 192 | 193 | var bs []byte 194 | var bsi interface{} = bs 195 | assert.True(t, bs == nil) 196 | assert.False(t, bsi == nil) 197 | assert.True(t, IsValNil(bs)) 198 | 199 | var m map[string]string 200 | assert.True(t, IsValNil(m)) 201 | assert.True(t, m == nil) 202 | 203 | m = map[string]string{"a": "b"} 204 | assert.False(t, m == nil) 205 | assert.False(t, IsValNil(m)) 206 | 207 | var s string 208 | assert.False(t, IsValNil(s)) 209 | 210 | assert.True(t, IsValNil(nil)) 211 | assert.False(t, IsValNil(struct{}{})) 212 | 213 | var f func() 214 | assert.True(t, IsValNil(f)) 215 | 216 | f = func() {} 217 | assert.False(t, IsValNil(f)) 218 | 219 | var c chan int 220 | assert.True(t, IsValNil(c)) 221 | c = make(chan int) 222 | assert.False(t, IsValNil(c)) 223 | 224 | var ii interface{} 225 | assert.True(t, IsValNil(ii)) 226 | ii = "" 227 | assert.False(t, IsValNil(ii)) 228 | } 229 | 230 | func TestStringSliceToNumber(t *testing.T) { 231 | var s = []string{"-1", "0", "1", "2", "3"} 232 | ui8, err := StringSlice(s).ToUint8() 233 | assert.Error(t, err) 234 | 235 | s = []string{"355", "0", "1", "2", "3"} 236 | ui8, err = StringSlice(s).ToUint8() 237 | assert.NoError(t, err) 238 | assert.EqualValues(t, []uint8{355 & 0xFF, 0, 1, 2, 3}, ui8) 239 | } 240 | -------------------------------------------------------------------------------- /common/validates.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "fmt" 5 | "regexp" 6 | "strconv" 7 | "strings" 8 | "unicode/utf8" 9 | ) 10 | 11 | // StrValidator 字符串验证器 12 | type StrValidator interface { 13 | //Vlidate 验证字符串参数是否符合规则 14 | Validate(param string) bool 15 | } 16 | 17 | // StringLenValidator 字符串长度验证 18 | type StringLenValidator struct { 19 | min int //最小长度 20 | max int //最大长度 21 | } 22 | 23 | // Validate 验证字符串的长度 24 | func (p *StringLenValidator) Validate(param string) bool { 25 | strLen := utf8.RuneCountInString(param) 26 | return p.min <= strLen && strLen <= p.max 27 | } 28 | 29 | // NotEmptyValidator 非空 30 | type NotEmptyValidator struct { 31 | } 32 | 33 | // Validate 验证字符串是否为空 34 | func (p *NotEmptyValidator) Validate(param string) bool { 35 | if len(param) == 0 { 36 | return false 37 | } 38 | return len(strings.TrimSpace(param)) > 0 39 | } 40 | 41 | // Int32Validator 32位整数验证 42 | type Int32Validator struct { 43 | min int32 //最小值 44 | max int32 //最大值 45 | } 46 | 47 | // Validate 验证整型值 48 | func (p *Int32Validator) Validate(param string) bool { 49 | if len(param) == 0 { 50 | return true 51 | } 52 | if v, err := strconv.ParseInt(param, 10, 32); err == nil { 53 | vi32 := int32(v) 54 | return p.min <= vi32 && vi32 <= p.max 55 | } 56 | return false 57 | } 58 | 59 | // Int64Validator 64位整数验证 60 | type Int64Validator struct { 61 | min int64 //最小值 62 | max int64 //最大值 63 | } 64 | 65 | // Validate 验证整型值 66 | func (p *Int64Validator) Validate(param string) bool { 67 | if len(param) == 0 { 68 | return true 69 | } 70 | if v, err := strconv.ParseInt(param, 10, 64); err == nil { 71 | return p.min <= v && v <= p.max 72 | } 73 | return false 74 | } 75 | 76 | // Float32Validator 32位浮点数验证 77 | type Float32Validator struct { 78 | min float32 //最小值 79 | max float32 //最大值 80 | } 81 | 82 | // Validate 验证浮点数值 83 | func (p *Float32Validator) Validate(param string) bool { 84 | if len(param) == 0 { 85 | return true 86 | } 87 | if v, err := strconv.ParseFloat(param, 32); err == nil { 88 | v32 := float32(v) 89 | return p.min <= v32 && v32 <= p.max 90 | } 91 | return false 92 | } 93 | 94 | // Float64Validator 64位浮点数验证 95 | type Float64Validator struct { 96 | min float64 //最小值 97 | max float64 //最大值 98 | } 99 | 100 | // Validate 验证浮点数值 101 | func (p *Float64Validator) Validate(param string) bool { 102 | if len(param) == 0 { 103 | return true 104 | } 105 | if v, err := strconv.ParseFloat(param, 64); err == nil { 106 | return p.min <= v && v <= p.max 107 | } 108 | return false 109 | } 110 | 111 | // BoolValidator bool验证 112 | type BoolValidator struct { 113 | } 114 | 115 | // Validate 验证bool值 116 | func (p *BoolValidator) Validate(param string) bool { 117 | if len(param) == 0 { 118 | return true 119 | } 120 | if _, err := strconv.ParseBool(param); err == nil { 121 | return true 122 | } 123 | return false 124 | } 125 | 126 | // RegExValidator 正则表达式验证 127 | type RegExValidator struct { 128 | pattern *regexp.Regexp //正则表达式 129 | empty bool //是否允许为空 130 | } 131 | 132 | // Validate 正则表达式验证 133 | func (p *RegExValidator) Validate(param string) bool { 134 | if param == "" && p.empty { 135 | return true 136 | } 137 | return p.pattern.MatchString(param) 138 | } 139 | 140 | // ParseInt 解析整数 141 | func ParseInt(param string) (v int, err error) { 142 | v, err = strconv.Atoi(param) 143 | return 144 | } 145 | 146 | // ParseInt32 解析整数 147 | func ParseInt32(param string) (v int32, err error) { 148 | v64, err := strconv.ParseInt(param, 10, 32) 149 | if err == nil { 150 | v = int32(v64) 151 | } 152 | return 153 | } 154 | 155 | // ParseInt64 解析整数 156 | func ParseInt64(param string) (v int64, err error) { 157 | v, err = strconv.ParseInt(param, 10, 64) 158 | return 159 | } 160 | 161 | // ParseFloat32 解析浮点数 162 | func ParseFloat32(param string) (v float32, err error) { 163 | v64, err := strconv.ParseFloat(param, 32) 164 | if err == nil { 165 | v = float32(v64) 166 | } 167 | return 168 | } 169 | 170 | // ParseFloat64 解析浮点数 171 | func ParseFloat64(param string) (v float64, err error) { 172 | v, err = strconv.ParseFloat(param, 64) 173 | return 174 | } 175 | 176 | // ValidatorNewer 创建验证器的函数类型 177 | type ValidatorNewer func(conf map[string]string) StrValidator 178 | 179 | // NewNotEmptyValidator 创建非空验证器 180 | func NewNotEmptyValidator(conf map[string]string) StrValidator { 181 | return vNOTEMPTY 182 | } 183 | 184 | // NewBoolValidator 创建bool验证器 185 | func NewBoolValidator(conf map[string]string) StrValidator { 186 | return vBOOL 187 | } 188 | 189 | // NewStrLenValidator 创建字符串长度验证,conf["min"],最小值;conf["max"],最大值 190 | func NewStrLenValidator(conf map[string]string) StrValidator { 191 | minLen, err := ParseInt(conf["min"]) 192 | if err != nil { 193 | panic(err) 194 | } 195 | maxLen, err := ParseInt(conf["max"]) 196 | if err != nil { 197 | panic(err) 198 | } 199 | if minLen < 0 || maxLen < 0 || minLen > maxLen { 200 | panic(fmt.Errorf("invalid str length,minLen:%v,maxLen:%v", minLen, maxLen)) 201 | } 202 | return &StringLenValidator{min: minLen, max: maxLen} 203 | } 204 | 205 | // NewInt32Validator 创建int32验证,conf["min"],最小值;conf["max"],最大值 206 | func NewInt32Validator(conf map[string]string) StrValidator { 207 | min, err := ParseInt32(conf["min"]) 208 | if err != nil { 209 | panic(err) 210 | } 211 | max, err := ParseInt32(conf["max"]) 212 | if err != nil { 213 | panic(err) 214 | } 215 | if min > max { 216 | panic(fmt.Errorf("invalid min %d,max %d", min, max)) 217 | } 218 | return &Int32Validator{min: min, max: max} 219 | 220 | } 221 | 222 | // NewInt64Validator 创建int64验证,conf["min"],最小值;conf["max"],最大值 223 | func NewInt64Validator(conf map[string]string) StrValidator { 224 | min, err := ParseInt64(conf["min"]) 225 | if err != nil { 226 | panic(err) 227 | } 228 | max, err := ParseInt64(conf["max"]) 229 | if err != nil { 230 | panic(err) 231 | } 232 | if min > max { 233 | panic(fmt.Errorf("invalid min %d,max %d", min, max)) 234 | } 235 | return &Int64Validator{min: min, max: max} 236 | } 237 | 238 | // NewFloat32Validator 创建float32验证,conf["min"],最小值;conf["max"],最大值 239 | func NewFloat32Validator(conf map[string]string) StrValidator { 240 | min, err := ParseFloat32(conf["min"]) 241 | if err != nil { 242 | panic(err) 243 | } 244 | max, err := ParseFloat32(conf["max"]) 245 | if err != nil { 246 | panic(err) 247 | } 248 | if min > max { 249 | panic(fmt.Errorf("invalid min %f,max %f", min, max)) 250 | } 251 | return &Float32Validator{min: min, max: max} 252 | } 253 | 254 | // NewFloat64Validator 创建float64验证,conf["min"],最小值;conf["max"],最大值 255 | func NewFloat64Validator(conf map[string]string) StrValidator { 256 | min, err := ParseFloat64(conf["min"]) 257 | if err != nil { 258 | panic(err) 259 | } 260 | max, err := ParseFloat64(conf["max"]) 261 | if err != nil { 262 | panic(err) 263 | } 264 | if min > max { 265 | panic(fmt.Errorf("invalid min %f,max %f", min, max)) 266 | } 267 | return &Float64Validator{min: min, max: max} 268 | } 269 | 270 | // NewRegexValidator 创建正则表达式验证,conf["pattern"] 正则表达式 271 | func NewRegexValidator(conf map[string]string) StrValidator { 272 | pattern := conf["pattern"] 273 | allowEmpty := strings.ToLower(conf["empty"]) == "true" 274 | if len(pattern) == 0 { 275 | panic(fmt.Errorf("invalid pattern %s", pattern)) 276 | } 277 | return &RegExValidator{pattern: regexp.MustCompile(pattern), empty: allowEmpty} 278 | } 279 | 280 | // 默认的构建器的名称 281 | const ( 282 | VNOTEMPTY = "notempty" //无构建参数 283 | VBOOL = "bool" 284 | VSTRLEN = "strlen" 285 | VINT32 = "i32" 286 | VINT64 = "i64" 287 | VFLOAT32 = "f32" 288 | VFLOAT64 = "f64" 289 | VREGEX = "regex" 290 | ) 291 | 292 | var ( 293 | vNOTEMPTY = &NotEmptyValidator{} 294 | vBOOL = &BoolValidator{} 295 | validateRegister = NewCopyOnWriteMap() 296 | ) 297 | 298 | // RegValidatorNewer 根据名称注册验证器构建函数 299 | func RegValidatorNewer(name string, validator ValidatorNewer) { 300 | if err := validateRegister.PutIfAbsent(name, validator); err != nil { 301 | panic("Duplicate validator " + err.Error()) 302 | } 303 | } 304 | 305 | // NewValidatorByConf 根据配置conf["name"]及其对应的参数构建验证器 306 | func NewValidatorByConf(conf map[string]string) StrValidator { 307 | name := conf["name"] 308 | if f := validateRegister.Get(name); f != nil { 309 | return f.(ValidatorNewer)(conf) 310 | } 311 | panic("Can't find the validator name:" + name) 312 | } 313 | 314 | // 初始化注册内置的验证器 315 | func init() { 316 | RegValidatorNewer(VNOTEMPTY, NewNotEmptyValidator) 317 | RegValidatorNewer(VBOOL, NewBoolValidator) 318 | RegValidatorNewer(VSTRLEN, NewStrLenValidator) 319 | RegValidatorNewer(VINT32, NewInt32Validator) 320 | RegValidatorNewer(VINT64, NewInt64Validator) 321 | RegValidatorNewer(VFLOAT32, NewFloat32Validator) 322 | RegValidatorNewer(VFLOAT64, NewFloat64Validator) 323 | RegValidatorNewer(VREGEX, NewRegexValidator) 324 | } 325 | -------------------------------------------------------------------------------- /common/validates_service.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | // RuleConfig 验证规则配置 9 | type RuleConfig struct { 10 | Name string 11 | Desc string //规则描述 12 | Validators []map[string]string //验证器列表,必须要有name 13 | } 14 | 15 | // ValidateRuleConfig 验证规则配置 16 | type ValidateRuleConfig struct { 17 | Rules []RuleConfig 18 | SName string //服务的名称 19 | parsed validateRuleMap //解析后的结果 20 | } 21 | 22 | // Parse 解析验证的配置 23 | func (p *ValidateRuleConfig) Parse() error { 24 | if p == nil { 25 | Warnf("no validate conf") 26 | return nil 27 | } 28 | rules := make(validateRuleMap) 29 | for _, ruleConfig := range p.Rules { 30 | ruleName := strings.TrimSpace(ruleConfig.Name) 31 | if len(ruleName) == 0 { 32 | panic("The rule name must not be empty") 33 | } 34 | validators := make([]StrValidator, 0, len(ruleConfig.Validators)) 35 | for _, validatorConf := range ruleConfig.Validators { 36 | validators = append(validators, NewValidatorByConf(validatorConf)) 37 | } 38 | rule := &ValidateRule{ 39 | desc: ruleConfig.Desc, 40 | validators: validators} 41 | rules[ruleName] = rule 42 | } 43 | p.parsed = rules 44 | return nil 45 | } 46 | 47 | // NewService 根据配置解析的结果创建验证服务 48 | func (p *ValidateRuleConfig) NewService() Service { 49 | if p.parsed == nil { 50 | panic("Can't create ValidateService from nil") 51 | } 52 | svr := RuleValidateService{} 53 | svr.SName = p.SName 54 | svr.rules = p.parsed 55 | return &svr 56 | } 57 | 58 | //ValidateRule 定义验证规则 59 | type ValidateRule struct { 60 | desc string //规则描述 61 | validators []StrValidator //通过Rules构建出来的验证规则 62 | } 63 | 64 | type validateRuleMap map[string]*ValidateRule 65 | 66 | // ValidateConfigurer validateConfig 67 | type ValidateConfigurer interface { 68 | GetValidateRuleConfig() *ValidateRuleConfig 69 | } 70 | 71 | // ValidateService 验证服务 72 | type ValidateService interface { 73 | Service 74 | //Validate 使用name指定验证规则,对value进行验证,验证通过返回nil,否则返回错误原因 75 | Validate(name string, value string) error 76 | } 77 | 78 | // RuleValidateService 根据规则进行的验证服务 79 | type RuleValidateService struct { 80 | BaseService 81 | Config ValidateConfigurer `inject:"_"` 82 | rules validateRuleMap 83 | } 84 | 85 | // Init implements Initable 86 | func (p *RuleValidateService) Init() error { 87 | if p.Config == nil { 88 | Warnf("no validate config") 89 | return nil 90 | } 91 | if p.Config.GetValidateRuleConfig() == nil { 92 | Warnf("no validate config rule") 93 | return nil 94 | } 95 | config := p.Config.GetValidateRuleConfig() 96 | p.SName = config.SName 97 | p.rules = config.parsed 98 | return nil 99 | } 100 | 101 | // Validate 验证 102 | func (p *RuleValidateService) Validate(ruleName string, s string) error { 103 | rule := p.rules[ruleName] 104 | if rule == nil { 105 | err := NewValidateError(fmt.Sprintf("can't find validate rule %s", ruleName)) 106 | err.notFoundRule = true 107 | return err 108 | } 109 | 110 | for _, v := range rule.validators { 111 | if !v.Validate(s) { 112 | return NewValidateError(rule.desc) 113 | } 114 | } 115 | return nil 116 | } 117 | 118 | // ValidatePair 定义验证规则名称其需要验证的值 119 | type ValidatePair struct { 120 | Name string 121 | Value string 122 | Msg string 123 | } 124 | 125 | // NewValidatePair create ValidatePair 126 | func NewValidatePair(name, value string) *ValidatePair { 127 | return &ValidatePair{Name: name, Value: value} 128 | } 129 | 130 | // NewValidatePairMsg create ValidatePair with msg 131 | func NewValidatePairMsg(name, value, msg string) *ValidatePair { 132 | return &ValidatePair{Name: name, Value: value, Msg: msg} 133 | } 134 | 135 | // ValidateAll 验证所有的规则 136 | func ValidateAll(validateService ValidateService, nameAndValues ...*ValidatePair) error { 137 | for _, nv := range nameAndValues { 138 | if err := validateService.Validate(nv.Name, nv.Value); err != nil { 139 | if verr, ok := err.(*ValidateError); ok && verr.notFoundRule { 140 | return err 141 | } 142 | if nv.Msg != "" { 143 | return NewValidateError(nv.Msg) 144 | } 145 | return err 146 | } 147 | } 148 | return nil 149 | } 150 | 151 | // ValidateError error 152 | type ValidateError struct { 153 | //错误消息 154 | msg string 155 | //是否未找到规则 156 | notFoundRule bool 157 | } 158 | 159 | // NewValidateError new 160 | func NewValidateError(msg string) *ValidateError { 161 | return &ValidateError{msg: msg} 162 | } 163 | 164 | func (p *ValidateError) Error() string { 165 | if p == nil { 166 | return "" 167 | } 168 | return p.msg 169 | } 170 | 171 | // Human impls HumanError.Human 172 | func (p *ValidateError) Human() bool { 173 | if p == nil { 174 | return false 175 | } 176 | return true 177 | } 178 | -------------------------------------------------------------------------------- /common/validates_test.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "regexp" 6 | "testing" 7 | ) 8 | 9 | func TestNotEmpty(t *testing.T) { 10 | va := &NotEmptyValidator{} 11 | assert.False(t, va.Validate("")) 12 | assert.False(t, va.Validate(" ")) 13 | assert.False(t, va.Validate(" ")) 14 | assert.True(t, va.Validate(" abc ")) 15 | assert.True(t, va.Validate(" a ")) 16 | } 17 | 18 | func TestInteger(t *testing.T) { 19 | va32 := &Int32Validator{ 20 | min: -3, 21 | max: 10} 22 | assert.False(t, va32.Validate("a")) 23 | assert.False(t, va32.Validate("11")) 24 | assert.True(t, va32.Validate("10")) 25 | assert.True(t, va32.Validate("-3")) 26 | 27 | va64 := &Int32Validator{ 28 | min: -3, 29 | max: 10} 30 | assert.False(t, va64.Validate("a")) 31 | assert.False(t, va64.Validate("11")) 32 | assert.True(t, va64.Validate("10")) 33 | assert.True(t, va64.Validate("-3")) 34 | } 35 | 36 | func TestFloat(t *testing.T) { 37 | va32 := &Float32Validator{ 38 | min: 0.1, 39 | max: 10} 40 | assert.False(t, va32.Validate("a")) 41 | assert.False(t, va32.Validate("11")) 42 | assert.True(t, va32.Validate("10")) 43 | assert.False(t, va32.Validate("-3")) 44 | 45 | va64 := &Float64Validator{ 46 | min: -3, 47 | max: 10.5} 48 | assert.False(t, va64.Validate("a")) 49 | assert.False(t, va64.Validate("11")) 50 | assert.True(t, va64.Validate("10.4")) 51 | assert.True(t, va64.Validate("-3")) 52 | } 53 | 54 | func TestRegex(t *testing.T) { 55 | rv := &RegExValidator{ 56 | pattern: regexp.MustCompile("^a+")} 57 | 58 | assert.True(t, rv.Validate("a")) 59 | assert.False(t, rv.Validate("1a")) 60 | } 61 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/d0ngw/go 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.24.2 6 | 7 | require ( 8 | github.com/go-sql-driver/mysql v1.6.0 9 | github.com/gomodule/redigo v1.8.8 10 | github.com/json-iterator/go v1.1.12 11 | github.com/stretchr/testify v1.7.0 12 | github.com/ugorji/go/codec v1.2.7 13 | go.uber.org/zap v1.21.0 14 | golang.org/x/net v0.38.0 15 | gopkg.in/natefinch/lumberjack.v2 v2.0.0 16 | gopkg.in/yaml.v3 v3.0.1 17 | ) 18 | 19 | require ( 20 | github.com/BurntSushi/toml v1.0.0 // indirect 21 | github.com/davecgh/go-spew v1.1.1 // indirect 22 | github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect 23 | github.com/modern-go/reflect2 v1.0.2 // indirect 24 | github.com/pmezard/go-difflib v1.0.0 // indirect 25 | go.uber.org/atomic v1.7.0 // indirect 26 | go.uber.org/multierr v1.6.0 // indirect 27 | ) 28 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/BurntSushi/toml v1.0.0 h1:dtDWrepsVPfW9H/4y7dDgFc2MBUSeJhlaDtK13CxFlU= 2 | github.com/BurntSushi/toml v1.0.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= 3 | github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= 4 | github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= 5 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 7 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 8 | github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= 9 | github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= 10 | github.com/gomodule/redigo v1.8.8 h1:f6cXq6RRfiyrOJEV7p3JhLDlmawGBVBBP1MggY8Mo4E= 11 | github.com/gomodule/redigo v1.8.8/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= 12 | github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= 13 | github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= 14 | github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= 15 | github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= 16 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 17 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 18 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 19 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 20 | github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= 21 | github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 22 | github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= 23 | github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= 24 | github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= 25 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 26 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 27 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 28 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 29 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 30 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 31 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 32 | github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= 33 | github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= 34 | github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= 35 | github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= 36 | go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= 37 | go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= 38 | go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= 39 | go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= 40 | go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= 41 | go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= 42 | go.uber.org/zap v1.21.0 h1:WefMeulhovoZ2sYXz7st6K0sLj7bBhpiFaud4r4zST8= 43 | go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= 44 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 45 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 46 | golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= 47 | golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 48 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 49 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 50 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 51 | golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= 52 | golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= 53 | golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= 54 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 55 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 56 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 57 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 58 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 59 | golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 60 | golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 61 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 62 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 63 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 64 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 65 | golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= 66 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 67 | golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= 68 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 69 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 70 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 71 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 72 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= 73 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 74 | gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= 75 | gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= 76 | gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= 77 | gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 78 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 79 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 80 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 81 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 82 | -------------------------------------------------------------------------------- /http/controller.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "reflect" 7 | "unicode" 8 | ) 9 | 10 | // Controller 接口定义http处理器 11 | type Controller interface { 12 | // GetName 控制器的名称 13 | GetName() string 14 | // GetPath 路径前缀,以'/'结束,同一个控制下的http.Handler都 15 | GetPath() string 16 | // GetHandlerMiddlewares 返回controller的处理方法中,需要增加middleware封装的方法,key是controller中的方法名 17 | GetHandlerMiddlewares() map[string][]Middleware 18 | } 19 | 20 | // BaseController 表示一个控制器 21 | type BaseController struct { 22 | Name string // Controller的名称 23 | Path string // Controller的路径 24 | HandlerMiddlewares map[string][]Middleware // Controller中需要使用middleware封装的方法 25 | } 26 | 27 | // GetName controller的名称 28 | func (p *BaseController) GetName() string { 29 | return p.Name 30 | } 31 | 32 | // GetPath controller的path 33 | func (p *BaseController) GetPath() string { 34 | return p.Path 35 | } 36 | 37 | // GetHandlerMiddlewares handler的middleware 38 | func (p *BaseController) GetHandlerMiddlewares() map[string][]Middleware { 39 | return p.HandlerMiddlewares 40 | } 41 | 42 | var ( 43 | m http.HandlerFunc 44 | t = reflect.TypeOf(m) 45 | ) 46 | 47 | type handlerWithMiddleware struct { 48 | handlerFunc http.HandlerFunc 49 | middlewares []Middleware 50 | } 51 | 52 | // ReflectHandlers 查找controller中类型为http.HandlerFunc的可导出方法,并将驼峰命名改为下划线分隔的路径 53 | // 例如Index -> index,GetUser -> get_user 54 | func reflectHandlers(controller Controller) (handlers map[string]*handlerWithMiddleware, err error) { 55 | val := reflect.ValueOf(controller) 56 | if !val.IsValid() || val.Kind() != reflect.Ptr { 57 | return nil, fmt.Errorf("controller must be a valid pointer") 58 | } 59 | 60 | // 检查方法是否存在 61 | hm := controller.GetHandlerMiddlewares() 62 | if len(hm) > 0 { 63 | for name := range hm { 64 | if found := val.MethodByName(name); !found.IsValid() { 65 | return nil, fmt.Errorf("Can't find method name %s for middlewares", name) 66 | } 67 | } 68 | } 69 | 70 | handlers = map[string]*handlerWithMiddleware{} 71 | methodCount := val.NumMethod() 72 | controllerType := val.Type() 73 | for i := 0; i < methodCount; i++ { 74 | methodVal := val.Method(i) 75 | methodValType := methodVal.Type() 76 | method := controllerType.Method(i) 77 | 78 | if methodValType.AssignableTo(t) { 79 | var fn http.HandlerFunc = methodVal.Interface().(func(http.ResponseWriter, *http.Request)) 80 | hmiddle := &handlerWithMiddleware{handlerFunc: fn} 81 | if middlewares, ok := hm[method.Name]; ok { 82 | hmiddle.middlewares = middlewares 83 | } 84 | handlers[ToUnderlineName(method.Name)] = hmiddle 85 | } 86 | } 87 | return handlers, nil 88 | } 89 | 90 | // ToUnderlineName 将驼峰命名改为小写的下划线命名 91 | func ToUnderlineName(camelName string) string { 92 | nameRune := []rune(camelName) 93 | normalizeName := make([]rune, 0, len(nameRune)) 94 | 95 | for ni := 0; ni < len(nameRune); ni++ { 96 | if ni != 0 && unicode.IsUpper(nameRune[ni]) && !unicode.IsUpper(nameRune[ni-1]) { 97 | normalizeName = append(normalizeName, '_') 98 | } 99 | 100 | r := nameRune[ni] 101 | if unicode.IsUpper(nameRune[ni]) { 102 | r = unicode.ToLower(r) 103 | } 104 | normalizeName = append(normalizeName, r) 105 | } 106 | return string(normalizeName) 107 | } 108 | -------------------------------------------------------------------------------- /http/controller_test.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type DemoController struct { 12 | BaseController 13 | } 14 | 15 | func (p *DemoController) Index(w http.ResponseWriter, r *http.Request) { 16 | fmt.Println("Index:", p.Path, p.Name) 17 | } 18 | 19 | func (p *DemoController) Second(w http.ResponseWriter, r *http.Request) { 20 | fmt.Println("Second:", p.Path, p.Name) 21 | } 22 | 23 | func TestReflectHandlers(t *testing.T) { 24 | testReflectHandlers(t, "demo1") 25 | testReflectHandlers(t, "demo2") 26 | } 27 | 28 | type LogMiddleware struct { 29 | Order int 30 | } 31 | 32 | func NewLogMiddleware(order int) *LogMiddleware { 33 | return &LogMiddleware{ 34 | Order: order, 35 | } 36 | } 37 | 38 | func (p *LogMiddleware) Handle(next MiddlewareFunc) MiddlewareFunc { 39 | return func(w http.ResponseWriter, r *http.Request) { 40 | fmt.Println("Begin process,order", p.Order) 41 | next(w, r) 42 | fmt.Println("Finish process,order", p.Order) 43 | } 44 | } 45 | 46 | func testReflectHandlers(t *testing.T, name string) { 47 | controller := &DemoController{ 48 | BaseController: BaseController{ 49 | Name: name, 50 | Path: "/" + name, 51 | HandlerMiddlewares: map[string][]Middleware{ 52 | "Index": { 53 | &LogMiddleware{Order: 0}, 54 | &LogMiddleware{Order: 1}, 55 | &LogMiddleware{Order: 2}, 56 | }, 57 | }, 58 | }, 59 | } 60 | 61 | mapping, err := reflectHandlers(controller) 62 | assert.Nil(t, err, "err") 63 | assert.EqualValues(t, 2, len(mapping)) 64 | 65 | mapping["index"].handlerFunc(nil, nil) 66 | mapping["second"].handlerFunc(nil, nil) 67 | 68 | mapping, err = reflectHandlers(controller) 69 | assert.Nil(t, err, "err") 70 | assert.EqualValues(t, 2, len(mapping)) 71 | mapping["second"].handlerFunc(nil, nil) 72 | mapping["index"].handlerFunc(nil, nil) 73 | } 74 | 75 | func TestToUnderlineName(t *testing.T) { 76 | assert.EqualValues(t, "index", ToUnderlineName("index")) 77 | assert.EqualValues(t, "index", ToUnderlineName("INDEX")) 78 | assert.EqualValues(t, "index", ToUnderlineName("Index")) 79 | assert.EqualValues(t, "in_dex", ToUnderlineName("InDex")) 80 | assert.EqualValues(t, "in_dex", ToUnderlineName("InDEX")) 81 | assert.EqualValues(t, "in_dex", ToUnderlineName("InDEx")) 82 | assert.EqualValues(t, "in_de_x", ToUnderlineName("InDeX")) 83 | assert.EqualValues(t, "in语言_de_x", ToUnderlineName("In语言DeX")) 84 | assert.EqualValues(t, "", ToUnderlineName("")) 85 | assert.EqualValues(t, "h5_ware_path", ToUnderlineName("H5WarePath")) 86 | } 87 | -------------------------------------------------------------------------------- /http/cookiejar.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "net/http" 5 | "net/http/cookiejar" 6 | "net/url" 7 | "sync" 8 | 9 | c "github.com/d0ngw/go/common" 10 | ) 11 | 12 | // RetrivedCookieJar 可持久化的Cookie 13 | type RetrivedCookieJar struct { 14 | jar *cookiejar.Jar 15 | urlCookies map[string][]*http.Cookie 16 | mu sync.Mutex 17 | } 18 | 19 | // NewRetrivedCookieJar 构建PersistCookieJar 20 | func NewRetrivedCookieJar(o *cookiejar.Options) *RetrivedCookieJar { 21 | jar, _ := cookiejar.New(o) 22 | return &RetrivedCookieJar{ 23 | jar: jar, 24 | urlCookies: map[string][]*http.Cookie{}, 25 | } 26 | } 27 | 28 | // SetCookies implements CookieJar.SetCookies 29 | func (p *RetrivedCookieJar) SetCookies(u *url.URL, cookies []*http.Cookie) { 30 | p.jar.SetCookies(u, cookies) 31 | cookieURL := u.String() 32 | if u != nil && cookies != nil { 33 | p.mu.Lock() 34 | p.urlCookies[cookieURL] = append(p.urlCookies[cookieURL], cookies...) 35 | p.mu.Unlock() 36 | } 37 | } 38 | 39 | // Cookies implements CookeJar.Cookies 40 | func (p *RetrivedCookieJar) Cookies(u *url.URL) []*http.Cookie { 41 | return p.jar.Cookies(u) 42 | } 43 | 44 | // URLAndCookies 取得所有的URL和Cookie 45 | func (p *RetrivedCookieJar) URLAndCookies() map[string][]*http.Cookie { 46 | all := map[string][]*http.Cookie{} 47 | p.mu.Lock() 48 | defer p.mu.Unlock() 49 | for k, v := range p.urlCookies { 50 | all[k] = v 51 | } 52 | return all 53 | } 54 | 55 | // SetURLAndCookies 设置所有的URL和Cookie 56 | func (p *RetrivedCookieJar) SetURLAndCookies(all map[string][]*http.Cookie) error { 57 | if all == nil { 58 | return nil 59 | } 60 | 61 | for u, cookies := range all { 62 | if cookies == nil { 63 | continue 64 | } 65 | 66 | cookieURL, err := url.Parse(u) 67 | if err != nil { 68 | c.Errorf("parse %s fail,err:%s", u, err) 69 | continue 70 | } 71 | p.SetCookies(cookieURL, cookies) 72 | } 73 | 74 | return nil 75 | } 76 | -------------------------------------------------------------------------------- /http/cookiejar_test.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/url" 7 | "testing" 8 | ) 9 | 10 | func TestCookieJar(t *testing.T) { 11 | jar := NewRetrivedCookieJar(nil) 12 | u, _ := url.Parse("https://www.google.com/tt") 13 | cookie1 := &http.Cookie{Name: "c", Value: "d"} 14 | cookie2 := &http.Cookie{Name: "e", Value: "f"} 15 | jar.SetCookies(u, []*http.Cookie{cookie1, cookie2}) 16 | 17 | u2, _ := url.Parse("https://www.google.com/t2") 18 | jar.SetCookies(u2, []*http.Cookie{cookie1, cookie2}) 19 | 20 | all := jar.URLAndCookies() 21 | for u, cookies := range all { 22 | fmt.Printf("u:%s", u) 23 | for _, c := range cookies { 24 | fmt.Printf(" cookie:%s", c) 25 | } 26 | fmt.Println() 27 | } 28 | 29 | jar = NewRetrivedCookieJar(nil) 30 | jar.SetURLAndCookies(all) 31 | all = jar.URLAndCookies() 32 | for u, cookies := range all { 33 | fmt.Printf("u:%s", u) 34 | for _, c := range cookies { 35 | fmt.Printf(" cookie:%s", c) 36 | } 37 | fmt.Println() 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /http/filesystem.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "net/http" 5 | "os" 6 | ) 7 | 8 | // NoDirFS 不输出目录列表的FS 9 | type NoDirFS struct { 10 | Fs http.FileSystem 11 | } 12 | 13 | // Open 取得指定的文件,如果name指向的是个目录,且目录下没有index.html,返回os.ErrPermission 14 | func (fs NoDirFS) Open(name string) (http.File, error) { 15 | f, err := fs.Fs.Open(name) 16 | if err != nil { 17 | return nil, err 18 | } 19 | stat, err := f.Stat() 20 | if err != nil { 21 | return nil, err 22 | } 23 | if stat.IsDir() { 24 | index, err := fs.Fs.Open(name + "/index.html") 25 | if err == nil { 26 | index.Close() 27 | return f, nil 28 | } 29 | f.Close() 30 | return nil, os.ErrPermission 31 | } 32 | return f, nil 33 | } 34 | -------------------------------------------------------------------------------- /http/http.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "net" 5 | "net/http" 6 | "strings" 7 | "sync" 8 | "time" 9 | 10 | c "github.com/d0ngw/go/common" 11 | "golang.org/x/net/netutil" 12 | ) 13 | 14 | type tcpKeepAliveListener struct { 15 | *net.TCPListener 16 | } 17 | 18 | // Accept接受连接 19 | func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { 20 | tc, err := ln.AcceptTCP() 21 | if err != nil { 22 | return nil, err 23 | } 24 | if err = tc.SetKeepAlive(true); err != nil { 25 | return 26 | } 27 | if err = tc.SetKeepAlivePeriod(3 * time.Minute); err != nil { 28 | return 29 | } 30 | return tc, nil 31 | } 32 | 33 | // GraceableHandler 安全地关闭的处理器 34 | type GraceableHandler struct { 35 | handler http.Handler 36 | waitGroup *sync.WaitGroup 37 | } 38 | 39 | func (p *GraceableHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 40 | p.waitGroup.Add(1) 41 | defer p.waitGroup.Done() 42 | 43 | p.handler.ServeHTTP(w, r) 44 | } 45 | 46 | // Service Http服务 47 | type Service struct { 48 | c.BaseService 49 | Conf *Config 50 | listener net.Listener 51 | serveMux *http.ServeMux 52 | graceHandler *GraceableHandler 53 | server *http.Server 54 | lock sync.Mutex 55 | } 56 | 57 | // Init 初始化Http服务 58 | func (p *Service) Init() bool { 59 | p.lock.Lock() 60 | defer p.lock.Unlock() 61 | 62 | serveMux := http.NewServeMux() 63 | 64 | for pattern, handler := range p.Conf.handles { 65 | if handler == nil { 66 | c.Errorf("Can't bind nil handlerFunc to path %s", pattern) 67 | return false 68 | } 69 | serveMux.Handle(pattern, p.handleWithMiddleware(handler)) 70 | } 71 | 72 | graceHandler := &GraceableHandler{ 73 | handler: serveMux, 74 | waitGroup: &sync.WaitGroup{}} 75 | 76 | server := &http.Server{ 77 | Addr: p.Conf.Addr, 78 | ReadTimeout: p.Conf.ReadTimeout * time.Second, 79 | WriteTimeout: p.Conf.WriteTimeout * time.Second, 80 | Handler: graceHandler} 81 | 82 | if p.Conf.Addr == "" { 83 | p.Conf.Addr = ":http" 84 | } 85 | 86 | p.graceHandler = graceHandler 87 | p.server = server 88 | p.serveMux = serveMux 89 | return true 90 | } 91 | 92 | // handleWithMiddleware 依次调用各个middleware 93 | func (p *Service) handleWithMiddleware(handler *handlerWithMiddleware) http.HandlerFunc { 94 | originHandler := func(w http.ResponseWriter, r *http.Request) { 95 | if ok, err := ErrorFromRequestContext(r); ok { 96 | c.Errorf("stop handle %s,cause by error:%s", r.RequestURI, err) 97 | } else { 98 | handler.handlerFunc(w, r) 99 | } 100 | } 101 | 102 | var middlewares = append(handler.middlewares, p.Conf.middlewares...) 103 | var middlewareCount = len(middlewares) 104 | 105 | h := originHandler 106 | for i := middlewareCount - 1; i >= 0; i-- { 107 | m := middlewares[i] 108 | h0 := h 109 | h = func(w http.ResponseWriter, r *http.Request) { 110 | m.Handle(h0)(w, r) 111 | } 112 | } 113 | 114 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 115 | h(w, r) 116 | }) 117 | } 118 | 119 | // Start 启动Http服务,开始端口监听和服务处理 120 | func (p *Service) Start() bool { 121 | p.lock.Lock() 122 | defer p.lock.Unlock() 123 | 124 | c.Infof("Listen at %s", p.Conf.Addr) 125 | ln, err := net.Listen("tcp", p.Conf.Addr) 126 | if err != nil { 127 | c.Errorf("Listen at %s fail,error:%v", p.Conf.Addr, err) 128 | return false 129 | } 130 | 131 | tcpListener := tcpKeepAliveListener{ln.(*net.TCPListener)} 132 | if p.Conf.MaxConns > 0 { 133 | p.listener = netutil.LimitListener(tcpListener, p.Conf.MaxConns) 134 | } else { 135 | p.listener = tcpListener 136 | } 137 | 138 | p.graceHandler.waitGroup.Add(1) 139 | 140 | go func() { 141 | defer p.graceHandler.waitGroup.Done() 142 | err := p.server.Serve(p.listener) 143 | if err != nil { 144 | var errLevel = c.Error 145 | if strings.Contains(err.Error(), "use of closed network connection") { 146 | errLevel = c.Warn 147 | } 148 | c.Logf(errLevel, "server.Serve return with %v", err) 149 | } 150 | }() 151 | return true 152 | } 153 | 154 | // Stop 停止Http服务,关闭端口监听和服务处理 155 | func (p *Service) Stop() bool { 156 | p.lock.Lock() 157 | defer p.lock.Unlock() 158 | 159 | if p.listener != nil { 160 | if err := p.listener.Close(); err != nil { 161 | c.Errorf("Close listener error:%v", err) 162 | } 163 | } 164 | 165 | //等待所有的服务 166 | c.Infof("Waiting shutdown") 167 | p.graceHandler.waitGroup.Wait() 168 | c.Infof("Finish shutdown") 169 | 170 | p.listener = nil 171 | p.graceHandler = nil 172 | p.server = nil 173 | p.serveMux = nil 174 | return true 175 | } 176 | -------------------------------------------------------------------------------- /http/http_config.go: -------------------------------------------------------------------------------- 1 | // Package http 提供基本的http服务 2 | package http 3 | 4 | import ( 5 | "fmt" 6 | "net/http" 7 | "strings" 8 | "sync" 9 | "time" 10 | 11 | c "github.com/d0ngw/go/common" 12 | "github.com/d0ngw/go/inject" 13 | ) 14 | 15 | // Config Http配置 16 | type Config struct { 17 | Addr string //Http监听地址 18 | ReadTimeout time.Duration //读超时,单位秒 19 | WriteTimeout time.Duration //写超时,单位秒 20 | MaxConns int //最大的并发连接数 21 | middlewares []Middleware //过滤操作 22 | controllers []Controller //controller 23 | handles map[string]*handlerWithMiddleware //handles 24 | controllerMux sync.RWMutex 25 | } 26 | 27 | // NewConfig 创建配置 28 | func NewConfig(addr string) *Config { 29 | return &Config{ 30 | Addr: addr, 31 | handles: map[string]*handlerWithMiddleware{}, 32 | middlewares: []Middleware{}, 33 | controllers: []Controller{}, 34 | } 35 | } 36 | 37 | // RegController 注册controller中的所有处理函数 38 | func (p *Config) RegController(controller Controller) error { 39 | if controller == nil { 40 | return fmt.Errorf("Can't reg nil contriller") 41 | } 42 | 43 | p.controllers = append(p.controllers, controller) 44 | 45 | var path = controller.GetPath() 46 | if !strings.HasSuffix(path, "/") { 47 | path += "/" 48 | } 49 | 50 | p.controllerMux.Lock() 51 | defer p.controllerMux.Unlock() 52 | 53 | handlers, err := reflectHandlers(controller) 54 | if err != nil { 55 | return err 56 | } 57 | 58 | if len(handlers) == 0 { 59 | c.Warnf("Can't find handler in %#v", controller) 60 | return nil 61 | } 62 | 63 | for handlerPath, h := range handlers { 64 | if strings.HasPrefix(handlerPath, "/") { 65 | handlerPath = handlerPath[1:] 66 | } 67 | 68 | patternPath := path + handlerPath 69 | if err := p.regHandleFunc(patternPath, h); err != nil { 70 | return err 71 | } 72 | } 73 | return nil 74 | } 75 | 76 | // RegHandleFunc 注册patternPath的处理函数handlerFunc 77 | func (p *Config) regHandleFunc(patternPath string, handle *handlerWithMiddleware) error { 78 | if _, ok := p.handles[patternPath]; ok { 79 | return fmt.Errorf("Duplicate ,path:%s", patternPath) 80 | } 81 | p.handles[patternPath] = handle 82 | return nil 83 | } 84 | 85 | // RegHandleFunc 注册patternPath的处理函数handlerFunc 86 | func (p *Config) RegHandleFunc(patternPath string, handlerFunc http.HandlerFunc) error { 87 | if _, ok := p.handles[patternPath]; ok { 88 | return fmt.Errorf("Duplicate ,path:%s", patternPath) 89 | } 90 | p.handles[patternPath] = &handlerWithMiddleware{handlerFunc, nil} 91 | return nil 92 | } 93 | 94 | // RegStaticFunc 注册静态资源patternPath的处理函数handlerFunc 95 | func (p *Config) RegStaticFunc(patternAndPath map[string]string) error { 96 | if patternAndPath == nil { 97 | return nil 98 | } 99 | for pattern, path := range patternAndPath { 100 | httpDir := http.Dir(path) 101 | fs := http.FileServer(NoDirFS{Fs: httpDir}) 102 | c.Infof("add static %s to %s", pattern, httpDir) 103 | err := p.RegHandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) { 104 | fs.ServeHTTP(w, r) 105 | }) 106 | if err != nil { 107 | return err 108 | } 109 | } 110 | return nil 111 | } 112 | 113 | // RegMiddleware 注册middleware,middleware的注册需要在RegController和RegHandleFunc之前完成 114 | func (p *Config) RegMiddleware(middleware Middleware) error { 115 | if middleware == nil { 116 | return fmt.Errorf("invalid middleware") 117 | } 118 | p.middlewares = append(p.middlewares, middleware) 119 | return nil 120 | } 121 | 122 | // InitWithInjector 初始化操作 123 | func (p *Config) InitWithInjector(injector *inject.Injector) error { 124 | for _, c := range p.controllers { 125 | injector.RequireInject(c) 126 | } 127 | for _, m := range p.middlewares { 128 | injector.RequireInject(m) 129 | } 130 | 131 | controllers := injector.GetInstancesByPrototype(struct{ s Controller }{}) 132 | for _, cer := range controllers { 133 | controller := cer.(Controller) 134 | if err := p.RegController(controller); err != nil { 135 | return fmt.Errorf("Reg controller %s fail,err:%s", controller.GetName(), err) 136 | } 137 | for _, middlewares := range controller.GetHandlerMiddlewares() { 138 | for _, m := range middlewares { 139 | injector.RequireInject(m) 140 | } 141 | } 142 | } 143 | return nil 144 | } 145 | -------------------------------------------------------------------------------- /http/http_context.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | ) 7 | 8 | type key int 9 | 10 | const ( 11 | errorKey key = 0 // 处理错误的key 12 | ) 13 | 14 | // RequestWithContext 向req的context中设置key = val,返回新的request 15 | func RequestWithContext(req *http.Request, key, val interface{}) *http.Request { 16 | ctx := req.Context() 17 | ctx = context.WithValue(ctx, key, val) 18 | return req.WithContext(ctx) 19 | } 20 | 21 | // FromRequestContext 从req的context中取得key值 22 | func FromRequestContext(req *http.Request, key interface{}) interface{} { 23 | return req.Context().Value(key) 24 | } 25 | 26 | // RequestWithError 向req中设置当前处理的错误,返回新的request 27 | func RequestWithError(req *http.Request, err error) *http.Request { 28 | return RequestWithContext(req, errorKey, err) 29 | } 30 | 31 | // ErrorFromRequestContext 从req的context取得错误值 32 | func ErrorFromRequestContext(req *http.Request) (bool, error) { 33 | err, ok := req.Context().Value(errorKey).(error) 34 | return ok, err 35 | } 36 | -------------------------------------------------------------------------------- /http/http_util_test.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "net/url" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | //PageParam 分野参数 11 | type PageParam struct { 12 | Page int 13 | PageSize int 14 | } 15 | 16 | type Int int32 17 | 18 | type params struct { 19 | PageParam 20 | ID int64 21 | Name string 22 | Weight float32 23 | Height float32 `pname:"h"` 24 | Ok bool 25 | Children []string `pname:"_"` 26 | Ages []int32 `psep:","` 27 | FriendsNames []string 28 | FriendsBooks []int64 `psep:","` 29 | FriendsWeights []float32 `psep:","` 30 | Years Int 31 | H5WarePath string 32 | Ints []int `psep:","` 33 | Int8s []int8 `psep:","` 34 | Int16s []int16 `psep:","` 35 | Errs []int `psep:","` 36 | } 37 | 38 | func TestParseParams(t *testing.T) { 39 | form := url.Values{} 40 | form.Set("id", "10") 41 | form.Set("name", "golang") 42 | form.Set("weight", "1.230") 43 | form.Set("h", "1.01") 44 | form.Set("ok", "true") 45 | form.Set("ages", "1,2,3") 46 | form["friends_names"] = []string{"tom", "jerry"} 47 | form.Set("friends_books", "1,2") 48 | form.Set("friends_weights", "0.1,0,-0.3") 49 | form.Set("page", "1") 50 | form.Set("page_size", "5") 51 | form.Set("years", "10") 52 | form.Set("h5_ware_path", "https://example.com") 53 | form.Set("ints", "1,2,3") 54 | form.Set("int8s", "4,5,6") 55 | form.Set("int16s", "7,8,9") 56 | form.Set("errs", "") 57 | 58 | p := ¶ms{} 59 | err := ParseParams(form, p) 60 | assert.Nil(t, err) 61 | t.Logf("%#v\n", p) 62 | assert.EqualValues(t, 10, p.ID) 63 | assert.EqualValues(t, "golang", p.Name) 64 | assert.EqualValues(t, 1.23, p.Weight) 65 | assert.EqualValues(t, 1.01, p.Height) 66 | assert.EqualValues(t, []int32{1, 2, 3}, p.Ages) 67 | assert.EqualValues(t, []string{"tom", "jerry"}, p.FriendsNames) 68 | assert.EqualValues(t, []float32{0.1, 0.0, -0.3}, p.FriendsWeights) 69 | assert.EqualValues(t, 1, p.Page) 70 | assert.EqualValues(t, 5, p.PageSize) 71 | assert.EqualValues(t, 10, p.Years) 72 | t.Logf("years %d", p.Years) 73 | assert.EqualValues(t, "https://example.com", p.H5WarePath) 74 | assert.EqualValues(t, []int{1, 2, 3}, p.Ints) 75 | assert.EqualValues(t, []int8{4, 5, 6}, p.Int8s) 76 | assert.EqualValues(t, []int16{7, 8, 9}, p.Int16s) 77 | assert.EqualValues(t, []int{}, p.Errs) 78 | } 79 | -------------------------------------------------------------------------------- /http/middleware.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | ) 7 | 8 | // MiddlewareFunc middleware函数 9 | type MiddlewareFunc func(http.ResponseWriter, *http.Request) 10 | 11 | // Middleware 接口 12 | type Middleware interface { 13 | // Handle处理 14 | Handle(next MiddlewareFunc) MiddlewareFunc 15 | } 16 | 17 | // RequestMehotdMiddleware http请求方法 18 | type RequestMehotdMiddleware struct { 19 | //允许的请求方法 20 | AllowsMethods map[string]struct{} 21 | } 22 | 23 | // Handle 校验Http请求的方法 24 | func (p *RequestMehotdMiddleware) Handle(next MiddlewareFunc) MiddlewareFunc { 25 | return func(w http.ResponseWriter, r *http.Request) { 26 | if _, ok := p.AllowsMethods[r.Method]; !ok { 27 | http.Error(w, "Bad Request", http.StatusMethodNotAllowed) 28 | return 29 | } 30 | next(w, r) 31 | } 32 | } 33 | 34 | // NewRequestMetodMiddleware 用methods构建middleware 35 | func NewRequestMetodMiddleware(methods ...string) *RequestMehotdMiddleware { 36 | m := &RequestMehotdMiddleware{ 37 | AllowsMethods: map[string]struct{}{}, 38 | } 39 | for _, method := range methods { 40 | m.AllowsMethods[strings.ToUpper(method)] = struct{}{} 41 | } 42 | return m 43 | } 44 | 45 | // Merge 合并Middleware 46 | func Merge(middlewares ...[]Middleware) []Middleware { 47 | var merged []Middleware 48 | for _, v := range middlewares { 49 | merged = append(merged, v...) 50 | } 51 | return merged 52 | } 53 | -------------------------------------------------------------------------------- /inject/Readme.md: -------------------------------------------------------------------------------- 1 | 2 | Dependency injection for go,similiar to [Guice](https://github.com/google/guice) 3 | -------------------------------------------------------------------------------- /inject/inject_helper.go: -------------------------------------------------------------------------------- 1 | package inject 2 | 3 | import ( 4 | "os" 5 | "path" 6 | 7 | c "github.com/d0ngw/go/common" 8 | ) 9 | 10 | // ConfigModuler get module depends on the config 11 | type ConfigModuler interface { 12 | ConfModule() (module *Module, err error) 13 | } 14 | 15 | // SetupInjector 从env指定的环境配置初始化配置,构建Injector 16 | func SetupInjector(config c.Configurer, addonConfig string, env string, modules ...*Module) (*Injector, error) { 17 | return SetupInjectorWithLoader(c.FileLoader, config, addonConfig, env, modules...) 18 | } 19 | 20 | // EnvConfRoot conf root 21 | const EnvConfRoot = "conf_root" 22 | 23 | // SetupInjectorWithLoader 从env指定的环境配置初始化配置,构建Injector 24 | func SetupInjectorWithLoader(loader c.ConfigLoader, config c.Configurer, addonConfig string, env string, modules ...*Module) (*Injector, error) { 25 | confDir := "conf" 26 | if os.Getenv(EnvConfRoot) != "" { 27 | confDir = os.Getenv(EnvConfRoot) 28 | } 29 | var ( 30 | confs = []string{"common.yaml"} 31 | ) 32 | 33 | if env != "" { 34 | confs = append(confs, "conf_"+env+".yaml") 35 | } else { 36 | confs = append(confs, "conf.yaml") 37 | } 38 | 39 | for _, f := range confs { 40 | conf := path.Join(confDir, f) 41 | if exist, err := loader.Exist(conf); err != nil { 42 | c.Errorf("check %s fail, err:%v", conf, err) 43 | return nil, err 44 | } else if !exist { 45 | c.Warnf("%s doesn't exist, skip", conf) 46 | } else { 47 | c.Infof("load conf from %s ", conf) 48 | if content, err := loader.Load(conf); err != nil { 49 | c.Errorf("load %s fail,err:%v", conf, err) 50 | return nil, err 51 | } else if len(content) > 0 { 52 | addonConfig += "\n" + string(content) + "\n" 53 | } 54 | } 55 | } 56 | 57 | err := c.LoadConfigWithLoader(loader, config, addonConfig, confDir) 58 | if err != nil { 59 | return nil, err 60 | } 61 | err = config.Parse() 62 | if err != nil { 63 | return nil, err 64 | } 65 | 66 | var confModule *Module 67 | if configModuler, ok := config.(ConfigModuler); ok { 68 | if confModule, err = configModuler.ConfModule(); err != nil { 69 | return nil, err 70 | } 71 | } 72 | 73 | // 绑定核心的服务 74 | module := NewModule() 75 | module.Bind(config) 76 | var allModuls []*Module 77 | allModuls = append(allModuls, module) 78 | allModuls = append(allModuls, modules...) 79 | if confModule != nil { 80 | allModuls = append(allModuls, confModule) 81 | } 82 | injector := NewInjector(allModuls) 83 | err = injector.Initialize() 84 | if err != nil { 85 | return nil, err 86 | } 87 | return injector, nil 88 | } 89 | 90 | // Injected 判断是否已经完成注入 91 | type Injected interface { 92 | // 是否已经完成注入 93 | IsInjected() bool 94 | } 95 | 96 | // IsInjected 判断i是否实现了Injected接口 97 | // 当i实现了Injected接口时,ok为true,这时injected表示接口i是否已经完成了注入 98 | func IsInjected(i interface{}) (ok bool, injected bool) { 99 | if i == nil { 100 | return 101 | } 102 | injectedi, ok := i.(Injected) 103 | if ok { 104 | injected = injectedi.IsInjected() 105 | } 106 | return 107 | } 108 | -------------------------------------------------------------------------------- /inject/inject_helper_test.go: -------------------------------------------------------------------------------- 1 | package inject 2 | 3 | import ( 4 | "errors" 5 | "os" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | type config struct { 12 | Name string `yaml:"name"` 13 | 14 | needModule bool 15 | } 16 | 17 | func (p *config) Parse() error { 18 | return nil 19 | } 20 | 21 | func (p *config) ConfModule() (module *Module, err error) { 22 | if !p.needModule { 23 | return 24 | } 25 | module = NewModule() 26 | module.BindWithName("hello", 2021) 27 | return 28 | } 29 | 30 | type userService struct { 31 | Injector *Injector `inject:"_"` 32 | Hello int `inject:"hello,optional"` 33 | } 34 | 35 | func (p *userService) Init() error { 36 | if p.Injector == nil { 37 | return errors.New("invalid Injector") 38 | } 39 | return nil 40 | } 41 | 42 | func TestSetupInjector(t *testing.T) { 43 | conf := &config{} 44 | module := NewModule() 45 | svc := &userService{} 46 | module.Bind(svc) 47 | 48 | err := os.Chdir("testdata") 49 | injector, err := SetupInjector(conf, "", "dev", module) 50 | assert.NoError(t, err) 51 | assert.NotNil(t, injector) 52 | assert.NotNil(t, svc.Injector) 53 | assert.EqualValues(t, 0, svc.Hello) 54 | 55 | conf.needModule = true 56 | svc = &userService{} 57 | module = NewModule() 58 | module.Bind(svc) 59 | injector, err = SetupInjector(conf, "", "dev", module) 60 | assert.NoError(t, err) 61 | assert.NotNil(t, injector) 62 | assert.NotNil(t, svc.Injector) 63 | assert.EqualValues(t, 2021, svc.Hello) 64 | } 65 | -------------------------------------------------------------------------------- /inject/inject_test.go: -------------------------------------------------------------------------------- 1 | package inject 2 | 3 | import ( 4 | _ "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | type accountService interface { 11 | Name() string 12 | } 13 | 14 | type injectedService struct { 15 | Injector *Injector `inject:"_"` 16 | } 17 | 18 | func (p *injectedService) IsInjected() bool { 19 | return p.Injector != nil 20 | } 21 | 22 | type AreaService struct { 23 | Injector *Injector `inject:"_"` 24 | } 25 | 26 | type AgeService struct { 27 | LdapAccountServicePtr *ldapAccount `inject:"ldap"` 28 | } 29 | 30 | type userRegService struct { 31 | AreaService 32 | *AgeService 33 | LdapImpl accountService `inject:"ldap"` 34 | DbImpl accountService `inject:"db"` 35 | ID string 36 | Injector *Injector `inject:"_"` 37 | LdapAccountServicePtr *ldapAccount `inject:"ldap"` 38 | Impls []accountService `inject:"_,optional"` 39 | } 40 | 41 | type ldapAccount struct { 42 | n string 43 | } 44 | 45 | func (p *ldapAccount) Name() string { 46 | return p.n + "@ldap" 47 | } 48 | 49 | type dbAccount struct { 50 | n string 51 | } 52 | 53 | func (p *dbAccount) Name() string { 54 | return p.n + "@db" 55 | } 56 | 57 | func TestInject(t *testing.T) { 58 | ldapImplA := ldapAccount{n: "a"} 59 | dbImplA := dbAccount{n: "b"} 60 | 61 | user := "user_name" 62 | 63 | mod := NewModule() 64 | mod.BindWithName("ldap", &ldapImplA) 65 | mod.BindWithName("db", &dbImplA) 66 | mod.Bind(user) 67 | mod.Bind(&user) 68 | 69 | injector := NewInjector([]*Module{mod}) 70 | 71 | regService := &userRegService{AgeService: &AgeService{}} 72 | injector.RequireInject(regService) 73 | assert.Equal(t, "a@ldap", regService.LdapImpl.Name()) 74 | assert.Equal(t, "b@db", regService.DbImpl.Name()) 75 | assert.NotNil(t, regService.Injector) 76 | assert.True(t, injector == regService.Injector) 77 | assert.NotNil(t, regService.AreaService.Injector) 78 | assert.NotNil(t, regService.AgeService.LdapAccountServicePtr) 79 | assert.EqualValues(t, regService.AgeService.LdapAccountServicePtr.Name(), regService.LdapImpl.Name()) 80 | assert.NotNil(t, regService.Impls) 81 | assert.EqualValues(t, 2, len(regService.Impls)) 82 | _, ok := regService.Impls[0].(*dbAccount) 83 | if !ok { 84 | _, _ = regService.Impls[0].(*ldapAccount) 85 | _, _ = regService.Impls[1].(*dbAccount) 86 | } else { 87 | _, _ = regService.Impls[1].(*ldapAccount) 88 | } 89 | t.Logf("Impls %#v", regService.Impls) 90 | 91 | injector.RequireInjectWithOverrideTags(regService, map[string]string{"DbImpl": "ldap", "LdapImpl": "db"}) 92 | assert.Equal(t, "b@db", regService.LdapImpl.Name()) 93 | assert.Equal(t, "a@ldap", regService.DbImpl.Name()) 94 | 95 | injector.RequireInject(regService) 96 | assert.Equal(t, "a@ldap", regService.LdapImpl.Name()) 97 | assert.Equal(t, "b@db", regService.DbImpl.Name()) 98 | 99 | //根据名称查找 100 | ldapImplGet := injector.GetInstanceByPrototype("ldap", struct{ s accountService }{}).(accountService) 101 | assert.NotNil(t, ldapImplGet) 102 | assert.Equal(t, &ldapImplA, ldapImplGet) 103 | assert.Equal(t, "a@ldap", ldapImplGet.Name()) 104 | 105 | ldapImplGet, ok = injector.GetInstanceByPrototype("", struct{ s accountService }{}).(accountService) 106 | assert.False(t, ok) 107 | assert.Nil(t, ldapImplGet) 108 | 109 | injectedSvc := &injectedService{} 110 | injector.RequireInject(injectedSvc) 111 | injector.RequireInject(injectedSvc) 112 | } 113 | 114 | func TestInjectInModule(t *testing.T) { 115 | ldapImplA := ldapAccount{n: "a"} 116 | dbImplA := dbAccount{n: "b"} 117 | regService := &userRegService{} 118 | 119 | user := "user_name" 120 | mod := NewModule() 121 | mod.BindWithName("ldap", &ldapImplA) 122 | mod.BindWithName("db", &dbImplA) 123 | mod.Bind(user) 124 | mod.Bind(regService) 125 | 126 | _ = NewInjector([]*Module{mod}) 127 | 128 | assert.Equal(t, "a@ldap", regService.LdapImpl.Name()) 129 | assert.Equal(t, "b@db", regService.DbImpl.Name()) 130 | } 131 | 132 | func TestInjectInModuleWithTag(t *testing.T) { 133 | ldapImplA := ldapAccount{n: "a"} 134 | dbImplA := dbAccount{n: "b"} 135 | regService := &userRegService{} 136 | 137 | user := "user_name" 138 | mod := NewModule() 139 | mod.BindWithName("ldap", &ldapImplA) 140 | mod.BindWithName("db", &dbImplA) 141 | mod.Bind(user) 142 | mod.BindWithNameOverrideTags("", regService, map[string]string{"DbImpl": "ldap"}) 143 | 144 | _ = NewInjector([]*Module{mod}) 145 | 146 | assert.Equal(t, "a@ldap", regService.LdapImpl.Name()) 147 | assert.Equal(t, "a@ldap", regService.DbImpl.Name()) 148 | } 149 | 150 | func TestInjectInModuleWithProviderFunc(t *testing.T) { 151 | ldapImplA := ldapAccount{n: "a"} 152 | dbImplA := dbAccount{n: "b"} 153 | regService := &userRegService{} 154 | 155 | user := "user_name" 156 | mod := NewModule() 157 | mod.BindWithProviderFunc("ldap", func() interface{} { 158 | return &ldapImplA 159 | }) 160 | mod.BindWithProviderFunc("db", func() interface{} { 161 | return &dbImplA 162 | }) 163 | mod.Bind(user) 164 | mod.BindWithNameOverrideTags("", regService, map[string]string{"DbImpl": "ldap"}) 165 | 166 | _ = NewInjector([]*Module{mod}) 167 | 168 | assert.Equal(t, "a@ldap", regService.LdapImpl.Name()) 169 | assert.Equal(t, "a@ldap", regService.DbImpl.Name()) 170 | } 171 | 172 | type ldapProvider struct{} 173 | 174 | func (p ldapProvider) GetInstance() interface{} { 175 | return &ldapAccount{n: "a"} 176 | } 177 | 178 | type dbProvider struct{} 179 | 180 | func (p dbProvider) GetInstance() interface{} { 181 | return &dbAccount{n: "b"} 182 | } 183 | 184 | func TestInjectInModuleWithProvider(t *testing.T) { 185 | regService := &userRegService{} 186 | 187 | user := "user_name" 188 | mod := NewModule() 189 | mod.BindWithProvider("ldap", ldapProvider{}) 190 | mod.BindWithProvider("db", dbProvider{}) 191 | mod.Bind(user) 192 | mod.BindWithNameOverrideTags("", regService, map[string]string{"DbImpl": "ldap"}) 193 | 194 | _ = NewInjector([]*Module{mod}) 195 | 196 | assert.Equal(t, "a@ldap", regService.LdapImpl.Name()) 197 | assert.Equal(t, "a@ldap", regService.DbImpl.Name()) 198 | 199 | } 200 | 201 | func TestGetInstancesByPrototype(t *testing.T) { 202 | ldapImplA := ldapAccount{n: "a"} 203 | dbImplA := dbAccount{n: "b"} 204 | regService := &userRegService{} 205 | 206 | user := "user_name" 207 | mod := NewModule() 208 | mod.BindWithName("ldap", &ldapImplA) 209 | mod.BindWithName("db", &dbImplA) 210 | mod.Bind(user) 211 | mod.Bind(regService) 212 | 213 | injector := NewInjector([]*Module{mod}) 214 | 215 | assert.Equal(t, "a@ldap", regService.LdapImpl.Name()) 216 | assert.Equal(t, "b@db", regService.DbImpl.Name()) 217 | 218 | allAccountServices := injector.GetInstancesByPrototype(struct{ a accountService }{}) 219 | assert.Equal(t, 2, len(allAccountServices)) 220 | 221 | for _, v := range allAccountServices { 222 | _ = v.(accountService) 223 | } 224 | } 225 | 226 | type userRegServiceBad struct { 227 | LdapAccountService ldapAccount `inject:"_"` 228 | LdapAccountServicePtr *ldapAccount `inject:"_"` 229 | } 230 | 231 | func TestInjectNoEmbed(t *testing.T) { 232 | ldapImplA := ldapAccount{n: "a"} 233 | ldapImplPtr := ldapAccount{n: "ptr"} 234 | 235 | mod := NewModule() 236 | mod.Bind(ldapImplA) 237 | mod.Bind(&ldapImplPtr) 238 | 239 | injector := NewInjector([]*Module{mod}) 240 | 241 | regServiceBad := &userRegServiceBad{} 242 | injector.RequireInject(regServiceBad) 243 | assert.Equal(t, "a@ldap", regServiceBad.LdapAccountService.Name()) 244 | assert.Equal(t, "ptr@ldap", regServiceBad.LdapAccountServicePtr.Name()) 245 | t.Logf("LdapAccountService %s", regServiceBad.LdapAccountService.Name()) 246 | t.Logf("LdapAccountServicePtr %s", regServiceBad.LdapAccountServicePtr.Name()) 247 | } 248 | 249 | func TestIsBind(t *testing.T) { 250 | ldapImplA := ldapAccount{n: "a"} 251 | ldapImplPtr := ldapAccount{n: "ptr"} 252 | ldapImplB := ldapAccount{n: "b"} 253 | 254 | mod := NewModule() 255 | mod.Bind(ldapImplA) 256 | mod.Bind(&ldapImplPtr) 257 | 258 | assert.True(t, mod.IsBind(&ldapImplPtr, "")) 259 | assert.True(t, mod.IsBind(ldapImplA, "")) 260 | assert.False(t, mod.IsBind(ldapImplB, "")) 261 | } 262 | -------------------------------------------------------------------------------- /inject/module.go: -------------------------------------------------------------------------------- 1 | package inject 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | 7 | c "github.com/d0ngw/go/common" 8 | ) 9 | 10 | // 内部的绑定实例 11 | type internalBind struct { 12 | name string //绑定的名称 13 | instance interface{} //绑定的实例 14 | injectType reflect.Type //注入的类型 15 | injectValue reflect.Value //注入的值,相当于reflect.ValueOf(instance) 16 | injectTags map[string]string //用于注入的tag,用于覆盖struct field中定义的tag 17 | } 18 | 19 | func (p internalBind) String() string { 20 | return fmt.Sprintf("%v#%s", p.injectType, p.name) 21 | } 22 | 23 | // bindKey 用于绑定的key 24 | type bindKey struct { 25 | bindName string 26 | bindType reflect.Type 27 | } 28 | 29 | func (p bindKey) String() string { 30 | return fmt.Sprintf("%v#%s", p.bindType, p.bindName) 31 | } 32 | 33 | // Provider 提供类似Guice Provider的功能,用于创建一个对象 34 | type Provider interface { 35 | // GetInstance 用于创建一个实例 36 | GetInstance() interface{} 37 | } 38 | 39 | // ProviderFunc 定义用于创建一个对象的函数类型 40 | type ProviderFunc func() interface{} 41 | 42 | // Module 提供Guice Module的功能 43 | type Module struct { 44 | binds []*internalBind 45 | } 46 | 47 | // NewModule 创建新的Module 48 | func NewModule() *Module { 49 | return &Module{binds: []*internalBind{}} 50 | } 51 | 52 | // BindWithNameOverrideTags 添加带名称的绑定,injectTags用于覆盖instance中struct field中field中定义的inject tag 53 | func (p *Module) BindWithNameOverrideTags(name string, instance interface{}, injectTags map[string]string) { 54 | if instance == nil { 55 | panic("Can't bind nil instance") 56 | } 57 | b := &internalBind{name, instance, injectType(instance), reflect.ValueOf(instance), injectTags} 58 | p.binds = append(p.binds, b) 59 | } 60 | 61 | // BindWithName 添加带名称的绑定 62 | func (p *Module) BindWithName(name string, instance interface{}) { 63 | p.BindWithNameOverrideTags(name, instance, map[string]string{}) 64 | } 65 | 66 | // Bind 添加不带名称的绑定 67 | func (p *Module) Bind(instance interface{}) { 68 | p.BindWithName("", instance) 69 | } 70 | 71 | // BindWithProvider 通过Provider提供带名称的绑定功能 72 | func (p *Module) BindWithProvider(name string, provider Provider) { 73 | if instance := provider.GetInstance(); instance != nil { 74 | p.BindWithName(name, instance) 75 | return 76 | } 77 | err := fmt.Errorf("Cant't bind nil instalce with name:%s,provider:%v", name, provider) 78 | panic(err) 79 | } 80 | 81 | // BindWithProviderFunc 通过Provider提供带名称的绑定功能 82 | func (p *Module) BindWithProviderFunc(name string, providerFunc ProviderFunc) { 83 | if instance := providerFunc(); instance != nil { 84 | p.BindWithName(name, instance) 85 | return 86 | } 87 | err := fmt.Errorf("Cant't bind nil instalce with name:%s,providerFunc:%v", name, providerFunc) 88 | panic(err) 89 | } 90 | 91 | // Append 将src中的绑定追加到本module 92 | func (p *Module) Append(src *Module) { 93 | if src == nil || len(src.binds) == 0 { 94 | return 95 | } 96 | p.binds = append(p.binds, src.binds...) 97 | } 98 | 99 | // IsBind check if the instance has been binded 100 | func (p *Module) IsBind(instance interface{}, name string) bool { 101 | for _, internal := range p.binds { 102 | if internal.instance == instance && internal.name == name { 103 | return true 104 | } 105 | } 106 | return false 107 | } 108 | 109 | func checkIsInterface(typ reflect.Type) bool { 110 | isInterface := false 111 | if typ.Kind() == reflect.Ptr { 112 | if typ.Elem().Kind() == reflect.Interface { 113 | isInterface = true 114 | } 115 | } else if typ.Kind() == reflect.Interface { 116 | isInterface = true 117 | } 118 | return isInterface 119 | } 120 | 121 | // injectType 取得注入的类型,如果实例不能被注入,会抛出一个panic 122 | func injectType(instance interface{}) reflect.Type { 123 | val := reflect.ValueOf(instance) 124 | typ := val.Type() 125 | 126 | //确保typ的类型不是interface{} 127 | if checkIsInterface(typ) { 128 | panic(fmt.Errorf("The type of instance `%#v` is interface,can't find it's exact type", val.Interface())) 129 | } 130 | if typ.Kind() != reflect.Ptr && reflect.Indirect(val).Kind() == reflect.Struct { 131 | c.Errorf("struct %T is not pointer but it's will be injected, please make sure it's expected.", instance) 132 | } 133 | return typ 134 | } 135 | 136 | // mergeBinds 合并多个模块的绑定,返回未命名的绑定和命名绑定 137 | func mergeBinds(modules []*Module) (unnamed []*internalBind, named map[string][]*internalBind, all []*internalBind) { 138 | all = []*internalBind{} 139 | unnamed = []*internalBind{} 140 | named = map[string][]*internalBind{} 141 | 142 | uniqBindMap := map[bindKey]struct{}{} 143 | 144 | for _, module := range modules { 145 | for _, bind := range module.binds { 146 | bindkey := bindKey{bind.name, bind.injectType} 147 | if _, ok := uniqBindMap[bindkey]; ok { 148 | panic(fmt.Errorf("Duplicate bind %s", bindkey)) 149 | } else { 150 | uniqBindMap[bindkey] = struct{}{} 151 | if len(bind.name) == 0 { 152 | unnamed = append(unnamed, bind) 153 | all = append(all, bind) 154 | } else { 155 | namedBinds := named[bind.name] 156 | namedBinds = append(namedBinds, bind) 157 | named[bind.name] = namedBinds 158 | all = append(all, bind) 159 | } 160 | } 161 | } 162 | } 163 | return 164 | } 165 | 166 | // Merge 合并Module 167 | func Merge(modules ...[]*Module) []*Module { 168 | var merged []*Module 169 | for _, v := range modules { 170 | merged = append(merged, v...) 171 | } 172 | return merged 173 | } 174 | -------------------------------------------------------------------------------- /inject/testdata/conf/common.yaml: -------------------------------------------------------------------------------- 1 | name: d0ngw -------------------------------------------------------------------------------- /orm/comm.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | 7 | // import mysql 8 | _ "github.com/go-sql-driver/mysql" 9 | ) 10 | 11 | // NullTime null time 12 | type NullTime = sql.NullTime 13 | 14 | // DBError 数据库操作错误 15 | type DBError struct { 16 | Msg string 17 | Err error 18 | } 19 | 20 | func (e *DBError) Error() string { 21 | return fmt.Sprintf("DBError msg:%s,err:%v", e.Msg, e.Err) 22 | } 23 | 24 | // NewDBError 构建数据库操作错误 25 | func NewDBError(err error, msg string) *DBError { 26 | return &DBError{Msg: msg, Err: err} 27 | } 28 | 29 | // NewDBErrorf 使用fmt.Sprintf构建 30 | func NewDBErrorf(err error, msgFormat string, args ...interface{}) *DBError { 31 | return &DBError{Msg: fmt.Sprintf(msgFormat, args...), Err: err} 32 | } 33 | 34 | // Entity 实体接口 35 | type Entity interface { 36 | TableName() string 37 | } 38 | 39 | // ShardHandler 分片处理 40 | type ShardHandler func() (shardName string, err error) 41 | 42 | // ShardEntity 支持按表分片实体的接口 43 | type ShardEntity interface { 44 | Entity 45 | //TableShardFunc table分片函数 46 | TableShardFunc() ShardHandler 47 | //SetTableShardFunc table设置分片函数 48 | SetTableShardFunc(ShardHandler) 49 | } 50 | 51 | // BaseShardEntity 基础的分片实体 52 | type BaseShardEntity struct { 53 | tblShardFunc ShardHandler 54 | } 55 | 56 | // TableShardFunc implements ShardEntity.TableShardFunc 57 | func (p *BaseShardEntity) TableShardFunc() ShardHandler { 58 | return p.tblShardFunc 59 | } 60 | 61 | // SetTableShardFunc implements ShardEntity.SetTableShardFunc 62 | func (p *BaseShardEntity) SetTableShardFunc(f ShardHandler) { 63 | p.tblShardFunc = f 64 | } 65 | 66 | // EntitySlice type for slice of EntityInterface 67 | type EntitySlice []Entity 68 | 69 | // ToInterface convert EntitySlice to []interface{} 70 | func (p EntitySlice) ToInterface() []interface{} { 71 | if p == nil { 72 | return nil 73 | } 74 | ret := make([]interface{}, len(p)) 75 | for i := range p { 76 | ret[i] = p[i] 77 | } 78 | return ret 79 | } 80 | 81 | // Pool 数据库连接池 82 | type Pool struct { 83 | db *sql.DB 84 | name string 85 | } 86 | 87 | //NewOp 创建DBOper 88 | func (p *Pool) NewOp() *Op { 89 | return &Op{pool: p} 90 | } 91 | 92 | // Name pool name 93 | func (p *Pool) Name() string { 94 | return p.name 95 | } 96 | 97 | // PoolFunc the func to crate db pool 98 | type PoolFunc func(config *DBConfig) (pool *Pool, err error) 99 | -------------------------------------------------------------------------------- /orm/comm_test.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "encoding/json" 7 | "fmt" 8 | "io/ioutil" 9 | "os" 10 | "path" 11 | "strings" 12 | "testing" 13 | ) 14 | 15 | var ( 16 | config = DBConfig{ 17 | User: "root", 18 | Pass: "123456", 19 | URL: "127.0.0.1:3306", 20 | Schema: "test", 21 | MaxConn: 100, 22 | MaxIdle: 10, 23 | } 24 | dbpool, err = NewMySQLDBPool(&config) 25 | 26 | setupSQL, _ = ioutil.ReadFile(path.Join("testdata", "setup.sql")) 27 | teardownSQL, _ = ioutil.ReadFile(path.Join("testdata", "teardown.sql")) 28 | ) 29 | 30 | // Conf 配置 31 | type Conf struct { 32 | IDs []int64 `json:"ids"` 33 | } 34 | 35 | // Value impls driver.Valuer for Range 36 | func (p *Conf) Value() (driver.Value, error) { 37 | if p == nil { 38 | return nil, nil 39 | } 40 | v, err := json.Marshal(p) 41 | if err != nil { 42 | return nil, err 43 | } 44 | return string(v), nil 45 | } 46 | 47 | // Scan impls sql.Scanner for Range,src只支持string 48 | func (p *Conf) Scan(src interface{}) error { 49 | if src == nil { 50 | return nil 51 | } 52 | var source []byte 53 | switch val := src.(type) { 54 | case string: 55 | source = []byte(val) 56 | case []byte: 57 | source = val 58 | default: 59 | return fmt.Errorf("Incompatible range %T,value:%v", val, val) 60 | } 61 | if len(source) > 0 { 62 | pv := &Conf{} 63 | err := json.Unmarshal(source, pv) 64 | if err != nil { 65 | return err 66 | } 67 | *p = *pv 68 | } 69 | return nil 70 | } 71 | 72 | // Conf2 配置 73 | type Conf2 struct { 74 | IDs []int64 `json:"ids"` 75 | } 76 | 77 | // Value impls driver.Valuer for Range 78 | func (p Conf2) Value() (driver.Value, error) { 79 | v, err := json.Marshal(p) 80 | if err != nil { 81 | return nil, err 82 | } 83 | return string(v), nil 84 | } 85 | 86 | // Scan impls sql.Scanner for Range,src只支持string 87 | func (p *Conf2) Scan(src interface{}) error { 88 | if src == nil { 89 | return nil 90 | } 91 | var source []byte 92 | switch val := src.(type) { 93 | case string: 94 | source = []byte(val) 95 | case []byte: 96 | source = val 97 | default: 98 | return fmt.Errorf("Incompatible range %T,value:%v", val, val) 99 | } 100 | if len(source) > 0 { 101 | pv := &Conf2{} 102 | err := json.Unmarshal(source, pv) 103 | if err != nil { 104 | return err 105 | } 106 | *p = *pv 107 | } 108 | return nil 109 | } 110 | 111 | type AutoID struct { 112 | ID int64 `column:"id" pk:"Y"` 113 | Name2 sql.NullString `column:"name2"` 114 | } 115 | 116 | type tmodel struct { 117 | AutoID 118 | BaseShardEntity 119 | Name sql.NullString `column:"name"` 120 | Time sql.NullInt64 `column:"create_time"` 121 | F64 sql.NullFloat64 `column:"f64"` 122 | Conf *Conf `column:"conf"` 123 | Conf2 Conf2 `column:"conf2"` 124 | Ver int64 `column:"ver"` 125 | Age int64 `column:"age"` 126 | } 127 | 128 | func (tm *tmodel) TableName() string { 129 | return "tt" 130 | } 131 | 132 | type User struct { 133 | BaseShardEntity 134 | ID int64 `column:"id" pk:"Y"` 135 | Name sql.NullString `column:"name"` 136 | Age int64 `column:"age"` 137 | Birthday NullTime `column:"birthday"` 138 | } 139 | 140 | func (p *User) TableName() string { 141 | return "user" 142 | } 143 | 144 | func TestMain(m *testing.M) { 145 | setUp() 146 | code := m.Run() 147 | teardown() 148 | os.Exit(code) 149 | } 150 | 151 | func setUp() { 152 | sqls := strings.Split(string(setupSQL), "--") 153 | for _, s := range sqls { 154 | s = strings.TrimSpace(s) 155 | if s == "" { 156 | continue 157 | } 158 | _, err := dbpool.db.Exec(s) 159 | if err != nil { 160 | fmt.Println(s) 161 | panic(err) 162 | } 163 | } 164 | } 165 | 166 | func teardown() { 167 | sqls := strings.Split(string(teardownSQL), "--") 168 | for _, s := range sqls { 169 | s = strings.TrimSpace(s) 170 | if s == "" { 171 | continue 172 | } 173 | _, err := dbpool.db.Exec(s) 174 | if err != nil { 175 | panic(err) 176 | } 177 | } 178 | } 179 | -------------------------------------------------------------------------------- /orm/config.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "fmt" 5 | 6 | c "github.com/d0ngw/go/common" 7 | ) 8 | 9 | // DBConfigurer DB配置器 10 | type DBConfigurer interface { 11 | c.Configurer 12 | DBConfig() *DBConfig 13 | } 14 | 15 | // DBShardConfigurer db shard configurer 16 | type DBShardConfigurer interface { 17 | c.Configurer 18 | DBShardConfig() *DBShardConfig 19 | } 20 | 21 | // EntityShardConfigurer entity shard configurer 22 | type EntityShardConfigurer interface { 23 | c.Configurer 24 | EntityShardConfig() *EntityShardConfig 25 | } 26 | 27 | //DBConfig 数据库配置 28 | type DBConfig struct { 29 | User string `yaml:"user"` 30 | Pass string `yaml:"pass"` 31 | URL string `yaml:"url"` 32 | Schema string `yaml:"schema"` 33 | MaxConn int `yaml:"maxConn"` 34 | MaxIdle int `yaml:"maxIdle"` 35 | MaxTimeSecond int `yaml:"maxTimeSecond"` 36 | Charset string `yaml:"charset"` 37 | Ext map[string]string `yaml:"ext"` 38 | } 39 | 40 | // Parse implements DBConfigurer 41 | func (p *DBConfig) Parse() error { 42 | if p.URL == "" { 43 | return fmt.Errorf("need url") 44 | } 45 | if p.Schema == "" { 46 | return fmt.Errorf("need schema") 47 | } 48 | return nil 49 | } 50 | 51 | // DBConfig implements DBConfigurer 52 | func (p *DBConfig) DBConfig() *DBConfig { 53 | return p 54 | } 55 | 56 | // DBShardConfig db shard config 57 | type DBShardConfig struct { 58 | Shards map[string]*DBConfig `yaml:"shards"` 59 | Default string `yaml:"default"` 60 | } 61 | 62 | // Parse implements Configurer.Parse 63 | func (p *DBShardConfig) Parse() error { 64 | if p == nil { 65 | c.Warnf("no db config") 66 | return nil 67 | } 68 | 69 | c.Infof("db shards count:%d", len(p.Shards)) 70 | for k, v := range p.Shards { 71 | if v == nil { 72 | return fmt.Errorf("no db config for %s", k) 73 | } 74 | if err := v.Parse(); err != nil { 75 | return err 76 | } 77 | } 78 | 79 | if p.Default != "" { 80 | if p.Shards[p.Default] == nil { 81 | return fmt.Errorf("can't find default shard %s", p.Default) 82 | } 83 | } 84 | return nil 85 | } 86 | 87 | // DBShardConfig implements DBShardConfigurer 88 | func (p *DBShardConfig) DBShardConfig() *DBShardConfig { 89 | return p 90 | } 91 | 92 | // EntityShardRuleConfig 实体的shard规则 93 | type EntityShardRuleConfig struct { 94 | Name string `yaml:"name"` //名称 95 | DBShard *OneRule `yaml:"db_shard"` //数据库实例的配置 96 | TableShard *OneRule `yaml:"table_shard"` //数据库表的配置 97 | Default bool `yaml:"default"` //是否是默认规则 98 | meta Meta 99 | } 100 | 101 | // Parse implements Configurer.Parse 102 | func (p *EntityShardRuleConfig) Parse() error { 103 | if p.Name == "" { 104 | return fmt.Errorf("invalid name") 105 | } 106 | if p.DBShard != nil { 107 | if err := p.DBShard.Parse(); err != nil { 108 | return fmt.Errorf("parse db_shard fail,name:%s,err:%v", p.Name, err) 109 | } 110 | } 111 | if p.TableShard != nil { 112 | if err := p.TableShard.Parse(); err != nil { 113 | return fmt.Errorf("parse table_shard fail,name:%s,err:%v", p.Name, err) 114 | } 115 | } 116 | return nil 117 | } 118 | 119 | type entityRule struct { 120 | meta Meta 121 | defaultRule *EntityShardRuleConfig 122 | rules map[string]*EntityShardRuleConfig 123 | } 124 | 125 | // EntityShardConfig entity shad config 126 | type EntityShardConfig struct { 127 | // pkgPath -> entity name -> rules 128 | Entities map[string]map[string][]*EntityShardRuleConfig `yaml:"entities"` 129 | entities map[string]map[string]*entityRule 130 | } 131 | 132 | // Parse implements Configurer.Parse 133 | func (p *EntityShardConfig) Parse() error { 134 | if len(p.Entities) == 0 { 135 | c.Infof("no entities") 136 | return nil 137 | } 138 | 139 | entities := map[string]map[string]*entityRule{} 140 | 141 | for pkgPath, pkgEntities := range p.Entities { 142 | pkg := map[string]*entityRule{} 143 | entities[pkgPath] = pkg 144 | for entityName, rules := range pkgEntities { 145 | meta := findMetaWithPkgAndName(pkgPath, entityName) 146 | if meta == nil { 147 | return fmt.Errorf("can't find meta for %s.%s", pkgPath, entityName) 148 | } 149 | 150 | entity := &entityRule{rules: map[string]*EntityShardRuleConfig{}, meta: meta} 151 | pkg[entityName] = entity 152 | 153 | for _, rule := range rules { 154 | if err := rule.Parse(); err != nil { 155 | return fmt.Errorf("parse %s/%s %s fail,err:%v", pkgPath, entityName, rule.Name, err) 156 | } 157 | rule.meta = meta 158 | if rule.Default { 159 | if entity.defaultRule == nil { 160 | entity.defaultRule = rule 161 | } else { 162 | return fmt.Errorf("duplicate default rule for %s/%s", pkgPath, entityName) 163 | } 164 | } 165 | if entity.rules[rule.Name] != nil { 166 | return fmt.Errorf("duplicate default rule for %s/%s", pkgPath, entityName) 167 | } 168 | entity.rules[rule.Name] = rule 169 | } 170 | } 171 | } 172 | p.entities = entities 173 | return nil 174 | } 175 | 176 | // EntityShardConfig implements EntityShardConfigurer 177 | func (p *EntityShardConfig) EntityShardConfig() *EntityShardConfig { 178 | return p 179 | } 180 | -------------------------------------------------------------------------------- /orm/mysql.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "log" 7 | "net/url" 8 | "time" 9 | 10 | c "github.com/d0ngw/go/common" 11 | ) 12 | 13 | // NewMySQLDBPool build mysql db pool from config 14 | func NewMySQLDBPool(config *DBConfig) (*Pool, error) { 15 | if config == nil { 16 | return nil, &DBError{"Not found config", nil} 17 | } 18 | 19 | if len(config.User) == 0 || len(config.URL) == 0 || len(config.Schema) == 0 { 20 | return nil, &DBError{"Invalid config", nil} 21 | } 22 | 23 | charset := config.Charset 24 | if charset == "" { 25 | charset = "utf8" 26 | } 27 | 28 | //设置时间为本地时间,并解析时间 29 | loc, err := time.LoadLocation("Local") 30 | connectURL := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=%s&loc=%s&parseTime=true", config.User, config.Pass, config.URL, config.Schema, charset, url.QueryEscape(loc.String())) 31 | for k, v := range config.Ext { 32 | if k != "" { 33 | connectURL = connectURL + "&" + k + "=" + url.QueryEscape(v) 34 | } 35 | } 36 | 37 | db, err := sql.Open("mysql", connectURL) 38 | if err != nil { 39 | log.Println("Error on initializing database connection,", err.Error()) 40 | return nil, &DBError{"Can't open connection", err} 41 | } 42 | 43 | c.Infof("db max idle connections:%d,max open connections:%d,charset:%s,ext:%v", config.MaxIdle, config.MaxConn, charset, config.Ext) 44 | db.SetMaxIdleConns(config.MaxIdle) 45 | db.SetMaxOpenConns(config.MaxConn) 46 | db.SetConnMaxLifetime(time.Duration(config.MaxTimeSecond) * time.Second) 47 | return &Pool{db: db}, nil 48 | } 49 | -------------------------------------------------------------------------------- /orm/mysql_test.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | _ "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestMysqlCreateor(t *testing.T) { 9 | dbp, err := NewMySQLDBPool(&config) 10 | if err != nil { 11 | t.Errorf("Create fail %s", err.Error()) 12 | return 13 | } 14 | if dbp == nil { 15 | t.Error("Create fail", err) 16 | return 17 | } 18 | 19 | defer dbp.db.Close() 20 | err = dbp.db.Ping() 21 | if err != nil { 22 | t.Errorf("Ping db fail %s", err) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /orm/op.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "database/sql" 5 | "errors" 6 | "fmt" 7 | "reflect" 8 | 9 | c "github.com/d0ngw/go/common" 10 | ) 11 | 12 | // OpTxFunc 在事务中处理的函数 13 | type OpTxFunc func(tx *sql.Tx) (interface{}, error) 14 | 15 | // OpCreator Op 16 | type OpCreator interface { 17 | //NewOp create a new Op 18 | NewOp() (*Op, error) 19 | } 20 | 21 | // Op 数据库操作接口,与sql.DB对应,封装了事务等 22 | type Op struct { 23 | pool *Pool //数据连接 24 | tx *sql.Tx //事务 25 | txDone bool //事务是否结束 26 | rollbackOnly bool //是否只回滚 27 | transDepth int //调用的深度 28 | sharDBSerevcie ShardDBService //分片服务 29 | } 30 | 31 | // DB sql.DB 32 | func (p *Op) DB() *sql.DB { 33 | return p.pool.db 34 | } 35 | 36 | // Pool pool 37 | func (p *Op) Pool() *Pool { 38 | return p.pool 39 | } 40 | 41 | // PoolName name of pool 42 | func (p *Op) PoolName() string { 43 | return p.pool.name 44 | } 45 | 46 | // SetupTableShard use op pool setup entity table shard 47 | func (p *Op) SetupTableShard(entity Entity, ruleName string) error { 48 | if p.sharDBSerevcie == nil { 49 | return errors.New("no shard db service") 50 | } 51 | poolName, err := p.sharDBSerevcie.setupTableShard(entity, ruleName) 52 | if err != nil { 53 | return err 54 | } 55 | if poolName != p.PoolName() { 56 | return fmt.Errorf("op.PoolName %s != entity.PoolName %s", p.PoolName(), poolName) 57 | } 58 | return nil 59 | } 60 | 61 | func (p *Op) close() { 62 | p.tx = nil 63 | p.rollbackOnly = false 64 | p.transDepth = 0 65 | } 66 | 67 | //检查事务的状态 68 | func (p *Op) checkTransStatus() error { 69 | if p.txDone { 70 | return sql.ErrTxDone 71 | } 72 | if p.tx == nil { 73 | return NewDBError(nil, "Not begin transaction") 74 | } 75 | return nil 76 | } 77 | 78 | func (p *Op) incrTransDepth() { 79 | p.transDepth = p.transDepth + 1 80 | } 81 | 82 | func (p *Op) decrTransDepth() error { 83 | p.transDepth = p.transDepth - 1 84 | if p.transDepth < 0 { 85 | return NewDBError(nil, "Too many invoke commit or rollback") 86 | } 87 | return nil 88 | } 89 | 90 | //结束事务 91 | func (p *Op) finishTrans() error { 92 | if err := p.checkTransStatus(); err != nil { 93 | return err 94 | } 95 | if err := p.decrTransDepth(); err != nil { 96 | return err 97 | } 98 | if p.transDepth > 0 { 99 | return nil 100 | } 101 | defer p.close() 102 | p.txDone = true 103 | if p.rollbackOnly { 104 | return p.tx.Rollback() 105 | } 106 | return p.tx.Commit() 107 | } 108 | 109 | // BeginTx 开始事务,支持简单的嵌套调用,如果已经开始了事务,则直接返回成功 110 | func (p *Op) BeginTx() (err error) { 111 | p.incrTransDepth() 112 | if p.tx != nil { 113 | return nil //事务已经开启 114 | } 115 | if tx, err := p.DB().Begin(); err == nil { 116 | p.tx = tx 117 | p.txDone = false 118 | return nil 119 | } 120 | return err 121 | } 122 | 123 | // Commit 提交事务 124 | func (p *Op) Commit() error { 125 | return p.finishTrans() 126 | } 127 | 128 | // Rollback 回滚事务 129 | func (p *Op) Rollback() error { 130 | p.SetRollbackOnly(true) 131 | return p.finishTrans() 132 | } 133 | 134 | // SetRollbackOnly 设置只回滚 135 | func (p *Op) SetRollbackOnly(rollback bool) { 136 | p.rollbackOnly = rollback 137 | } 138 | 139 | // IsRollbackOnly 是否只回滚 140 | func (p *Op) IsRollbackOnly() bool { 141 | return p.rollbackOnly 142 | } 143 | 144 | // DoInTrans 在事务中执行 145 | func (p *Op) DoInTrans(peration OpTxFunc) (rt interface{}, err error) { 146 | if err := p.BeginTx(); err != nil { 147 | return nil, err 148 | } 149 | var succ = false 150 | //结束事务 151 | defer func() { 152 | if !succ { 153 | p.SetRollbackOnly(true) 154 | } 155 | transErr := p.finishTrans() 156 | if transErr != nil { 157 | c.Errorf("Finish transaction err:%v", transErr) 158 | rt = nil 159 | err = transErr 160 | } 161 | }() 162 | rt, err = peration(p.tx) 163 | if err != nil { 164 | c.Errorf("Operation fail:%v", err) 165 | succ = false 166 | } else { 167 | succ = true 168 | } 169 | return 170 | } 171 | 172 | //查找实体对应的模型元 173 | func findEntityMeta(entity Entity) *meta { 174 | _, _, typ := extract(entity) 175 | modelMeta := findMeta(typ) 176 | if modelMeta == nil { 177 | panic(NewDBErrorf(nil, "Can't find modelMeta for:%v ", typ)) 178 | } 179 | return modelMeta 180 | } 181 | 182 | // Add 添加实体 183 | func Add(op *Op, entity Entity) error { 184 | modelMeta := findEntityMeta(entity) 185 | if op.tx != nil { 186 | return modelMeta.insertFunc(op.tx, entity) 187 | } 188 | return modelMeta.insertFunc(op.DB(), entity) 189 | } 190 | 191 | // Update 更新实体 192 | func Update(op *Op, entity Entity) (bool, error) { 193 | modelMeta := findEntityMeta(entity) 194 | if op.tx != nil { 195 | bvalue, err := modelMeta.updateFunc(op.tx, entity) 196 | if err != nil { 197 | return false, err 198 | } 199 | return reflect.ValueOf(bvalue).Bool(), nil 200 | } 201 | return modelMeta.updateFunc(op.DB(), entity) 202 | } 203 | 204 | // UpdateReplace 更新实体 205 | func UpdateReplace(op *Op, entity Entity, replColumns map[string]ReplColumn, excludeColumns map[string]struct{}) (bool, error) { 206 | modelMeta := findEntityMeta(entity) 207 | if op.tx != nil { 208 | bvalue, err := modelMeta.updateReplaceFunc(op.tx, entity, replColumns, excludeColumns) 209 | if err != nil { 210 | return false, err 211 | } 212 | return reflect.ValueOf(bvalue).Bool(), nil 213 | } 214 | return modelMeta.updateReplaceFunc(op.DB(), entity, replColumns, excludeColumns) 215 | } 216 | 217 | // UpdateExcludeColumns 更新除columns之外的字段 218 | func UpdateExcludeColumns(op *Op, entity Entity, columns ...string) (bool, error) { 219 | modelMeta := findEntityMeta(entity) 220 | if op.tx != nil { 221 | bvalue, err := modelMeta.updateExcludeColumnsFunc(op.tx, entity, columns...) 222 | if err != nil { 223 | return false, err 224 | } 225 | return reflect.ValueOf(bvalue).Bool(), nil 226 | } 227 | return modelMeta.updateExcludeColumnsFunc(op.DB(), entity, columns...) 228 | } 229 | 230 | // UpdateColumns 更新列 231 | func UpdateColumns(op *Op, entity Entity, columns string, condition string, params ...interface{}) (int64, error) { 232 | modelMeta := findEntityMeta(entity) 233 | if op.tx != nil { 234 | return modelMeta.updateColumnsFunc(op.tx, entity, columns, condition, params) 235 | } 236 | return modelMeta.updateColumnsFunc(op.DB(), entity, columns, condition, params) 237 | } 238 | 239 | // Get 根据ID查询实体 240 | func Get(op *Op, entity Entity, id interface{}) (Entity, error) { 241 | modelMeta := findEntityMeta(entity) 242 | if op.tx != nil { 243 | e, err := modelMeta.getFunc(op.tx, entity, id) 244 | if e == nil || err != nil { 245 | return nil, err 246 | } 247 | return e.(Entity), nil 248 | } 249 | return modelMeta.getFunc(op.DB(), entity, id) 250 | } 251 | 252 | // Query 根据条件查询实体 253 | func Query(op *Op, entity Entity, condition string, params ...interface{}) ([]Entity, error) { 254 | modelMeta := findEntityMeta(entity) 255 | if op.tx != nil { 256 | return modelMeta.entityQueryFunc(op.tx, entity, condition, params) 257 | } 258 | return modelMeta.entityQueryFunc(op.DB(), entity, condition, params) 259 | } 260 | 261 | // QueryColumns 根据条件查询columns指定的字段 262 | func QueryColumns(op *Op, entity Entity, columns []string, condition string, params ...interface{}) ([]Entity, error) { 263 | modelMeta := findEntityMeta(entity) 264 | if op.tx != nil { 265 | return modelMeta.entityQueryColumnFunc(op.tx, entity, columns, condition, params) 266 | } 267 | return modelMeta.entityQueryColumnFunc(op.DB(), entity, columns, condition, params) 268 | } 269 | 270 | type count struct { 271 | Count int64 272 | } 273 | 274 | // QueryCount 根据条件查询条数 275 | func QueryCount(op *Op, entity Entity, column string, condition string, params ...interface{}) (num int64, err error) { 276 | modelMeta := findEntityMeta(entity) 277 | columns := []string{"count(" + column + ")"} 278 | var counts []*count 279 | if op.tx != nil { 280 | err = modelMeta.clumnsQueryFunc(op.tx, entity, &counts, columns, condition, params) 281 | } else { 282 | err = modelMeta.clumnsQueryFunc(op.DB(), entity, &counts, columns, condition, params) 283 | } 284 | if err != nil { 285 | return 286 | } 287 | if len(counts) > 0 { 288 | num = counts[0].Count 289 | } 290 | return 291 | } 292 | 293 | // QueryColumnsForDestSlice 根据条件查询数据,结果保存到destSlicePtr 294 | func QueryColumnsForDestSlice(op *Op, entity Entity, destSlicePtr interface{}, columns []string, condition string, params ...interface{}) (err error) { 295 | modelMeta := findEntityMeta(entity) 296 | if op.tx != nil { 297 | err = modelMeta.clumnsQueryFunc(op.tx, entity, destSlicePtr, columns, condition, params) 298 | } else { 299 | err = modelMeta.clumnsQueryFunc(op.DB(), entity, destSlicePtr, columns, condition, params) 300 | } 301 | if err != nil { 302 | return 303 | } 304 | return 305 | } 306 | 307 | // Del 根据ID删除实体 308 | func Del(op *Op, entity Entity, id interface{}) (bool, error) { 309 | modelMeta := findEntityMeta(entity) 310 | if op.tx != nil { 311 | return modelMeta.delEFunc(op.tx, entity, id) 312 | } 313 | return modelMeta.delEFunc(op.DB(), entity, id) 314 | } 315 | 316 | // DelByCondition 根据条件删除 317 | func DelByCondition(op *Op, entity Entity, condition string, params ...interface{}) (int64, error) { 318 | modelMeta := findEntityMeta(entity) 319 | if op.tx != nil { 320 | return modelMeta.delFunc(op.tx, entity, condition, params) 321 | } 322 | return modelMeta.delFunc(op.DB(), entity, condition, params) 323 | } 324 | 325 | // AddOrUpdate 添加或者更新实体(如果id已经存在),只支持MySql 326 | func AddOrUpdate(op *Op, entity Entity) (int64, error) { 327 | modelMeta := findEntityMeta(entity) 328 | if op.tx != nil { 329 | return modelMeta.insertOrUpdateFunc(op.tx, entity) 330 | } 331 | return modelMeta.insertOrUpdateFunc(op.DB(), entity) 332 | } 333 | -------------------------------------------------------------------------------- /orm/orm_meta_test.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | type SecondID struct { 10 | AutoID 11 | ID2 int64 `column:"id2"` 12 | Name3 string `column:"name3"` 13 | } 14 | 15 | func (p *SecondID) TableName() string { 16 | return "second_test" 17 | } 18 | 19 | type ThirdID struct { 20 | ID3 int64 `column:"id3"` 21 | SecondID 22 | Name4 string `column:"name4"` 23 | } 24 | 25 | type FourthID struct { 26 | ID4 int64 `column:"id4"` 27 | ThirdID 28 | Name5 string `column:"name5"` 29 | Name6 string `column:"name6"` 30 | Name7 string `column:"name7"` 31 | Name8 string `column:"name8"` 32 | Name9 string `column:"name9"` 33 | Name10 string `column:"name10"` 34 | Name11 string `column:"name11"` 35 | Conf *Conf `column:"conf"` 36 | } 37 | 38 | func TestPaseMeta(t *testing.T) { 39 | meta, err := parseMeta(&FourthID{}) 40 | assert.NoError(t, err) 41 | assert.NotNil(t, meta) 42 | assert.EqualValues(t, 15, len(meta.fields)) 43 | expectIndexs := map[string][]int{ 44 | "id4": {0}, 45 | "id3": {1, 0}, 46 | "id": {1, 1, 0, 0}, 47 | "name2": {1, 1, 0, 1}, 48 | "id2": {1, 1, 1}, 49 | "name3": {1, 1, 2}, 50 | "name4": {1, 2}, 51 | "name5": {2}, 52 | "name6": {3}, 53 | "name7": {4}, 54 | "name8": {5}, 55 | "name9": {6}, 56 | "name10": {7}, 57 | "name11": {8}, 58 | "conf": {9}, 59 | } 60 | for _, field := range meta.fields { 61 | t.Logf("field:%v,%v", field, expectIndexs[field.column]) 62 | assert.EqualValues(t, field.index, expectIndexs[field.column]) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /orm/service_shard.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "fmt" 5 | 6 | c "github.com/d0ngw/go/common" 7 | ) 8 | 9 | // ShardDBService 支持分库,分表的DBService 10 | type ShardDBService interface { 11 | DBService 12 | // NewOpByShardName create op by shard name 13 | NewOpByShardName(name string) (op *Op, err error) 14 | // NewOpByEntity create Op for entity with rule name,if rule name is empty use default rule 15 | NewOpByEntity(entity Entity, ruleName string) (op *Op, err error) 16 | // setupTableShard setup ShardEntity.TableShardFunc with rule name,if rule name is empty use default rule 17 | setupTableShard(entity Entity, ruleName string) (poolName string, err error) 18 | } 19 | 20 | // SimpleShardDBService implements DBService interface 21 | type SimpleShardDBService struct { 22 | DBShardConfig DBShardConfigurer `inject:"_"` 23 | EntityShardConfig EntityShardConfigurer `inject:"_,optional"` 24 | poolFunc PoolFunc 25 | pools map[string]*Pool 26 | defaultPool *Pool 27 | } 28 | 29 | // NewSimpleShardDBService create 30 | func NewSimpleShardDBService(poolFunc PoolFunc) *SimpleShardDBService { 31 | return &SimpleShardDBService{poolFunc: poolFunc} 32 | } 33 | 34 | // Init implements Initable.Init() 35 | func (p *SimpleShardDBService) Init() error { 36 | if p.poolFunc == nil { 37 | return fmt.Errorf("no pool func") 38 | } 39 | if p.pools != nil { 40 | return fmt.Errorf("Inited") 41 | } 42 | if p.DBShardConfig == nil { 43 | return fmt.Errorf("no db shard config") 44 | } 45 | 46 | pools := map[string]*Pool{} 47 | var defaultPool *Pool 48 | 49 | config := p.DBShardConfig.DBShardConfig() 50 | if config == nil { 51 | c.Warnf("no db shard config") 52 | } else { 53 | for k, v := range config.Shards { 54 | pool, err := p.poolFunc(v) 55 | if err != nil { 56 | return err 57 | } 58 | pools[k] = pool 59 | pool.name = k 60 | } 61 | 62 | if config.Default != "" { 63 | defaultPool = pools[config.Default] 64 | if defaultPool == nil { 65 | return fmt.Errorf("can't find default pool for %s", config.Default) 66 | } 67 | } 68 | p.pools = pools 69 | p.defaultPool = defaultPool 70 | if p.defaultPool == nil { 71 | return fmt.Errorf("no default pool") 72 | } 73 | } 74 | return nil 75 | } 76 | 77 | // NewOp create default op 78 | func (p *SimpleShardDBService) NewOp() (op *Op, err error) { 79 | pool, err := p.getDefaultPool() 80 | if err != nil { 81 | return 82 | } 83 | op = pool.NewOp() 84 | op.sharDBSerevcie = p 85 | return 86 | } 87 | 88 | // NewOpByShardName create Op by shard name 89 | func (p *SimpleShardDBService) NewOpByShardName(poolName string) (op *Op, err error) { 90 | pool := p.pools[poolName] 91 | if pool == nil { 92 | err = fmt.Errorf("can't find pool by name %s", poolName) 93 | return 94 | } 95 | op = pool.NewOp() 96 | op.sharDBSerevcie = p 97 | return 98 | } 99 | 100 | // NewOpByEntity create Op for entity with rule name,if rule name is empty use default rule 101 | func (p *SimpleShardDBService) NewOpByEntity(entity Entity, ruleName string) (op *Op, err error) { 102 | pool, err := p.matchPoolAndSetupTblShard(entity, ruleName) 103 | if err != nil { 104 | return 105 | } 106 | 107 | op = pool.NewOp() 108 | op.sharDBSerevcie = p 109 | return 110 | } 111 | 112 | // setupTableShard setup ShardEntity.TableShardFunc 113 | func (p *SimpleShardDBService) setupTableShard(entity Entity, ruleName string) (poolName string, err error) { 114 | pool, err := p.matchPoolAndSetupTblShard(entity, ruleName) 115 | if err != nil { 116 | return 117 | } 118 | poolName = pool.Name() 119 | return 120 | } 121 | 122 | func (p *SimpleShardDBService) findShardRule(entity Entity, ruleName string) (rule *EntityShardRuleConfig, err error) { 123 | if p.EntityShardConfig == nil || p.EntityShardConfig.EntityShardConfig() == nil { 124 | return nil, nil 125 | } 126 | _, _, typ := extract(entity) 127 | 128 | pkgPath := typ.PkgPath() 129 | name := typ.Name() 130 | 131 | shardConf := p.EntityShardConfig.EntityShardConfig() 132 | 133 | pkg := shardConf.entities[pkgPath] 134 | if pkg == nil { 135 | return nil, nil 136 | } 137 | 138 | entityRules := pkg[name] 139 | if entityRules == nil { 140 | return nil, nil 141 | } 142 | 143 | if ruleName != "" { 144 | rule = entityRules.rules[ruleName] 145 | if rule == nil { 146 | return nil, fmt.Errorf("can't find rule for %s.%s for rule %s", pkgPath, name, ruleName) 147 | } 148 | } else { 149 | rule = entityRules.defaultRule 150 | } 151 | return 152 | } 153 | 154 | func (p *SimpleShardDBService) findShardPool(entity Entity, rule *EntityShardRuleConfig) (pool *Pool, err error) { 155 | if rule == nil || rule.DBShard == nil { 156 | return p.getDefaultPool() 157 | } 158 | 159 | dbShardRule := rule.DBShard 160 | 161 | fieldName := dbShardRule.ShardFieldName() 162 | var fieldVal interface{} 163 | 164 | if fieldName != "" { 165 | fieldVal, err = rule.meta.FieldValue(entity, fieldName) 166 | if err != nil { 167 | return 168 | } 169 | } 170 | 171 | shardName, err := dbShardRule.Shard(fieldVal) 172 | if err != nil { 173 | return 174 | } 175 | 176 | pool = p.pools[shardName] 177 | if pool == nil { 178 | err = fmt.Errorf("can't find pool by name %s", shardName) 179 | return 180 | } 181 | return 182 | } 183 | 184 | func (p *SimpleShardDBService) findTableShardHandler(entity Entity, rule *EntityShardRuleConfig) (handler ShardHandler, err error) { 185 | if rule == nil || rule.TableShard == nil { 186 | return 187 | } 188 | 189 | tableShardRule := rule.TableShard 190 | 191 | fieldName := tableShardRule.ShardFieldName() 192 | var fieldVal interface{} 193 | 194 | if fieldName != "" { 195 | fieldVal, err = rule.meta.FieldValue(entity, fieldName) 196 | if err != nil { 197 | return 198 | } 199 | } 200 | 201 | shardName, err := tableShardRule.Shard(fieldVal) 202 | if err != nil { 203 | return 204 | } 205 | 206 | handler = func() (string, error) { 207 | return shardName, nil 208 | } 209 | return 210 | } 211 | 212 | func (p *SimpleShardDBService) matchPoolAndSetupTblShard(entity Entity, ruleName string) (pool *Pool, err error) { 213 | if entity == nil { 214 | err = fmt.Errorf("invalid meta") 215 | return 216 | } 217 | 218 | rule, err := p.findShardRule(entity, ruleName) 219 | if err != nil { 220 | return 221 | } 222 | 223 | pool, err = p.findShardPool(entity, rule) 224 | if err != nil { 225 | return 226 | } 227 | 228 | if shardEntity, ok := entity.(ShardEntity); ok { 229 | handler, err := p.findTableShardHandler(entity, rule) 230 | if err != nil { 231 | return nil, err 232 | } 233 | if handler != nil { 234 | shardEntity.SetTableShardFunc(handler) 235 | } 236 | } 237 | return 238 | } 239 | 240 | func (p *SimpleShardDBService) getDefaultPool() (pool *Pool, err error) { 241 | if p.defaultPool != nil { 242 | return p.defaultPool, nil 243 | } 244 | return nil, fmt.Errorf("no default pool") 245 | } 246 | -------------------------------------------------------------------------------- /orm/service_shard_test.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | 7 | c "github.com/d0ngw/go/common" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestShardDBServcie(t *testing.T) { 12 | defaultMetaReg.clean() 13 | tm := tmodel{} 14 | AddMeta(&tm) 15 | user := &User{} 16 | AddMeta(user) 17 | 18 | conf := &shardConf{} 19 | err := c.LoadYAMLFromPath("testdata/shard.yaml", conf) 20 | assert.NoError(t, err) 21 | 22 | err = conf.Parse() 23 | assert.NoError(t, err) 24 | 25 | shardServcie := NewSimpleShardDBService(NewMySQLDBPool) 26 | shardServcie.DBShardConfig = conf 27 | shardServcie.EntityShardConfig = conf 28 | 29 | err = shardServcie.Init() 30 | assert.Nil(t, err) 31 | 32 | op, err := shardServcie.NewOp() 33 | assert.Equal(t, "test0", op.PoolName()) 34 | 35 | var shardSvr ShardDBService = shardServcie 36 | assert.NotNil(t, shardSvr) 37 | 38 | op, err = shardServcie.NewOpByShardName("no exist") 39 | assert.NotNil(t, err) 40 | assert.Nil(t, op) 41 | 42 | op, err = shardServcie.NewOpByShardName("test0") 43 | assert.NotNil(t, op) 44 | assert.Nil(t, err) 45 | 46 | op, err = shardServcie.NewOpByEntity(&tm, "") 47 | assert.NotNil(t, op) 48 | assert.Nil(t, err) 49 | assert.Nil(t, tm.tblShardFunc) 50 | assert.Equal(t, "test0", op.PoolName()) 51 | 52 | tm.ID = 2 53 | op, err = shardServcie.NewOpByEntity(&tm, "test_db_shard_hash") 54 | assert.NotNil(t, op) 55 | assert.Nil(t, err) 56 | assert.NotNil(t, tm.tblShardFunc) 57 | assert.Equal(t, "test_2", op.PoolName()) 58 | 59 | tblShard, err := tm.tblShardFunc() 60 | assert.Nil(t, err) 61 | assert.Equal(t, "tt_2", tblShard) 62 | 63 | tm.Name = sql.NullString{String: "ok", Valid: true} 64 | err = Add(op, &tm) 65 | assert.Nil(t, err) 66 | 67 | ret, err := Get(op, &tm, tm.ID) 68 | assert.Nil(t, err) 69 | assert.NotNil(t, ret) 70 | 71 | del, err := Del(op, &tm, tm.ID) 72 | assert.Nil(t, err) 73 | assert.True(t, del) 74 | 75 | user0 := &User{ 76 | Name: sql.NullString{String: "u0", Valid: true}, 77 | Age: 1, 78 | } 79 | user1 := &User{ 80 | Name: sql.NullString{String: "u1", Valid: true}, 81 | Age: 2, 82 | } 83 | user2 := &User{ 84 | Name: sql.NullString{String: "u2", Valid: true}, 85 | Age: 3, 86 | } 87 | 88 | err = op.SetupTableShard(user0, "") 89 | assert.NoError(t, err) 90 | 91 | err = op.SetupTableShard(user1, "") 92 | assert.NoError(t, err) 93 | 94 | err = op.SetupTableShard(user2, "") 95 | assert.NoError(t, err) 96 | 97 | _, err = op.DoInTrans(func(tx *sql.Tx) (interface{}, error) { 98 | e := Add(op, &tm) 99 | assert.NoError(t, e) 100 | 101 | e = Add(op, user0) 102 | assert.NoError(t, e) 103 | 104 | e = Add(op, user1) 105 | assert.NoError(t, e) 106 | 107 | e = Add(op, user2) 108 | assert.NoError(t, e) 109 | return nil, nil 110 | }) 111 | } 112 | -------------------------------------------------------------------------------- /orm/service_simple.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "fmt" 5 | 6 | c "github.com/d0ngw/go/common" 7 | ) 8 | 9 | // DBService is the service that supply DBOper 10 | type DBService interface { 11 | c.Initable 12 | OpCreator 13 | } 14 | 15 | // SimpleDBService implements DBService interface 16 | type SimpleDBService struct { 17 | Config DBConfigurer `inject:"_"` 18 | poolFunc PoolFunc 19 | pool *Pool 20 | } 21 | 22 | // NewSimpleDBService build simple db service 23 | func NewSimpleDBService(poolFunc PoolFunc) *SimpleDBService { 24 | return &SimpleDBService{poolFunc: poolFunc} 25 | } 26 | 27 | // Init implements Initable.Init() 28 | func (p *SimpleDBService) Init() error { 29 | if p.poolFunc == nil { 30 | return fmt.Errorf("no pool func") 31 | } 32 | if p.pool != nil { 33 | return fmt.Errorf("Inited") 34 | } 35 | if p.Config == nil { 36 | return fmt.Errorf("No db config") 37 | } 38 | 39 | pool, err := p.poolFunc(p.Config.DBConfig()) 40 | if err != nil { 41 | return err 42 | } 43 | p.pool = pool 44 | return nil 45 | } 46 | 47 | // NewOp implements DBService.NewOp() 48 | func (p *SimpleDBService) NewOp() (*Op, error) { 49 | if p.pool == nil { 50 | return nil, fmt.Errorf("please init db pool") 51 | } 52 | return p.pool.NewOp(), nil 53 | } 54 | -------------------------------------------------------------------------------- /orm/shard.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "sort" 7 | "strconv" 8 | 9 | c "github.com/d0ngw/go/common" 10 | ) 11 | 12 | // ShardPolicy 分片规则 13 | type ShardPolicy string 14 | 15 | // ShardRule 分片规则的实现 16 | type ShardRule interface { 17 | c.Configurer 18 | // Policy 返回策略名称 19 | Policy() ShardPolicy 20 | // Shard 计算分片的名称 21 | Shard(val interface{}) (shardName string, err error) 22 | // ShardFieldName 用于分片的字段名称 23 | ShardFieldName() string 24 | } 25 | 26 | const ( 27 | //Hash hash shard 28 | Hash ShardPolicy = "hash" 29 | //Named named shard 30 | Named ShardPolicy = "named" 31 | //NumRange number range shard 32 | NumRange ShardPolicy = "num_range" 33 | ) 34 | 35 | // IsValid 是否有效 36 | func (p ShardPolicy) IsValid() bool { 37 | return p == Hash || p == Named || p == NumRange 38 | } 39 | 40 | // HashRule hash规则 41 | type HashRule struct { 42 | Count int64 `yaml:"count"` //hash的个数 43 | NamePrefix string `yaml:"name_prefix"` //名称的前缀 44 | FieldName string `yaml:"field_name"` //hash取值的字段名 45 | } 46 | 47 | // Policy implements ShardRule 48 | func (p *HashRule) Policy() ShardPolicy { 49 | return Hash 50 | } 51 | 52 | // Parse implements Configurer 53 | func (p *HashRule) Parse() error { 54 | if p.Count <= 0 { 55 | return fmt.Errorf("invalid count") 56 | } 57 | if p.NamePrefix == "" { 58 | return fmt.Errorf("invalid name_prefix") 59 | } 60 | if p.FieldName == "" { 61 | return fmt.Errorf("invalid field_name") 62 | } 63 | return nil 64 | } 65 | 66 | // Shard implements ShardRule.Shard 67 | func (p *HashRule) Shard(val interface{}) (shardName string, err error) { 68 | valInt64, err := c.Int64(val) 69 | if err != nil { 70 | return 71 | } 72 | if valInt64 < 0 { 73 | return "", fmt.Errorf("invalid hash val %v", val) 74 | } 75 | return p.NamePrefix + strconv.FormatInt((valInt64%p.Count), 10), nil 76 | } 77 | 78 | // ShardFieldName 用于分片的字段名 79 | func (p *HashRule) ShardFieldName() string { 80 | return p.FieldName 81 | } 82 | 83 | // NamedRule 指定命名 84 | type NamedRule struct { 85 | Name string `yaml:"name"` 86 | } 87 | 88 | // Policy implements ShardRule 89 | func (p *NamedRule) Policy() ShardPolicy { 90 | return Named 91 | } 92 | 93 | // Parse implements Configurer 94 | func (p *NamedRule) Parse() error { 95 | if p.Name == "" { 96 | return fmt.Errorf("invalid name") 97 | } 98 | return nil 99 | } 100 | 101 | // Shard implements ShardRule.Shard 102 | func (p *NamedRule) Shard(val interface{}) (shardName string, err error) { 103 | return p.Name, nil 104 | } 105 | 106 | // ShardFieldName 用于分片的字段名 107 | func (p *NamedRule) ShardFieldName() string { 108 | return "" 109 | } 110 | 111 | // NumRangeRule 数字区间 112 | type NumRangeRule struct { 113 | FieldName string `yaml:"field_name"` //分片取值的字段名 114 | DefaultName string `yaml:"default_name"` //默认名称 115 | Ranges []*struct { 116 | Begin int64 `yaml:"begin"` 117 | End int64 `yaml:"end"` 118 | Name string `yaml:"name"` 119 | } `yaml:"ranges"` 120 | } 121 | 122 | // Policy implements ShardRule 123 | func (p *NumRangeRule) Policy() ShardPolicy { 124 | return NumRange 125 | } 126 | 127 | // Parse implements Configurer 128 | func (p *NumRangeRule) Parse() error { 129 | if len(p.Ranges) == 0 { 130 | return fmt.Errorf("invalid ranges") 131 | } 132 | 133 | for _, v := range p.Ranges { 134 | if v.Begin > v.End { 135 | return fmt.Errorf("invalid range begin:%d > end:%d", v.Begin, v.End) 136 | } 137 | } 138 | sort.Slice(p.Ranges, func(i, j int) bool { 139 | return p.Ranges[i].Begin < p.Ranges[j].Begin 140 | }) 141 | 142 | for i := 1; i < len(p.Ranges); i++ { 143 | if p.Ranges[i].Begin <= p.Ranges[i-1].End { 144 | return fmt.Errorf("invalid range[%d].Begin %d <= range[%d].End %d", i, p.Ranges[i].Begin, i-1, p.Ranges[i-1].End) 145 | } 146 | } 147 | return nil 148 | } 149 | 150 | // Shard implements ShardRule.Shard 151 | func (p *NumRangeRule) Shard(val interface{}) (shardName string, err error) { 152 | valInt64, err := c.Int64(val) 153 | if err != nil { 154 | return 155 | } 156 | 157 | i, j, found := 0, len(p.Ranges), -1 158 | 159 | for i < j { 160 | h := int(uint(i+j) >> 1) 161 | r := p.Ranges[h] 162 | if r.Begin <= valInt64 && r.End >= valInt64 { 163 | found = h 164 | break 165 | } else if r.Begin < valInt64 { 166 | i = h + 1 167 | } else if r.Begin > valInt64 { 168 | j = h 169 | } 170 | } 171 | if found >= 0 { 172 | return p.Ranges[found].Name, nil 173 | } 174 | if p.DefaultName != "" { 175 | return p.DefaultName, nil 176 | } 177 | return "", fmt.Errorf("can't find name for val %d", val) 178 | } 179 | 180 | // ShardFieldName 用于分片的字段名 181 | func (p *NumRangeRule) ShardFieldName() string { 182 | return p.FieldName 183 | } 184 | 185 | // OneRule 选择一个 186 | type OneRule struct { 187 | Hash *HashRule `yaml:"hash"` 188 | Named *NamedRule `yaml:"named"` 189 | NumRange *NumRangeRule `yaml:"num_range"` 190 | policy ShardPolicy 191 | rule ShardRule 192 | } 193 | 194 | // Parse implements Configurer 195 | func (p *OneRule) Parse() error { 196 | var rules = []ShardRule{p.Hash, p.Named, p.NumRange} 197 | for _, v := range rules { 198 | if v == nil || reflect.ValueOf(v).IsNil() { 199 | continue 200 | } 201 | if err := v.Parse(); err != nil { 202 | return err 203 | } 204 | if p.policy == "" { 205 | p.policy = v.Policy() 206 | p.rule = v 207 | } else { 208 | return fmt.Errorf("only allow one rule") 209 | } 210 | } 211 | 212 | if p.policy == "" || p.rule == nil { 213 | return fmt.Errorf("no rule") 214 | } 215 | return nil 216 | } 217 | 218 | // Policy implements ShardPolicy.Policy 219 | func (p *OneRule) Policy() ShardPolicy { 220 | return p.policy 221 | } 222 | 223 | // Shard implements ShardPolicy.Shard 224 | func (p *OneRule) Shard(val interface{}) (shardName string, err error) { 225 | return p.rule.Shard(val) 226 | } 227 | 228 | // ShardFieldName 用于分片的字段名 229 | func (p *OneRule) ShardFieldName() string { 230 | return p.rule.ShardFieldName() 231 | } 232 | -------------------------------------------------------------------------------- /orm/shard_test.go: -------------------------------------------------------------------------------- 1 | package orm 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | 8 | c "github.com/d0ngw/go/common" 9 | ) 10 | 11 | type shardConf struct { 12 | c.AppConfig 13 | DBShards *DBShardConfig `yaml:"db_shards"` 14 | EntityShards *EntityShardConfig `yaml:"entity_shards"` 15 | } 16 | 17 | func (p *shardConf) Parse() error { 18 | if err := p.AppConfig.Parse(); err != nil { 19 | return err 20 | } 21 | 22 | if err := p.DBShards.Parse(); err != nil { 23 | return err 24 | } 25 | 26 | if err := p.EntityShards.Parse(); err != nil { 27 | return err 28 | } 29 | return nil 30 | } 31 | 32 | func (p *shardConf) DBShardConfig() *DBShardConfig { 33 | return p.DBShards 34 | } 35 | 36 | func (p *shardConf) EntityShardConfig() *EntityShardConfig { 37 | return p.EntityShards 38 | } 39 | 40 | func TestShardConfig(t *testing.T) { 41 | defaultMetaReg.clean() 42 | tm := tmodel{} 43 | meta := AddMeta(&tm) 44 | AddMeta(&User{}) 45 | 46 | conf := &shardConf{} 47 | err := c.LoadYAMLFromPath("testdata/shard.yaml", conf) 48 | assert.NoError(t, err) 49 | 50 | err = conf.Parse() 51 | assert.NoError(t, err) 52 | 53 | assert.NotNil(t, conf.DBShards) 54 | assert.NotNil(t, conf.EntityShards) 55 | 56 | testShard := conf.DBShards.Shards["test0"] 57 | defaultShard := conf.DBShards.Default 58 | assert.NotNil(t, testShard) 59 | assert.NotNil(t, defaultShard) 60 | assert.True(t, defaultShard == "test0") 61 | 62 | entityRule := conf.EntityShards.entities[meta.Type().PkgPath()][meta.Type().Name()] 63 | assert.NotNil(t, entityRule) 64 | assert.NotNil(t, entityRule.meta) 65 | assert.True(t, meta == entityRule.meta) 66 | 67 | defaultRule := entityRule.defaultRule 68 | assert.NotNil(t, defaultRule) 69 | assert.Equal(t, defaultRule.Name, "default") 70 | assert.Equal(t, defaultRule, entityRule.rules["default"]) 71 | 72 | testDBShardHash := entityRule.rules["test_db_shard_hash"] 73 | assert.NotNil(t, testDBShardHash) 74 | 75 | dbHashRule := testDBShardHash.DBShard.Hash 76 | tableHasRule := testDBShardHash.TableShard.Hash 77 | assert.NotNil(t, dbHashRule) 78 | assert.NotNil(t, tableHasRule) 79 | 80 | var hashTest = func(hRule *HashRule, namePrefix string) { 81 | assert.NotNil(t, dbHashRule) 82 | assert.EqualValues(t, 100, hRule.Count) 83 | assert.Equal(t, namePrefix, hRule.NamePrefix) 84 | assert.Equal(t, "id", hRule.FieldName) 85 | name, err := hRule.Shard(0) 86 | assert.Nil(t, err) 87 | assert.Equal(t, namePrefix+"0", name) 88 | name, err = hRule.Shard(1) 89 | assert.Nil(t, err) 90 | assert.Equal(t, namePrefix+"1", name) 91 | name, err = hRule.Shard(99) 92 | assert.Nil(t, err) 93 | assert.Equal(t, namePrefix+"99", name) 94 | } 95 | hashTest(dbHashRule, "test_") 96 | hashTest(tableHasRule, "tt_") 97 | 98 | testDBShardNamed := entityRule.rules["test_db_shard_named"] 99 | assert.NotNil(t, testDBShardNamed) 100 | 101 | dbNamedRule := testDBShardNamed.DBShard.Named 102 | tableNamedRule := testDBShardNamed.TableShard.Named 103 | assert.NotNil(t, dbNamedRule) 104 | assert.NotNil(t, tableNamedRule) 105 | 106 | var namedTest = func(nRule *NamedRule, name string) { 107 | assert.Equal(t, name, nRule.Name) 108 | name, err := nRule.Shard(0) 109 | assert.Nil(t, err) 110 | assert.Equal(t, name, name) 111 | } 112 | namedTest(dbNamedRule, "test0") 113 | namedTest(tableNamedRule, "tt") 114 | 115 | testDBShardNumRange := entityRule.rules["test_db_shard_num_range"] 116 | assert.NotNil(t, testDBShardNumRange) 117 | 118 | dbNumRangeRule := testDBShardNumRange.DBShard.NumRange 119 | tableNumRangeRule := testDBShardNumRange.TableShard.NumRange 120 | assert.NotNil(t, dbNumRangeRule) 121 | assert.NotNil(t, tableNumRangeRule) 122 | 123 | var numRangeTest = func(nrRule *NumRangeRule, defaultName string, namePrefix string) { 124 | assert.NotNil(t, nrRule) 125 | assert.Equal(t, "id", nrRule.FieldName) 126 | assert.Equal(t, defaultName, nrRule.DefaultName) 127 | assert.EqualValues(t, 3, len(nrRule.Ranges)) 128 | assert.EqualValues(t, 0, nrRule.Ranges[0].Begin) 129 | assert.EqualValues(t, 100, nrRule.Ranges[0].End) 130 | assert.EqualValues(t, 101, nrRule.Ranges[1].Begin) 131 | assert.EqualValues(t, 200, nrRule.Ranges[1].End) 132 | assert.EqualValues(t, 500, nrRule.Ranges[2].Begin) 133 | assert.EqualValues(t, 1000, nrRule.Ranges[2].End) 134 | 135 | var name string 136 | 137 | for i := -100; i <= -1; i++ { 138 | name, err = nrRule.Shard(i) 139 | assert.Nil(t, err) 140 | assert.Equal(t, defaultName, name, "i=%d", i) 141 | } 142 | 143 | for i := 0; i <= 100; i++ { 144 | name, err = nrRule.Shard(i) 145 | assert.Nil(t, err) 146 | assert.Equal(t, namePrefix+"100", name, "i=%d", i) 147 | } 148 | 149 | for i := 101; i <= 200; i++ { 150 | name, err = nrRule.Shard(i) 151 | assert.Nil(t, err) 152 | assert.Equal(t, namePrefix+"200", name, "i=%d", i) 153 | } 154 | 155 | for i := 201; i <= 499; i++ { 156 | name, err = nrRule.Shard(i) 157 | assert.Nil(t, err) 158 | assert.Equal(t, defaultName, name, "i=%d", i) 159 | } 160 | 161 | for i := 500; i <= 1000; i++ { 162 | name, err = nrRule.Shard(i) 163 | assert.Nil(t, err) 164 | assert.Equal(t, namePrefix+"1000", name, "i=%d", i) 165 | } 166 | 167 | for i := 1001; i <= 10000; i++ { 168 | name, err = nrRule.Shard(i) 169 | assert.Nil(t, err) 170 | assert.Equal(t, defaultName, name, "i=%d", i) 171 | } 172 | 173 | for i := 100001; i <= 100005; i++ { 174 | name, err = nrRule.Shard(i) 175 | assert.Nil(t, err) 176 | assert.Equal(t, defaultName, name, "i=%d", i) 177 | } 178 | } 179 | 180 | numRangeTest(dbNumRangeRule, "test0", "test_") 181 | numRangeTest(tableNumRangeRule, "tt", "tt_") 182 | defaultMetaReg.clean() 183 | } 184 | -------------------------------------------------------------------------------- /orm/testdata/setup.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE IF NOT EXISTS `tt` ( 2 | `id` bigint(20) NOT NULL AUTO_INCREMENT, 3 | `name` varchar(64) DEFAULT NULL, 4 | `name2` varchar(64) DEFAULT "", 5 | `create_time` bigint(20) DEFAULT NULL, 6 | `f64` double DEFAULT NULL, 7 | `conf` varchar(64), 8 | `conf2` varchar(64) DEFAULT "", 9 | `ver` bigint NOT NULL DEFAULT 0, 10 | `age` bigint NOT NULL DEFAULT 0, 11 | PRIMARY KEY (`id`)) 12 | ENGINE=InnoDB DEFAULT CHARSET=utf8; 13 | -- 14 | CREATE TABLE IF NOT EXISTS `tt_2` ( 15 | `id` bigint(20) NOT NULL AUTO_INCREMENT, 16 | `name` varchar(64) DEFAULT NULL, 17 | `name2` varchar(64) DEFAULT "", 18 | `create_time` bigint(20) DEFAULT NULL, 19 | `f64` double DEFAULT NULL, 20 | `conf` varchar(64), 21 | `conf2` varchar(64) DEFAULT "", 22 | `ver` bigint NOT NULL DEFAULT 0, 23 | `age` bigint NOT NULL DEFAULT 0, 24 | PRIMARY KEY (`id`)) 25 | ENGINE=InnoDB DEFAULT CHARSET=utf8; 26 | -- 27 | CREATE TABLE IF NOT EXISTS `user_0` ( 28 | `id` bigint(20) NOT NULL AUTO_INCREMENT, 29 | `name` varchar(64) DEFAULT NULL, 30 | `age` bigint(20) NOT NULL DEFAULT 0, 31 | `birthday` DATE COMMENT '生日', 32 | PRIMARY KEY (`id`)) 33 | ENGINE=InnoDB DEFAULT CHARSET=utf8; 34 | -- 35 | CREATE TABLE IF NOT EXISTS `user_1` ( 36 | `id` bigint(20) NOT NULL AUTO_INCREMENT, 37 | `name` varchar(64) DEFAULT NULL, 38 | `age` bigint(20) NOT NULL DEFAULT 0, 39 | `birthday` DATE COMMENT '生日', 40 | PRIMARY KEY (`id`)) 41 | ENGINE=InnoDB DEFAULT CHARSET=utf8; 42 | -- 43 | CREATE TABLE IF NOT EXISTS `user_2` ( 44 | `id` bigint(20) NOT NULL AUTO_INCREMENT, 45 | `name` varchar(64) DEFAULT NULL, 46 | `age` bigint(20) NOT NULL DEFAULT 0, 47 | `birthday` DATE COMMENT '生日', 48 | PRIMARY KEY (`id`)) 49 | ENGINE=InnoDB DEFAULT CHARSET=utf8; -------------------------------------------------------------------------------- /orm/testdata/shard.yaml: -------------------------------------------------------------------------------- 1 | db_shards: 2 | default: "test0" 3 | shards: 4 | test0: 5 | user: "root" 6 | pass: "123456" 7 | url: "127.0.0.1:3306" 8 | schema: "test" 9 | charset: "utf8mb4" 10 | maxConn: 100 11 | maxIdle: 1 12 | test_2: 13 | user: "root" 14 | pass: "123456" 15 | url: "127.0.0.1:3306" 16 | schema: "test" 17 | charset: "utf8mb4" 18 | maxConn: 100 19 | maxIdle: 1 20 | 21 | entity_shards: 22 | entities: 23 | github.com/d0ngw/go/orm: 24 | tmodel: 25 | - name: default 26 | default: true 27 | 28 | - name: test_db_shard_hash 29 | default: false 30 | db_shard: 31 | hash: 32 | count: 100 33 | name_prefix: "test_" 34 | field_name: "id" 35 | table_shard: 36 | hash: 37 | count: 100 38 | name_prefix: "tt_" 39 | field_name: "id" 40 | 41 | - name: test_db_shard_named 42 | default: false 43 | db_shard: 44 | named: 45 | name: "test0" 46 | table_shard: 47 | named: 48 | name: "tt" 49 | 50 | - name: test_db_shard_num_range 51 | default: false 52 | db_shard: 53 | num_range: 54 | field_name: "id" 55 | default_name: "test0" 56 | ranges: 57 | - begin: 101 58 | end: 200 59 | name: test_200 60 | - begin: 0 61 | end: 100 62 | name: test_100 63 | - begin: 500 64 | end: 1000 65 | name: test_1000 66 | table_shard: 67 | num_range: 68 | field_name: "id" 69 | default_name: "tt" 70 | ranges: 71 | - begin: 101 72 | end: 200 73 | name: tt_200 74 | - begin: 0 75 | end: 100 76 | name: tt_100 77 | - begin: 500 78 | end: 1000 79 | name: tt_1000 80 | 81 | User: 82 | - name: default 83 | default: true 84 | db_shard: 85 | named: 86 | name: "test_2" 87 | table_shard: 88 | hash: 89 | count: 3 90 | name_prefix: "user_" 91 | field_name: "age" 92 | 93 | 94 | -------------------------------------------------------------------------------- /orm/testdata/teardown.sql: -------------------------------------------------------------------------------- 1 | DROP TABLE IF EXISTS `tt`; 2 | -- 3 | DROP TABLE IF EXISTS `tt_2`; 4 | -- 5 | DROP TABLE IF EXISTS `user_0`; 6 | -- 7 | DROP TABLE IF EXISTS `user_1`; 8 | -- 9 | DROP TABLE IF EXISTS `user_2`; --------------------------------------------------------------------------------