├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── annotations ├── Makefile ├── annotations.pb.go └── annotations.proto ├── examples ├── Makefile ├── main.go ├── proto │ └── example.proto └── server │ └── server.go ├── grpcmw ├── client_interceptor.go ├── client_level.go ├── client_router.go ├── registry │ ├── client.go │ └── server.go ├── route.go ├── server_interceptor.go ├── server_level.go ├── server_router.go └── wrappers.go └── protoc-gen-grpc-middleware ├── descriptor ├── file.go ├── interceptors.go ├── method.go ├── parse.go └── service.go ├── main.go └── template ├── init.go ├── interceptors.go ├── method.go ├── package.go └── service.go /.gitignore: -------------------------------------------------------------------------------- 1 | /examples/proto/*.pb.go 2 | /examples/proto/*.pb.mw.go 3 | /examples/examples 4 | !/annotations/*.pb.go 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 MARQUIS 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | EXAMPLES_DIR = examples/ 2 | ANNOTATIONS_DIR = annotations/ 3 | 4 | .PHONY: all annotations examples clean 5 | 6 | all: annotations examples 7 | 8 | annotations: 9 | @make -C $(ANNOTATIONS_DIR) 10 | 11 | examples: 12 | @make -C $(EXAMPLES_DIR) 13 | 14 | clean: 15 | @make clean -C $(ANNOTATIONS_DIR) 16 | @make clean -C $(EXAMPLES_DIR) 17 | 18 | re: clean all 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-grpcmw 2 | 3 | `go-grpcmw` provides a package and a protobuf generator for managing easily 4 | grpc interceptors. 5 | 6 | The package can be used without the protobuf generator. However, using both 7 | together will allow you to avoid writing redundant code. 8 | 9 | ## Prerequisites 10 | 11 | * Protobuf 3.0.0 or later. 12 | 13 | ## Installation 14 | 15 | ```shell 16 | go get -u github.com/MarquisIO/go-grpcmw/protoc-gen-grpc-middleware 17 | go get -u github.com/golang/protobuf/protoc-gen-go 18 | ``` 19 | 20 | ## Quick start 21 | 22 | Write your gRPC service definition: 23 | 24 | ```protobuf 25 | syntax = "proto3"; 26 | 27 | import "github.com/MarquisIO/go-grpcmw/annotations/annotations.proto"; 28 | 29 | package pb; 30 | 31 | option (grpcmw.package_interceptors) = { 32 | indexes: ["index"] 33 | }; 34 | 35 | service SomeService { 36 | option (grpcmw.service_interceptors) = { 37 | indexes: ["index"] 38 | }; 39 | 40 | rpc SomeMethod (Message) returns (Message) { 41 | option (grpcmw.method_interceptors) = { 42 | indexes: ["index"] 43 | }; 44 | } 45 | } 46 | 47 | message Message { 48 | string msg = 1; 49 | } 50 | ``` 51 | 52 | Generate the stubs: 53 | 54 | ```shell 55 | protoc --go_out=plugins=grpc:. --grpc-middleware_out=:. path/to/you/file.proto 56 | ``` 57 | 58 | Use the code generated to add your own middlewares: 59 | 60 | ```go 61 | // Register an interceptor in the registry 62 | registry.GetClientInterceptor("index"). 63 | AddGRPCUnaryInterceptor(SomeUnaryClientInterceptor). 64 | AddGRPCStreamInterceptor(SomeStreamClientInterceptor) 65 | registry.GetServerInterceptor("index"). 66 | AddGRPCUnaryInterceptor(SomeUnaryServerInterceptor). 67 | AddGRPCStreamInterceptor(SomeStreamServerInterceptor) 68 | 69 | // Client 70 | clientRouter := grpcmw.NewClientRouter() 71 | clientStub := pb.RegisterClientInterceptors(clientRouter) 72 | clientStub.RegisterSomeService(). 73 | SomeMethod(). 74 | AddGRPCInterceptor(clientUnaryMiddleware) 75 | grpc.Dial(address, 76 | grpc.WithStreamInterceptor(clientRouter.StreamResolver()), 77 | grpc.WithUnaryInterceptor(clientRouter.UnaryResolver()), 78 | ) 79 | 80 | // Server 81 | serverRouter := grpcmw.NewServerRouter() 82 | serverStub := pb.RegisterServerInterceptors(serverRouter) 83 | serverStub.RegisterSomeService(). 84 | SomeMethod(). 85 | AddGRPCInterceptor(serverUnaryMiddleware) 86 | grpc.NewServer( 87 | grpc.UnaryInterceptor(serverRouter.UnaryResolver()), 88 | grpc.StreamInterceptor(serverRouter.StreamResolver()), 89 | ) 90 | ``` 91 | 92 | ## Chaining 93 | 94 | Four types of interceptors are provided: `ServerUnaryInterceptor`, 95 | `ServerStreamInterceptor`, `ClientUnaryInterceptor` and 96 | `ClientStreamInterceptor` (corresponding to those defined in 97 | [google.golang.org/grpc](https://godoc.org/google.golang.org/grpc)). They allow 98 | you to chain multiple gRPC interceptors of the same type: 99 | 100 | ```go 101 | // grpcInterceptor1 -> grpcInterceptor2 -> interceptor1 -> grpcInterceptor3 102 | intcp := grpcmw.NewUnaryServerInterceptor(grpcInterceptor1, grpcInterceptor2). 103 | AddInterceptor(interceptor1). 104 | AddGRPCInterceptor(grpcInterceptor3) 105 | ``` 106 | 107 | ## Routing 108 | 109 | This package also provides a routing feature so that interceptors can be bound 110 | either to: 111 | * a protobuf package: all requests to any service that have been declared in 112 | this package will go through the interceptor. 113 | * a gRPC service: all requests to this service will go through the interceptor. 114 | * a gRPC method: all requests to this method will go through the interceptor. 115 | 116 | `ServerRouter` and `ClientRouter` provide one more global level in addition to 117 | the three described above. 118 | These implementations are based on the route construction as defined 119 | [in the official gRPC repository](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md). 120 | 121 | ```go 122 | serverRouter := grpcmw.NewServerRouter() 123 | pkgInterceptor := grpcmw.NewServerInterceptor("pb") 124 | serviceInterceptor := grpcmw.NewServerInterceptor("Service") 125 | 126 | // Interceptors that have been added to `serviceInterceptor` will be called 127 | // each time the gRPC service `Service` will be requested. In other words, all 128 | // requests to "/pb.Service/*" will go through these interceptors. 129 | pkgInterceptor.Register(serviceInterceptor) 130 | serverRouter.GetRegister(). 131 | Register(pkgInterceptor) 132 | 133 | // In order to use the router, you have to create the server with it. 134 | grpc.NewServer( 135 | grpc.UnaryInterceptor(serverRouter.UnaryResolver()), 136 | grpc.StreamInterceptor(serverRouter.StreamResolver()), 137 | ) 138 | ``` 139 | 140 | ## Registry 141 | 142 | The `registry` package provides an interceptor registry for both server and 143 | client side. 144 | 145 | ```go 146 | registry.GetServerInterceptor("index"). 147 | AddGRPCUnaryInterceptor(SomeUnaryServerInterceptor). 148 | AddGRPCStreamInterceptor(SomeStreamServerInterceptor) 149 | ``` 150 | 151 | ## Protobuf generation 152 | 153 | In order to ease the use of the registry and routing features, a protobuf 154 | generator is provided which you can use with the following command: 155 | 156 | ```shell 157 | protoc --grpc-middleware_out=:. path/to/you/file.proto 158 | ``` 159 | 160 | ### Routing 161 | 162 | Say we have the following protobuf file: 163 | 164 | ```protobuf 165 | syntax = "proto3"; 166 | 167 | package pb; 168 | 169 | service SomeService { 170 | rpc SomeMethod (Message) returns (Message) {} 171 | } 172 | 173 | message Message { 174 | string msg = 1; 175 | } 176 | ``` 177 | 178 | It will create some helpers for adding interceptors to a package, service or 179 | method. 180 | 181 | ```go 182 | serverRouter := grpcmw.NewServerRouter() 183 | serverStub := pb.RegisterServerInterceptors(serverRouter) 184 | serverStub.AddGRPCInterceptor(pkgUnaryMiddleware) 185 | serviceStub := serverStub.RegisterSomeService() 186 | serviceStub.AddGRPCInterceptor(serviceUnaryMiddleware) 187 | methodStub := serviceStub.SomeMethod() 188 | methodStub.AddGRPCInterceptor(methodUnaryMiddleware) 189 | ``` 190 | 191 | ### Registry 192 | 193 | Three annotations are provided in 194 | [annotations/annotations.proto](./annotations/annotations.proto): 195 | * `package_interceptors`: for the package level. 196 | * `service_interceptors`: for the service level. 197 | * `method_interceptors`: for the method level. 198 | 199 | These annotations have an array of index (`indexes`) that tells the generator 200 | which interceptors from the registry have to be added to the router. 201 | 202 | Say we have the following protobuf file: 203 | 204 | ```protobuf 205 | syntax = "proto3"; 206 | 207 | import "github.com/MarquisIO/go-grpcmw/annotations/annotations.proto"; 208 | 209 | package pb; 210 | 211 | option (grpcmw.package_interceptors) = { 212 | indexes: ["index"] 213 | }; 214 | 215 | service SomeService { 216 | option (grpcmw.service_interceptors) = { 217 | indexes: ["index"] 218 | }; 219 | 220 | rpc SomeMethod (Message) returns (Message) { 221 | option (grpcmw.method_interceptors) = { 222 | indexes: ["index"] 223 | }; 224 | } 225 | } 226 | 227 | message Message { 228 | string msg = 1; 229 | } 230 | ``` 231 | 232 | You can then register interceptors in the registry at the index "index". 233 | 234 | ```go 235 | // Register an interceptor in the registry 236 | registry.GetServerInterceptor("index"). 237 | AddGRPCUnaryInterceptor(SomeUnaryServerInterceptor). 238 | AddGRPCStreamInterceptor(SomeStreamServerInterceptor) 239 | 240 | // You have to call `RegisterServerInterceptors` and `RegisterSomeService` so 241 | // that interceptors are added to the router at the package, servive and method 242 | // levels. 243 | serverRouter := grpcmw.NewServerRouter() 244 | serverStub := pb.RegisterServerInterceptors(serverRouter) 245 | serverStub.RegisterSomeService() 246 | ``` 247 | -------------------------------------------------------------------------------- /annotations/Makefile: -------------------------------------------------------------------------------- 1 | PROTOC = protoc 2 | 3 | PROTO_SRC = annotations.proto 4 | PROTO_PB_GO = $(PROTO_SRC:.proto=.pb.go) 5 | 6 | DESCRIPTOR = github.com/golang/protobuf/protoc-gen-go/descriptor 7 | 8 | GO_PACKAGE = annotations 9 | 10 | .PHONY: all clean re 11 | 12 | all: $(PROTO_PB_GO) 13 | 14 | clean: 15 | $(RM) $(PROTO_PB_GO) 16 | 17 | re: clean all 18 | 19 | %.pb.go: %.proto 20 | $(PROTOC) --go_out=import_path=$(GO_PACKAGE),Mgoogle/protobuf/descriptor.proto=$(DESCRIPTOR),plugins=grpc:. $^ 21 | -------------------------------------------------------------------------------- /annotations/annotations.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go. DO NOT EDIT. 2 | // source: annotations.proto 3 | 4 | /* 5 | Package annotations is a generated protocol buffer package. 6 | 7 | It is generated from these files: 8 | annotations.proto 9 | 10 | It has these top-level messages: 11 | Interceptors 12 | */ 13 | package annotations 14 | 15 | import proto "github.com/golang/protobuf/proto" 16 | import fmt "fmt" 17 | import math "math" 18 | import google_protobuf "github.com/golang/protobuf/protoc-gen-go/descriptor" 19 | 20 | // Reference imports to suppress errors if they are not otherwise used. 21 | var _ = proto.Marshal 22 | var _ = fmt.Errorf 23 | var _ = math.Inf 24 | 25 | // This is a compile-time assertion to ensure that this generated file 26 | // is compatible with the proto package it is being compiled against. 27 | // A compilation error at this line likely means your copy of the 28 | // proto package needs to be updated. 29 | const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package 30 | 31 | type Interceptors struct { 32 | Indexes []string `protobuf:"bytes,1,rep,name=indexes" json:"indexes,omitempty"` 33 | XXX_unrecognized []byte `json:"-"` 34 | } 35 | 36 | func (m *Interceptors) Reset() { *m = Interceptors{} } 37 | func (m *Interceptors) String() string { return proto.CompactTextString(m) } 38 | func (*Interceptors) ProtoMessage() {} 39 | func (*Interceptors) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } 40 | 41 | func (m *Interceptors) GetIndexes() []string { 42 | if m != nil { 43 | return m.Indexes 44 | } 45 | return nil 46 | } 47 | 48 | var E_PackageInterceptors = &proto.ExtensionDesc{ 49 | ExtendedType: (*google_protobuf.FileOptions)(nil), 50 | ExtensionType: (*Interceptors)(nil), 51 | Field: 1041, 52 | Name: "grpcmw.package_interceptors", 53 | Tag: "bytes,1041,opt,name=package_interceptors,json=packageInterceptors", 54 | Filename: "annotations.proto", 55 | } 56 | 57 | var E_ServiceInterceptors = &proto.ExtensionDesc{ 58 | ExtendedType: (*google_protobuf.ServiceOptions)(nil), 59 | ExtensionType: (*Interceptors)(nil), 60 | Field: 1041, 61 | Name: "grpcmw.service_interceptors", 62 | Tag: "bytes,1041,opt,name=service_interceptors,json=serviceInterceptors", 63 | Filename: "annotations.proto", 64 | } 65 | 66 | var E_MethodInterceptors = &proto.ExtensionDesc{ 67 | ExtendedType: (*google_protobuf.MethodOptions)(nil), 68 | ExtensionType: (*Interceptors)(nil), 69 | Field: 1041, 70 | Name: "grpcmw.method_interceptors", 71 | Tag: "bytes,1041,opt,name=method_interceptors,json=methodInterceptors", 72 | Filename: "annotations.proto", 73 | } 74 | 75 | func init() { 76 | proto.RegisterType((*Interceptors)(nil), "grpcmw.Interceptors") 77 | proto.RegisterExtension(E_PackageInterceptors) 78 | proto.RegisterExtension(E_ServiceInterceptors) 79 | proto.RegisterExtension(E_MethodInterceptors) 80 | } 81 | 82 | func init() { proto.RegisterFile("annotations.proto", fileDescriptor0) } 83 | 84 | var fileDescriptor0 = []byte{ 85 | // 215 bytes of a gzipped FileDescriptorProto 86 | 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x4c, 0xcc, 0xcb, 0xcb, 87 | 0x2f, 0x49, 0x2c, 0xc9, 0xcc, 0xcf, 0x2b, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x4b, 88 | 0x2f, 0x2a, 0x48, 0xce, 0x2d, 0x97, 0x52, 0x48, 0xcf, 0xcf, 0x4f, 0xcf, 0x49, 0xd5, 0x07, 0x8b, 89 | 0x26, 0x95, 0xa6, 0xe9, 0xa7, 0xa4, 0x16, 0x27, 0x17, 0x65, 0x16, 0x94, 0xe4, 0x17, 0x41, 0x54, 90 | 0x2a, 0x69, 0x70, 0xf1, 0x78, 0xe6, 0x95, 0xa4, 0x16, 0x25, 0xa7, 0x82, 0x04, 0x8b, 0x85, 0x24, 91 | 0xb8, 0xd8, 0x33, 0xf3, 0x52, 0x52, 0x2b, 0x52, 0x8b, 0x25, 0x18, 0x15, 0x98, 0x35, 0x38, 0x83, 92 | 0x60, 0x5c, 0xab, 0x34, 0x2e, 0x91, 0x82, 0xc4, 0xe4, 0xec, 0xc4, 0xf4, 0xd4, 0xf8, 0x4c, 0x64, 93 | 0x1d, 0x32, 0x7a, 0x10, 0x4b, 0xf4, 0x60, 0x96, 0xe8, 0xb9, 0x65, 0xe6, 0xa4, 0xfa, 0x17, 0x80, 94 | 0xdd, 0x23, 0x31, 0x91, 0x43, 0x81, 0x51, 0x83, 0xdb, 0x48, 0x44, 0x0f, 0xe2, 0x22, 0x3d, 0x64, 95 | 0xcb, 0x82, 0x84, 0xa1, 0x06, 0x22, 0x0b, 0x5a, 0x65, 0x72, 0x89, 0x14, 0xa7, 0x16, 0x95, 0x65, 96 | 0x26, 0xa3, 0xd9, 0x23, 0x8f, 0x61, 0x4f, 0x30, 0x44, 0x19, 0x71, 0x56, 0x41, 0xcd, 0x44, 0xb1, 97 | 0x2a, 0x8d, 0x4b, 0x38, 0x37, 0xb5, 0x24, 0x23, 0x3f, 0x05, 0xd5, 0x26, 0x39, 0x0c, 0x9b, 0x7c, 98 | 0xc1, 0xaa, 0x88, 0xb2, 0x48, 0x08, 0x62, 0x22, 0xb2, 0x18, 0x20, 0x00, 0x00, 0xff, 0xff, 0xc0, 99 | 0x67, 0xfd, 0xf1, 0xa2, 0x01, 0x00, 0x00, 100 | } 101 | -------------------------------------------------------------------------------- /annotations/annotations.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | import "google/protobuf/descriptor.proto"; 4 | 5 | package grpcmw; 6 | 7 | extend google.protobuf.FileOptions { 8 | optional Interceptors package_interceptors = 1041; 9 | } 10 | 11 | extend google.protobuf.ServiceOptions { 12 | optional Interceptors service_interceptors = 1041; 13 | } 14 | 15 | extend google.protobuf.MethodOptions { 16 | optional Interceptors method_interceptors = 1041; 17 | } 18 | 19 | message Interceptors { 20 | repeated string indexes = 1; 21 | } 22 | -------------------------------------------------------------------------------- /examples/Makefile: -------------------------------------------------------------------------------- 1 | PROTOC = protoc 2 | 3 | PROTO_SRC = proto/example.proto 4 | PROTO_PB_GO = $(PROTO_SRC:.proto=.pb.go) 5 | PROTO_PB_MW_GO = $(PROTO_SRC:.proto=.pb.mw.go) 6 | 7 | .PHONY: all clean re 8 | 9 | all: $(PROTO_PB_GO) $(PROTO_PB_MW_GO) 10 | 11 | clean: 12 | $(RM) $(PROTO_PB_GO) $(PROTO_PB_MW_GO) 13 | 14 | re: clean all 15 | 16 | 17 | %.pb.go: %.proto 18 | $(PROTOC) -I $(GOPATH)/src:. --go_out=plugins=grpc:. $^ 19 | 20 | %.pb.mw.go: %.proto 21 | $(PROTOC) -I $(GOPATH)/src:. --grpc-middleware_out=:. $^ 22 | -------------------------------------------------------------------------------- /examples/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "net" 7 | 8 | "golang.org/x/net/context" 9 | 10 | "github.com/MarquisIO/go-grpcmw/examples/proto" 11 | "github.com/MarquisIO/go-grpcmw/examples/server" 12 | "github.com/MarquisIO/go-grpcmw/grpcmw" 13 | "github.com/MarquisIO/go-grpcmw/grpcmw/registry" 14 | "google.golang.org/grpc" 15 | ) 16 | 17 | func serverMiddlewareRegistry(level string) grpc.UnaryServerInterceptor { 18 | return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 19 | fmt.Printf("enter server : %s level of middleware (registry)\n", level) 20 | defer fmt.Printf("leave server : %s level of middleware (registry)\n", level) 21 | return handler(ctx, req) 22 | } 23 | } 24 | 25 | func clientMiddlewareRegistry(level string) grpc.UnaryClientInterceptor { 26 | return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 27 | fmt.Printf("enter client : %s level of middleware (registry)\n", level) 28 | defer fmt.Printf("leave client : %s level of middleware (registry)\n", level) 29 | return invoker(ctx, method, req, reply, cc, opts...) 30 | } 31 | } 32 | 33 | func serverMiddleware(level string) grpc.UnaryServerInterceptor { 34 | return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 35 | fmt.Printf("enter server : %s level of middleware\n", level) 36 | defer fmt.Printf("leave server : %s level of middleware\n", level) 37 | return handler(ctx, req) 38 | } 39 | } 40 | 41 | func clientMiddleware(level string) grpc.UnaryClientInterceptor { 42 | return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 43 | fmt.Printf("enter client : %s level of middleware\n", level) 44 | defer fmt.Printf("leave client : %s level of middleware\n", level) 45 | return invoker(ctx, method, req, reply, cc, opts...) 46 | } 47 | } 48 | 49 | func startServer(port uint16) (*grpc.Server, net.Listener) { 50 | // Server 51 | // Setup global server router 52 | r := grpcmw.NewServerRouter() 53 | r.GetRegister().AddGRPCUnaryInterceptor(serverMiddleware("global")) 54 | 55 | pkgInterceptors := pb.RegisterServerInterceptors(r) 56 | pkgInterceptors.AddGRPCUnaryInterceptor(serverMiddleware("package")) 57 | pkgInterceptors.RegisterService().AddGRPCUnaryInterceptor(serverMiddleware("service")) 58 | pkgInterceptors.RegisterService().Method().AddGRPCInterceptor(serverMiddleware("method")) 59 | 60 | // Create gRPC server and register the service 61 | var e serverpb.Example 62 | server := grpc.NewServer(grpc.UnaryInterceptor(r.UnaryResolver())) 63 | pb.RegisterServiceServer(server, &e) 64 | 65 | // Start listening 66 | lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) 67 | if err != nil { 68 | log.Fatalf("Could not create listener on port %d: %v", port, err) 69 | } 70 | go server.Serve(lis) 71 | return server, lis 72 | } 73 | 74 | func startClient(port uint16) (*grpc.ClientConn, pb.ServiceClient) { 75 | // Client 76 | // Setup global client router 77 | r := grpcmw.NewClientRouter() 78 | r.GetRegister().AddGRPCUnaryInterceptor(clientMiddleware("global")) 79 | 80 | pkgInterceptors := pb.RegisterClientInterceptors(r) 81 | pkgInterceptors.AddGRPCUnaryInterceptor(clientMiddleware("package")) 82 | pkgInterceptors.RegisterService().AddGRPCUnaryInterceptor(clientMiddleware("service")) 83 | pkgInterceptors.RegisterService().Method().AddGRPCInterceptor(clientMiddleware("method")) 84 | 85 | // Setup connection to the server 86 | target := fmt.Sprintf("127.0.0.1:%d", port) 87 | conn, err := grpc.Dial(target, 88 | grpc.WithInsecure(), 89 | grpc.WithUnaryInterceptor(r.UnaryResolver())) 90 | if err != nil { 91 | log.Fatalf("Could not dial \"%s\": %v", target, err) 92 | } 93 | return conn, pb.NewServiceClient(conn) 94 | } 95 | 96 | func main() { 97 | var port uint16 = 4242 98 | 99 | // Register middlewares on registry 100 | registry.GetClientInterceptor("pkg").AddGRPCUnaryInterceptor(clientMiddlewareRegistry("package")) 101 | registry.GetClientInterceptor("srv").AddGRPCUnaryInterceptor(clientMiddlewareRegistry("service")) 102 | registry.GetClientInterceptor("meth").AddGRPCUnaryInterceptor(clientMiddlewareRegistry("method")) 103 | registry.GetServerInterceptor("pkg").AddGRPCUnaryInterceptor(serverMiddlewareRegistry("package")) 104 | registry.GetServerInterceptor("srv").AddGRPCUnaryInterceptor(serverMiddlewareRegistry("service")) 105 | registry.GetServerInterceptor("meth").AddGRPCUnaryInterceptor(serverMiddlewareRegistry("method")) 106 | 107 | server, lis := startServer(port) 108 | defer lis.Close() 109 | defer server.GracefulStop() 110 | 111 | conn, client := startClient(port) 112 | defer conn.Close() 113 | 114 | msg, err := client.Method(context.Background(), &pb.Message{Msg: "message"}) 115 | if err != nil { 116 | log.Fatalf("Call to Method failed: %v", err) 117 | } 118 | fmt.Printf("Received : %s\n", msg.Msg) 119 | } 120 | -------------------------------------------------------------------------------- /examples/proto/example.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | import "github.com/MarquisIO/go-grpcmw/annotations/annotations.proto"; 4 | 5 | package pb; 6 | 7 | option (grpcmw.package_interceptors) = { 8 | indexes: ["pkg"] 9 | }; 10 | 11 | service Service { 12 | option (grpcmw.service_interceptors) = { 13 | indexes: ["srv"] 14 | }; 15 | 16 | rpc Method (Message) returns (Message) { 17 | option (grpcmw.method_interceptors) = { 18 | indexes: ["meth"] 19 | }; 20 | } 21 | } 22 | 23 | message Message { 24 | string msg = 1; 25 | } 26 | -------------------------------------------------------------------------------- /examples/server/server.go: -------------------------------------------------------------------------------- 1 | package serverpb 2 | 3 | import ( 4 | "fmt" 5 | 6 | "golang.org/x/net/context" 7 | 8 | "github.com/MarquisIO/go-grpcmw/examples/proto" 9 | ) 10 | 11 | // Example implements the `pb.Example` service 12 | type Example struct{} 13 | 14 | // Method prints: 15 | // "Received : " 16 | func (e *Example) Method(ctx context.Context, msg *pb.Message) (*pb.Message, error) { 17 | fmt.Printf("Received : %s\n", msg.Msg) 18 | return msg, nil 19 | } 20 | -------------------------------------------------------------------------------- /grpcmw/client_interceptor.go: -------------------------------------------------------------------------------- 1 | package grpcmw 2 | 3 | import ( 4 | "sync" 5 | 6 | "golang.org/x/net/context" 7 | 8 | "google.golang.org/grpc" 9 | ) 10 | 11 | // StreamClientInterceptor represents a client interceptor for gRPC methods that 12 | // return a stream. It allows chaining of `grpc.StreamClientInterceptor` 13 | // and other `StreamClientInterceptor`. 14 | type StreamClientInterceptor interface { 15 | // Interceptor chains all added interceptors into a single 16 | // `grpc.StreamClientInterceptor`. 17 | Interceptor() grpc.StreamClientInterceptor 18 | // AddGRPCInterceptor adds given interceptors to the chain. 19 | AddGRPCInterceptor(i ...grpc.StreamClientInterceptor) StreamClientInterceptor 20 | // AddInterceptor is a convenient way for adding `StreamClientInterceptor` 21 | // to the chain of interceptors. 22 | AddInterceptor(i ...StreamClientInterceptor) StreamClientInterceptor 23 | } 24 | 25 | // UnaryClientInterceptor represents a client interceptor for gRPC methods that 26 | // return a single value. It allows chaining of `grpc.UnaryClientInterceptor` 27 | // and other `UnaryClientInterceptor`. 28 | type UnaryClientInterceptor interface { 29 | // Interceptor chains all added interceptors into a single 30 | // `grpc.UnaryClientInterceptor`. 31 | Interceptor() grpc.UnaryClientInterceptor 32 | // AddGRPCInterceptor adds `arr` to the chain of interceptors. 33 | AddGRPCInterceptor(i ...grpc.UnaryClientInterceptor) UnaryClientInterceptor 34 | // AddInterceptor is a convenient way for adding `UnaryClientInterceptor` 35 | // to the chain of interceptors. 36 | AddInterceptor(i ...UnaryClientInterceptor) UnaryClientInterceptor 37 | } 38 | 39 | type streamClientInterceptor struct { 40 | lock *sync.RWMutex 41 | interceptors []grpc.StreamClientInterceptor 42 | } 43 | 44 | type unaryClientInterceptor struct { 45 | lock *sync.RWMutex 46 | interceptors []grpc.UnaryClientInterceptor 47 | } 48 | 49 | // NewStreamClientInterceptor returns a new `StreamClientInterceptor`. 50 | // It initializes its interceptor chain with `arr`. 51 | // This implementation is thread-safe. 52 | func NewStreamClientInterceptor(arr ...grpc.StreamClientInterceptor) StreamClientInterceptor { 53 | return &streamClientInterceptor{ 54 | interceptors: arr, 55 | lock: &sync.RWMutex{}, 56 | } 57 | } 58 | 59 | func chainStreamClientInterceptor(current grpc.StreamClientInterceptor, next grpc.Streamer) grpc.Streamer { 60 | return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { 61 | return current(ctx, desc, cc, method, next, opts...) 62 | } 63 | } 64 | 65 | // Interceptor chains all added interceptors into a single 66 | // `grpc.StreamClientInterceptor`. 67 | // 68 | // The `streamer` passed to each interceptor is either the next interceptor or, 69 | // for the last element of the chain, the target method. 70 | func (si streamClientInterceptor) Interceptor() grpc.StreamClientInterceptor { 71 | return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { 72 | // TODO: Find a more efficient way 73 | interceptor := streamer 74 | si.lock.RLock() 75 | for idx := len(si.interceptors) - 1; idx >= 0; idx-- { 76 | interceptor = chainStreamClientInterceptor(si.interceptors[idx], interceptor) 77 | } 78 | si.lock.RUnlock() 79 | return interceptor(ctx, desc, cc, method, opts...) 80 | } 81 | } 82 | 83 | // AddGRPCInterceptor adds `arr` to the chain of interceptors. 84 | func (si *streamClientInterceptor) AddGRPCInterceptor(arr ...grpc.StreamClientInterceptor) StreamClientInterceptor { 85 | si.lock.Lock() 86 | defer si.lock.Unlock() 87 | si.interceptors = append(si.interceptors, arr...) 88 | return si 89 | } 90 | 91 | // AddInterceptor is a convenient way for adding `StreamClientInterceptor` 92 | // to the chain of interceptors. It only calls the method `Interceptor` 93 | // for each of them and append the return value to the chain. 94 | func (si *streamClientInterceptor) AddInterceptor(arr ...StreamClientInterceptor) StreamClientInterceptor { 95 | si.lock.Lock() 96 | defer si.lock.Unlock() 97 | for _, i := range arr { 98 | si.interceptors = append(si.interceptors, i.Interceptor()) 99 | } 100 | return si 101 | } 102 | 103 | // NewUnaryClientInterceptor returns a new `UnaryClientInterceptor`. 104 | // It initializes its interceptor chain with `arr`. 105 | // This implementation is thread-safe. 106 | func NewUnaryClientInterceptor(arr ...grpc.UnaryClientInterceptor) UnaryClientInterceptor { 107 | return &unaryClientInterceptor{ 108 | interceptors: arr, 109 | lock: &sync.RWMutex{}, 110 | } 111 | } 112 | 113 | func chainUnaryClientInterceptor(current grpc.UnaryClientInterceptor, next grpc.UnaryInvoker) grpc.UnaryInvoker { 114 | return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 115 | return current(ctx, method, req, reply, cc, next, opts...) 116 | } 117 | } 118 | 119 | // Interceptor chains all added interceptors into a single 120 | // `grpc.UnaryClientInterceptor`. 121 | // 122 | // The `streamer` passed to each interceptor is either the next interceptor or, 123 | // for the last element of the chain, the target method. 124 | func (ui *unaryClientInterceptor) Interceptor() grpc.UnaryClientInterceptor { 125 | return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 126 | // TODO: Find a more efficient way 127 | interceptor := invoker 128 | ui.lock.RLock() 129 | for idx := len(ui.interceptors) - 1; idx >= 0; idx-- { 130 | interceptor = chainUnaryClientInterceptor(ui.interceptors[idx], interceptor) 131 | } 132 | ui.lock.RUnlock() 133 | return interceptor(ctx, method, req, reply, cc, opts...) 134 | } 135 | } 136 | 137 | // AddGRPCInterceptor adds `arr` to the chain of interceptors. 138 | func (ui *unaryClientInterceptor) AddGRPCInterceptor(arr ...grpc.UnaryClientInterceptor) UnaryClientInterceptor { 139 | ui.lock.Lock() 140 | defer ui.lock.Unlock() 141 | ui.interceptors = append(ui.interceptors, arr...) 142 | return ui 143 | } 144 | 145 | // AddInterceptor is a convenient way for adding `UnaryClientInterceptor` 146 | // to the chain of interceptors. It only calls the method `Interceptor` 147 | // for each of them and append the return value to the chain. 148 | func (ui *unaryClientInterceptor) AddInterceptor(arr ...UnaryClientInterceptor) UnaryClientInterceptor { 149 | ui.lock.Lock() 150 | defer ui.lock.Unlock() 151 | for _, i := range arr { 152 | ui.interceptors = append(ui.interceptors, i.Interceptor()) 153 | } 154 | return ui 155 | } 156 | -------------------------------------------------------------------------------- /grpcmw/client_level.go: -------------------------------------------------------------------------------- 1 | package grpcmw 2 | 3 | import ( 4 | "sync" 5 | 6 | "google.golang.org/grpc" 7 | ) 8 | 9 | // ClientInterceptor represents a client interceptor that uses both 10 | // `UnaryClientInterceptor` and `StreamClientInterceptor` and that can be 11 | // indexed. 12 | type ClientInterceptor interface { 13 | // AddGRPCUnaryInterceptor adds given unary interceptors to the chain. 14 | AddGRPCUnaryInterceptor(i ...grpc.UnaryClientInterceptor) ClientInterceptor 15 | // AddUnaryInterceptor is a convenient way for adding `UnaryClientInterceptor` 16 | // to the chain of unary interceptors. 17 | AddUnaryInterceptor(i ...UnaryClientInterceptor) ClientInterceptor 18 | // UnaryClientInterceptor returns the chain of unary interceptors. 19 | UnaryClientInterceptor() UnaryClientInterceptor 20 | // AddGRPCStreamInterceptor adds given stream interceptors to the chain. 21 | AddGRPCStreamInterceptor(i ...grpc.StreamClientInterceptor) ClientInterceptor 22 | // AddStreamInterceptor is a convenient way for adding 23 | // `StreamClientInterceptor` to the chain of stream interceptors. 24 | AddStreamInterceptor(i ...StreamClientInterceptor) ClientInterceptor 25 | // StreamClientInterceptor returns the chain of stream interceptors. 26 | StreamClientInterceptor() StreamClientInterceptor 27 | // Merge merges the given interceptors with the current interceptor. 28 | Merge(i ...ClientInterceptor) ClientInterceptor 29 | // Index returns the index of the `ClientInterceptor`. 30 | Index() string 31 | } 32 | 33 | // ClientInterceptorRegister represents a register of `ClientInterceptor`, 34 | // indexing them by using their method `Index`. 35 | // It also implements `ClientInterceptor`. 36 | type ClientInterceptorRegister interface { 37 | ClientInterceptor 38 | // Register registers `level` at the index returned by its method `Index`. 39 | Register(level ClientInterceptor) 40 | // Get returns the `ClientInterceptor` registered at the index `key`. If 41 | // nothing is found, it returns (nil, false). 42 | Get(key string) (ClientInterceptor, bool) 43 | } 44 | 45 | type lowerClientInterceptor struct { 46 | unaries UnaryClientInterceptor 47 | streams StreamClientInterceptor 48 | index string 49 | } 50 | 51 | type higherClientInterceptorLevel struct { 52 | ClientInterceptor 53 | sublevels map[string]ClientInterceptor 54 | lock *sync.RWMutex 55 | } 56 | 57 | // NewClientInterceptor initializes a new `ClientInterceptor` with `index` 58 | // as its index. It initializes the underlying `UnaryClientInterceptor` and 59 | // `StreamClientInterceptor`. 60 | // This implementation is thread-safe. 61 | func NewClientInterceptor(index string) ClientInterceptor { 62 | return &lowerClientInterceptor{ 63 | unaries: NewUnaryClientInterceptor(), 64 | streams: NewStreamClientInterceptor(), 65 | index: index, 66 | } 67 | } 68 | 69 | // Index returns the index of the `ClientInterceptor`. 70 | func (l lowerClientInterceptor) Index() string { 71 | return l.index 72 | } 73 | 74 | // AddGRPCUnaryInterceptor calls `AddGRPCInterceptor` of the underlying 75 | // `UnaryClientInterceptor`. It returns the current instance of 76 | // `ClientInterceptor` to allow chaining. 77 | func (l *lowerClientInterceptor) AddGRPCUnaryInterceptor(arr ...grpc.UnaryClientInterceptor) ClientInterceptor { 78 | l.unaries.AddGRPCInterceptor(arr...) 79 | return l 80 | } 81 | 82 | // AddUnaryInterceptor calls `AddInterceptor` of the underlying 83 | // `UnaryClientInterceptor`. It returns the current instance of 84 | // `ClientInterceptor` to allow chaining. 85 | func (l *lowerClientInterceptor) AddUnaryInterceptor(arr ...UnaryClientInterceptor) ClientInterceptor { 86 | l.unaries.AddInterceptor(arr...) 87 | return l 88 | } 89 | 90 | // UnaryClientInterceptor returns the underlying instance of 91 | // `UnaryClientInterceptor`. 92 | func (l *lowerClientInterceptor) UnaryClientInterceptor() UnaryClientInterceptor { 93 | return l.unaries 94 | } 95 | 96 | // AddGRPCStreamInterceptor calls `AddGRPCInterceptor` of the underlying 97 | // `StreamClientInterceptor`. It returns the current instance of 98 | // `ClientInterceptor` to allow chaining. 99 | func (l *lowerClientInterceptor) AddGRPCStreamInterceptor(arr ...grpc.StreamClientInterceptor) ClientInterceptor { 100 | l.streams.AddGRPCInterceptor(arr...) 101 | return l 102 | } 103 | 104 | // AddStreamInterceptor calls `AddGRPCInterceptor` of the underlying 105 | // `StreamClientInterceptor`. It returns the current instance of 106 | // `ClientInterceptor` to allow chaining. 107 | func (l *lowerClientInterceptor) AddStreamInterceptor(arr ...StreamClientInterceptor) ClientInterceptor { 108 | l.streams.AddInterceptor(arr...) 109 | return l 110 | } 111 | 112 | // StreamClientInterceptor returns the underlying instance of 113 | // `StreamClientInterceptor`. 114 | func (l *lowerClientInterceptor) StreamClientInterceptor() StreamClientInterceptor { 115 | return l.streams 116 | } 117 | 118 | // Merge merges the given interceptors with the current interceptor. 119 | func (l *lowerClientInterceptor) Merge(interceptors ...ClientInterceptor) ClientInterceptor { 120 | for _, interceptor := range interceptors { 121 | l.AddUnaryInterceptor(interceptor.UnaryClientInterceptor()). 122 | AddStreamInterceptor(interceptor.StreamClientInterceptor()) 123 | } 124 | return l 125 | } 126 | 127 | // NewClientInterceptorRegister initializes a `ClientInterceptorRegister` with 128 | // an empty register and `index` as index as its index. 129 | // This implementation is thread-safe. 130 | func NewClientInterceptorRegister(index string) ClientInterceptorRegister { 131 | return &higherClientInterceptorLevel{ 132 | ClientInterceptor: NewClientInterceptor(index), 133 | sublevels: make(map[string]ClientInterceptor), 134 | lock: &sync.RWMutex{}, 135 | } 136 | } 137 | 138 | // Get returns the `ClientInterceptor` registered at the index `key`. If nothing 139 | // is found, it returns (nil, false). 140 | func (l higherClientInterceptorLevel) Get(key string) (interceptor ClientInterceptor, exists bool) { 141 | l.lock.RLock() 142 | defer l.lock.RUnlock() 143 | interceptor, exists = l.sublevels[key] 144 | return 145 | } 146 | 147 | // Register registers `level` at the index returned by its method `Index`. 148 | // It overwrites any interceptor that has already been registered at this index. 149 | func (l *higherClientInterceptorLevel) Register(level ClientInterceptor) { 150 | l.lock.Lock() 151 | defer l.lock.Unlock() 152 | l.sublevels[level.Index()] = level 153 | } 154 | -------------------------------------------------------------------------------- /grpcmw/client_router.go: -------------------------------------------------------------------------------- 1 | package grpcmw 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "golang.org/x/net/context" 8 | 9 | "google.golang.org/grpc" 10 | "google.golang.org/grpc/codes" 11 | ) 12 | 13 | // ClientRouter represents route resolver that allows to use the appropriate 14 | // chain of interceptors for a given gRPC request with an interceptor register. 15 | type ClientRouter interface { 16 | // GetRegister returns the interceptor register of the router. 17 | GetRegister() ClientInterceptorRegister 18 | // SetRegister sets the interceptor register of the router. 19 | SetRegister(reg ClientInterceptorRegister) 20 | // UnaryResolver returns a `grpc.UnaryClientInterceptor` that uses the 21 | // appropriate chain of interceptors with the given unary gRPC request. 22 | UnaryResolver() grpc.UnaryClientInterceptor 23 | // StreamResolver returns a `grpc.StreamClientInterceptor` that uses the 24 | // appropriate chain of interceptors with the given stream gRPC request. 25 | StreamResolver() grpc.StreamClientInterceptor 26 | } 27 | 28 | type clientRouter struct { 29 | interceptors ClientInterceptorRegister 30 | } 31 | 32 | // NewClientRouter initializes a `ClientRouter`. 33 | // This implementation is based on the official route format used by gRPC as 34 | // defined here : 35 | // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md 36 | // 37 | // Based on this format, this implementation splits the interceptors into four 38 | // levels: 39 | // - the global level: these are the interceptors called at each request. 40 | // - the package level: these are the interceptors called at each request to 41 | // a service from the corresponding package. 42 | // - the service level: these are the interceptors called at each request to 43 | // a method from the corresponding service. 44 | // - the method level: these are the interceptors called at each request to 45 | // the specific method. 46 | func NewClientRouter() ClientRouter { 47 | return &clientRouter{ 48 | interceptors: NewClientInterceptorRegister("global"), 49 | } 50 | } 51 | 52 | func resolveClientInterceptorRec(pathTokens []string, lvl ClientInterceptor, cb func(lvl ClientInterceptor), force bool) (ClientInterceptor, error) { 53 | if cb != nil { 54 | cb(lvl) 55 | } 56 | if len(pathTokens) == 0 || len(pathTokens[0]) == 0 { 57 | return lvl, nil 58 | } 59 | reg, ok := lvl.(ClientInterceptorRegister) 60 | if !ok { 61 | return nil, fmt.Errorf("Level %s does not implement grpcmw.ClientInterceptorRegister", lvl.Index()) 62 | } 63 | sub, exists := reg.Get(pathTokens[0]) 64 | if !exists { 65 | if force { 66 | if len(pathTokens) == 1 { 67 | sub = NewClientInterceptor(pathTokens[0]) 68 | } else { 69 | sub = NewClientInterceptorRegister(pathTokens[0]) 70 | } 71 | reg.Register(sub) 72 | } else { 73 | return nil, nil 74 | } 75 | } 76 | return resolveClientInterceptorRec(pathTokens[1:], sub, cb, force) 77 | } 78 | 79 | func resolveClientInterceptor(route string, lvl ClientInterceptor, cb func(lvl ClientInterceptor), force bool) (ClientInterceptor, error) { 80 | // TODO: Find a more efficient way to resolve the route 81 | matchs := routeRegexp.FindStringSubmatch(route) 82 | if len(matchs) == 0 { 83 | return nil, errors.New("Invalid route") 84 | } 85 | return resolveClientInterceptorRec(matchs[1:], lvl, cb, force) 86 | } 87 | 88 | // UnaryResolver returns a `grpc.UnaryClientInterceptor` that uses the 89 | // appropriate chain of interceptors with the given gRPC request. 90 | func (r *clientRouter) UnaryResolver() grpc.UnaryClientInterceptor { 91 | return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { 92 | // TODO: Find a more efficient way to chain the interceptors 93 | interceptor := NewUnaryClientInterceptor() 94 | _, err := resolveClientInterceptor(method, r.interceptors, func(lvl ClientInterceptor) { 95 | interceptor.AddInterceptor(lvl.UnaryClientInterceptor()) 96 | }, false) 97 | if err != nil { 98 | return grpc.Errorf(codes.Internal, err.Error()) 99 | } 100 | return interceptor.Interceptor()(ctx, method, req, reply, cc, invoker, opts...) 101 | } 102 | } 103 | 104 | // StreamResolver returns a `grpc.StreamClientInterceptor` that uses the 105 | // appropriate chain of interceptors with the given stream gRPC request. 106 | func (r *clientRouter) StreamResolver() grpc.StreamClientInterceptor { 107 | return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { 108 | // TODO: Find a more efficient way to chain the interceptors 109 | interceptor := NewStreamClientInterceptor() 110 | _, err := resolveClientInterceptor(method, r.interceptors, func(lvl ClientInterceptor) { 111 | interceptor.AddInterceptor(lvl.StreamClientInterceptor()) 112 | }, false) 113 | if err != nil { 114 | return nil, grpc.Errorf(codes.Internal, err.Error()) 115 | } 116 | return interceptor.Interceptor()(ctx, desc, cc, method, streamer, opts...) 117 | } 118 | } 119 | 120 | // GetRegister returns the underlying `ClientInterceptorRegister` which is the 121 | // global level in the interceptor chain. 122 | func (r *clientRouter) GetRegister() ClientInterceptorRegister { 123 | return r.interceptors 124 | } 125 | 126 | // SetRegister sets the interceptor register of the router. 127 | func (r *clientRouter) SetRegister(reg ClientInterceptorRegister) { 128 | r.interceptors = reg 129 | } 130 | -------------------------------------------------------------------------------- /grpcmw/registry/client.go: -------------------------------------------------------------------------------- 1 | package registry 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/MarquisIO/go-grpcmw/grpcmw" 7 | ) 8 | 9 | var ( 10 | clientLock sync.Mutex 11 | clientRegistry = make(map[string]grpcmw.ClientInterceptor) 12 | ) 13 | 14 | // GetClientInterceptor returns the `grpcmw.ClientInterceptor` registered at 15 | // `index`. If nothing is at this `index`, it registers a new one using 16 | // `grpcmw.NewClientInterceptor` and returns it. 17 | // This is thread-safe. 18 | func GetClientInterceptor(index string) grpcmw.ClientInterceptor { 19 | clientLock.Lock() 20 | defer clientLock.Unlock() 21 | intcp, ok := clientRegistry[index] 22 | if !ok { 23 | intcp = grpcmw.NewClientInterceptor(index) 24 | clientRegistry[index] = intcp 25 | } 26 | return intcp 27 | } 28 | 29 | // SetClientInterceptor registers `interceptor` at `index`. It replaces any 30 | // interceptor that has been previously registered at this `index`. 31 | // This is thread-safe. 32 | func SetClientInterceptor(index string, interceptor grpcmw.ClientInterceptor) { 33 | clientLock.Lock() 34 | defer clientLock.Unlock() 35 | clientRegistry[index] = interceptor 36 | } 37 | 38 | // DeleteClientInterceptor deletes any interceptor registered at `index`. 39 | // This is thread-safe. 40 | func DeleteClientInterceptor(index string) { 41 | clientLock.Lock() 42 | defer clientLock.Unlock() 43 | delete(clientRegistry, index) 44 | } 45 | -------------------------------------------------------------------------------- /grpcmw/registry/server.go: -------------------------------------------------------------------------------- 1 | package registry 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/MarquisIO/go-grpcmw/grpcmw" 7 | ) 8 | 9 | var ( 10 | serverLock sync.Mutex 11 | serverRegistry = make(map[string]grpcmw.ServerInterceptor) 12 | ) 13 | 14 | // GetServerInterceptor returns the `grpcmw.ServerInterceptor` registered at 15 | // `index`. If nothing is at this `index`, it registers a new one using 16 | // `grpcmw.NewServerInterceptor` and returns it. 17 | // This is thread-safe. 18 | func GetServerInterceptor(index string) grpcmw.ServerInterceptor { 19 | serverLock.Lock() 20 | defer serverLock.Unlock() 21 | intcp, ok := serverRegistry[index] 22 | if !ok { 23 | intcp = grpcmw.NewServerInterceptor(index) 24 | serverRegistry[index] = intcp 25 | } 26 | return intcp 27 | } 28 | 29 | // SetServerInterceptor registers `interceptor` at `index`. It replaces any 30 | // interceptor that has been previously registered at this `index`. 31 | // This is thread-safe. 32 | func SetServerInterceptor(index string, interceptor grpcmw.ServerInterceptor) { 33 | serverLock.Lock() 34 | defer serverLock.Unlock() 35 | serverRegistry[index] = interceptor 36 | } 37 | 38 | // DeleteServerInterceptor deletes any interceptor registered at `index`. 39 | // This is thread-safe. 40 | func DeleteServerInterceptor(index string) { 41 | serverLock.Lock() 42 | defer serverLock.Unlock() 43 | delete(serverRegistry, index) 44 | } 45 | -------------------------------------------------------------------------------- /grpcmw/route.go: -------------------------------------------------------------------------------- 1 | package grpcmw 2 | 3 | import "regexp" 4 | 5 | var ( 6 | routeRegexp = regexp.MustCompile(`\/(?:(.+)\.)?(.+)\/(.+)`) 7 | ) 8 | -------------------------------------------------------------------------------- /grpcmw/server_interceptor.go: -------------------------------------------------------------------------------- 1 | package grpcmw 2 | 3 | import ( 4 | "sync" 5 | 6 | "golang.org/x/net/context" 7 | 8 | "google.golang.org/grpc" 9 | ) 10 | 11 | // StreamServerInterceptor represents a server interceptor for gRPC methods that 12 | // return a stream. It allows chaining of `grpc.StreamServerInterceptor` 13 | // and other `StreamServerInterceptor`. 14 | type StreamServerInterceptor interface { 15 | // Interceptor chains all added interceptors into a single 16 | // `grpc.StreamServerInterceptor`. 17 | Interceptor() grpc.StreamServerInterceptor 18 | // AddGRPCInterceptor adds given interceptors to the chain. 19 | AddGRPCInterceptor(i ...grpc.StreamServerInterceptor) StreamServerInterceptor 20 | // AddInterceptor is a convenient way for adding `StreamServerInterceptor` 21 | // to the chain of interceptors. 22 | AddInterceptor(i ...StreamServerInterceptor) StreamServerInterceptor 23 | } 24 | 25 | // UnaryServerInterceptor represents a server interceptor for gRPC methods that 26 | // return a single value. It allows chaining of `grpc.UnaryServerInterceptor` 27 | // and other `UnaryServerInterceptor`. 28 | type UnaryServerInterceptor interface { 29 | // Interceptor chains all added interceptors into a single 30 | // `grpc.UnaryServerInterceptor`. 31 | Interceptor() grpc.UnaryServerInterceptor 32 | // AddGRPCInterceptor adds given interceptors to the chain. 33 | AddGRPCInterceptor(i ...grpc.UnaryServerInterceptor) UnaryServerInterceptor 34 | // AddInterceptor is a convenient way for adding `UnaryServerInterceptor` 35 | // to the chain of interceptors. 36 | AddInterceptor(i ...UnaryServerInterceptor) UnaryServerInterceptor 37 | } 38 | 39 | type streamServerInterceptor struct { 40 | interceptors []grpc.StreamServerInterceptor 41 | lock *sync.RWMutex 42 | } 43 | 44 | type unaryServerInterceptor struct { 45 | interceptors []grpc.UnaryServerInterceptor 46 | lock *sync.RWMutex 47 | } 48 | 49 | // NewStreamServerInterceptor returns a new `StreamServerInterceptor`. 50 | // It initializes its interceptor chain with `arr`. 51 | // This implementation is thread-safe. 52 | func NewStreamServerInterceptor(arr ...grpc.StreamServerInterceptor) StreamServerInterceptor { 53 | return &streamServerInterceptor{ 54 | interceptors: arr, 55 | lock: &sync.RWMutex{}, 56 | } 57 | } 58 | 59 | func chainStreamServerInterceptor(current grpc.StreamServerInterceptor, info *grpc.StreamServerInfo, next grpc.StreamHandler) grpc.StreamHandler { 60 | return func(srv interface{}, stream grpc.ServerStream) error { 61 | return current(srv, stream, info, next) 62 | } 63 | } 64 | 65 | // Interceptor chains all added interceptors into a single 66 | // `grpc.StreamServerInterceptor`. 67 | // 68 | // The `handler` passed to each interceptor is either the next interceptor or, 69 | // for the last element of the chain, the target method. 70 | func (si streamServerInterceptor) Interceptor() grpc.StreamServerInterceptor { 71 | return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 72 | // TODO: Find a more efficient way 73 | interceptor := handler 74 | si.lock.RLock() 75 | for idx := len(si.interceptors) - 1; idx >= 0; idx-- { 76 | interceptor = chainStreamServerInterceptor(si.interceptors[idx], info, interceptor) 77 | } 78 | si.lock.RUnlock() 79 | return interceptor(srv, ss) 80 | } 81 | } 82 | 83 | // AddGRPCInterceptor adds `arr` to the chain of interceptors. 84 | func (si *streamServerInterceptor) AddGRPCInterceptor(arr ...grpc.StreamServerInterceptor) StreamServerInterceptor { 85 | si.lock.Lock() 86 | defer si.lock.Unlock() 87 | si.interceptors = append(si.interceptors, arr...) 88 | return si 89 | } 90 | 91 | // AddInterceptor is a convenient way for adding `StreamServerInterceptor` 92 | // to the chain of interceptors. It only calls the method `Interceptor` 93 | // for each of them and append the return value to the chain. 94 | func (si *streamServerInterceptor) AddInterceptor(arr ...StreamServerInterceptor) StreamServerInterceptor { 95 | si.lock.Lock() 96 | defer si.lock.Unlock() 97 | for _, i := range arr { 98 | si.interceptors = append(si.interceptors, i.Interceptor()) 99 | } 100 | return si 101 | } 102 | 103 | // NewUnaryServerInterceptor returns a new `UnaryServerInterceptor`. 104 | // It initializes its interceptor chain with `arr`. 105 | // This implementation is thread-safe. 106 | func NewUnaryServerInterceptor(arr ...grpc.UnaryServerInterceptor) UnaryServerInterceptor { 107 | return &unaryServerInterceptor{ 108 | interceptors: arr, 109 | lock: &sync.RWMutex{}, 110 | } 111 | } 112 | 113 | func chainUnaryServerInterceptor(current grpc.UnaryServerInterceptor, info *grpc.UnaryServerInfo, next grpc.UnaryHandler) grpc.UnaryHandler { 114 | return func(ctx context.Context, req interface{}) (interface{}, error) { 115 | return current(ctx, req, info, next) 116 | } 117 | } 118 | 119 | // Interceptor chains all added interceptors into a single 120 | // `grpc.UnaryServerInterceptor`. 121 | // 122 | // The `handler` passed to each interceptor is either the next interceptor or, 123 | // for the last element of the chain, the target method. 124 | func (ui *unaryServerInterceptor) Interceptor() grpc.UnaryServerInterceptor { 125 | return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 126 | // TODO: Find a more efficient way 127 | interceptor := handler 128 | ui.lock.RLock() 129 | for idx := len(ui.interceptors) - 1; idx >= 0; idx-- { 130 | interceptor = chainUnaryServerInterceptor(ui.interceptors[idx], info, interceptor) 131 | } 132 | ui.lock.RUnlock() 133 | return interceptor(ctx, req) 134 | } 135 | } 136 | 137 | // AddGRPCInterceptor adds `arr` to the chain of interceptors. 138 | func (ui *unaryServerInterceptor) AddGRPCInterceptor(arr ...grpc.UnaryServerInterceptor) UnaryServerInterceptor { 139 | ui.lock.Lock() 140 | defer ui.lock.Unlock() 141 | ui.interceptors = append(ui.interceptors, arr...) 142 | return ui 143 | } 144 | 145 | // AddInterceptor is a convenient way for adding `UnaryServerInterceptor` 146 | // to the chain of interceptors. It only calls the method `Interceptor` 147 | // for each of them and append the return value to the chain. 148 | func (ui *unaryServerInterceptor) AddInterceptor(arr ...UnaryServerInterceptor) UnaryServerInterceptor { 149 | ui.lock.Lock() 150 | defer ui.lock.Unlock() 151 | for _, i := range arr { 152 | ui.interceptors = append(ui.interceptors, i.Interceptor()) 153 | } 154 | return ui 155 | } 156 | -------------------------------------------------------------------------------- /grpcmw/server_level.go: -------------------------------------------------------------------------------- 1 | package grpcmw 2 | 3 | import ( 4 | "sync" 5 | 6 | "google.golang.org/grpc" 7 | ) 8 | 9 | // ServerInterceptor represent a server interceptor that uses both 10 | // `UnaryServerInterceptor` and `StreamServerInterceptor` and that can be 11 | // indexed. 12 | type ServerInterceptor interface { 13 | // AddGRPCUnaryInterceptor adds given unary interceptors to the chain. 14 | AddGRPCUnaryInterceptor(i ...grpc.UnaryServerInterceptor) ServerInterceptor 15 | // AddUnaryInterceptor is a convenient way for adding `UnaryServerInterceptor` 16 | // to the chain of unary interceptors. 17 | AddUnaryInterceptor(i ...UnaryServerInterceptor) ServerInterceptor 18 | // UnaryServerInterceptor returns the chain of unary interceptors. 19 | UnaryServerInterceptor() UnaryServerInterceptor 20 | // AddGRPCStreamInterceptor adds given stream interceptors to the chain. 21 | AddGRPCStreamInterceptor(i ...grpc.StreamServerInterceptor) ServerInterceptor 22 | // AddStreamInterceptor is a convenient way for adding 23 | // `StreamServerInterceptor` to the chain of stream interceptors. 24 | AddStreamInterceptor(i ...StreamServerInterceptor) ServerInterceptor 25 | // StreamServerInterceptor returns the chain of stream interceptors. 26 | StreamServerInterceptor() StreamServerInterceptor 27 | // Merge merges the given interceptors with the current interceptor. 28 | Merge(interceptors ...ServerInterceptor) ServerInterceptor 29 | // Index returns the index of the `ServerInterceptor`. 30 | Index() string 31 | } 32 | 33 | // ServerInterceptorRegister represents a register of `ServerInterceptor`, 34 | // indexing them by using their method `Index`. 35 | // It also implements `ServerInterceptor`. 36 | type ServerInterceptorRegister interface { 37 | ServerInterceptor 38 | // Register registers `level` at the index returned by its method `Index`. 39 | Register(level ServerInterceptor) 40 | // Get returns the `ServerInterceptor` registered at the index `key`. If 41 | // nothing is found, it returns (nil, false). 42 | Get(key string) (ServerInterceptor, bool) 43 | } 44 | 45 | type lowerServerInterceptor struct { 46 | unaries UnaryServerInterceptor 47 | streams StreamServerInterceptor 48 | index string 49 | } 50 | 51 | type higherServerInterceptorLevel struct { 52 | ServerInterceptor 53 | sublevels map[string]ServerInterceptor 54 | lock *sync.RWMutex 55 | } 56 | 57 | // NewServerInterceptor initializes a new `ServerInterceptor` with `index` 58 | // as its index. It initializes the underlying `UnaryServerInterceptor` and 59 | // `StreamServerInterceptor`. 60 | // This implementation is thread-safe. 61 | func NewServerInterceptor(index string) ServerInterceptor { 62 | return &lowerServerInterceptor{ 63 | unaries: NewUnaryServerInterceptor(), 64 | streams: NewStreamServerInterceptor(), 65 | index: index, 66 | } 67 | } 68 | 69 | // Index returns the index of the `ServerInterceptor`. 70 | func (l lowerServerInterceptor) Index() string { 71 | return l.index 72 | } 73 | 74 | // AddGRPCUnaryInterceptor calls `AddGRPCInterceptor` of the underlying 75 | // `UnaryServerInterceptor`. It returns the current instance of 76 | // `ServerInterceptor` to allow chaining. 77 | func (l *lowerServerInterceptor) AddGRPCUnaryInterceptor(arr ...grpc.UnaryServerInterceptor) ServerInterceptor { 78 | l.unaries.AddGRPCInterceptor(arr...) 79 | return l 80 | } 81 | 82 | // AddUnaryInterceptor calls `AddInterceptor` of the underlying 83 | // `UnaryServerInterceptor`. It returns the current instance of 84 | // `ServerInterceptor` to allow chaining. 85 | func (l *lowerServerInterceptor) AddUnaryInterceptor(arr ...UnaryServerInterceptor) ServerInterceptor { 86 | l.unaries.AddInterceptor(arr...) 87 | return l 88 | } 89 | 90 | // UnaryServerInterceptor returns the underlying instance of 91 | // `UnaryServerInterceptor`. 92 | func (l *lowerServerInterceptor) UnaryServerInterceptor() UnaryServerInterceptor { 93 | return l.unaries 94 | } 95 | 96 | // AddGRPCStreamInterceptor calls `AddGRPCInterceptor` of the underlying 97 | // `StreamServerInterceptor`. It returns the current instance of 98 | // `ServerInterceptor` to allow chaining. 99 | func (l *lowerServerInterceptor) AddGRPCStreamInterceptor(arr ...grpc.StreamServerInterceptor) ServerInterceptor { 100 | l.streams.AddGRPCInterceptor(arr...) 101 | return l 102 | } 103 | 104 | // AddStreamInterceptor calls `AddGRPCInterceptor` of the underlying 105 | // `StreamServerInterceptor`. It returns the current instance of 106 | // `ServerInterceptor` to allow chaining. 107 | func (l *lowerServerInterceptor) AddStreamInterceptor(arr ...StreamServerInterceptor) ServerInterceptor { 108 | l.streams.AddInterceptor(arr...) 109 | return l 110 | } 111 | 112 | // StreamServerInterceptor returns the underlying instance of 113 | // `StreamServerInterceptor`. 114 | func (l *lowerServerInterceptor) StreamServerInterceptor() StreamServerInterceptor { 115 | return l.streams 116 | } 117 | 118 | // Merge merges the given interceptors with the current interceptor. 119 | func (l *lowerServerInterceptor) Merge(interceptors ...ServerInterceptor) ServerInterceptor { 120 | for _, interceptor := range interceptors { 121 | l.AddUnaryInterceptor(interceptor.UnaryServerInterceptor()). 122 | AddStreamInterceptor(interceptor.StreamServerInterceptor()) 123 | } 124 | return l 125 | } 126 | 127 | // NewServerInterceptorRegister initializes a `ServerInterceptorRegister` with 128 | // an empty register and `index` as index as its index. 129 | // This implementation is thread-safe. 130 | func NewServerInterceptorRegister(index string) ServerInterceptorRegister { 131 | return &higherServerInterceptorLevel{ 132 | ServerInterceptor: NewServerInterceptor(index), 133 | sublevels: make(map[string]ServerInterceptor), 134 | lock: &sync.RWMutex{}, 135 | } 136 | } 137 | 138 | // Get returns the `ServerInterceptor` registered at the index `key`. If nothing 139 | // is found, it returns (nil, false). 140 | func (l higherServerInterceptorLevel) Get(key string) (interceptor ServerInterceptor, exists bool) { 141 | l.lock.RLock() 142 | defer l.lock.RUnlock() 143 | interceptor, exists = l.sublevels[key] 144 | return 145 | } 146 | 147 | // Register registers `level` at the index returned by its method `Index`. 148 | // It overwrites any interceptor that has already been registered at this index. 149 | func (l *higherServerInterceptorLevel) Register(level ServerInterceptor) { 150 | l.lock.Lock() 151 | defer l.lock.Unlock() 152 | l.sublevels[level.Index()] = level 153 | } 154 | -------------------------------------------------------------------------------- /grpcmw/server_router.go: -------------------------------------------------------------------------------- 1 | package grpcmw 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "golang.org/x/net/context" 8 | 9 | "google.golang.org/grpc" 10 | "google.golang.org/grpc/codes" 11 | ) 12 | 13 | // ServerRouter represents route resolver that allows to use the appropriate 14 | // chain of interceptors for a given gRPC request with an interceptor register. 15 | type ServerRouter interface { 16 | // GetRegister returns the interceptor register of the router. 17 | GetRegister() ServerInterceptorRegister 18 | // SetRegister sets the interceptor register of the router. 19 | SetRegister(reg ServerInterceptorRegister) 20 | // UnaryResolver returns a `grpc.UnaryServerInterceptor` that uses the 21 | // appropriate chain of interceptors with the given unary gRPC request. 22 | UnaryResolver() grpc.UnaryServerInterceptor 23 | // StreamResolver returns a `grpc.StreamServerInterceptor` that uses the 24 | // appropriate chain of interceptors with the given stream gRPC request. 25 | StreamResolver() grpc.StreamServerInterceptor 26 | } 27 | 28 | type serverRouter struct { 29 | interceptors ServerInterceptorRegister 30 | } 31 | 32 | // NewServerRouter initializes a `ServerRouter`. 33 | // This implementation is based on the official route format used by gRPC as 34 | // defined here : 35 | // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md 36 | // 37 | // Based on this format, this implementation splits the interceptors into four 38 | // levels: 39 | // - the global level: these are the interceptors called at each request. 40 | // - the package level: these are the interceptors called at each request to 41 | // a service from the corresponding package. 42 | // - the service level: these are the interceptors called at each request to 43 | // a method from the corresponding service. 44 | // - the method level: these are the interceptors called at each request to 45 | // the specific method. 46 | func NewServerRouter() ServerRouter { 47 | return &serverRouter{ 48 | interceptors: NewServerInterceptorRegister("global"), 49 | } 50 | } 51 | 52 | func resolveServerInterceptorRec(pathTokens []string, lvl ServerInterceptor, cb func(lvl ServerInterceptor), force bool) (ServerInterceptor, error) { 53 | if cb != nil { 54 | cb(lvl) 55 | } 56 | if len(pathTokens) == 0 || len(pathTokens[0]) == 0 { 57 | return lvl, nil 58 | } 59 | reg, ok := lvl.(ServerInterceptorRegister) 60 | if !ok { 61 | return nil, fmt.Errorf("Level %s does not implement grpcmw.ServerInterceptorRegister", lvl.Index()) 62 | } 63 | sub, exists := reg.Get(pathTokens[0]) 64 | if !exists { 65 | if force { 66 | if len(pathTokens) == 1 { 67 | sub = NewServerInterceptor(pathTokens[0]) 68 | } else { 69 | sub = NewServerInterceptorRegister(pathTokens[0]) 70 | } 71 | reg.Register(sub) 72 | } else { 73 | return nil, nil 74 | } 75 | } 76 | return resolveServerInterceptorRec(pathTokens[1:], sub, cb, force) 77 | } 78 | 79 | func resolveServerInterceptor(route string, lvl ServerInterceptor, cb func(lvl ServerInterceptor), force bool) (ServerInterceptor, error) { 80 | // TODO: Find a more efficient way to resolve the route 81 | matchs := routeRegexp.FindStringSubmatch(route) 82 | if len(matchs) == 0 { 83 | return nil, errors.New("Invalid route") 84 | } 85 | return resolveServerInterceptorRec(matchs[1:], lvl, cb, force) 86 | } 87 | 88 | // UnaryResolver returns a `grpc.UnaryServerInterceptor` that uses the 89 | // appropriate chain of interceptors with the given gRPC request. 90 | func (r *serverRouter) UnaryResolver() grpc.UnaryServerInterceptor { 91 | return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 92 | // TODO: Find a more efficient way to chain the interceptors 93 | interceptor := NewUnaryServerInterceptor() 94 | _, err := resolveServerInterceptor(info.FullMethod, r.interceptors, func(lvl ServerInterceptor) { 95 | interceptor.AddInterceptor(lvl.UnaryServerInterceptor()) 96 | }, false) 97 | if err != nil { 98 | return nil, grpc.Errorf(codes.Internal, err.Error()) 99 | } 100 | return interceptor.Interceptor()(ctx, req, info, handler) 101 | } 102 | } 103 | 104 | // StreamResolver returns a `grpc.StreamServerInterceptor` that uses the 105 | // appropriate chain of interceptors with the given stream gRPC request. 106 | func (r *serverRouter) StreamResolver() grpc.StreamServerInterceptor { 107 | return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 108 | // TODO: Find a more efficient way to chain the interceptors 109 | interceptor := NewStreamServerInterceptor() 110 | _, err := resolveServerInterceptor(info.FullMethod, r.interceptors, func(lvl ServerInterceptor) { 111 | interceptor.AddInterceptor(lvl.StreamServerInterceptor()) 112 | }, false) 113 | if err != nil { 114 | return grpc.Errorf(codes.Internal, err.Error()) 115 | } 116 | return interceptor.Interceptor()(srv, ss, info, handler) 117 | } 118 | } 119 | 120 | // GetRegister returns the underlying `ServerInterceptorRegister` which is the 121 | // global level in the interceptor chain. 122 | func (r *serverRouter) GetRegister() ServerInterceptorRegister { 123 | return r.interceptors 124 | } 125 | 126 | // SetRegister sets the interceptor register of the router. 127 | func (r *serverRouter) SetRegister(reg ServerInterceptorRegister) { 128 | r.interceptors = reg 129 | } 130 | -------------------------------------------------------------------------------- /grpcmw/wrappers.go: -------------------------------------------------------------------------------- 1 | package grpcmw 2 | 3 | import ( 4 | "golang.org/x/net/context" 5 | 6 | "google.golang.org/grpc" 7 | ) 8 | 9 | // ServerStreamWrapper represents a wrapper for `grpc.ServerStream` that allows 10 | // to modify the context. 11 | type ServerStreamWrapper struct { 12 | grpc.ServerStream 13 | ctx context.Context 14 | } 15 | 16 | // WrapServerStream returns checks if `ss` is already a `*ServerStreamWrapper`. 17 | // If it is, it returns the `ss`, otherwise it returns a new wrapper for 18 | // `grpc.ServerStream`. 19 | func WrapServerStream(ss grpc.ServerStream) *ServerStreamWrapper { 20 | if ret, ok := ss.(*ServerStreamWrapper); ok { 21 | return ret 22 | } 23 | return &ServerStreamWrapper{ 24 | ServerStream: ss, 25 | ctx: ss.Context(), 26 | } 27 | } 28 | 29 | // Context returns the context of the wrapper0 30 | func (w ServerStreamWrapper) Context() context.Context { 31 | return w.ctx 32 | } 33 | 34 | // SetContext set the context of the wrapper to `ctx`. 35 | func (w *ServerStreamWrapper) SetContext(ctx context.Context) { 36 | w.ctx = ctx 37 | } 38 | -------------------------------------------------------------------------------- /protoc-gen-grpc-middleware/descriptor/file.go: -------------------------------------------------------------------------------- 1 | package descriptor 2 | 3 | import ( 4 | "github.com/MarquisIO/go-grpcmw/annotations" 5 | "github.com/golang/protobuf/protoc-gen-go/descriptor" 6 | ) 7 | 8 | // File represents a protobuf file. 9 | type File struct { 10 | Package string 11 | Name string 12 | Services []*Service 13 | Interceptors *Interceptors 14 | } 15 | 16 | // GetFile parses `pb` and builds a `File` object from it. 17 | // If the file does not define any service nor any interceptor option, it does 18 | // not return anything. 19 | func GetFile(pb *descriptor.FileDescriptorProto) (f *File, err error) { 20 | services := pb.GetService() 21 | f = &File{ 22 | Name: pb.GetName(), 23 | Package: pb.GetPackage(), 24 | Services: make([]*Service, len(services)), 25 | } 26 | if pb.Options != nil { 27 | if f.Interceptors, err = GetInterceptors(pb.Options, annotations.E_PackageInterceptors); err != nil { 28 | return nil, err 29 | } 30 | } 31 | for idx, service := range services { 32 | if f.Services[idx], err = GetService(service, f.Package); err != nil { 33 | return nil, err 34 | } 35 | } 36 | if f.Interceptors == nil && len(f.Services) == 0 { 37 | return nil, nil 38 | } 39 | return 40 | } 41 | -------------------------------------------------------------------------------- /protoc-gen-grpc-middleware/descriptor/interceptors.go: -------------------------------------------------------------------------------- 1 | package descriptor 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/golang/protobuf/proto" 7 | 8 | "github.com/MarquisIO/go-grpcmw/annotations" 9 | ) 10 | 11 | // Interceptors defines interceptors to use. 12 | type Interceptors struct { 13 | Indexes []string 14 | } 15 | 16 | // GetInterceptors extracts the `Interceptors` extension (described by `desc`) 17 | // from `pb`. 18 | func GetInterceptors(pb proto.Message, desc *proto.ExtensionDesc) (*Interceptors, error) { 19 | if !proto.HasExtension(pb, desc) { 20 | return nil, nil 21 | } 22 | ext, err := proto.GetExtension(pb, desc) 23 | if err != nil { 24 | return nil, err 25 | } 26 | interceptors, ok := ext.(*annotations.Interceptors) 27 | if !ok { 28 | return nil, fmt.Errorf("extension is %T; want an Interceptors", ext) 29 | } else if len(interceptors.GetIndexes()) == 0 { 30 | return nil, nil 31 | } 32 | return &Interceptors{ 33 | Indexes: interceptors.GetIndexes(), 34 | }, nil 35 | } 36 | -------------------------------------------------------------------------------- /protoc-gen-grpc-middleware/descriptor/method.go: -------------------------------------------------------------------------------- 1 | package descriptor 2 | 3 | import ( 4 | "github.com/golang/protobuf/protoc-gen-go/descriptor" 5 | 6 | "github.com/MarquisIO/go-grpcmw/annotations" 7 | ) 8 | 9 | // Method represents a method from a grpc service. 10 | type Method struct { 11 | Package string 12 | Service string 13 | Method string 14 | Stream bool 15 | Interceptors *Interceptors 16 | } 17 | 18 | // GetMethod parses `pb` and builds from it a `Method` object. 19 | func GetMethod(pb *descriptor.MethodDescriptorProto, service, pkg string) (method *Method, err error) { 20 | method = &Method{ 21 | Package: pkg, 22 | Service: service, 23 | Method: pb.GetName(), 24 | Stream: pb.GetClientStreaming() || pb.GetServerStreaming(), 25 | } 26 | if pb.Options != nil { 27 | if method.Interceptors, err = GetInterceptors(pb.Options, annotations.E_MethodInterceptors); err != nil { 28 | return nil, err 29 | } 30 | } 31 | return 32 | } 33 | -------------------------------------------------------------------------------- /protoc-gen-grpc-middleware/descriptor/parse.go: -------------------------------------------------------------------------------- 1 | package descriptor 2 | 3 | import ( 4 | plugin "github.com/golang/protobuf/protoc-gen-go/plugin" 5 | ) 6 | 7 | // Parse parses the given protobuf request into a map of packages (key) and of 8 | // files information (value). 9 | func Parse(pb *plugin.CodeGeneratorRequest) (pkgs map[string][]*File, err error) { 10 | // TODO: Do this in multiple goroutines 11 | filesToGenerate := make(map[string]struct{}, len(pb.GetFileToGenerate())) 12 | for _, f := range pb.GetFileToGenerate() { 13 | filesToGenerate[f] = struct{}{} 14 | } 15 | pkgs = make(map[string][]*File) 16 | for _, file := range pb.GetProtoFile() { 17 | if _, ok := filesToGenerate[file.GetName()]; ok { 18 | if parsed, err := GetFile(file); err != nil { 19 | return nil, err 20 | } else if parsed != nil { 21 | pkgs[file.GetPackage()] = append(pkgs[file.GetPackage()], parsed) 22 | } 23 | } 24 | } 25 | return 26 | } 27 | -------------------------------------------------------------------------------- /protoc-gen-grpc-middleware/descriptor/service.go: -------------------------------------------------------------------------------- 1 | package descriptor 2 | 3 | import ( 4 | "github.com/MarquisIO/go-grpcmw/annotations" 5 | "github.com/golang/protobuf/protoc-gen-go/descriptor" 6 | ) 7 | 8 | // Service represents a grpc service definition from a protobuf file. 9 | type Service struct { 10 | Package string 11 | Service string 12 | Methods []*Method 13 | Interceptors *Interceptors 14 | } 15 | 16 | // GetService parses `pb` and builds a `Service` object from it. 17 | func GetService(pb *descriptor.ServiceDescriptorProto, pkg string) (s *Service, err error) { 18 | methods := pb.GetMethod() 19 | s = &Service{ 20 | Package: pkg, 21 | Service: pb.GetName(), 22 | Methods: make([]*Method, len(methods)), 23 | } 24 | if pb.Options != nil { 25 | if s.Interceptors, err = GetInterceptors(pb.Options, annotations.E_ServiceInterceptors); err != nil { 26 | return nil, err 27 | } 28 | } 29 | for idx, method := range methods { 30 | if s.Methods[idx], err = GetMethod(method, s.Service, pkg); err != nil { 31 | return nil, err 32 | } 33 | } 34 | return 35 | } 36 | -------------------------------------------------------------------------------- /protoc-gen-grpc-middleware/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io" 5 | "io/ioutil" 6 | "log" 7 | "os" 8 | 9 | "github.com/MarquisIO/go-grpcmw/protoc-gen-grpc-middleware/descriptor" 10 | "github.com/MarquisIO/go-grpcmw/protoc-gen-grpc-middleware/template" 11 | "github.com/golang/protobuf/proto" 12 | plugin "github.com/golang/protobuf/protoc-gen-go/plugin" 13 | ) 14 | 15 | func parseRequest(r io.Reader) (*plugin.CodeGeneratorRequest, error) { 16 | input, err := ioutil.ReadAll(r) 17 | if err != nil { 18 | return nil, err 19 | } 20 | req := new(plugin.CodeGeneratorRequest) 21 | if err = proto.Unmarshal(input, req); err != nil { 22 | return nil, err 23 | } 24 | return req, nil 25 | } 26 | 27 | func getResponseFromError(err error) *plugin.CodeGeneratorResponse { 28 | ret := err.Error() 29 | return &plugin.CodeGeneratorResponse{Error: &ret} 30 | } 31 | 32 | func main() { 33 | var res *plugin.CodeGeneratorResponse 34 | if req, err := parseRequest(os.Stdin); err != nil { 35 | res = getResponseFromError(err) 36 | } else if pkgs, err := descriptor.Parse(req); err != nil { 37 | res = getResponseFromError(err) 38 | } else if res, err = template.Apply(pkgs); err != nil { 39 | res = getResponseFromError(err) 40 | } 41 | if buf, err := proto.Marshal(res); err != nil { 42 | log.Fatalf("Could not marshal response: %v", err) 43 | } else if _, err = os.Stdout.Write(buf); err != nil { 44 | log.Fatalf("Could not write response to stdout: %v", err) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /protoc-gen-grpc-middleware/template/init.go: -------------------------------------------------------------------------------- 1 | package template 2 | 3 | import ( 4 | "bytes" 5 | "path/filepath" 6 | "strings" 7 | "text/template" 8 | 9 | "github.com/MarquisIO/go-grpcmw/protoc-gen-grpc-middleware/descriptor" 10 | plugin "github.com/golang/protobuf/protoc-gen-go/plugin" 11 | ) 12 | 13 | // Code template keys 14 | const ( 15 | initKey = "init" 16 | ) 17 | 18 | // Code templates 19 | const ( 20 | initCode = `package {{.Package}} 21 | 22 | import ( 23 | grpcmw "github.com/MarquisIO/go-grpcmw/grpcmw" 24 | registry "github.com/MarquisIO/go-grpcmw/grpcmw/registry" 25 | ) 26 | 27 | type server{{template "pkgType" .}} struct { 28 | grpcmw.ServerInterceptor 29 | } 30 | 31 | type client{{template "pkgType" .}} struct { 32 | grpcmw.ClientInterceptor 33 | } 34 | 35 | var pkgInterceptors []string 36 | {{with .Interceptors}}{{template "pkgInterceptors" .}}{{end}} 37 | func RegisterServerInterceptors(router grpcmw.ServerRouter) *server{{template "pkgType" .}} { 38 | register := router.GetRegister() 39 | lvl, ok := register.Get("{{.Package}}") 40 | if !ok { 41 | lvl = grpcmw.NewServerInterceptorRegister("{{.Package}}") 42 | register.Register(lvl) 43 | for _, interceptor := range pkgInterceptors { 44 | lvl.Merge(registry.GetServerInterceptor(interceptor)) 45 | } 46 | } 47 | return &server{{template "pkgType" .}}{ 48 | ServerInterceptor: lvl, 49 | } 50 | } 51 | 52 | func RegisterClientInterceptors(router grpcmw.ClientRouter) *client{{template "pkgType" .}} { 53 | register := router.GetRegister() 54 | lvl, ok := register.Get("{{.Package}}") 55 | if !ok { 56 | lvl = grpcmw.NewClientInterceptorRegister("{{.Package}}") 57 | register.Register(lvl) 58 | for _, interceptor := range pkgInterceptors { 59 | lvl.Merge(registry.GetClientInterceptor(interceptor)) 60 | } 61 | } 62 | return &client{{template "pkgType" .}}{ 63 | ClientInterceptor: lvl, 64 | } 65 | } 66 | 67 | {{range .Services}}{{template "service" .}}{{end}} 68 | ` 69 | ) 70 | 71 | var initCodeTpl = template.Must(template.New(initKey).Parse(initCode)) 72 | 73 | // Apply applies the given package descriptors and generates the appropriate 74 | // code using go templates. 75 | func Apply(pkgs map[string][]*descriptor.File) (*plugin.CodeGeneratorResponse, error) { 76 | res := &plugin.CodeGeneratorResponse{} 77 | for _, files := range pkgs { 78 | for idx, file := range files { 79 | buf := new(bytes.Buffer) 80 | dest := &plugin.CodeGeneratorResponse_File{} 81 | destName := strings.TrimSuffix(file.Name, filepath.Ext(file.Name)) + ".pb.mw.go" 82 | dest.Name = &destName 83 | templateKey := pkgKey 84 | if idx == 0 { 85 | templateKey = initKey 86 | } 87 | if err := initCodeTpl.ExecuteTemplate(buf, templateKey, file); err != nil { 88 | return nil, err 89 | } 90 | ct := buf.String() 91 | dest.Content = &ct 92 | res.File = append(res.File, dest) 93 | } 94 | } 95 | return res, nil 96 | } 97 | -------------------------------------------------------------------------------- /protoc-gen-grpc-middleware/template/interceptors.go: -------------------------------------------------------------------------------- 1 | package template 2 | 3 | import "text/template" 4 | 5 | // Code template keys 6 | const ( 7 | pkgInterceptorsKey = "pkgInterceptors" 8 | ) 9 | 10 | // Code templates 11 | const ( 12 | pkgInterceptorsCode = ` 13 | {{if .Indexes}}func init() { 14 | pkgInterceptors = append( 15 | pkgInterceptors,{{range .Indexes}} 16 | "{{.}}",{{end}} 17 | ) 18 | }{{end}} 19 | ` 20 | ) 21 | 22 | func init() { 23 | template.Must(initCodeTpl.New(pkgInterceptorsKey).Parse(pkgInterceptorsCode)) 24 | } 25 | -------------------------------------------------------------------------------- /protoc-gen-grpc-middleware/template/method.go: -------------------------------------------------------------------------------- 1 | package template 2 | 3 | import "text/template" 4 | 5 | // Code template keys 6 | const ( 7 | methodTypeKey = "methodType" 8 | methodKey = "method" 9 | ) 10 | 11 | // Code templates 12 | const ( 13 | methodTypeCode = `{{if .}}Stream{{else}}Unary{{end}}` 14 | 15 | methodCode = ` 16 | func (s *server{{template "serviceType" .}}) {{.Method}}() grpcmw.{{template "methodType" .Stream}}ServerInterceptor { 17 | method, ok := s.ServerInterceptor.(grpcmw.ServerInterceptorRegister).Get("{{.Method}}") 18 | if !ok { 19 | method = grpcmw.NewServerInterceptorRegister("{{.Method}}") 20 | s.ServerInterceptor.(grpcmw.ServerInterceptorRegister).Register(method) 21 | } 22 | return method.{{template "methodType" .Stream}}ServerInterceptor() 23 | } 24 | 25 | func (s *client{{template "serviceType" .}}) {{.Method}}() grpcmw.{{template "methodType" .Stream}}ClientInterceptor { 26 | method, ok := s.ClientInterceptor.(grpcmw.ClientInterceptorRegister).Get("{{.Method}}") 27 | if !ok { 28 | method = grpcmw.NewClientInterceptorRegister("{{.Method}}") 29 | s.ClientInterceptor.(grpcmw.ClientInterceptorRegister).Register(method) 30 | } 31 | return method.{{template "methodType" .Stream}}ClientInterceptor() 32 | }` 33 | ) 34 | 35 | func init() { 36 | template.Must(initCodeTpl.New(methodKey).Parse(methodCode)) 37 | template.Must(initCodeTpl.New(methodTypeKey).Parse(methodTypeCode)) 38 | } 39 | -------------------------------------------------------------------------------- /protoc-gen-grpc-middleware/template/package.go: -------------------------------------------------------------------------------- 1 | package template 2 | 3 | import ( 4 | "text/template" 5 | ) 6 | 7 | // Code template keys 8 | const ( 9 | pkgKey = "pkg" 10 | pkgTypeKey = "pkgType" 11 | ) 12 | 13 | // Code templates 14 | const ( 15 | pkgTypeCode = `Interceptor_{{.Package}}` 16 | 17 | pkgCode = `package {{.Package}} 18 | 19 | import ( 20 | grpcmw "github.com/MarquisIO/go-grpcmw/grpcmw" 21 | registry "github.com/MarquisIO/go-grpcmw/grpcmw/registry" 22 | ) 23 | 24 | var ( 25 | _ = registry.GetClientInterceptor 26 | ) 27 | 28 | {{with .Interceptors}}{{template "pkgInterceptors" .}}{{end}} 29 | {{range .Services}}{{template "service" .}}{{end}} 30 | ` 31 | ) 32 | 33 | func init() { 34 | template.Must(initCodeTpl.New(pkgKey).Parse(pkgCode)) 35 | template.Must(initCodeTpl.New(pkgTypeKey).Parse(pkgTypeCode)) 36 | } 37 | -------------------------------------------------------------------------------- /protoc-gen-grpc-middleware/template/service.go: -------------------------------------------------------------------------------- 1 | package template 2 | 3 | import "text/template" 4 | 5 | // Code template keys 6 | const ( 7 | serviceKey = "service" 8 | serviceTypeKey = "serviceType" 9 | ) 10 | 11 | // Code templates 12 | const ( 13 | serviceTypeCode = `Interceptor_{{.Package}}{{.Service}}` 14 | 15 | serviceCode = ` 16 | type server{{template "serviceType" .}} struct { 17 | grpcmw.ServerInterceptor 18 | } 19 | 20 | type client{{template "serviceType" .}} struct { 21 | grpcmw.ClientInterceptor 22 | } 23 | 24 | func (i *server{{template "pkgType" .}}) Register{{.Service}}() *server{{template "serviceType" .}} { 25 | service, ok := i.ServerInterceptor.(grpcmw.ServerInterceptorRegister).Get("{{.Service}}") 26 | if !ok { 27 | ret := &server{{template "serviceType" .}}{ 28 | ServerInterceptor: grpcmw.NewServerInterceptorRegister("{{.Service}}"), 29 | } 30 | i.ServerInterceptor.(grpcmw.ServerInterceptorRegister).Register(ret.ServerInterceptor) 31 | {{with .Interceptors}}ret.ServerInterceptor.Merge({{range .Indexes}} 32 | registry.GetServerInterceptor("{{.}}"),{{end}} 33 | ){{end}} 34 | {{range .Methods}}{{if .Interceptors}} 35 | ret.{{.Method}}().AddInterceptor({{$method := .}}{{range .Interceptors.Indexes}} 36 | registry.GetServerInterceptor("{{.}}").{{template "methodType" $method.Stream}}ServerInterceptor(),{{end}} 37 | ){{end}}{{end}} 38 | return ret 39 | } 40 | return &server{{template "serviceType" .}}{ 41 | ServerInterceptor: service, 42 | } 43 | } 44 | 45 | func (i *client{{template "pkgType" .}}) Register{{.Service}}() *client{{template "serviceType" .}} { 46 | service, ok := i.ClientInterceptor.(grpcmw.ClientInterceptorRegister).Get("{{.Service}}") 47 | if !ok { 48 | ret := &client{{template "serviceType" .}}{ 49 | ClientInterceptor: grpcmw.NewClientInterceptorRegister("{{.Service}}"), 50 | } 51 | i.ClientInterceptor.(grpcmw.ClientInterceptorRegister).Register(ret.ClientInterceptor) 52 | {{with .Interceptors}}ret.ClientInterceptor.Merge({{range .Indexes}} 53 | registry.GetClientInterceptor("{{.}}"),{{end}} 54 | ){{end}} 55 | {{range .Methods}}{{if .Interceptors}} 56 | ret.{{.Method}}().AddInterceptor({{$method := .}}{{range .Interceptors.Indexes}} 57 | registry.GetClientInterceptor("{{.}}").{{template "methodType" $method.Stream}}ClientInterceptor(),{{end}} 58 | ){{end}}{{end}} 59 | return ret 60 | } 61 | return &client{{template "serviceType" .}}{ 62 | ClientInterceptor: service, 63 | } 64 | } 65 | 66 | {{range .Methods}}{{template "method" .}}{{end}} 67 | ` 68 | ) 69 | 70 | func init() { 71 | template.Must(initCodeTpl.New(serviceKey).Parse(serviceCode)) 72 | template.Must(initCodeTpl.New(serviceTypeKey).Parse(serviceTypeCode)) 73 | } 74 | --------------------------------------------------------------------------------