├── .gitignore ├── throttle ├── rate_limiter.go ├── rate_limiter_memory_test.go ├── rate_limiter_redis.go └── rate_limiter_memory.go ├── .travis.yml ├── server ├── init.go ├── route-test.yml ├── route │ ├── trie_tree_test.go │ ├── matcher.go │ ├── router_test.go │ ├── trie_tree.go │ └── router.go ├── lb │ ├── loadbalancer_test.go │ └── loadbalancer.go ├── filter_serv_match.go ├── filter_url_rewrite.go ├── syncmap_rate_limiter.go ├── filter_rate_limit.go ├── filter_def.go ├── gogate-test.yml ├── statistics │ ├── stat_test.go │ ├── store_csv_file.go │ └── stat.go ├── server_context.go ├── server_response.go ├── server_test.go ├── server_send.go ├── server_handler.go ├── server_filter.go └── server.go ├── utils ├── net_test.go ├── idgen.go ├── stopwatch.go ├── net.go ├── rand_test.go ├── rand.go ├── collection_test.go └── collection.go ├── eureka.json ├── discovery ├── client_test.go ├── client.go ├── empty_client.go ├── syncmap_ins_meta_lbclient.go ├── syncmap_ins_info_arr.go ├── eureka_client_test.go ├── consul_client.go ├── refresh.go └── eureka_client.go ├── redis ├── redis_test.go └── redis.go ├── main.go ├── gogate.yml ├── route.yml ├── LICENSE ├── go.mod ├── lua └── rate_limiter.lua ├── conf ├── log.go └── config.go ├── examples └── usage.go ├── perr └── error.go └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.iml 3 | vendor/ 4 | _vendor* 5 | logs/ 6 | bin/ 7 | out/ 8 | go.sum -------------------------------------------------------------------------------- /throttle/rate_limiter.go: -------------------------------------------------------------------------------- 1 | package throttle 2 | 3 | type RateLimiter interface { 4 | Acquire() 5 | TryAcquire() bool 6 | } 7 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - "1.13" 5 | 6 | 7 | install: 8 | - GO111MODULE=on 9 | - go mod tidy 10 | 11 | script: 12 | - go build -------------------------------------------------------------------------------- /server/init.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | . "github.com/wanghongfei/gogate/conf" 5 | ) 6 | 7 | func InitGogate(gogateConfigFile string) { 8 | LoadConfig(gogateConfigFile) 9 | InitLog() 10 | } 11 | -------------------------------------------------------------------------------- /utils/net_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestGetFirstNoneLoopIp(t *testing.T) { 9 | ip, err := GetFirstNoneLoopIp() 10 | if nil != err { 11 | t.Error(err) 12 | return 13 | } 14 | 15 | fmt.Println(ip) 16 | } 17 | -------------------------------------------------------------------------------- /utils/idgen.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import "github.com/bwmarrin/snowflake" 4 | 5 | var generator *snowflake.Node 6 | 7 | func init() { 8 | node, err := snowflake.NewNode(1) 9 | if nil != err { 10 | panic(err) 11 | } 12 | 13 | generator = node 14 | } 15 | 16 | func GenerateUuid() int64 { 17 | return generator.Generate().Int64() 18 | } 19 | 20 | -------------------------------------------------------------------------------- /eureka.json: -------------------------------------------------------------------------------- 1 | { 2 | "config": { 3 | "certFile": "", 4 | "keyFile": "", 5 | "caCertFiles": null, 6 | "timeout": 1000000000, 7 | "consistency": "" 8 | }, 9 | "cluster": { 10 | "leader": "http://127.0.0.1:8761/eureka", 11 | "machines": [ 12 | "http://127.0.0.1:8761/eureka" 13 | ] 14 | } 15 | } -------------------------------------------------------------------------------- /server/route-test.yml: -------------------------------------------------------------------------------- 1 | services: 2 | service1: 3 | host: localhost:8080 4 | name: service1 5 | prefix: /service1 6 | strip-prefix: true 7 | service2: 8 | host: localhost:8081 9 | name: service2 10 | prefix: /service2 11 | strip-prefix: true 12 | service3: 13 | host: localhost:8082 14 | name: service3 15 | prefix: /service3 16 | strip-prefix: true 17 | -------------------------------------------------------------------------------- /utils/stopwatch.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import "time" 4 | 5 | type Stopwatch struct { 6 | start time.Time 7 | } 8 | 9 | // 创建一个计时器 10 | func NewStopwatch() *Stopwatch { 11 | return &Stopwatch{time.Now()} 12 | } 13 | 14 | // 返回从上次调用Record()到现的经过的毫秒数 15 | func (st *Stopwatch) Record() int64 { 16 | now := time.Now() 17 | diff := now.Sub(st.start).Nanoseconds() / 1000 / 1000 18 | 19 | st.start = now 20 | 21 | return diff 22 | } 23 | -------------------------------------------------------------------------------- /server/route/trie_tree_test.go: -------------------------------------------------------------------------------- 1 | package route 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestTrieTree_Search(t *testing.T) { 9 | tree := NewTrieTree() 10 | 11 | 12 | data := &ServiceInfo{} 13 | data.Id = "serviceA" 14 | tree.PutString("/", data) 15 | tree.PutString("abcde", data) 16 | fmt.Println(tree.SearchFirst("/abc")) 17 | 18 | data = tree.Search("abcde") 19 | fmt.Println(data.Id) 20 | 21 | data = tree.SearchFirst("abcdefgasdf") 22 | fmt.Println(data.Id) 23 | 24 | } 25 | -------------------------------------------------------------------------------- /utils/net.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | ) 7 | 8 | func GetFirstNoneLoopIp() (string, error) { 9 | addrs, err := net.InterfaceAddrs() 10 | if nil != err { 11 | return "", fmt.Errorf("failed to fetch interfaces => %w", err) 12 | } 13 | 14 | for _, addr := range addrs { 15 | if ip, ok := addr.(*net.IPNet); ok && !ip.IP.IsLoopback() { 16 | if ip.IP.To4() != nil { 17 | return ip.IP.String(), nil 18 | } 19 | } 20 | } 21 | 22 | return "", fmt.Errorf("no first-none-loop ip found") 23 | } 24 | -------------------------------------------------------------------------------- /utils/rand_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestRandomByWeight(t *testing.T) { 9 | var weights = []int { 10 | //80, 20, 10, 11 | 1, 1, 3, 12 | } 13 | 14 | stoneMap := make(map[int]int) 15 | for ix := 0; ix < 100; ix++ { 16 | next := RandomByWeight(weights) 17 | stoneMap[next]++ 18 | } 19 | 20 | fmt.Println(stoneMap) 21 | } 22 | 23 | func BenchmarkRandomByWeight(b *testing.B) { 24 | var weights = []int { 25 | //80, 20, 10, 26 | 1, 1, 3, 27 | } 28 | 29 | for i := 0; i < b.N; i++ { 30 | RandomByWeight(weights) 31 | } 32 | } -------------------------------------------------------------------------------- /discovery/client_test.go: -------------------------------------------------------------------------------- 1 | package discovery 2 | 3 | import ( 4 | "fmt" 5 | "github.com/hashicorp/consul/api" 6 | "testing" 7 | ) 8 | 9 | func TestQueryAll(t *testing.T) { 10 | // QueryAll() 11 | } 12 | 13 | func TestQueryConsul(t *testing.T) { 14 | cfg := &api.Config{} 15 | cfg.Address = "127.0.0.1:8500" 16 | cfg.Scheme = "http" 17 | 18 | client, err := api.NewClient(cfg) 19 | 20 | checkList, _, err := client.Health().State("passing", &api.QueryOptions{}) 21 | if nil != err { 22 | t.Error(err) 23 | return 24 | } 25 | 26 | for _, info := range checkList { 27 | fmt.Println(info.ServiceName) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /server/lb/loadbalancer_test.go: -------------------------------------------------------------------------------- 1 | package lb 2 | 3 | import ( 4 | "fmt" 5 | "github.com/wanghongfei/gogate/discovery" 6 | "testing" 7 | ) 8 | 9 | func TestRoundRobinLoadBalancer_Choose(t *testing.T) { 10 | lb := &RoundRobinLoadBalancer{} 11 | 12 | instances := make([]*discovery.InstanceInfo, 0) 13 | instances = append(instances, &discovery.InstanceInfo{ 14 | Addr: "1", 15 | }) 16 | instances = append(instances, &discovery.InstanceInfo{ 17 | Addr: "2", 18 | }) 19 | instances = append(instances, &discovery.InstanceInfo{ 20 | Addr: "3", 21 | }) 22 | 23 | for ix := 0; ix < 10; ix++ { 24 | target := lb.Choose(instances) 25 | fmt.Println(target.Addr) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /server/filter_serv_match.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/valyala/fasthttp" 5 | . "github.com/wanghongfei/gogate/conf" 6 | ) 7 | 8 | func ServiceMatchPreFilter(s *Server, ctx *fasthttp.RequestCtx, newRequest *fasthttp.Request) bool { 9 | uri := GetStringFromUserValue(ctx, REQUEST_PATH) 10 | 11 | servInfo := s.Router.Match(uri) 12 | if nil == servInfo { 13 | // 没匹配到 14 | ctx.Response.SetStatusCode(404) 15 | NewResponse(ctx.UserValue(REQUEST_PATH).(string), "no match").Send(ctx) 16 | return false 17 | } 18 | ctx.SetUserValue(ROUTE_INFO, servInfo) 19 | ctx.SetUserValue(SERVICE_NAME, servInfo.Id) 20 | 21 | Log.Debugf("%s matched to %s", uri, servInfo.Id) 22 | 23 | return true 24 | } 25 | -------------------------------------------------------------------------------- /discovery/client.go: -------------------------------------------------------------------------------- 1 | package discovery 2 | 3 | 4 | // 封装服务实例信息 5 | type InstanceInfo struct { 6 | ServiceName string 7 | 8 | // 格式为 host:port 9 | Addr string 10 | // 此实例附加信息 11 | Meta map[string]string 12 | } 13 | 14 | // 服务发现客户端接口 15 | type Client interface { 16 | // 直接向远程注册中心查询所有服务实例 17 | QueryServices() ([]*InstanceInfo, error) 18 | 19 | // 注册自己 20 | Register() error 21 | 22 | // 取消注册自己 23 | UnRegister() error 24 | 25 | // 从本地缓存中查询指定服务的全部实例信息 26 | Get(string) []*InstanceInfo 27 | 28 | // 启动注册信息定时刷新逻辑 29 | StartPeriodicalRefresh() error 30 | 31 | // 获取内部保存的注册表 32 | GetInternalRegistryStore() *InsInfoArrSyncMap 33 | 34 | // 更新内部保存的注册表 35 | SetInternalRegistryStore(*InsInfoArrSyncMap) 36 | } 37 | 38 | -------------------------------------------------------------------------------- /server/filter_url_rewrite.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/valyala/fasthttp" 5 | . "github.com/wanghongfei/gogate/conf" 6 | ) 7 | 8 | func UrlRewritePreFilter(s *Server, ctx *fasthttp.RequestCtx, newRequest *fasthttp.Request) bool { 9 | info, ok := GetServiceInfoFromUserValue(ctx, ROUTE_INFO) 10 | if !ok { 11 | return true 12 | } 13 | 14 | 15 | if info.StripPrefix { 16 | // path中去掉prefix 17 | original := string(newRequest.URI().Path()) 18 | posToStrip := len(info.Prefix) 19 | 20 | newPath := original[posToStrip:] 21 | if newPath == "" { 22 | newPath = "/" 23 | } 24 | newRequest.URI().SetPath(newPath) 25 | 26 | Log.Debugf("rewrite path from %s to %s", original, newPath) 27 | } 28 | 29 | return true 30 | } 31 | -------------------------------------------------------------------------------- /discovery/empty_client.go: -------------------------------------------------------------------------------- 1 | package discovery 2 | 3 | var DoNothingClient = new(EmptyClient) 4 | 5 | type EmptyClient struct{ 6 | 7 | } 8 | 9 | func (e EmptyClient) QueryServices() ([]*InstanceInfo, error) { 10 | return nil, nil 11 | } 12 | 13 | func (e EmptyClient) Register() error { 14 | return nil 15 | } 16 | 17 | func (e EmptyClient) UnRegister() error { 18 | return nil 19 | } 20 | 21 | func (e EmptyClient) Get(string) []*InstanceInfo { 22 | return nil 23 | } 24 | 25 | func (e EmptyClient) StartPeriodicalRefresh() error { 26 | return nil 27 | } 28 | 29 | func (e EmptyClient) GetInternalRegistryStore() *InsInfoArrSyncMap { 30 | return nil 31 | } 32 | 33 | func (e EmptyClient) SetInternalRegistryStore(*InsInfoArrSyncMap) { 34 | } 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /server/syncmap_rate_limiter.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/wanghongfei/gogate/throttle" 7 | ) 8 | 9 | // 封装sync.map, 提供类型安全的方法调用 10 | type RateLimiterSyncMap struct { 11 | rlMap *sync.Map 12 | } 13 | 14 | func NewRateLimiterSyncMap() *RateLimiterSyncMap { 15 | return &RateLimiterSyncMap{ 16 | rlMap: new(sync.Map), 17 | } 18 | } 19 | 20 | func (rsm *RateLimiterSyncMap) Get(key string) (throttle.RateLimiter, bool) { 21 | val, exist := rsm.rlMap.Load(key) 22 | if !exist { 23 | return nil, false 24 | } 25 | 26 | rl, ok := val.(throttle.RateLimiter) 27 | if !ok { 28 | return nil, false 29 | } 30 | 31 | return rl, true 32 | } 33 | 34 | func (rsm *RateLimiterSyncMap) Put(key string, val throttle.RateLimiter) { 35 | rsm.rlMap.Store(key, val) 36 | } 37 | -------------------------------------------------------------------------------- /throttle/rate_limiter_memory_test.go: -------------------------------------------------------------------------------- 1 | package throttle 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | ) 9 | 10 | func TestNewMemoryRateLimiter(t *testing.T) { 11 | rl := NewMemoryRateLimiter(1000) 12 | fmt.Println(rl) 13 | 14 | } 15 | 16 | func TestRateLimiter_Acquire(t *testing.T) { 17 | rl := NewMemoryRateLimiter(1) 18 | 19 | count := 0 20 | for { 21 | rl.Acquire() 22 | fmt.Println(time.Now()) 23 | 24 | count++ 25 | if count >= 10 { 26 | break 27 | } 28 | } 29 | } 30 | 31 | func TestRateLimiter_TryAcquire(t *testing.T) { 32 | count := 0 33 | 34 | rl := NewMemoryRateLimiter(1) 35 | for { 36 | pass := rl.TryAcquire() 37 | if pass { 38 | fmt.Println(time.Now()) 39 | } 40 | 41 | count++ 42 | if count >= 10 { 43 | break 44 | } 45 | } 46 | 47 | } 48 | -------------------------------------------------------------------------------- /discovery/syncmap_ins_meta_lbclient.go: -------------------------------------------------------------------------------- 1 | package discovery 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/valyala/fasthttp" 7 | ) 8 | 9 | // 封装sync.map, 提供类型安全的方法调用 10 | type InsLbClientSyncMap struct { 11 | mcMap *sync.Map 12 | } 13 | 14 | func NewInsMetaLbClientSyncMap() *InsLbClientSyncMap { 15 | return &InsLbClientSyncMap{ 16 | mcMap: new(sync.Map), 17 | } 18 | } 19 | 20 | func (ism *InsLbClientSyncMap) Get(key string) (*fasthttp.LBClient, bool) { 21 | val, exist := ism.mcMap.Load(key) 22 | if !exist { 23 | return nil, false 24 | } 25 | 26 | syncMap, ok := val.(*fasthttp.LBClient) 27 | if !ok { 28 | return nil, false 29 | } 30 | 31 | return syncMap, true 32 | } 33 | 34 | func (ism *InsLbClientSyncMap) Put(key string, val *fasthttp.LBClient) { 35 | ism.mcMap.Store(key, val) 36 | } 37 | -------------------------------------------------------------------------------- /utils/rand.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "math/rand" 5 | "time" 6 | ) 7 | 8 | var globalRand *rand.Rand 9 | 10 | func init() { 11 | globalRand = rand.New(rand.NewSource(time.Now().UnixNano())) 12 | } 13 | 14 | // 根据权重值生成随机落点; 15 | // weight: 权重数组, 例如传1,2,3, 则会按1:2:3的概率生成 16 | // return: 此次落点的索引, 索引值对应于传入参数weight 17 | func RandomByWeight(weight []int) int { 18 | // 计算出生成随机数时的范围最大值 19 | max := 0 20 | for _, w := range weight { 21 | // 乘以10进行放大 22 | max += w * 10 23 | } 24 | 25 | // 生成随机数 26 | stone := 0 27 | for stone == 0 { 28 | stone = globalRand.Intn(max) 29 | } 30 | 31 | // 判断随机数落在了哪个区间 32 | sum := 0 33 | for ix, w := range weight { 34 | start := sum 35 | sum += w * 10 36 | 37 | if stone > start && stone <= sum { 38 | return ix 39 | } 40 | } 41 | 42 | return -1 43 | } 44 | -------------------------------------------------------------------------------- /utils/collection_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestSlice(t *testing.T) { 9 | a := make([]int, 0, 10) 10 | fmt.Printf("cap(a) = %d, len(a) = %d\n", cap(a), len(a)) 11 | 12 | b := append(a, 1) 13 | fmt.Printf("cap(a) = %d, len(a) = %d, cap(b) = %d, len(b) = %d\n", cap(a), len(a), cap(b), len(b)) 14 | 15 | _ = append(a, 2) 16 | fmt.Printf("cap(a) = %d, len(a) = %d, cap(b) = %d, len(b) = %d\n", cap(a), len(a), cap(b), len(b)) 17 | 18 | println(b[0]) 19 | } 20 | 21 | func TestCopy(t *testing.T) { 22 | arr := make([]int, 0, 6) 23 | for ix := 0; ix < cap(arr); ix++ { 24 | arr = append(arr, ix + 1) 25 | } 26 | fmt.Println(arr) 27 | 28 | targetIx := 1 29 | copy(arr[targetIx + 2:], arr[targetIx + 1:]) 30 | arr[targetIx + 1] = 999 31 | fmt.Println(arr) 32 | 33 | } 34 | -------------------------------------------------------------------------------- /redis/redis_test.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "os" 7 | "testing" 8 | 9 | ) 10 | 11 | func TestGetString(t *testing.T) { 12 | c := NewRedisClient("127.0.0.1:6379", 1) 13 | err := c.Connect() 14 | if nil != err { 15 | t.Error(err) 16 | return 17 | } 18 | 19 | 20 | str, err := c.GetString("abc") 21 | fmt.Println(str) 22 | c.Close() 23 | } 24 | 25 | func TestRedisClient_ExeLuaInt(t *testing.T) { 26 | c := NewRedisClient("127.0.0.1:6379", 1) 27 | err := c.Connect() 28 | if nil != err { 29 | t.Error(err) 30 | return 31 | } 32 | defer c.Close() 33 | 34 | luaFile, err := os.Open("../lua/rate_limiter.lua") 35 | if nil != err { 36 | t.Error(err) 37 | return 38 | } 39 | defer luaFile.Close() 40 | 41 | luaBuf, _ := ioutil.ReadAll(luaFile) 42 | resp, err := c.ExeLuaInt(string(luaBuf), nil, []string{"10"}) 43 | if err != nil { 44 | t.Error(err) 45 | return 46 | } 47 | 48 | fmt.Println(resp) 49 | } 50 | -------------------------------------------------------------------------------- /server/filter_rate_limit.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/valyala/fasthttp" 5 | . "github.com/wanghongfei/gogate/conf" 6 | ) 7 | 8 | // 控制QPS的前置过虑器 9 | func RateLimitPreFilter(s *Server, ctx *fasthttp.RequestCtx, newRequest *fasthttp.Request) bool { 10 | // 取出router结果 11 | info, ok := GetServiceInfoFromUserValue(ctx, ROUTE_INFO) 12 | if !ok { 13 | return true 14 | } 15 | 16 | // 取出对应service的限速器 17 | if 0 == info.Qps { 18 | // 如果没有说明不需要限速 19 | Log.Debugf("no limiter for service %s", info.Id) 20 | return true 21 | } 22 | 23 | // 取出限速器 24 | rl, ok := s.rateLimiterMap.Get(info.Id) 25 | if !ok { 26 | Log.Errorf("lack rate limiter for %s", info.Id) 27 | return true 28 | } 29 | 30 | pass := rl.TryAcquire() 31 | if !pass { 32 | // token不足 33 | NewResponse(ctx.UserValue(REQUEST_PATH).(string), "reach QPS limitation").Send(ctx) 34 | Log.Infof("drop request for %s due to rate limitation", info.Id) 35 | } 36 | 37 | return pass 38 | } 39 | -------------------------------------------------------------------------------- /server/filter_def.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/valyala/fasthttp" 5 | ) 6 | 7 | // 前置过滤器函数 8 | type PreFilterFunc func(server *Server, ctx *fasthttp.RequestCtx, newRequest *fasthttp.Request) bool 9 | // 前置过滤器对象 10 | type PreFilter struct { 11 | FilterFunc PreFilterFunc 12 | Name string 13 | } 14 | 15 | func NewPreFilter(name string, filter PreFilterFunc) *PreFilter { 16 | return &PreFilter{ 17 | FilterFunc: filter, 18 | Name: name, 19 | } 20 | } 21 | 22 | func (pf *PreFilter) String() string { 23 | return pf.Name 24 | } 25 | 26 | 27 | // 后置过滤器函数 28 | type PostFilterFunc func(req *fasthttp.Request, resp *fasthttp.Response) bool 29 | // 后置过滤器对象 30 | type PostFilter struct { 31 | FilterFunc PostFilterFunc 32 | Name string 33 | } 34 | 35 | func NewPostFilter(name string, filter PostFilterFunc) *PostFilter { 36 | return &PostFilter{ 37 | FilterFunc: filter, 38 | Name: name, 39 | } 40 | } 41 | 42 | func (pf *PostFilter) String() string { 43 | return pf.Name 44 | } 45 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | . "github.com/wanghongfei/gogate/conf" 5 | "github.com/wanghongfei/gogate/perr" 6 | serv "github.com/wanghongfei/gogate/server" 7 | "os" 8 | ) 9 | 10 | func main() { 11 | // 初始化 12 | serv.InitGogate("gogate.yml") 13 | 14 | // 构造gogate对象 15 | server, err := serv.NewGatewayServer( 16 | App.ServerConfig.Host, 17 | App.ServerConfig.Port, 18 | App.EurekaConfig.RouteFile, 19 | App.ServerConfig.MaxConnection, 20 | ) 21 | checkErrorExit(err, true) 22 | 23 | Log.Infof("pre filters: %v", server.ExportAllPreFilters()) 24 | Log.Infof("post filters: %v", server.ExportAllPostFilters()) 25 | 26 | 27 | // 启动服务器 28 | err = server.Start() 29 | checkErrorExit(err, true) 30 | Log.Info("listener has been closed") 31 | 32 | // 等待优雅关闭 33 | err = server.Shutdown() 34 | checkErrorExit(err, false) 35 | } 36 | 37 | func checkErrorExit(err error, exit bool) { 38 | if nil != err { 39 | Log.Error(perr.EnvMsg(err)) 40 | 41 | if exit { 42 | os.Exit(1) 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /gogate.yml: -------------------------------------------------------------------------------- 1 | version: 1.0 2 | 3 | server: 4 | # 向eureka注册自己时使用的服务名 5 | appName: gogate 6 | host: 0.0.0.0 7 | port: 8001 8 | # gateway最大连接数 9 | maxConnection: 2000 10 | # gateway请求后端服务超时时间, 毫秒 11 | timeout: 3000 12 | 13 | # 如果eureka, consul都没有启动, 则进入静态模式, 不访问注册中心 14 | eureka: 15 | enable: true 16 | # eureka配置文件名 17 | configFile: eureka.json 18 | # 路由配置文件名 19 | routeFile: route.yml 20 | # eureka剔除服务的最大时间限值, 秒 21 | evictionDuration: 30 22 | # 心跳间隔, 秒 23 | heartbeatInterval: 20 24 | 25 | consul: 26 | enable: false 27 | address: 127.0.0.1:8500 28 | 29 | 30 | traffic: 31 | # 是否开启流量记录功能 32 | enableTrafficRecord: false 33 | # 流量日志文件所在目录 34 | trafficLogDir: /tmp 35 | 36 | redis: 37 | # 是否使用redis做限速器 38 | enabled: false 39 | # 目前只支持单实例, 不支持cluster 40 | addr: 127.0.0.1:6379 41 | # 限速器lua代码文件 42 | rateLimiterLua: lua/rate_limiter.lua 43 | 44 | log: 45 | console-only: true 46 | directory: "logs" 47 | file-pattern: "logs/gogate.log.%Y-%m-%d" 48 | file-link: "logs/gogate.log" 49 | -------------------------------------------------------------------------------- /server/gogate-test.yml: -------------------------------------------------------------------------------- 1 | version: 1.0 2 | 3 | server: 4 | # 向eureka注册自己时使用的服务名 5 | appName: gogate 6 | host: 0.0.0.0 7 | port: 8001 8 | # gateway最大连接数 9 | maxConnection: 2000 10 | # gateway请求后端服务超时时间, 毫秒 11 | timeout: 3000 12 | 13 | # 如果eureka, consul都没有启动, 则进入静态模式, 不访问注册中心 14 | eureka: 15 | enable: false 16 | # eureka配置文件名 17 | configFile: eureka.json 18 | # 路由配置文件名 19 | routeFile: route.yml 20 | # eureka剔除服务的最大时间限值, 秒 21 | evictionDuration: 30 22 | # 心跳间隔, 秒 23 | heartbeatInterval: 20 24 | 25 | consul: 26 | enable: false 27 | address: 127.0.0.1:8500 28 | 29 | 30 | traffic: 31 | # 是否开启流量记录功能 32 | enableTrafficRecord: false 33 | # 流量日志文件所在目录 34 | trafficLogDir: /tmp 35 | 36 | redis: 37 | # 是否使用redis做限速器 38 | enabled: false 39 | # 目前只支持单实例, 不支持cluster 40 | addr: 127.0.0.1:6379 41 | # 限速器lua代码文件 42 | rateLimiterLua: lua/rate_limiter.lua 43 | 44 | log: 45 | console-only: true 46 | directory: "/tmp/logs" 47 | file-pattern: "/tmp/logs/gogate.log.%Y-%m-%d" 48 | file-link: "/tmp/logs/gogate.log" 49 | -------------------------------------------------------------------------------- /server/lb/loadbalancer.go: -------------------------------------------------------------------------------- 1 | package lb 2 | 3 | import ( 4 | "github.com/wanghongfei/gogate/discovery" 5 | "sync/atomic" 6 | ) 7 | 8 | // 负载均衡接口 9 | type LoadBalancer interface { 10 | // 从instance中选一个对象返回 11 | Choose(instances []*discovery.InstanceInfo) *discovery.InstanceInfo 12 | 13 | ChooseByAddresses(addrs []string) string 14 | } 15 | 16 | // 轮询均衡器实现 17 | type RoundRobinLoadBalancer struct { 18 | index int64 19 | } 20 | 21 | func (lb *RoundRobinLoadBalancer) Choose(instances []*discovery.InstanceInfo) *discovery.InstanceInfo { 22 | total := len(instances) 23 | next := lb.nextIndex(total) 24 | 25 | return instances[next] 26 | } 27 | 28 | func (lb *RoundRobinLoadBalancer) ChooseByAddresses(addrs []string) string { 29 | total := len(addrs) 30 | next := lb.nextIndex(total) 31 | 32 | return addrs[next] 33 | } 34 | 35 | func (lb *RoundRobinLoadBalancer) nextIndex(total int) int64 { 36 | next := lb.index % int64(total) 37 | if next < 0 { 38 | next = next * -1 39 | } 40 | 41 | atomic.AddInt64(&lb.index, 1) 42 | 43 | return next 44 | } 45 | 46 | -------------------------------------------------------------------------------- /discovery/syncmap_ins_info_arr.go: -------------------------------------------------------------------------------- 1 | package discovery 2 | 3 | import ( 4 | "sync" 5 | 6 | ) 7 | 8 | // 封装sync.map, 提供类型安全的方法调用 9 | type InsInfoArrSyncMap struct { 10 | dataMap *sync.Map 11 | } 12 | 13 | func NewInsInfoArrSyncMap() *InsInfoArrSyncMap { 14 | return &InsInfoArrSyncMap{ 15 | dataMap: new(sync.Map), 16 | } 17 | } 18 | 19 | func (ism *InsInfoArrSyncMap) Get(key string) ([]*InstanceInfo, bool) { 20 | val, exist := ism.dataMap.Load(key) 21 | if !exist { 22 | return nil, false 23 | } 24 | 25 | info, ok := val.([]*InstanceInfo) 26 | if !ok { 27 | return nil, false 28 | } 29 | 30 | return info, true 31 | } 32 | 33 | func (ism *InsInfoArrSyncMap) Put(key string, val []*InstanceInfo) { 34 | ism.dataMap.Store(key, val) 35 | } 36 | 37 | func (ism *InsInfoArrSyncMap) Each(eachFunc func(key string, val []*InstanceInfo) bool) { 38 | ism.dataMap.Range(func(key, value interface{}) bool { 39 | return eachFunc(key.(string), value.([]*InstanceInfo)) 40 | }) 41 | } 42 | 43 | func (ism *InsInfoArrSyncMap) GetMap() *sync.Map { 44 | return ism.dataMap 45 | } 46 | -------------------------------------------------------------------------------- /server/statistics/stat_test.go: -------------------------------------------------------------------------------- 1 | package stat 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestNewCsvFileTraficInfoStore(t *testing.T) { 10 | info := &TraficInfo{ 11 | ServiceId: "user-service", 12 | SuccessCount: 10, 13 | FailedCount: 1, 14 | timestamp: time.Now().UnixNano() / 10e6, 15 | } 16 | 17 | cf := NewCsvFileTraficInfoStore("/tmp") 18 | 19 | 20 | err := cf.Send(info) 21 | if nil != err { 22 | t.Error(err) 23 | return 24 | } 25 | 26 | err = cf.Close() 27 | if nil != err { 28 | t.Error(err) 29 | } 30 | } 31 | 32 | func TestNewTraficStat(t *testing.T) { 33 | stat := NewTrafficStat(10, 1, NewCsvFileTraficInfoStore("/tmp")) 34 | stat.StartRecordTrafic() 35 | 36 | ticker := time.NewTicker(time.Millisecond * 400) 37 | count := 0 38 | for { 39 | <- ticker.C 40 | 41 | info := &TraficInfo{ 42 | ServiceId: "dog-service", 43 | SuccessCount: 1, 44 | FailedCount: 0, 45 | } 46 | 47 | stat.RecordTrafic(info) 48 | fmt.Println("put") 49 | 50 | count ++ 51 | if count > 10 { 52 | break 53 | } 54 | } 55 | } -------------------------------------------------------------------------------- /route.yml: -------------------------------------------------------------------------------- 1 | services: 2 | user-service: 3 | # eureka中的服务名 4 | id: user-service 5 | # 以/user开头的请求, 会被转发到user-service服务中 6 | prefix: /user 7 | # 转发时是否去掉请求前缀, 即/user 8 | strip-prefix: true 9 | # 设置qps限制, 每秒最多请求数 10 | qps: 1 11 | # 灰度配置 12 | canary: 13 | - 14 | # 对应eurekai注册信息中元数据(metadata map)中key=version的值 15 | meta: "1.0" 16 | # 流量比重 17 | weight: 3 18 | - 19 | meta: "2.0" 20 | weight: 4 21 | - 22 | meta: "" 23 | weight: 10 24 | 25 | trends-service: 26 | id: trends-service 27 | # 请求路径当匹配多个prefix时, 长的获胜 28 | prefix: /trends 29 | strip-prefix: false 30 | # 设置qps限制, 每秒最多请求数 31 | qps: 1 32 | 33 | order-service: 34 | id: order-service 35 | prefix: /order 36 | strip-prefix: false 37 | 38 | img-service: 39 | # 如果有host, 则不查注册中心直接使用此地址, 多个地址逗号分隔, 不能有空格 40 | host: localhost:8081,localhost:8080 41 | name: img-service 42 | prefix: /img 43 | strip-prefix: true 44 | 45 | engine-service: 46 | id: engine-service 47 | prefix: /engine 48 | strip-prefix: true 49 | 50 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 wanghongfei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/wanghongfei/gogate 2 | 3 | go 1.13 4 | 5 | replace golang.org/x/net => github.com/golang/net v0.0.0-20190404232315-eb5bcb51f2a3 6 | 7 | replace golang.org/x/text => github.com/golang/text v0.3.0 8 | 9 | replace golang.org/x/crypto => github.com/golang/crypto v0.0.0-20190411191339-88737f569e3a 10 | 11 | replace golang.org/x/sys => github.com/golang/sys v0.0.0-20190412213103-97732733099d 12 | 13 | require go.uber.org/zap v1.13.0 14 | 15 | require ( 16 | github.com/bwmarrin/snowflake v0.3.0 17 | github.com/fastly/go-utils v0.0.0-20180712184237-d95a45783239 // indirect 18 | github.com/hashicorp/consul/api v1.0.1 19 | github.com/jehiah/go-strftime v0.0.0-20171201141054-1d33003b3869 // indirect 20 | github.com/jonboulle/clockwork v0.1.0 // indirect 21 | github.com/lestrrat-go/file-rotatelogs v2.3.0+incompatible 22 | github.com/lestrrat-go/strftime v1.0.1 // indirect 23 | github.com/mediocregopher/radix.v2 v0.0.0-20181115013041-b67df6e626f9 24 | github.com/tebeka/strftime v0.1.3 // indirect 25 | github.com/valyala/fasthttp v1.9.0 26 | github.com/wanghongfei/go-eureka-client v1.1.0 27 | gopkg.in/yaml.v2 v2.2.2 28 | ) 29 | -------------------------------------------------------------------------------- /discovery/eureka_client_test.go: -------------------------------------------------------------------------------- 1 | package discovery 2 | 3 | import ( 4 | "fmt" 5 | "github.com/hashicorp/consul/api" 6 | "testing" 7 | ) 8 | 9 | func TestStartRegister(t *testing.T) { 10 | // StartRegister() 11 | // time.Sleep(time.Second * 60) 12 | } 13 | 14 | func TestRegisterToConsul(t *testing.T) { 15 | client, err := api.NewClient(api.DefaultConfig()) 16 | if nil != err { 17 | t.Error(err) 18 | return 19 | } 20 | 21 | reg := &api.AgentServiceRegistration{} 22 | reg.ID = "id" 23 | reg.Name = "go-unit-test" 24 | reg.Address = "127.0.0.1" 25 | reg.Port = 8080 26 | reg.Check = &api.AgentServiceCheck{} 27 | reg.Check.HTTP = "http://127.0.0.1:9000" 28 | reg.Check.Method = "GET" 29 | reg.Check.Interval = "10s" 30 | reg.Check.Timeout = "1s" 31 | // reg.Check.DeregisterCriticalServiceAfter = "2s" 32 | 33 | err = client.Agent().ServiceRegister(reg) 34 | if nil != err { 35 | t.Error(err) 36 | return 37 | } 38 | 39 | 40 | servMap, err := client.Agent().Services() 41 | if nil != err { 42 | t.Error(err) 43 | return 44 | } 45 | 46 | for name, serv := range servMap { 47 | fmt.Println(name) 48 | fmt.Println(serv) 49 | } 50 | } -------------------------------------------------------------------------------- /lua/rate_limiter.lua: -------------------------------------------------------------------------------- 1 | -- 流量控制脚本, 设置一个TTL = 1的HASH对象, 有2个entry, 分别是当前请求次数current和当前最大请求次数max; 2 | -- 当current + 1 > max时则返回0, 当HASH对象不存在或者current + 1 <= max时返回1; 3 | 4 | -- 参数: ARGV[1]:服务名; ARGV[2]: qps 5 | -- 返回0表示达到限流, 返回1表示OK 6 | 7 | local servId = ARGV[1] 8 | if nil == servId then 9 | servId = "global" 10 | end 11 | 12 | local HASH_KEY = "goate:ratelimiter:" .. ARGV[1] 13 | 14 | local mapResult = redis.call("hgetall", HASH_KEY) 15 | -- 判断hgetall是否为空 16 | if nil == next(mapResult) then 17 | local max = ARGV[2] 18 | if nil == max then 19 | max = 100 20 | end 21 | 22 | redis.call("hmset", HASH_KEY, "current", 1, "max", max) 23 | redis.call("expire", HASH_KEY, 1) 24 | 25 | return 1 26 | end 27 | 28 | -- 将hgetall的返回对象转换成table(hashKey-val) 29 | local limiterMap = {} 30 | local nextkey 31 | for i, v in ipairs(mapResult) do 32 | if i % 2 == 1 then 33 | nextkey = v 34 | else 35 | limiterMap[nextkey] = v 36 | end 37 | end 38 | 39 | 40 | local current = tonumber(limiterMap["current"]) 41 | local max = tonumber(limiterMap["max"]) 42 | if current + 1 > max then 43 | return 0 44 | end 45 | 46 | redis.call("hincrby", HASH_KEY, "current", 1) 47 | return 1 48 | 49 | -------------------------------------------------------------------------------- /server/server_context.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/valyala/fasthttp" 5 | "github.com/wanghongfei/gogate/server/route" 6 | "github.com/wanghongfei/gogate/utils" 7 | ) 8 | 9 | // 从请求上下文中取出*ServiceInfo 10 | func GetServiceInfoFromUserValue(ctx *fasthttp.RequestCtx, key string) (*route.ServiceInfo, bool) { 11 | val := ctx.UserValue(key) 12 | if nil == val { 13 | return nil, false 14 | } 15 | 16 | info, ok := val.(*route.ServiceInfo) 17 | if !ok { 18 | return nil, false 19 | } 20 | 21 | return info, true 22 | } 23 | 24 | // 从请求上下文中取出string 25 | func GetStringFromUserValue(ctx *fasthttp.RequestCtx, key string) string { 26 | val := ctx.UserValue(key) 27 | if nil == val { 28 | return "" 29 | } 30 | 31 | str, ok := val.(string) 32 | if !ok { 33 | return "" 34 | } 35 | 36 | return str 37 | } 38 | 39 | func GetInt64FromUserValue(ctx *fasthttp.RequestCtx, key string) int64 { 40 | val := ctx.UserValue(key) 41 | if nil == val { 42 | return -1 43 | } 44 | 45 | num, ok := val.(int64) 46 | if !ok { 47 | return -1 48 | } 49 | 50 | return num 51 | } 52 | 53 | func GetStopWatchFromUserValue(ctx *fasthttp.RequestCtx) *utils.Stopwatch { 54 | return ctx.UserValue(STOPWATCH).(*utils.Stopwatch) 55 | } -------------------------------------------------------------------------------- /server/route/matcher.go: -------------------------------------------------------------------------------- 1 | package route 2 | 3 | import "strings" 4 | 5 | type PathMatcher struct { 6 | routeMap map[string]*ServiceInfo 7 | routeTrieTree *TrieTree 8 | } 9 | 10 | func (matcher *PathMatcher) Match(path string) *ServiceInfo { 11 | // 如果大于3个token则使用TrieTree匹配提高性能 12 | if strings.Count(path, "/") >= 3 { 13 | return matcher.matchByTree(path) 14 | } 15 | 16 | // 使用切token的方式匹配 17 | return matcher.matchByToken(path) 18 | } 19 | 20 | func (matcher *PathMatcher) matchByTree(path string) *ServiceInfo { 21 | return matcher.routeTrieTree.SearchFirst(path) 22 | } 23 | 24 | func (matcher *PathMatcher) matchByToken(path string) *ServiceInfo { 25 | if !strings.HasSuffix(path, "/") { 26 | path = path + "/" 27 | } 28 | 29 | if "/" == path { 30 | path = "//" 31 | } 32 | 33 | // 以/为分隔符, 从后向前匹配 34 | // 每次循环都去掉最后一个/XXXX节点 35 | term := path 36 | for { 37 | lastSlash := strings.LastIndex(term, "/") 38 | if -1 == lastSlash { 39 | break 40 | } 41 | 42 | matchTerm := term[0:lastSlash] 43 | term = matchTerm 44 | 45 | if "" == matchTerm { 46 | matchTerm = "/" 47 | } 48 | 49 | appId, exist := matcher.routeMap[matchTerm] 50 | if exist { 51 | return appId 52 | } 53 | } 54 | 55 | return nil 56 | 57 | } 58 | -------------------------------------------------------------------------------- /server/server_response.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/valyala/fasthttp" 7 | . "github.com/wanghongfei/gogate/conf" 8 | ) 9 | 10 | type GogateResponse struct { 11 | RequestId int64`json:"requestId"` 12 | Path string`json:"path"` 13 | Error string`json:"error"` 14 | } 15 | 16 | func NewResponse(path, msg string) *GogateResponse { 17 | return &GogateResponse{ 18 | Path: path, 19 | Error: msg, 20 | } 21 | } 22 | 23 | func (resp *GogateResponse) ToJson() string { 24 | return string(resp.ToJsonBytes()) 25 | } 26 | 27 | func (resp *GogateResponse) ToJsonBytes() []byte { 28 | buf, _ := json.Marshal(resp) 29 | 30 | return buf 31 | } 32 | 33 | func (resp *GogateResponse) Send(ctx *fasthttp.RequestCtx) { 34 | ctx.Response.Header.Set("Content-Type", "application/json;charset=utf8") 35 | resp.RequestId = GetInt64FromUserValue(ctx, REQUEST_ID) 36 | timer := GetStopWatchFromUserValue(ctx) 37 | 38 | responseBody := resp.ToJson() 39 | Log.Infof("request %d finished, cost = %dms, statusCode = %d, response = %s", resp.RequestId, timer.Record(), ctx.Response.StatusCode(), responseBody) 40 | ctx.WriteString(responseBody) 41 | } 42 | 43 | func (resp *GogateResponse) SendWithStatus(ctx *fasthttp.RequestCtx, statusCode int) { 44 | ctx.SetStatusCode(statusCode) 45 | resp.Send(ctx) 46 | } 47 | -------------------------------------------------------------------------------- /redis/redis.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "github.com/mediocregopher/radix.v2/pool" 5 | "github.com/wanghongfei/gogate/perr" 6 | ) 7 | 8 | // Redis Client, 只能连接一个redis实例, 有连接池 9 | type RedisClient struct { 10 | addr string 11 | poolSize int 12 | connPool *pool.Pool 13 | 14 | isConnected bool 15 | } 16 | 17 | func NewRedisClient(addr string, poolSize int) *RedisClient { 18 | if poolSize < 1 { 19 | poolSize = 1 20 | } 21 | 22 | return &RedisClient{ 23 | addr: addr, 24 | poolSize: poolSize, 25 | } 26 | } 27 | 28 | func (crd *RedisClient) GetString(key string) (string, error) { 29 | resp := crd.connPool.Cmd("get", key) 30 | if nil != resp.Err { 31 | return "", perr.WrapSystemErrorf(resp.Err, "failed to GetString") 32 | } 33 | 34 | return resp.Str() 35 | } 36 | 37 | func (crd *RedisClient) ExeLuaInt(lua string, keys []string, args []string) (int, error) { 38 | resp := crd.connPool.Cmd("eval", lua, len(keys), keys, args) 39 | if nil != resp.Err { 40 | return 0, resp.Err 41 | } 42 | 43 | return resp.Int() 44 | } 45 | 46 | func (crd *RedisClient) Close() { 47 | crd.connPool.Empty() 48 | crd.isConnected = false 49 | } 50 | 51 | func (crd *RedisClient) IsConnected() bool { 52 | return crd.isConnected 53 | } 54 | 55 | func (crd *RedisClient) Connect() error { 56 | conn, err := pool.New("tcp", crd.addr, crd.poolSize) 57 | if err != nil { 58 | return perr.WrapSystemErrorf(err, "failed to connect to redis") 59 | } 60 | 61 | crd.connPool = conn 62 | crd.isConnected = true 63 | return nil 64 | } 65 | -------------------------------------------------------------------------------- /utils/collection.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "errors" 5 | "github.com/wanghongfei/gogate/perr" 6 | "sync" 7 | ) 8 | 9 | /* 10 | * 从map中删除指定的key 11 | 12 | * PARAMS: 13 | * - baseMap: 要删除key的map 14 | * - keys: 要删除的key数组 15 | */ 16 | func DelKeys(baseMap *sync.Map, keys []interface{}) error { 17 | if nil == baseMap { 18 | return perr.WrapSystemErrorf(nil, "baseMap cannot be null") 19 | } 20 | 21 | for _, key := range keys { 22 | baseMap.Delete(key) 23 | } 24 | 25 | return nil 26 | } 27 | 28 | /* 29 | * 两个map取并集 30 | * 31 | * PARAMS: 32 | * - fromMap: 源map 33 | * - toMap: 合并后的map 34 | * 35 | */ 36 | func MergeSyncMap(fromMap, toMap *sync.Map) error { 37 | if nil == fromMap || nil == toMap { 38 | return perr.WrapSystemErrorf(nil, "fromMap or toMap cannot be null") 39 | } 40 | 41 | fromMap.Range(func(key, value interface{}) bool { 42 | toMap.Store(key, value) 43 | return true 44 | }) 45 | 46 | return nil 47 | } 48 | 49 | /* 50 | * 找出在baseMap中存在但yMap中不存在的元素 51 | * 52 | * PARAMS: 53 | * - baseMap: 独有元素所在的map 54 | * - yMap: 对比map 55 | * 56 | * RETURNS: 57 | * baseMap中独有元素的key的数组 58 | */ 59 | func FindExclusiveKey(baseMap, yMap *sync.Map) ([]interface{}, error) { 60 | if nil == baseMap || nil == yMap { 61 | return nil, errors.New("fromMap or toMap cannot be null") 62 | } 63 | 64 | var keys []interface{} 65 | baseMap.Range(func(key, value interface{}) bool { 66 | _, exist := yMap.Load(key) 67 | if !exist { 68 | keys = append(keys, key) 69 | } 70 | 71 | return true 72 | }) 73 | 74 | return keys, nil 75 | } 76 | -------------------------------------------------------------------------------- /server/server_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/valyala/fasthttp" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestIntegration(t *testing.T) { 10 | // 模拟两个下游服务 11 | server1 := &fasthttp.Server{ 12 | Handler: func(ctx *fasthttp.RequestCtx) { 13 | ctx.WriteString("server1 at 8080") 14 | ctx.SetConnectionClose() 15 | }, 16 | } 17 | server2 := &fasthttp.Server{ 18 | Handler: func(ctx *fasthttp.RequestCtx) { 19 | ctx.WriteString("server2 at 8081") 20 | ctx.SetConnectionClose() 21 | }, 22 | } 23 | 24 | // 启动服务 25 | go func() { 26 | server1.ListenAndServe("0.0.0.0:8080") 27 | }() 28 | 29 | go func() { 30 | server2.ListenAndServe("0.0.0.0:8081") 31 | }() 32 | 33 | 34 | // 启动gogate 35 | InitGogate("gogate-test.yml") 36 | gogate, err := NewGatewayServer("localhost", 7000, "route-test.yml", 10) 37 | if nil != err { 38 | t.Fatal(err) 39 | } 40 | go func() { 41 | err := gogate.Start() 42 | if nil != err { 43 | t.Fatal(err) 44 | } 45 | }() 46 | 47 | time.Sleep(time.Second) 48 | 49 | // 发请求 50 | _, buf, err := fasthttp.Get(make([]byte, 0, 20), "http://localhost:7000/service1/info") 51 | if nil != err { 52 | t.Fatal(err) 53 | } 54 | if "server1 at 8080" != string(buf) { 55 | t.Error("service 1 failed") 56 | } 57 | 58 | // 发请求 59 | _, buf, err = fasthttp.Get(make([]byte, 0, 20), "http://localhost:7000/service2/info") 60 | if nil != err { 61 | t.Fatal(err) 62 | } 63 | if "server2 at 8081" != string(buf) { 64 | t.Error("service 2 failed") 65 | } 66 | 67 | 68 | server1.Shutdown() 69 | server2.Shutdown() 70 | gogate.Shutdown() 71 | } 72 | 73 | 74 | -------------------------------------------------------------------------------- /throttle/rate_limiter_redis.go: -------------------------------------------------------------------------------- 1 | package throttle 2 | 3 | import ( 4 | "github.com/wanghongfei/gogate/perr" 5 | "io/ioutil" 6 | "os" 7 | "strconv" 8 | "time" 9 | 10 | "github.com/wanghongfei/gogate/redis" 11 | ) 12 | 13 | type RedisRateLimiter struct { 14 | qps string 15 | client *redis.RedisClient 16 | luaCode string 17 | 18 | serviceId string 19 | luaArgs []string 20 | } 21 | 22 | func NewRedisRateLimiter(client *redis.RedisClient, luaPath string, qps int, serviceId string) (*RedisRateLimiter, error) { 23 | if nil == client { 24 | return nil, perr.WrapSystemErrorf(nil, "redis client cannot be nil") 25 | } 26 | 27 | if qps < 1 { 28 | qps = 1 29 | } 30 | 31 | if !client.IsConnected() { 32 | err := client.Connect() 33 | if nil != err { 34 | return nil, perr.WrapSystemErrorf(err, "failed to connect to redis") 35 | } 36 | } 37 | 38 | luaF, err := os.Open(luaPath) 39 | if nil != err { 40 | return nil, err 41 | } 42 | defer luaF.Close() 43 | 44 | luaBuf, _ := ioutil.ReadAll(luaF) 45 | luaCode := string(luaBuf) 46 | 47 | qpsStr := strconv.Itoa(qps) 48 | 49 | return &RedisRateLimiter{ 50 | client: client, 51 | luaCode: luaCode, 52 | qps: qpsStr, 53 | serviceId: serviceId, 54 | luaArgs: []string{serviceId, qpsStr}, 55 | }, nil 56 | } 57 | 58 | func (rrl *RedisRateLimiter) Acquire() { 59 | for { 60 | ok := rrl.TryAcquire() 61 | if ok { 62 | break 63 | } 64 | 65 | time.Sleep(time.Millisecond * 100) 66 | } 67 | } 68 | 69 | func (rrl *RedisRateLimiter) TryAcquire() bool { 70 | resp, _ := rrl.client.ExeLuaInt(rrl.luaCode, nil, rrl.luaArgs) 71 | // resp, _ := rrl.client.ExeLuaInt(rrl.luaCode, nil, []string{rrl.serviceId, rrl.qps}) 72 | return resp == 1 73 | } 74 | 75 | -------------------------------------------------------------------------------- /server/route/router_test.go: -------------------------------------------------------------------------------- 1 | package route 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "os" 7 | "testing" 8 | 9 | "gopkg.in/yaml.v2" 10 | ) 11 | 12 | func TestLoadRoute(t *testing.T) { 13 | //routeMap, _, err := loadRoute("../route.yml") 14 | //if nil != err { 15 | // t.Error(err) 16 | //} 17 | // 18 | //for _, servInfo := range routeMap { 19 | // fmt.Printf("path = %v, id = %s\n", servInfo.Prefix, servInfo.Id) 20 | //} 21 | } 22 | 23 | func TestRouter_Match(t *testing.T) { 24 | r, err := NewRouter("../../route.yml") 25 | if nil != err { 26 | t.Fatal(err) 27 | } 28 | 29 | result := r.Match("/user") 30 | fmt.Println(result) 31 | if "user-service" != result.Id { 32 | t.Errorf("/user mismatch, %s\n", result) 33 | } 34 | 35 | result = r.Match("/order") 36 | fmt.Println(result) 37 | if "order-service" != result.Id { 38 | t.Errorf("/order mismatch, %s\n", result) 39 | } 40 | 41 | result = r.Match("/aaaa") 42 | if nil != result { 43 | t.Errorf("/aaaa mismatch, %s\n", result) 44 | } 45 | fmt.Println(result) 46 | 47 | result = r.Match("/img") 48 | if "localhost:8080,localhost:8080" != result.Host { 49 | t.Errorf("/img mismatch, %s\n", result) 50 | } 51 | fmt.Println(result) 52 | } 53 | 54 | func BenchmarkRouter_MatchLong(b *testing.B) { 55 | r, err := NewRouter("../../route.yml") 56 | if nil != err { 57 | b.Fatal(err) 58 | } 59 | 60 | 61 | for ix := 0; ix < b.N; ix++ { 62 | r.Match("/order/a/b/c/d/e/f/g") 63 | } 64 | } 65 | 66 | func TestYaml(t *testing.T) { 67 | f, err := os.Open("../route.yml") 68 | if nil != err { 69 | t.Error(err) 70 | return 71 | } 72 | defer f.Close() 73 | 74 | buf, _ := ioutil.ReadAll(f) 75 | 76 | yamlMap := make(map[string]interface{}) 77 | yaml.Unmarshal(buf, &yamlMap) 78 | 79 | fmt.Println(yamlMap["services"]) 80 | } 81 | -------------------------------------------------------------------------------- /conf/log.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "fmt" 5 | "github.com/lestrrat-go/file-rotatelogs" 6 | "go.uber.org/zap" 7 | "go.uber.org/zap/zapcore" 8 | "os" 9 | "time" 10 | ) 11 | 12 | var Log *zap.SugaredLogger 13 | 14 | // 初始化日志库; 15 | // dependsOn: 配置文件加载 16 | func initRotateLog() { 17 | logConfig := App.Log 18 | 19 | encoderConfig := zap.NewProductionEncoderConfig() 20 | encoderConfig.EncodeTime = timeEncodeFunc 21 | encoderConfig.TimeKey = "time" 22 | encoder := zapcore.NewConsoleEncoder(encoderConfig) 23 | 24 | var writer zapcore.WriteSyncer 25 | var logLevel zapcore.Level 26 | if !logConfig.ConsoleOnly { 27 | // 获取当前工作目录 28 | pwd, err := os.Getwd() 29 | fmt.Println(pwd) 30 | if nil != err { 31 | panic(err) 32 | } 33 | 34 | // 创建日志目录 35 | if !checkPathExist(logConfig.Directory) { 36 | fmt.Printf("log dir %s does not exist, create\n", logConfig.Directory) 37 | os.Mkdir(logConfig.Directory, os.ModePerm) 38 | } 39 | 40 | routateWriter, err := rotatelogs.New( 41 | pwd + "/" + logConfig.FilePattern, 42 | rotatelogs.WithLinkName(logConfig.FileLink), 43 | rotatelogs.WithMaxAge(24 * time.Hour * 30), 44 | rotatelogs.WithRotationTime(24 * time.Hour), 45 | ) 46 | 47 | if nil != err { 48 | panic(err) 49 | } 50 | 51 | logLevel = zapcore.InfoLevel 52 | writer = zapcore.AddSync(routateWriter) 53 | 54 | } else { 55 | logLevel = zapcore.DebugLevel 56 | writer = zapcore.AddSync(os.Stdout) 57 | } 58 | 59 | logCore := zapcore.NewCore(encoder, writer, logLevel) 60 | // logger := zap.New(logCore, zap.AddCaller()) 61 | logger := zap.New(logCore) 62 | Log = logger.Sugar() 63 | 64 | Log.Info("log initialized") 65 | } 66 | 67 | func checkPathExist(path string) bool { 68 | _, err := os.Stat(path) 69 | if nil == err { 70 | return true 71 | } 72 | 73 | return false 74 | } 75 | 76 | func timeEncodeFunc(t time.Time, enc zapcore.PrimitiveArrayEncoder) { 77 | enc.AppendString(t.Format("2006-01-02 15:04:05.000")) 78 | } 79 | -------------------------------------------------------------------------------- /examples/usage.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "github.com/valyala/fasthttp" 6 | "github.com/wanghongfei/gogate/conf" 7 | serv "github.com/wanghongfei/gogate/server" 8 | ) 9 | 10 | func main() { 11 | // 初始化 12 | serv.InitGogate("gogate.yml") 13 | 14 | // 创建Server 15 | server, err := serv.NewGatewayServer( 16 | conf.App.ServerConfig.Host, 17 | conf.App.ServerConfig.Port, 18 | conf.App.EurekaConfig.RouteFile, 19 | conf.App.ServerConfig.MaxConnection, 20 | ) 21 | if nil != err { 22 | fmt.Println(err) 23 | return 24 | } 25 | 26 | // ******************* 非必须 ************************* 27 | // 注册自定义过虑器, 在转发请求之前调用 28 | customPreFilter := serv.NewPreFilter("pre-log-filter1", PreLogFilter) 29 | server.AppendPreFilter(customPreFilter) 30 | // 在指定filter后面添加指定filter 31 | server.InsertPreFilterBehind("pre-log-filter1", customPreFilter) 32 | fmt.Printf("pre filters: %v\n", server.ExportAllPreFilters()) 33 | 34 | // optional 35 | // 注册自定义过虑器, 在转发请求之后调用 36 | customPostFilter := serv.NewPostFilter("post-log-filter1", PostLogFilter) 37 | server.AppendPostFilter(customPostFilter) 38 | // 在指定filter后面添加指定filter 39 | server.InsertPostFilterBehind("pre-log-filter1", customPostFilter) 40 | fmt.Printf("post filters: %v\n", server.ExportAllPostFilters()) 41 | 42 | // 自定义过虑器的添加方法必须在server启动之前调用, 启动后调用无效 43 | // ******************* 非必须 ************************* 44 | 45 | // 启动Server 46 | err = server.Start() 47 | if nil != err { 48 | fmt.Println(err) 49 | return 50 | } 51 | 52 | // 等待优雅关闭 53 | server.Shutdown() 54 | } 55 | 56 | // 此方法会在gogate转发请求之前调用 57 | // server: gogate服务器对象 58 | // ctx: fasthttp请求上下文 59 | // newRequest: gogate在转发请求时使用的请求对象指针, 可以做一些修改, 比如改请求参数,添加请求头之类 60 | // return: 返回true则会继续执行后面的过虑器(如果有的话), 返回false则不会执行 61 | func PreLogFilter(server *serv.Server, ctx *fasthttp.RequestCtx, newRequest *fasthttp.Request) bool { 62 | fmt.Println("request path: " + ctx.URI().String()) 63 | 64 | return true 65 | } 66 | 67 | // 此方法会在gogate转发请求之后调用 68 | // req: 转发给上游服务的HTTP请求 69 | // resp: 上游服务的响应 70 | // return: 返回true则会继续执行后面的过虑器(如果有的话), 返回false则不会执行 71 | func PostLogFilter(req *fasthttp.Request, resp *fasthttp.Response) bool { 72 | fmt.Println("response: " + resp.String()) 73 | 74 | return true 75 | } 76 | -------------------------------------------------------------------------------- /server/statistics/store_csv_file.go: -------------------------------------------------------------------------------- 1 | package stat 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "github.com/wanghongfei/gogate/perr" 7 | "os" 8 | "strconv" 9 | "time" 10 | ) 11 | 12 | // 文件流量存储器 13 | type CsvFileTraficInfoStore struct { 14 | // 流量日志文件所在目录 15 | logDir string 16 | 17 | // serviceId(string) -> *File 18 | fileMap map[string]*os.File 19 | } 20 | 21 | func NewCsvFileTraficInfoStore(logDir string) *CsvFileTraficInfoStore { 22 | return &CsvFileTraficInfoStore{ 23 | logDir: logDir, 24 | fileMap: make(map[string]*os.File), 25 | } 26 | } 27 | 28 | func (fs *CsvFileTraficInfoStore) Send(info *TraficInfo) error { 29 | buf := fs.ToCsv(info) 30 | f, err := fs.getFile(info.ServiceId) 31 | if nil != err { 32 | return perr.WrapSystemErrorf(nil, "failed to getFile => %w") 33 | } 34 | 35 | buf.WriteTo(f) 36 | 37 | return nil 38 | } 39 | 40 | func (fs *CsvFileTraficInfoStore) Close() error { 41 | errMsg := "" 42 | 43 | for _, file := range fs.fileMap { 44 | closeErr := file.Close() 45 | if nil != closeErr { 46 | errMsg = fmt.Sprintf("%s%s;", errMsg, closeErr.Error()) 47 | } 48 | } 49 | 50 | if "" != errMsg { 51 | return perr.WrapSystemErrorf(nil, errMsg) 52 | } 53 | 54 | return nil 55 | } 56 | 57 | // 从map中取出日志文件, 如果没有则打开 58 | func (fs *CsvFileTraficInfoStore) getFile(servId string) (*os.File, error) { 59 | logFile, exist := fs.fileMap[servId] 60 | if !exist { 61 | // 不存在则创建 62 | fName := fs.genFileName(servId) 63 | f, err := os.OpenFile(fName, os.O_CREATE | os.O_APPEND | os.O_RDWR, 0644) 64 | if nil != err { 65 | return nil, perr.WrapSystemErrorf(err, "failed to open file") 66 | } 67 | 68 | logFile = f 69 | fs.fileMap[servId] = f 70 | } 71 | 72 | return logFile, nil 73 | } 74 | 75 | func (fs *CsvFileTraficInfoStore) genFileName(servId string) string { 76 | now := time.Now() 77 | today := now.Format("20060102") 78 | 79 | return fs.logDir + "/" + servId + "_" + today + ".log" 80 | } 81 | 82 | func (fs *CsvFileTraficInfoStore) ToCsv(info *TraficInfo) *bytes.Buffer { 83 | var buf bytes.Buffer 84 | buf.WriteString(strconv.FormatInt(info.timestamp, 10)) 85 | buf.WriteString(",") 86 | 87 | buf.WriteString(strconv.Itoa(info.SuccessCount)) 88 | buf.WriteString(",") 89 | 90 | buf.WriteString(strconv.Itoa(info.FailedCount)) 91 | buf.WriteString(",") 92 | 93 | buf.WriteString(info.ServiceId) 94 | buf.WriteString("\n") 95 | 96 | return &buf 97 | } 98 | 99 | 100 | -------------------------------------------------------------------------------- /throttle/rate_limiter_memory.go: -------------------------------------------------------------------------------- 1 | package throttle 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | // 基于内存存储的限速器 11 | type MemoryRateLimiter struct { 12 | // 每秒生成的token数 13 | tokenPerSecond int 14 | 15 | // 生成一个token需要的micro second 16 | tokenGenMicro int64 17 | // 上次生成token的时间 18 | lastGenMicro int64 19 | 20 | // 当前桶内token数量 21 | tokenCount int 22 | mutex *sync.Mutex 23 | } 24 | 25 | // 创建限速器 26 | // qps: 每秒最大请求数 27 | func NewMemoryRateLimiter(qps int) *MemoryRateLimiter { 28 | if qps < 1 { 29 | qps = 1 30 | } 31 | 32 | rl := new(MemoryRateLimiter) 33 | rl.tokenPerSecond = qps 34 | rl.mutex = new(sync.Mutex) 35 | rl.tokenGenMicro = int64(int64(1000 * 1000) / int64(qps)) 36 | 37 | return rl 38 | } 39 | 40 | // 获取token, 如果没有则block 41 | func (rl *MemoryRateLimiter) Acquire() { 42 | rl.mutex.Lock() 43 | rl.consumeToken(true) 44 | rl.mutex.Unlock() 45 | } 46 | 47 | // 获取token, 成功返回true, 没有则返回false 48 | func (rl *MemoryRateLimiter) TryAcquire() bool { 49 | rl.mutex.Lock() 50 | got := rl.consumeToken(false) 51 | rl.mutex.Unlock() 52 | 53 | return got 54 | } 55 | 56 | func (rl *MemoryRateLimiter) fillBucket() { 57 | nowMicro := time.Now().UnixNano() / 1000 58 | // 如果是第一次获取, 直接填满1s的token 59 | if 0 == rl.lastGenMicro { 60 | rl.tokenCount = rl.tokenPerSecond 61 | rl.lastGenMicro = nowMicro 62 | return 63 | } 64 | 65 | // 计算上次生成token时的时间差 66 | microSecondDiff := nowMicro - rl.lastGenMicro 67 | // 计算应当生成的新token数 68 | newTokens := microSecondDiff / rl.tokenGenMicro 69 | if newTokens < 1 { 70 | return 71 | } 72 | rl.tokenCount += int(newTokens) 73 | // token总数不能超过qps值 74 | if rl.tokenCount > rl.tokenPerSecond { 75 | rl.tokenCount = rl.tokenPerSecond 76 | } 77 | 78 | rl.lastGenMicro = nowMicro 79 | 80 | // fmt.Printf("timeDiff = %v\n", microSecondDiff) 81 | } 82 | 83 | func (rl *MemoryRateLimiter) consumeToken(canSleep bool) bool { 84 | for rl.tokenCount == 0 { 85 | rl.fillBucket() 86 | if rl.tokenCount == 0 { 87 | if canSleep { 88 | time.Sleep(time.Microsecond * time.Duration(rl.tokenGenMicro)) 89 | } else { 90 | return false 91 | } 92 | } 93 | } 94 | 95 | rl.tokenCount-- 96 | return true 97 | } 98 | 99 | func (rl *MemoryRateLimiter) String() string { 100 | return "qps = " + strconv.Itoa(rl.tokenPerSecond) + 101 | ",tokenGenMicro = " + strconv.FormatInt(rl.tokenGenMicro, 10) + 102 | ",lastGenMicro = " + fmt.Sprintf("%v", rl.lastGenMicro) + 103 | ",tokenCount = " + strconv.Itoa(rl.tokenCount) 104 | } 105 | 106 | -------------------------------------------------------------------------------- /server/statistics/stat.go: -------------------------------------------------------------------------------- 1 | package stat 2 | 3 | import ( 4 | "time" 5 | 6 | . "github.com/wanghongfei/gogate/conf" 7 | ) 8 | 9 | // 流量记录器 10 | type TraficStat struct { 11 | store TraficInfoStore 12 | 13 | // trafic信息缓冲channel 14 | bufferChan chan *TraficInfo 15 | writeChan chan map[string]*TraficInfo 16 | // 发送间隔, 秒 17 | writeInterval int 18 | } 19 | 20 | // 创建统计器 21 | // queueSize: 最大可以放多少条流量对象不会block 22 | // interval: 每多少秒调用一次存储对象发送这期间积累的数据 23 | // traficStore: 数据保存逻辑的实现, 如CsvFileTrafficStore 24 | func NewTrafficStat(queueSize, interval int, traficStore TraficInfoStore) *TraficStat { 25 | if interval < 1 { 26 | interval = 1 27 | } 28 | 29 | return &TraficStat{ 30 | bufferChan: make(chan *TraficInfo, queueSize), 31 | writeChan: make(chan map[string]*TraficInfo, interval + 1), 32 | writeInterval: interval, 33 | 34 | store: traficStore, 35 | } 36 | } 37 | 38 | // 启动流量记录routine 39 | func (ts *TraficStat) StartRecordTrafic() { 40 | // 启动统计routine 41 | go ts.traficAggregateRoutine() 42 | // 启动写日志任务routine 43 | go ts.traficLogRoutine() 44 | } 45 | 46 | // 记录流量 47 | func (ts *TraficStat) RecordTrafic(info *TraficInfo) { 48 | // 验证 49 | if nil == info || info.SuccessCount < 0 || info.FailedCount < 0 { 50 | // 无效数据丢弃 51 | return 52 | } 53 | 54 | ts.bufferChan <- info 55 | } 56 | 57 | // 每ts.writeInterval秒累计一次此时间段内的流量信息, 封装成写任务扔到writeChan中 58 | func (ts *TraficStat) traficAggregateRoutine() { 59 | ticker := time.NewTicker(time.Second * time.Duration(ts.writeInterval)) 60 | 61 | for { 62 | <- ticker.C 63 | 64 | // 取出当前channel全部元素 65 | size := len(ts.bufferChan) 66 | if 0 == size { 67 | // 上一个时间周期内没有元素 68 | // skip 69 | continue 70 | } 71 | 72 | // 统计在此时间周期里的数据之和 73 | sumMap := make(map[string]*TraficInfo) 74 | for ix := 0; ix < size; ix++ { 75 | elem := <- ts.bufferChan 76 | 77 | targetInfo, exist := sumMap[elem.ServiceId] 78 | if !exist { 79 | targetInfo = new(TraficInfo) 80 | targetInfo.timestamp = time.Now().UnixNano() / 1e6 81 | targetInfo.ServiceId = elem.ServiceId 82 | sumMap[elem.ServiceId] = targetInfo 83 | } 84 | 85 | targetInfo.FailedCount += elem.FailedCount 86 | targetInfo.SuccessCount += elem.SuccessCount 87 | } 88 | 89 | ts.writeChan <- sumMap 90 | } 91 | } 92 | 93 | func (ts *TraficStat) traficLogRoutine() { 94 | for servMap := range ts.writeChan { 95 | for _, traffic := range servMap { 96 | err := ts.store.Send(traffic) 97 | if nil != err { 98 | Log.Error(err) 99 | } 100 | } 101 | 102 | } 103 | 104 | } 105 | 106 | // 定义流量信息 107 | type TraficInfo struct { 108 | ServiceId string 109 | SuccessCount int 110 | FailedCount int 111 | 112 | // unix毫秒数 113 | timestamp int64 114 | } 115 | 116 | // 定义流量数据存储方式 117 | type TraficInfoStore interface { 118 | // 发送trafic数据 119 | Send(info *TraficInfo) error 120 | // 清理资源 121 | Close() error 122 | } 123 | 124 | 125 | -------------------------------------------------------------------------------- /discovery/consul_client.go: -------------------------------------------------------------------------------- 1 | package discovery 2 | 3 | import ( 4 | "github.com/hashicorp/consul/api" 5 | "github.com/wanghongfei/gogate/conf" 6 | . "github.com/wanghongfei/gogate/conf" 7 | "github.com/wanghongfei/gogate/perr" 8 | "strconv" 9 | "strings" 10 | ) 11 | 12 | type ConsulClient struct { 13 | // 继承方法 14 | *periodicalRefreshClient 15 | 16 | client *api.Client 17 | 18 | // 保存服务地址 19 | // key: 服务名:版本号, 版本号为eureka注册信息中的metadata[version]值 20 | // val: []*InstanceInfo 21 | registryMap *InsInfoArrSyncMap 22 | } 23 | 24 | 25 | func NewConsulClient() (Client, error) { 26 | cfg := &api.Config{} 27 | cfg.Address = conf.App.ConsulConfig.Address 28 | cfg.Scheme = "http" 29 | 30 | c, err := api.NewClient(cfg) 31 | if nil != err { 32 | return nil, perr.WrapSystemErrorf(err, "failed to init consule client") 33 | } 34 | 35 | consuleClient := &ConsulClient{client:c} 36 | consuleClient.periodicalRefreshClient = newPeriodicalRefresh(consuleClient) 37 | 38 | return consuleClient, nil 39 | } 40 | 41 | func (c *ConsulClient) GetInternalRegistryStore() *InsInfoArrSyncMap { 42 | return c.registryMap 43 | } 44 | 45 | func (c *ConsulClient) SetInternalRegistryStore(registry *InsInfoArrSyncMap) { 46 | c.registryMap = registry 47 | } 48 | 49 | func (c *ConsulClient) Get(serviceId string) []*InstanceInfo { 50 | instance, exist := c.registryMap.Get(serviceId) 51 | if !exist { 52 | return nil 53 | } 54 | 55 | return instance 56 | } 57 | 58 | 59 | func (c *ConsulClient) QueryServices() ([]*InstanceInfo, error) { 60 | servMap, err := c.client.Agent().Services() 61 | if nil != err { 62 | return nil, err 63 | } 64 | 65 | // 查出所有健康实例 66 | healthList, _, err := c.client.Health().State("passing", &api.QueryOptions{}) 67 | if nil != err { 68 | return nil, perr.WrapSystemErrorf(err, "failed to query consul") 69 | } 70 | 71 | instances := make([]*InstanceInfo, 0, 10) 72 | for _, servInfo := range servMap { 73 | servName := servInfo.Service 74 | servId := servInfo.ID 75 | 76 | // 查查在healthList中有没有 77 | isHealth := false 78 | for _, healthInfo := range healthList { 79 | if healthInfo.ServiceName == servName && healthInfo.ServiceID == servId { 80 | isHealth = true 81 | break 82 | } 83 | } 84 | 85 | if !isHealth { 86 | Log.Warn("following instance is not health, skip; service name: %v, service id: %v", servName, servId) 87 | continue 88 | } 89 | 90 | instances = append( 91 | instances, 92 | &InstanceInfo{ 93 | ServiceName: strings.ToUpper(servInfo.Service), 94 | Addr: servInfo.Address + ":" + strconv.Itoa(servInfo.Port), 95 | Meta: servInfo.Meta, 96 | }, 97 | ) 98 | } 99 | 100 | return instances, nil 101 | } 102 | 103 | func (c *ConsulClient) Register() error { 104 | return perr.WrapSystemErrorf(nil, "not implement yet") 105 | } 106 | 107 | func (c *ConsulClient) UnRegister() error { 108 | return perr.WrapSystemErrorf(nil, "not implement yet") 109 | } 110 | -------------------------------------------------------------------------------- /server/route/trie_tree.go: -------------------------------------------------------------------------------- 1 | package route 2 | 3 | var charPosMap = make(map[rune]int) 4 | var urlCharCount int 5 | 6 | func init() { 7 | var urlCharArray = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~:/?#[]@!$&'()*+,;=%" 8 | for ix, char := range urlCharArray { 9 | charPosMap[char] = ix 10 | } 11 | 12 | urlCharCount = len(urlCharArray) 13 | } 14 | 15 | // 字典树 16 | type TrieTree struct { 17 | root *treeNode 18 | } 19 | 20 | // 创建空字典树 21 | func NewTrieTree() *TrieTree { 22 | return &TrieTree{ 23 | root: newTreeNode(0, nil), 24 | } 25 | } 26 | 27 | // 查找path对应的服务信息 28 | // path: 请求路径 29 | func (tree *TrieTree) Search(path string) *ServiceInfo { 30 | node := tree.root 31 | for _, char := range path { 32 | node = node.findSubNode(char) 33 | if nil == node { 34 | return nil 35 | } 36 | } 37 | 38 | return node.Data 39 | } 40 | 41 | // 搜索路径上遇到的第一个字符串 42 | // path: 请求路径 43 | func (tree *TrieTree) SearchFirst(path string) *ServiceInfo { 44 | node := tree.root 45 | for _, char := range path { 46 | node = node.findSubNode(char) 47 | if nil == node { 48 | return nil 49 | } 50 | 51 | if nil != node.Data { 52 | return node.Data 53 | } 54 | } 55 | 56 | return nil 57 | } 58 | 59 | // 添加一条path->serviceInfo映射 60 | func (tree *TrieTree) PutString(path string, data *ServiceInfo) { 61 | pathRunes := []rune(path) 62 | LEN := len(pathRunes) 63 | 64 | node := tree.root 65 | for ix, char := range pathRunes { 66 | subNode := findNode(char, node.SubNodes) 67 | if nil == subNode { 68 | var newNode *treeNode 69 | // 是最后一个字符 70 | if ix == LEN - 1 { 71 | newNode = newTreeNode(char, data) 72 | } else { 73 | newNode = newTreeNode(char, nil) 74 | } 75 | 76 | node.addSubNode(newNode) 77 | node = newNode 78 | 79 | } else if ix == LEN - 1 { 80 | subNode.Data = data 81 | 82 | } else { 83 | node = subNode 84 | } 85 | } 86 | } 87 | 88 | func findNode(char rune, nodeList []*treeNode) *treeNode { 89 | if nil == nodeList { 90 | return nil 91 | } 92 | 93 | pos := mapPosition(char) 94 | return nodeList[pos] 95 | } 96 | 97 | type treeNode struct { 98 | Char rune 99 | Data *ServiceInfo 100 | 101 | SubNodes []*treeNode 102 | } 103 | 104 | func newTreeNode(char rune, data *ServiceInfo) *treeNode { 105 | node := &treeNode{ 106 | Char: char, 107 | Data: data, 108 | SubNodes: nil, 109 | } 110 | 111 | return node 112 | } 113 | 114 | func mapPosition(char rune) int { 115 | return charPosMap[char] 116 | } 117 | 118 | // 添加子节点 119 | func (node *treeNode) addSubNode(newNode *treeNode) { 120 | if nil == node.SubNodes { 121 | node.SubNodes = make([]*treeNode, urlCharCount) 122 | } 123 | 124 | position := mapPosition(newNode.Char) 125 | node.SubNodes[position] = newNode 126 | } 127 | 128 | func (node *treeNode) findSubNode(target rune) *treeNode { 129 | if nil == node.SubNodes { 130 | return nil 131 | } 132 | 133 | pos := mapPosition(target) 134 | return node.SubNodes[pos] 135 | } 136 | 137 | -------------------------------------------------------------------------------- /server/server_send.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | . "github.com/wanghongfei/gogate/conf" 5 | "github.com/wanghongfei/gogate/discovery" 6 | "github.com/wanghongfei/gogate/perr" 7 | "github.com/wanghongfei/gogate/server/route" 8 | "strings" 9 | 10 | "github.com/valyala/fasthttp" 11 | "github.com/wanghongfei/gogate/utils" 12 | ) 13 | 14 | const META_VERSION = "version" 15 | 16 | // 转发请求到指定微服务 17 | // return: 18 | // Response: 响应对象; 19 | // string: 下游服务名 20 | // error: 错误 21 | func (serv *Server) sendRequest(ctx *fasthttp.RequestCtx, req *fasthttp.Request) (*fasthttp.Response, string, error) { 22 | // 获取服务信息 23 | info := ctx.UserValue(ROUTE_INFO).(*route.ServiceInfo) 24 | 25 | var logRecordName string 26 | // 需要从注册列表中查询地址 27 | if info.Id != "" { 28 | if serv.IsInStaticMode() { 29 | return nil, "", perr.WrapBizErrorf(nil, "no static address found for this service") 30 | } 31 | 32 | logRecordName = info.Id 33 | 34 | // 获取Client 35 | appId := strings.ToUpper(info.Id) 36 | 37 | // 灰度, 选择版本 38 | version := chooseVersion(info.Canary) 39 | 40 | // 取出指定服务的所有实例 41 | serviceInstances := serv.discoveryClient.Get(appId) 42 | if nil == serviceInstances { 43 | return nil, "", perr.WrapBizErrorf(nil, "no instance %s for service (service is offline)", appId) 44 | } 45 | 46 | // 按version过滤 47 | if "" != version { 48 | serviceInstances = filterWithVersion(serviceInstances, version) 49 | if 0 == len(serviceInstances) { 50 | // 此version下没有实例 51 | return nil, "", perr.WrapBizErrorf(nil, "no instance %s:%s for service", appId, version) 52 | } 53 | } 54 | 55 | // 负载均衡 56 | targetInstance := serv.lb.Choose(serviceInstances) 57 | // 修改请求的host为目标主机地址 58 | req.URI().SetHost(targetInstance.Addr) 59 | 60 | } else { 61 | logRecordName = info.Name 62 | 63 | // 直接使用后面的地址 64 | hostList := strings.Split(info.Host, ",") 65 | 66 | targetAddr := serv.lb.ChooseByAddresses(hostList) 67 | req.URI().SetHost(targetAddr) 68 | } 69 | 70 | // 发请求 71 | resp := new(fasthttp.Response) 72 | err := serv.fastClient.Do(req, resp) 73 | if nil != err { 74 | return nil, "", perr.WrapSystemErrorf(nil, "failed to send request to downstream service") 75 | } 76 | 77 | return resp, logRecordName, nil 78 | } 79 | 80 | // 过滤出meta里version字段为指定值的实例 81 | func filterWithVersion(instances []*discovery.InstanceInfo, targetVersion string) []*discovery.InstanceInfo { 82 | result := make([]*discovery.InstanceInfo, 0, 5) 83 | 84 | for _, ins := range instances { 85 | if ins.Meta[META_VERSION] == targetVersion { 86 | result = append(result, ins) 87 | } 88 | } 89 | 90 | return result 91 | } 92 | 93 | func chooseVersion(canaryInfos []*route.CanaryInfo) string { 94 | if nil == canaryInfos || len(canaryInfos) == 0 { 95 | return "" 96 | } 97 | 98 | var weights []int 99 | for _, info := range canaryInfos { 100 | weights = append(weights, info.Weight) 101 | } 102 | 103 | index := utils.RandomByWeight(weights) 104 | if -1 == index { 105 | Log.Warn("random interval returned -1") 106 | return "" 107 | } 108 | 109 | return canaryInfos[index].Meta 110 | } 111 | -------------------------------------------------------------------------------- /discovery/refresh.go: -------------------------------------------------------------------------------- 1 | package discovery 2 | 3 | import ( 4 | . "github.com/wanghongfei/gogate/conf" 5 | "github.com/wanghongfei/gogate/perr" 6 | "github.com/wanghongfei/gogate/utils" 7 | "sync" 8 | "time" 9 | ) 10 | const REGISTRY_REFRESH_INTERVAL = 30 11 | 12 | type periodicalRefreshClient struct { 13 | client Client 14 | } 15 | 16 | func newPeriodicalRefresh(c Client) *periodicalRefreshClient { 17 | return &periodicalRefreshClient{c} 18 | } 19 | 20 | // 向eureka查询注册列表, 刷新本地列表 21 | func (r *periodicalRefreshClient) StartPeriodicalRefresh() error { 22 | Log.Infof("refresh registry every %d sec", REGISTRY_REFRESH_INTERVAL) 23 | 24 | refreshRegistryChan := make(chan error) 25 | 26 | isBootstrap := true 27 | go func() { 28 | ticker := time.NewTicker(REGISTRY_REFRESH_INTERVAL * time.Second) 29 | 30 | for { 31 | Log.Info("registry refresh started") 32 | err := r.doRefresh() 33 | if nil != err { 34 | // 如果是第一次查询失败, 退出程序 35 | if isBootstrap { 36 | refreshRegistryChan <- perr.WrapSystemErrorf(err, "failed to refresh registry") 37 | return 38 | 39 | } else { 40 | Log.Error(err) 41 | } 42 | 43 | } 44 | Log.Info("done refreshing registry") 45 | 46 | if isBootstrap { 47 | isBootstrap = false 48 | close(refreshRegistryChan) 49 | } 50 | 51 | <-ticker.C 52 | } 53 | }() 54 | 55 | return <- refreshRegistryChan 56 | } 57 | 58 | func (r *periodicalRefreshClient) doRefresh() error { 59 | instances, err := r.client.QueryServices() 60 | 61 | if nil != err { 62 | return perr.WrapSystemErrorf(err, "failed to query all services") 63 | } 64 | 65 | if nil == instances { 66 | Log.Info("no instance found") 67 | return nil 68 | } 69 | 70 | Log.Infof("total app count: %d", len(instances)) 71 | 72 | newRegistryMap := r.groupByService(instances) 73 | 74 | r.refreshRegistryMap(newRegistryMap) 75 | 76 | return nil 77 | 78 | } 79 | 80 | 81 | // 将所有实例按服务名进行分组 82 | func (r *periodicalRefreshClient) groupByService(instances []*InstanceInfo) *sync.Map { 83 | servMap := new(sync.Map) 84 | for _, ins := range instances { 85 | infosGeneric, exist := servMap.Load(ins.ServiceName) 86 | if !exist { 87 | infosGeneric = make([]*InstanceInfo, 0, 5) 88 | infosGeneric = append(infosGeneric.([]*InstanceInfo), ins) 89 | 90 | } else { 91 | infosGeneric = append(infosGeneric.([]*InstanceInfo), ins) 92 | } 93 | servMap.Store(ins.ServiceName, infosGeneric) 94 | } 95 | return servMap 96 | } 97 | 98 | 99 | // 更新本地注册列表 100 | // s: gogate server对象 101 | // newRegistry: 刚从eureka查出的最新服务列表 102 | func (r *periodicalRefreshClient) refreshRegistryMap(newRegistry *sync.Map) { 103 | if nil == r.client.GetInternalRegistryStore() { 104 | r.client.SetInternalRegistryStore(NewInsInfoArrSyncMap()) 105 | } 106 | 107 | // 找出本地列表存在, 但新列表中不存在的服务 108 | exclusiveKeys, _ := utils.FindExclusiveKey(r.client.GetInternalRegistryStore().GetMap(), newRegistry) 109 | // 删除本地多余的服务 110 | utils.DelKeys(r.client.GetInternalRegistryStore().GetMap(), exclusiveKeys) 111 | // 将新列表中的服务合并到本地列表中 112 | utils.MergeSyncMap(newRegistry, r.client.GetInternalRegistryStore().GetMap()) 113 | } 114 | -------------------------------------------------------------------------------- /perr/error.go: -------------------------------------------------------------------------------- 1 | package perr 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "runtime" 7 | ) 8 | 9 | // 增强内置error接口 10 | type EnvError interface { 11 | error 12 | 13 | LineNumber() int 14 | SrcName() string 15 | } 16 | 17 | type WithBottomMsg interface { 18 | BottomMsg() string 19 | } 20 | 21 | // 不能暴露给用户的系统级错误 22 | type SystemError struct { 23 | srcName string 24 | lineNumber int 25 | 26 | Msg string 27 | 28 | Cause error 29 | } 30 | 31 | func (s *SystemError) LineNumber() int { 32 | return s.lineNumber 33 | } 34 | 35 | func (s *SystemError) SrcName() string { 36 | return s.srcName 37 | } 38 | 39 | func (s *SystemError) Error() string { 40 | return s.Msg 41 | } 42 | 43 | func (s *SystemError) Unwrap() error { 44 | return s.Cause 45 | } 46 | 47 | // 可以给用户看的业务错误 48 | type BizError struct { 49 | srcName string 50 | lineNumber int 51 | 52 | Msg string 53 | bottomMsg string 54 | 55 | Cause error 56 | } 57 | 58 | func (b *BizError) Error() string { 59 | return b.Msg 60 | } 61 | 62 | func (b *BizError) LineNumber() int { 63 | return b.lineNumber 64 | } 65 | 66 | func (b *BizError) SrcName() string { 67 | return b.srcName 68 | } 69 | 70 | func (b *BizError) BottomMsg() string { 71 | return b.bottomMsg 72 | } 73 | 74 | func (b *BizError) Unwrap() error { 75 | return b.Cause 76 | } 77 | 78 | // 创建一个携带环境信息的error, 包含文件名 + 行号 79 | func WrapBizErrorf(cause error, format string, args ...interface{}) EnvError { 80 | _, srcName, line, _ := runtime.Caller(1) 81 | msg := fmt.Sprintf(format, args...) 82 | 83 | bottomMsg := msg 84 | if nil != cause { 85 | if withBottomMsg, ok := cause.(WithBottomMsg); ok { 86 | bottomMsg = withBottomMsg.BottomMsg() 87 | } 88 | } 89 | 90 | return &BizError{ 91 | srcName: srcName, 92 | lineNumber: line, 93 | Msg: msg, 94 | Cause: cause, 95 | bottomMsg: bottomMsg, 96 | } 97 | } 98 | 99 | // 创建一个携带环境信息的error, 包含文件名 + 行号 100 | func WrapSystemErrorf(cause error, format string, args ...interface{}) error { 101 | _, srcName, line, _ := runtime.Caller(1) 102 | fmtErr := fmt.Errorf(format, args...) 103 | 104 | return &SystemError{ 105 | srcName: srcName, 106 | lineNumber: line, 107 | Msg: fmtErr.Error(), 108 | Cause: cause, 109 | } 110 | } 111 | 112 | func EnvMsg(err error) string { 113 | msg := err.Error() 114 | if envError, ok := err.(EnvError); ok { 115 | msg = fmt.Sprintf("[%s:%d]%s", envError.SrcName(), envError.LineNumber(), msg) 116 | } 117 | 118 | cause := err 119 | for { 120 | cause = errors.Unwrap(cause) 121 | if nil == cause { 122 | break 123 | } 124 | 125 | if envError, ok := cause.(EnvError); ok { 126 | msg = fmt.Sprintf("%s => [%s:%d]%s", msg, envError.SrcName(), envError.LineNumber(), cause.Error()) 127 | } 128 | 129 | } 130 | 131 | return msg 132 | } 133 | 134 | 135 | func ParseError(err error) (*BizError, *SystemError, error) { 136 | // 是否包含系统错误 137 | var sysErr *SystemError 138 | if errors.As(err, &sysErr) { 139 | return nil, sysErr, nil 140 | } 141 | 142 | // 是否包含业务错误 143 | var bizErr *BizError 144 | if errors.As(err, &bizErr) { 145 | return bizErr, nil, nil 146 | } 147 | 148 | return nil, nil, err 149 | } 150 | -------------------------------------------------------------------------------- /conf/config.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "errors" 5 | "io/ioutil" 6 | "os" 7 | 8 | "gopkg.in/yaml.v2" 9 | ) 10 | 11 | type GateConfig struct { 12 | Version string`yaml:"version"` 13 | 14 | ServerConfig *ServerConfig`yaml:"server"` 15 | RedisConfig *RedisConfig`yaml:"redis"` 16 | 17 | EurekaConfig *EurekaConfig`yaml:"eureka"` 18 | ConsulConfig *ConsulConfig`yaml:"consul"` 19 | 20 | Traffic *TrafficConfig`yaml:"traffic"` 21 | 22 | Log struct { 23 | ConsoleOnly bool`yaml:"console-only"` 24 | FilePattern string`yaml:"file-pattern"` 25 | FileLink string`yaml:"file-link"` 26 | Directory string`yaml:"directory"` 27 | }`yaml:"log"` 28 | } 29 | 30 | type ServerConfig struct { 31 | AppName string`yaml:"appName"` 32 | Host string`yaml:"host"` 33 | Port int`yaml:"port"` 34 | MaxConnection int`yaml:"maxConnection"` 35 | // 请求超时时间, ms 36 | Timeout int`yaml:"timeout"` 37 | 38 | } 39 | 40 | type EurekaConfig struct { 41 | Enable bool`yaml:"enable"` 42 | ConfigFile string`yaml:"configFile"` 43 | RouteFile string`yaml:"routeFile"` 44 | EvictionDuration uint`yaml:"evictionDuration"` 45 | HeartbeatInterval int`yaml:"heartbeatInterval"` 46 | } 47 | 48 | type ConsulConfig struct { 49 | Enable bool`yaml:"enable"` 50 | Address string`yaml:"address"` 51 | } 52 | 53 | type TrafficConfig struct { 54 | EnableTrafficRecord bool`yaml:"enableTrafficRecord"` 55 | TrafficLogDir string`yaml:"trafficLogDir"` 56 | 57 | } 58 | 59 | type RedisConfig struct { 60 | Enabled bool 61 | Addr string 62 | RateLimiterLua string`yaml:"rateLimiterLua"` 63 | } 64 | 65 | var App *GateConfig 66 | 67 | func LoadConfig(filename string) { 68 | f, err := os.Open(filename) 69 | if nil != err { 70 | Log.Error(err) 71 | panic(err) 72 | } 73 | defer f.Close() 74 | 75 | buf, _ := ioutil.ReadAll(f) 76 | 77 | config := new(GateConfig) 78 | err = yaml.Unmarshal(buf, config) 79 | if nil != err { 80 | Log.Error(err) 81 | panic(err) 82 | } 83 | 84 | validateGogateConfig(config) 85 | } 86 | 87 | func InitLog() { 88 | initRotateLog() 89 | } 90 | 91 | func validateGogateConfig(config *GateConfig) error { 92 | if nil == config { 93 | return errors.New("config is nil") 94 | } 95 | 96 | // 检查eureka配置 97 | euConfig := config.EurekaConfig 98 | if nil == euConfig { 99 | return errors.New("eureka config cannot be empty") 100 | } 101 | if euConfig.ConfigFile == "" || euConfig.RouteFile == "" { 102 | return errors.New("eureka or route config file cannot be empty") 103 | } 104 | 105 | servCfg := config.ServerConfig 106 | if servCfg.AppName == "" { 107 | servCfg.AppName = "gogate" 108 | } 109 | 110 | if servCfg.Host == "" { 111 | servCfg.Host = "127.0.0.1" 112 | } 113 | 114 | if servCfg.Port == 0 { 115 | servCfg.Port = 8080 116 | } 117 | 118 | if servCfg.MaxConnection == 0 { 119 | servCfg.MaxConnection = 1000 120 | } 121 | 122 | if servCfg.Timeout == 0 { 123 | servCfg.Timeout = 3000 124 | } 125 | 126 | 127 | trafficCfg := config.Traffic 128 | if trafficCfg.EnableTrafficRecord { 129 | if trafficCfg.TrafficLogDir == "" { 130 | trafficCfg.TrafficLogDir = "/tmp" 131 | } 132 | } 133 | 134 | rdConfig := config.RedisConfig 135 | if rdConfig.Enabled { 136 | if rdConfig.Addr == "" { 137 | rdConfig.Addr = "127.0.0.1:6379" 138 | } 139 | } 140 | 141 | App = config 142 | 143 | return nil 144 | } 145 | 146 | -------------------------------------------------------------------------------- /server/route/router.go: -------------------------------------------------------------------------------- 1 | package route 2 | 3 | import ( 4 | "github.com/wanghongfei/gogate/perr" 5 | "gopkg.in/yaml.v2" 6 | "io/ioutil" 7 | "os" 8 | ) 9 | 10 | type Router struct { 11 | // 配置文件路径 12 | cfgPath string 13 | 14 | // path(string) -> *ServiceInfo 15 | pathMatcher *PathMatcher 16 | 17 | ServInfos []*ServiceInfo 18 | } 19 | 20 | type ServiceInfo struct { 21 | Id string 22 | Prefix string 23 | Host string 24 | Name string 25 | StripPrefix bool`yaml:"strip-prefix"` 26 | Qps int 27 | 28 | Canary []*CanaryInfo 29 | } 30 | 31 | type CanaryInfo struct { 32 | Meta string 33 | Weight int 34 | } 35 | 36 | func (info *ServiceInfo) String() string { 37 | return "prefix = " + info.Prefix + ", id = " + info.Id + ", host = " + info.Host 38 | } 39 | 40 | /* 41 | * 创建路由器 42 | * 43 | * PARAMS: 44 | * - path: 路由配置文件路径 45 | * 46 | */ 47 | func NewRouter(path string) (*Router, error) { 48 | matcher, servInfos, err := loadRoute(path) 49 | if nil != err { 50 | return nil, perr.WrapSystemErrorf(err, "failed to load route info") 51 | } 52 | 53 | 54 | return &Router{ 55 | pathMatcher: matcher, 56 | cfgPath: path, 57 | ServInfos: servInfos, 58 | }, nil 59 | } 60 | 61 | /* 62 | * 重新加载路由器 63 | */ 64 | func (r *Router) ReloadRoute() error { 65 | matcher, servInfos, err := loadRoute(r.cfgPath) 66 | if nil != err { 67 | return perr.WrapSystemErrorf(err, "failed to load route info") 68 | } 69 | 70 | r.ServInfos = servInfos 71 | r.pathMatcher = matcher 72 | 73 | return nil 74 | } 75 | 76 | /* 77 | * 根据uri选择一个最匹配的appId 78 | * 79 | * RETURNS: 80 | * 返回最匹配的ServiceInfo 81 | */ 82 | func (r *Router) Match(reqPath string) *ServiceInfo { 83 | 84 | return r.pathMatcher.Match(reqPath) 85 | } 86 | 87 | func loadRoute(path string) (*PathMatcher, []*ServiceInfo, error) { 88 | // 打开配置文件 89 | routeFile, err := os.Open(path) 90 | if nil != err { 91 | return nil, nil, perr.WrapSystemErrorf(err, "failed to open file") 92 | } 93 | defer routeFile.Close() 94 | 95 | // 读取 96 | buf, err := ioutil.ReadAll(routeFile) 97 | if nil != err { 98 | return nil, nil, err 99 | } 100 | 101 | // 解析yml 102 | // ymlMap := make(map[string]*ServiceInfo) 103 | ymlMap := make(map[string]map[string]*ServiceInfo) 104 | err = yaml.UnmarshalStrict(buf, &ymlMap) 105 | if nil != err { 106 | return nil, nil, err 107 | } 108 | 109 | servInfos := make([]*ServiceInfo, 0, 10) 110 | 111 | // 构造 path->serviceId 映射 112 | // 保存到字典树中 113 | tree := NewTrieTree() 114 | // 保存到map中 115 | routeMap := make(map[string]*ServiceInfo) 116 | for name, info := range ymlMap["services"] { 117 | // 验证 118 | err = validateServiceInfo(info) 119 | if nil != err { 120 | return nil, nil, perr.WrapSystemErrorf(err, "invalid config for %s", name) 121 | } 122 | 123 | tree.PutString(info.Prefix, info) 124 | routeMap[info.Prefix] = info 125 | 126 | servInfos = append(servInfos, info) 127 | } 128 | 129 | 130 | matcher := &PathMatcher{ 131 | routeMap: routeMap, 132 | routeTrieTree: tree, 133 | } 134 | return matcher, servInfos, nil 135 | } 136 | 137 | func validateServiceInfo(info *ServiceInfo) error { 138 | if nil == info { 139 | return perr.WrapSystemErrorf(nil, "info is empty") 140 | } 141 | 142 | if "" == info.Id && "" == info.Host { 143 | return perr.WrapSystemErrorf(nil, "id and host are both empty") 144 | } 145 | 146 | if "" == info.Prefix { 147 | return perr.WrapSystemErrorf(nil, "path is empty") 148 | } 149 | 150 | return nil 151 | } -------------------------------------------------------------------------------- /server/server_handler.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/valyala/fasthttp" 5 | . "github.com/wanghongfei/gogate/conf" 6 | "github.com/wanghongfei/gogate/perr" 7 | "github.com/wanghongfei/gogate/utils" 8 | "runtime" 9 | "strconv" 10 | ) 11 | 12 | const ( 13 | SERVICE_NAME = "key_service_name" 14 | REQUEST_PATH = "key_request_path" 15 | ROUTE_INFO = "key_route_info" 16 | REQUEST_ID = "key_req_id" 17 | STOPWATCH = "key_stopwatch" 18 | 19 | RELOAD_PATH = "/_mgr/reload" 20 | ) 21 | 22 | // HTTP请求处理方法. 23 | func (serv *Server) HandleRequest(ctx *fasthttp.RequestCtx) { 24 | defer recoverPanic(ctx, serv) 25 | 26 | // 计时器 27 | sw := utils.NewStopwatch() 28 | ctx.SetUserValue(STOPWATCH, sw) 29 | 30 | 31 | // 取出请求path 32 | path := string(ctx.Path()) 33 | ctx.SetUserValue(REQUEST_PATH, path) 34 | 35 | // 生成唯一id 36 | reqId := utils.GenerateUuid() 37 | ctx.SetUserValue(REQUEST_ID, reqId) 38 | 39 | Log.Infof("request %d received, method = %s, path = %s, body = %s", reqId, string(ctx.Method()), path, string(ctx.Request.Body())) 40 | 41 | // 处理reload请求 42 | if path == RELOAD_PATH { 43 | err := serv.ReloadRoute() 44 | if nil != err { 45 | Log.Error(err) 46 | NewResponse(path, err.Error()).Send(ctx) 47 | return 48 | } 49 | 50 | // ctx.WriteString(serv.ExtractRoute()) 51 | ctx.WriteString("ok") 52 | return 53 | } 54 | 55 | newReq := new(fasthttp.Request) 56 | ctx.Request.CopyTo(newReq) 57 | 58 | // 调用Pre过虑器 59 | ok := invokePreFilters(serv, ctx, newReq) 60 | if !ok { 61 | return 62 | } 63 | 64 | // 发请求 65 | resp, logRecordName, err := serv.sendRequest(ctx, newReq) 66 | // 错误处理 67 | if nil != err { 68 | err = perr.WrapBizErrorf(err, "anther layer") 69 | Log.Errorf("request %d, %s", reqId, perr.EnvMsg(err)) 70 | 71 | // 解析错误类型 72 | bizErr, sysErr, _ := perr.ParseError(err) 73 | var responseMessage string 74 | if nil != bizErr { 75 | // 业务错误 76 | responseMessage = bizErr.BottomMsg() 77 | 78 | } else if nil != sysErr { 79 | // 系统错误 80 | responseMessage = "system error" 81 | ctx.SetStatusCode(fasthttp.StatusInternalServerError) 82 | 83 | } else { 84 | responseMessage = err.Error() 85 | } 86 | 87 | NewResponse(path, responseMessage).Send(ctx) 88 | 89 | serv.recordTraffic(logRecordName, false) 90 | return 91 | } 92 | serv.recordTraffic(logRecordName, true) 93 | 94 | 95 | // 调用Post过虑器 96 | ok = invokePostFilters(serv, newReq, resp) 97 | if !ok { 98 | return 99 | } 100 | 101 | // 返回响应 102 | sendResponse(ctx, resp, reqId, sw) 103 | 104 | } 105 | 106 | func sendResponse(ctx *fasthttp.RequestCtx, resp *fasthttp.Response, reqId int64, timer *utils.Stopwatch) { 107 | // copy header 108 | ctx.Response.Header = resp.Header 109 | ctx.Response.Header.Add("proxy", "gogate") 110 | 111 | timeCost := timer.Record() 112 | resp.Header.Add("Time", strconv.FormatInt(timeCost, 10)) 113 | resp.Header.Set("Server", "gogate") 114 | 115 | Log.Infof("request %d finished, cost = %dms, statusCode = %d, response = %s", reqId, timeCost, ctx.Response.StatusCode(), string(resp.Body())) 116 | ctx.Write(resp.Body()) 117 | } 118 | 119 | func invokePreFilters(s *Server, ctx *fasthttp.RequestCtx, newReq *fasthttp.Request) bool { 120 | for _, f := range s.preFilters { 121 | next := f.FilterFunc(s, ctx, newReq) 122 | if !next { 123 | return false 124 | } 125 | } 126 | 127 | return true 128 | } 129 | 130 | func invokePostFilters(s *Server, newReq *fasthttp.Request, resp *fasthttp.Response) bool { 131 | for _, f := range s.postFilters { 132 | next := f.FilterFunc(newReq, resp) 133 | if !next { 134 | return false 135 | } 136 | } 137 | 138 | return true 139 | } 140 | 141 | func processPanic(ctx *fasthttp.RequestCtx, serv *Server) { 142 | path := string(ctx.Path()) 143 | NewResponse(path, "system error").SendWithStatus(ctx, 500) 144 | 145 | // 记录流量 146 | serv.recordTraffic(GetStringFromUserValue(ctx, SERVICE_NAME), false) 147 | 148 | } 149 | 150 | func recoverPanic(ctx *fasthttp.RequestCtx, serv *Server) { 151 | if r := recover(); r != nil { 152 | // 日志记录调用栈 153 | stackBuf := make([]byte, 1024) 154 | bufLen := runtime.Stack(stackBuf, false) 155 | Log.Errorf("panic: %s", string(stackBuf[0:bufLen])) 156 | 157 | processPanic(ctx, serv) 158 | } 159 | } 160 | 161 | -------------------------------------------------------------------------------- /server/server_filter.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | . "github.com/wanghongfei/gogate/conf" 5 | ) 6 | 7 | // 注册过滤器, 追加到末尾 8 | func (serv *Server) AppendPreFilter(pre *PreFilter) { 9 | if serv.isStarted { 10 | Log.Warn("cannot change filters after server started") 11 | } 12 | 13 | // asynclog.Info("append pre filter: %s", pre.Name) 14 | serv.preFilters = append(serv.preFilters, pre) 15 | } 16 | 17 | // 注册过滤器, 追加到末尾 18 | func (serv *Server) AppendPostFilter(post *PostFilter) { 19 | if serv.isStarted { 20 | Log.Warn("cannot change filters after server started") 21 | } 22 | 23 | // asynceog.Info("append post filter: %s", post.Name) 24 | serv.postFilters = append(serv.postFilters, post) 25 | } 26 | 27 | // 导出所有前置过滤器 28 | func (serv *Server) ExportAllPreFilters() []*PreFilter { 29 | result := make([]*PreFilter, len(serv.preFilters)) 30 | copy(result, serv.preFilters) 31 | 32 | return result 33 | } 34 | 35 | // 导出所有后置过滤器 36 | func (serv *Server) ExportAllPostFilters() []*PostFilter { 37 | result := make([]*PostFilter, len(serv.postFilters)) 38 | copy(result, serv.postFilters) 39 | 40 | return result 41 | } 42 | 43 | // 在指定前置过滤器的后面添加 44 | func (serv *Server) InsertPreFilterBehind(filterName string, filter *PreFilter) bool { 45 | if serv.isStarted { 46 | Log.Warn("cannot change filters after server started") 47 | } 48 | 49 | // asynclog.Info("insert pre filter: %s", filter.Name) 50 | 51 | targetIdx := serv.getPreFilterIndex(filterName) 52 | if -1 == targetIdx { 53 | return false 54 | } 55 | 56 | rearIdx := targetIdx + 1 57 | rear := append([]*PreFilter{}, serv.preFilters[rearIdx:]...) 58 | serv.preFilters = append(serv.preFilters[0:rearIdx], filter) 59 | serv.preFilters = append(serv.preFilters, rear...) 60 | 61 | 62 | return true 63 | } 64 | 65 | // 在指定后置过滤器的后面添加; 66 | // filterName: 在此过滤器后面添加filter, 如果要在队头添加, 则使用空字符串 67 | // filter: 过滤器对象 68 | func (serv *Server) InsertPostFilterBehind(filterName string, filter *PostFilter) bool { 69 | if serv.isStarted { 70 | Log.Warn("cannot change filters after server started") 71 | } 72 | 73 | // asynclog.Info("insert post filter: %s", filter.Name) 74 | 75 | targetIdx := serv.getPostFilterIndex(filterName) 76 | if -1 == targetIdx { 77 | return false 78 | } 79 | 80 | rearIdx := targetIdx + 1 81 | rear := append([]*PostFilter{}, serv.postFilters[rearIdx:]...) 82 | serv.postFilters = append(serv.postFilters[0:rearIdx], filter) 83 | serv.postFilters = append(serv.postFilters, rear...) 84 | 85 | return true 86 | } 87 | 88 | // 在最头部添加前置过滤器 89 | func (serv *Server) InsertPreFilterAhead(filter *PreFilter) { 90 | if serv.isStarted { 91 | Log.Warn("cannot change filters after server started") 92 | } 93 | 94 | // asynclog.Info("insert pre filter: %s", filter.Name) 95 | 96 | newFilterSlice := make([]*PreFilter, 0, 1 + len(serv.preFilters)) 97 | newFilterSlice = append(newFilterSlice, filter) 98 | newFilterSlice = append(newFilterSlice, serv.preFilters...) 99 | 100 | serv.preFilters = newFilterSlice 101 | } 102 | 103 | // 在最头部添加后置过滤器 104 | func (serv *Server) InsertPostFilterAhead(filter *PostFilter) { 105 | if serv.isStarted { 106 | Log.Warn("cannot change filters after server started") 107 | } 108 | 109 | // asynclog.Info("insert post filter: %s", filter.Name) 110 | 111 | newFilterSlice := make([]*PostFilter, 0, 1 + len(serv.postFilters)) 112 | newFilterSlice = append(newFilterSlice, filter) 113 | newFilterSlice = append(newFilterSlice, serv.postFilters...) 114 | 115 | serv.postFilters = newFilterSlice 116 | } 117 | 118 | func (serv *Server) ensurePreFilterCap(neededSpace int) { 119 | currentCap := cap(serv.preFilters) 120 | currentLen := len(serv.preFilters) 121 | leftSpace := currentCap - currentLen 122 | 123 | if leftSpace < neededSpace { 124 | newCap := currentCap + (neededSpace - leftSpace) + 3 125 | 126 | oldFilters := serv.preFilters 127 | serv.preFilters = make([]*PreFilter, 0, newCap) 128 | copy(serv.preFilters, oldFilters) 129 | } 130 | } 131 | 132 | func (serv *Server) getPreFilterIndex(name string) int { 133 | if nil == serv.preFilters { 134 | return -1 135 | } 136 | 137 | for ix, f := range serv.preFilters { 138 | if f.Name == name { 139 | return ix 140 | } 141 | } 142 | 143 | return -1 144 | } 145 | 146 | func (serv *Server) getPostFilterIndex(name string) int { 147 | if nil == serv.preFilters { 148 | return -1 149 | } 150 | 151 | for ix, f := range serv.postFilters { 152 | if f.Name == name { 153 | return ix 154 | } 155 | } 156 | 157 | return -1 158 | } 159 | 160 | -------------------------------------------------------------------------------- /discovery/eureka_client.go: -------------------------------------------------------------------------------- 1 | package discovery 2 | 3 | import ( 4 | "github.com/wanghongfei/gogate/perr" 5 | "strconv" 6 | "time" 7 | 8 | "github.com/wanghongfei/go-eureka-client/eureka" 9 | "github.com/wanghongfei/gogate/conf" 10 | . "github.com/wanghongfei/gogate/conf" 11 | "github.com/wanghongfei/gogate/utils" 12 | ) 13 | 14 | // var euClient *eureka.Client 15 | var gogateApp *eureka.InstanceInfo 16 | var instanceId = "" 17 | 18 | var ticker *time.Ticker 19 | var tickerCloseChan chan struct{} 20 | 21 | type EurekaClient struct { 22 | // 继承方法 23 | *periodicalRefreshClient 24 | 25 | client *eureka.Client 26 | 27 | // 保存服务地址 28 | // key: 服务名:版本号, 版本号为eureka注册信息中的metadata[version]值 29 | // val: []*InstanceInfo 30 | registryMap *InsInfoArrSyncMap 31 | } 32 | 33 | func NewEurekaClient(confFile string) (Client, error) { 34 | c, err := eureka.NewClientFromFile(confFile) 35 | if nil != err { 36 | return nil, perr.WrapSystemErrorf(err, "failed to init eureka client") 37 | } 38 | 39 | euClient := &EurekaClient{client:c} 40 | euClient.periodicalRefreshClient = newPeriodicalRefresh(euClient) 41 | 42 | return euClient, nil 43 | } 44 | 45 | func (c *EurekaClient) Get(serviceId string) []*InstanceInfo { 46 | instance, exist := c.registryMap.Get(serviceId) 47 | if !exist { 48 | return nil 49 | } 50 | 51 | return instance 52 | } 53 | 54 | func (c *EurekaClient) GetInternalRegistryStore() *InsInfoArrSyncMap { 55 | return c.registryMap 56 | } 57 | 58 | func (c *EurekaClient) SetInternalRegistryStore(registry *InsInfoArrSyncMap) { 59 | c.registryMap = registry 60 | } 61 | 62 | 63 | // 查询所有服务 64 | func (c *EurekaClient) QueryServices() ([]*InstanceInfo, error) { 65 | apps, err := c.client.GetApplications() 66 | if nil != err { 67 | return nil, perr.WrapSystemErrorf(err, "faield to query eureka") 68 | } 69 | 70 | var instances []*InstanceInfo 71 | for _, app := range apps.Applications { 72 | // 服务名 73 | servName := app.Name 74 | 75 | // 遍历每一个实例 76 | for _, ins := range app.Instances { 77 | // 跳过无效实例 78 | if nil == ins.Port || ins.Status != "UP" { 79 | continue 80 | } 81 | 82 | addr := ins.HostName + ":" + strconv.Itoa(ins.Port.Port) 83 | var meta map[string]string 84 | if nil != ins.Metadata { 85 | meta = ins.Metadata.Map 86 | } 87 | 88 | instances = append( 89 | instances, 90 | &InstanceInfo{ 91 | ServiceName: servName, 92 | Addr: addr, 93 | Meta: meta, 94 | }, 95 | ) 96 | } 97 | } 98 | 99 | return instances, nil 100 | } 101 | 102 | func (c *EurekaClient) Register() error { 103 | ip, err := utils.GetFirstNoneLoopIp() 104 | if nil != err { 105 | return perr.WrapSystemErrorf(err, "failed to get first none loop ip") 106 | } 107 | 108 | 109 | instanceId = ip + ":" + strconv.Itoa(conf.App.ServerConfig.Port) 110 | 111 | // 注册 112 | Log.Infof("register to eureka as %s", instanceId) 113 | gogateApp = eureka.NewInstanceInfo( 114 | instanceId, 115 | conf.App.ServerConfig.AppName, 116 | ip, 117 | conf.App.ServerConfig.Port, 118 | conf.App.EurekaConfig.EvictionDuration, 119 | false, 120 | ) 121 | gogateApp.Metadata = &eureka.MetaData{ 122 | Class: "", 123 | Map: map[string]string {"version": conf.App.Version}, 124 | } 125 | 126 | err = c.client.RegisterInstance(conf.App.ServerConfig.AppName, gogateApp) 127 | if nil != err { 128 | return perr.WrapSystemErrorf(err, "failed to register to eureka") 129 | } 130 | 131 | // 心跳 132 | go func() { 133 | ticker = time.NewTicker(time.Second * time.Duration(conf.App.EurekaConfig.HeartbeatInterval)) 134 | tickerCloseChan = make(chan struct{}) 135 | 136 | for { 137 | select { 138 | case <-ticker.C: 139 | c.heartbeat() 140 | 141 | case <-tickerCloseChan: 142 | Log.Info("heartbeat stopped") 143 | return 144 | 145 | } 146 | } 147 | }() 148 | 149 | return nil 150 | } 151 | 152 | func (c *EurekaClient) UnRegister() error { 153 | c.stopHeartbeat() 154 | 155 | Log.Infof("unregistering %s", instanceId) 156 | err := c.client.UnregisterInstance("gogate", instanceId) 157 | 158 | if nil != err { 159 | return perr.WrapSystemErrorf(err, "failed to unregister") 160 | } 161 | 162 | Log.Info("done unregistration") 163 | return nil 164 | } 165 | 166 | func (c *EurekaClient) stopHeartbeat() { 167 | ticker.Stop() 168 | close(tickerCloseChan) 169 | } 170 | 171 | func (c *EurekaClient) heartbeat() { 172 | err := c.client.SendHeartbeat(gogateApp.App, instanceId) 173 | if nil != err { 174 | Log.Warnf("failed to send heartbeat, %v", err) 175 | return 176 | } 177 | 178 | Log.Info("heartbeat sent") 179 | } 180 | 181 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![build](https://api.travis-ci.org/wanghongfei/gogate.svg?branch=master) 2 | 3 | # GoGate 4 | 5 | Go语言实现的Spring Cloud网关,目标是性能,即使用更少的资源达到更高的QPS。 6 | 7 | GoGate使用以高性能著称的`FastHttp`库收发HTTP请求,且会为每个host单独创建一个`HostClient`以减少锁竞争。 8 | 9 | 10 | 11 | 目前已经实现的功能有: 12 | 13 | - 基于Eureka(或Consul)服务发现的请求路由 14 | - 基于配置文件的请求路由 15 | - 请求路由、路由配置热更新 16 | - 负载均衡 17 | - 灰度发布(基于Eureka meta信息里的version字段分配流量) 18 | - 微服务粒度的QPS控制(有基于内存的令牌桶算法限流和Redis + Lua限流两种可选) 19 | - 微服务粒度的流量统计(暂时实现为记录日志到/tmp目录下) 20 | 21 | 初步测试了一下性能,结论如下: 22 | 23 | 相同的硬件环境、Zuul充分预热且关闭Hystrix的前提下,Go版的网关QPS为Zuul的2.3倍,同时内存占用仅为Zuul的十分之一(600M vs 50M)。而且Go基本上第一波请求就能达到最大QPS, zuul要预热几次才会稳定。 24 | 25 | 如果按消耗相同资源的前提下算的话,go一定要比zuul节省多的多的多的机器。同时,Gogate可以做到秒级启动, SpringBoot要等半天。 26 | 27 | 28 | 29 | ## 什么情况下可以考虑使用非Java语言的网关 30 | 31 | - 系统使用Spring Cloud全家桶 32 | - 对Zuul 1性能不满意 33 | - 对Cloud官方已经明确不会整合Zuul 2的行为不爽 34 | - 认为Spring Cloud Gateway不够成熟(相比Zuul 2.0) 35 | - 对网关的CPU/内存资源使用非常敏感 36 | 37 | 那么就可以考虑! 38 | 39 | 40 | 41 | ## 流程 42 | 43 | ![arc](https://s2.ax1x.com/2019/06/28/ZMJXDI.png) 44 | 45 | 服务路由`service-match-pre-filter`: 根据URL匹配后端微服务 46 | 47 | 流量控制`rate-limit-pre-filter`: 令牌桶算法控制qps 48 | 49 | URL重写`url-rewrite-pre-filter`: 调整向后端服务发请求的URL 50 | 51 | 转发请求: 负载均衡、按比例分配流量 52 | 53 | 54 | 55 | gogate没有提供默认的Post Filter,可根据需要自己实现相应函数。 56 | 57 | 58 | 59 | ## 构建 60 | 61 | go版本>=1.13。 62 | 63 | 在$GOPATH之外的任意目录下clone项目: 64 | 65 | ```shell 66 | git clone https://github.com/wanghongfei/gogate 67 | ``` 68 | 69 | 安装依赖: 70 | 71 | ```shell 72 | go mod tidy 73 | ``` 74 | 75 | 最后构建: 76 | 77 | ```shell 78 | go build 79 | ``` 80 | 81 | 82 | 83 | 84 | 85 | ## 使用 86 | 87 | 可以编译`main.go`直接生成可执行文件,也可以当一个库来使用。 88 | 89 | 可以在转发请求之前和之后添加自定义Filter来添加自定义逻辑。 90 | 91 | 详见`examples/usage.go` 92 | 93 | 94 | 95 | ## 路由配置 96 | 97 | 路由匹配规则: 98 | 99 | - 当`id`不为空时,会使用eureka的注册信息查询此服务的地址 100 | - 当`host`不为空时, 会优先使用此字段指定的服务地址, 多个地址用逗号分隔 101 | - 使用`Trie Tree`进行前缀匹配。 102 | 103 | 当路由配置文件发生变动时,访问 104 | 105 | ``` 106 | GET /_mgr/reload 107 | ``` 108 | 109 | 即可应用新配置。 110 | 111 | 112 | 113 | 示例配置: 114 | 115 | ```yaml 116 | services: 117 | user-service: 118 | # eureka中的服务名 119 | id: user-service 120 | # 以/user开头的请求, 会被转发到user-service服务中 121 | prefix: /user 122 | # 转发时是否去掉请求前缀, 即/user 123 | strip-prefix: true 124 | # 灰度配置 125 | canary: 126 | - 127 | # 对应eurekai注册信息中元数据(metadata map)中key=version的值 128 | meta: "1.0" 129 | # 流量比重 130 | weight: 3 131 | - 132 | meta: "2.0" 133 | weight: 4 134 | - 135 | # 对应没有metadata的服务 136 | meta: "" 137 | weight: 1 138 | 139 | trends-service: 140 | id: trends-service 141 | prefix: /trends 142 | strip-prefix: false 143 | # 设置qps限制, 每秒最多请求数 144 | qps: 1 145 | 146 | order-service: 147 | id: order-service 148 | prefix: /order 149 | strip-prefix: false 150 | 151 | img-service: 152 | # 如果有host, 则不查注册中心直接使用此地址, 多个地址逗号分隔 153 | host: localhost:4444,localhost:5555 154 | prefix: /img 155 | strip-prefix: false 156 | 157 | log: 158 | console-only: true 159 | directory: "logs" 160 | file-pattern: "logs/gogate.log.%Y-%m-%d" 161 | file-link: "logs/gogate.log" 162 | ``` 163 | 164 | 165 | 166 | ## 自定义过滤器 167 | 168 | 前置fitler和后置filter都可以在任意位置添加自定义过滤器以实现定制化的功能。 169 | 170 | 171 | 172 | - 前置过滤器 173 | 174 | 函数签名为: 175 | 176 | ```go 177 | type PreFilterFunc func(server *Server, ctx *fasthttp.RequestCtx, newRequest *fasthttp.Request) bool 178 | ``` 179 | 180 | `server`: gogate server对象的指针 181 | 182 | `ctx`: 请求上下文对象指针 183 | 184 | `newRequest`: 要转发给下游微服务的请求对象的指针,可以对相关参数进行修改,如header, body, method等 185 | 186 | 返回`true`时gogate会继续触发下一个过滤器,返回`false`则表示请求到此为止, 不会执行后续过滤器,也不会转发请求。 187 | 188 | 189 | 190 | - 后置过滤器 191 | 192 | 函数签名名: 193 | 194 | ```go 195 | type PostFilterFunc func(req *fasthttp.Request, resp *fasthttp.Response) bool 196 | ``` 197 | 198 | `req`: 已经转发给微服务的请求对象指针 199 | 200 | `resp`: 微服务返回的响应对象指针, 可进行修改 201 | 202 | 返回`true`时gogate会继续触发下一个过滤器,返回`false`则表示请求到此为止, 不会执行后续过滤器,也不会转发请求。 203 | 204 | 205 | 206 | - 添加过滤器 207 | 208 | `Server.AppendPreFilter`: 在末尾追加前置过滤器 209 | 210 | `Server.AppendPostFilter`: 在末尾追加后置过滤器 211 | 212 | `Server.InsertPreFilter`: 在指定过滤器的后面插入前置过滤器 213 | 214 | `Server.InsertPostFilter`: 在指定过滤器的后面插入后置过滤器 215 | 216 | `Server.InsertPreFilterAhead`: 插入前置过滤器到最头部 217 | 218 | `Server.InsertPostFilterAhead`: 插入后置过滤器到最头部 219 | 220 | ## Eureka配置 221 | 222 | `eureka.json`文件 223 | 224 | 225 | 226 | ## gogate配置 227 | 228 | `gogate.yml`文件: 229 | 230 | ```yaml 231 | version: 1.0 232 | 233 | server: 234 | # 向eureka注册自己时使用的服务名 235 | appName: gogate 236 | host: 0.0.0.0 237 | port: 8080 238 | # gateway最大连接数 239 | maxConnection: 1000 240 | # gateway请求后端服务超时时间, 毫秒 241 | timeout: 3000 242 | 243 | # 如果eureka, consul都没有启动, 则进入静态模式, 不访问注册中心 244 | eureka: 245 | enable: false 246 | # eureka配置文件名 247 | configFile: eureka.json 248 | # 路由配置文件名 249 | routeFile: route.yml 250 | # eureka剔除服务的最大时间限值, 秒 251 | evictionDuration: 30 252 | # 心跳间隔, 秒 253 | heartbeatInterval: 20 254 | 255 | consul: 256 | enable: true 257 | # agent地址 258 | address: 127.0.0.1:8500 259 | 260 | 261 | traffic: 262 | # 是否开启流量记录功能 263 | enableTrafficRecord: true 264 | # 流量日志文件所在目录 265 | trafficLogDir: /tmp 266 | 267 | redis: 268 | # 是否使用redis做限速器 269 | enabled: false 270 | # 目前只支持单实例, 不支持cluster 271 | addr: 127.0.0.1:6379 272 | # 限速器lua代码文件 273 | rateLimiterLua: lua/rate_limiter.lua 274 | ``` 275 | 276 | 277 | 278 | ## 限流器 279 | 280 | gogate有两个限流器实现, `MemoryRateLimiter`和`RedisRateLimiter`,通过`gogate.yml`配置文件里的`redis.enabled`控制。前者使用令牌桶算法实现,适用于单实例部署的场景;后者基于 Redis + Lua 实现,适用于多实例部署。但有一个限制是目前Redis只支持连接单个实例,不支持cluster。 281 | 282 | 283 | 284 | ## 流量日志 285 | 286 | gogate会记录过去`1s`内各个微服务的请求数据,包括成功请求数和失败请求数,然后写入`/tmp/{service-id}_yyyyMMdd.log`文件中: 287 | 288 | ``` 289 | 1527580599228,2,1,user-service 290 | 1527580600230,4,1,user-service 291 | 1527580601228,1,1,user-service 292 | ``` 293 | 294 | 即`毫秒时间戳,成功请求数,失败请求数,服务名`。 295 | 296 | 如果在过去的1s内没有请求, 则不会向日志中写入任何数据。 297 | 298 | -------------------------------------------------------------------------------- /server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/valyala/fasthttp" 5 | "github.com/wanghongfei/gogate/conf" 6 | . "github.com/wanghongfei/gogate/conf" 7 | "github.com/wanghongfei/gogate/discovery" 8 | "github.com/wanghongfei/gogate/perr" 9 | "github.com/wanghongfei/gogate/redis" 10 | "github.com/wanghongfei/gogate/server/lb" 11 | "github.com/wanghongfei/gogate/server/route" 12 | "github.com/wanghongfei/gogate/server/statistics" 13 | "github.com/wanghongfei/gogate/throttle" 14 | "net" 15 | "os" 16 | "strconv" 17 | "time" 18 | ) 19 | 20 | type Server struct { 21 | host string 22 | port int 23 | 24 | // 负载均衡组件 25 | lb lb.LoadBalancer 26 | 27 | //// 保存listener引用, 用于关闭server 28 | //listen net.Listener 29 | //// 是否启用优雅关闭 30 | //graceShutdown bool 31 | //// 优雅关闭最大等待时间 32 | //maxWait time.Duration 33 | //wg *sync.WaitGroup 34 | 35 | // URI路由组件 36 | Router *route.Router 37 | 38 | // 过滤器 39 | preFilters []*PreFilter 40 | postFilters []*PostFilter 41 | 42 | // fasthttp对象 43 | fastServ *fasthttp.Server 44 | fastClient *fasthttp.Client 45 | 46 | isStarted bool 47 | 48 | discoveryClient discovery.Client 49 | 50 | // 服务id(string) -> 此服务的限速器对象(*MemoryRateLimiter) 51 | rateLimiterMap *RateLimiterSyncMap 52 | 53 | trafficStat *stat.TraficStat 54 | } 55 | 56 | const ( 57 | // 默认最大连接数 58 | MAX_CONNECTION = 5000 59 | ) 60 | 61 | /* 62 | * 创建网关服务对象 63 | * 64 | * PARAMS: 65 | * - host: 主机名(ip) 66 | * - port: 端口 67 | * - routePath: 路由配置文件路径 68 | * - maxConn: 最大连接数, 0表示使用默认值 69 | * 70 | */ 71 | func NewGatewayServer(host string, port int, routePath string, maxConn int) (*Server, error) { 72 | if "" == host { 73 | return nil, perr.WrapSystemErrorf(nil, "invalid host %s", host) 74 | } 75 | 76 | if port <= 0 || port > 65535 { 77 | return nil, perr.WrapSystemErrorf(nil, "invalid port %d", port) 78 | } 79 | 80 | if maxConn <= 0 { 81 | maxConn = MAX_CONNECTION 82 | } 83 | 84 | // 创建router 85 | router, err := route.NewRouter(routePath) 86 | if nil != err { 87 | return nil, perr.WrapSystemErrorf(err, "failed to create router") 88 | } 89 | 90 | // 创建Server对象 91 | serv := &Server{ 92 | host: host, 93 | port: port, 94 | 95 | lb: &lb.RoundRobinLoadBalancer{}, 96 | 97 | Router: router, 98 | 99 | preFilters: make([]*PreFilter, 0, 3), 100 | postFilters: make([]*PostFilter, 0, 3), 101 | 102 | //graceShutdown: useGracefullyShutdown, 103 | //maxWait: maxWait, 104 | } 105 | 106 | // 创建FastServer对象 107 | fastServ := &fasthttp.Server{ 108 | Concurrency: maxConn, 109 | Handler: serv.HandleRequest, 110 | LogAllErrors: true, 111 | } 112 | serv.fastServ = fastServ 113 | 114 | // 创建http client 115 | serv.fastClient = &fasthttp.Client{ 116 | MaxConnsPerHost: maxConn, 117 | ReadTimeout: time.Duration(conf.App.ServerConfig.Timeout) * time.Millisecond, 118 | WriteTimeout: time.Duration(conf.App.ServerConfig.Timeout) * time.Millisecond, 119 | } 120 | 121 | // 创建每个服务的限速器 122 | serv.rebuildRateLimiter() 123 | 124 | // 注册过虑器 125 | serv.AppendPreFilter(NewPreFilter("service-match-pre-filter", ServiceMatchPreFilter)) 126 | serv.InsertPreFilterBehind("service-match-pre-filter", NewPreFilter("rate-limit-pre-filter", RateLimitPreFilter)) 127 | serv.InsertPreFilterBehind("rate-limit-pre-filter", NewPreFilter("url-rewrite-pre-filter", UrlRewritePreFilter)) 128 | 129 | return serv, nil 130 | 131 | } 132 | 133 | // 启动服务器 134 | func (serv *Server) Start() error { 135 | if conf.App.Traffic.EnableTrafficRecord { 136 | serv.trafficStat = stat.NewTrafficStat(1000, 1, stat.NewCsvFileTraficInfoStore(conf.App.Traffic.TrafficLogDir)) 137 | serv.trafficStat.StartRecordTrafic() 138 | } 139 | 140 | serv.isStarted = true 141 | 142 | // 监听端口 143 | listen, err := net.Listen("tcp", serv.host + ":" + strconv.Itoa(serv.port)) 144 | if nil != err { 145 | return perr.WrapSystemErrorf(nil, "failed to listen at %s:%d => %w", serv.host, serv.port, err) 146 | } 147 | 148 | // 是否启用优雅关闭功能 149 | //if serv.graceShutdown { 150 | // serv.wg = new(sync.WaitGroup) 151 | //} 152 | 153 | // 保存Listener指针 154 | //serv.listen = listen 155 | 156 | bothEnabled := conf.App.EurekaConfig.Enable && conf.App.ConsulConfig.Enable 157 | if bothEnabled { 158 | return perr.WrapSystemErrorf(nil, "eureka and consul are both enabled") 159 | } 160 | 161 | // 初始化服务注册模块 162 | if conf.App.EurekaConfig.Enable { 163 | Log.Info("eureka enabled") 164 | serv.discoveryClient, err = discovery.NewEurekaClient(conf.App.EurekaConfig.ConfigFile) 165 | if nil != err { 166 | return err 167 | } 168 | 169 | // 注册自己, 启动心跳 170 | // discovery.StartRegister() 171 | 172 | } else if conf.App.ConsulConfig.Enable { 173 | Log.Info("consul enabled") 174 | // 初始化consul 175 | serv.discoveryClient, err = discovery.NewConsulClient() 176 | if nil != err { 177 | return err 178 | } 179 | 180 | } else { 181 | Log.Infof("no register center enabled, use static mode") 182 | serv.discoveryClient = discovery.DoNothingClient 183 | } 184 | 185 | // 启动注册表定时更新 186 | err = serv.discoveryClient.StartPeriodicalRefresh() 187 | if nil != err { 188 | return perr.WrapSystemErrorf(err, "failed to start discovery module") 189 | } 190 | 191 | // 启动http server 192 | Log.Infof("start Gogate at %s:%d, pid: %d", serv.host, serv.port, os.Getpid()) 193 | return serv.fastServ.Serve(listen) 194 | } 195 | 196 | // 关闭server 197 | func (serv *Server) Shutdown() error { 198 | serv.isStarted = false 199 | serv.discoveryClient.UnRegister() 200 | 201 | err := serv.fastServ.Shutdown() 202 | if nil != err { 203 | return perr.WrapSystemErrorf(err, "failed to shutdown server") 204 | } 205 | 206 | return nil 207 | } 208 | 209 | // 更新路由配置文件 210 | func (serv *Server) ReloadRoute() error { 211 | Log.Info("start reloading route info") 212 | err := serv.Router.ReloadRoute() 213 | serv.rebuildRateLimiter() 214 | Log.Info("route info reloaded") 215 | 216 | if nil != err { 217 | return perr.WrapSystemErrorf(err, "failed to reload route") 218 | } 219 | 220 | return nil 221 | } 222 | 223 | func (serv *Server) IsInStaticMode() bool { 224 | return serv.discoveryClient == discovery.DoNothingClient 225 | } 226 | 227 | func (serv *Server) recordTraffic(servName string, success bool) { 228 | if nil != serv.trafficStat { 229 | Log.Debug("log traffic for %s", servName) 230 | 231 | info := &stat.TraficInfo{ 232 | ServiceId: servName, 233 | } 234 | if success { 235 | info.SuccessCount = 1 236 | } else { 237 | info.FailedCount = 1 238 | } 239 | 240 | serv.trafficStat.RecordTrafic(info) 241 | } 242 | 243 | } 244 | 245 | // 给路由表中的每个服务重新创建限速器; 246 | // 在更新过route.yml配置文件时调用 247 | func (serv *Server) rebuildRateLimiter() { 248 | serv.rateLimiterMap = NewRateLimiterSyncMap() 249 | 250 | // 创建每个服务的限速器 251 | for _, info := range serv.Router.ServInfos { 252 | if 0 == info.Qps { 253 | continue 254 | } 255 | 256 | rl := serv.createRateLimiter(info) 257 | if nil != rl { 258 | serv.rateLimiterMap.Put(info.Id, rl) 259 | Log.Debugf("done building rateLimiter for %s", info.Id) 260 | } 261 | } 262 | } 263 | 264 | // 创建限速器对象 265 | // 如果配置文件中设置了使用redis, 则创建RedisRateLimiter, 否则创建MemoryRateLimiter 266 | func (serv *Server) createRateLimiter(info *route.ServiceInfo) throttle.RateLimiter { 267 | enableRedis := conf.App.RedisConfig.Enabled 268 | if !enableRedis { 269 | return throttle.NewMemoryRateLimiter(info.Qps) 270 | } 271 | 272 | client := redis.NewRedisClient(conf.App.RedisConfig.Addr, 5) 273 | err := client.Connect() 274 | if nil != err { 275 | Log.Warn("failed to create ratelimiter, err = %v", err) 276 | return nil 277 | } 278 | 279 | rl, err := throttle.NewRedisRateLimiter(client, conf.App.RedisConfig.RateLimiterLua, info.Qps, info.Id) 280 | if nil != err { 281 | Log.Warn("failed to create ratelimiter, err = %v", err) 282 | return nil 283 | } 284 | 285 | return rl 286 | } 287 | --------------------------------------------------------------------------------