├── .gitignore ├── Makefile ├── README.md ├── code ├── balancer │ ├── Makefile │ ├── balancer_test.go │ ├── client.go │ ├── proto │ │ ├── balancer.pb.go │ │ └── balancer.proto │ ├── server.go │ └── zrpc │ │ ├── client.go │ │ ├── config.go │ │ ├── internal │ │ ├── auth │ │ │ ├── auth.go │ │ │ ├── credential.go │ │ │ └── vars.go │ │ ├── chainclientinterceptors.go │ │ ├── chainserverinterceptors.go │ │ ├── client.go │ │ ├── clientinterceptors │ │ │ ├── breakerinterceptor.go │ │ │ ├── durationinterceptor.go │ │ │ ├── prometheusinterceptor.go │ │ │ ├── timeoutinterceptor.go │ │ │ └── tracinginterceptor.go │ │ ├── codes │ │ │ └── accept.go │ │ ├── rpclogger.go │ │ ├── rpcpubserver.go │ │ ├── rpcserver.go │ │ ├── server.go │ │ └── serverinterceptors │ │ │ ├── authinterceptor.go │ │ │ ├── breakerinterceptor.go │ │ │ ├── crashinterceptor.go │ │ │ ├── prometheusinterceptor.go │ │ │ ├── sheddinginterceptor.go │ │ │ ├── statinterceptor.go │ │ │ ├── timeoutinterceptor.go │ │ │ └── tracinginterceptor.go │ │ ├── p2c │ │ └── p2c.go │ │ ├── resolver │ │ ├── internal │ │ │ ├── directbuilder.go │ │ │ ├── discovbuilder.go │ │ │ ├── etcdbuilder.go │ │ │ ├── kube │ │ │ │ ├── eventhandler.go │ │ │ │ └── targetparser.go │ │ │ ├── kubebuilder.go │ │ │ ├── resolver.go │ │ │ └── subset.go │ │ ├── register.go │ │ └── target.go │ │ └── server.go ├── breaker │ ├── breaker_test.go │ ├── googlebreaker_test.go │ └── rollingwindow_test.go ├── core │ ├── breaker │ │ ├── breaker.go │ │ └── googlebreaker.go │ ├── collection │ │ └── rollingwindow.go │ ├── limit │ │ ├── periodlimit.go │ │ └── tokenlimit.go │ ├── load │ │ ├── adaptiveshedder.go │ │ └── nopshedder.go │ ├── stat │ │ ├── internal │ │ │ ├── cgroup_linux.go │ │ │ ├── cpu_linux.go │ │ │ └── cpu_other.go │ │ └── usage.go │ └── syncx │ │ ├── atomicbool.go │ │ ├── atomicduration.go │ │ ├── limit.go │ │ └── spinlock.go ├── limit │ ├── limit_test.go │ ├── periodlimit_test.go │ └── tokenlimit_test.go ├── rest │ ├── rest │ │ ├── config.go │ │ ├── engine.go │ │ ├── handler │ │ │ ├── authhandler.go │ │ │ ├── breakerhandler.go │ │ │ ├── contentsecurityhandler.go │ │ │ ├── cryptionhandler.go │ │ │ ├── gunziphandler.go │ │ │ ├── loghandler.go │ │ │ ├── maxbyteshandler.go │ │ │ ├── maxconnshandler.go │ │ │ ├── metrichandler.go │ │ │ ├── prometheushandler.go │ │ │ ├── recoverhandler.go │ │ │ ├── sheddinghandler.go │ │ │ ├── timeouthandler.go │ │ │ └── tracinghandler.go │ │ ├── httpx │ │ │ ├── requests.go │ │ │ ├── responses.go │ │ │ ├── router.go │ │ │ ├── util.go │ │ │ └── vars.go │ │ ├── internal │ │ │ ├── cors │ │ │ │ └── handlers.go │ │ │ ├── log.go │ │ │ ├── response │ │ │ │ ├── headeronceresponsewriter.go │ │ │ │ └── withcoderesponsewriter.go │ │ │ ├── security │ │ │ │ └── contentsecurity.go │ │ │ └── starter.go │ │ ├── pathvar │ │ │ └── params.go │ │ ├── router │ │ │ └── patrouter.go │ │ ├── server.go │ │ ├── token │ │ │ └── tokenparser.go │ │ └── types.go │ ├── rest_server.go │ └── rest_test.go └── shedding │ └── shedding_test.go ├── go.mod ├── go.sum └── images └── go-zero-rest-start.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | .DS_Store 8 | *.DS_Store 9 | .log 10 | *.log 11 | 12 | # Test binary, build with `go test -c` 13 | *.test 14 | 15 | # Output of the go coverage tool, specifically when used with LiteIDE 16 | *.out 17 | 18 | # ide cache files 19 | .idea 20 | # .vscode 21 | 22 | # go mod vendor files 23 | /vendor/ 24 | 25 | # go build files 26 | main 27 | 28 | # tools 29 | 30 | # docker 31 | /data/ 32 | */tmp -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: build clean tool lint help 2 | 3 | all: build 4 | 5 | # analyse 6 | vet: 7 | @go vet ./code/...; true 8 | 9 | # fotmat 10 | # go install mvdan.cc/gofumpt@latest 11 | fmt: 12 | @gofumpt -l -w ./code/; true 13 | 14 | clean: 15 | @go clean -i . 16 | 17 | help: 18 | @echo "make vet: run specified go vet" 19 | @echo "make fmt: gofumpt -l -w ./service/" 20 | @echo "make clean: remove object files and cached files" -------------------------------------------------------------------------------- /code/balancer/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: build clean tool lint help 2 | 3 | all: proto-gen 4 | 5 | # proto gen 6 | proto-gen: 7 | @cd ./proto && protoc --go_out=plugins=grpc:. balancer.proto -------------------------------------------------------------------------------- /code/balancer/balancer_test.go: -------------------------------------------------------------------------------- 1 | package balancer_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "gozerosource/code/balancer" 8 | ) 9 | 10 | func TestP2C(t *testing.T) { 11 | const ( 12 | seconds = 10 13 | ) 14 | timer := time.NewTimer(time.Second * seconds) 15 | quit := make(chan struct{}) 16 | 17 | defer timer.Stop() 18 | go func() { 19 | <-timer.C 20 | close(quit) 21 | }() 22 | 23 | go balancer.NewServer() 24 | go balancer.NewClient() 25 | 26 | for { 27 | select { 28 | case <-quit: 29 | return 30 | default: 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /code/balancer/client.go: -------------------------------------------------------------------------------- 1 | package balancer 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "fmt" 7 | "time" 8 | 9 | "gozerosource/code/balancer/proto" 10 | "gozerosource/code/balancer/zrpc" 11 | 12 | "github.com/zeromicro/go-zero/core/discov" 13 | ) 14 | 15 | const timeFormat = "15:04:05" 16 | 17 | func NewClient() { 18 | flag.Parse() 19 | 20 | c := zrpc.RpcClientConf{ 21 | Etcd: discov.EtcdConf{ 22 | Hosts: []string{"127.0.0.1:2379"}, 23 | Key: "balancer.rpc", 24 | }, 25 | } 26 | 27 | client := zrpc.MustNewClient(c) 28 | ticker := time.NewTicker(time.Second) 29 | defer ticker.Stop() 30 | for { 31 | <-ticker.C 32 | fmt.Println("") 33 | fmt.Println("---------------------------------------------") 34 | fmt.Println("") 35 | conn := client.Conn() 36 | balancer := proto.NewBalancerClient(conn) 37 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 38 | resp, err := balancer.Hello(ctx, &proto.Request{ 39 | Msg: "I'm balancer client", 40 | }) 41 | if err != nil { 42 | fmt.Printf("warning ⛔ %s X %s\n", time.Now().Format(timeFormat), err.Error()) 43 | } else { 44 | fmt.Println("") 45 | fmt.Printf("I'm client 👉 %s => %s\n\n", time.Now().Format(timeFormat), resp.Data) 46 | } 47 | cancel() 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /code/balancer/proto/balancer.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package proto; 4 | 5 | option go_package = "./;proto"; 6 | 7 | message Request { 8 | string msg = 1; 9 | } 10 | 11 | message Response { 12 | string data = 1; 13 | } 14 | 15 | service Balancer { 16 | rpc Hello(Request) returns (Response); 17 | } 18 | -------------------------------------------------------------------------------- /code/balancer/server.go: -------------------------------------------------------------------------------- 1 | package balancer 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "os" 8 | "sync" 9 | "time" 10 | 11 | "gozerosource/code/balancer/proto" 12 | "gozerosource/code/balancer/zrpc" 13 | 14 | "github.com/zeromicro/go-zero/core/discov" 15 | "github.com/zeromicro/go-zero/core/logx" 16 | "github.com/zeromicro/go-zero/core/service" 17 | "google.golang.org/grpc" 18 | ) 19 | 20 | type BalancerServer struct { 21 | lock sync.Mutex 22 | alive bool 23 | downTime time.Time 24 | } 25 | 26 | func NewBalancerServer() *BalancerServer { 27 | return &BalancerServer{ 28 | alive: true, 29 | } 30 | } 31 | 32 | func (gs *BalancerServer) Hello(ctx context.Context, req *proto.Request) (*proto.Response, error) { 33 | fmt.Printf("I'm blancer server 👉 %s => %s\n\n", time.Now().Format(timeFormat), req) 34 | 35 | hostname, err := os.Hostname() 36 | if err != nil { 37 | return nil, err 38 | } 39 | 40 | return &proto.Response{ 41 | Data: "hello from " + hostname, 42 | }, nil 43 | } 44 | 45 | func NewServer() { 46 | c := zrpc.RpcServerConf{ 47 | ServiceConf: service.ServiceConf{ 48 | Name: "rpc.balancer", 49 | Log: logx.LogConf{ 50 | Mode: "console", 51 | }, 52 | }, 53 | ListenOn: "127.0.0.1:3456", 54 | Etcd: discov.EtcdConf{ 55 | Hosts: []string{"127.0.0.1:2379"}, 56 | Key: "balancer.rpc", 57 | }, 58 | } 59 | 60 | server := zrpc.MustNewServer(c, func(grpcServer *grpc.Server) { 61 | proto.RegisterBalancerServer(grpcServer, NewBalancerServer()) 62 | }) 63 | 64 | // 拦截器 65 | interceptor := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { 66 | st := time.Now() 67 | resp, err = handler(ctx, req) 68 | log.Printf("👋 method: %s time: %v\n\n", info.FullMethod, time.Since(st)) 69 | return resp, err 70 | } 71 | 72 | server.AddUnaryInterceptors(interceptor) 73 | server.Start() 74 | } 75 | -------------------------------------------------------------------------------- /code/balancer/zrpc/client.go: -------------------------------------------------------------------------------- 1 | package zrpc 2 | 3 | import ( 4 | "log" 5 | "time" 6 | 7 | "gozerosource/code/balancer/zrpc/internal" 8 | "gozerosource/code/balancer/zrpc/internal/auth" 9 | "gozerosource/code/balancer/zrpc/internal/clientinterceptors" 10 | 11 | "google.golang.org/grpc" 12 | ) 13 | 14 | var ( 15 | // WithDialOption is an alias of internal.WithDialOption. 16 | WithDialOption = internal.WithDialOption 17 | // WithNonBlock sets the dialing to be nonblock. 18 | WithNonBlock = internal.WithNonBlock 19 | // WithTimeout is an alias of internal.WithTimeout. 20 | WithTimeout = internal.WithTimeout 21 | // WithTransportCredentials return a func to make the gRPC calls secured with given credentials. 22 | WithTransportCredentials = internal.WithTransportCredentials 23 | // WithUnaryClientInterceptor is an alias of internal.WithUnaryClientInterceptor. 24 | WithUnaryClientInterceptor = internal.WithUnaryClientInterceptor 25 | ) 26 | 27 | type ( 28 | // Client is an alias of internal.Client. 29 | Client = internal.Client 30 | // ClientOption is an alias of internal.ClientOption. 31 | ClientOption = internal.ClientOption 32 | 33 | // A RpcClient is a rpc client. 34 | RpcClient struct { 35 | client Client 36 | } 37 | ) 38 | 39 | // MustNewClient returns a Client, exits on any error. 40 | func MustNewClient(c RpcClientConf, options ...ClientOption) Client { 41 | cli, err := NewClient(c, options...) 42 | if err != nil { 43 | log.Fatal(err) 44 | } 45 | 46 | return cli 47 | } 48 | 49 | // NewClient returns a Client. 50 | func NewClient(c RpcClientConf, options ...ClientOption) (Client, error) { 51 | var opts []ClientOption 52 | if c.HasCredential() { 53 | opts = append(opts, WithDialOption(grpc.WithPerRPCCredentials(&auth.Credential{ 54 | App: c.App, 55 | Token: c.Token, 56 | }))) 57 | } 58 | if c.NonBlock { 59 | opts = append(opts, WithNonBlock()) 60 | } 61 | if c.Timeout > 0 { 62 | opts = append(opts, WithTimeout(time.Duration(c.Timeout)*time.Millisecond)) 63 | } 64 | 65 | opts = append(opts, options...) 66 | 67 | target, err := c.BuildTarget() 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | client, err := internal.NewClient(target, opts...) 73 | if err != nil { 74 | return nil, err 75 | } 76 | 77 | return &RpcClient{ 78 | client: client, 79 | }, nil 80 | } 81 | 82 | // NewClientWithTarget returns a Client with connecting to given target. 83 | func NewClientWithTarget(target string, opts ...ClientOption) (Client, error) { 84 | return internal.NewClient(target, opts...) 85 | } 86 | 87 | // Conn returns the underlying grpc.ClientConn. 88 | func (rc *RpcClient) Conn() *grpc.ClientConn { 89 | return rc.client.Conn() 90 | } 91 | 92 | // SetClientSlowThreshold sets the slow threshold on client side. 93 | func SetClientSlowThreshold(threshold time.Duration) { 94 | clientinterceptors.SetSlowThreshold(threshold) 95 | } 96 | -------------------------------------------------------------------------------- /code/balancer/zrpc/config.go: -------------------------------------------------------------------------------- 1 | package zrpc 2 | 3 | import ( 4 | "gozerosource/code/balancer/zrpc/resolver" 5 | 6 | "github.com/zeromicro/go-zero/core/discov" 7 | "github.com/zeromicro/go-zero/core/service" 8 | "github.com/zeromicro/go-zero/core/stores/redis" 9 | ) 10 | 11 | type ( 12 | // A RpcServerConf is a rpc server config. 13 | RpcServerConf struct { 14 | service.ServiceConf 15 | ListenOn string 16 | Etcd discov.EtcdConf `json:",optional"` 17 | Auth bool `json:",optional"` 18 | Redis redis.RedisKeyConf `json:",optional"` 19 | StrictControl bool `json:",optional"` 20 | // setting 0 means no timeout 21 | Timeout int64 `json:",default=2000"` 22 | CpuThreshold int64 `json:",default=900,range=[0:1000]"` 23 | } 24 | 25 | // A RpcClientConf is a rpc client config. 26 | RpcClientConf struct { 27 | Etcd discov.EtcdConf `json:",optional"` 28 | Endpoints []string `json:",optional"` 29 | Target string `json:",optional"` 30 | App string `json:",optional"` 31 | Token string `json:",optional"` 32 | NonBlock bool `json:",optional"` 33 | Timeout int64 `json:",default=2000"` 34 | } 35 | ) 36 | 37 | // NewDirectClientConf returns a RpcClientConf. 38 | // 原生 grpc 客户端配置项 39 | func NewDirectClientConf(endpoints []string, app, token string) RpcClientConf { 40 | return RpcClientConf{ 41 | Endpoints: endpoints, 42 | App: app, 43 | Token: token, 44 | } 45 | } 46 | 47 | // NewEtcdClientConf returns a RpcClientConf. 48 | // etcd 服务注册发现客户端配置项 49 | func NewEtcdClientConf(hosts []string, key, app, token string) RpcClientConf { 50 | return RpcClientConf{ 51 | Etcd: discov.EtcdConf{ 52 | Hosts: hosts, 53 | Key: key, 54 | }, 55 | App: app, 56 | Token: token, 57 | } 58 | } 59 | 60 | // HasEtcd checks if there is etcd settings in config. 61 | // 是否有 etcd 配置项 62 | func (sc RpcServerConf) HasEtcd() bool { 63 | return len(sc.Etcd.Hosts) > 0 && len(sc.Etcd.Key) > 0 64 | } 65 | 66 | // Validate validates the config. 67 | // 配置验证器 68 | func (sc RpcServerConf) Validate() error { 69 | if !sc.Auth { 70 | return nil 71 | } 72 | 73 | return sc.Redis.Validate() 74 | } 75 | 76 | // BuildTarget builds the rpc target from the given config. 77 | // 从配置项中获取 rpc 客户端标准配置项 78 | func (cc RpcClientConf) BuildTarget() (string, error) { 79 | if len(cc.Endpoints) > 0 { 80 | return resolver.BuildDirectTarget(cc.Endpoints), nil 81 | } else if len(cc.Target) > 0 { 82 | return cc.Target, nil 83 | } 84 | 85 | if err := cc.Etcd.Validate(); err != nil { 86 | return "", err 87 | } 88 | 89 | if cc.Etcd.HasAccount() { 90 | discov.RegisterAccount(cc.Etcd.Hosts, cc.Etcd.User, cc.Etcd.Pass) 91 | } 92 | if cc.Etcd.HasTLS() { 93 | if err := discov.RegisterTLS(cc.Etcd.Hosts, cc.Etcd.CertFile, cc.Etcd.CertKeyFile, 94 | cc.Etcd.CACertFile, cc.Etcd.InsecureSkipVerify); err != nil { 95 | return "", err 96 | } 97 | } 98 | 99 | return resolver.BuildDiscovTarget(cc.Etcd.Hosts, cc.Etcd.Key), nil 100 | } 101 | 102 | // HasCredential checks if there is a credential in config. 103 | // 检查配置中是否有凭证 104 | func (cc RpcClientConf) HasCredential() bool { 105 | return len(cc.App) > 0 && len(cc.Token) > 0 106 | } 107 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/auth/auth.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/zeromicro/go-zero/core/collection" 8 | "github.com/zeromicro/go-zero/core/stores/redis" 9 | "google.golang.org/grpc/codes" 10 | "google.golang.org/grpc/metadata" 11 | "google.golang.org/grpc/status" 12 | ) 13 | 14 | const defaultExpiration = 5 * time.Minute 15 | 16 | // An Authenticator is used to authenticate the rpc requests. 17 | type Authenticator struct { 18 | store *redis.Redis 19 | key string 20 | cache *collection.Cache 21 | strict bool 22 | } 23 | 24 | // NewAuthenticator returns an Authenticator. 25 | // 权限验证器 26 | func NewAuthenticator(store *redis.Redis, key string, strict bool) (*Authenticator, error) { 27 | cache, err := collection.NewCache(defaultExpiration) 28 | if err != nil { 29 | return nil, err 30 | } 31 | 32 | return &Authenticator{ 33 | store: store, 34 | key: key, 35 | cache: cache, 36 | strict: strict, 37 | }, nil 38 | } 39 | 40 | // Authenticate authenticates the given ctx. 41 | // 验证权限 42 | func (a *Authenticator) Authenticate(ctx context.Context) error { 43 | md, ok := metadata.FromIncomingContext(ctx) 44 | if !ok { 45 | return status.Error(codes.Unauthenticated, missingMetadata) 46 | } 47 | 48 | apps, tokens := md[appKey], md[tokenKey] 49 | if len(apps) == 0 || len(tokens) == 0 { 50 | return status.Error(codes.Unauthenticated, missingMetadata) 51 | } 52 | 53 | app, token := apps[0], tokens[0] 54 | if len(app) == 0 || len(token) == 0 { 55 | return status.Error(codes.Unauthenticated, missingMetadata) 56 | } 57 | 58 | return a.validate(app, token) 59 | } 60 | 61 | func (a *Authenticator) validate(app, token string) error { 62 | expect, err := a.cache.Take(app, func() (interface{}, error) { 63 | return a.store.Hget(a.key, app) 64 | }) 65 | if err != nil { 66 | if a.strict { 67 | return status.Error(codes.Internal, err.Error()) 68 | } 69 | 70 | return nil 71 | } 72 | 73 | if token != expect { 74 | return status.Error(codes.Unauthenticated, accessDenied) 75 | } 76 | 77 | return nil 78 | } 79 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/auth/credential.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "context" 5 | 6 | "google.golang.org/grpc/metadata" 7 | ) 8 | 9 | // 权限凭据 10 | // A Credential is used to authenticate. 11 | type Credential struct { 12 | App string 13 | Token string 14 | } 15 | 16 | // GetRequestMetadata gets the request metadata. 17 | func (c *Credential) GetRequestMetadata(context.Context, ...string) (map[string]string, error) { 18 | return map[string]string{ 19 | appKey: c.App, 20 | tokenKey: c.Token, 21 | }, nil 22 | } 23 | 24 | // RequireTransportSecurity always returns false. 25 | func (c *Credential) RequireTransportSecurity() bool { 26 | return false 27 | } 28 | 29 | // 解析并获取权限凭据 30 | // ParseCredential parses credential from given ctx. 31 | func ParseCredential(ctx context.Context) Credential { 32 | var credential Credential 33 | 34 | md, ok := metadata.FromIncomingContext(ctx) 35 | if !ok { 36 | return credential 37 | } 38 | 39 | apps, tokens := md[appKey], md[tokenKey] 40 | if len(apps) == 0 || len(tokens) == 0 { 41 | return credential 42 | } 43 | 44 | app, token := apps[0], tokens[0] 45 | if len(app) == 0 || len(token) == 0 { 46 | return credential 47 | } 48 | 49 | credential.App = app 50 | credential.Token = token 51 | 52 | return credential 53 | } 54 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/auth/vars.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | const ( 4 | appKey = "app" 5 | tokenKey = "token" 6 | 7 | accessDenied = "access denied" 8 | missingMetadata = "app/token required" 9 | ) 10 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/chainclientinterceptors.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import "google.golang.org/grpc" 4 | 5 | // 客户端流拦截器 6 | // WithStreamClientInterceptors uses given client stream interceptors. 7 | func WithStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.DialOption { 8 | return grpc.WithChainStreamInterceptor(interceptors...) 9 | } 10 | 11 | // 客户端拦截器 12 | // WithUnaryClientInterceptors uses given client unary interceptors. 13 | func WithUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.DialOption { 14 | return grpc.WithChainUnaryInterceptor(interceptors...) 15 | } 16 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/chainserverinterceptors.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import "google.golang.org/grpc" 4 | 5 | // 服务端流拦截器 6 | // WithStreamServerInterceptors uses given server stream interceptors. 7 | func WithStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.ServerOption { 8 | return grpc.ChainStreamInterceptor(interceptors...) 9 | } 10 | 11 | // 服务端拦截器 12 | // WithUnaryServerInterceptors uses given server unary interceptors. 13 | func WithUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.ServerOption { 14 | return grpc.ChainUnaryInterceptor(interceptors...) 15 | } 16 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/client.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "strings" 8 | "time" 9 | 10 | "gozerosource/code/balancer/zrpc/internal/clientinterceptors" 11 | "gozerosource/code/balancer/zrpc/p2c" 12 | "gozerosource/code/balancer/zrpc/resolver" 13 | 14 | "google.golang.org/grpc" 15 | "google.golang.org/grpc/credentials" 16 | ) 17 | 18 | const ( 19 | dialTimeout = time.Second * 3 20 | separator = '/' 21 | ) 22 | 23 | func init() { 24 | resolver.Register() 25 | } 26 | 27 | type ( 28 | // 客户端接口 29 | // Client interface wraps the Conn method. 30 | Client interface { 31 | Conn() *grpc.ClientConn 32 | } 33 | 34 | // 客户端配置项 35 | // A ClientOptions is a client options. 36 | ClientOptions struct { 37 | NonBlock bool 38 | Timeout time.Duration 39 | Secure bool 40 | DialOptions []grpc.DialOption 41 | } 42 | 43 | // ClientOption defines the method to customize a ClientOptions. 44 | ClientOption func(options *ClientOptions) 45 | 46 | client struct { 47 | conn *grpc.ClientConn 48 | } 49 | ) 50 | 51 | // 初始化客户端 52 | // NewClient returns a Client. 53 | func NewClient(target string, opts ...ClientOption) (Client, error) { 54 | var cli client 55 | opts = append([]ClientOption{WithDialOption(grpc.WithBalancerName(p2c.Name))}, opts...) 56 | if err := cli.dial(target, opts...); err != nil { 57 | return nil, err 58 | } 59 | 60 | return &cli, nil 61 | } 62 | 63 | // 获取客户端连接 64 | func (c *client) Conn() *grpc.ClientConn { 65 | return c.conn 66 | } 67 | 68 | // 构建拨号配置 69 | func (c *client) buildDialOptions(opts ...ClientOption) []grpc.DialOption { 70 | var cliOpts ClientOptions 71 | for _, opt := range opts { 72 | opt(&cliOpts) 73 | } 74 | 75 | var options []grpc.DialOption 76 | if !cliOpts.Secure { 77 | options = append([]grpc.DialOption(nil), grpc.WithInsecure()) 78 | } 79 | 80 | if !cliOpts.NonBlock { 81 | options = append(options, grpc.WithBlock()) 82 | } 83 | 84 | options = append(options, 85 | WithUnaryClientInterceptors( 86 | clientinterceptors.UnaryTracingInterceptor, 87 | clientinterceptors.DurationInterceptor, 88 | clientinterceptors.PrometheusInterceptor, 89 | clientinterceptors.BreakerInterceptor, 90 | clientinterceptors.TimeoutInterceptor(cliOpts.Timeout), 91 | ), 92 | WithStreamClientInterceptors( 93 | clientinterceptors.StreamTracingInterceptor, 94 | ), 95 | ) 96 | 97 | return append(options, cliOpts.DialOptions...) 98 | } 99 | 100 | // 拨号 101 | func (c *client) dial(server string, opts ...ClientOption) error { 102 | options := c.buildDialOptions(opts...) 103 | timeCtx, cancel := context.WithTimeout(context.Background(), dialTimeout) 104 | defer cancel() 105 | conn, err := grpc.DialContext(timeCtx, server, options...) 106 | if err != nil { 107 | service := server 108 | if errors.Is(err, context.DeadlineExceeded) { 109 | pos := strings.LastIndexByte(server, separator) 110 | // len(server) - 1 is the index of last char 111 | if 0 < pos && pos < len(server)-1 { 112 | service = server[pos+1:] 113 | } 114 | } 115 | return fmt.Errorf("rpc dial: %s, error: %s, make sure rpc service %q is already started", 116 | server, err.Error(), service) 117 | } 118 | 119 | c.conn = conn 120 | return nil 121 | } 122 | 123 | // 拨号配置 124 | // WithDialOption returns a func to customize a ClientOptions with given dial option. 125 | func WithDialOption(opt grpc.DialOption) ClientOption { 126 | return func(options *ClientOptions) { 127 | options.DialOptions = append(options.DialOptions, opt) 128 | } 129 | } 130 | 131 | // 非阻塞拨号设置 132 | // WithNonBlock sets the dialing to be nonblock. 133 | func WithNonBlock() ClientOption { 134 | return func(options *ClientOptions) { 135 | options.NonBlock = true 136 | } 137 | } 138 | 139 | // 超时设置 140 | // WithTimeout returns a func to customize a ClientOptions with given timeout. 141 | func WithTimeout(timeout time.Duration) ClientOption { 142 | return func(options *ClientOptions) { 143 | options.Timeout = timeout 144 | } 145 | } 146 | 147 | // Grpc调用凭据设置 148 | // WithTransportCredentials return a func to make the gRPC calls secured with given credentials. 149 | func WithTransportCredentials(creds credentials.TransportCredentials) ClientOption { 150 | return func(options *ClientOptions) { 151 | options.Secure = true 152 | options.DialOptions = append(options.DialOptions, grpc.WithTransportCredentials(creds)) 153 | } 154 | } 155 | 156 | // 自定义拦截器设置 157 | // WithUnaryClientInterceptor returns a func to customize a ClientOptions with given interceptor. 158 | func WithUnaryClientInterceptor(interceptor grpc.UnaryClientInterceptor) ClientOption { 159 | return func(options *ClientOptions) { 160 | options.DialOptions = append(options.DialOptions, WithUnaryClientInterceptors(interceptor)) 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/clientinterceptors/breakerinterceptor.go: -------------------------------------------------------------------------------- 1 | package clientinterceptors 2 | 3 | import ( 4 | "context" 5 | "path" 6 | 7 | "gozerosource/code/balancer/zrpc/internal/codes" 8 | 9 | "github.com/zeromicro/go-zero/core/breaker" 10 | 11 | "google.golang.org/grpc" 12 | ) 13 | 14 | // 断路拦截器 15 | // BreakerInterceptor is an interceptor that acts as a circuit breaker. 16 | func BreakerInterceptor(ctx context.Context, method string, req, reply interface{}, 17 | cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption, 18 | ) error { 19 | breakerName := path.Join(cc.Target(), method) 20 | return breaker.DoWithAcceptable(breakerName, func() error { 21 | return invoker(ctx, method, req, reply, cc, opts...) 22 | }, codes.Acceptable) 23 | } 24 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/clientinterceptors/durationinterceptor.go: -------------------------------------------------------------------------------- 1 | package clientinterceptors 2 | 3 | import ( 4 | "context" 5 | "path" 6 | "time" 7 | 8 | "github.com/zeromicro/go-zero/core/logx" 9 | "github.com/zeromicro/go-zero/core/syncx" 10 | "github.com/zeromicro/go-zero/core/timex" 11 | 12 | "google.golang.org/grpc" 13 | ) 14 | 15 | const defaultSlowThreshold = time.Millisecond * 500 16 | 17 | var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold) 18 | 19 | // 执行时间拦截器(执行出错、执行慢输出日志) 20 | // DurationInterceptor is an interceptor that logs the processing time. 21 | func DurationInterceptor(ctx context.Context, method string, req, reply interface{}, 22 | cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption, 23 | ) error { 24 | serverName := path.Join(cc.Target(), method) 25 | start := timex.Now() 26 | err := invoker(ctx, method, req, reply, cc, opts...) 27 | if err != nil { 28 | logx.WithContext(ctx).WithDuration(timex.Since(start)).Infof("fail - %s - %v - %s", 29 | serverName, req, err.Error()) 30 | } else { 31 | elapsed := timex.Since(start) 32 | if elapsed > slowThreshold.Load() { 33 | logx.WithContext(ctx).WithDuration(elapsed).Slowf("[RPC] ok - slowcall - %s - %v - %v", 34 | serverName, req, reply) 35 | } 36 | } 37 | 38 | return err 39 | } 40 | 41 | // 设置慢阈值 42 | // SetSlowThreshold sets the slow threshold. 43 | func SetSlowThreshold(threshold time.Duration) { 44 | slowThreshold.Set(threshold) 45 | } 46 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/clientinterceptors/prometheusinterceptor.go: -------------------------------------------------------------------------------- 1 | package clientinterceptors 2 | 3 | import ( 4 | "context" 5 | "strconv" 6 | "time" 7 | 8 | "github.com/zeromicro/go-zero/core/metric" 9 | "github.com/zeromicro/go-zero/core/prometheus" 10 | "github.com/zeromicro/go-zero/core/timex" 11 | 12 | "google.golang.org/grpc" 13 | "google.golang.org/grpc/status" 14 | ) 15 | 16 | const clientNamespace = "rpc_client" 17 | 18 | var ( 19 | metricClientReqDur = metric.NewHistogramVec(&metric.HistogramVecOpts{ 20 | Namespace: clientNamespace, 21 | Subsystem: "requests", 22 | Name: "duration_ms", 23 | Help: "rpc client requests duration(ms).", 24 | Labels: []string{"method"}, 25 | Buckets: []float64{5, 10, 25, 50, 100, 250, 500, 1000}, 26 | }) 27 | 28 | metricClientReqCodeTotal = metric.NewCounterVec(&metric.CounterVecOpts{ 29 | Namespace: clientNamespace, 30 | Subsystem: "requests", 31 | Name: "code_total", 32 | Help: "rpc client requests code count.", 33 | Labels: []string{"method", "code"}, 34 | }) 35 | ) 36 | 37 | // 服务状态上报 Prometheus 拦截器 38 | // PrometheusInterceptor is an interceptor that reports to prometheus server. 39 | func PrometheusInterceptor(ctx context.Context, method string, req, reply interface{}, 40 | cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption, 41 | ) error { 42 | if !prometheus.Enabled() { 43 | return invoker(ctx, method, req, reply, cc, opts...) 44 | } 45 | 46 | startTime := timex.Now() 47 | err := invoker(ctx, method, req, reply, cc, opts...) 48 | metricClientReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), method) 49 | metricClientReqCodeTotal.Inc(method, strconv.Itoa(int(status.Code(err)))) 50 | return err 51 | } 52 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/clientinterceptors/timeoutinterceptor.go: -------------------------------------------------------------------------------- 1 | package clientinterceptors 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "google.golang.org/grpc" 8 | ) 9 | 10 | // 服务超时拦截器 11 | // TimeoutInterceptor is an interceptor that controls timeout. 12 | func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor { 13 | return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, 14 | invoker grpc.UnaryInvoker, opts ...grpc.CallOption, 15 | ) error { 16 | if timeout <= 0 { 17 | return invoker(ctx, method, req, reply, cc, opts...) 18 | } 19 | 20 | ctx, cancel := context.WithTimeout(ctx, timeout) 21 | defer cancel() 22 | 23 | return invoker(ctx, method, req, reply, cc, opts...) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/clientinterceptors/tracinginterceptor.go: -------------------------------------------------------------------------------- 1 | package clientinterceptors 2 | 3 | import ( 4 | "context" 5 | "io" 6 | 7 | ztrace "github.com/zeromicro/go-zero/core/trace" 8 | 9 | "go.opentelemetry.io/otel" 10 | "go.opentelemetry.io/otel/codes" 11 | "go.opentelemetry.io/otel/trace" 12 | "google.golang.org/grpc" 13 | gcodes "google.golang.org/grpc/codes" 14 | "google.golang.org/grpc/metadata" 15 | "google.golang.org/grpc/status" 16 | ) 17 | 18 | const ( 19 | receiveEndEvent streamEventType = iota 20 | errorEvent 21 | ) 22 | 23 | // 链路追踪拦截器 24 | // UnaryTracingInterceptor returns a grpc.UnaryClientInterceptor for opentelemetry. 25 | func UnaryTracingInterceptor(ctx context.Context, method string, req, reply interface{}, 26 | cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption, 27 | ) error { 28 | ctx, span := startSpan(ctx, method, cc.Target()) 29 | defer span.End() 30 | 31 | ztrace.MessageSent.Event(ctx, 1, req) 32 | ztrace.MessageReceived.Event(ctx, 1, reply) 33 | 34 | if err := invoker(ctx, method, req, reply, cc, opts...); err != nil { 35 | s, ok := status.FromError(err) 36 | if ok { 37 | span.SetStatus(codes.Error, s.Message()) 38 | span.SetAttributes(ztrace.StatusCodeAttr(s.Code())) 39 | } else { 40 | span.SetStatus(codes.Error, err.Error()) 41 | } 42 | return err 43 | } 44 | 45 | span.SetAttributes(ztrace.StatusCodeAttr(gcodes.OK)) 46 | return nil 47 | } 48 | 49 | // 链路追踪拦截器(流) 50 | // StreamTracingInterceptor returns a grpc.StreamClientInterceptor for opentelemetry. 51 | func StreamTracingInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, 52 | method string, streamer grpc.Streamer, opts ...grpc.CallOption, 53 | ) (grpc.ClientStream, error) { 54 | ctx, span := startSpan(ctx, method, cc.Target()) 55 | s, err := streamer(ctx, desc, cc, method, opts...) 56 | if err != nil { 57 | st, ok := status.FromError(err) 58 | if ok { 59 | span.SetStatus(codes.Error, st.Message()) 60 | span.SetAttributes(ztrace.StatusCodeAttr(st.Code())) 61 | } else { 62 | span.SetStatus(codes.Error, err.Error()) 63 | } 64 | span.End() 65 | return s, err 66 | } 67 | 68 | stream := wrapClientStream(ctx, s, desc) 69 | 70 | go func() { 71 | if err := <-stream.Finished; err != nil { 72 | s, ok := status.FromError(err) 73 | if ok { 74 | span.SetStatus(codes.Error, s.Message()) 75 | span.SetAttributes(ztrace.StatusCodeAttr(s.Code())) 76 | } else { 77 | span.SetStatus(codes.Error, err.Error()) 78 | } 79 | } else { 80 | span.SetAttributes(ztrace.StatusCodeAttr(gcodes.OK)) 81 | } 82 | 83 | span.End() 84 | }() 85 | 86 | return stream, nil 87 | } 88 | 89 | type ( 90 | streamEventType int 91 | 92 | streamEvent struct { 93 | Type streamEventType 94 | Err error 95 | } 96 | 97 | clientStream struct { 98 | grpc.ClientStream 99 | Finished chan error 100 | desc *grpc.StreamDesc 101 | events chan streamEvent 102 | eventsDone chan struct{} 103 | receivedMessageID int 104 | sentMessageID int 105 | } 106 | ) 107 | 108 | func (w *clientStream) CloseSend() error { 109 | err := w.ClientStream.CloseSend() 110 | if err != nil { 111 | w.sendStreamEvent(errorEvent, err) 112 | } 113 | 114 | return err 115 | } 116 | 117 | func (w *clientStream) Header() (metadata.MD, error) { 118 | md, err := w.ClientStream.Header() 119 | if err != nil { 120 | w.sendStreamEvent(errorEvent, err) 121 | } 122 | 123 | return md, err 124 | } 125 | 126 | func (w *clientStream) RecvMsg(m interface{}) error { 127 | err := w.ClientStream.RecvMsg(m) 128 | if err == nil && !w.desc.ServerStreams { 129 | w.sendStreamEvent(receiveEndEvent, nil) 130 | } else if err == io.EOF { 131 | w.sendStreamEvent(receiveEndEvent, nil) 132 | } else if err != nil { 133 | w.sendStreamEvent(errorEvent, err) 134 | } else { 135 | w.receivedMessageID++ 136 | ztrace.MessageReceived.Event(w.Context(), w.receivedMessageID, m) 137 | } 138 | 139 | return err 140 | } 141 | 142 | func (w *clientStream) SendMsg(m interface{}) error { 143 | err := w.ClientStream.SendMsg(m) 144 | w.sentMessageID++ 145 | ztrace.MessageSent.Event(w.Context(), w.sentMessageID, m) 146 | if err != nil { 147 | w.sendStreamEvent(errorEvent, err) 148 | } 149 | 150 | return err 151 | } 152 | 153 | func (w *clientStream) sendStreamEvent(eventType streamEventType, err error) { 154 | select { 155 | case <-w.eventsDone: 156 | case w.events <- streamEvent{Type: eventType, Err: err}: 157 | } 158 | } 159 | 160 | func startSpan(ctx context.Context, method, target string) (context.Context, trace.Span) { 161 | var md metadata.MD 162 | requestMetadata, ok := metadata.FromOutgoingContext(ctx) 163 | if ok { 164 | md = requestMetadata.Copy() 165 | } else { 166 | md = metadata.MD{} 167 | } 168 | tr := otel.Tracer(ztrace.TraceName) 169 | name, attr := ztrace.SpanInfo(method, target) 170 | ctx, span := tr.Start(ctx, name, trace.WithSpanKind(trace.SpanKindClient), 171 | trace.WithAttributes(attr...)) 172 | ztrace.Inject(ctx, otel.GetTextMapPropagator(), &md) 173 | ctx = metadata.NewOutgoingContext(ctx, md) 174 | 175 | return ctx, span 176 | } 177 | 178 | // 流包装器 179 | // wrapClientStream wraps s with given ctx and desc. 180 | func wrapClientStream(ctx context.Context, s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream { 181 | events := make(chan streamEvent) 182 | eventsDone := make(chan struct{}) 183 | finished := make(chan error) 184 | 185 | go func() { 186 | defer close(eventsDone) 187 | 188 | for { 189 | select { 190 | case event := <-events: 191 | switch event.Type { 192 | case receiveEndEvent: 193 | finished <- nil 194 | return 195 | case errorEvent: 196 | finished <- event.Err 197 | return 198 | } 199 | case <-ctx.Done(): 200 | finished <- ctx.Err() 201 | return 202 | } 203 | } 204 | }() 205 | 206 | return &clientStream{ 207 | ClientStream: s, 208 | desc: desc, 209 | events: events, 210 | eventsDone: eventsDone, 211 | Finished: finished, 212 | } 213 | } 214 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/codes/accept.go: -------------------------------------------------------------------------------- 1 | package codes 2 | 3 | import ( 4 | "google.golang.org/grpc/codes" 5 | "google.golang.org/grpc/status" 6 | ) 7 | 8 | // 检测是否是可接受错误 9 | // Acceptable checks if given error is acceptable. 10 | func Acceptable(err error) bool { 11 | switch status.Code(err) { 12 | case codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss: 13 | return false 14 | default: 15 | return true 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/rpclogger.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/zeromicro/go-zero/core/logx" 7 | "google.golang.org/grpc/grpclog" 8 | ) 9 | 10 | // because grpclog.errorLog is not exported, we need to define our own. 11 | const errorLevel = 2 12 | 13 | var once sync.Once 14 | 15 | // A Logger is a rpc logger. 16 | type Logger struct{} 17 | 18 | // InitLogger initializes the rpc logger. 19 | func InitLogger() { 20 | once.Do(func() { 21 | grpclog.SetLoggerV2(new(Logger)) 22 | }) 23 | } 24 | 25 | // Error logs the given args into error log. 26 | func (l *Logger) Error(args ...interface{}) { 27 | logx.Error(args...) 28 | } 29 | 30 | // Errorf logs the given args with format into error log. 31 | func (l *Logger) Errorf(format string, args ...interface{}) { 32 | logx.Errorf(format, args...) 33 | } 34 | 35 | // Errorln logs the given args into error log with newline. 36 | func (l *Logger) Errorln(args ...interface{}) { 37 | logx.Error(args...) 38 | } 39 | 40 | // Fatal logs the given args into error log. 41 | func (l *Logger) Fatal(args ...interface{}) { 42 | logx.Error(args...) 43 | } 44 | 45 | // Fatalf logs the given args with format into error log. 46 | func (l *Logger) Fatalf(format string, args ...interface{}) { 47 | logx.Errorf(format, args...) 48 | } 49 | 50 | // Fatalln logs args into error log with newline. 51 | func (l *Logger) Fatalln(args ...interface{}) { 52 | logx.Error(args...) 53 | } 54 | 55 | // Info ignores the grpc info logs. 56 | func (l *Logger) Info(args ...interface{}) { 57 | // ignore builtin grpc info 58 | } 59 | 60 | // Infoln ignores the grpc info logs. 61 | func (l *Logger) Infoln(args ...interface{}) { 62 | // ignore builtin grpc info 63 | } 64 | 65 | // Infof ignores the grpc info logs. 66 | func (l *Logger) Infof(format string, args ...interface{}) { 67 | // ignore builtin grpc info 68 | } 69 | 70 | // V checks if meet required log level. 71 | func (l *Logger) V(v int) bool { 72 | return v >= errorLevel 73 | } 74 | 75 | // Warning ignores the grpc warning logs. 76 | func (l *Logger) Warning(args ...interface{}) { 77 | // ignore builtin grpc warning 78 | } 79 | 80 | // Warningf ignores the grpc warning logs. 81 | func (l *Logger) Warningf(format string, args ...interface{}) { 82 | // ignore builtin grpc warning 83 | } 84 | 85 | // Warningln ignores the grpc warning logs. 86 | func (l *Logger) Warningln(args ...interface{}) { 87 | // ignore builtin grpc warning 88 | } 89 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/rpcpubserver.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "os" 5 | "strings" 6 | 7 | "github.com/zeromicro/go-zero/core/discov" 8 | "github.com/zeromicro/go-zero/core/netx" 9 | ) 10 | 11 | const ( 12 | allEths = "0.0.0.0" 13 | envPodIp = "POD_IP" 14 | ) 15 | 16 | // NewRpcPubServer returns a Server. 17 | // 初始化 rpc 发布服务,用于服务发现 18 | func NewRpcPubServer(etcd discov.EtcdConf, listenOn string, opts ...ServerOption) (Server, error) { 19 | registerEtcd := func() error { 20 | pubListenOn := figureOutListenOn(listenOn) 21 | var pubOpts []discov.PubOption 22 | if etcd.HasAccount() { 23 | pubOpts = append(pubOpts, discov.WithPubEtcdAccount(etcd.User, etcd.Pass)) 24 | } 25 | if etcd.HasTLS() { 26 | pubOpts = append(pubOpts, discov.WithPubEtcdTLS(etcd.CertFile, etcd.CertKeyFile, 27 | etcd.CACertFile, etcd.InsecureSkipVerify)) 28 | } 29 | pubClient := discov.NewPublisher(etcd.Hosts, etcd.Key, pubListenOn, pubOpts...) 30 | return pubClient.KeepAlive() 31 | } 32 | server := keepAliveServer{ 33 | registerEtcd: registerEtcd, 34 | Server: NewRpcServer(listenOn, opts...), 35 | } 36 | 37 | return server, nil 38 | } 39 | 40 | // 连接保持服务 41 | type keepAliveServer struct { 42 | registerEtcd func() error 43 | Server 44 | } 45 | 46 | func (ags keepAliveServer) Start(fn RegisterFn) error { 47 | if err := ags.registerEtcd(); err != nil { 48 | return err 49 | } 50 | 51 | return ags.Server.Start(fn) 52 | } 53 | 54 | // 重新解析配置 55 | func figureOutListenOn(listenOn string) string { 56 | fields := strings.Split(listenOn, ":") 57 | if len(fields) == 0 { 58 | return listenOn 59 | } 60 | 61 | host := fields[0] 62 | if len(host) > 0 && host != allEths { 63 | return listenOn 64 | } 65 | 66 | ip := os.Getenv(envPodIp) 67 | if len(ip) == 0 { 68 | ip = netx.InternalIp() 69 | } 70 | if len(ip) == 0 { 71 | return listenOn 72 | } 73 | 74 | return strings.Join(append([]string{ip}, fields[1:]...), ":") 75 | } 76 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/rpcserver.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "net" 5 | 6 | "gozerosource/code/balancer/zrpc/internal/serverinterceptors" 7 | 8 | "github.com/zeromicro/go-zero/core/proc" 9 | "github.com/zeromicro/go-zero/core/stat" 10 | "google.golang.org/grpc" 11 | ) 12 | 13 | type ( 14 | // ServerOption defines the method to customize a rpcServerOptions. 15 | ServerOption func(options *rpcServerOptions) 16 | 17 | rpcServerOptions struct { 18 | metrics *stat.Metrics 19 | } 20 | 21 | rpcServer struct { 22 | name string 23 | *baseRpcServer 24 | } 25 | ) 26 | 27 | func init() { 28 | InitLogger() 29 | } 30 | 31 | // NewRpcServer returns a Server. 32 | // 初始化 Rpc 服务 33 | func NewRpcServer(address string, opts ...ServerOption) Server { 34 | var options rpcServerOptions 35 | for _, opt := range opts { 36 | opt(&options) 37 | } 38 | if options.metrics == nil { 39 | options.metrics = stat.NewMetrics(address) 40 | } 41 | 42 | return &rpcServer{ 43 | baseRpcServer: newBaseRpcServer(address, &options), 44 | } 45 | } 46 | 47 | // 设置服务名称 48 | func (s *rpcServer) SetName(name string) { 49 | s.name = name 50 | s.baseRpcServer.SetName(name) 51 | } 52 | 53 | // 启动服务 54 | func (s *rpcServer) Start(register RegisterFn) error { 55 | lis, err := net.Listen("tcp", s.address) 56 | if err != nil { 57 | return err 58 | } 59 | 60 | // 基础拦截器 61 | unaryInterceptors := []grpc.UnaryServerInterceptor{ 62 | serverinterceptors.UnaryTracingInterceptor, 63 | serverinterceptors.UnaryCrashInterceptor, 64 | serverinterceptors.UnaryStatInterceptor(s.metrics), 65 | serverinterceptors.UnaryPrometheusInterceptor, 66 | serverinterceptors.UnaryBreakerInterceptor, 67 | } 68 | unaryInterceptors = append(unaryInterceptors, s.unaryInterceptors...) 69 | // 基础流拦截器 70 | streamInterceptors := []grpc.StreamServerInterceptor{ 71 | serverinterceptors.StreamTracingInterceptor, 72 | serverinterceptors.StreamCrashInterceptor, 73 | serverinterceptors.StreamBreakerInterceptor, 74 | } 75 | streamInterceptors = append(streamInterceptors, s.streamInterceptors...) 76 | options := append(s.options, WithUnaryServerInterceptors(unaryInterceptors...), 77 | WithStreamServerInterceptors(streamInterceptors...)) 78 | server := grpc.NewServer(options...) 79 | register(server) // 加载业务服务 80 | // we need to make sure all others are wrapped up 81 | // so we do graceful stop at shutdown phase instead of wrap up phase 82 | // 确保所有其他服务都被包裹 83 | // 这样可以在关闭服务时对服务做优雅停止 84 | waitForCalled := proc.AddWrapUpListener(func() { 85 | server.GracefulStop() 86 | }) 87 | defer waitForCalled() 88 | 89 | return server.Serve(lis) 90 | } 91 | 92 | // WithMetrics returns a func that sets metrics to a Server. 93 | // 注册服务指标监听器 94 | func WithMetrics(metrics *stat.Metrics) ServerOption { 95 | return func(options *rpcServerOptions) { 96 | options.metrics = metrics 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/server.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "github.com/zeromicro/go-zero/core/stat" 5 | "google.golang.org/grpc" 6 | ) 7 | 8 | type ( 9 | // RegisterFn defines the method to register a server. 10 | // 定义注册服务器的方法,用于加载业务服务 11 | RegisterFn func(*grpc.Server) 12 | 13 | // Server interface represents a rpc server. 14 | // 服务接口 15 | Server interface { 16 | // 添加配置 17 | AddOptions(options ...grpc.ServerOption) 18 | // 添加 rpc 数据流拦截器 19 | AddStreamInterceptors(interceptors ...grpc.StreamServerInterceptor) 20 | // 添加拦截器 21 | AddUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor) 22 | // 设置服务名称 23 | SetName(string) 24 | // 启动服务 25 | Start(register RegisterFn) error 26 | } 27 | 28 | // rpc 服务基类 29 | baseRpcServer struct { 30 | address string 31 | metrics *stat.Metrics 32 | options []grpc.ServerOption 33 | streamInterceptors []grpc.StreamServerInterceptor 34 | unaryInterceptors []grpc.UnaryServerInterceptor 35 | } 36 | ) 37 | 38 | // 初始化 rpc 服务基类 39 | func newBaseRpcServer(address string, rpcServerOpts *rpcServerOptions) *baseRpcServer { 40 | return &baseRpcServer{ 41 | address: address, 42 | metrics: rpcServerOpts.metrics, 43 | } 44 | } 45 | 46 | func (s *baseRpcServer) AddOptions(options ...grpc.ServerOption) { 47 | s.options = append(s.options, options...) 48 | } 49 | 50 | func (s *baseRpcServer) AddStreamInterceptors(interceptors ...grpc.StreamServerInterceptor) { 51 | s.streamInterceptors = append(s.streamInterceptors, interceptors...) 52 | } 53 | 54 | func (s *baseRpcServer) AddUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor) { 55 | s.unaryInterceptors = append(s.unaryInterceptors, interceptors...) 56 | } 57 | 58 | func (s *baseRpcServer) SetName(name string) { 59 | s.metrics.SetName(name) 60 | } 61 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/serverinterceptors/authinterceptor.go: -------------------------------------------------------------------------------- 1 | package serverinterceptors 2 | 3 | import ( 4 | "context" 5 | 6 | "gozerosource/code/balancer/zrpc/internal/auth" 7 | 8 | "google.golang.org/grpc" 9 | ) 10 | 11 | // 权限拦截器(数据流) 12 | // StreamAuthorizeInterceptor returns a func that uses given authenticator in processing stream requests. 13 | func StreamAuthorizeInterceptor(authenticator *auth.Authenticator) grpc.StreamServerInterceptor { 14 | return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, 15 | handler grpc.StreamHandler, 16 | ) error { 17 | if err := authenticator.Authenticate(stream.Context()); err != nil { 18 | return err 19 | } 20 | 21 | return handler(srv, stream) 22 | } 23 | } 24 | 25 | // 权限拦截器 26 | // UnaryAuthorizeInterceptor returns a func that uses given authenticator in processing unary requests. 27 | func UnaryAuthorizeInterceptor(authenticator *auth.Authenticator) grpc.UnaryServerInterceptor { 28 | return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, 29 | handler grpc.UnaryHandler, 30 | ) (interface{}, error) { 31 | if err := authenticator.Authenticate(ctx); err != nil { 32 | return nil, err 33 | } 34 | 35 | return handler(ctx, req) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/serverinterceptors/breakerinterceptor.go: -------------------------------------------------------------------------------- 1 | package serverinterceptors 2 | 3 | import ( 4 | "context" 5 | 6 | "gozerosource/code/balancer/zrpc/internal/codes" 7 | 8 | "github.com/zeromicro/go-zero/core/breaker" 9 | "google.golang.org/grpc" 10 | ) 11 | 12 | // 断路拦截器(数据流) 13 | // StreamBreakerInterceptor is an interceptor that acts as a circuit breaker. 14 | func StreamBreakerInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, 15 | handler grpc.StreamHandler, 16 | ) (err error) { 17 | breakerName := info.FullMethod 18 | return breaker.DoWithAcceptable(breakerName, func() error { 19 | return handler(srv, stream) 20 | }, codes.Acceptable) 21 | } 22 | 23 | // 断路拦截器 24 | // UnaryBreakerInterceptor is an interceptor that acts as a circuit breaker. 25 | func UnaryBreakerInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, 26 | handler grpc.UnaryHandler, 27 | ) (resp interface{}, err error) { 28 | breakerName := info.FullMethod 29 | err = breaker.DoWithAcceptable(breakerName, func() error { 30 | var err error 31 | resp, err = handler(ctx, req) 32 | return err 33 | }, codes.Acceptable) 34 | 35 | return resp, err 36 | } 37 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/serverinterceptors/crashinterceptor.go: -------------------------------------------------------------------------------- 1 | package serverinterceptors 2 | 3 | import ( 4 | "context" 5 | "runtime/debug" 6 | 7 | "github.com/zeromicro/go-zero/core/logx" 8 | "google.golang.org/grpc" 9 | "google.golang.org/grpc/codes" 10 | "google.golang.org/grpc/status" 11 | ) 12 | 13 | // crash 拦截器(数据流) 14 | // 用于服务 crash 后自动恢复,保证鲁棒性 15 | // StreamCrashInterceptor catches panics in processing stream requests and recovers. 16 | func StreamCrashInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, 17 | handler grpc.StreamHandler, 18 | ) (err error) { 19 | defer handleCrash(func(r interface{}) { 20 | err = toPanicError(r) 21 | }) 22 | 23 | return handler(srv, stream) 24 | } 25 | 26 | // crash 拦截器 27 | // 用于服务 crash 后自动恢复,保证鲁棒性 28 | // UnaryCrashInterceptor catches panics in processing unary requests and recovers. 29 | func UnaryCrashInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, 30 | handler grpc.UnaryHandler, 31 | ) (resp interface{}, err error) { 32 | defer handleCrash(func(r interface{}) { 33 | err = toPanicError(r) 34 | }) 35 | 36 | return handler(ctx, req) 37 | } 38 | 39 | func handleCrash(handler func(interface{})) { 40 | if r := recover(); r != nil { 41 | handler(r) 42 | } 43 | } 44 | 45 | func toPanicError(r interface{}) error { 46 | logx.Errorf("%+v\n\n%s", r, debug.Stack()) 47 | return status.Errorf(codes.Internal, "panic: %v", r) 48 | } 49 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/serverinterceptors/prometheusinterceptor.go: -------------------------------------------------------------------------------- 1 | package serverinterceptors 2 | 3 | import ( 4 | "context" 5 | "strconv" 6 | "time" 7 | 8 | "github.com/zeromicro/go-zero/core/metric" 9 | "github.com/zeromicro/go-zero/core/prometheus" 10 | "github.com/zeromicro/go-zero/core/timex" 11 | "google.golang.org/grpc" 12 | "google.golang.org/grpc/status" 13 | ) 14 | 15 | const serverNamespace = "rpc_server" 16 | 17 | var ( 18 | metricServerReqDur = metric.NewHistogramVec(&metric.HistogramVecOpts{ 19 | Namespace: serverNamespace, 20 | Subsystem: "requests", 21 | Name: "duration_ms", 22 | Help: "rpc server requests duration(ms).", 23 | Labels: []string{"method"}, 24 | Buckets: []float64{5, 10, 25, 50, 100, 250, 500, 1000}, 25 | }) 26 | 27 | metricServerReqCodeTotal = metric.NewCounterVec(&metric.CounterVecOpts{ 28 | Namespace: serverNamespace, 29 | Subsystem: "requests", 30 | Name: "code_total", 31 | Help: "rpc server requests code count.", 32 | Labels: []string{"method", "code"}, 33 | }) 34 | ) 35 | 36 | // 服务状态上报 Prometheus 拦截器 37 | // UnaryPrometheusInterceptor reports the statistics to the prometheus server. 38 | func UnaryPrometheusInterceptor(ctx context.Context, req interface{}, 39 | info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, 40 | ) (interface{}, error) { 41 | if !prometheus.Enabled() { 42 | return handler(ctx, req) 43 | } 44 | 45 | startTime := timex.Now() 46 | resp, err := handler(ctx, req) 47 | metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), info.FullMethod) 48 | metricServerReqCodeTotal.Inc(info.FullMethod, strconv.Itoa(int(status.Code(err)))) 49 | return resp, err 50 | } 51 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/serverinterceptors/sheddinginterceptor.go: -------------------------------------------------------------------------------- 1 | package serverinterceptors 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | 7 | "github.com/zeromicro/go-zero/core/load" 8 | "github.com/zeromicro/go-zero/core/stat" 9 | "google.golang.org/grpc" 10 | ) 11 | 12 | const serviceType = "rpc" 13 | 14 | var ( 15 | sheddingStat *load.SheddingStat 16 | lock sync.Mutex 17 | ) 18 | 19 | // 服务降载拦截器 20 | // UnarySheddingInterceptor returns a func that does load shedding on processing unary requests. 21 | func UnarySheddingInterceptor(shedder load.Shedder, metrics *stat.Metrics) grpc.UnaryServerInterceptor { 22 | ensureSheddingStat() 23 | 24 | return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, 25 | handler grpc.UnaryHandler, 26 | ) (val interface{}, err error) { 27 | sheddingStat.IncrementTotal() 28 | var promise load.Promise 29 | // 检查是否被降载 30 | promise, err = shedder.Allow() 31 | // 降载,记录相关日志与指标 32 | if err != nil { 33 | metrics.AddDrop() 34 | sheddingStat.IncrementDrop() 35 | return 36 | } 37 | // 最后回调执行结果 38 | defer func() { 39 | // 执行失败 40 | if err == context.DeadlineExceeded { 41 | promise.Fail() 42 | // 执行成功 43 | } else { 44 | sheddingStat.IncrementPass() 45 | promise.Pass() 46 | } 47 | }() 48 | // 执行业务方法 49 | return handler(ctx, req) 50 | } 51 | } 52 | 53 | func ensureSheddingStat() { 54 | lock.Lock() 55 | if sheddingStat == nil { 56 | sheddingStat = load.NewSheddingStat(serviceType) 57 | } 58 | lock.Unlock() 59 | } 60 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/serverinterceptors/statinterceptor.go: -------------------------------------------------------------------------------- 1 | package serverinterceptors 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "time" 7 | 8 | "github.com/zeromicro/go-zero/core/logx" 9 | "github.com/zeromicro/go-zero/core/stat" 10 | "github.com/zeromicro/go-zero/core/syncx" 11 | "github.com/zeromicro/go-zero/core/timex" 12 | "google.golang.org/grpc" 13 | "google.golang.org/grpc/peer" 14 | ) 15 | 16 | const defaultSlowThreshold = time.Millisecond * 500 17 | 18 | var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold) 19 | 20 | // SetSlowThreshold sets the slow threshold. 21 | func SetSlowThreshold(threshold time.Duration) { 22 | slowThreshold.Set(threshold) 23 | } 24 | 25 | // 服务状态上报拦截器 26 | // UnaryStatInterceptor returns a func that uses given metrics to report stats. 27 | func UnaryStatInterceptor(metrics *stat.Metrics) grpc.UnaryServerInterceptor { 28 | return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, 29 | handler grpc.UnaryHandler, 30 | ) (resp interface{}, err error) { 31 | defer handleCrash(func(r interface{}) { 32 | err = toPanicError(r) 33 | }) 34 | 35 | startTime := timex.Now() 36 | defer func() { 37 | duration := timex.Since(startTime) 38 | metrics.Add(stat.Task{ 39 | Duration: duration, 40 | }) 41 | logDuration(ctx, info.FullMethod, req, duration) 42 | }() 43 | 44 | return handler(ctx, req) 45 | } 46 | } 47 | 48 | func logDuration(ctx context.Context, method string, req interface{}, duration time.Duration) { 49 | var addr string 50 | client, ok := peer.FromContext(ctx) 51 | if ok { 52 | addr = client.Addr.String() 53 | } 54 | content, err := json.Marshal(req) 55 | if err != nil { 56 | logx.WithContext(ctx).Errorf("%s - %s", addr, err.Error()) 57 | } else if duration > slowThreshold.Load() { 58 | logx.WithContext(ctx).WithDuration(duration).Slowf("[RPC] slowcall - %s - %s - %s", 59 | addr, method, string(content)) 60 | } else { 61 | logx.WithContext(ctx).WithDuration(duration).Infof("%s - %s - %s", addr, method, string(content)) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/serverinterceptors/timeoutinterceptor.go: -------------------------------------------------------------------------------- 1 | package serverinterceptors 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "runtime/debug" 7 | "strings" 8 | "sync" 9 | "time" 10 | 11 | "google.golang.org/grpc" 12 | "google.golang.org/grpc/codes" 13 | "google.golang.org/grpc/status" 14 | ) 15 | 16 | // 服务超时拦截器 17 | // UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests. 18 | func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor { 19 | return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, 20 | handler grpc.UnaryHandler, 21 | ) (interface{}, error) { 22 | ctx, cancel := context.WithTimeout(ctx, timeout) 23 | defer cancel() 24 | 25 | var resp interface{} 26 | var err error 27 | var lock sync.Mutex 28 | done := make(chan struct{}) 29 | // create channel with buffer size 1 to avoid goroutine leak 30 | panicChan := make(chan interface{}, 1) 31 | go func() { 32 | defer func() { 33 | if p := recover(); p != nil { 34 | // attach call stack to avoid missing in different goroutine 35 | panicChan <- fmt.Sprintf("%+v\n\n%s", p, strings.TrimSpace(string(debug.Stack()))) 36 | } 37 | }() 38 | 39 | lock.Lock() 40 | defer lock.Unlock() 41 | resp, err = handler(ctx, req) 42 | close(done) 43 | }() 44 | 45 | select { 46 | case p := <-panicChan: 47 | panic(p) 48 | case <-done: 49 | lock.Lock() 50 | defer lock.Unlock() 51 | return resp, err 52 | case <-ctx.Done(): 53 | err := ctx.Err() 54 | 55 | if err == context.Canceled { 56 | err = status.Error(codes.Canceled, err.Error()) 57 | } else if err == context.DeadlineExceeded { 58 | err = status.Error(codes.DeadlineExceeded, err.Error()) 59 | } 60 | return nil, err 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /code/balancer/zrpc/internal/serverinterceptors/tracinginterceptor.go: -------------------------------------------------------------------------------- 1 | package serverinterceptors 2 | 3 | import ( 4 | "context" 5 | 6 | ztrace "github.com/zeromicro/go-zero/core/trace" 7 | "go.opentelemetry.io/otel" 8 | "go.opentelemetry.io/otel/baggage" 9 | "go.opentelemetry.io/otel/codes" 10 | "go.opentelemetry.io/otel/trace" 11 | "google.golang.org/grpc" 12 | gcodes "google.golang.org/grpc/codes" 13 | "google.golang.org/grpc/metadata" 14 | "google.golang.org/grpc/status" 15 | ) 16 | 17 | // 链路追踪拦截器 18 | // UnaryTracingInterceptor is a grpc.UnaryServerInterceptor for opentelemetry. 19 | func UnaryTracingInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, 20 | handler grpc.UnaryHandler, 21 | ) (interface{}, error) { 22 | ctx, span := startSpan(ctx, info.FullMethod) 23 | defer span.End() 24 | 25 | ztrace.MessageReceived.Event(ctx, 1, req) 26 | resp, err := handler(ctx, req) 27 | if err != nil { 28 | s, ok := status.FromError(err) 29 | if ok { 30 | span.SetStatus(codes.Error, s.Message()) 31 | span.SetAttributes(ztrace.StatusCodeAttr(s.Code())) 32 | ztrace.MessageSent.Event(ctx, 1, s.Proto()) 33 | } else { 34 | span.SetStatus(codes.Error, err.Error()) 35 | } 36 | return nil, err 37 | } 38 | 39 | span.SetAttributes(ztrace.StatusCodeAttr(gcodes.OK)) 40 | ztrace.MessageSent.Event(ctx, 1, resp) 41 | 42 | return resp, nil 43 | } 44 | 45 | // 链路追踪拦截器(流) 46 | // StreamTracingInterceptor returns a grpc.StreamServerInterceptor for opentelemetry. 47 | func StreamTracingInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, 48 | handler grpc.StreamHandler, 49 | ) error { 50 | ctx, span := startSpan(ss.Context(), info.FullMethod) 51 | defer span.End() 52 | 53 | if err := handler(srv, wrapServerStream(ctx, ss)); err != nil { 54 | s, ok := status.FromError(err) 55 | if ok { 56 | span.SetStatus(codes.Error, s.Message()) 57 | span.SetAttributes(ztrace.StatusCodeAttr(s.Code())) 58 | } else { 59 | span.SetStatus(codes.Error, err.Error()) 60 | } 61 | return err 62 | } 63 | 64 | span.SetAttributes(ztrace.StatusCodeAttr(gcodes.OK)) 65 | return nil 66 | } 67 | 68 | // serverStream wraps around the embedded grpc.ServerStream, 69 | // and intercepts the RecvMsg and SendMsg method call. 70 | type serverStream struct { 71 | grpc.ServerStream 72 | ctx context.Context 73 | receivedMessageID int 74 | sentMessageID int 75 | } 76 | 77 | func (w *serverStream) Context() context.Context { 78 | return w.ctx 79 | } 80 | 81 | func (w *serverStream) RecvMsg(m interface{}) error { 82 | err := w.ServerStream.RecvMsg(m) 83 | if err == nil { 84 | w.receivedMessageID++ 85 | ztrace.MessageReceived.Event(w.Context(), w.receivedMessageID, m) 86 | } 87 | 88 | return err 89 | } 90 | 91 | func (w *serverStream) SendMsg(m interface{}) error { 92 | err := w.ServerStream.SendMsg(m) 93 | w.sentMessageID++ 94 | ztrace.MessageSent.Event(w.Context(), w.sentMessageID, m) 95 | 96 | return err 97 | } 98 | 99 | func startSpan(ctx context.Context, method string) (context.Context, trace.Span) { 100 | var md metadata.MD 101 | requestMetadata, ok := metadata.FromIncomingContext(ctx) 102 | if ok { 103 | md = requestMetadata.Copy() 104 | } else { 105 | md = metadata.MD{} 106 | } 107 | bags, spanCtx := ztrace.Extract(ctx, otel.GetTextMapPropagator(), &md) 108 | ctx = baggage.ContextWithBaggage(ctx, bags) 109 | tr := otel.Tracer(ztrace.TraceName) 110 | name, attr := ztrace.SpanInfo(method, ztrace.PeerFromCtx(ctx)) 111 | 112 | return tr.Start(trace.ContextWithRemoteSpanContext(ctx, spanCtx), name, 113 | trace.WithSpanKind(trace.SpanKindServer), trace.WithAttributes(attr...)) 114 | } 115 | 116 | // wrapServerStream wraps the given grpc.ServerStream with the given context. 117 | func wrapServerStream(ctx context.Context, ss grpc.ServerStream) *serverStream { 118 | return &serverStream{ 119 | ServerStream: ss, 120 | ctx: ctx, 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /code/balancer/zrpc/p2c/p2c.go: -------------------------------------------------------------------------------- 1 | package p2c 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "math/rand" 7 | "strings" 8 | "sync" 9 | "sync/atomic" 10 | "time" 11 | 12 | "gozerosource/code/balancer/zrpc/internal/codes" 13 | 14 | "github.com/zeromicro/go-zero/core/syncx" 15 | "github.com/zeromicro/go-zero/core/timex" 16 | 17 | "google.golang.org/grpc/balancer" 18 | "google.golang.org/grpc/balancer/base" 19 | "google.golang.org/grpc/resolver" 20 | ) 21 | 22 | const ( 23 | // Name is the name of p2c balancer. 24 | Name = "p2c_ewma" 25 | 26 | decayTime = int64(time.Second * 10) // default value from finagle(衰退时间) 27 | forcePick = int64(time.Second) // 强制节点选取时间间隔 28 | initSuccess = 1000 // 初始连接健康值 29 | throttleSuccess = initSuccess / 2 // 连接非健康临界值 30 | penalty = int64(math.MaxInt32) // 负载状态最大值 31 | pickTimes = 3 // 随机选取节点次数 32 | logInterval = time.Minute // 输出节点状态间隔时间 33 | ) 34 | 35 | var emptyPickResult balancer.PickResult 36 | 37 | func init() { 38 | balancer.Register(newBuilder()) 39 | } 40 | 41 | type p2cPickerBuilder struct{} 42 | 43 | // gRPC 在节点有更新的时候会调用 Build 方法,传入所有节点信息, 44 | // 我们在这里把每个节点信息用 subConn 结构保存起来。 45 | // 并归并到一起用 p2cPicker 结构保存起来 46 | func (b *p2cPickerBuilder) Build(info base.PickerBuildInfo) balancer.Picker { 47 | readySCs := info.ReadySCs 48 | if len(readySCs) == 0 { 49 | return base.NewErrPicker(balancer.ErrNoSubConnAvailable) 50 | } 51 | 52 | var conns []*subConn 53 | for conn, connInfo := range readySCs { 54 | conns = append(conns, &subConn{ 55 | addr: connInfo.Address, 56 | conn: conn, 57 | success: initSuccess, 58 | }) 59 | } 60 | 61 | return &p2cPicker{ 62 | conns: conns, 63 | r: rand.New(rand.NewSource(time.Now().UnixNano())), 64 | stamp: syncx.NewAtomicDuration(), 65 | } 66 | } 67 | 68 | func newBuilder() balancer.Builder { 69 | return base.NewBalancerBuilder(Name, new(p2cPickerBuilder), base.Config{HealthCheck: true}) 70 | } 71 | 72 | type p2cPicker struct { 73 | conns []*subConn // 保存所有节点的信息 74 | r *rand.Rand 75 | stamp *syncx.AtomicDuration 76 | lock sync.Mutex 77 | } 78 | 79 | // 选取节点算法(grpc 自定义负载均衡算法) 80 | func (p *p2cPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { 81 | p.lock.Lock() 82 | defer p.lock.Unlock() 83 | 84 | var chosen *subConn 85 | switch len(p.conns) { 86 | case 0: // 没有节点,返回错误 87 | return emptyPickResult, balancer.ErrNoSubConnAvailable 88 | case 1: // 有一个节点,直接返回这个节点 89 | chosen = p.choose(p.conns[0], nil) 90 | case 2: // 有两个节点,计算负载,返回负载低的节点 91 | chosen = p.choose(p.conns[0], p.conns[1]) 92 | default: // 有多个节点,p2c 挑选两个节点,比较这两个节点的负载,返回负载低的节点 93 | var node1, node2 *subConn 94 | // 3次随机选择两个节点 95 | for i := 0; i < pickTimes; i++ { 96 | a := p.r.Intn(len(p.conns)) 97 | b := p.r.Intn(len(p.conns) - 1) 98 | // 防止选择同一个 99 | if b >= a { 100 | b++ 101 | } 102 | node1 = p.conns[a] 103 | node2 = p.conns[b] 104 | // 如果这次选择的节点达到了健康要求, 就中断选择 105 | if node1.healthy() && node2.healthy() { 106 | break 107 | } 108 | } 109 | 110 | chosen = p.choose(node1, node2) 111 | } 112 | 113 | atomic.AddInt64(&chosen.inflight, 1) 114 | atomic.AddInt64(&chosen.requests, 1) 115 | 116 | return balancer.PickResult{ 117 | SubConn: chosen.conn, 118 | Done: p.buildDoneFunc(chosen), 119 | }, nil 120 | } 121 | 122 | // grpc 请求结束时调用 123 | // 存储本次请求耗时等信息,并计算出 EWMA值 保存起来,供下次请求时计算负载等情况的使用 124 | func (p *p2cPicker) buildDoneFunc(c *subConn) func(info balancer.DoneInfo) { 125 | start := int64(timex.Now()) 126 | return func(info balancer.DoneInfo) { 127 | // 正在处理的请求数减 1 128 | atomic.AddInt64(&c.inflight, -1) 129 | now := timex.Now() 130 | // 保存本次请求结束时的时间点,并取出上次请求时的时间点 131 | last := atomic.SwapInt64(&c.last, int64(now)) 132 | td := int64(now) - last 133 | if td < 0 { 134 | td = 0 135 | } 136 | // 用牛顿冷却定律中的衰减函数模型计算EWMA算法中的β值 137 | // 牛顿冷却算法 https://www.ruanyifeng.com/blog/2012/03/ranking_algorithm_newton_s_law_of_cooling.html 138 | w := math.Exp(float64(-td) / float64(decayTime)) 139 | // 保存本次请求的耗时 140 | lag := int64(now) - start 141 | if lag < 0 { 142 | lag = 0 143 | } 144 | olag := atomic.LoadUint64(&c.lag) 145 | if olag == 0 { 146 | w = 0 147 | } 148 | // 计算 EWMA 值 149 | // EWMA(指数加权移动平均算法) https://blog.csdn.net/mzpmzk/article/details/80085929 150 | atomic.StoreUint64(&c.lag, uint64(float64(olag)*w+float64(lag)*(1-w))) 151 | success := initSuccess 152 | if info.Err != nil && !codes.Acceptable(info.Err) { 153 | success = 0 154 | } 155 | // 健康状态 156 | osucc := atomic.LoadUint64(&c.success) 157 | atomic.StoreUint64(&c.success, uint64(float64(osucc)*w+float64(success)*(1-w))) 158 | 159 | stamp := p.stamp.Load() 160 | if now-stamp >= logInterval { 161 | if p.stamp.CompareAndSwap(stamp, now) { 162 | p.logStats() 163 | } 164 | } 165 | } 166 | } 167 | 168 | // 比较两个节点的负载情况,选择负载低的 169 | func (p *p2cPicker) choose(c1, c2 *subConn) *subConn { 170 | start := int64(timex.Now()) 171 | if c2 == nil { 172 | atomic.StoreInt64(&c1.pick, start) 173 | return c1 174 | } 175 | 176 | // c2 一直是负载较大的 node 177 | if c1.load() > c2.load() { 178 | c1, c2 = c2, c1 // 交换变量,方便判断 179 | } 180 | 181 | pick := atomic.LoadInt64(&c2.pick) 182 | // 如果(本次被选中的时间 - 上次被选中的时间 > forcePick && 本次与上次时间点不同) 183 | // return 负载较大 node 184 | if start-pick > forcePick && atomic.CompareAndSwapInt64(&c2.pick, pick, start) { 185 | return c2 186 | } 187 | 188 | // 返回负载较小 node 189 | atomic.StoreInt64(&c1.pick, start) 190 | return c1 191 | } 192 | 193 | // 输出所有节点状态信息 194 | func (p *p2cPicker) logStats() { 195 | var stats []string 196 | 197 | p.lock.Lock() 198 | defer p.lock.Unlock() 199 | 200 | for _, conn := range p.conns { 201 | stats = append(stats, fmt.Sprintf("conn: %s, load: %d, reqs: %d", 202 | conn.addr.Addr, conn.load(), atomic.SwapInt64(&conn.requests, 0))) 203 | } 204 | 205 | fmt.Printf("p2c - %s", strings.Join(stats, "; ")) 206 | } 207 | 208 | type subConn struct { 209 | lag uint64 // 用来保存 ewma 值 210 | inflight int64 // 用在保存当前节点正在处理的请求总数 211 | success uint64 // 用来标识一段时间内此连接的健康状态 212 | requests int64 // 用来保存请求总数 213 | last int64 // 用来保存上一次请求耗时, 用于计算 ewma 值 214 | pick int64 // 保存上一次被选中的时间点 215 | addr resolver.Address 216 | conn balancer.SubConn 217 | } 218 | 219 | // 节点健康情况 220 | func (c *subConn) healthy() bool { 221 | return atomic.LoadUint64(&c.success) > throttleSuccess 222 | } 223 | 224 | // 节点负载情况 225 | func (c *subConn) load() int64 { 226 | // plus one to avoid multiply zero 227 | // 通过 EWMA 计算节点的负载情况; 加 1 是为了避免为 0 的情况 228 | lag := int64(math.Sqrt(float64(atomic.LoadUint64(&c.lag) + 1))) 229 | load := lag * (atomic.LoadInt64(&c.inflight) + 1) 230 | if load == 0 { 231 | return penalty 232 | } 233 | 234 | return load 235 | } 236 | -------------------------------------------------------------------------------- /code/balancer/zrpc/resolver/internal/directbuilder.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "strings" 5 | 6 | "google.golang.org/grpc/resolver" 7 | ) 8 | 9 | type directBuilder struct{} 10 | 11 | func (d *directBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) ( 12 | resolver.Resolver, error, 13 | ) { 14 | var addrs []resolver.Address 15 | endpoints := strings.FieldsFunc(target.Endpoint, func(r rune) bool { 16 | return r == EndpointSepChar 17 | }) 18 | 19 | for _, val := range subset(endpoints, subsetSize) { 20 | addrs = append(addrs, resolver.Address{ 21 | Addr: val, 22 | }) 23 | } 24 | if err := cc.UpdateState(resolver.State{ 25 | Addresses: addrs, 26 | }); err != nil { 27 | return nil, err 28 | } 29 | 30 | return &nopResolver{cc: cc}, nil 31 | } 32 | 33 | func (d *directBuilder) Scheme() string { 34 | return DirectScheme 35 | } 36 | -------------------------------------------------------------------------------- /code/balancer/zrpc/resolver/internal/discovbuilder.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "strings" 5 | 6 | "github.com/zeromicro/go-zero/core/discov" 7 | "github.com/zeromicro/go-zero/core/logx" 8 | 9 | "google.golang.org/grpc/resolver" 10 | ) 11 | 12 | type discovBuilder struct{} 13 | 14 | func (b *discovBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) ( 15 | resolver.Resolver, error, 16 | ) { 17 | hosts := strings.FieldsFunc(target.Authority, func(r rune) bool { 18 | return r == EndpointSepChar 19 | }) 20 | // 获取服务列表 21 | sub, err := discov.NewSubscriber(hosts, target.Endpoint) 22 | if err != nil { 23 | return nil, err 24 | } 25 | 26 | update := func() { 27 | var addrs []resolver.Address 28 | for _, val := range subset(sub.Values(), subsetSize) { 29 | addrs = append(addrs, resolver.Address{ 30 | Addr: val, 31 | }) 32 | } 33 | // 调用UpdateState方法更新 34 | if err := cc.UpdateState(resolver.State{ 35 | Addresses: addrs, 36 | }); err != nil { 37 | logx.Errorf("%s", err) 38 | } 39 | } 40 | // 添加监听,当服务地址发生变化会触发更新 41 | sub.AddListener(update) 42 | // 更新服务列表 43 | update() 44 | 45 | return &nopResolver{cc: cc}, nil 46 | } 47 | 48 | func (b *discovBuilder) Scheme() string { 49 | return DiscovScheme 50 | } 51 | -------------------------------------------------------------------------------- /code/balancer/zrpc/resolver/internal/etcdbuilder.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | type etcdBuilder struct { 4 | discovBuilder 5 | } 6 | 7 | func (b *etcdBuilder) Scheme() string { 8 | return EtcdScheme 9 | } 10 | -------------------------------------------------------------------------------- /code/balancer/zrpc/resolver/internal/kube/eventhandler.go: -------------------------------------------------------------------------------- 1 | package kube 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/zeromicro/go-zero/core/lang" 7 | "github.com/zeromicro/go-zero/core/logx" 8 | v1 "k8s.io/api/core/v1" 9 | ) 10 | 11 | // EventHandler is ResourceEventHandler implementation. 12 | type EventHandler struct { 13 | update func([]string) 14 | endpoints map[string]lang.PlaceholderType 15 | lock sync.Mutex 16 | } 17 | 18 | // NewEventHandler returns an EventHandler. 19 | func NewEventHandler(update func([]string)) *EventHandler { 20 | return &EventHandler{ 21 | update: update, 22 | endpoints: make(map[string]lang.PlaceholderType), 23 | } 24 | } 25 | 26 | // OnAdd handles the endpoints add events. 27 | func (h *EventHandler) OnAdd(obj interface{}) { 28 | endpoints, ok := obj.(*v1.Endpoints) 29 | if !ok { 30 | logx.Errorf("%v is not an object with type *v1.Endpoints", obj) 31 | return 32 | } 33 | 34 | h.lock.Lock() 35 | defer h.lock.Unlock() 36 | 37 | var changed bool 38 | for _, sub := range endpoints.Subsets { 39 | for _, point := range sub.Addresses { 40 | if _, ok := h.endpoints[point.IP]; !ok { 41 | h.endpoints[point.IP] = lang.Placeholder 42 | changed = true 43 | } 44 | } 45 | } 46 | 47 | if changed { 48 | h.notify() 49 | } 50 | } 51 | 52 | // OnDelete handles the endpoints delete events. 53 | func (h *EventHandler) OnDelete(obj interface{}) { 54 | endpoints, ok := obj.(*v1.Endpoints) 55 | if !ok { 56 | logx.Errorf("%v is not an object with type *v1.Endpoints", obj) 57 | return 58 | } 59 | 60 | h.lock.Lock() 61 | defer h.lock.Unlock() 62 | 63 | var changed bool 64 | for _, sub := range endpoints.Subsets { 65 | for _, point := range sub.Addresses { 66 | if _, ok := h.endpoints[point.IP]; ok { 67 | delete(h.endpoints, point.IP) 68 | changed = true 69 | } 70 | } 71 | } 72 | 73 | if changed { 74 | h.notify() 75 | } 76 | } 77 | 78 | // OnUpdate handles the endpoints update events. 79 | func (h *EventHandler) OnUpdate(oldObj, newObj interface{}) { 80 | oldEndpoints, ok := oldObj.(*v1.Endpoints) 81 | if !ok { 82 | logx.Errorf("%v is not an object with type *v1.Endpoints", oldObj) 83 | return 84 | } 85 | 86 | newEndpoints, ok := newObj.(*v1.Endpoints) 87 | if !ok { 88 | logx.Errorf("%v is not an object with type *v1.Endpoints", newObj) 89 | return 90 | } 91 | 92 | if oldEndpoints.ResourceVersion == newEndpoints.ResourceVersion { 93 | return 94 | } 95 | 96 | h.Update(newEndpoints) 97 | } 98 | 99 | // Update updates the endpoints. 100 | func (h *EventHandler) Update(endpoints *v1.Endpoints) { 101 | h.lock.Lock() 102 | defer h.lock.Unlock() 103 | 104 | old := h.endpoints 105 | h.endpoints = make(map[string]lang.PlaceholderType) 106 | for _, sub := range endpoints.Subsets { 107 | for _, point := range sub.Addresses { 108 | h.endpoints[point.IP] = lang.Placeholder 109 | } 110 | } 111 | 112 | if diff(old, h.endpoints) { 113 | h.notify() 114 | } 115 | } 116 | 117 | func (h *EventHandler) notify() { 118 | var targets []string 119 | 120 | for k := range h.endpoints { 121 | targets = append(targets, k) 122 | } 123 | 124 | h.update(targets) 125 | } 126 | 127 | func diff(o, n map[string]lang.PlaceholderType) bool { 128 | if len(o) != len(n) { 129 | return true 130 | } 131 | 132 | for k := range o { 133 | if _, ok := n[k]; !ok { 134 | return true 135 | } 136 | } 137 | 138 | return false 139 | } 140 | -------------------------------------------------------------------------------- /code/balancer/zrpc/resolver/internal/kube/targetparser.go: -------------------------------------------------------------------------------- 1 | package kube 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | 8 | "google.golang.org/grpc/resolver" 9 | ) 10 | 11 | const ( 12 | colon = ":" 13 | defaultNamespace = "default" 14 | ) 15 | 16 | var emptyService Service 17 | 18 | // Service represents a service with namespace, name and port. 19 | type Service struct { 20 | Namespace string 21 | Name string 22 | Port int 23 | } 24 | 25 | // ParseTarget parses the resolver.Target. 26 | func ParseTarget(target resolver.Target) (Service, error) { 27 | var service Service 28 | service.Namespace = target.Authority 29 | if len(service.Namespace) == 0 { 30 | service.Namespace = defaultNamespace 31 | } 32 | 33 | segs := strings.SplitN(target.Endpoint, colon, 2) 34 | if len(segs) < 2 { 35 | return emptyService, fmt.Errorf("bad endpoint: %s", target.Endpoint) 36 | } 37 | 38 | service.Name = segs[0] 39 | port, err := strconv.Atoi(segs[1]) 40 | if err != nil { 41 | return emptyService, err 42 | } 43 | 44 | service.Port = port 45 | 46 | return service, nil 47 | } 48 | -------------------------------------------------------------------------------- /code/balancer/zrpc/resolver/internal/kubebuilder.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "gozerosource/code/balancer/zrpc/resolver/internal/kube" 9 | 10 | "github.com/zeromicro/go-zero/core/logx" 11 | "github.com/zeromicro/go-zero/core/proc" 12 | "github.com/zeromicro/go-zero/core/threading" 13 | 14 | "google.golang.org/grpc/resolver" 15 | v1 "k8s.io/apimachinery/pkg/apis/meta/v1" 16 | "k8s.io/client-go/informers" 17 | "k8s.io/client-go/kubernetes" 18 | "k8s.io/client-go/rest" 19 | ) 20 | 21 | const ( 22 | resyncInterval = 5 * time.Minute 23 | nameSelector = "metadata.name=" 24 | ) 25 | 26 | type kubeBuilder struct{} 27 | 28 | func (b *kubeBuilder) Build(target resolver.Target, cc resolver.ClientConn, 29 | opts resolver.BuildOptions, 30 | ) (resolver.Resolver, error) { 31 | svc, err := kube.ParseTarget(target) 32 | if err != nil { 33 | return nil, err 34 | } 35 | 36 | config, err := rest.InClusterConfig() 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | cs, err := kubernetes.NewForConfig(config) 42 | if err != nil { 43 | return nil, err 44 | } 45 | 46 | handler := kube.NewEventHandler(func(endpoints []string) { 47 | var addrs []resolver.Address 48 | for _, val := range subset(endpoints, subsetSize) { 49 | addrs = append(addrs, resolver.Address{ 50 | Addr: fmt.Sprintf("%s:%d", val, svc.Port), 51 | }) 52 | } 53 | 54 | if err := cc.UpdateState(resolver.State{ 55 | Addresses: addrs, 56 | }); err != nil { 57 | logx.Errorf("%s", err) 58 | } 59 | }) 60 | inf := informers.NewSharedInformerFactoryWithOptions(cs, resyncInterval, 61 | informers.WithNamespace(svc.Namespace), 62 | informers.WithTweakListOptions(func(options *v1.ListOptions) { 63 | options.FieldSelector = nameSelector + svc.Name 64 | })) 65 | in := inf.Core().V1().Endpoints() 66 | in.Informer().AddEventHandler(handler) 67 | threading.GoSafe(func() { 68 | inf.Start(proc.Done()) 69 | }) 70 | 71 | endpoints, err := cs.CoreV1().Endpoints(svc.Namespace).Get(context.Background(), svc.Name, v1.GetOptions{}) 72 | if err != nil { 73 | return nil, err 74 | } 75 | 76 | handler.Update(endpoints) 77 | 78 | return &nopResolver{cc: cc}, nil 79 | } 80 | 81 | func (b *kubeBuilder) Scheme() string { 82 | return KubernetesScheme 83 | } 84 | -------------------------------------------------------------------------------- /code/balancer/zrpc/resolver/internal/resolver.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "fmt" 5 | 6 | "google.golang.org/grpc/resolver" 7 | ) 8 | 9 | const ( 10 | // DirectScheme stands for direct scheme. 11 | DirectScheme = "direct" 12 | // DiscovScheme stands for discov scheme. 13 | DiscovScheme = "discov" 14 | // EtcdScheme stands for etcd scheme. 15 | EtcdScheme = "etcd" 16 | // KubernetesScheme stands for k8s scheme. 17 | KubernetesScheme = "k8s" 18 | // EndpointSepChar is the separator cha in endpoints. 19 | EndpointSepChar = ',' 20 | 21 | subsetSize = 32 22 | ) 23 | 24 | var ( 25 | // EndpointSep is the separator string in endpoints. 26 | EndpointSep = fmt.Sprintf("%c", EndpointSepChar) 27 | 28 | directResolverBuilder directBuilder 29 | discovResolverBuilder discovBuilder 30 | etcdResolverBuilder etcdBuilder 31 | k8sResolverBuilder kubeBuilder 32 | ) 33 | 34 | // RegisterResolver registers the direct and discov schemes to the resolver. 35 | // RegisterResolver 注册自定义的Resolver 36 | func RegisterResolver() { 37 | resolver.Register(&directResolverBuilder) 38 | resolver.Register(&discovResolverBuilder) 39 | resolver.Register(&etcdResolverBuilder) 40 | resolver.Register(&k8sResolverBuilder) 41 | } 42 | 43 | type nopResolver struct { 44 | cc resolver.ClientConn 45 | } 46 | 47 | func (r *nopResolver) Close() { 48 | } 49 | 50 | func (r *nopResolver) ResolveNow(options resolver.ResolveNowOptions) { 51 | } 52 | -------------------------------------------------------------------------------- /code/balancer/zrpc/resolver/internal/subset.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import "math/rand" 4 | 5 | func subset(set []string, sub int) []string { 6 | rand.Shuffle(len(set), func(i, j int) { 7 | set[i], set[j] = set[j], set[i] 8 | }) 9 | if len(set) <= sub { 10 | return set 11 | } 12 | 13 | return set[:sub] 14 | } 15 | -------------------------------------------------------------------------------- /code/balancer/zrpc/resolver/register.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "gozerosource/code/balancer/zrpc/resolver/internal" 5 | ) 6 | 7 | // Register registers schemes defined zrpc. 8 | // Keep it in a separated package to let third party register manually. 9 | func Register() { 10 | internal.RegisterResolver() 11 | } 12 | -------------------------------------------------------------------------------- /code/balancer/zrpc/resolver/target.go: -------------------------------------------------------------------------------- 1 | package resolver 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "gozerosource/code/balancer/zrpc/resolver/internal" 8 | ) 9 | 10 | // BuildDirectTarget returns a string that represents the given endpoints with direct schema. 11 | // BuildDiscovTarget 构建target 12 | func BuildDirectTarget(endpoints []string) string { 13 | return fmt.Sprintf("%s:///%s", internal.DirectScheme, 14 | strings.Join(endpoints, internal.EndpointSep)) 15 | } 16 | 17 | // BuildDiscovTarget returns a string that represents the given endpoints with discov schema. 18 | func BuildDiscovTarget(endpoints []string, key string) string { 19 | return fmt.Sprintf("%s://%s/%s", internal.DiscovScheme, 20 | strings.Join(endpoints, internal.EndpointSep), key) 21 | } 22 | -------------------------------------------------------------------------------- /code/balancer/zrpc/server.go: -------------------------------------------------------------------------------- 1 | package zrpc 2 | 3 | import ( 4 | "log" 5 | "time" 6 | 7 | "gozerosource/code/balancer/zrpc/internal" 8 | "gozerosource/code/balancer/zrpc/internal/auth" 9 | "gozerosource/code/balancer/zrpc/internal/serverinterceptors" 10 | 11 | "github.com/zeromicro/go-zero/core/load" 12 | "github.com/zeromicro/go-zero/core/logx" 13 | "github.com/zeromicro/go-zero/core/stat" 14 | "google.golang.org/grpc" 15 | ) 16 | 17 | // A RpcServer is a rpc server. 18 | type RpcServer struct { 19 | server internal.Server 20 | register internal.RegisterFn 21 | } 22 | 23 | // MustNewServer returns a RpcSever, exits on any error. 24 | func MustNewServer(c RpcServerConf, register internal.RegisterFn) *RpcServer { 25 | server, err := NewServer(c, register) 26 | if err != nil { 27 | log.Fatal(err) 28 | } 29 | 30 | return server 31 | } 32 | 33 | // NewServer returns a RpcServer. 34 | func NewServer(c RpcServerConf, register internal.RegisterFn) (*RpcServer, error) { 35 | var err error 36 | if err = c.Validate(); err != nil { 37 | return nil, err 38 | } 39 | 40 | var server internal.Server 41 | // 初始化服务指标监听器 42 | metrics := stat.NewMetrics(c.ListenOn) 43 | serverOptions := []internal.ServerOption{ 44 | internal.WithMetrics(metrics), 45 | } 46 | 47 | if c.HasEtcd() { 48 | // 如果配置 etcd 服务则加载 rpc 发布服务 用于服务发现 49 | server, err = internal.NewRpcPubServer(c.Etcd, c.ListenOn, serverOptions...) 50 | if err != nil { 51 | return nil, err 52 | } 53 | } else { 54 | // 直接使用 rpc 服务 55 | server = internal.NewRpcServer(c.ListenOn, serverOptions...) 56 | } 57 | 58 | server.SetName(c.Name) 59 | if err = setupInterceptors(server, c, metrics); err != nil { 60 | return nil, err 61 | } 62 | 63 | rpcServer := &RpcServer{ 64 | server: server, 65 | register: register, 66 | } 67 | if err = c.SetUp(); err != nil { 68 | return nil, err 69 | } 70 | 71 | return rpcServer, nil 72 | } 73 | 74 | // AddOptions adds given options. 75 | // 添加配置 76 | func (rs *RpcServer) AddOptions(options ...grpc.ServerOption) { 77 | rs.server.AddOptions(options...) 78 | } 79 | 80 | // AddStreamInterceptors adds given stream interceptors. 81 | // 添加 rpc 数据流拦截器 82 | func (rs *RpcServer) AddStreamInterceptors(interceptors ...grpc.StreamServerInterceptor) { 83 | rs.server.AddStreamInterceptors(interceptors...) 84 | } 85 | 86 | // AddUnaryInterceptors adds given unary interceptors. 87 | // 添加拦截器 88 | func (rs *RpcServer) AddUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor) { 89 | rs.server.AddUnaryInterceptors(interceptors...) 90 | } 91 | 92 | // Start starts the RpcServer. 93 | // Graceful shutdown is enabled by default. 94 | // Use proc.SetTimeToForceQuit to customize the graceful shutdown period. 95 | // 启动 Rpc 服务 96 | // 已默认开启服务优雅关闭项 97 | // 可使用 proc.SetTimeToForceQuit 配置项来自定义服务优雅关闭项 98 | func (rs *RpcServer) Start() { 99 | if err := rs.server.Start(rs.register); err != nil { 100 | logx.Error(err) 101 | panic(err) 102 | } 103 | } 104 | 105 | // Stop stops the RpcServer. 106 | func (rs *RpcServer) Stop() { 107 | logx.Close() 108 | } 109 | 110 | // SetServerSlowThreshold sets the slow threshold on server side. 111 | // 设置服务器端的慢阈值 112 | func SetServerSlowThreshold(threshold time.Duration) { 113 | serverinterceptors.SetSlowThreshold(threshold) 114 | } 115 | 116 | // 设置拦截器 117 | func setupInterceptors(server internal.Server, c RpcServerConf, metrics *stat.Metrics) error { 118 | if c.CpuThreshold > 0 { 119 | // 添加服务降级拦截器 120 | shedder := load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold)) 121 | server.AddUnaryInterceptors(serverinterceptors.UnarySheddingInterceptor(shedder, metrics)) 122 | } 123 | 124 | if c.Timeout > 0 { 125 | // 添加服务超时拦截器 126 | server.AddUnaryInterceptors(serverinterceptors.UnaryTimeoutInterceptor( 127 | time.Duration(c.Timeout) * time.Millisecond)) 128 | } 129 | 130 | if c.Auth { 131 | // 初始化权限验证服务 132 | authenticator, err := auth.NewAuthenticator(c.Redis.NewRedis(), c.Redis.Key, c.StrictControl) 133 | if err != nil { 134 | return err 135 | } 136 | 137 | // 添加权限验证服务(rpc 数据流) 138 | server.AddStreamInterceptors(serverinterceptors.StreamAuthorizeInterceptor(authenticator)) 139 | // 添加权限验证服务 140 | server.AddUnaryInterceptors(serverinterceptors.UnaryAuthorizeInterceptor(authenticator)) 141 | } 142 | 143 | return nil 144 | } 145 | -------------------------------------------------------------------------------- /code/breaker/breaker_test.go: -------------------------------------------------------------------------------- 1 | package breaker 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "gozerosource/code/core/breaker" 9 | ) 10 | 11 | func Test_Breaker(t *testing.T) { 12 | b := breaker.NewBreaker() 13 | for i := 0; i < 100; i++ { 14 | allow, err := b.Allow() 15 | if err != nil { 16 | fmt.Println("err", err) 17 | break 18 | } 19 | if i < 10 { 20 | allow.Reject() 21 | // time.Sleep(2000 * time.Millisecond) 22 | time.Sleep(20 * time.Millisecond) 23 | } else { 24 | allow.Accept() 25 | } 26 | } 27 | fmt.Println(b.GB.History()) 28 | } 29 | 30 | func Test_Beaker2(t *testing.T) { 31 | b := breaker.NewBreaker() 32 | for i := 0; i < 100; i++ { 33 | err := b.DoWithAcceptable( 34 | func() error { 35 | if i < 10 { 36 | time.Sleep(20 * time.Millisecond) 37 | // return errors.New(">>>>>>>>>") 38 | } 39 | return nil 40 | }, 41 | func(err error) bool { 42 | // fmt.Println("err", err) 43 | return i >= 8 44 | }, 45 | ) 46 | if err != nil { 47 | fmt.Println("err", err) 48 | break 49 | } 50 | } 51 | fmt.Println(b.GB.History()) 52 | } 53 | 54 | func Benchmark_BrewkerSerial(b *testing.B) { 55 | bk := breaker.NewBreaker() 56 | for i := 0; i < b.N; i++ { 57 | allow, err := bk.Allow() 58 | if err != nil { 59 | fmt.Println("err", err) 60 | break 61 | } 62 | if i%2 == 0 { 63 | allow.Accept() 64 | } else { 65 | allow.Reject() 66 | } 67 | } 68 | fmt.Println(bk.GB.History()) 69 | } 70 | 71 | func Benchmark_BrewkerParallel(b *testing.B) { 72 | // 测试一个对象或者函数在多线程的场景下面是否安全 73 | b.RunParallel(func(pb *testing.PB) { 74 | bk := breaker.NewBreaker() 75 | i := 0 76 | for pb.Next() { 77 | i++ 78 | allow, err := bk.Allow() 79 | if err != nil { 80 | fmt.Println("err", err) 81 | break 82 | } 83 | if i%100 == 0 { 84 | allow.Accept() 85 | } else { 86 | allow.Reject() 87 | } 88 | // time.Sleep(20 * time.Millisecond) 89 | } 90 | fmt.Println(bk.GB.History()) 91 | }) 92 | } 93 | -------------------------------------------------------------------------------- /code/breaker/googlebreaker_test.go: -------------------------------------------------------------------------------- 1 | package breaker 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "gozerosource/code/core/breaker" 9 | ) 10 | 11 | // 简单场景直接判断对象是否被熔断,执行请求后必须需手动上报执行结果至熔断器。 12 | func Test_GoogleBreaker(t *testing.T) { 13 | gb := breaker.NewGoogleBreaker() 14 | for i := 0; i < 100; i++ { 15 | allow, err := gb.Allow() 16 | if err != nil { 17 | fmt.Println("err", err) 18 | break 19 | } 20 | if i < 10 { 21 | allow.Reject() 22 | // time.Sleep(2000 * time.Millisecond) 23 | time.Sleep(20 * time.Millisecond) 24 | } else { 25 | allow.Accept() 26 | } 27 | } 28 | fmt.Println(gb.History()) 29 | } 30 | 31 | // 复杂场景下支持自定义快速失败,自定义判定请求是否成功的熔断方法,自动上报执行结果至熔断器。 32 | func Test_GoogleBreaker2(t *testing.T) { 33 | gb := breaker.NewGoogleBreaker() 34 | for i := 0; i < 100; i++ { 35 | err := gb.DoReq( 36 | func() error { 37 | if i < 10 { 38 | time.Sleep(20 * time.Millisecond) 39 | // return errors.New(">>>>>>>>>") 40 | } 41 | return nil 42 | }, 43 | func(err error) error { 44 | fmt.Println("err", err) 45 | return nil 46 | }, 47 | func(err error) bool { 48 | fmt.Println("err", err) 49 | return i >= 8 50 | }, 51 | ) 52 | if err != nil { 53 | fmt.Println("err", err) 54 | break 55 | } 56 | } 57 | fmt.Println(gb.History()) 58 | } 59 | -------------------------------------------------------------------------------- /code/breaker/rollingwindow_test.go: -------------------------------------------------------------------------------- 1 | package breaker 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "gozerosource/code/core/collection" 9 | ) 10 | 11 | const ( 12 | // 250ms for bucket duration 13 | windowSec = time.Second // 窗口时间 14 | buckets = 40 // bucket 数量 15 | k = 1.5 // 倍值(越小越敏感) 16 | ) 17 | 18 | func Test_RollingWindow(t *testing.T) { 19 | bucketDuration := time.Duration(int64(windowSec) / int64(buckets)) 20 | // st := collection.NewRollingWindow(buckets, bucketDuration, collection.IgnoreCurrentBucket()) 21 | st := collection.NewRollingWindow(buckets, bucketDuration) 22 | for i := 0; i < 100; i++ { 23 | time.Sleep(25 * time.Millisecond) 24 | st.Add(float64(i)) 25 | // if i < 50 { 26 | // st.Add(float64(i)) 27 | // } else { 28 | // st.Add(float64(0)) 29 | // } 30 | } 31 | var accepts int64 32 | var total int64 33 | st.Reduce(func(b *collection.Bucket) { 34 | accepts += int64(b.Sum) 35 | total += b.Count 36 | }) 37 | fmt.Println("accepts", accepts) 38 | fmt.Println("total", total) 39 | } 40 | -------------------------------------------------------------------------------- /code/core/breaker/breaker.go: -------------------------------------------------------------------------------- 1 | package breaker 2 | 3 | type ( 4 | // 自定义判定执行结果 5 | Acceptable func(err error) bool 6 | // 手动回调 7 | Promise interface { 8 | // Accept tells the Breaker that the call is successful. 9 | // 请求成功 10 | Accept() 11 | // Reject tells the Breaker that the call is failed. 12 | // 请求失败 13 | Reject(reason string) 14 | } 15 | Breaker interface { 16 | // 熔断器名称 17 | Name() string 18 | 19 | // 熔断方法,执行请求时必须手动上报执行结果 20 | // 适用于简单无需自定义快速失败,无需自定义判定请求结果的场景 21 | // 相当于手动挡。。。 22 | Allow() (Promise, error) 23 | 24 | // 熔断方法,自动上报执行结果 25 | // 自动挡。。。 26 | Do(req func() error) error 27 | 28 | // 熔断方法 29 | // acceptable - 支持自定义判定执行结果 30 | DoWithAcceptable(req func() error, acceptable Acceptable) error 31 | 32 | // 熔断方法 33 | // fallback - 支持自定义快速失败 34 | DoWithFallback(req func() error, fallback func(err error) error) error 35 | 36 | // 熔断方法 37 | // fallback - 支持自定义快速失败 38 | // acceptable - 支持自定义判定执行结果 39 | DoWithFallbackAcceptable(req func() error, fallback func(err error) error, acceptable Acceptable) error 40 | } 41 | 42 | internalPromise interface { 43 | Accept() 44 | Reject() 45 | } 46 | ) 47 | 48 | func defaultAcceptable(err error) bool { 49 | return err == nil 50 | } 51 | 52 | type breaker struct { 53 | GB *googleBreaker 54 | } 55 | 56 | func NewBreaker() *breaker { 57 | return &breaker{ 58 | GB: NewGoogleBreaker(), 59 | } 60 | } 61 | 62 | func (b *breaker) Name() string { 63 | return "" 64 | } 65 | 66 | func (b *breaker) Allow() (internalPromise, error) { 67 | return b.GB.Allow() 68 | } 69 | 70 | func (b *breaker) Do(req func() error) error { 71 | return b.GB.DoReq(req, nil, defaultAcceptable) 72 | } 73 | 74 | func (b *breaker) DoWithAcceptable(req func() error, acceptable Acceptable) error { 75 | return b.GB.DoReq(req, nil, acceptable) 76 | } 77 | 78 | func (b *breaker) DoWithFallback(req func() error, fallback func(err error) error) error { 79 | return b.GB.DoReq(req, fallback, defaultAcceptable) 80 | } 81 | 82 | func (b *breaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error, acceptable Acceptable) error { 83 | return b.GB.DoReq(req, fallback, acceptable) 84 | } 85 | -------------------------------------------------------------------------------- /code/core/breaker/googlebreaker.go: -------------------------------------------------------------------------------- 1 | package breaker 2 | 3 | import ( 4 | "errors" 5 | "math" 6 | "math/rand" 7 | "sync" 8 | "time" 9 | 10 | "gozerosource/code/core/collection" 11 | ) 12 | 13 | const ( 14 | // 250ms for bucket duration 15 | windowSec = time.Second * 10 // 窗口时间 16 | buckets = 40 // bucket 数量 17 | k = 1.5 // 倍值(越小越敏感) 18 | protection = 5 19 | ) 20 | 21 | // A Proba is used to test if true on given probability. 22 | type Proba struct { 23 | // rand.New(...) returns a non thread safe object 24 | r *rand.Rand 25 | lock sync.Mutex 26 | } 27 | 28 | // NewProba returns a Proba. 29 | func NewProba() *Proba { 30 | return &Proba{ 31 | r: rand.New(rand.NewSource(time.Now().UnixNano())), 32 | } 33 | } 34 | 35 | // 检查给定概率是否为真 36 | // TrueOnProba checks if true on given probability. 37 | func (p *Proba) TrueOnProba(proba float64) (truth bool) { 38 | p.lock.Lock() 39 | truth = p.r.Float64() < proba 40 | p.lock.Unlock() 41 | return 42 | } 43 | 44 | // ErrServiceUnavailable is returned when the Breaker state is open. 45 | var ErrServiceUnavailable = errors.New("circuit breaker is open") 46 | 47 | // googleBreaker is a netflixBreaker pattern from google. 48 | // see Client-Side Throttling section in https://landing.google.com/sre/sre-book/chapters/handling-overload/ 49 | type googleBreaker struct { 50 | k float64 51 | stat *collection.RollingWindow 52 | proba *Proba 53 | } 54 | 55 | func NewGoogleBreaker() *googleBreaker { 56 | bucketDuration := time.Duration(int64(windowSec) / int64(buckets)) 57 | st := collection.NewRollingWindow(buckets, bucketDuration) 58 | return &googleBreaker{ 59 | stat: st, 60 | k: k, 61 | proba: NewProba(), 62 | } 63 | } 64 | 65 | // 判断是否触发熔断 66 | func (b *googleBreaker) accept() error { 67 | // 获取最近一段时间的统计数据 68 | accepts, total := b.History() 69 | // 计算动态熔断概率 70 | weightedAccepts := b.k * float64(accepts) 71 | // Google Sre过载保护算法 https://landing.google.com/sre/sre-book/chapters/handling-overload/#eq2101 72 | dropRatio := math.Max(0, (float64(total-protection)-weightedAccepts)/float64(total+1)) 73 | if dropRatio <= 0 { 74 | return nil 75 | } 76 | // 随机产生0.0-1.0之间的随机数与上面计算出来的熔断概率相比较 77 | // 如果随机数比熔断概率小则进行熔断 78 | if b.proba.TrueOnProba(dropRatio) { 79 | return ErrServiceUnavailable 80 | } 81 | 82 | return nil 83 | } 84 | 85 | // 熔断方法,执行请求时必须手动上报执行结果 86 | // 适用于简单无需自定义快速失败,无需自定义判定请求结果的场景 87 | // 相当于手动挡。。。 88 | // 返回一个promise异步回调对象,可由开发者自行决定是否上报结果到熔断器 89 | func (b *googleBreaker) Allow() (internalPromise, error) { 90 | if err := b.accept(); err != nil { 91 | return nil, err 92 | } 93 | 94 | return googlePromise{ 95 | b: b, 96 | }, nil 97 | } 98 | 99 | // 熔断方法,自动上报执行结果 100 | // 自动挡。。。 101 | // req 熔断对象方法 102 | // fallback 自定义快速失败函数,可对熔断产生的err进行包装后返回 103 | // acceptable 对本次未熔断时执行请求的结果进行自定义的判定,比如可以针对http.code,rpc.code,body.code 104 | func (b *googleBreaker) DoReq(req func() error, fallback func(err error) error, acceptable Acceptable) error { 105 | // 判定是否熔断 106 | if err := b.accept(); err != nil { 107 | // 熔断中,如果有自定义的fallback则执行 108 | if fallback != nil { 109 | return fallback(err) 110 | } 111 | 112 | return err 113 | } 114 | // 如果执行req()过程发生了panic,依然判定本次执行失败上报至熔断器 115 | defer func() { 116 | if e := recover(); e != nil { 117 | b.markFailure() 118 | panic(e) 119 | } 120 | }() 121 | // 执行请求 122 | err := req() 123 | // 判定请求成功 124 | if acceptable(err) { 125 | b.markSuccess() 126 | } else { 127 | b.markFailure() 128 | } 129 | 130 | return err 131 | } 132 | 133 | // 上报成功 134 | func (b *googleBreaker) markSuccess() { 135 | b.stat.Add(1) 136 | } 137 | 138 | // 上报失败 139 | func (b *googleBreaker) markFailure() { 140 | b.stat.Add(0) 141 | } 142 | 143 | // 统计数据 144 | // accepts 成功次数 145 | // total 总次数 146 | func (b *googleBreaker) History() (accepts, total int64) { 147 | b.stat.Reduce(func(b *collection.Bucket) { 148 | accepts += int64(b.Sum) 149 | total += b.Count 150 | }) 151 | 152 | return 153 | } 154 | 155 | type googlePromise struct { 156 | b *googleBreaker 157 | } 158 | 159 | // 正常请求计数 160 | func (p googlePromise) Accept() { 161 | p.b.markSuccess() 162 | } 163 | 164 | // 异常请求计数 165 | func (p googlePromise) Reject() { 166 | p.b.markFailure() 167 | } 168 | -------------------------------------------------------------------------------- /code/core/collection/rollingwindow.go: -------------------------------------------------------------------------------- 1 | package collection 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | "github.com/zeromicro/go-zero/core/timex" 8 | ) 9 | 10 | type ( 11 | // RollingWindowOption let callers customize the RollingWindow. 12 | RollingWindowOption func(rollingWindow *RollingWindow) 13 | 14 | // RollingWindow defines a rolling window to calculate the events in buckets with time interval. 15 | RollingWindow struct { 16 | lock sync.RWMutex 17 | size int // 窗口数量 18 | win *window // 窗口数据容器 19 | interval time.Duration // 窗口间隔周期 20 | offset int // 窗口游标 21 | // 汇总数据时,是否忽略当前正在写入桶的数据 22 | // 某些场景下因为当前正在写入的桶数据并没有经过完整的窗口时间间隔 23 | // 可能导致当前桶的统计并不准确 24 | ignoreCurrent bool 25 | // 最后写入桶的时间 26 | // 用于计算下一次写入数据间隔最后一次写入数据的之间 27 | // 经过了多少个时间间隔 28 | lastTime time.Duration // start time of the last bucket 29 | } 30 | ) 31 | 32 | // 初始化滑动窗口 33 | // NewRollingWindow returns a RollingWindow that with size buckets and time interval, 34 | // use opts to customize the RollingWindow. 35 | func NewRollingWindow(size int, interval time.Duration, opts ...RollingWindowOption) *RollingWindow { 36 | if size < 1 { 37 | panic("size must be greater than 0") 38 | } 39 | w := &RollingWindow{ 40 | size: size, 41 | win: newWindow(size), 42 | interval: interval, 43 | lastTime: timex.Now(), 44 | } 45 | for _, opt := range opts { 46 | opt(w) 47 | } 48 | return w 49 | } 50 | 51 | // Add adds value to current bucket. 52 | // 添加数据 53 | func (rw *RollingWindow) Add(v float64) { 54 | rw.lock.Lock() 55 | defer rw.lock.Unlock() 56 | 57 | rw.updateOffset() // 获取当前写入的下标,滑动的动作发生在此 58 | rw.win.add(rw.offset, v) // 添加数据 59 | } 60 | 61 | // Reduce runs fn on all buckets, ignore current bucket if ignoreCurrent was set. 62 | // 归纳汇总数据 63 | func (rw *RollingWindow) Reduce(fn func(b *Bucket)) { 64 | rw.lock.RLock() 65 | defer rw.lock.RUnlock() 66 | 67 | var diff int 68 | span := rw.span() 69 | // ignore current bucket, because of partial data 70 | // 当前时间截止前,未过期桶的数量 71 | if span == 0 && rw.ignoreCurrent { 72 | diff = rw.size - 1 73 | } else { 74 | diff = rw.size - span 75 | } 76 | if diff > 0 { 77 | // rw.offset - rw.offset+span之间的桶数据是过期的不应该计入统计 78 | offset := (rw.offset + span + 1) % rw.size 79 | // 汇总数据 80 | rw.win.reduce(offset, diff, fn) 81 | } 82 | } 83 | 84 | // 计算当前距离最后写入数据经过多少个单元时间间隔 85 | // 实际上指的就是经过多少个桶 86 | func (rw *RollingWindow) span() int { 87 | offset := int(timex.Since(rw.lastTime) / rw.interval) 88 | if 0 <= offset && offset < rw.size { 89 | return offset 90 | } 91 | // 大于时间窗口时 返回窗口大小即可 92 | return rw.size 93 | } 94 | 95 | // 更新当前时间的游标 96 | // 实现窗口滑动 97 | func (rw *RollingWindow) updateOffset() { 98 | // 经过span个桶的时间 99 | span := rw.span() 100 | // 还在同一单元时间内不需要更新 101 | if span <= 0 { 102 | return 103 | } 104 | 105 | offset := rw.offset 106 | // reset expired buckets 107 | // 既然经过了span个桶的时间没有写入数据 108 | // 那么这些桶内的数据就不应该继续保留了,属于过期数据清空即可 109 | // 可以看到这里全部用的 % 取余操作,可以实现按照下标周期性写入 110 | // 如果超出下标了那就从头开始写,确保新数据一定能够正常写入 111 | // 类似循环数组的效果 112 | for i := 0; i < span; i++ { 113 | rw.win.resetBucket((offset + i + 1) % rw.size) 114 | } 115 | // 更新offset 116 | rw.offset = (offset + span) % rw.size 117 | now := timex.Now() 118 | // align to interval time boundary 119 | // 更新操作时间(当前时间-上次时间余数) 120 | rw.lastTime = now - (now-rw.lastTime)%rw.interval 121 | } 122 | 123 | // Bucket defines the bucket that holds sum and num of additions. 124 | // 桶 125 | type Bucket struct { 126 | Sum float64 // 当前桶内值之和 127 | Count int64 // 当前桶的add总次数 128 | } 129 | 130 | // 添加数据 131 | func (b *Bucket) add(v float64) { 132 | b.Sum += v // 求和 133 | b.Count++ // 次数+1 134 | } 135 | 136 | // 桶重置 137 | func (b *Bucket) reset() { 138 | b.Sum = 0 139 | b.Count = 0 140 | } 141 | 142 | // 滑动窗口 143 | type window struct { 144 | buckets []*Bucket // 环形数组 145 | size int 146 | } 147 | 148 | // 初始化窗口 149 | func newWindow(size int) *window { 150 | buckets := make([]*Bucket, size) 151 | for i := 0; i < size; i++ { 152 | buckets[i] = new(Bucket) 153 | } 154 | return &window{ 155 | buckets: buckets, 156 | size: size, 157 | } 158 | } 159 | 160 | // 添加数据 161 | // offset 游标,定位写入bucket位置 162 | // v 行为数据 163 | func (w *window) add(offset int, v float64) { 164 | w.buckets[offset%w.size].add(v) 165 | } 166 | 167 | // 汇总数据 168 | // fn 自定义的bucket统计函数 169 | func (w *window) reduce(start, count int, fn func(b *Bucket)) { 170 | for i := 0; i < count; i++ { 171 | fn(w.buckets[(start+i)%w.size]) 172 | } 173 | } 174 | 175 | // 重置特定 bucket 176 | func (w *window) resetBucket(offset int) { 177 | w.buckets[offset%w.size].reset() 178 | } 179 | 180 | // IgnoreCurrentBucket lets the Reduce call ignore current bucket. 181 | // 让 Reduce 调用忽略当前bucket 182 | func IgnoreCurrentBucket() RollingWindowOption { 183 | return func(w *RollingWindow) { 184 | w.ignoreCurrent = true 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /code/core/limit/periodlimit.go: -------------------------------------------------------------------------------- 1 | package limit 2 | 3 | import ( 4 | "errors" 5 | "strconv" 6 | "time" 7 | 8 | "github.com/zeromicro/go-zero/core/stores/redis" 9 | ) 10 | 11 | /* 12 | 13 | -- KYES[1]:限流器key 14 | -- ARGV[1]:qos,单位时间内最多请求次数 15 | -- ARGV[2]:单位限流窗口时间 16 | -- 请求最大次数,等于p.quota 17 | local limit = tonumber(ARGV[1]) 18 | -- 窗口即一个单位限流周期,这里用过期模拟窗口效果,等于p.permit 19 | local window = tonumber(ARGV[2]) 20 | -- 请求次数+1,获取请求总数 21 | local current = redis.call("INCRBY",KYES[1],1) 22 | -- 如果是第一次请求,则设置过期时间并返回 成功 23 | if current == 1 then 24 | redis.call("expire",KYES[1],window) 25 | return 1 26 | -- 如果当前请求数量limit则返回 失败 33 | else 34 | return 0 35 | end 36 | 37 | */ 38 | 39 | // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key 40 | const periodScript = `local limit = tonumber(ARGV[1]) 41 | local window = tonumber(ARGV[2]) 42 | local current = redis.call("INCRBY", KEYS[1], 1) 43 | if current == 1 then 44 | redis.call("expire", KEYS[1], window) 45 | return 1 46 | elseif current < limit then 47 | return 1 48 | elseif current == limit then 49 | return 2 50 | else 51 | return 0 52 | end` 53 | 54 | const ( 55 | // Unknown means not initialized state. 56 | Unknown = iota 57 | // Allowed means allowed state. 58 | Allowed 59 | // HitQuota means this request exactly hit the quota. 60 | HitQuota 61 | // OverQuota means passed the quota. 62 | OverQuota 63 | 64 | internalOverQuota = 0 65 | internalAllowed = 1 66 | internalHitQuota = 2 67 | ) 68 | 69 | // ErrUnknownCode is an error that represents unknown status code. 70 | var ErrUnknownCode = errors.New("unknown status code") 71 | 72 | type ( 73 | // PeriodOption defines the method to customize a PeriodLimit. 74 | // go中常见的option参数模式 75 | // 如果参数非常多,推荐使用此模式来设置参数 76 | PeriodOption func(l *PeriodLimit) 77 | 78 | // A PeriodLimit is used to limit requests during a period of time. 79 | // 固定时间窗口限流器 80 | PeriodLimit struct { 81 | period int // 窗口大小,单位s 82 | quota int // 请求上限 83 | limitStore *redis.Redis // 存储 84 | keyPrefix string // key前缀 85 | // 线性限流,开启此选项后可以实现周期性的限流 86 | // 比如quota=5时,quota实际值可能会是5.4.3.2.1呈现出周期性变化 87 | align bool 88 | } 89 | ) 90 | 91 | // NewPeriodLimit returns a PeriodLimit with given parameters. 92 | func NewPeriodLimit(period, quota int, limitStore *redis.Redis, keyPrefix string, 93 | opts ...PeriodOption, 94 | ) *PeriodLimit { 95 | limiter := &PeriodLimit{ 96 | period: period, 97 | quota: quota, 98 | limitStore: limitStore, 99 | keyPrefix: keyPrefix, 100 | } 101 | 102 | for _, opt := range opts { 103 | opt(limiter) 104 | } 105 | 106 | return limiter 107 | } 108 | 109 | // Take requests a permit, it returns the permit state. 110 | // 执行限流 111 | // 注意一下返回值: 112 | // 0:表示错误,比如可能是redis故障、过载 113 | // 1:允许 114 | // 2:允许但是当前窗口内已到达上限 115 | // 3:拒绝 116 | func (h *PeriodLimit) Take(key string) (int, error) { 117 | // 执行lua脚本 118 | resp, err := h.limitStore.Eval(periodScript, []string{h.keyPrefix + key}, []string{ 119 | strconv.Itoa(h.quota), 120 | strconv.Itoa(h.calcExpireSeconds()), 121 | }) 122 | if err != nil { 123 | return Unknown, err 124 | } 125 | 126 | code, ok := resp.(int64) 127 | if !ok { 128 | return Unknown, ErrUnknownCode 129 | } 130 | 131 | switch code { 132 | case internalOverQuota: 133 | return OverQuota, nil 134 | case internalAllowed: 135 | return Allowed, nil 136 | case internalHitQuota: 137 | return HitQuota, nil 138 | default: 139 | return Unknown, ErrUnknownCode 140 | } 141 | } 142 | 143 | // 计算过期时间也就是窗口时间大小 144 | // 如果align==true 145 | // 线性限流,开启此选项后可以实现周期性的限流 146 | // 比如quota=5时,quota实际值可能会是5.4.3.2.1呈现出周期性变化 147 | func (h *PeriodLimit) calcExpireSeconds() int { 148 | if h.align { 149 | now := time.Now() 150 | _, offset := now.Zone() 151 | unix := now.Unix() + int64(offset) 152 | return h.period - int(unix%int64(h.period)) 153 | } 154 | 155 | return h.period 156 | } 157 | 158 | // Align returns a func to customize a PeriodLimit with alignment. 159 | // For example, if we want to limit end users with 5 sms verification messages every day, 160 | // we need to align with the local timezone and the start of the day. 161 | func Align() PeriodOption { 162 | return func(l *PeriodLimit) { 163 | l.align = true 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /code/core/limit/tokenlimit.go: -------------------------------------------------------------------------------- 1 | package limit 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "sync" 7 | "sync/atomic" 8 | "time" 9 | 10 | "github.com/zeromicro/go-zero/core/logx" 11 | "github.com/zeromicro/go-zero/core/stores/redis" 12 | xrate "golang.org/x/time/rate" 13 | ) 14 | 15 | /* 16 | 17 | -- 返回是否可以活获得预期的token 18 | 每秒生成token数量即token生成速度 19 | local rate = tonumber(ARGV[1]) 20 | -- 桶容量 21 | local capacity = tonumber(ARGV[2]) 22 | -- 当前时间戳 23 | local now = tonumber(ARGV[3]) 24 | -- 当前请求token数量 25 | local requested = tonumber(ARGV[4]) 26 | 27 | -- fill_time:填满 token_bucket 需要多久 28 | local fill_time = capacity/rate 29 | -- 向下取整,ttl为填满时间的2倍 30 | local ttl = math.floor(fill_time*2) 31 | -- 当前时间桶容量 32 | -- 获取目前 token_bucket 中剩余 token 数 33 | -- 如果是第一次进入,则设置 token_bucket 数量为 令牌桶最大值 34 | local last_tokens = tonumber(redis.call("get", KEYS[1])) 35 | -- 如果当前桶容量为0,说明是第一次进入,则默认容量为桶的最大容量 36 | if last_tokens == nil then 37 | last_tokens = capacity 38 | end 39 | 40 | -- 上一次更新 token_bucket 的时间 41 | local last_refreshed = tonumber(redis.call("get", KEYS[2])) 42 | -- 第一次进入则设置刷新时间为0 43 | if last_refreshed == nil then 44 | last_refreshed = 0 45 | end 46 | 47 | -- 距离上次请求的时间跨度 48 | local delta = math.max(0, now-last_refreshed) 49 | -- 通过当前时间与上一次更新时间的跨度,以及生产token的速率,计算出新的token数 50 | -- 如果超过 max_burst,多余生产的token会被丢弃 51 | local filled_tokens = math.min(capacity, last_tokens+(delta*rate)) 52 | -- 本次请求token数量是否足够 53 | local allowed = filled_tokens >= requested 54 | -- 桶剩余数量 55 | local new_tokens = filled_tokens 56 | -- 允许本次token申请,计算剩余数量 57 | if allowed then 58 | new_tokens = filled_tokens - requested 59 | end 60 | 61 | -- 更新新的token数,以及更新时间 62 | -- 设置剩余token数量 63 | redis.call("setex", KEYS[1], ttl, new_tokens) 64 | --设置刷新时间 65 | redis.call("setex", KEYS[2], ttl, now) 66 | 67 | return allowed 68 | 69 | */ 70 | 71 | const ( 72 | // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key 73 | // KEYS[1] as tokens_key 74 | // KEYS[2] as timestamp_key 75 | script = `local rate = tonumber(ARGV[1]) 76 | local capacity = tonumber(ARGV[2]) 77 | local now = tonumber(ARGV[3]) 78 | local requested = tonumber(ARGV[4]) 79 | local fill_time = capacity/rate 80 | local ttl = math.floor(fill_time*2) 81 | local last_tokens = tonumber(redis.call("get", KEYS[1])) 82 | if last_tokens == nil then 83 | last_tokens = capacity 84 | end 85 | 86 | local last_refreshed = tonumber(redis.call("get", KEYS[2])) 87 | if last_refreshed == nil then 88 | last_refreshed = 0 89 | end 90 | 91 | local delta = math.max(0, now-last_refreshed) 92 | local filled_tokens = math.min(capacity, last_tokens+(delta*rate)) 93 | local allowed = filled_tokens >= requested 94 | local new_tokens = filled_tokens 95 | if allowed then 96 | new_tokens = filled_tokens - requested 97 | end 98 | 99 | redis.call("setex", KEYS[1], ttl, new_tokens) 100 | redis.call("setex", KEYS[2], ttl, now) 101 | 102 | return allowed` 103 | tokenFormat = "{%s}.tokens" 104 | timestampFormat = "{%s}.ts" 105 | pingInterval = time.Millisecond * 100 106 | ) 107 | 108 | // A TokenLimiter controls how frequently events are allowed to happen with in one second. 109 | type TokenLimiter struct { 110 | rate int // 每秒生产速率 111 | burst int // 桶容量 112 | store *redis.Redis // 存储容器 113 | tokenKey string // redis key 114 | timestampKey string // 桶刷新时间key 115 | rescueLock sync.Mutex // lock 116 | redisAlive uint32 // redis健康标识 117 | rescueLimiter *xrate.Limiter // redis故障时采用进程内 令牌桶限流器 118 | monitorStarted bool // redis监控探测任务标识 119 | } 120 | 121 | // NewTokenLimiter returns a new TokenLimiter that allows events up to rate and permits 122 | // bursts of at most burst tokens. 123 | func NewTokenLimiter(rate, burst int, store *redis.Redis, key string) *TokenLimiter { 124 | tokenKey := fmt.Sprintf(tokenFormat, key) 125 | timestampKey := fmt.Sprintf(timestampFormat, key) 126 | 127 | return &TokenLimiter{ 128 | rate: rate, 129 | burst: burst, 130 | store: store, 131 | tokenKey: tokenKey, 132 | timestampKey: timestampKey, 133 | redisAlive: 1, 134 | rescueLimiter: xrate.NewLimiter(xrate.Every(time.Second/time.Duration(rate)), burst), 135 | } 136 | } 137 | 138 | // Allow is shorthand for AllowN(time.Now(), 1). 139 | func (lim *TokenLimiter) Allow() bool { 140 | return lim.AllowN(time.Now(), 1) 141 | } 142 | 143 | // AllowN reports whether n events may happen at time now. 144 | // Use this method if you intend to drop / skip events that exceed the rate. 145 | // Otherwise, use Reserve or Wait. 146 | func (lim *TokenLimiter) AllowN(now time.Time, n int) bool { 147 | return lim.reserveN(now, n) 148 | } 149 | 150 | func (lim *TokenLimiter) reserveN(now time.Time, n int) bool { 151 | // 判断redis是否健康 152 | // redis故障时采用进程内限流器 153 | // 兜底保障 154 | if atomic.LoadUint32(&lim.redisAlive) == 0 { 155 | return lim.rescueLimiter.AllowN(now, n) 156 | } 157 | // 执行脚本获取令牌 158 | resp, err := lim.store.Eval( 159 | script, 160 | []string{ 161 | lim.tokenKey, 162 | lim.timestampKey, 163 | }, 164 | []string{ 165 | strconv.Itoa(lim.rate), 166 | strconv.Itoa(lim.burst), 167 | strconv.FormatInt(now.Unix(), 10), 168 | strconv.Itoa(n), 169 | }) 170 | // redis allowed == false 171 | // Lua boolean false -> r Nil bulk reply 172 | // 特殊处理key不存在的情况 173 | if err == redis.Nil { 174 | return false 175 | } 176 | if err != nil { 177 | logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err) 178 | // 执行异常,开启redis健康探测任务 179 | // 同时采用进程内限流器作为兜底 180 | lim.startMonitor() 181 | return lim.rescueLimiter.AllowN(now, n) 182 | } 183 | 184 | code, ok := resp.(int64) 185 | if !ok { 186 | logx.Errorf("fail to eval redis script: %v, use in-process limiter for rescue", resp) 187 | lim.startMonitor() 188 | return lim.rescueLimiter.AllowN(now, n) 189 | } 190 | 191 | // redis allowed == true 192 | // Lua boolean true -> r integer reply with value of 1 193 | return code == 1 194 | } 195 | 196 | // 开启redis健康探测 197 | func (lim *TokenLimiter) startMonitor() { 198 | lim.rescueLock.Lock() 199 | defer lim.rescueLock.Unlock() 200 | // 防止重复开启 201 | if lim.monitorStarted { 202 | return 203 | } 204 | // 设置任务和健康标识 205 | lim.monitorStarted = true 206 | atomic.StoreUint32(&lim.redisAlive, 0) 207 | // 健康探测 208 | go lim.waitForRedis() 209 | } 210 | 211 | // redis健康探测定时任务 212 | func (lim *TokenLimiter) waitForRedis() { 213 | // 健康探测成功时回调此函数 214 | ticker := time.NewTicker(pingInterval) 215 | defer func() { 216 | ticker.Stop() 217 | lim.rescueLock.Lock() 218 | lim.monitorStarted = false 219 | lim.rescueLock.Unlock() 220 | }() 221 | 222 | for range ticker.C { 223 | // ping属于redis内置健康探测命令 224 | if lim.store.Ping() { 225 | // 健康探测成功,设置健康标识 226 | atomic.StoreUint32(&lim.redisAlive, 1) 227 | return 228 | } 229 | } 230 | } 231 | -------------------------------------------------------------------------------- /code/core/load/nopshedder.go: -------------------------------------------------------------------------------- 1 | package load 2 | 3 | type nopShedder struct{} 4 | 5 | func newNopShedder() Shedder { 6 | return nopShedder{} 7 | } 8 | 9 | func (s nopShedder) Allow() (Promise, error) { 10 | return nopPromise{}, nil 11 | } 12 | 13 | type nopPromise struct{} 14 | 15 | func (p nopPromise) Pass() { 16 | } 17 | 18 | func (p nopPromise) Fail() { 19 | } 20 | -------------------------------------------------------------------------------- /code/core/stat/internal/cgroup_linux.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | "strconv" 8 | "strings" 9 | 10 | "github.com/zeromicro/go-zero/core/iox" 11 | "github.com/zeromicro/go-zero/core/lang" 12 | ) 13 | 14 | const cgroupDir = "/sys/fs/cgroup" 15 | 16 | type cgroup struct { 17 | cgroups map[string]string 18 | } 19 | 20 | func (c *cgroup) acctUsageAllCpus() (uint64, error) { 21 | data, err := iox.ReadText(path.Join(c.cgroups["cpuacct"], "cpuacct.usage")) 22 | if err != nil { 23 | return 0, err 24 | } 25 | 26 | return parseUint(string(data)) 27 | } 28 | 29 | func (c *cgroup) acctUsagePerCpu() ([]uint64, error) { 30 | data, err := iox.ReadText(path.Join(c.cgroups["cpuacct"], "cpuacct.usage_percpu")) 31 | if err != nil { 32 | return nil, err 33 | } 34 | 35 | var usage []uint64 36 | for _, v := range strings.Fields(string(data)) { 37 | u, err := parseUint(v) 38 | if err != nil { 39 | return nil, err 40 | } 41 | 42 | usage = append(usage, u) 43 | } 44 | 45 | return usage, nil 46 | } 47 | 48 | func (c *cgroup) cpuQuotaUs() (int64, error) { 49 | data, err := iox.ReadText(path.Join(c.cgroups["cpu"], "cpu.cfs_quota_us")) 50 | if err != nil { 51 | return 0, err 52 | } 53 | 54 | return strconv.ParseInt(string(data), 10, 64) 55 | } 56 | 57 | func (c *cgroup) cpuPeriodUs() (uint64, error) { 58 | data, err := iox.ReadText(path.Join(c.cgroups["cpu"], "cpu.cfs_period_us")) 59 | if err != nil { 60 | return 0, err 61 | } 62 | 63 | return parseUint(string(data)) 64 | } 65 | 66 | func (c *cgroup) cpus() ([]uint64, error) { 67 | data, err := iox.ReadText(path.Join(c.cgroups["cpuset"], "cpuset.cpus")) 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | return parseUints(string(data)) 73 | } 74 | 75 | func currentCgroup() (*cgroup, error) { 76 | cgroupFile := fmt.Sprintf("/proc/%d/cgroup", os.Getpid()) 77 | lines, err := iox.ReadTextLines(cgroupFile, iox.WithoutBlank()) 78 | if err != nil { 79 | return nil, err 80 | } 81 | 82 | cgroups := make(map[string]string) 83 | for _, line := range lines { 84 | cols := strings.Split(line, ":") 85 | if len(cols) != 3 { 86 | return nil, fmt.Errorf("invalid cgroup line: %s", line) 87 | } 88 | 89 | subsys := cols[1] 90 | // only read cpu staff 91 | if !strings.HasPrefix(subsys, "cpu") { 92 | continue 93 | } 94 | 95 | // https://man7.org/linux/man-pages/man7/cgroups.7.html 96 | // comma-separated list of controllers for cgroup version 1 97 | fields := strings.Split(subsys, ",") 98 | for _, val := range fields { 99 | cgroups[val] = path.Join(cgroupDir, val) 100 | } 101 | } 102 | 103 | return &cgroup{ 104 | cgroups: cgroups, 105 | }, nil 106 | } 107 | 108 | func parseUint(s string) (uint64, error) { 109 | v, err := strconv.ParseInt(s, 10, 64) 110 | if err != nil { 111 | if err.(*strconv.NumError).Err == strconv.ErrRange { 112 | return 0, nil 113 | } 114 | 115 | return 0, fmt.Errorf("cgroup: bad int format: %s", s) 116 | } 117 | 118 | if v < 0 { 119 | return 0, nil 120 | } 121 | 122 | return uint64(v), nil 123 | } 124 | 125 | func parseUints(val string) ([]uint64, error) { 126 | if val == "" { 127 | return nil, nil 128 | } 129 | 130 | ints := make(map[uint64]lang.PlaceholderType) 131 | cols := strings.Split(val, ",") 132 | for _, r := range cols { 133 | if strings.Contains(r, "-") { 134 | fields := strings.SplitN(r, "-", 2) 135 | min, err := parseUint(fields[0]) 136 | if err != nil { 137 | return nil, fmt.Errorf("cgroup: bad int list format: %s", val) 138 | } 139 | 140 | max, err := parseUint(fields[1]) 141 | if err != nil { 142 | return nil, fmt.Errorf("cgroup: bad int list format: %s", val) 143 | } 144 | 145 | if max < min { 146 | return nil, fmt.Errorf("cgroup: bad int list format: %s", val) 147 | } 148 | 149 | for i := min; i <= max; i++ { 150 | ints[i] = lang.Placeholder 151 | } 152 | } else { 153 | v, err := parseUint(r) 154 | if err != nil { 155 | return nil, err 156 | } 157 | 158 | ints[v] = lang.Placeholder 159 | } 160 | } 161 | 162 | var sets []uint64 163 | for k := range ints { 164 | sets = append(sets, k) 165 | } 166 | 167 | return sets, nil 168 | } 169 | -------------------------------------------------------------------------------- /code/core/stat/internal/cpu_linux.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "log" 7 | "strings" 8 | "time" 9 | 10 | "github.com/zeromicro/go-zero/core/iox" 11 | ) 12 | 13 | const ( 14 | cpuTicks = 100 15 | cpuFields = 8 16 | ) 17 | 18 | var ( 19 | preSystem uint64 20 | preTotal uint64 21 | quota float64 22 | cores uint64 23 | ) 24 | 25 | // if /proc not present, ignore the cpu calculation, like wsl linux 26 | func init() { 27 | cpus, err := perCpuUsage() 28 | if err != nil { 29 | log.Println(err) 30 | return 31 | } 32 | 33 | cores = uint64(len(cpus)) 34 | sets, err := cpuSets() 35 | if err != nil { 36 | log.Println(err) 37 | return 38 | } 39 | 40 | quota = float64(len(sets)) 41 | cq, err := cpuQuota() 42 | if err == nil { 43 | if cq != -1 { 44 | period, err := cpuPeriod() 45 | if err != nil { 46 | log.Println(err) 47 | return 48 | } 49 | 50 | limit := float64(cq) / float64(period) 51 | if limit < quota { 52 | quota = limit 53 | } 54 | } 55 | } 56 | 57 | preSystem, err = systemCpuUsage() 58 | if err != nil { 59 | log.Println(err) 60 | return 61 | } 62 | 63 | preTotal, err = totalCpuUsage() 64 | if err != nil { 65 | log.Println(err) 66 | return 67 | } 68 | } 69 | 70 | // RefreshCpu refreshes cpu usage and returns. 71 | func RefreshCpu() uint64 { 72 | total, err := totalCpuUsage() 73 | if err != nil { 74 | return 0 75 | } 76 | system, err := systemCpuUsage() 77 | if err != nil { 78 | return 0 79 | } 80 | 81 | var usage uint64 82 | cpuDelta := total - preTotal 83 | systemDelta := system - preSystem 84 | if cpuDelta > 0 && systemDelta > 0 { 85 | usage = uint64(float64(cpuDelta*cores*1e3) / (float64(systemDelta) * quota)) 86 | } 87 | preSystem = system 88 | preTotal = total 89 | 90 | return usage 91 | } 92 | 93 | func cpuQuota() (int64, error) { 94 | cg, err := currentCgroup() 95 | if err != nil { 96 | return 0, err 97 | } 98 | 99 | return cg.cpuQuotaUs() 100 | } 101 | 102 | func cpuPeriod() (uint64, error) { 103 | cg, err := currentCgroup() 104 | if err != nil { 105 | return 0, err 106 | } 107 | 108 | return cg.cpuPeriodUs() 109 | } 110 | 111 | func cpuSets() ([]uint64, error) { 112 | cg, err := currentCgroup() 113 | if err != nil { 114 | return nil, err 115 | } 116 | 117 | return cg.cpus() 118 | } 119 | 120 | func perCpuUsage() ([]uint64, error) { 121 | cg, err := currentCgroup() 122 | if err != nil { 123 | return nil, err 124 | } 125 | 126 | return cg.acctUsagePerCpu() 127 | } 128 | 129 | func systemCpuUsage() (uint64, error) { 130 | lines, err := iox.ReadTextLines("/proc/stat", iox.WithoutBlank()) 131 | if err != nil { 132 | return 0, err 133 | } 134 | 135 | for _, line := range lines { 136 | fields := strings.Fields(line) 137 | if fields[0] == "cpu" { 138 | if len(fields) < cpuFields { 139 | return 0, fmt.Errorf("bad format of cpu stats") 140 | } 141 | 142 | var totalClockTicks uint64 143 | for _, i := range fields[1:cpuFields] { 144 | v, err := parseUint(i) 145 | if err != nil { 146 | return 0, err 147 | } 148 | 149 | totalClockTicks += v 150 | } 151 | 152 | return (totalClockTicks * uint64(time.Second)) / cpuTicks, nil 153 | } 154 | } 155 | 156 | return 0, errors.New("bad stats format") 157 | } 158 | 159 | func totalCpuUsage() (usage uint64, err error) { 160 | var cg *cgroup 161 | if cg, err = currentCgroup(); err != nil { 162 | return 163 | } 164 | 165 | return cg.acctUsageAllCpus() 166 | } 167 | -------------------------------------------------------------------------------- /code/core/stat/internal/cpu_other.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | // +build !linux 3 | 4 | package internal 5 | 6 | // RefreshCpu returns cpu usage, always returns 0 on systems other than linux. 7 | func RefreshCpu() uint64 { 8 | return 0 9 | } 10 | -------------------------------------------------------------------------------- /code/core/stat/usage.go: -------------------------------------------------------------------------------- 1 | package stat 2 | 3 | import ( 4 | "log" 5 | "runtime" 6 | "sync/atomic" 7 | "time" 8 | 9 | "gozerosource/code/core/stat/internal" 10 | "gozerosource/code/core/syncx" 11 | 12 | "github.com/zeromicro/go-zero/core/threading" 13 | ) 14 | 15 | const ( 16 | // 250ms and 0.95 as beta will count the average cpu load for past 5 seconds 17 | cpuRefreshInterval = time.Millisecond * 250 18 | allRefreshInterval = time.Minute 19 | // moving average beta hyperparameter 20 | beta = 0.95 21 | ) 22 | 23 | var logEnabled = syncx.ForAtomicBool(true) 24 | 25 | var cpuUsage int64 26 | 27 | func init() { 28 | go func() { 29 | cpuTicker := time.NewTicker(cpuRefreshInterval) 30 | defer cpuTicker.Stop() 31 | allTicker := time.NewTicker(allRefreshInterval) 32 | defer allTicker.Stop() 33 | 34 | for { 35 | select { 36 | case <-cpuTicker.C: 37 | threading.RunSafe(func() { 38 | // cpu滑动平均值 39 | curUsage := internal.RefreshCpu() 40 | prevUsage := atomic.LoadInt64(&cpuUsage) 41 | // cpu = cpuᵗ⁻¹ * beta + cpuᵗ * (1 - beta) 42 | // 滑动平均算法 43 | usage := int64(float64(prevUsage)*beta + float64(curUsage)*(1-beta)) 44 | atomic.StoreInt64(&cpuUsage, usage) 45 | }) 46 | case <-allTicker.C: 47 | if logEnabled.True() { 48 | printUsage() 49 | } 50 | } 51 | } 52 | }() 53 | } 54 | 55 | // CpuUsage returns current cpu usage. 56 | func CpuUsage() int64 { 57 | return atomic.LoadInt64(&cpuUsage) 58 | } 59 | 60 | func bToMb(b uint64) float32 { 61 | return float32(b) / 1024 / 1024 62 | } 63 | 64 | func printUsage() { 65 | var m runtime.MemStats 66 | runtime.ReadMemStats(&m) 67 | log.Printf("CPU: %dm, MEMORY: Alloc=%.1fMi, TotalAlloc=%.1fMi, Sys=%.1fMi, NumGC=%d", 68 | CpuUsage(), bToMb(m.Alloc), bToMb(m.TotalAlloc), bToMb(m.Sys), m.NumGC) 69 | } 70 | -------------------------------------------------------------------------------- /code/core/syncx/atomicbool.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import "sync/atomic" 4 | 5 | // An AtomicBool is an atomic implementation for boolean values. 6 | type AtomicBool uint32 7 | 8 | // NewAtomicBool returns an AtomicBool. 9 | func NewAtomicBool() *AtomicBool { 10 | return new(AtomicBool) 11 | } 12 | 13 | // ForAtomicBool returns an AtomicBool with given val. 14 | func ForAtomicBool(val bool) *AtomicBool { 15 | b := NewAtomicBool() 16 | b.Set(val) 17 | return b 18 | } 19 | 20 | // CompareAndSwap compares current value with given old, if equals, set to given val. 21 | func (b *AtomicBool) CompareAndSwap(old, val bool) bool { 22 | var ov, nv uint32 23 | if old { 24 | ov = 1 25 | } 26 | if val { 27 | nv = 1 28 | } 29 | return atomic.CompareAndSwapUint32((*uint32)(b), ov, nv) 30 | } 31 | 32 | // Set sets the value to v. 33 | func (b *AtomicBool) Set(v bool) { 34 | if v { 35 | atomic.StoreUint32((*uint32)(b), 1) 36 | } else { 37 | atomic.StoreUint32((*uint32)(b), 0) 38 | } 39 | } 40 | 41 | // True returns true if current value is true. 42 | func (b *AtomicBool) True() bool { 43 | return atomic.LoadUint32((*uint32)(b)) == 1 44 | } 45 | -------------------------------------------------------------------------------- /code/core/syncx/atomicduration.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "sync/atomic" 5 | "time" 6 | ) 7 | 8 | // An AtomicDuration is an implementation of atomic duration. 9 | type AtomicDuration int64 10 | 11 | // NewAtomicDuration returns an AtomicDuration. 12 | func NewAtomicDuration() *AtomicDuration { 13 | return new(AtomicDuration) 14 | } 15 | 16 | // ForAtomicDuration returns an AtomicDuration with given value. 17 | func ForAtomicDuration(val time.Duration) *AtomicDuration { 18 | d := NewAtomicDuration() 19 | d.Set(val) 20 | return d 21 | } 22 | 23 | // CompareAndSwap compares current value with old, if equals, set the value to val. 24 | func (d *AtomicDuration) CompareAndSwap(old, val time.Duration) bool { 25 | return atomic.CompareAndSwapInt64((*int64)(d), int64(old), int64(val)) 26 | } 27 | 28 | // Load loads the current duration. 29 | func (d *AtomicDuration) Load() time.Duration { 30 | return time.Duration(atomic.LoadInt64((*int64)(d))) 31 | } 32 | 33 | // Set sets the value to val. 34 | func (d *AtomicDuration) Set(val time.Duration) { 35 | atomic.StoreInt64((*int64)(d), int64(val)) 36 | } 37 | -------------------------------------------------------------------------------- /code/core/syncx/limit.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/zeromicro/go-zero/core/lang" 7 | ) 8 | 9 | // ErrLimitReturn indicates that the more than borrowed elements were returned. 10 | var ErrLimitReturn = errors.New("discarding limited token, resource pool is full, someone returned multiple times") 11 | 12 | // Limit controls the concurrent requests. 13 | type Limit struct { 14 | pool chan lang.PlaceholderType 15 | } 16 | 17 | // NewLimit creates a Limit that can borrow n elements from it concurrently. 18 | func NewLimit(n int) Limit { 19 | return Limit{ 20 | pool: make(chan lang.PlaceholderType, n), 21 | } 22 | } 23 | 24 | // Borrow borrows an element from Limit in blocking mode. 25 | func (l Limit) Borrow() { 26 | l.pool <- lang.Placeholder 27 | } 28 | 29 | // Return returns the borrowed resource, returns error only if returned more than borrowed. 30 | func (l Limit) Return() error { 31 | select { 32 | case <-l.pool: 33 | return nil 34 | default: 35 | return ErrLimitReturn 36 | } 37 | } 38 | 39 | // TryBorrow tries to borrow an element from Limit, in non-blocking mode. 40 | // If success, true returned, false for otherwise. 41 | func (l Limit) TryBorrow() bool { 42 | select { 43 | case l.pool <- lang.Placeholder: 44 | return true 45 | default: 46 | return false 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /code/core/syncx/spinlock.go: -------------------------------------------------------------------------------- 1 | package syncx 2 | 3 | import ( 4 | "runtime" 5 | "sync/atomic" 6 | ) 7 | 8 | // A SpinLock is used as a lock a fast execution. 9 | type SpinLock struct { 10 | lock uint32 11 | } 12 | 13 | // Lock locks the SpinLock. 14 | func (sl *SpinLock) Lock() { 15 | for !sl.TryLock() { 16 | runtime.Gosched() 17 | } 18 | } 19 | 20 | // TryLock tries to lock the SpinLock. 21 | func (sl *SpinLock) TryLock() bool { 22 | return atomic.CompareAndSwapUint32(&sl.lock, 0, 1) 23 | } 24 | 25 | // Unlock unlocks the SpinLock. 26 | func (sl *SpinLock) Unlock() { 27 | atomic.StoreUint32(&sl.lock, 0) 28 | } 29 | -------------------------------------------------------------------------------- /code/limit/limit_test.go: -------------------------------------------------------------------------------- 1 | package limit_test 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | "sync/atomic" 7 | "testing" 8 | "time" 9 | 10 | "gozerosource/code/core/syncx" 11 | ) 12 | 13 | func Test_limit(t *testing.T) { 14 | const ( 15 | seconds = 5 16 | threads = 100 17 | ) 18 | timer := time.NewTimer(time.Second * seconds) 19 | quit := make(chan struct{}) 20 | defer timer.Stop() 21 | go func() { 22 | <-timer.C 23 | close(quit) 24 | }() 25 | 26 | latch := syncx.NewLimit(20) 27 | 28 | var allowed, denied int32 29 | var wait sync.WaitGroup 30 | for i := 0; i < threads; i++ { 31 | wait.Add(1) 32 | go func() { 33 | for { 34 | select { 35 | case <-quit: 36 | wait.Done() 37 | return 38 | default: 39 | if latch.TryBorrow() { 40 | atomic.AddInt32(&allowed, 1) 41 | defer func() { 42 | if err := latch.Return(); err != nil { 43 | fmt.Println(err) 44 | } 45 | }() 46 | } else { 47 | atomic.AddInt32(&denied, 1) 48 | } 49 | } 50 | } 51 | }() 52 | } 53 | 54 | wait.Wait() 55 | fmt.Printf("allowed: %d, denied: %d, qps: %d\n", allowed, denied, (allowed+denied)/seconds) 56 | } 57 | 58 | func Benchmark_limit(b *testing.B) { 59 | // 测试一个对象或者函数在多线程的场景下面是否安全 60 | b.RunParallel(func(pb *testing.PB) { 61 | latch := syncx.NewLimit(10) 62 | for pb.Next() { 63 | if latch.TryBorrow() { 64 | defer func() { 65 | if err := latch.Return(); err != nil { 66 | fmt.Println(err) 67 | } 68 | }() 69 | } 70 | } 71 | }) 72 | } 73 | -------------------------------------------------------------------------------- /code/limit/periodlimit_test.go: -------------------------------------------------------------------------------- 1 | package limit_test 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "strconv" 7 | "sync" 8 | "sync/atomic" 9 | "testing" 10 | "time" 11 | 12 | "gozerosource/code/core/limit" 13 | 14 | "github.com/zeromicro/go-zero/core/stores/redis" 15 | ) 16 | 17 | func Test_PeriodLimit(t *testing.T) { 18 | const ( 19 | burst = 100 20 | rate = 100 21 | seconds = 5 22 | threads = 2 23 | ) 24 | store := redis.New("127.0.0.1:6379") 25 | fmt.Println(store.Ping()) 26 | lmt := limit.NewPeriodLimit(seconds, 5, store, "period-limit") 27 | timer := time.NewTimer(time.Second * seconds) 28 | quit := make(chan struct{}) 29 | defer timer.Stop() 30 | go func() { 31 | <-timer.C 32 | close(quit) 33 | }() 34 | 35 | var allowed, denied int32 36 | var wait sync.WaitGroup 37 | for i := 0; i < threads; i++ { 38 | i := i 39 | wait.Add(1) 40 | go func() { 41 | for { 42 | select { 43 | case <-quit: 44 | wait.Done() 45 | return 46 | default: 47 | if v, err := lmt.Take(strconv.FormatInt(int64(i), 10)); err == nil && v == limit.Allowed { 48 | atomic.AddInt32(&allowed, 1) 49 | } else if err != nil { 50 | log.Fatal(err) 51 | } else { 52 | atomic.AddInt32(&denied, 1) 53 | } 54 | } 55 | } 56 | }() 57 | } 58 | 59 | wait.Wait() 60 | fmt.Printf("allowed: %d, denied: %d, qps: %d\n", allowed, denied, (allowed+denied)/seconds) 61 | } 62 | -------------------------------------------------------------------------------- /code/limit/tokenlimit_test.go: -------------------------------------------------------------------------------- 1 | package limit_test 2 | 3 | import ( 4 | "fmt" 5 | "runtime" 6 | "sync" 7 | "sync/atomic" 8 | "testing" 9 | "time" 10 | 11 | "gozerosource/code/core/limit" 12 | 13 | "github.com/zeromicro/go-zero/core/stores/redis" 14 | ) 15 | 16 | func Test_TokenLimiter(t *testing.T) { 17 | const ( 18 | burst = 100 19 | rate = 100 20 | seconds = 5 21 | ) 22 | store := redis.New("127.0.0.1:6379") 23 | fmt.Println(store.Ping()) 24 | // New tokenLimiter 25 | limiter := limit.NewTokenLimiter(rate, burst, store, "token-limiter") 26 | timer := time.NewTimer(time.Second * seconds) 27 | quit := make(chan struct{}) 28 | defer timer.Stop() 29 | go func() { 30 | <-timer.C 31 | close(quit) 32 | }() 33 | 34 | var allowed, denied int32 35 | var wait sync.WaitGroup 36 | for i := 0; i < runtime.NumCPU(); i++ { 37 | wait.Add(1) 38 | go func() { 39 | for { 40 | select { 41 | case <-quit: 42 | wait.Done() 43 | return 44 | default: 45 | if limiter.Allow() { 46 | atomic.AddInt32(&allowed, 1) 47 | } else { 48 | atomic.AddInt32(&denied, 1) 49 | } 50 | } 51 | } 52 | }() 53 | } 54 | 55 | wait.Wait() 56 | fmt.Printf("allowed: %d, denied: %d, qps: %d\n", allowed, denied, (allowed+denied)/seconds) 57 | } 58 | -------------------------------------------------------------------------------- /code/rest/rest/config.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/zeromicro/go-zero/core/service" 7 | ) 8 | 9 | type ( 10 | // A PrivateKeyConf is a private key config. 11 | PrivateKeyConf struct { 12 | Fingerprint string // 指纹 13 | KeyFile string // key 文件 14 | } 15 | 16 | // A SignatureConf is a signature config. 17 | SignatureConf struct { 18 | Strict bool `json:",default=false"` // 是否严格 19 | Expiry time.Duration `json:",default=1h"` // 过期时间 20 | PrivateKeys []PrivateKeyConf // 私有 keys 21 | } 22 | 23 | // A RestConf is a http service config. 24 | // Why not name it as Conf, because we need to consider usage like: 25 | // type Config struct { 26 | // zrpc.RpcConf 27 | // rest.RestConf 28 | // } 29 | // if with the name Conf, there will be two Conf inside Config. 30 | // rest 服务配置 31 | RestConf struct { 32 | service.ServiceConf // 业务服务配置 33 | Host string `json:",default=0.0.0.0"` // host 34 | Port int // port 35 | CertFile string `json:",optional"` // cret 文件 36 | KeyFile string `json:",optional"` // key 文件 37 | Verbose bool `json:",optional"` 38 | MaxConns int `json:",default=10000"` // 单服务可承载最大并发数 39 | MaxBytes int64 `json:",default=1048576"` // 单服务单次可承载最大数据 40 | // milliseconds 41 | Timeout int64 `json:",default=3000"` // 服务超时时间 42 | CpuThreshold int64 `json:",default=900,range=[0:1000]"` // 服务熔断阀值 43 | Signature SignatureConf `json:",optional"` // 签名配置 44 | } 45 | ) 46 | -------------------------------------------------------------------------------- /code/rest/rest/engine.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "crypto/tls" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | "time" 9 | 10 | "gozerosource/code/rest/rest/internal/response" 11 | 12 | "gozerosource/code/rest/rest/handler" 13 | "gozerosource/code/rest/rest/httpx" 14 | "gozerosource/code/rest/rest/internal" 15 | 16 | "github.com/justinas/alice" 17 | "github.com/zeromicro/go-zero/core/codec" 18 | "github.com/zeromicro/go-zero/core/load" 19 | "github.com/zeromicro/go-zero/core/stat" 20 | ) 21 | 22 | // use 1000m to represent 100% 23 | const topCpuUsage = 1000 24 | 25 | // ErrSignatureConfig is an error that indicates bad config for signature. 26 | var ErrSignatureConfig = errors.New("bad config for Signature") 27 | 28 | type engine struct { 29 | conf RestConf // rest 配置 30 | routes []featuredRoutes // 路由 31 | unauthorizedCallback handler.UnauthorizedCallback // 权限验证失败回调 32 | unsignedCallback handler.UnsignedCallback // 签名验证失败回调 33 | middlewares []Middleware // 中间件 34 | shedder load.Shedder // 降载处理器 35 | priorityShedder load.Shedder // 更高降载阀值处理器 36 | tlsConfig *tls.Config // tls 配置 37 | } 38 | 39 | // 引擎初始化 40 | func newEngine(c RestConf) *engine { 41 | srv := &engine{ 42 | conf: c, 43 | } 44 | if c.CpuThreshold > 0 { 45 | // 加载服务降载处理器 46 | srv.shedder = load.NewAdaptiveShedder(load.WithCpuThreshold(c.CpuThreshold)) 47 | // 加载更高降载阀值处理器 48 | srv.priorityShedder = load.NewAdaptiveShedder(load.WithCpuThreshold( 49 | (c.CpuThreshold + topCpuUsage) >> 1)) 50 | } 51 | 52 | return srv 53 | } 54 | 55 | // 添加路由 56 | func (ng *engine) addRoutes(r featuredRoutes) { 57 | ng.routes = append(ng.routes, r) 58 | } 59 | 60 | // 注入权限处理函数 61 | func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain, 62 | verifier func(alice.Chain) alice.Chain, 63 | ) alice.Chain { 64 | if fr.jwt.enabled { 65 | if len(fr.jwt.prevSecret) == 0 { 66 | chain = chain.Append(handler.Authorize( 67 | fr.jwt.secret, 68 | handler.WithUnauthorizedCallback(ng.unauthorizedCallback), 69 | )) 70 | } else { 71 | chain = chain.Append(handler.Authorize( 72 | fr.jwt.secret, 73 | handler.WithPrevSecret(fr.jwt.prevSecret), 74 | handler.WithUnauthorizedCallback(ng.unauthorizedCallback)), 75 | ) 76 | } 77 | } 78 | 79 | return verifier(chain) 80 | } 81 | 82 | // 绑定特色路由 83 | func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error { 84 | verifier, err := ng.signatureVerifier(fr.signature) 85 | if err != nil { 86 | return err 87 | } 88 | 89 | for _, route := range fr.routes { 90 | if err := ng.bindRoute(fr, router, metrics, route, verifier); err != nil { 91 | return err 92 | } 93 | } 94 | 95 | return nil 96 | } 97 | 98 | // 绑定路由 99 | func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics, 100 | route Route, verifier func(chain alice.Chain) alice.Chain, 101 | ) error { 102 | chain := alice.New( 103 | handler.TracingHandler(ng.conf.Name, route.Path), 104 | ng.getLogHandler(), 105 | handler.PrometheusHandler(route.Path), 106 | handler.MaxConns(ng.conf.MaxConns), 107 | handler.BreakerHandler(route.Method, route.Path, metrics), 108 | handler.SheddingHandler(ng.getShedder(fr.priority), metrics), 109 | handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)), 110 | handler.RecoverHandler, 111 | handler.MetricHandler(metrics), 112 | handler.MaxBytesHandler(ng.conf.MaxBytes), 113 | handler.GunzipHandler, 114 | ) 115 | chain = ng.appendAuthHandler(fr, chain, verifier) 116 | 117 | for _, middleware := range ng.middlewares { 118 | chain = chain.Append(convertMiddleware(middleware)) 119 | } 120 | handle := chain.ThenFunc(route.Handler) 121 | 122 | return router.Handle(route.Method, route.Path, handle) 123 | } 124 | 125 | // 批量绑定路由 126 | func (ng *engine) bindRoutes(router httpx.Router) error { 127 | metrics := ng.createMetrics() 128 | 129 | for _, fr := range ng.routes { 130 | if err := ng.bindFeaturedRoutes(router, fr, metrics); err != nil { 131 | return err 132 | } 133 | } 134 | 135 | return nil 136 | } 137 | 138 | func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration { 139 | if timeout > 0 { 140 | return timeout 141 | } 142 | 143 | return time.Duration(ng.conf.Timeout) * time.Millisecond 144 | } 145 | 146 | // 加载服务指标监控器 147 | func (ng *engine) createMetrics() *stat.Metrics { 148 | var metrics *stat.Metrics 149 | 150 | if len(ng.conf.Name) > 0 { 151 | metrics = stat.NewMetrics(ng.conf.Name) 152 | } else { 153 | metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port)) 154 | } 155 | 156 | return metrics 157 | } 158 | 159 | func (ng *engine) getLogHandler() func(http.Handler) http.Handler { 160 | if ng.conf.Verbose { 161 | return handler.DetailedLogHandler 162 | } 163 | 164 | return handler.LogHandler 165 | } 166 | 167 | func (ng *engine) getShedder(priority bool) load.Shedder { 168 | if priority && ng.priorityShedder != nil { 169 | return ng.priorityShedder 170 | } 171 | 172 | return ng.shedder 173 | } 174 | 175 | // notFoundHandler returns a middleware that handles 404 not found requests. 176 | // 获取路由未找到中间件 177 | func (ng *engine) notFoundHandler(next http.Handler) http.Handler { 178 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 179 | chain := alice.New( 180 | handler.TracingHandler(ng.conf.Name, ""), 181 | ng.getLogHandler(), 182 | ) 183 | 184 | var h http.Handler 185 | if next != nil { 186 | h = chain.Then(next) 187 | } else { 188 | h = chain.Then(http.NotFoundHandler()) 189 | } 190 | 191 | cw := response.NewHeaderOnceResponseWriter(w) 192 | h.ServeHTTP(cw, r) 193 | cw.WriteHeader(http.StatusNotFound) 194 | }) 195 | } 196 | 197 | func (ng *engine) setTlsConfig(cfg *tls.Config) { 198 | ng.tlsConfig = cfg 199 | } 200 | 201 | // 设置权限验证失败回调 202 | func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) { 203 | ng.unauthorizedCallback = callback 204 | } 205 | 206 | // 设置签名验证失败回调 207 | func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) { 208 | ng.unsignedCallback = callback 209 | } 210 | 211 | // 签名校验 212 | func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) { 213 | if !signature.enabled { 214 | return func(chain alice.Chain) alice.Chain { 215 | return chain 216 | }, nil 217 | } 218 | 219 | if len(signature.PrivateKeys) == 0 { 220 | if signature.Strict { 221 | return nil, ErrSignatureConfig 222 | } 223 | 224 | return func(chain alice.Chain) alice.Chain { 225 | return chain 226 | }, nil 227 | } 228 | 229 | decrypters := make(map[string]codec.RsaDecrypter) 230 | for _, key := range signature.PrivateKeys { 231 | fingerprint := key.Fingerprint 232 | file := key.KeyFile 233 | decrypter, err := codec.NewRsaDecrypter(file) 234 | if err != nil { 235 | return nil, err 236 | } 237 | 238 | decrypters[fingerprint] = decrypter 239 | } 240 | 241 | return func(chain alice.Chain) alice.Chain { 242 | if ng.unsignedCallback != nil { 243 | return chain.Append(handler.ContentSecurityHandler( 244 | decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback)) 245 | } 246 | 247 | return chain.Append(handler.ContentSecurityHandler( 248 | decrypters, signature.Expiry, signature.Strict)) 249 | }, nil 250 | } 251 | 252 | func (ng *engine) start(router httpx.Router) error { 253 | if err := ng.bindRoutes(router); err != nil { 254 | return err 255 | } 256 | 257 | if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 { 258 | return internal.StartHttp(ng.conf.Host, ng.conf.Port, router) 259 | } 260 | 261 | return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile, 262 | ng.conf.KeyFile, router, func(srv *http.Server) { 263 | if ng.tlsConfig != nil { 264 | srv.TLSConfig = ng.tlsConfig 265 | } 266 | }) 267 | } 268 | 269 | func (ng *engine) use(middleware Middleware) { 270 | ng.middlewares = append(ng.middlewares, middleware) 271 | } 272 | 273 | func convertMiddleware(ware Middleware) func(http.Handler) http.Handler { 274 | return func(next http.Handler) http.Handler { 275 | return ware(next.ServeHTTP) 276 | } 277 | } 278 | -------------------------------------------------------------------------------- /code/rest/rest/handler/authhandler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "net/http" 7 | "net/http/httputil" 8 | 9 | "gozerosource/code/rest/rest/internal/response" 10 | "gozerosource/code/rest/rest/token" 11 | 12 | "github.com/golang-jwt/jwt/v4" 13 | "github.com/zeromicro/go-zero/core/logx" 14 | ) 15 | 16 | const ( 17 | jwtAudience = "aud" 18 | jwtExpire = "exp" 19 | jwtId = "jti" 20 | jwtIssueAt = "iat" 21 | jwtIssuer = "iss" 22 | jwtNotBefore = "nbf" 23 | jwtSubject = "sub" 24 | noDetailReason = "no detail reason" 25 | ) 26 | 27 | var ( 28 | errInvalidToken = errors.New("invalid auth token") 29 | errNoClaims = errors.New("no auth params") 30 | ) 31 | 32 | type ( 33 | // A AuthorizeOptions is authorize options. 34 | AuthorizeOptions struct { 35 | PrevSecret string // 上一个 Secret 36 | Callback UnauthorizedCallback // 验证失败回调 37 | } 38 | 39 | // UnauthorizedCallback defines the method of unauthorized callback. 40 | // 验证失败标准函数 41 | UnauthorizedCallback func(w http.ResponseWriter, r *http.Request, err error) 42 | // AuthorizeOption defines the method to customize an AuthorizeOptions. 43 | // 权限验证标准配置 44 | AuthorizeOption func(opts *AuthorizeOptions) 45 | ) 46 | 47 | // Authorize returns an authorize middleware. 48 | // 权限验证中间件 49 | func Authorize(secret string, opts ...AuthorizeOption) func(http.Handler) http.Handler { 50 | var authOpts AuthorizeOptions 51 | for _, opt := range opts { 52 | opt(&authOpts) 53 | } 54 | 55 | parser := token.NewTokenParser() // 加载 token 解析器 56 | return func(next http.Handler) http.Handler { 57 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 58 | // 解析 token 59 | tok, err := parser.ParseToken(r, secret, authOpts.PrevSecret) 60 | if err != nil { 61 | unauthorized(w, r, err, authOpts.Callback) 62 | return 63 | } 64 | 65 | if !tok.Valid { 66 | unauthorized(w, r, errInvalidToken, authOpts.Callback) 67 | return 68 | } 69 | 70 | claims, ok := tok.Claims.(jwt.MapClaims) 71 | if !ok { 72 | unauthorized(w, r, errNoClaims, authOpts.Callback) 73 | return 74 | } 75 | 76 | ctx := r.Context() // 获取上下文 77 | for k, v := range claims { 78 | switch k { 79 | case jwtAudience, jwtExpire, jwtId, jwtIssueAt, jwtIssuer, jwtNotBefore, jwtSubject: 80 | // ignore the standard claims 81 | // 忽略 jwt 标准声明 82 | default: 83 | ctx = context.WithValue(ctx, k, v) // 解析后的数据注入上下文 84 | } 85 | } 86 | 87 | next.ServeHTTP(w, r.WithContext(ctx)) 88 | }) 89 | } 90 | } 91 | 92 | // WithPrevSecret returns an AuthorizeOption with setting previous secret. 93 | // 设置上一个secret 94 | func WithPrevSecret(secret string) AuthorizeOption { 95 | return func(opts *AuthorizeOptions) { 96 | opts.PrevSecret = secret 97 | } 98 | } 99 | 100 | // WithUnauthorizedCallback returns an AuthorizeOption with setting unauthorized callback. 101 | // 设置验证失败回调 102 | func WithUnauthorizedCallback(callback UnauthorizedCallback) AuthorizeOption { 103 | return func(opts *AuthorizeOptions) { 104 | opts.Callback = callback 105 | } 106 | } 107 | 108 | // 记录详细日志 109 | func detailAuthLog(r *http.Request, reason string) { 110 | // discard dump error, only for debug purpose 111 | details, _ := httputil.DumpRequest(r, true) 112 | logx.Errorf("authorize failed: %s\n=> %+v", reason, string(details)) 113 | } 114 | 115 | // 加载验证失败回调&记录日志 116 | func unauthorized(w http.ResponseWriter, r *http.Request, err error, callback UnauthorizedCallback) { 117 | writer := response.NewHeaderOnceResponseWriter(w) 118 | 119 | if err != nil { 120 | detailAuthLog(r, err.Error()) 121 | } else { 122 | detailAuthLog(r, noDetailReason) 123 | } 124 | 125 | // let callback go first, to make sure we respond with user-defined HTTP header 126 | if callback != nil { 127 | callback(writer, r, err) 128 | } 129 | 130 | // if user not setting HTTP header, we set header with 401 131 | writer.WriteHeader(http.StatusUnauthorized) 132 | } 133 | -------------------------------------------------------------------------------- /code/rest/rest/handler/breakerhandler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "strings" 7 | 8 | "gozerosource/code/rest/rest/httpx" 9 | "gozerosource/code/rest/rest/internal/response" 10 | 11 | "github.com/zeromicro/go-zero/core/breaker" 12 | "github.com/zeromicro/go-zero/core/logx" 13 | "github.com/zeromicro/go-zero/core/stat" 14 | ) 15 | 16 | const breakerSeparator = "://" 17 | 18 | // BreakerHandler returns a break circuit middleware. 19 | // 断路器中间件 20 | func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handler) http.Handler { 21 | brk := breaker.NewBreaker(breaker.WithName(strings.Join([]string{method, path}, breakerSeparator))) 22 | return func(next http.Handler) http.Handler { 23 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 24 | promise, err := brk.Allow() 25 | if err != nil { 26 | metrics.AddDrop() 27 | logx.Errorf("[http] dropped, %s - %s - %s", 28 | r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent()) 29 | w.WriteHeader(http.StatusServiceUnavailable) 30 | return 31 | } 32 | 33 | cw := &response.WithCodeResponseWriter{Writer: w} 34 | defer func() { 35 | if cw.Code < http.StatusInternalServerError { 36 | promise.Accept() 37 | } else { 38 | promise.Reject(fmt.Sprintf("%d %s", cw.Code, http.StatusText(cw.Code))) 39 | } 40 | }() 41 | next.ServeHTTP(cw, r) 42 | }) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /code/rest/rest/handler/contentsecurityhandler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "net/http" 5 | "time" 6 | 7 | "gozerosource/code/rest/rest/httpx" 8 | "gozerosource/code/rest/rest/internal/security" 9 | 10 | "github.com/zeromicro/go-zero/core/codec" 11 | "github.com/zeromicro/go-zero/core/logx" 12 | ) 13 | 14 | const contentSecurity = "X-Content-Security" 15 | 16 | // UnsignedCallback defines the method of the unsigned callback. 17 | type UnsignedCallback func(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) 18 | 19 | // ContentSecurityHandler returns a middleware to verify content security. 20 | // 安全验证中间件 21 | func ContentSecurityHandler(decrypters map[string]codec.RsaDecrypter, tolerance time.Duration, 22 | strict bool, callbacks ...UnsignedCallback, 23 | ) func(http.Handler) http.Handler { 24 | if len(callbacks) == 0 { 25 | callbacks = append(callbacks, handleVerificationFailure) 26 | } 27 | 28 | return func(next http.Handler) http.Handler { 29 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 30 | switch r.Method { 31 | case http.MethodDelete, http.MethodGet, http.MethodPost, http.MethodPut: 32 | header, err := security.ParseContentSecurity(decrypters, r) 33 | if err != nil { 34 | logx.Errorf("Signature parse failed, X-Content-Security: %s, error: %s", 35 | r.Header.Get(contentSecurity), err.Error()) 36 | executeCallbacks(w, r, next, strict, httpx.CodeSignatureInvalidHeader, callbacks) 37 | } else if code := security.VerifySignature(r, header, tolerance); code != httpx.CodeSignaturePass { 38 | logx.Errorf("Signature verification failed, X-Content-Security: %s", 39 | r.Header.Get(contentSecurity)) 40 | executeCallbacks(w, r, next, strict, code, callbacks) 41 | } else if r.ContentLength > 0 && header.Encrypted() { 42 | CryptionHandler(header.Key)(next).ServeHTTP(w, r) 43 | } else { 44 | next.ServeHTTP(w, r) 45 | } 46 | default: 47 | next.ServeHTTP(w, r) 48 | } 49 | }) 50 | } 51 | } 52 | 53 | func executeCallbacks(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, 54 | code int, callbacks []UnsignedCallback, 55 | ) { 56 | for _, callback := range callbacks { 57 | callback(w, r, next, strict, code) 58 | } 59 | } 60 | 61 | func handleVerificationFailure(w http.ResponseWriter, r *http.Request, next http.Handler, strict bool, code int) { 62 | if strict { 63 | w.WriteHeader(http.StatusForbidden) 64 | } else { 65 | next.ServeHTTP(w, r) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /code/rest/rest/handler/cryptionhandler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "encoding/base64" 7 | "errors" 8 | "io" 9 | "io/ioutil" 10 | "net" 11 | "net/http" 12 | 13 | "github.com/zeromicro/go-zero/core/codec" 14 | "github.com/zeromicro/go-zero/core/logx" 15 | ) 16 | 17 | const maxBytes = 1 << 20 // 1 MiB 18 | 19 | var errContentLengthExceeded = errors.New("content length exceeded") 20 | 21 | // CryptionHandler returns a middleware to handle cryption. 22 | // 加密解密中间件 23 | func CryptionHandler(key []byte) func(http.Handler) http.Handler { 24 | return func(next http.Handler) http.Handler { 25 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 26 | cw := newCryptionResponseWriter(w) 27 | defer cw.flush(key) 28 | 29 | if r.ContentLength <= 0 { 30 | next.ServeHTTP(cw, r) 31 | return 32 | } 33 | 34 | if err := decryptBody(key, r); err != nil { 35 | w.WriteHeader(http.StatusBadRequest) 36 | return 37 | } 38 | 39 | next.ServeHTTP(cw, r) 40 | }) 41 | } 42 | } 43 | 44 | func decryptBody(key []byte, r *http.Request) error { 45 | if r.ContentLength > maxBytes { 46 | return errContentLengthExceeded 47 | } 48 | 49 | var content []byte 50 | var err error 51 | if r.ContentLength > 0 { 52 | content = make([]byte, r.ContentLength) 53 | _, err = io.ReadFull(r.Body, content) 54 | } else { 55 | content, err = ioutil.ReadAll(io.LimitReader(r.Body, maxBytes)) 56 | } 57 | if err != nil { 58 | return err 59 | } 60 | 61 | content, err = base64.StdEncoding.DecodeString(string(content)) 62 | if err != nil { 63 | return err 64 | } 65 | 66 | output, err := codec.EcbDecrypt(key, content) 67 | if err != nil { 68 | return err 69 | } 70 | 71 | var buf bytes.Buffer 72 | buf.Write(output) 73 | r.Body = ioutil.NopCloser(&buf) 74 | 75 | return nil 76 | } 77 | 78 | type cryptionResponseWriter struct { 79 | http.ResponseWriter 80 | buf *bytes.Buffer 81 | } 82 | 83 | func newCryptionResponseWriter(w http.ResponseWriter) *cryptionResponseWriter { 84 | return &cryptionResponseWriter{ 85 | ResponseWriter: w, 86 | buf: new(bytes.Buffer), 87 | } 88 | } 89 | 90 | func (w *cryptionResponseWriter) Flush() { 91 | if flusher, ok := w.ResponseWriter.(http.Flusher); ok { 92 | flusher.Flush() 93 | } 94 | } 95 | 96 | func (w *cryptionResponseWriter) Header() http.Header { 97 | return w.ResponseWriter.Header() 98 | } 99 | 100 | // Hijack implements the http.Hijacker interface. 101 | // This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it. 102 | func (w *cryptionResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 103 | if hijacked, ok := w.ResponseWriter.(http.Hijacker); ok { 104 | return hijacked.Hijack() 105 | } 106 | 107 | return nil, nil, errors.New("server doesn't support hijacking") 108 | } 109 | 110 | func (w *cryptionResponseWriter) Write(p []byte) (int, error) { 111 | return w.buf.Write(p) 112 | } 113 | 114 | func (w *cryptionResponseWriter) WriteHeader(statusCode int) { 115 | w.ResponseWriter.WriteHeader(statusCode) 116 | } 117 | 118 | func (w *cryptionResponseWriter) flush(key []byte) { 119 | if w.buf.Len() == 0 { 120 | return 121 | } 122 | 123 | content, err := codec.EcbEncrypt(key, w.buf.Bytes()) 124 | if err != nil { 125 | w.WriteHeader(http.StatusInternalServerError) 126 | return 127 | } 128 | 129 | body := base64.StdEncoding.EncodeToString(content) 130 | if n, err := io.WriteString(w.ResponseWriter, body); err != nil { 131 | logx.Errorf("write response failed, error: %s", err) 132 | } else if n < len(content) { 133 | logx.Errorf("actual bytes: %d, written bytes: %d", len(content), n) 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /code/rest/rest/handler/gunziphandler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "compress/gzip" 5 | "net/http" 6 | "strings" 7 | 8 | "github.com/zeromicro/go-zero/rest/httpx" 9 | ) 10 | 11 | const gzipEncoding = "gzip" 12 | 13 | // GunzipHandler returns a middleware to gunzip http request body. 14 | // zip 压缩中间件 15 | func GunzipHandler(next http.Handler) http.Handler { 16 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 17 | if strings.Contains(r.Header.Get(httpx.ContentEncoding), gzipEncoding) { 18 | reader, err := gzip.NewReader(r.Body) 19 | if err != nil { 20 | w.WriteHeader(http.StatusBadRequest) 21 | return 22 | } 23 | 24 | r.Body = reader 25 | } 26 | 27 | next.ServeHTTP(w, r) 28 | }) 29 | } 30 | -------------------------------------------------------------------------------- /code/rest/rest/handler/loghandler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "io/ioutil" 11 | "net" 12 | "net/http" 13 | "net/http/httputil" 14 | "strings" 15 | "time" 16 | 17 | "gozerosource/code/rest/rest/httpx" 18 | "gozerosource/code/rest/rest/internal" 19 | 20 | "github.com/zeromicro/go-zero/core/iox" 21 | "github.com/zeromicro/go-zero/core/logx" 22 | "github.com/zeromicro/go-zero/core/syncx" 23 | "github.com/zeromicro/go-zero/core/timex" 24 | "github.com/zeromicro/go-zero/core/utils" 25 | ) 26 | 27 | const ( 28 | limitBodyBytes = 1024 29 | defaultSlowThreshold = time.Millisecond * 500 30 | ) 31 | 32 | var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold) 33 | 34 | type loggedResponseWriter struct { 35 | w http.ResponseWriter 36 | r *http.Request 37 | code int 38 | } 39 | 40 | func (w *loggedResponseWriter) Flush() { 41 | if flusher, ok := w.w.(http.Flusher); ok { 42 | flusher.Flush() 43 | } 44 | } 45 | 46 | func (w *loggedResponseWriter) Header() http.Header { 47 | return w.w.Header() 48 | } 49 | 50 | // Hijack implements the http.Hijacker interface. 51 | // This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it. 52 | func (w *loggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 53 | if hijacked, ok := w.w.(http.Hijacker); ok { 54 | return hijacked.Hijack() 55 | } 56 | 57 | return nil, nil, errors.New("server doesn't support hijacking") 58 | } 59 | 60 | func (w *loggedResponseWriter) Write(bytes []byte) (int, error) { 61 | return w.w.Write(bytes) 62 | } 63 | 64 | func (w *loggedResponseWriter) WriteHeader(code int) { 65 | w.w.WriteHeader(code) 66 | w.code = code 67 | } 68 | 69 | // LogHandler returns a middleware that logs http request and response. 70 | // 日志中间件 71 | func LogHandler(next http.Handler) http.Handler { 72 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 73 | timer := utils.NewElapsedTimer() 74 | logs := new(internal.LogCollector) 75 | lrw := loggedResponseWriter{ 76 | w: w, 77 | r: r, 78 | code: http.StatusOK, 79 | } 80 | 81 | var dup io.ReadCloser 82 | r.Body, dup = iox.DupReadCloser(r.Body) 83 | next.ServeHTTP(&lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs))) 84 | r.Body = dup 85 | logBrief(r, lrw.code, timer, logs) 86 | }) 87 | } 88 | 89 | type detailLoggedResponseWriter struct { 90 | writer *loggedResponseWriter 91 | buf *bytes.Buffer 92 | } 93 | 94 | func newDetailLoggedResponseWriter(writer *loggedResponseWriter, buf *bytes.Buffer) *detailLoggedResponseWriter { 95 | return &detailLoggedResponseWriter{ 96 | writer: writer, 97 | buf: buf, 98 | } 99 | } 100 | 101 | func (w *detailLoggedResponseWriter) Flush() { 102 | w.writer.Flush() 103 | } 104 | 105 | func (w *detailLoggedResponseWriter) Header() http.Header { 106 | return w.writer.Header() 107 | } 108 | 109 | // Hijack implements the http.Hijacker interface. 110 | // This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it. 111 | func (w *detailLoggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 112 | if hijacked, ok := w.writer.w.(http.Hijacker); ok { 113 | return hijacked.Hijack() 114 | } 115 | 116 | return nil, nil, errors.New("server doesn't support hijacking") 117 | } 118 | 119 | func (w *detailLoggedResponseWriter) Write(bs []byte) (int, error) { 120 | w.buf.Write(bs) 121 | return w.writer.Write(bs) 122 | } 123 | 124 | func (w *detailLoggedResponseWriter) WriteHeader(code int) { 125 | w.writer.WriteHeader(code) 126 | } 127 | 128 | // DetailedLogHandler returns a middleware that logs http request and response in details. 129 | func DetailedLogHandler(next http.Handler) http.Handler { 130 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 131 | timer := utils.NewElapsedTimer() 132 | var buf bytes.Buffer 133 | lrw := newDetailLoggedResponseWriter(&loggedResponseWriter{ 134 | w: w, 135 | r: r, 136 | code: http.StatusOK, 137 | }, &buf) 138 | 139 | var dup io.ReadCloser 140 | r.Body, dup = iox.DupReadCloser(r.Body) 141 | logs := new(internal.LogCollector) 142 | next.ServeHTTP(lrw, r.WithContext(context.WithValue(r.Context(), internal.LogContext, logs))) 143 | r.Body = dup 144 | logDetails(r, lrw, timer, logs) 145 | }) 146 | } 147 | 148 | // SetSlowThreshold sets the slow threshold. 149 | func SetSlowThreshold(threshold time.Duration) { 150 | slowThreshold.Set(threshold) 151 | } 152 | 153 | func dumpRequest(r *http.Request) string { 154 | reqContent, err := httputil.DumpRequest(r, true) 155 | if err != nil { 156 | return err.Error() 157 | } 158 | 159 | return string(reqContent) 160 | } 161 | 162 | func logBrief(r *http.Request, code int, timer *utils.ElapsedTimer, logs *internal.LogCollector) { 163 | var buf bytes.Buffer 164 | duration := timer.Duration() 165 | logger := logx.WithContext(r.Context()).WithDuration(duration) 166 | buf.WriteString(fmt.Sprintf("[HTTP] %s - %d - %s - %s - %s", 167 | r.Method, code, r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent())) 168 | if duration > slowThreshold.Load() { 169 | logger.Slowf("[HTTP] %s - %d - %s - %s - %s - slowcall(%s)", 170 | r.Method, code, r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent(), timex.ReprOfDuration(duration)) 171 | } 172 | 173 | ok := isOkResponse(code) 174 | if !ok { 175 | fullReq := dumpRequest(r) 176 | limitReader := io.LimitReader(strings.NewReader(fullReq), limitBodyBytes) 177 | body, err := ioutil.ReadAll(limitReader) 178 | if err != nil { 179 | buf.WriteString(fmt.Sprintf("\n%s", fullReq)) 180 | } else { 181 | buf.WriteString(fmt.Sprintf("\n%s", string(body))) 182 | } 183 | } 184 | 185 | body := logs.Flush() 186 | if len(body) > 0 { 187 | buf.WriteString(fmt.Sprintf("\n%s", body)) 188 | } 189 | 190 | if ok { 191 | logger.Info(buf.String()) 192 | } else { 193 | logger.Error(buf.String()) 194 | } 195 | } 196 | 197 | func logDetails(r *http.Request, response *detailLoggedResponseWriter, timer *utils.ElapsedTimer, 198 | logs *internal.LogCollector, 199 | ) { 200 | var buf bytes.Buffer 201 | duration := timer.Duration() 202 | code := response.writer.code 203 | logger := logx.WithContext(r.Context()) 204 | buf.WriteString(fmt.Sprintf("[HTTP] %s - %d - %s - %s\n=> %s\n", 205 | r.Method, code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r))) 206 | if duration > defaultSlowThreshold { 207 | logger.Slowf("[HTTP] %s - %d - %s - slowcall(%s)\n=> %s\n", 208 | r.Method, code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r)) 209 | } 210 | 211 | body := logs.Flush() 212 | if len(body) > 0 { 213 | buf.WriteString(fmt.Sprintf("%s\n", body)) 214 | } 215 | 216 | respBuf := response.buf.Bytes() 217 | if len(respBuf) > 0 { 218 | buf.WriteString(fmt.Sprintf("<= %s", respBuf)) 219 | } 220 | 221 | if isOkResponse(code) { 222 | logger.Info(buf.String()) 223 | } else { 224 | logger.Error(buf.String()) 225 | } 226 | } 227 | 228 | func isOkResponse(code int) bool { 229 | // not server error 230 | return code < http.StatusInternalServerError 231 | } 232 | -------------------------------------------------------------------------------- /code/rest/rest/handler/maxbyteshandler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "net/http" 5 | 6 | "gozerosource/code/rest/rest/internal" 7 | ) 8 | 9 | // MaxBytesHandler returns a middleware that limit reading of http request body. 10 | // 最大请求数据限制中间件 11 | func MaxBytesHandler(n int64) func(http.Handler) http.Handler { 12 | if n <= 0 { 13 | return func(next http.Handler) http.Handler { 14 | return next 15 | } 16 | } 17 | 18 | return func(next http.Handler) http.Handler { 19 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 20 | if r.ContentLength > n { 21 | internal.Errorf(r, "request entity too large, limit is %d, but got %d, rejected with code %d", 22 | n, r.ContentLength, http.StatusRequestEntityTooLarge) 23 | w.WriteHeader(http.StatusRequestEntityTooLarge) 24 | } else { 25 | next.ServeHTTP(w, r) 26 | } 27 | }) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /code/rest/rest/handler/maxconnshandler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "net/http" 5 | 6 | "gozerosource/code/rest/rest/internal" 7 | 8 | "github.com/zeromicro/go-zero/core/logx" 9 | "github.com/zeromicro/go-zero/core/syncx" 10 | ) 11 | 12 | // MaxConns returns a middleware that limit the concurrent connections. 13 | // 最大请求连接数限制中间件 14 | func MaxConns(n int) func(http.Handler) http.Handler { 15 | if n <= 0 { 16 | return func(next http.Handler) http.Handler { 17 | return next 18 | } 19 | } 20 | 21 | return func(next http.Handler) http.Handler { 22 | latch := syncx.NewLimit(n) 23 | 24 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 25 | if latch.TryBorrow() { 26 | defer func() { 27 | if err := latch.Return(); err != nil { 28 | logx.Error(err) 29 | } 30 | }() 31 | 32 | next.ServeHTTP(w, r) 33 | } else { 34 | internal.Errorf(r, "concurrent connections over %d, rejected with code %d", 35 | n, http.StatusServiceUnavailable) 36 | w.WriteHeader(http.StatusServiceUnavailable) 37 | } 38 | }) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /code/rest/rest/handler/metrichandler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/zeromicro/go-zero/core/stat" 7 | "github.com/zeromicro/go-zero/core/timex" 8 | ) 9 | 10 | // MetricHandler returns a middleware that stat the metrics. 11 | // 请求指标统计中间件 12 | func MetricHandler(metrics *stat.Metrics) func(http.Handler) http.Handler { 13 | return func(next http.Handler) http.Handler { 14 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 15 | startTime := timex.Now() 16 | defer func() { 17 | metrics.Add(stat.Task{ 18 | Duration: timex.Since(startTime), 19 | }) 20 | }() 21 | 22 | next.ServeHTTP(w, r) 23 | }) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /code/rest/rest/handler/prometheushandler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "net/http" 5 | "strconv" 6 | "time" 7 | 8 | "gozerosource/code/rest/rest/internal/response" 9 | 10 | "github.com/zeromicro/go-zero/core/metric" 11 | "github.com/zeromicro/go-zero/core/prometheus" 12 | "github.com/zeromicro/go-zero/core/timex" 13 | ) 14 | 15 | const serverNamespace = "http_server" 16 | 17 | var ( 18 | metricServerReqDur = metric.NewHistogramVec(&metric.HistogramVecOpts{ 19 | Namespace: serverNamespace, 20 | Subsystem: "requests", 21 | Name: "duration_ms", 22 | Help: "http server requests duration(ms).", 23 | Labels: []string{"path"}, 24 | Buckets: []float64{5, 10, 25, 50, 100, 250, 500, 1000}, 25 | }) 26 | 27 | metricServerReqCodeTotal = metric.NewCounterVec(&metric.CounterVecOpts{ 28 | Namespace: serverNamespace, 29 | Subsystem: "requests", 30 | Name: "code_total", 31 | Help: "http server requests error count.", 32 | Labels: []string{"path", "code"}, 33 | }) 34 | ) 35 | 36 | // PrometheusHandler returns a middleware that reports stats to prometheus. 37 | // prometheus 上报中间件 38 | func PrometheusHandler(path string) func(http.Handler) http.Handler { 39 | return func(next http.Handler) http.Handler { 40 | if !prometheus.Enabled() { 41 | return next 42 | } 43 | 44 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 45 | startTime := timex.Now() 46 | cw := &response.WithCodeResponseWriter{Writer: w} 47 | defer func() { 48 | metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), path) 49 | metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code)) 50 | }() 51 | 52 | next.ServeHTTP(cw, r) 53 | }) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /code/rest/rest/handler/recoverhandler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "runtime/debug" 7 | 8 | "gozerosource/code/rest/rest/internal" 9 | ) 10 | 11 | // RecoverHandler returns a middleware that recovers if panic happens. 12 | // 错误捕获中间件 13 | func RecoverHandler(next http.Handler) http.Handler { 14 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 15 | defer func() { 16 | if result := recover(); result != nil { 17 | internal.Error(r, fmt.Sprintf("%v\n%s", result, debug.Stack())) 18 | w.WriteHeader(http.StatusInternalServerError) 19 | } 20 | }() 21 | 22 | next.ServeHTTP(w, r) 23 | }) 24 | } 25 | -------------------------------------------------------------------------------- /code/rest/rest/handler/sheddinghandler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "net/http" 5 | "sync" 6 | 7 | "gozerosource/code/rest/rest/httpx" 8 | "gozerosource/code/rest/rest/internal/response" 9 | 10 | "github.com/zeromicro/go-zero/core/load" 11 | "github.com/zeromicro/go-zero/core/logx" 12 | "github.com/zeromicro/go-zero/core/stat" 13 | ) 14 | 15 | const serviceType = "api" 16 | 17 | var ( 18 | sheddingStat *load.SheddingStat 19 | lock sync.Mutex 20 | ) 21 | 22 | // SheddingHandler returns a middleware that does load shedding. 23 | // 过载保护中间件 24 | func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Handler) http.Handler { 25 | if shedder == nil { 26 | return func(next http.Handler) http.Handler { 27 | return next 28 | } 29 | } 30 | 31 | ensureSheddingStat() 32 | 33 | return func(next http.Handler) http.Handler { 34 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 35 | sheddingStat.IncrementTotal() 36 | promise, err := shedder.Allow() 37 | if err != nil { 38 | metrics.AddDrop() 39 | sheddingStat.IncrementDrop() 40 | logx.Errorf("[http] dropped, %s - %s - %s", 41 | r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent()) 42 | w.WriteHeader(http.StatusServiceUnavailable) 43 | return 44 | } 45 | 46 | cw := &response.WithCodeResponseWriter{Writer: w} 47 | defer func() { 48 | if cw.Code == http.StatusServiceUnavailable { 49 | promise.Fail() 50 | } else { 51 | sheddingStat.IncrementPass() 52 | promise.Pass() 53 | } 54 | }() 55 | next.ServeHTTP(cw, r) 56 | }) 57 | } 58 | } 59 | 60 | func ensureSheddingStat() { 61 | lock.Lock() 62 | if sheddingStat == nil { 63 | sheddingStat = load.NewSheddingStat(serviceType) 64 | } 65 | lock.Unlock() 66 | } 67 | -------------------------------------------------------------------------------- /code/rest/rest/handler/timeouthandler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "path" 11 | "runtime" 12 | "strings" 13 | "sync" 14 | "time" 15 | 16 | "gozerosource/code/rest/rest/httpx" 17 | "gozerosource/code/rest/rest/internal" 18 | ) 19 | 20 | const ( 21 | statusClientClosedRequest = 499 22 | reason = "Request Timeout" 23 | ) 24 | 25 | // TimeoutHandler returns the handler with given timeout. 26 | // If client closed request, code 499 will be logged. 27 | // Notice: even if canceled in server side, 499 will be logged as well. 28 | // 超时控制中间件 29 | func TimeoutHandler(duration time.Duration) func(http.Handler) http.Handler { 30 | return func(next http.Handler) http.Handler { 31 | if duration > 0 { 32 | return &timeoutHandler{ 33 | handler: next, 34 | dt: duration, 35 | } 36 | } 37 | 38 | return next 39 | } 40 | } 41 | 42 | // timeoutHandler is the handler that controls the request timeout. 43 | // Why we implement it on our own, because the stdlib implementation 44 | // treats the ClientClosedRequest as http.StatusServiceUnavailable. 45 | // And we write the codes in logs as code 499, which is defined by nginx. 46 | type timeoutHandler struct { 47 | handler http.Handler 48 | dt time.Duration 49 | } 50 | 51 | func (h *timeoutHandler) errorBody() string { 52 | return reason 53 | } 54 | 55 | func (h *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 56 | ctx, cancelCtx := context.WithTimeout(r.Context(), h.dt) 57 | defer cancelCtx() 58 | 59 | r = r.WithContext(ctx) 60 | done := make(chan struct{}) 61 | tw := &timeoutWriter{ 62 | w: w, 63 | h: make(http.Header), 64 | req: r, 65 | } 66 | panicChan := make(chan interface{}, 1) 67 | go func() { 68 | defer func() { 69 | if p := recover(); p != nil { 70 | panicChan <- p 71 | } 72 | }() 73 | h.handler.ServeHTTP(tw, r) 74 | close(done) 75 | }() 76 | select { 77 | case p := <-panicChan: 78 | panic(p) 79 | case <-done: 80 | tw.mu.Lock() 81 | defer tw.mu.Unlock() 82 | dst := w.Header() 83 | for k, vv := range tw.h { 84 | dst[k] = vv 85 | } 86 | if !tw.wroteHeader { 87 | tw.code = http.StatusOK 88 | } 89 | w.WriteHeader(tw.code) 90 | w.Write(tw.wbuf.Bytes()) 91 | case <-ctx.Done(): 92 | tw.mu.Lock() 93 | defer tw.mu.Unlock() 94 | // there isn't any user-defined middleware before TimoutHandler, 95 | // so we can guarantee that cancelation in biz related code won't come here. 96 | httpx.Error(w, ctx.Err(), func(w http.ResponseWriter, err error) { 97 | if errors.Is(err, context.Canceled) { 98 | w.WriteHeader(statusClientClosedRequest) 99 | } else { 100 | w.WriteHeader(http.StatusServiceUnavailable) 101 | } 102 | io.WriteString(w, h.errorBody()) 103 | }) 104 | tw.timedOut = true 105 | } 106 | } 107 | 108 | type timeoutWriter struct { 109 | w http.ResponseWriter 110 | h http.Header 111 | wbuf bytes.Buffer 112 | req *http.Request 113 | 114 | mu sync.Mutex 115 | timedOut bool 116 | wroteHeader bool 117 | code int 118 | } 119 | 120 | var _ http.Pusher = (*timeoutWriter)(nil) 121 | 122 | // Header returns the underline temporary http.Header. 123 | func (tw *timeoutWriter) Header() http.Header { return tw.h } 124 | 125 | // Push implements the Pusher interface. 126 | func (tw *timeoutWriter) Push(target string, opts *http.PushOptions) error { 127 | if pusher, ok := tw.w.(http.Pusher); ok { 128 | return pusher.Push(target, opts) 129 | } 130 | return http.ErrNotSupported 131 | } 132 | 133 | // Write writes the data to the connection as part of an HTTP reply. 134 | // Timeout and multiple header written are guarded. 135 | func (tw *timeoutWriter) Write(p []byte) (int, error) { 136 | tw.mu.Lock() 137 | defer tw.mu.Unlock() 138 | 139 | if tw.timedOut { 140 | return 0, http.ErrHandlerTimeout 141 | } 142 | 143 | if !tw.wroteHeader { 144 | tw.writeHeaderLocked(http.StatusOK) 145 | } 146 | return tw.wbuf.Write(p) 147 | } 148 | 149 | func (tw *timeoutWriter) writeHeaderLocked(code int) { 150 | checkWriteHeaderCode(code) 151 | 152 | switch { 153 | case tw.timedOut: 154 | return 155 | case tw.wroteHeader: 156 | if tw.req != nil { 157 | caller := relevantCaller() 158 | internal.Errorf(tw.req, "http: superfluous response.WriteHeader call from %s (%s:%d)", 159 | caller.Function, path.Base(caller.File), caller.Line) 160 | } 161 | default: 162 | tw.wroteHeader = true 163 | tw.code = code 164 | } 165 | } 166 | 167 | func (tw *timeoutWriter) WriteHeader(code int) { 168 | tw.mu.Lock() 169 | defer tw.mu.Unlock() 170 | tw.writeHeaderLocked(code) 171 | } 172 | 173 | func checkWriteHeaderCode(code int) { 174 | if code < 100 || code > 599 { 175 | panic(fmt.Sprintf("invalid WriteHeader code %v", code)) 176 | } 177 | } 178 | 179 | // relevantCaller searches the call stack for the first function outside of net/http. 180 | // The purpose of this function is to provide more helpful error messages. 181 | func relevantCaller() runtime.Frame { 182 | pc := make([]uintptr, 16) 183 | n := runtime.Callers(1, pc) 184 | frames := runtime.CallersFrames(pc[:n]) 185 | var frame runtime.Frame 186 | for { 187 | frame, more := frames.Next() 188 | if !strings.HasPrefix(frame.Function, "net/http.") { 189 | return frame 190 | } 191 | if !more { 192 | break 193 | } 194 | } 195 | return frame 196 | } 197 | -------------------------------------------------------------------------------- /code/rest/rest/handler/tracinghandler.go: -------------------------------------------------------------------------------- 1 | package handler 2 | 3 | import ( 4 | "net/http" 5 | 6 | "github.com/zeromicro/go-zero/core/trace" 7 | "go.opentelemetry.io/otel" 8 | "go.opentelemetry.io/otel/propagation" 9 | semconv "go.opentelemetry.io/otel/semconv/v1.4.0" 10 | oteltrace "go.opentelemetry.io/otel/trace" 11 | ) 12 | 13 | // TracingHandler return a middleware that process the opentelemetry. 14 | // 链路追踪中间件 15 | func TracingHandler(serviceName, path string) func(http.Handler) http.Handler { 16 | return func(next http.Handler) http.Handler { 17 | // GetTextMapPropagator返回全局TextMapPropagator。如果没有 18 | // 设置,将返回一个No-Op TextMapPropagator。 19 | propagator := otel.GetTextMapPropagator() 20 | // GetTracerProvider返回已注册的全局跟踪器。 21 | // 如果没有注册,则返回NoopTracerProvider的实例。 22 | // 使用跟踪提供者来创建一个命名的跟踪器。例如。 23 | // tracer := otel.GetTracerProvider().Tracer("example.com/foo") 24 | // 或 25 | // tracer := otel.Tracer("example.com/foo") 26 | tracer := otel.GetTracerProvider().Tracer(trace.TraceName) 27 | 28 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 29 | ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header)) 30 | spanName := path 31 | if len(spanName) == 0 { 32 | spanName = r.URL.Path 33 | } 34 | spanCtx, span := tracer.Start( 35 | ctx, 36 | spanName, 37 | oteltrace.WithSpanKind(oteltrace.SpanKindServer), 38 | oteltrace.WithAttributes(semconv.HTTPServerAttributesFromHTTPRequest( 39 | serviceName, spanName, r)...), 40 | ) 41 | defer span.End() 42 | 43 | // convenient for tracking error messages 44 | sc := span.SpanContext() 45 | if sc.HasTraceID() { 46 | w.Header().Set(trace.TraceIdKey, sc.TraceID().String()) 47 | } 48 | 49 | next.ServeHTTP(w, r.WithContext(spanCtx)) 50 | }) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /code/rest/rest/httpx/requests.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | "net/textproto" 7 | "strings" 8 | 9 | "github.com/zeromicro/go-zero/core/mapping" 10 | "github.com/zeromicro/go-zero/rest/pathvar" 11 | ) 12 | 13 | const ( 14 | formKey = "form" 15 | pathKey = "path" 16 | headerKey = "header" 17 | maxMemory = 32 << 20 // 32MB 18 | maxBodyLen = 8 << 20 // 8MB 19 | separator = ";" 20 | tokensInAttribute = 2 21 | ) 22 | 23 | var ( 24 | formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues()) 25 | pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues()) 26 | headerUnmarshaler = mapping.NewUnmarshaler(headerKey, mapping.WithStringValues(), 27 | mapping.WithCanonicalKeyFunc(textproto.CanonicalMIMEHeaderKey)) 28 | ) 29 | 30 | // Parse parses the request. 31 | func Parse(r *http.Request, v interface{}) error { 32 | if err := ParsePath(r, v); err != nil { 33 | return err 34 | } 35 | 36 | if err := ParseForm(r, v); err != nil { 37 | return err 38 | } 39 | 40 | if err := ParseHeaders(r, v); err != nil { 41 | return err 42 | } 43 | 44 | return ParseJsonBody(r, v) 45 | } 46 | 47 | // ParseHeaders parses the headers request. 48 | func ParseHeaders(r *http.Request, v interface{}) error { 49 | m := map[string]interface{}{} 50 | for k, v := range r.Header { 51 | if len(v) == 1 { 52 | m[k] = v[0] 53 | } else { 54 | m[k] = v 55 | } 56 | } 57 | 58 | return headerUnmarshaler.Unmarshal(m, v) 59 | } 60 | 61 | // ParseForm parses the form request. 62 | func ParseForm(r *http.Request, v interface{}) error { 63 | if err := r.ParseForm(); err != nil { 64 | return err 65 | } 66 | 67 | if err := r.ParseMultipartForm(maxMemory); err != nil { 68 | if err != http.ErrNotMultipart { 69 | return err 70 | } 71 | } 72 | 73 | params := make(map[string]interface{}, len(r.Form)) 74 | for name := range r.Form { 75 | formValue := r.Form.Get(name) 76 | if len(formValue) > 0 { 77 | params[name] = formValue 78 | } 79 | } 80 | 81 | return formUnmarshaler.Unmarshal(params, v) 82 | } 83 | 84 | // ParseHeader parses the request header and returns a map. 85 | func ParseHeader(headerValue string) map[string]string { 86 | ret := make(map[string]string) 87 | fields := strings.Split(headerValue, separator) 88 | 89 | for _, field := range fields { 90 | field = strings.TrimSpace(field) 91 | if len(field) == 0 { 92 | continue 93 | } 94 | 95 | kv := strings.SplitN(field, "=", tokensInAttribute) 96 | if len(kv) != tokensInAttribute { 97 | continue 98 | } 99 | 100 | ret[kv[0]] = kv[1] 101 | } 102 | 103 | return ret 104 | } 105 | 106 | // ParseJsonBody parses the post request which contains json in body. 107 | func ParseJsonBody(r *http.Request, v interface{}) error { 108 | if withJsonBody(r) { 109 | reader := io.LimitReader(r.Body, maxBodyLen) 110 | return mapping.UnmarshalJsonReader(reader, v) 111 | } 112 | 113 | return mapping.UnmarshalJsonMap(nil, v) 114 | } 115 | 116 | // ParsePath parses the symbols reside in url path. 117 | // Like http://localhost/bag/:name 118 | func ParsePath(r *http.Request, v interface{}) error { 119 | vars := pathvar.Vars(r) 120 | m := make(map[string]interface{}, len(vars)) 121 | for k, v := range vars { 122 | m[k] = v 123 | } 124 | 125 | return pathUnmarshaler.Unmarshal(m, v) 126 | } 127 | 128 | func withJsonBody(r *http.Request) bool { 129 | return r.ContentLength > 0 && strings.Contains(r.Header.Get(ContentType), ApplicationJson) 130 | } 131 | -------------------------------------------------------------------------------- /code/rest/rest/httpx/responses.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "encoding/json" 5 | "net/http" 6 | "sync" 7 | 8 | "github.com/zeromicro/go-zero/core/logx" 9 | ) 10 | 11 | var ( 12 | errorHandler func(error) (int, interface{}) 13 | lock sync.RWMutex 14 | ) 15 | 16 | // Error writes err into w. 17 | func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) { 18 | lock.RLock() 19 | handler := errorHandler 20 | lock.RUnlock() 21 | 22 | if handler == nil { 23 | if len(fns) > 0 { 24 | fns[0](w, err) 25 | } else { 26 | http.Error(w, err.Error(), http.StatusBadRequest) 27 | } 28 | return 29 | } 30 | 31 | code, body := handler(err) 32 | if body == nil { 33 | w.WriteHeader(code) 34 | return 35 | } 36 | 37 | e, ok := body.(error) 38 | if ok { 39 | http.Error(w, e.Error(), code) 40 | } else { 41 | WriteJson(w, code, body) 42 | } 43 | } 44 | 45 | // Ok writes HTTP 200 OK into w. 46 | func Ok(w http.ResponseWriter) { 47 | w.WriteHeader(http.StatusOK) 48 | } 49 | 50 | // OkJson writes v into w with 200 OK. 51 | func OkJson(w http.ResponseWriter, v interface{}) { 52 | WriteJson(w, http.StatusOK, v) 53 | } 54 | 55 | // SetErrorHandler sets the error handler, which is called on calling Error. 56 | func SetErrorHandler(handler func(error) (int, interface{})) { 57 | lock.Lock() 58 | defer lock.Unlock() 59 | errorHandler = handler 60 | } 61 | 62 | // WriteJson writes v as json string into w with code. 63 | func WriteJson(w http.ResponseWriter, code int, v interface{}) { 64 | w.Header().Set(ContentType, ApplicationJson) 65 | w.WriteHeader(code) 66 | 67 | if bs, err := json.Marshal(v); err != nil { 68 | http.Error(w, err.Error(), http.StatusInternalServerError) 69 | } else if n, err := w.Write(bs); err != nil { 70 | // http.ErrHandlerTimeout has been handled by http.TimeoutHandler, 71 | // so it's ignored here. 72 | if err != http.ErrHandlerTimeout { 73 | logx.Errorf("write response failed, error: %s", err) 74 | } 75 | } else if n < len(bs) { 76 | logx.Errorf("actual bytes: %d, written bytes: %d", len(bs), n) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /code/rest/rest/httpx/router.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import "net/http" 4 | 5 | // Router interface represents a http router that handles http requests. 6 | type Router interface { 7 | http.Handler 8 | Handle(method, path string, handler http.Handler) error 9 | SetNotFoundHandler(handler http.Handler) 10 | SetNotAllowedHandler(handler http.Handler) 11 | } 12 | -------------------------------------------------------------------------------- /code/rest/rest/httpx/util.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import "net/http" 4 | 5 | const xForwardedFor = "X-Forwarded-For" 6 | 7 | // GetRemoteAddr returns the peer address, supports X-Forward-For. 8 | func GetRemoteAddr(r *http.Request) string { 9 | v := r.Header.Get(xForwardedFor) 10 | if len(v) > 0 { 11 | return v 12 | } 13 | 14 | return r.RemoteAddr 15 | } 16 | -------------------------------------------------------------------------------- /code/rest/rest/httpx/vars.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | const ( 4 | // ApplicationJson means application/json. 5 | ApplicationJson = "application/json" 6 | // ContentEncoding means Content-Encoding. 7 | ContentEncoding = "Content-Encoding" 8 | // ContentSecurity means X-Content-Security. 9 | ContentSecurity = "X-Content-Security" 10 | // ContentType means Content-Type. 11 | ContentType = "Content-Type" 12 | // KeyField means key. 13 | KeyField = "key" 14 | // SecretField means secret. 15 | SecretField = "secret" 16 | // TypeField means type. 17 | TypeField = "type" 18 | // CryptionType means cryption. 19 | CryptionType = 1 20 | ) 21 | 22 | const ( 23 | // CodeSignaturePass means signature verification passed. 24 | CodeSignaturePass = iota 25 | // CodeSignatureInvalidHeader means invalid header in signature. 26 | CodeSignatureInvalidHeader 27 | // CodeSignatureWrongTime means wrong timestamp in signature. 28 | CodeSignatureWrongTime 29 | // CodeSignatureInvalidToken means invalid token in signature. 30 | CodeSignatureInvalidToken 31 | ) 32 | -------------------------------------------------------------------------------- /code/rest/rest/internal/cors/handlers.go: -------------------------------------------------------------------------------- 1 | package cors 2 | 3 | import ( 4 | "net/http" 5 | 6 | "gozerosource/code/rest/rest/internal/response" 7 | ) 8 | 9 | const ( 10 | allowOrigin = "Access-Control-Allow-Origin" 11 | allOrigins = "*" 12 | allowMethods = "Access-Control-Allow-Methods" 13 | allowHeaders = "Access-Control-Allow-Headers" 14 | allowCredentials = "Access-Control-Allow-Credentials" 15 | exposeHeaders = "Access-Control-Expose-Headers" 16 | requestMethod = "Access-Control-Request-Method" 17 | requestHeaders = "Access-Control-Request-Headers" 18 | allowHeadersVal = "Content-Type, Origin, X-CSRF-Token, Authorization, AccessToken, Token, Range" 19 | exposeHeadersVal = "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers" 20 | methods = "GET, HEAD, POST, PATCH, PUT, DELETE" 21 | allowTrue = "true" 22 | maxAgeHeader = "Access-Control-Max-Age" 23 | maxAgeHeaderVal = "86400" 24 | varyHeader = "Vary" 25 | originHeader = "Origin" 26 | ) 27 | 28 | // NotAllowedHandler handles cross domain not allowed requests. 29 | // At most one origin can be specified, other origins are ignored if given, default to be *. 30 | func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.Handler { 31 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 32 | gw := response.NewHeaderOnceResponseWriter(w) 33 | checkAndSetHeaders(gw, r, origins) 34 | if fn != nil { 35 | fn(gw) 36 | } 37 | 38 | if r.Method == http.MethodOptions { 39 | gw.WriteHeader(http.StatusNoContent) 40 | } else { 41 | gw.WriteHeader(http.StatusNotFound) 42 | } 43 | }) 44 | } 45 | 46 | // Middleware returns a middleware that adds CORS headers to the response. 47 | func Middleware(fn func(w http.Header), origins ...string) func(http.HandlerFunc) http.HandlerFunc { 48 | return func(next http.HandlerFunc) http.HandlerFunc { 49 | return func(w http.ResponseWriter, r *http.Request) { 50 | checkAndSetHeaders(w, r, origins) 51 | if fn != nil { 52 | fn(w.Header()) 53 | } 54 | 55 | if r.Method == http.MethodOptions { 56 | w.WriteHeader(http.StatusNoContent) 57 | } else { 58 | next(w, r) 59 | } 60 | } 61 | } 62 | } 63 | 64 | func checkAndSetHeaders(w http.ResponseWriter, r *http.Request, origins []string) { 65 | setVaryHeaders(w, r) 66 | 67 | if len(origins) == 0 { 68 | setHeader(w, allOrigins) 69 | return 70 | } 71 | 72 | origin := r.Header.Get(originHeader) 73 | if isOriginAllowed(origins, origin) { 74 | setHeader(w, origin) 75 | } 76 | } 77 | 78 | func isOriginAllowed(allows []string, origin string) bool { 79 | for _, o := range allows { 80 | if o == allOrigins { 81 | return true 82 | } 83 | 84 | if o == origin { 85 | return true 86 | } 87 | } 88 | 89 | return false 90 | } 91 | 92 | func setHeader(w http.ResponseWriter, origin string) { 93 | header := w.Header() 94 | header.Set(allowOrigin, origin) 95 | header.Set(allowMethods, methods) 96 | header.Set(allowHeaders, allowHeadersVal) 97 | header.Set(exposeHeaders, exposeHeadersVal) 98 | if origin != allOrigins { 99 | header.Set(allowCredentials, allowTrue) 100 | } 101 | header.Set(maxAgeHeader, maxAgeHeaderVal) 102 | } 103 | 104 | func setVaryHeaders(w http.ResponseWriter, r *http.Request) { 105 | header := w.Header() 106 | header.Add(varyHeader, originHeader) 107 | if r.Method == http.MethodOptions { 108 | header.Add(varyHeader, requestMethod) 109 | header.Add(varyHeader, requestHeaders) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /code/rest/rest/internal/log.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "net/http" 7 | "sync" 8 | 9 | "github.com/zeromicro/go-zero/core/logx" 10 | "github.com/zeromicro/go-zero/rest/httpx" 11 | ) 12 | 13 | // LogContext is a context key. 14 | var LogContext = contextKey("request_logs") 15 | 16 | // A LogCollector is used to collect logs. 17 | type LogCollector struct { 18 | Messages []string 19 | lock sync.Mutex 20 | } 21 | 22 | // Append appends msg into log context. 23 | func (lc *LogCollector) Append(msg string) { 24 | lc.lock.Lock() 25 | lc.Messages = append(lc.Messages, msg) 26 | lc.lock.Unlock() 27 | } 28 | 29 | // Flush flushes collected logs. 30 | func (lc *LogCollector) Flush() string { 31 | var buffer bytes.Buffer 32 | 33 | start := true 34 | for _, message := range lc.takeAll() { 35 | if start { 36 | start = false 37 | } else { 38 | buffer.WriteByte('\n') 39 | } 40 | buffer.WriteString(message) 41 | } 42 | 43 | return buffer.String() 44 | } 45 | 46 | func (lc *LogCollector) takeAll() []string { 47 | lc.lock.Lock() 48 | messages := lc.Messages 49 | lc.Messages = nil 50 | lc.lock.Unlock() 51 | 52 | return messages 53 | } 54 | 55 | // Error logs the given v along with r in error log. 56 | func Error(r *http.Request, v ...interface{}) { 57 | logx.WithContext(r.Context()).Error(format(r, v...)) 58 | } 59 | 60 | // Errorf logs the given v with format along with r in error log. 61 | func Errorf(r *http.Request, format string, v ...interface{}) { 62 | logx.WithContext(r.Context()).Error(formatf(r, format, v...)) 63 | } 64 | 65 | // Info logs the given v along with r in access log. 66 | func Info(r *http.Request, v ...interface{}) { 67 | appendLog(r, format(r, v...)) 68 | } 69 | 70 | // Infof logs the given v with format along with r in access log. 71 | func Infof(r *http.Request, format string, v ...interface{}) { 72 | appendLog(r, formatf(r, format, v...)) 73 | } 74 | 75 | func appendLog(r *http.Request, message string) { 76 | logs := r.Context().Value(LogContext) 77 | if logs != nil { 78 | logs.(*LogCollector).Append(message) 79 | } 80 | } 81 | 82 | func format(r *http.Request, v ...interface{}) string { 83 | return formatWithReq(r, fmt.Sprint(v...)) 84 | } 85 | 86 | func formatf(r *http.Request, format string, v ...interface{}) string { 87 | return formatWithReq(r, fmt.Sprintf(format, v...)) 88 | } 89 | 90 | func formatWithReq(r *http.Request, v string) string { 91 | return fmt.Sprintf("(%s - %s) %s", r.RequestURI, httpx.GetRemoteAddr(r), v) 92 | } 93 | 94 | type contextKey string 95 | 96 | func (c contextKey) String() string { 97 | return "rest/internal context key " + string(c) 98 | } 99 | -------------------------------------------------------------------------------- /code/rest/rest/internal/response/headeronceresponsewriter.go: -------------------------------------------------------------------------------- 1 | package response 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "net" 7 | "net/http" 8 | ) 9 | 10 | // HeaderOnceResponseWriter is a http.ResponseWriter implementation 11 | // that only the first WriterHeader takes effect. 12 | // HeaderOnceResponseWriter 是一个http.ResponseWriter的实现 13 | // 只有第一个WriterHeader生效 14 | type HeaderOnceResponseWriter struct { 15 | w http.ResponseWriter 16 | wroteHeader bool 17 | } 18 | 19 | // NewHeaderOnceResponseWriter returns a HeaderOnceResponseWriter. 20 | func NewHeaderOnceResponseWriter(w http.ResponseWriter) http.ResponseWriter { 21 | return &HeaderOnceResponseWriter{w: w} 22 | } 23 | 24 | // Flush flushes the response writer. 25 | func (w *HeaderOnceResponseWriter) Flush() { 26 | if flusher, ok := w.w.(http.Flusher); ok { 27 | flusher.Flush() 28 | } 29 | } 30 | 31 | // Header returns the http header. 32 | func (w *HeaderOnceResponseWriter) Header() http.Header { 33 | return w.w.Header() 34 | } 35 | 36 | // Hijack implements the http.Hijacker interface. 37 | // This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it. 38 | func (w *HeaderOnceResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 39 | if hijacked, ok := w.w.(http.Hijacker); ok { 40 | return hijacked.Hijack() 41 | } 42 | 43 | return nil, nil, errors.New("server doesn't support hijacking") 44 | } 45 | 46 | // Write writes bytes into w. 47 | func (w *HeaderOnceResponseWriter) Write(bytes []byte) (int, error) { 48 | return w.w.Write(bytes) 49 | } 50 | 51 | // WriteHeader writes code into w, and not sealing the writer. 52 | func (w *HeaderOnceResponseWriter) WriteHeader(code int) { 53 | if w.wroteHeader { 54 | return 55 | } 56 | 57 | w.w.WriteHeader(code) 58 | w.wroteHeader = true 59 | } 60 | -------------------------------------------------------------------------------- /code/rest/rest/internal/response/withcoderesponsewriter.go: -------------------------------------------------------------------------------- 1 | package response 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "net" 7 | "net/http" 8 | ) 9 | 10 | // A WithCodeResponseWriter is a helper to delay sealing a http.ResponseWriter on writing code. 11 | type WithCodeResponseWriter struct { 12 | Writer http.ResponseWriter 13 | Code int 14 | } 15 | 16 | // Flush flushes the response writer. 17 | func (w *WithCodeResponseWriter) Flush() { 18 | if flusher, ok := w.Writer.(http.Flusher); ok { 19 | flusher.Flush() 20 | } 21 | } 22 | 23 | // Header returns the http header. 24 | func (w *WithCodeResponseWriter) Header() http.Header { 25 | return w.Writer.Header() 26 | } 27 | 28 | // Hijack implements the http.Hijacker interface. 29 | // This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it. 30 | func (w *WithCodeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 31 | if hijacked, ok := w.Writer.(http.Hijacker); ok { 32 | return hijacked.Hijack() 33 | } 34 | 35 | return nil, nil, errors.New("server doesn't support hijacking") 36 | } 37 | 38 | // Write writes bytes into w. 39 | func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) { 40 | return w.Writer.Write(bytes) 41 | } 42 | 43 | // WriteHeader writes code into w, and not sealing the writer. 44 | func (w *WithCodeResponseWriter) WriteHeader(code int) { 45 | w.Writer.WriteHeader(code) 46 | w.Code = code 47 | } 48 | -------------------------------------------------------------------------------- /code/rest/rest/internal/security/contentsecurity.go: -------------------------------------------------------------------------------- 1 | package security 2 | 3 | import ( 4 | "crypto/sha256" 5 | "encoding/base64" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "net/url" 11 | "strconv" 12 | "strings" 13 | "time" 14 | 15 | "github.com/zeromicro/go-zero/core/codec" 16 | "github.com/zeromicro/go-zero/core/iox" 17 | "github.com/zeromicro/go-zero/core/logx" 18 | "github.com/zeromicro/go-zero/rest/httpx" 19 | ) 20 | 21 | const ( 22 | requestUriHeader = "X-Request-Uri" 23 | signatureField = "signature" 24 | timeField = "time" 25 | ) 26 | 27 | var ( 28 | // ErrInvalidContentType is an error that indicates invalid content type. 29 | ErrInvalidContentType = errors.New("invalid content type") 30 | // ErrInvalidHeader is an error that indicates invalid X-Content-Security header. 31 | ErrInvalidHeader = errors.New("invalid X-Content-Security header") 32 | // ErrInvalidKey is an error that indicates invalid key. 33 | ErrInvalidKey = errors.New("invalid key") 34 | // ErrInvalidPublicKey is an error that indicates invalid public key. 35 | ErrInvalidPublicKey = errors.New("invalid public key") 36 | // ErrInvalidSecret is an error that indicates invalid secret. 37 | ErrInvalidSecret = errors.New("invalid secret") 38 | ) 39 | 40 | // A ContentSecurityHeader is a content security header. 41 | type ContentSecurityHeader struct { 42 | Key []byte 43 | Timestamp string 44 | ContentType int 45 | Signature string 46 | } 47 | 48 | // Encrypted checks if it's a crypted request. 49 | func (h *ContentSecurityHeader) Encrypted() bool { 50 | return h.ContentType == httpx.CryptionType 51 | } 52 | 53 | // ParseContentSecurity parses content security settings in give r. 54 | func ParseContentSecurity(decrypters map[string]codec.RsaDecrypter, r *http.Request) ( 55 | *ContentSecurityHeader, error, 56 | ) { 57 | contentSecurity := r.Header.Get(httpx.ContentSecurity) 58 | attrs := httpx.ParseHeader(contentSecurity) 59 | fingerprint := attrs[httpx.KeyField] 60 | secret := attrs[httpx.SecretField] 61 | signature := attrs[signatureField] 62 | 63 | if len(fingerprint) == 0 || len(secret) == 0 || len(signature) == 0 { 64 | return nil, ErrInvalidHeader 65 | } 66 | 67 | decrypter, ok := decrypters[fingerprint] 68 | if !ok { 69 | return nil, ErrInvalidPublicKey 70 | } 71 | 72 | decryptedSecret, err := decrypter.DecryptBase64(secret) 73 | if err != nil { 74 | return nil, ErrInvalidSecret 75 | } 76 | 77 | attrs = httpx.ParseHeader(string(decryptedSecret)) 78 | base64Key := attrs[httpx.KeyField] 79 | timestamp := attrs[timeField] 80 | contentType := attrs[httpx.TypeField] 81 | 82 | key, err := base64.StdEncoding.DecodeString(base64Key) 83 | if err != nil { 84 | return nil, ErrInvalidKey 85 | } 86 | 87 | cType, err := strconv.Atoi(contentType) 88 | if err != nil { 89 | return nil, ErrInvalidContentType 90 | } 91 | 92 | return &ContentSecurityHeader{ 93 | Key: key, 94 | Timestamp: timestamp, 95 | ContentType: cType, 96 | Signature: signature, 97 | }, nil 98 | } 99 | 100 | // VerifySignature verifies the signature in given r. 101 | func VerifySignature(r *http.Request, securityHeader *ContentSecurityHeader, tolerance time.Duration) int { 102 | seconds, err := strconv.ParseInt(securityHeader.Timestamp, 10, 64) 103 | if err != nil { 104 | return httpx.CodeSignatureInvalidHeader 105 | } 106 | 107 | now := time.Now().Unix() 108 | toleranceSeconds := int64(tolerance.Seconds()) 109 | if seconds+toleranceSeconds < now || now+toleranceSeconds < seconds { 110 | return httpx.CodeSignatureWrongTime 111 | } 112 | 113 | reqPath, reqQuery := getPathQuery(r) 114 | signContent := strings.Join([]string{ 115 | securityHeader.Timestamp, 116 | r.Method, 117 | reqPath, 118 | reqQuery, 119 | computeBodySignature(r), 120 | }, "\n") 121 | actualSignature := codec.HmacBase64(securityHeader.Key, signContent) 122 | 123 | if securityHeader.Signature == actualSignature { 124 | return httpx.CodeSignaturePass 125 | } 126 | 127 | logx.Infof("signature different, expect: %s, actual: %s", 128 | securityHeader.Signature, actualSignature) 129 | 130 | return httpx.CodeSignatureInvalidToken 131 | } 132 | 133 | func computeBodySignature(r *http.Request) string { 134 | var dup io.ReadCloser 135 | r.Body, dup = iox.DupReadCloser(r.Body) 136 | sha := sha256.New() 137 | io.Copy(sha, r.Body) 138 | r.Body = dup 139 | return fmt.Sprintf("%x", sha.Sum(nil)) 140 | } 141 | 142 | func getPathQuery(r *http.Request) (string, string) { 143 | requestUri := r.Header.Get(requestUriHeader) 144 | if len(requestUri) == 0 { 145 | return r.URL.Path, r.URL.RawQuery 146 | } 147 | 148 | uri, err := url.Parse(requestUri) 149 | if err != nil { 150 | return r.URL.Path, r.URL.RawQuery 151 | } 152 | 153 | return uri.Path, uri.RawQuery 154 | } 155 | -------------------------------------------------------------------------------- /code/rest/rest/internal/starter.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | 8 | "github.com/zeromicro/go-zero/core/logx" 9 | "github.com/zeromicro/go-zero/core/proc" 10 | ) 11 | 12 | // StartOption defines the method to customize http.Server. 13 | type StartOption func(srv *http.Server) 14 | 15 | // StartHttp starts a http server. 16 | func StartHttp(host string, port int, handler http.Handler, opts ...StartOption) error { 17 | return start(host, port, handler, func(srv *http.Server) error { 18 | return srv.ListenAndServe() 19 | }, opts...) 20 | } 21 | 22 | // StartHttps starts a https server. 23 | func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler, 24 | opts ...StartOption, 25 | ) error { 26 | return start(host, port, handler, func(srv *http.Server) error { 27 | // certFile and keyFile are set in buildHttpsServer 28 | return srv.ListenAndServeTLS(certFile, keyFile) 29 | }, opts...) 30 | } 31 | 32 | func start(host string, port int, handler http.Handler, run func(srv *http.Server) error, 33 | opts ...StartOption, 34 | ) (err error) { 35 | server := &http.Server{ 36 | Addr: fmt.Sprintf("%s:%d", host, port), 37 | Handler: handler, 38 | } 39 | for _, opt := range opts { 40 | opt(server) 41 | } 42 | 43 | waitForCalled := proc.AddWrapUpListener(func() { 44 | if e := server.Shutdown(context.Background()); err != nil { 45 | logx.Error(e) 46 | } 47 | }) 48 | defer func() { 49 | if err == http.ErrServerClosed { 50 | waitForCalled() 51 | } 52 | }() 53 | 54 | return run(server) 55 | } 56 | -------------------------------------------------------------------------------- /code/rest/rest/pathvar/params.go: -------------------------------------------------------------------------------- 1 | package pathvar 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | ) 7 | 8 | var pathVars = contextKey("pathVars") 9 | 10 | // Vars parses path variables and returns a map. 11 | func Vars(r *http.Request) map[string]string { 12 | vars, ok := r.Context().Value(pathVars).(map[string]string) 13 | if ok { 14 | return vars 15 | } 16 | 17 | return nil 18 | } 19 | 20 | // WithVars writes params into given r and returns a new http.Request. 21 | func WithVars(r *http.Request, params map[string]string) *http.Request { 22 | return r.WithContext(context.WithValue(r.Context(), pathVars, params)) 23 | } 24 | 25 | type contextKey string 26 | 27 | func (c contextKey) String() string { 28 | return "rest/pathvar/context key: " + string(c) 29 | } 30 | -------------------------------------------------------------------------------- /code/rest/rest/router/patrouter.go: -------------------------------------------------------------------------------- 1 | package router 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "path" 7 | "strings" 8 | 9 | "github.com/zeromicro/go-zero/core/search" 10 | "github.com/zeromicro/go-zero/rest/httpx" 11 | "github.com/zeromicro/go-zero/rest/pathvar" 12 | ) 13 | 14 | const ( 15 | allowHeader = "Allow" 16 | allowMethodSeparator = ", " 17 | ) 18 | 19 | var ( 20 | // ErrInvalidMethod is an error that indicates not a valid http method. 21 | ErrInvalidMethod = errors.New("not a valid http method") 22 | // ErrInvalidPath is an error that indicates path is not start with /. 23 | ErrInvalidPath = errors.New("path must begin with '/'") 24 | ) 25 | 26 | type patRouter struct { 27 | trees map[string]*search.Tree 28 | notFound http.Handler 29 | notAllowed http.Handler 30 | } 31 | 32 | // NewRouter returns a httpx.Router. 33 | func NewRouter() httpx.Router { 34 | return &patRouter{ 35 | trees: make(map[string]*search.Tree), 36 | } 37 | } 38 | 39 | func (pr *patRouter) Handle(method, reqPath string, handler http.Handler) error { 40 | if !validMethod(method) { 41 | return ErrInvalidMethod 42 | } 43 | 44 | if len(reqPath) == 0 || reqPath[0] != '/' { 45 | return ErrInvalidPath 46 | } 47 | 48 | cleanPath := path.Clean(reqPath) 49 | tree, ok := pr.trees[method] 50 | if ok { 51 | return tree.Add(cleanPath, handler) 52 | } 53 | 54 | tree = search.NewTree() 55 | pr.trees[method] = tree 56 | return tree.Add(cleanPath, handler) 57 | } 58 | 59 | func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { 60 | reqPath := path.Clean(r.URL.Path) // 返回相当于path的最短路径名称 61 | if tree, ok := pr.trees[r.Method]; ok { // 查找对应 http method 62 | if result, ok := tree.Search(reqPath); ok { // 查找路由 path 63 | if len(result.Params) > 0 { 64 | r = pathvar.WithVars(r, result.Params) // 获取路由参数并且添加到 *http.Request 中 65 | } 66 | result.Item.(http.Handler).ServeHTTP(w, r) // 调度方法 67 | return 68 | } 69 | } 70 | 71 | allows, ok := pr.methodsAllowed(r.Method, reqPath) 72 | if !ok { 73 | pr.handleNotFound(w, r) 74 | return 75 | } 76 | 77 | if pr.notAllowed != nil { 78 | pr.notAllowed.ServeHTTP(w, r) 79 | } else { 80 | w.Header().Set(allowHeader, allows) 81 | w.WriteHeader(http.StatusMethodNotAllowed) 82 | } 83 | } 84 | 85 | func (pr *patRouter) SetNotFoundHandler(handler http.Handler) { 86 | pr.notFound = handler 87 | } 88 | 89 | func (pr *patRouter) SetNotAllowedHandler(handler http.Handler) { 90 | pr.notAllowed = handler 91 | } 92 | 93 | func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) { 94 | if pr.notFound != nil { 95 | pr.notFound.ServeHTTP(w, r) 96 | } else { 97 | http.NotFound(w, r) 98 | } 99 | } 100 | 101 | // 判断路由请求可否放行 102 | func (pr *patRouter) methodsAllowed(method, path string) (string, bool) { 103 | var allows []string 104 | 105 | for treeMethod, tree := range pr.trees { 106 | if treeMethod == method { 107 | continue 108 | } 109 | 110 | _, ok := tree.Search(path) 111 | if ok { 112 | allows = append(allows, treeMethod) 113 | } 114 | } 115 | 116 | if len(allows) > 0 { 117 | return strings.Join(allows, allowMethodSeparator), true 118 | } 119 | 120 | return "", false 121 | } 122 | 123 | func validMethod(method string) bool { 124 | return method == http.MethodDelete || method == http.MethodGet || 125 | method == http.MethodHead || method == http.MethodOptions || 126 | method == http.MethodPatch || method == http.MethodPost || 127 | method == http.MethodPut 128 | } 129 | -------------------------------------------------------------------------------- /code/rest/rest/server.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "crypto/tls" 5 | "log" 6 | "net/http" 7 | "path" 8 | "time" 9 | 10 | "gozerosource/code/rest/rest/internal/cors" 11 | 12 | "gozerosource/code/rest/rest/handler" 13 | "gozerosource/code/rest/rest/httpx" 14 | "gozerosource/code/rest/rest/router" 15 | 16 | "github.com/zeromicro/go-zero/core/logx" 17 | ) 18 | 19 | type ( 20 | // RunOption defines the method to customize a Server. 21 | RunOption func(*Server) 22 | 23 | // A Server is a http server. 24 | Server struct { 25 | ngin *engine 26 | router httpx.Router 27 | } 28 | ) 29 | 30 | // MustNewServer returns a server with given config of c and options defined in opts. 31 | // Be aware that later RunOption might overwrite previous one that write the same option. 32 | // The process will exit if error occurs. 33 | // 初始化(如有出错直接退出) 34 | func MustNewServer(c RestConf, opts ...RunOption) *Server { 35 | server, err := NewServer(c, opts...) 36 | if err != nil { 37 | log.Fatal(err) 38 | } 39 | 40 | return server 41 | } 42 | 43 | // NewServer returns a server with given config of c and options defined in opts. 44 | // Be aware that later RunOption might overwrite previous one that write the same option. 45 | // 初始化 46 | func NewServer(c RestConf, opts ...RunOption) (*Server, error) { 47 | if err := c.SetUp(); err != nil { 48 | return nil, err 49 | } 50 | 51 | server := &Server{ 52 | ngin: newEngine(c), // 加载核心引擎 53 | router: router.NewRouter(), // 加载路由 54 | } 55 | 56 | opts = append([]RunOption{WithNotFoundHandler(nil)}, opts...) // 加载路由未找到方法 57 | for _, opt := range opts { 58 | opt(server) // 加载运行时方法 59 | } 60 | 61 | return server, nil 62 | } 63 | 64 | // AddRoutes add given routes into the Server. 65 | // 批量添加路由 66 | func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) { 67 | r := featuredRoutes{ 68 | routes: rs, 69 | } 70 | for _, opt := range opts { 71 | opt(&r) 72 | } 73 | s.ngin.addRoutes(r) 74 | } 75 | 76 | // AddRoute adds given route into the Server. 77 | // 添加路由 78 | func (s *Server) AddRoute(r Route, opts ...RouteOption) { 79 | s.AddRoutes([]Route{r}, opts...) 80 | } 81 | 82 | // Start starts the Server. 83 | // Graceful shutdown is enabled by default. 84 | // Use proc.SetTimeToForceQuit to customize the graceful shutdown period. 85 | // 启动服务 86 | func (s *Server) Start() { 87 | handleError(s.ngin.start(s.router)) 88 | } 89 | 90 | // Stop stops the Server. 91 | // 停止服务 92 | func (s *Server) Stop() { 93 | logx.Close() 94 | } 95 | 96 | // Use adds the given middleware in the Server. 97 | // 加载中间件 98 | func (s *Server) Use(middleware Middleware) { 99 | s.ngin.use(middleware) 100 | } 101 | 102 | // ToMiddleware converts the given handler to a Middleware. 103 | // 将 handle 转换为 Middleware 104 | func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware { 105 | return func(handle http.HandlerFunc) http.HandlerFunc { 106 | return handler(handle).ServeHTTP 107 | } 108 | } 109 | 110 | // WithCors returns a func to enable CORS for given origin, or default to all origins (*). 111 | // 跨域处理器 112 | func WithCors(origin ...string) RunOption { 113 | return func(server *Server) { 114 | server.router.SetNotAllowedHandler(cors.NotAllowedHandler(nil, origin...)) 115 | server.Use(cors.Middleware(nil, origin...)) 116 | } 117 | } 118 | 119 | // WithCustomCors returns a func to enable CORS for given origin, or default to all origins (*), 120 | // fn lets caller customizing the response. 121 | // 自定义跨域处理器 122 | func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(http.ResponseWriter), 123 | origin ...string, 124 | ) RunOption { 125 | return func(server *Server) { 126 | server.router.SetNotAllowedHandler(cors.NotAllowedHandler(notAllowedFn, origin...)) 127 | server.Use(cors.Middleware(middlewareFn, origin...)) 128 | } 129 | } 130 | 131 | // WithJwt returns a func to enable jwt authentication in given route. 132 | // jwt 处理器 133 | func WithJwt(secret string) RouteOption { 134 | return func(r *featuredRoutes) { 135 | validateSecret(secret) 136 | r.jwt.enabled = true 137 | r.jwt.secret = secret 138 | } 139 | } 140 | 141 | // WithJwtTransition returns a func to enable jwt authentication as well as jwt secret transition. 142 | // Which means old and new jwt secrets work together for a period. 143 | func WithJwtTransition(secret, prevSecret string) RouteOption { 144 | return func(r *featuredRoutes) { 145 | // why not validate prevSecret, because prevSecret is an already used one, 146 | // even it not meet our requirement, we still need to allow the transition. 147 | validateSecret(secret) 148 | r.jwt.enabled = true 149 | r.jwt.secret = secret 150 | r.jwt.prevSecret = prevSecret 151 | } 152 | } 153 | 154 | // WithMiddlewares adds given middlewares to given routes. 155 | // jwt token 转换器,新老 token 可以同时使用 156 | func WithMiddlewares(ms []Middleware, rs ...Route) []Route { 157 | for i := len(ms) - 1; i >= 0; i-- { 158 | rs = WithMiddleware(ms[i], rs...) 159 | } 160 | return rs 161 | } 162 | 163 | // WithMiddleware adds given middleware to given route. 164 | // 给指定路由加载中间件 165 | func WithMiddleware(middleware Middleware, rs ...Route) []Route { 166 | routes := make([]Route, len(rs)) 167 | 168 | for i := range rs { 169 | route := rs[i] 170 | routes[i] = Route{ 171 | Method: route.Method, 172 | Path: route.Path, 173 | Handler: middleware(route.Handler), 174 | } 175 | } 176 | 177 | return routes 178 | } 179 | 180 | // WithNotFoundHandler returns a RunOption with not found handler set to given handler. 181 | // 路由未找到处理方法 182 | func WithNotFoundHandler(handler http.Handler) RunOption { 183 | return func(server *Server) { 184 | notFoundHandler := server.ngin.notFoundHandler(handler) 185 | server.router.SetNotFoundHandler(notFoundHandler) 186 | } 187 | } 188 | 189 | // WithNotAllowedHandler returns a RunOption with not allowed handler set to given handler. 190 | // 不予通过处理方法 191 | func WithNotAllowedHandler(handler http.Handler) RunOption { 192 | return func(server *Server) { 193 | server.router.SetNotAllowedHandler(handler) 194 | } 195 | } 196 | 197 | // WithPrefix adds group as a prefix to the route paths. 198 | // 路由前缀处理方法 199 | func WithPrefix(group string) RouteOption { 200 | return func(r *featuredRoutes) { 201 | var routes []Route 202 | for _, rt := range r.routes { 203 | p := path.Join(group, rt.Path) 204 | routes = append(routes, Route{ 205 | Method: rt.Method, 206 | Path: p, 207 | Handler: rt.Handler, 208 | }) 209 | } 210 | r.routes = routes 211 | } 212 | } 213 | 214 | // WithPriority returns a RunOption with priority. 215 | // 给路 featuredRoutes 提高优先级,在熔断服务中使 featuredRoutes 有更高的熔断阀值 216 | func WithPriority() RouteOption { 217 | return func(r *featuredRoutes) { 218 | r.priority = true 219 | } 220 | } 221 | 222 | // WithRouter returns a RunOption that make server run with given router. 223 | // server 使用给定路由 224 | func WithRouter(router httpx.Router) RunOption { 225 | return func(server *Server) { 226 | server.router = router 227 | } 228 | } 229 | 230 | // WithSignature returns a RouteOption to enable signature verification. 231 | // server 配置并开启签名验证 232 | func WithSignature(signature SignatureConf) RouteOption { 233 | return func(r *featuredRoutes) { 234 | r.signature.enabled = true 235 | r.signature.Strict = signature.Strict 236 | r.signature.Expiry = signature.Expiry 237 | r.signature.PrivateKeys = signature.PrivateKeys 238 | } 239 | } 240 | 241 | // WithTimeout returns a RouteOption to set timeout with given value. 242 | // server 配置并开启服务超时 243 | func WithTimeout(timeout time.Duration) RouteOption { 244 | return func(r *featuredRoutes) { 245 | r.timeout = timeout 246 | } 247 | } 248 | 249 | // WithTLSConfig returns a RunOption that with given tls config. 250 | // server 配置并开启 TLS 251 | func WithTLSConfig(cfg *tls.Config) RunOption { 252 | return func(srv *Server) { 253 | srv.ngin.setTlsConfig(cfg) 254 | } 255 | } 256 | 257 | // WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set. 258 | // server 配置未授权回调函数 259 | func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption { 260 | return func(srv *Server) { 261 | srv.ngin.setUnauthorizedCallback(callback) 262 | } 263 | } 264 | 265 | // WithUnsignedCallback returns a RunOption that with given unsigned callback set. 266 | // server 配置未签名回调函数 267 | func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption { 268 | return func(srv *Server) { 269 | srv.ngin.setUnsignedCallback(callback) 270 | } 271 | } 272 | 273 | // 记录并抛出错误 274 | func handleError(err error) { 275 | // ErrServerClosed means the server is closed manually 276 | if err == nil || err == http.ErrServerClosed { 277 | return 278 | } 279 | 280 | logx.Error(err) 281 | panic(err) 282 | } 283 | 284 | // secret 字符规范验证 285 | func validateSecret(secret string) { 286 | if len(secret) < 8 { 287 | panic("secret's length can't be less than 8") 288 | } 289 | } 290 | -------------------------------------------------------------------------------- /code/rest/rest/token/tokenparser.go: -------------------------------------------------------------------------------- 1 | package token 2 | 3 | import ( 4 | "net/http" 5 | "sync" 6 | "sync/atomic" 7 | "time" 8 | 9 | "github.com/golang-jwt/jwt/v4" 10 | "github.com/golang-jwt/jwt/v4/request" 11 | "github.com/zeromicro/go-zero/core/timex" 12 | ) 13 | 14 | const claimHistoryResetDuration = time.Hour * 24 15 | 16 | type ( 17 | // ParseOption defines the method to customize a TokenParser. 18 | ParseOption func(parser *TokenParser) 19 | 20 | // A TokenParser is used to parse tokens. 21 | TokenParser struct { 22 | resetTime time.Duration // 重置时间 23 | resetDuration time.Duration // 重置周期 24 | history sync.Map // 历史数据 25 | } 26 | ) 27 | 28 | // NewTokenParser returns a TokenParser. 29 | // token 解析器 30 | func NewTokenParser(opts ...ParseOption) *TokenParser { 31 | parser := &TokenParser{ 32 | resetTime: timex.Now(), 33 | resetDuration: claimHistoryResetDuration, 34 | } 35 | 36 | for _, opt := range opts { 37 | opt(parser) 38 | } 39 | 40 | return parser 41 | } 42 | 43 | // ParseToken parses token from given r, with passed in secret and prevSecret. 44 | // 解析 token 45 | func (tp *TokenParser) ParseToken(r *http.Request, secret, prevSecret string) (*jwt.Token, error) { 46 | var token *jwt.Token 47 | var err error 48 | 49 | if len(prevSecret) > 0 { 50 | count := tp.loadCount(secret) 51 | prevCount := tp.loadCount(prevSecret) 52 | 53 | var first, second string 54 | if count > prevCount { 55 | first = secret 56 | second = prevSecret 57 | } else { 58 | first = prevSecret 59 | second = secret 60 | } 61 | 62 | token, err = tp.doParseToken(r, first) 63 | if err != nil { 64 | token, err = tp.doParseToken(r, second) 65 | if err != nil { 66 | return nil, err 67 | } 68 | 69 | tp.incrementCount(second) 70 | } else { 71 | tp.incrementCount(first) 72 | } 73 | } else { 74 | token, err = tp.doParseToken(r, secret) 75 | if err != nil { 76 | return nil, err 77 | } 78 | } 79 | 80 | return token, nil 81 | } 82 | 83 | // jwt 验证 84 | func (tp *TokenParser) doParseToken(r *http.Request, secret string) (*jwt.Token, error) { 85 | return request.ParseFromRequest(r, request.AuthorizationHeaderExtractor, 86 | func(token *jwt.Token) (interface{}, error) { 87 | return []byte(secret), nil 88 | }, request.WithParser(newParser())) 89 | } 90 | 91 | // 递增验证次数 92 | func (tp *TokenParser) incrementCount(secret string) { 93 | now := timex.Now() 94 | if tp.resetTime+tp.resetDuration < now { 95 | tp.history.Range(func(key, value interface{}) bool { 96 | tp.history.Delete(key) 97 | return true 98 | }) 99 | } 100 | 101 | value, ok := tp.history.Load(secret) 102 | if ok { 103 | atomic.AddUint64(value.(*uint64), 1) 104 | } else { 105 | var count uint64 = 1 106 | tp.history.Store(secret, &count) 107 | } 108 | } 109 | 110 | // 加载验证次数 111 | func (tp *TokenParser) loadCount(secret string) uint64 { 112 | value, ok := tp.history.Load(secret) 113 | if ok { 114 | return *value.(*uint64) 115 | } 116 | 117 | return 0 118 | } 119 | 120 | // WithResetDuration returns a func to customize a TokenParser with reset duration. 121 | // 设置过期周期 122 | func WithResetDuration(duration time.Duration) ParseOption { 123 | return func(parser *TokenParser) { 124 | parser.resetDuration = duration 125 | } 126 | } 127 | 128 | func newParser() *jwt.Parser { 129 | return jwt.NewParser(jwt.WithJSONNumber()) 130 | } 131 | -------------------------------------------------------------------------------- /code/rest/rest/types.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "net/http" 5 | "time" 6 | ) 7 | 8 | type ( 9 | // Middleware defines the middleware method. 10 | // 中间件标准类型 11 | Middleware func(next http.HandlerFunc) http.HandlerFunc 12 | 13 | // A Route is a http route. 14 | // 路由 15 | Route struct { 16 | Method string // http 方法 17 | Path string // 路由 path 18 | Handler http.HandlerFunc // 处理函数 19 | } 20 | 21 | // RouteOption defines the method to customize a featured route. 22 | // 定义特色路由 23 | RouteOption func(r *featuredRoutes) 24 | 25 | // jwt 26 | jwtSetting struct { 27 | enabled bool // 是否开启 28 | secret string // 加密串 29 | prevSecret string // 前一个加密串(兼容处理) 30 | } 31 | 32 | // 签名 33 | signatureSetting struct { 34 | SignatureConf // 签名配置 35 | enabled bool // 是否开启 36 | } 37 | 38 | // 特色路由 39 | featuredRoutes struct { 40 | timeout time.Duration // 超时处理 41 | priority bool // 是否开启优先 42 | jwt jwtSetting // jwt 配置 43 | signature signatureSetting // 签名配置 44 | routes []Route // 指定路由 45 | } 46 | ) 47 | -------------------------------------------------------------------------------- /code/rest/rest_server.go: -------------------------------------------------------------------------------- 1 | package rest 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "gozerosource/code/rest/rest" 8 | 9 | "github.com/zeromicro/go-zero/core/logx" 10 | "github.com/zeromicro/go-zero/core/service" 11 | ) 12 | 13 | func ServerStart() { 14 | c := rest.RestConf{ 15 | Host: "127.0.0.1", 16 | Port: 8081, 17 | MaxConns: 100, 18 | MaxBytes: 1048576, 19 | Timeout: 1000, 20 | CpuThreshold: 800, 21 | ServiceConf: service.ServiceConf{ 22 | Log: logx.LogConf{ 23 | Mode: "console", 24 | Path: "./logs", 25 | }, 26 | }, 27 | } 28 | server := rest.MustNewServer(c, rest.WithCors("localhost:8080")) 29 | defer server.Stop() 30 | server.AddRoutes([]rest.Route{ 31 | { 32 | Method: http.MethodGet, 33 | Path: "/ping", 34 | Handler: func(w http.ResponseWriter, r *http.Request) { 35 | w.Write([]byte("pong")) 36 | }, 37 | }, 38 | { 39 | Method: http.MethodGet, 40 | Path: "/check", 41 | Handler: func(w http.ResponseWriter, r *http.Request) { 42 | w.Write([]byte("ok")) 43 | }, 44 | }, 45 | }) 46 | 47 | fmt.Printf("Starting server at %s:%d...\n", c.Host, c.Port) 48 | server.Start() 49 | } 50 | -------------------------------------------------------------------------------- /code/rest/rest_test.go: -------------------------------------------------------------------------------- 1 | package rest_test 2 | 3 | import ( 4 | "io" 5 | "io/ioutil" 6 | "net/http" 7 | "net/url" 8 | "testing" 9 | "time" 10 | 11 | "gozerosource/code/rest" 12 | ) 13 | 14 | func Test_Rest(t *testing.T) { 15 | go rest.ServerStart() 16 | time.Sleep(time.Millisecond) 17 | for _, tt := range [...]struct { 18 | name, method, uri string 19 | body io.Reader 20 | want *http.Request 21 | wantBody string 22 | }{ 23 | { 24 | name: "GET with ping url", 25 | method: "GET", 26 | uri: "http://127.0.0.1:8081/ping", 27 | body: nil, 28 | want: &http.Request{ 29 | Method: "GET", 30 | Host: "127.0.0.1:8081", 31 | URL: &url.URL{ 32 | Scheme: "http", 33 | Path: "/ping", 34 | RawPath: "/ping", 35 | Host: "127.0.0.1:8081", 36 | }, 37 | Header: http.Header{}, 38 | Proto: "HTTP/1.1", 39 | }, 40 | wantBody: "pong", 41 | }, 42 | { 43 | name: "GET with check url", 44 | method: "GET", 45 | uri: "http://127.0.0.1:8081/check", 46 | body: nil, 47 | want: &http.Request{ 48 | Method: "GET", 49 | Host: "127.0.0.1:8081", 50 | URL: &url.URL{ 51 | Scheme: "http", 52 | Path: "/check", 53 | RawPath: "/check", 54 | Host: "127.0.0.1:8081", 55 | }, 56 | Header: http.Header{}, 57 | Proto: "HTTP/1.1", 58 | }, 59 | wantBody: "ok", 60 | }, 61 | } { 62 | t.Run(tt.name, func(t *testing.T) { 63 | body, err := httpRequest(tt.method, tt.uri, tt.body) 64 | if err != nil { 65 | t.Errorf("ReadAll: %v", err) 66 | } 67 | if string(body) != tt.wantBody { 68 | t.Errorf("Body = %q; want %q", body, tt.wantBody) 69 | } 70 | }) 71 | } 72 | } 73 | 74 | func httpRequest(method, url string, bodyRow io.Reader) ([]byte, error) { 75 | req, err := http.NewRequest(method, url, bodyRow) 76 | if err != nil { 77 | return nil, err 78 | } 79 | cli := &http.Client{} 80 | rsp, err := cli.Do(req) 81 | if err != nil { 82 | return nil, err 83 | } 84 | defer rsp.Body.Close() 85 | body, err := ioutil.ReadAll(rsp.Body) 86 | if err != nil { 87 | return nil, err 88 | } 89 | return body, nil 90 | } 91 | -------------------------------------------------------------------------------- /code/shedding/shedding_test.go: -------------------------------------------------------------------------------- 1 | package shedding_test 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "sync" 7 | "sync/atomic" 8 | "testing" 9 | "time" 10 | 11 | "gozerosource/code/core/load" 12 | "gozerosource/code/core/stat" 13 | 14 | "github.com/zeromicro/go-zero/core/mathx" 15 | ) 16 | 17 | const ( 18 | buckets = 10 19 | bucketDuration = time.Millisecond * 50 20 | ) 21 | 22 | func TestAdaptiveShedder(t *testing.T) { 23 | load.DisableLog() 24 | shedder := load.NewAdaptiveShedder( 25 | load.WithWindow(bucketDuration), 26 | load.WithBuckets(buckets), 27 | load.WithCpuThreshold(100), 28 | ) 29 | var wg sync.WaitGroup 30 | var drop int64 31 | proba := mathx.NewProba() 32 | for i := 0; i < 100; i++ { 33 | wg.Add(1) 34 | go func() { 35 | defer wg.Done() 36 | for i := 0; i < 30; i++ { 37 | promise, err := shedder.Allow() 38 | if err != nil { 39 | atomic.AddInt64(&drop, 1) 40 | } else { 41 | count := rand.Intn(5) 42 | time.Sleep(time.Millisecond * time.Duration(count)) 43 | if proba.TrueOnProba(0.01) { 44 | promise.Fail() 45 | } else { 46 | promise.Pass() 47 | } 48 | } 49 | } 50 | }() 51 | } 52 | wg.Wait() 53 | } 54 | 55 | func Test_Shedding(t *testing.T) { 56 | shedder := load.NewAdaptiveShedder(load.WithCpuThreshold(6)) 57 | // cpuFull() 58 | for i := 0; i < 100; i++ { 59 | fmt.Println(i, stat.CpuUsage()) 60 | promise, err := shedder.Allow() 61 | if err != nil { 62 | fmt.Println(err) 63 | return 64 | } 65 | promise.Fail() 66 | // promise.Pass() 67 | time.Sleep(10 * time.Millisecond) 68 | } 69 | } 70 | 71 | func cpuFull() { 72 | for i := 0; i < 100; i++ { 73 | go func(i int) { 74 | for { 75 | } 76 | }(i) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module gozerosource 2 | 3 | go 1.17 4 | 5 | require ( 6 | github.com/golang-jwt/jwt/v4 v4.2.0 7 | github.com/justinas/alice v1.2.0 8 | github.com/zeromicro/go-zero v1.3.1 9 | go.opentelemetry.io/otel v1.3.0 10 | go.opentelemetry.io/otel/trace v1.3.0 11 | golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 12 | google.golang.org/grpc v1.44.0 13 | google.golang.org/protobuf v1.27.1 14 | k8s.io/api v0.20.12 15 | k8s.io/apimachinery v0.20.12 16 | k8s.io/client-go v0.20.12 17 | ) 18 | 19 | require ( 20 | github.com/beorn7/perks v1.0.1 // indirect 21 | github.com/cespare/xxhash/v2 v2.1.2 // indirect 22 | github.com/coreos/go-semver v0.3.0 // indirect 23 | github.com/coreos/go-systemd/v22 v22.3.2 // indirect 24 | github.com/davecgh/go-spew v1.1.1 // indirect 25 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 26 | github.com/go-logr/logr v1.2.2 // indirect 27 | github.com/go-logr/stdr v1.2.2 // indirect 28 | github.com/go-redis/redis/v8 v8.11.4 // indirect 29 | github.com/gogo/protobuf v1.3.2 // indirect 30 | github.com/golang/mock v1.6.0 // indirect 31 | github.com/golang/protobuf v1.5.2 // indirect 32 | github.com/google/go-cmp v0.5.6 // indirect 33 | github.com/google/gofuzz v1.1.0 // indirect 34 | github.com/google/uuid v1.3.0 // indirect 35 | github.com/googleapis/gnostic v0.4.1 // indirect 36 | github.com/hashicorp/golang-lru v0.5.1 // indirect 37 | github.com/json-iterator/go v1.1.11 // indirect 38 | github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect 39 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 40 | github.com/modern-go/reflect2 v1.0.1 // indirect 41 | github.com/openzipkin/zipkin-go v0.4.0 // indirect 42 | github.com/prometheus/client_golang v1.11.0 // indirect 43 | github.com/prometheus/client_model v0.2.0 // indirect 44 | github.com/prometheus/common v0.26.0 // indirect 45 | github.com/prometheus/procfs v0.6.0 // indirect 46 | github.com/spaolacci/murmur3 v1.1.0 // indirect 47 | go.etcd.io/etcd/api/v3 v3.5.2 // indirect 48 | go.etcd.io/etcd/client/pkg/v3 v3.5.2 // indirect 49 | go.etcd.io/etcd/client/v3 v3.5.2 // indirect 50 | go.opentelemetry.io/otel/exporters/jaeger v1.3.0 // indirect 51 | go.opentelemetry.io/otel/exporters/zipkin v1.3.0 // indirect 52 | go.opentelemetry.io/otel/sdk v1.3.0 // indirect 53 | go.uber.org/atomic v1.9.0 // indirect 54 | go.uber.org/automaxprocs v1.4.0 // indirect 55 | go.uber.org/multierr v1.8.0 // indirect 56 | go.uber.org/zap v1.21.0 // indirect 57 | golang.org/x/crypto v0.0.0-20210920023735-84f357641f63 // indirect 58 | golang.org/x/net v0.0.0-20220225172249-27dd8689420f // indirect 59 | golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d // indirect 60 | golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 // indirect 61 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect 62 | golang.org/x/text v0.3.7 // indirect 63 | google.golang.org/appengine v1.6.5 // indirect 64 | google.golang.org/genproto v0.0.0-20220228195345-15d65a4533f7 // indirect 65 | gopkg.in/inf.v0 v0.9.1 // indirect 66 | gopkg.in/yaml.v2 v2.4.0 // indirect 67 | k8s.io/klog/v2 v2.40.1 // indirect 68 | k8s.io/utils v0.0.0-20201110183641-67b214c5f920 // indirect 69 | sigs.k8s.io/structured-merge-diff/v4 v4.1.2 // indirect 70 | sigs.k8s.io/yaml v1.2.0 // indirect 71 | ) 72 | -------------------------------------------------------------------------------- /images/go-zero-rest-start.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caicaispace/go-zero-source/7ec4cbac641714c947efc86a7c5df1bd3f25c48b/images/go-zero-rest-start.jpg --------------------------------------------------------------------------------