├── .gitignore ├── message.go ├── testproto ├── dummy.proto └── dummy.pb.go ├── service ├── config.go ├── service.go └── serverclient_test.go ├── .travis.yml ├── requesttree ├── utils.go ├── middleware.go └── middleware_test.go ├── client ├── utils.go ├── sugar.go ├── call.go ├── interface.go ├── errorset_test.go ├── errorset.go ├── client.go └── client_test.go ├── transport └── transport.go ├── compat ├── compat.go └── message.go ├── LICENSE ├── marshaling └── registry.go ├── README.md ├── response.go ├── server ├── endpoint.go ├── interface.go ├── srv.go └── srv_test.go └── request.go /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /message.go: -------------------------------------------------------------------------------- 1 | package mercury 2 | 3 | import ( 4 | tmsg "github.com/monzo/typhon/message" 5 | ) 6 | 7 | type Message tmsg.Message 8 | -------------------------------------------------------------------------------- /testproto/dummy.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | package testproto; 3 | 4 | message DummyRequest { 5 | string ping = 1; 6 | } 7 | 8 | message DummyResponse { 9 | string pong = 1; 10 | } 11 | -------------------------------------------------------------------------------- /service/config.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "github.com/monzo/mercury/transport" 5 | ) 6 | 7 | type Config struct { 8 | Name string 9 | Description string 10 | // Transport specifies a transport to run the server on. If none is specified, a mock transport is used. 11 | Transport transport.Transport 12 | } 13 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.6.3 5 | - 1.7 6 | 7 | install: 8 | - export PATH=${PATH}:${HOME}/gopath/bin 9 | - go get -v -t ./... 10 | - go get -v github.com/golang/lint/golint 11 | 12 | before_script: 13 | - go vet ./... 14 | - golint . 15 | 16 | script: 17 | - go test -v ./... 18 | - go test -v -test.race ./... 19 | -------------------------------------------------------------------------------- /requesttree/utils.go: -------------------------------------------------------------------------------- 1 | package requesttree 2 | 3 | import ( 4 | "golang.org/x/net/context" 5 | ) 6 | 7 | const parentIdCtxKey = parentIdHeader 8 | 9 | // ParentRequestIdFor returns the parent request ID for the provided context (if any). 10 | func ParentRequestIdFor(ctx context.Context) string { 11 | switch v := ctx.Value(parentIdCtxKey).(type) { 12 | case string: 13 | return v 14 | } 15 | return "" 16 | } 17 | -------------------------------------------------------------------------------- /client/utils.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | func mergeContexts(maps ...map[string]string) map[string]string { 4 | result := make(map[string]string) 5 | for _, m := range maps { 6 | for k, v := range m { 7 | if _, ok := result[k]; !ok { 8 | result[k] = v 9 | } 10 | } 11 | } 12 | return result 13 | } 14 | 15 | func stringsMap(strings ...string) map[string]struct{} { 16 | result := make(map[string]struct{}, len(strings)) 17 | for _, s := range strings { 18 | result[s] = struct{}{} 19 | } 20 | return result 21 | } 22 | -------------------------------------------------------------------------------- /client/sugar.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "golang.org/x/net/context" 5 | ) 6 | 7 | // Req sends a synchronous request to a service using a new client, and unmarshals the response into the supplied 8 | // protobuf 9 | func Req(ctx context.Context, service, endpoint string, req, res interface{}) error { 10 | return NewClient(). 11 | Add(Call{ 12 | Uid: "default", 13 | Service: service, 14 | Endpoint: endpoint, 15 | Body: req, 16 | Response: res, 17 | Context: ctx, 18 | }). 19 | Execute(). 20 | Errors(). 21 | Combined() 22 | } 23 | -------------------------------------------------------------------------------- /transport/transport.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "sync" 5 | 6 | ttrans "github.com/monzo/typhon/transport" 7 | ) 8 | 9 | type Transport ttrans.Transport 10 | 11 | var ( 12 | defaultTransport Transport 13 | defaultTransportM sync.RWMutex 14 | ) 15 | 16 | // DefaultTransport returns the global default transport, over which servers and clients should run by default 17 | func DefaultTransport() Transport { 18 | defaultTransportM.RLock() 19 | defer defaultTransportM.RUnlock() 20 | return defaultTransport 21 | } 22 | 23 | // SetDefaultTransport replaces the global default transport. When replacing, it does not close the prior transport. 24 | func SetDefaultTransport(t Transport) { 25 | defaultTransportM.Lock() 26 | defer defaultTransportM.Unlock() 27 | defaultTransport = t 28 | } 29 | -------------------------------------------------------------------------------- /compat/compat.go: -------------------------------------------------------------------------------- 1 | package mercurycompat 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/monzo/mercury/server" 8 | "github.com/monzo/typhon" 9 | ) 10 | 11 | func CompatServer(srv server.Server) typhon.Filter { 12 | return func(req typhon.Request, svc typhon.Service) typhon.Response { 13 | eps := []string{ 14 | fmt.Sprintf("%s %s", req.Method, req.URL.Path), 15 | fmt.Sprintf("%s %s", req.Method, strings.TrimPrefix(req.URL.Path, "/"))} 16 | if req.Method == "POST" { 17 | eps = append(eps, fmt.Sprintf("%s", strings.TrimPrefix(req.URL.Path, "/"))) 18 | } 19 | 20 | for _, epName := range eps { 21 | ep, ok := srv.Endpoint(epName) 22 | if ok { 23 | oldRsp, err := ep.Handle(new2OldRequest(req)) 24 | if err != nil { 25 | return typhon.Response{ 26 | Error: err} 27 | } 28 | return old2NewResponse(req, oldRsp) 29 | } 30 | } 31 | 32 | // No matching endpoint found; send it to the lower-level service 33 | return svc(req) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Oliver Beattie 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /testproto/dummy.pb.go: -------------------------------------------------------------------------------- 1 | // Code generated by protoc-gen-go. 2 | // source: github.com/monzo/mercury/testproto/dummy.proto 3 | // DO NOT EDIT! 4 | 5 | /* 6 | Package testproto is a generated protocol buffer package. 7 | 8 | It is generated from these files: 9 | github.com/monzo/mercury/testproto/dummy.proto 10 | 11 | It has these top-level messages: 12 | DummyRequest 13 | DummyResponse 14 | */ 15 | package testproto 16 | 17 | import proto "github.com/golang/protobuf/proto" 18 | 19 | // Reference imports to suppress errors if they are not otherwise used. 20 | var _ = proto.Marshal 21 | 22 | type DummyRequest struct { 23 | Ping string `protobuf:"bytes,1,opt,name=ping" json:"ping,omitempty"` 24 | } 25 | 26 | func (m *DummyRequest) Reset() { *m = DummyRequest{} } 27 | func (m *DummyRequest) String() string { return proto.CompactTextString(m) } 28 | func (*DummyRequest) ProtoMessage() {} 29 | 30 | type DummyResponse struct { 31 | Pong string `protobuf:"bytes,1,opt,name=pong" json:"pong,omitempty"` 32 | } 33 | 34 | func (m *DummyResponse) Reset() { *m = DummyResponse{} } 35 | func (m *DummyResponse) String() string { return proto.CompactTextString(m) } 36 | func (*DummyResponse) ProtoMessage() {} 37 | 38 | func init() { 39 | } 40 | -------------------------------------------------------------------------------- /marshaling/registry.go: -------------------------------------------------------------------------------- 1 | package marshaling 2 | 3 | import ( 4 | "sync" 5 | 6 | tmsg "github.com/monzo/typhon/message" 7 | ) 8 | 9 | const ( 10 | ContentTypeHeader = "Content-Type" 11 | AcceptHeader = "Accept" 12 | JSONContentType = tmsg.JSONContentType 13 | ) 14 | 15 | type MarshalerFactory func() tmsg.Marshaler 16 | type UnmarshalerFactory func(interface{}) tmsg.Unmarshaler 17 | 18 | type marshalerPair struct { 19 | m MarshalerFactory 20 | u UnmarshalerFactory 21 | } 22 | 23 | var ( 24 | marshalerRegistryM sync.RWMutex 25 | marshalerRegistry = map[string]marshalerPair{ 26 | JSONContentType: { 27 | m: tmsg.JSONMarshaler, 28 | u: tmsg.JSONUnmarshaler}, 29 | } 30 | ) 31 | 32 | func Register(contentType string, mc MarshalerFactory, uc UnmarshalerFactory) { 33 | if contentType == "" || mc == nil || uc == nil { 34 | return 35 | } 36 | 37 | marshalerRegistryM.Lock() 38 | defer marshalerRegistryM.Unlock() 39 | marshalerRegistry[contentType] = marshalerPair{ 40 | m: mc, 41 | u: uc, 42 | } 43 | } 44 | 45 | func Marshaler(contentType string) tmsg.Marshaler { 46 | marshalerRegistryM.RLock() 47 | defer marshalerRegistryM.RUnlock() 48 | if mp, ok := marshalerRegistry[contentType]; ok { 49 | return mp.m() 50 | } 51 | return nil 52 | } 53 | 54 | func Unmarshaler(contentType string, protocol interface{}) tmsg.Unmarshaler { 55 | marshalerRegistryM.RLock() 56 | defer marshalerRegistryM.RUnlock() 57 | if mp, ok := marshalerRegistry[contentType]; ok { 58 | return mp.u(protocol) 59 | } 60 | return nil 61 | } 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mercury 2 | 3 | 🚨 **Mercury is deprecated and is no longer maintained.** 🚨 4 | 5 | [![Build Status](https://travis-ci.org/monzo/mercury.svg?branch=master)](https://travis-ci.org/monzo/mercury) 6 | [![GoDoc](https://godoc.org/github.com/monzo/mercury?status.svg)](https://godoc.org/github.com/monzo/mercury) 7 | 8 | An RPC client/server implementation using [Typhon](https://github.com/monzo/typhon), intended for building microservices. 9 | 10 | ## Server 11 | 12 | A [`Server`](http://godoc.org/github.com/monzo/mercury/server) receives RPC requests, routes them to an [`Endpoint`](http://godoc.org/github.com/monzo/mercury/server#Endpoint), calls a handler function to "do work," and returns a response back to a caller. 13 | 14 | ### Server middleware 15 | 16 | Server middleware offers hooks into request processing for globally altering a server's input or output. They could be used to provide authentication or distributed tracing functionality, for example. 17 | 18 | ## Client 19 | 20 | A [`Client`](http://godoc.org/github.com/monzo/mercury/client#Client) offers a convenient way, atop a Typhon transport, to make requests to other servers. They co-ordinate the execution of many parallel requests, deal with response and error unmarshaling, and provide convenient ways of dealing with response errors. 21 | 22 | ### Client middleware 23 | 24 | Like server middleware, clients too have hooks for altering outbound requests or inbound responses. 25 | 26 | ## Service 27 | 28 | A [`Service`](http://godoc.org/github.com/monzo/mercury/service#Service) is a lightweight wrapper around a server, which also sets up some global defaults (for instance, to use the same default client transport as the server). 29 | -------------------------------------------------------------------------------- /response.go: -------------------------------------------------------------------------------- 1 | package mercury 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/monzo/terrors" 7 | tperrors "github.com/monzo/terrors/proto" 8 | tmsg "github.com/monzo/typhon/message" 9 | 10 | "github.com/monzo/mercury/marshaling" 11 | ) 12 | 13 | type Response interface { 14 | tmsg.Response 15 | // IsError returns whether this response contains an error 16 | IsError() bool 17 | // SetIsError modifies whether the flag specifying whether this response contains an error 18 | SetIsError(bool) 19 | // Error returns the error contained within the response 20 | Error() error 21 | } 22 | 23 | type response struct { 24 | tmsg.Response 25 | } 26 | 27 | func (r *response) IsError() bool { 28 | return r.Headers()[errHeader] == "1" 29 | } 30 | 31 | func (r *response) SetIsError(v bool) { 32 | r.SetHeader(errHeader, "1") 33 | } 34 | 35 | func (r *response) Error() error { 36 | if !r.IsError() { 37 | return nil 38 | } 39 | r2 := r.Copy() 40 | um := marshaling.Unmarshaler(r2.Headers()[marshaling.ContentTypeHeader], &tperrors.Error{}) 41 | if um == nil { 42 | um = marshaling.Unmarshaler(marshaling.JSONContentType, &tperrors.Error{}) 43 | } 44 | if umErr := um.UnmarshalPayload(r2); umErr != nil { 45 | return umErr 46 | } 47 | if err := terrors.Unmarshal(r2.Body().(*tperrors.Error)); err != nil { 48 | return err // Don't return a nil but typed result 49 | } 50 | return nil 51 | } 52 | 53 | func (r *response) String() string { 54 | return fmt.Sprintf("%v", r.Response) 55 | } 56 | 57 | func NewResponse() Response { 58 | return FromTyphonResponse(tmsg.NewResponse()) 59 | } 60 | 61 | func FromTyphonResponse(rsp tmsg.Response) Response { 62 | switch rsp := rsp.(type) { 63 | case Response: 64 | return rsp 65 | default: 66 | return &response{ 67 | Response: rsp} 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /client/call.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "github.com/monzo/terrors" 5 | tmsg "github.com/monzo/typhon/message" 6 | "golang.org/x/net/context" 7 | 8 | "github.com/monzo/mercury" 9 | "github.com/monzo/mercury/marshaling" 10 | ) 11 | 12 | // A Call is a convenient way to form a Request for an RPC call. 13 | type Call struct { 14 | // Uid represents a unique identifier for this call within the scope of a client. 15 | Uid string 16 | // Service to receive the call. 17 | Service string 18 | // Endpoint of the receiving service. 19 | Endpoint string 20 | // Body will be serialised to form the Payload of the request. 21 | Body interface{} 22 | // Headers to send on the request (these may be augmented by the client). 23 | Headers map[string]string 24 | // Response is a protocol into which the response's Payload should be unmarshaled. 25 | Response interface{} 26 | // Context is a context for the request. This should nearly always be the parent request (if any). 27 | Context context.Context 28 | } 29 | 30 | func (c Call) marshaler() tmsg.Marshaler { 31 | result := tmsg.Marshaler(nil) 32 | if c.Headers != nil && c.Headers[marshaling.ContentTypeHeader] != "" { 33 | result = marshaling.Marshaler(c.Headers[marshaling.ContentTypeHeader]) 34 | } 35 | if result == nil { 36 | result = tmsg.JSONMarshaler() 37 | } 38 | return result 39 | } 40 | 41 | // Request yields a Request formed from this Call 42 | func (c Call) Request() (mercury.Request, error) { 43 | req := mercury.NewRequest() 44 | req.SetService(c.Service) 45 | req.SetEndpoint(c.Endpoint) 46 | req.SetHeaders(c.Headers) 47 | if c.Context != nil { 48 | req.SetContext(c.Context) 49 | } 50 | if c.Body != nil { 51 | req.SetBody(c.Body) 52 | if err := c.marshaler().MarshalBody(req); err != nil { 53 | return nil, terrors.WrapWithCode(err, nil, terrors.ErrBadRequest) 54 | } 55 | } 56 | return req, nil 57 | } 58 | -------------------------------------------------------------------------------- /server/endpoint.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | log "github.com/monzo/slog" 5 | "github.com/monzo/terrors" 6 | tmsg "github.com/monzo/typhon/message" 7 | 8 | "github.com/monzo/mercury" 9 | "github.com/monzo/mercury/marshaling" 10 | ) 11 | 12 | type Handler func(req mercury.Request) (mercury.Response, error) 13 | 14 | // An Endpoint represents a handler function bound to a particular endpoint name. 15 | type Endpoint struct { 16 | // Name is the Endpoint's unique name, and is used to route requests to it. 17 | Name string 18 | // Handler is a function to be invoked upon receiving a request, to generate a response. 19 | Handler Handler 20 | // Request is a "template" object for the Endpoint's request format. 21 | Request interface{} 22 | // Response is a "template" object for the Endpoint's response format. 23 | Response interface{} 24 | } 25 | 26 | func (e Endpoint) unmarshaler(req mercury.Request) tmsg.Unmarshaler { 27 | result := marshaling.Unmarshaler(req.Headers()[marshaling.ContentTypeHeader], e.Request) 28 | if result == nil { // Default to json 29 | result = marshaling.Unmarshaler(marshaling.JSONContentType, e.Request) 30 | } 31 | return result 32 | } 33 | 34 | // Handle takes an inbound Request, unmarshals it, dispatches it to the handler, and serialises the result as a 35 | // Response. Note that the response may be nil. 36 | func (e Endpoint) Handle(req mercury.Request) (mercury.Response, error) { 37 | // Unmarshal the request body (unless there already is one) 38 | if req.Body() == nil && e.Request != nil { 39 | if um := e.unmarshaler(req); um != nil { 40 | if werr := terrors.Wrap(um.UnmarshalPayload(req), nil); werr != nil { 41 | log.Warn(req, "[Mercury:Server] Cannot unmarshal request payload: %v", werr) 42 | terr := werr.(*terrors.Error) 43 | terr.Code = terrors.ErrBadRequest 44 | return nil, terr 45 | } 46 | } 47 | } 48 | 49 | return e.Handler(req) 50 | } 51 | -------------------------------------------------------------------------------- /compat/message.go: -------------------------------------------------------------------------------- 1 | package mercurycompat 2 | 3 | import ( 4 | "bytes" 5 | "io/ioutil" 6 | "net/http" 7 | "net/url" 8 | "strings" 9 | 10 | "github.com/monzo/mercury" 11 | "github.com/monzo/typhon" 12 | ) 13 | 14 | const legacyIdHeader = "Legacy-Id" 15 | 16 | func toHeader(m map[string]string) http.Header { 17 | h := make(http.Header, len(m)) 18 | for k, v := range m { 19 | h.Set(k, v) 20 | } 21 | return h 22 | } 23 | 24 | func fromHeader(h http.Header) map[string]string { 25 | m := make(map[string]string, len(h)) 26 | for k, v := range h { 27 | if len(v) < 1 { 28 | continue 29 | } 30 | m[k] = v[0] 31 | } 32 | return m 33 | } 34 | 35 | func old2NewRequest(oldReq mercury.Request) typhon.Request { 36 | ep := oldReq.Endpoint() 37 | if !strings.HasPrefix(ep, "/") { 38 | ep = "/" + ep 39 | } 40 | v := typhon.Request{ 41 | Context: oldReq.Context(), 42 | Request: http.Request{ 43 | Method: "POST", 44 | URL: &url.URL{ 45 | Scheme: "http", 46 | Host: oldReq.Service(), 47 | Path: ep}, 48 | Proto: "HTTP/1.1", 49 | ProtoMajor: 1, 50 | ProtoMinor: 1, 51 | Header: toHeader(oldReq.Headers()), 52 | Host: oldReq.Service(), 53 | Body: ioutil.NopCloser(bytes.NewReader(oldReq.Payload())), 54 | ContentLength: int64(len(oldReq.Payload()))}} 55 | v.Header.Set(legacyIdHeader, oldReq.Id()) 56 | return v 57 | } 58 | 59 | func new2OldRequest(newReq typhon.Request) mercury.Request { 60 | req := mercury.NewRequest() 61 | req.SetService(newReq.Host) 62 | req.SetEndpoint(newReq.URL.Path) 63 | req.SetHeaders(fromHeader(newReq.Header)) 64 | b, _ := newReq.BodyBytes(true) 65 | req.SetPayload(b) 66 | req.SetId(newReq.Header.Get(legacyIdHeader)) 67 | req.SetContext(newReq) 68 | return req 69 | } 70 | 71 | func old2NewResponse(req typhon.Request, oldRsp mercury.Response) typhon.Response { 72 | rsp := typhon.NewResponse(req) 73 | rsp.Header = toHeader(oldRsp.Headers()) 74 | rsp.Write(oldRsp.Payload()) 75 | rsp.Error = oldRsp.Error() 76 | return rsp 77 | } 78 | -------------------------------------------------------------------------------- /service/service.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/monzo/mercury/client" 7 | "github.com/monzo/mercury/requesttree" 8 | "github.com/monzo/mercury/server" 9 | "github.com/monzo/mercury/transport" 10 | ) 11 | 12 | var ( 13 | defaultService Service 14 | defaultServiceM sync.RWMutex 15 | ) 16 | 17 | func init() { 18 | client.SetDefaultMiddleware(DefaultClientMiddleware()) 19 | server.SetDefaultMiddleware(DefaultServerMiddleware()) 20 | } 21 | 22 | type Service interface { 23 | Server() server.Server 24 | Run() 25 | Transport() transport.Transport 26 | } 27 | 28 | type svc struct { 29 | srv server.Server 30 | config Config 31 | } 32 | 33 | func (s *svc) Server() server.Server { 34 | return s.srv 35 | } 36 | 37 | func (s *svc) Run() { 38 | s.srv.Run(s.config.Transport) 39 | } 40 | 41 | func (s *svc) Transport() transport.Transport { 42 | return s.config.Transport 43 | } 44 | 45 | // DefaultServerMiddleware returns the complement of server middleware provided by Mercury 46 | func DefaultServerMiddleware() []server.ServerMiddleware { 47 | return []server.ServerMiddleware{ 48 | requesttree.Middleware()} 49 | } 50 | 51 | // DefaultClientMiddleware returns the complement of client middleware provided by Mercury 52 | func DefaultClientMiddleware() []client.ClientMiddleware { 53 | return []client.ClientMiddleware{ 54 | requesttree.Middleware()} 55 | } 56 | 57 | // DefaultService returns the global default Service. 58 | func DefaultService() Service { 59 | defaultServiceM.RLock() 60 | defer defaultServiceM.RUnlock() 61 | return defaultService 62 | } 63 | 64 | // New creates a new service with default middleware 65 | func New(cfg Config) Service { 66 | if cfg.Transport == nil { 67 | cfg.Transport = transport.DefaultTransport() 68 | } 69 | 70 | srv := server.NewServer(cfg.Name) 71 | srv.SetMiddleware(DefaultServerMiddleware()) 72 | 73 | return &svc{ 74 | srv: srv, 75 | config: cfg, 76 | } 77 | } 78 | 79 | // Init performs any global initialisation that is usually required for Mercury services. Namely it: 80 | // 81 | // * Sets up a server with middleware (request tree) 82 | // * Sets the created service as the default service 83 | func Init(cfg Config) Service { 84 | impl := New(cfg) 85 | 86 | defaultServiceM.Lock() 87 | defaultService = impl 88 | defaultServiceM.Unlock() 89 | 90 | <-impl.Transport().Ready() 91 | 92 | return impl 93 | } 94 | -------------------------------------------------------------------------------- /server/interface.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/monzo/mercury" 5 | "github.com/monzo/mercury/transport" 6 | ) 7 | 8 | // A Server provides Endpoint RPC functionality atop a typhon Transport. 9 | type Server interface { 10 | // Name returns the service name. It must be set at construction time and is immutable. 11 | Name() string 12 | // AddEndpoints registers new Endpoint. If any name conflicts with an existing endpoint, the old endpoint(s) will be 13 | // removed. Errors raised as panics. 14 | AddEndpoints(eps ...Endpoint) 15 | // RemoveEndpoints removes the Endpoints given (if they are registered). 16 | RemoveEndpoints(eps ...Endpoint) 17 | // Endpoint returns a registered endpoint (if there is one) for the given Name. 18 | Endpoint(name string) (Endpoint, bool) 19 | // Endpoints returns all Endpoints registered. 20 | Endpoints() []Endpoint 21 | // Start starts the server on the given transport, and returns once the server is ready for work. The server will 22 | // continue until purposefully stopped, or until a terminal error occurs. The transport should be pre-initialised. 23 | Start(trans transport.Transport) error 24 | // Run starts the server and blocks until it stops. As this function is intended to support the main run loop of a 25 | // service, an error results in a panic. 26 | Run(trans transport.Transport) 27 | // Stop forcefully stops the server. It does not terminate the underlying transport. 28 | Stop() 29 | 30 | // Middleware returns a copy of the ServerMiddleware stack currently installed. 31 | // 32 | // Server middleware is used to act upon or transform a handler's input or output. Middleware is applied in order 33 | // during the request phase, and in reverse order during the response phase. 34 | Middleware() []ServerMiddleware 35 | // SetMiddleware replaces the server's ServerMiddleware stack. 36 | SetMiddleware([]ServerMiddleware) 37 | // AddMiddleware appends the given ServerMiddleware to the stack. 38 | AddMiddleware(ServerMiddleware) 39 | } 40 | 41 | type ServerMiddleware interface { 42 | // ProcessServerRequest is called on each inbound request, before it is routed to an Endpoint. If a response or an 43 | // error is returned, Mercury does not bother calling any other request middleware. It will apply response 44 | // middleware and respond to the caller with the result. 45 | // 46 | // If an error is to be returned, use `ErrorResponse`. 47 | ProcessServerRequest(req mercury.Request) (mercury.Request, mercury.Response) 48 | // ProcessServerResponse is called on all responses before they are returned to a caller. Unlike request middleware, 49 | // response middleware is always called. If an error is returned, it will be marshaled to a response and will 50 | // continue to other response middleware. 51 | // 52 | // Nil responses MUST be handled. If an error is to be returned, use `ErrorResponse`. 53 | // 54 | // Note that response middleware are applied in reverse order. 55 | ProcessServerResponse(rsp mercury.Response, req mercury.Request) mercury.Response 56 | } 57 | -------------------------------------------------------------------------------- /request.go: -------------------------------------------------------------------------------- 1 | package mercury 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | "time" 7 | 8 | log "github.com/monzo/slog" 9 | tmsg "github.com/monzo/typhon/message" 10 | "golang.org/x/net/context" 11 | 12 | "github.com/monzo/mercury/marshaling" 13 | ) 14 | 15 | const ( 16 | errHeader = "Content-Error" 17 | ) 18 | 19 | // A Request is a representation of an RPC call (inbound or outbound). It extends Typhon's Request to provide a 20 | // Context, and also helpers for constructing a response. 21 | type Request interface { 22 | tmsg.Request 23 | context.Context 24 | 25 | // Response constructs a response to this request, with the (optional) given body. The response will share 26 | // the request's ID, and be destined for the originator. 27 | Response(body interface{}) Response 28 | // A Context for the Request. 29 | Context() context.Context 30 | // SetContext replaces the Request's Context. 31 | SetContext(ctx context.Context) 32 | } 33 | 34 | func responseFromRequest(req Request, body interface{}) Response { 35 | rsp := NewResponse() 36 | rsp.SetId(req.Id()) 37 | if body != nil { 38 | rsp.SetBody(body) 39 | 40 | ct := req.Headers()[marshaling.AcceptHeader] 41 | marshaler := marshaling.Marshaler(ct) 42 | if marshaler == nil { // Fall back to JSON 43 | marshaler = marshaling.Marshaler(marshaling.JSONContentType) 44 | } 45 | if marshaler == nil { 46 | log.Error(req, "[Mercury] No marshaler for response %s: %s", rsp.Id(), ct) 47 | } else if err := marshaler.MarshalBody(rsp); err != nil { 48 | log.Error(req, "[Mercury] Failed to marshal response %s: %v", rsp.Id(), err) 49 | } 50 | } 51 | return rsp 52 | } 53 | 54 | type request struct { 55 | sync.RWMutex 56 | tmsg.Request 57 | ctx context.Context 58 | } 59 | 60 | func (r *request) Response(body interface{}) Response { 61 | return responseFromRequest(r, body) 62 | } 63 | 64 | func (r *request) Context() context.Context { 65 | if r == nil { 66 | return nil 67 | } 68 | r.RLock() 69 | defer r.RUnlock() 70 | return r.ctx 71 | } 72 | 73 | func (r *request) SetContext(ctx context.Context) { 74 | r.Lock() 75 | defer r.Unlock() 76 | r.ctx = ctx 77 | } 78 | 79 | func (r *request) Copy() tmsg.Request { 80 | r.RLock() 81 | defer r.RUnlock() 82 | return &request{ 83 | Request: r.Request.Copy(), 84 | ctx: r.ctx, 85 | } 86 | } 87 | 88 | func (r *request) String() string { 89 | return fmt.Sprintf("%v", r.Request) 90 | } 91 | 92 | // Context implementation 93 | 94 | func (r *request) Deadline() (time.Time, bool) { 95 | return r.Context().Deadline() 96 | } 97 | 98 | func (r *request) Done() <-chan struct{} { 99 | return r.Context().Done() 100 | } 101 | 102 | func (r *request) Err() error { 103 | return r.Context().Err() 104 | } 105 | 106 | func (r *request) Value(key interface{}) interface{} { 107 | return r.Context().Value(key) 108 | } 109 | 110 | func NewRequest() Request { 111 | return FromTyphonRequest(tmsg.NewRequest()) 112 | } 113 | 114 | func FromTyphonRequest(req tmsg.Request) Request { 115 | switch req := req.(type) { 116 | case Request: 117 | return req 118 | default: 119 | return &request{ 120 | Request: req, 121 | ctx: context.Background(), 122 | } 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /client/interface.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/monzo/terrors" 7 | 8 | "github.com/monzo/mercury" 9 | "github.com/monzo/mercury/transport" 10 | ) 11 | 12 | // A Client is a convenient way to make Requests (potentially in parallel) and access their Responses/Errors. 13 | type Client interface { 14 | // Add a Call to the internal request set. 15 | Add(Call) Client 16 | // Add a Request to the internal request set (requests added this way will not benefit from automatic body 17 | // unmarshaling). 18 | AddRequest(uid string, req mercury.Request) Client 19 | // SetTimeout sets a timeout within which all requests must be received. Any response not received within this 20 | // window will result in an error being added to the error set. 21 | SetTimeout(timeout time.Duration) Client 22 | // Go fires off the requests. It does not wait until the requests have completed to return. 23 | Go() Client 24 | // Wait blocks until all requests have finished executing. 25 | Wait() Client 26 | // Execute fires off all requests and waits until all requests have completed before returning. 27 | Execute() Client 28 | // SetTransport configures a Transport to use with this Client. By default, it uses the default transport. 29 | SetTransport(t transport.Transport) Client 30 | 31 | // WaitC returns a channel which will be closed when all requests have finished. 32 | WaitC() <-chan struct{} 33 | // Errors returns an ErrorSet of all errors generated during execution (if any). 34 | Errors() ErrorSet 35 | // Response retrieves the Response for the request given by its uid. If no such uid is known, returns nil. 36 | Response(uid string) mercury.Response 37 | 38 | // Middleware returns the ClientMiddleware stack currently installed. This is not a copy, so it's advisable not to 39 | // f*ck around with it. 40 | // 41 | // Client middleware is used to act upon or transform an RPC request or its response. Middleware is applied in order 42 | // during the request phase, and in reverse order during the response phase. 43 | // 44 | // Beware: client middleware can cause the timeouts to be exceeded. They must be fast, and certainly should not 45 | // make any remote calls themselves. 46 | Middleware() []ClientMiddleware 47 | // SetMiddleware replaces the client's Client stack. 48 | SetMiddleware([]ClientMiddleware) Client 49 | // AddMiddleware appends the given ClientMiddleware to the stack. 50 | AddMiddleware(ClientMiddleware) Client 51 | } 52 | 53 | type ClientMiddleware interface { 54 | // ProcessClientRequest is called on every outbound request, before it is sent to a transport. 55 | // 56 | // The middleware may mutate the request, or by returning nil, prevent the request from being sent entirely. 57 | ProcessClientRequest(req mercury.Request) mercury.Request 58 | // ProcessClientResponse is called on responses before they are available to the caller. If a call fails, or 59 | // returns an error, ProcessClientError is invoked instead of this method for that request. 60 | // 61 | // Note that response middleware are applied in reverse order. 62 | ProcessClientResponse(rsp mercury.Response, req mercury.Request) mercury.Response 63 | // ProcessClientError is called whenever a remote call results in an error (either local or remote). 64 | // 65 | // Note that error middleware are applied in reverse order. 66 | ProcessClientError(err *terrors.Error, req mercury.Request) 67 | } 68 | -------------------------------------------------------------------------------- /requesttree/middleware.go: -------------------------------------------------------------------------------- 1 | package requesttree 2 | 3 | import ( 4 | "github.com/monzo/terrors" 5 | "golang.org/x/net/context" 6 | 7 | "github.com/monzo/mercury" 8 | ) 9 | 10 | const ( 11 | parentIdHeader = "Parent-Request-ID" 12 | reqIdCtxKey = "Request-ID" 13 | 14 | currentServiceHeader = "Current-Service" 15 | currentEndpointHeader = "Current-Endpoint" 16 | originServiceHeader = "Origin-Service" 17 | originEndpointHeader = "Origin-Endpoint" 18 | ) 19 | 20 | type requestTreeMiddleware struct{} 21 | 22 | func (m requestTreeMiddleware) ProcessClientRequest(req mercury.Request) mercury.Request { 23 | if req.Headers()[parentIdHeader] == "" { // Don't overwrite an exiting header 24 | if parentId, ok := req.Context().Value(reqIdCtxKey).(string); ok && parentId != "" { 25 | req.SetHeader(parentIdHeader, parentId) 26 | } 27 | } 28 | 29 | // Pass through the current service and endpoint as the origin of this request 30 | req.SetHeader(originServiceHeader, CurrentServiceFor(req)) 31 | req.SetHeader(originEndpointHeader, CurrentEndpointFor(req)) 32 | 33 | return req 34 | } 35 | 36 | func (m requestTreeMiddleware) ProcessClientResponse(rsp mercury.Response, req mercury.Request) mercury.Response { 37 | return rsp 38 | } 39 | 40 | func (m requestTreeMiddleware) ProcessClientError(err *terrors.Error, req mercury.Request) { 41 | } 42 | 43 | func (m requestTreeMiddleware) ProcessServerRequest(req mercury.Request) (mercury.Request, mercury.Response) { 44 | req.SetContext(context.WithValue(req.Context(), reqIdCtxKey, req.Id())) 45 | if v := req.Headers()[parentIdHeader]; v != "" { 46 | req.SetContext(context.WithValue(req.Context(), parentIdCtxKey, v)) 47 | } 48 | 49 | // Set the current service and endpoint into the context 50 | req.SetContext(context.WithValue(req.Context(), currentServiceHeader, req.Service())) 51 | req.SetContext(context.WithValue(req.Context(), currentEndpointHeader, req.Endpoint())) 52 | 53 | // Set the originator into the context 54 | req.SetContext(context.WithValue(req.Context(), originServiceHeader, req.Headers()[originServiceHeader])) 55 | req.SetContext(context.WithValue(req.Context(), originEndpointHeader, req.Headers()[originEndpointHeader])) 56 | 57 | return req, nil 58 | } 59 | 60 | func (m requestTreeMiddleware) ProcessServerResponse(rsp mercury.Response, req mercury.Request) mercury.Response { 61 | if v, ok := req.Value(parentIdCtxKey).(string); ok && v != "" && rsp != nil { 62 | rsp.SetHeader(parentIdHeader, v) 63 | } 64 | return rsp 65 | } 66 | 67 | func Middleware() requestTreeMiddleware { 68 | return requestTreeMiddleware{} 69 | } 70 | 71 | // OriginServiceFor returns the originating service for this context 72 | func OriginServiceFor(ctx context.Context) string { 73 | if s, ok := ctx.Value(originServiceHeader).(string); ok { 74 | return s 75 | } 76 | return "" 77 | } 78 | 79 | // OriginEndpointFor returns the originating endpoint for this context 80 | func OriginEndpointFor(ctx context.Context) string { 81 | if e, ok := ctx.Value(originEndpointHeader).(string); ok { 82 | return e 83 | } 84 | return "" 85 | } 86 | 87 | // CurrentServiceFor returns the current service that this context is executing within 88 | func CurrentServiceFor(ctx context.Context) string { 89 | if s, ok := ctx.Value(currentServiceHeader).(string); ok { 90 | return s 91 | } 92 | return "" 93 | } 94 | 95 | // CurrentEndpointFor returns the current endpoint that this context is executing within 96 | func CurrentEndpointFor(ctx context.Context) string { 97 | if e, ok := ctx.Value(currentEndpointHeader).(string); ok { 98 | return e 99 | } 100 | return "" 101 | } 102 | -------------------------------------------------------------------------------- /requesttree/middleware_test.go: -------------------------------------------------------------------------------- 1 | package requesttree 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/monzo/typhon/mock" 7 | "github.com/stretchr/testify/suite" 8 | "golang.org/x/net/context" 9 | 10 | "github.com/monzo/mercury" 11 | "github.com/monzo/mercury/client" 12 | "github.com/monzo/mercury/server" 13 | "github.com/monzo/mercury/testproto" 14 | "github.com/monzo/mercury/transport" 15 | ) 16 | 17 | const testOriginServiceName = "service.requesttree-origin" 18 | const testServiceName = "service.requesttree-example" 19 | 20 | func TestParentRequestIdMiddlewareSuite(t *testing.T) { 21 | suite.Run(t, new(parentRequestIdMiddlewareSuite)) 22 | } 23 | 24 | type parentRequestIdMiddlewareSuite struct { 25 | suite.Suite 26 | trans transport.Transport 27 | srv server.Server 28 | } 29 | 30 | func (suite *parentRequestIdMiddlewareSuite) SetupTest() { 31 | suite.trans = mock.NewTransport() 32 | suite.srv = server.NewServer(testServiceName) 33 | suite.srv.AddMiddleware(Middleware()) 34 | 35 | suite.srv.AddEndpoints( 36 | server.Endpoint{ 37 | Name: "foo", 38 | Request: &testproto.DummyRequest{}, 39 | Response: &testproto.DummyResponse{}, 40 | Handler: func(req mercury.Request) (mercury.Response, error) { 41 | 42 | // Assert first call has correct origin 43 | suite.Assert().Equal(testOriginServiceName, OriginServiceFor(req)) 44 | suite.Assert().Equal("e2etest", OriginEndpointFor(req)) 45 | 46 | // Assert first call has updated to the current service 47 | suite.Assert().Equal(testServiceName, CurrentServiceFor(req)) 48 | suite.Assert().Equal("foo", CurrentEndpointFor(req)) 49 | 50 | cl := client.NewClient(). 51 | SetTransport(suite.trans). 52 | SetMiddleware([]client.ClientMiddleware{Middleware()}). 53 | Add(client.Call{ 54 | Uid: "call", 55 | Service: testServiceName, 56 | Endpoint: "foo-2", 57 | Body: &testproto.DummyRequest{}, 58 | Response: &testproto.DummyResponse{}, 59 | Context: req, 60 | }). 61 | Execute() 62 | return cl.Response("call"), cl.Errors().Combined() 63 | }}, 64 | server.Endpoint{ 65 | Name: "foo-2", 66 | Request: &testproto.DummyRequest{}, 67 | Response: &testproto.DummyResponse{}, 68 | Handler: func(req mercury.Request) (mercury.Response, error) { 69 | 70 | // Assert origin headers were set correctly as previous service 71 | suite.Assert().Equal(testServiceName, OriginServiceFor(req)) 72 | suite.Assert().Equal("foo", OriginEndpointFor(req)) 73 | 74 | // And that our current service's headers were set 75 | suite.Assert().Equal(testServiceName, CurrentServiceFor(req)) 76 | suite.Assert().Equal("foo-2", CurrentEndpointFor(req)) 77 | 78 | return req.Response(&testproto.DummyResponse{ 79 | Pong: ParentRequestIdFor(req)}), nil 80 | }}) 81 | suite.srv.Start(suite.trans) 82 | } 83 | 84 | func (suite *parentRequestIdMiddlewareSuite) TearDownTest() { 85 | suite.srv.Stop() 86 | suite.srv = nil 87 | suite.trans.Tomb().Killf("test ending") 88 | suite.trans.Tomb().Wait() 89 | suite.trans = nil 90 | } 91 | 92 | // TestE2E verifies parent request IDs are properly set on child requests 93 | func (suite *parentRequestIdMiddlewareSuite) TestE2E() { 94 | cli := client. 95 | NewClient(). 96 | SetTransport(suite.trans). 97 | SetMiddleware([]client.ClientMiddleware{Middleware()}) 98 | 99 | dummyOrigin := mercury.NewRequest() 100 | dummyOrigin.SetId("foobarbaz") 101 | ctx := context.WithValue(dummyOrigin.Context(), "Current-Service", testOriginServiceName) 102 | ctx = context.WithValue(ctx, "Current-Endpoint", "e2etest") 103 | dummyOrigin.SetContext(ctx) 104 | 105 | cli.Add(client.Call{ 106 | Uid: "call", 107 | Service: testServiceName, 108 | Endpoint: "foo", 109 | Context: dummyOrigin, 110 | Response: &testproto.DummyResponse{}, 111 | Body: &testproto.DummyRequest{}}) 112 | cli.Execute() 113 | 114 | suite.Assert().NoError(cli.Errors().Combined()) 115 | rsp := cli.Response("call") 116 | response := rsp.Body().(*testproto.DummyResponse) 117 | suite.Assert().NotEmpty(response.Pong) 118 | suite.Assert().Equal(response.Pong, rsp.Headers()[parentIdHeader]) 119 | } 120 | -------------------------------------------------------------------------------- /client/errorset_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/monzo/terrors" 7 | "github.com/stretchr/testify/suite" 8 | ) 9 | 10 | func TestErrorSetSuite(t *testing.T) { 11 | suite.Run(t, new(errorSetSuite)) 12 | } 13 | 14 | type errorSetSuite struct { 15 | suite.Suite 16 | errs ErrorSet 17 | rawErrs map[string]*terrors.Error 18 | } 19 | 20 | func (suite *errorSetSuite) SetupTest() { 21 | suite.rawErrs = map[string]*terrors.Error{} 22 | suite.errs = nil 23 | 24 | err := terrors.InternalService("", "uid1", nil) 25 | err.Params[errUidField] = "uid1" 26 | err.Params[errServiceField] = "service.uid1" 27 | err.Params[errEndpointField] = "uid1" 28 | suite.errs = append(suite.errs, err) 29 | suite.rawErrs["uid1"] = err 30 | 31 | err = terrors.InternalService("", "uid2", nil) 32 | err.Params[errUidField] = "uid2" 33 | err.Params[errServiceField] = "service.uid2" 34 | err.Params[errEndpointField] = "uid2" 35 | suite.errs = append(suite.errs, err) 36 | suite.rawErrs["uid2"] = err 37 | 38 | err = terrors.InternalService("", "uid3", nil) 39 | err.Params[errUidField] = "uid3" 40 | err.Params[errServiceField] = "service.uid2" // Same service as uid2 41 | err.Params[errEndpointField] = "uid3" 42 | suite.errs = append(suite.errs, err) 43 | suite.rawErrs["uid3"] = err 44 | } 45 | 46 | func (suite *errorSetSuite) TestBasic() { 47 | errs := suite.errs 48 | err1 := suite.rawErrs["uid1"] 49 | err2 := suite.rawErrs["uid2"] 50 | err3 := suite.rawErrs["uid3"] 51 | 52 | suite.Assert().Len(errs, 3) 53 | suite.Assert().Equal(err1, errs.ForUid("uid1")) 54 | suite.Assert().Equal(err2, errs.ForUid("uid2")) 55 | suite.Assert().Equal(err3, errs.ForUid("uid3")) 56 | suite.Assert().True(errs.Any()) 57 | } 58 | 59 | func (suite *errorSetSuite) TestIgnoreUid() { 60 | errs := suite.errs 61 | err2 := suite.rawErrs["uid2"] 62 | err3 := suite.rawErrs["uid3"] 63 | 64 | errs = errs.IgnoreUid("uid1") 65 | suite.Assert().Len(errs, 2) 66 | suite.Assert().Len(suite.errs, 3) 67 | suite.Assert().True(errs.Any()) 68 | suite.Assert().Nil(errs.ForUid("uid1")) 69 | suite.Assert().Equal(err2, errs.ForUid("uid2")) 70 | suite.Assert().Equal(err3, errs.ForUid("uid3")) 71 | 72 | errs = errs.IgnoreUid("uid2", "uid3") 73 | suite.Assert().Empty(errs) 74 | suite.Assert().Len(suite.errs, 3) 75 | suite.Assert().False(errs.Any()) 76 | suite.Assert().Nil(errs.ForUid("uid1")) 77 | suite.Assert().Nil(errs.ForUid("uid2")) 78 | suite.Assert().Nil(errs.ForUid("uid3")) 79 | } 80 | 81 | func (suite *errorSetSuite) TestIgnoreService() { 82 | errs := suite.errs 83 | err2 := suite.rawErrs["uid2"] 84 | err3 := suite.rawErrs["uid3"] 85 | 86 | errs = errs.IgnoreService("service.uid1") 87 | suite.Assert().Len(errs, 2) 88 | suite.Assert().Len(suite.errs, 3) 89 | suite.Assert().Nil(errs.ForUid("uid1")) 90 | suite.Assert().Equal(err2, errs.ForUid("uid2")) 91 | suite.Assert().Equal(err3, errs.ForUid("uid3")) 92 | suite.Assert().True(errs.Any()) 93 | 94 | errs = errs.IgnoreService("service.uid2") // uid2 and uid3 have the same service 95 | suite.Assert().Empty(errs) 96 | suite.Assert().Len(suite.errs, 3) 97 | suite.Assert().Nil(errs.ForUid("uid1")) 98 | suite.Assert().Nil(errs.ForUid("uid2")) 99 | suite.Assert().Nil(errs.ForUid("uid3")) 100 | suite.Assert().False(errs.Any()) 101 | } 102 | 103 | func (suite *errorSetSuite) TestIgnoreEndpoint() { 104 | errs := suite.errs 105 | err2 := suite.rawErrs["uid2"] 106 | err3 := suite.rawErrs["uid3"] 107 | 108 | errs = errs.IgnoreEndpoint("service.uid1", "uid1") 109 | suite.Assert().Len(errs, 2) 110 | suite.Assert().Len(suite.errs, 3) 111 | suite.Assert().True(errs.Any()) 112 | suite.Assert().Nil(errs.ForUid("uid1")) 113 | suite.Assert().Equal(err2, errs.ForUid("uid2")) 114 | suite.Assert().Equal(err3, errs.ForUid("uid3")) 115 | 116 | errs = errs.IgnoreEndpoint("service.uid1", "uid10") // Doesn't exist 117 | suite.Assert().Len(errs, 2) 118 | suite.Assert().Len(suite.errs, 3) 119 | suite.Assert().True(errs.Any()) 120 | suite.Assert().Nil(errs.ForUid("uid1")) 121 | suite.Assert().Equal(err2, errs.ForUid("uid2")) 122 | suite.Assert().Equal(err3, errs.ForUid("uid3")) 123 | } 124 | 125 | func (suite *errorSetSuite) TestIgnoreCode() { 126 | errs := suite.errs 127 | 128 | errs = errs.IgnoreCode(terrors.ErrInternalService) 129 | suite.Assert().Nil(errs.ForUid("uid1")) 130 | suite.Assert().Nil(errs.ForUid("uid2")) 131 | suite.Assert().Nil(errs.ForUid("uid3")) 132 | suite.Assert().Empty(errs) 133 | suite.Assert().Len(suite.errs, 3) 134 | suite.Assert().False(errs.Any()) 135 | } 136 | 137 | func (suite *errorSetSuite) TestForUid() { 138 | errs := suite.errs 139 | err1 := suite.rawErrs["uid1"] 140 | err2 := suite.rawErrs["uid2"] 141 | err3 := suite.rawErrs["uid3"] 142 | 143 | suite.Assert().Equal(err1, errs.ForUid("uid1")) 144 | suite.Assert().Equal(err2, errs.ForUid("uid2")) 145 | suite.Assert().Equal(err3, errs.ForUid("uid3")) 146 | } 147 | 148 | func (suite *errorSetSuite) TestErrors() { 149 | suite.Assert().Equal(suite.rawErrs, suite.errs.Errors()) 150 | } 151 | 152 | func (suite *errorSetSuite) TestMultiErrorPriority() { 153 | br := terrors.BadRequest("missing_param", "foo bar", nil) 154 | is := terrors.InternalService("something_broke", "hello world", nil) 155 | suite.Assert().True(higherPriority(is.Code, br.Code)) 156 | se := terrors.New("something_else", "baz", nil) 157 | suite.Assert().True(higherPriority(is.Code, se.Code)) 158 | suite.Assert().True(higherPriority(br.Code, se.Code)) 159 | 160 | es := ErrorSet{se, is, br} 161 | suite.Assert().Equal(is.Code, es.Combined().(*terrors.Error).Code) 162 | } 163 | -------------------------------------------------------------------------------- /client/errorset.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/monzo/terrors" 8 | ) 9 | 10 | const ( 11 | errUidField = "Client-Uid" 12 | errServiceField = "Client-Service" 13 | errEndpointField = "Client-Endpoint" 14 | ) 15 | 16 | var ( 17 | // used to work out which err to use when merging multiple. Lower number = higher priority 18 | codePriority = map[string]int{ 19 | terrors.ErrUnknown: 0, 20 | terrors.ErrInternalService: 1, 21 | terrors.ErrBadRequest: 2, 22 | terrors.ErrBadResponse: 3, 23 | terrors.ErrForbidden: 4, 24 | terrors.ErrUnauthorized: 5, 25 | terrors.ErrNotFound: 6, 26 | terrors.ErrTimeout: 7, 27 | } 28 | ) 29 | 30 | type ErrorSet []*terrors.Error 31 | 32 | // Copy returns a new ErrorSet containing the same errors as the receiver 33 | func (es ErrorSet) Copy() ErrorSet { 34 | result := make(ErrorSet, len(es)) 35 | copy(result, es) 36 | return result 37 | } 38 | 39 | // ForUid returns the error for a given request uid (or nil) 40 | func (es ErrorSet) ForUid(uid string) *terrors.Error { 41 | for _, e := range es { 42 | if euid, ok := e.Params[errUidField]; ok && euid == uid { 43 | return e 44 | } 45 | } 46 | return nil 47 | } 48 | 49 | // Any returns whether there are any contained errors 50 | func (es ErrorSet) Any() bool { 51 | return len(es) > 0 52 | } 53 | 54 | // Errors returns a map of request uids to their error, for requests which had errors 55 | func (es ErrorSet) Errors() map[string]*terrors.Error { 56 | result := make(map[string]*terrors.Error, len(es)) // Never return nil; with a map it's just fraught 57 | for _, err := range es { 58 | result[err.Params[errUidField]] = err 59 | } 60 | return result 61 | } 62 | 63 | // IgnoreCode returns a new ErrorSet without errors of the given codes 64 | func (es ErrorSet) IgnoreCode(codes ...string) ErrorSet { 65 | if len(codes) == 0 { 66 | return es 67 | } 68 | codesMap := make(map[string]struct{}, len(codes)) 69 | for _, c := range codes { 70 | codesMap[c] = struct{}{} 71 | } 72 | 73 | result := make(ErrorSet, 0, len(es)-len(codes)) 74 | for _, err := range es { 75 | if _, excluded := codesMap[err.Code]; !excluded { 76 | result = append(result, err) 77 | } 78 | } 79 | return result 80 | } 81 | 82 | // IgnoreEndpoint returns a new ErrorSet without errors from the given service endpoint 83 | func (es ErrorSet) IgnoreEndpoint(service, endpoint string) ErrorSet { 84 | result := make(ErrorSet, 0, len(es)-1) 85 | for _, err := range es { 86 | if !(err.Params[errServiceField] == service && err.Params[errEndpointField] == endpoint) { 87 | result = append(result, err) 88 | } 89 | } 90 | return result 91 | } 92 | 93 | // IgnoreService returns a new ErrorSet without errors from the given service(s) 94 | func (es ErrorSet) IgnoreService(services ...string) ErrorSet { 95 | if len(services) == 0 { 96 | return es 97 | } 98 | servicesMap := stringsMap(services...) 99 | result := make(ErrorSet, 0, len(es)-len(services)) 100 | for _, err := range es { 101 | if _, excluded := servicesMap[err.Params[errServiceField]]; !excluded { 102 | result = append(result, err) 103 | } 104 | } 105 | return result 106 | } 107 | 108 | // IgnoreUid returns a new ErrorSet without errors from the given request uid(s) 109 | func (es ErrorSet) IgnoreUid(uids ...string) ErrorSet { 110 | if len(uids) == 0 { 111 | return es 112 | } 113 | uidsMap := stringsMap(uids...) 114 | result := make(ErrorSet, 0, len(es)-len(uids)) 115 | for _, err := range es { 116 | if _, excluded := uidsMap[err.Params[errUidField]]; !excluded { 117 | result = append(result, err) 118 | } 119 | } 120 | return result 121 | } 122 | 123 | // sanitiseContext takes an error context and removes client-specific things from it (in-place) 124 | func (es ErrorSet) sanitiseContext(ctx map[string]string) { 125 | delete(ctx, errUidField) 126 | delete(ctx, errServiceField) 127 | delete(ctx, errEndpointField) 128 | } 129 | 130 | // returns true if this has higher priority than that 131 | func higherPriority(this, that string) bool { 132 | // code priority is based on first part of the dotted code before the first dot 133 | thisPr, ok := codePriority[strings.Split(this, ".")[0]] 134 | if !ok { 135 | thisPr = 1000 136 | } 137 | thatPr, ok := codePriority[strings.Split(that, ".")[0]] 138 | if !ok { 139 | thatPr = 1000 140 | } 141 | return thisPr < thatPr 142 | } 143 | 144 | // Combined returns a combined error from the set. If there is only one error, it is returned unmolested. If there are 145 | // more, they are all "flattened" into a single error. Where codes differ, they are normalised to that with the lowest 146 | // index. 147 | func (es ErrorSet) Combined() error { 148 | switch len(es) { 149 | case 0: 150 | return nil 151 | 152 | case 1: 153 | return es[0] 154 | 155 | default: 156 | msg := fmt.Sprintf("%s, and %d more errors", es[0].Message, len(es)-1) 157 | result := terrors.New(es[0].Code, msg, nil) 158 | 159 | params := []map[string]string{} 160 | for _, err := range es { 161 | if higherPriority(err.Code, result.Code) { 162 | result.Code = err.Code 163 | } 164 | params = append(params, err.Params) 165 | } 166 | 167 | result.Params = mergeContexts(params...) 168 | es.sanitiseContext(result.Params) 169 | return result 170 | } 171 | } 172 | 173 | // Error satisfies Go's Error interface 174 | func (es ErrorSet) Error() string { 175 | if err := es.Combined(); err != nil { 176 | return err.Error() 177 | } 178 | return "" 179 | } 180 | 181 | func (es ErrorSet) String() string { 182 | return es.Error() 183 | } 184 | -------------------------------------------------------------------------------- /service/serverclient_test.go: -------------------------------------------------------------------------------- 1 | package service 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | "time" 7 | 8 | "github.com/monzo/terrors" 9 | "github.com/monzo/typhon/mock" 10 | "github.com/stretchr/testify/suite" 11 | 12 | "github.com/monzo/mercury" 13 | "github.com/monzo/mercury/client" 14 | "github.com/monzo/mercury/marshaling" 15 | "github.com/monzo/mercury/server" 16 | "github.com/monzo/mercury/testproto" 17 | "github.com/monzo/mercury/transport" 18 | ) 19 | 20 | const testServiceName = "service.client-server-example" 21 | 22 | func TestClientServerSuite_MockTransport(t *testing.T) { 23 | suite.Run(t, &clientServerSuite{ 24 | TransF: func() transport.Transport { 25 | return mock.NewTransport() 26 | }}) 27 | } 28 | 29 | type clientServerSuite struct { 30 | suite.Suite 31 | TransF func() transport.Transport 32 | trans transport.Transport 33 | server server.Server 34 | } 35 | 36 | func (suite *clientServerSuite) SetupSuite() { 37 | trans := suite.TransF() 38 | select { 39 | case <-trans.Ready(): 40 | case <-time.After(2 * time.Second): 41 | panic("transport not ready") 42 | } 43 | suite.trans = trans 44 | } 45 | 46 | func (suite *clientServerSuite) SetupTest() { 47 | suite.server = server.NewServer(testServiceName) 48 | suite.server.SetMiddleware(DefaultServerMiddleware()) 49 | suite.server.Start(suite.trans) 50 | } 51 | 52 | func (suite *clientServerSuite) TearDownTest() { 53 | suite.server.Stop() 54 | suite.server = nil 55 | } 56 | 57 | func (suite *clientServerSuite) TearDownSuite() { 58 | suite.trans.Tomb().Killf("Test ending") 59 | suite.trans.Tomb().Wait() 60 | suite.trans = nil 61 | } 62 | 63 | func (suite *clientServerSuite) TestE2E() { 64 | suite.server.AddEndpoints( 65 | server.Endpoint{ 66 | Name: "test", 67 | Request: new(testproto.DummyRequest), 68 | Response: new(testproto.DummyResponse), 69 | Handler: func(req mercury.Request) (mercury.Response, error) { 70 | return req.Response(&testproto.DummyResponse{ 71 | Pong: "teste2e", 72 | }), nil 73 | }}) 74 | 75 | cl := client.NewClient(). 76 | SetMiddleware(DefaultClientMiddleware()). 77 | Add(client.Call{ 78 | Uid: "call", 79 | Service: testServiceName, 80 | Endpoint: "test", 81 | Body: &testproto.DummyRequest{}, 82 | Response: &testproto.DummyResponse{}, 83 | }). 84 | SetTransport(suite.trans). 85 | SetTimeout(time.Second). 86 | Execute() 87 | 88 | suite.Assert().False(cl.Errors().Any()) 89 | rsp := cl.Response("call") 90 | suite.Assert().NotNil(rsp) 91 | response := rsp.Body().(*testproto.DummyResponse) 92 | suite.Assert().Equal("teste2e", response.Pong) 93 | suite.Assert().False(rsp.IsError()) 94 | suite.Assert().Nil(rsp.Error()) 95 | } 96 | 97 | // TestErrors verifies that an error sent from a handler is correctly returned by a client 98 | func (suite *clientServerSuite) TestErrors() { 99 | suite.server.AddEndpoints(server.Endpoint{ 100 | Name: "error", 101 | Request: new(testproto.DummyRequest), 102 | Response: new(testproto.DummyResponse), 103 | Handler: func(req mercury.Request) (mercury.Response, error) { 104 | return nil, terrors.BadRequest("", "naughty naughty", nil) 105 | }}) 106 | 107 | cl := client.NewClient(). 108 | SetMiddleware(DefaultClientMiddleware()). 109 | Add( 110 | client.Call{ 111 | Uid: "call", 112 | Service: testServiceName, 113 | Endpoint: "error", 114 | Body: &testproto.DummyRequest{}, 115 | Response: &testproto.DummyResponse{}, 116 | }). 117 | SetTransport(suite.trans). 118 | SetTimeout(time.Second). 119 | Execute() 120 | 121 | suite.Assert().True(cl.Errors().Any()) 122 | err := cl.Errors().ForUid("call") 123 | suite.Require().NotNil(err) 124 | suite.Assert().Equal(terrors.ErrBadRequest, err.Code) 125 | 126 | rsp := mercury.FromTyphonResponse(cl.Response("call").Copy()) 127 | rsp.SetBody("FOO") // Deliberately set this to verify it is not mutated while accessing the error 128 | suite.Require().NotNil(rsp) 129 | suite.Assert().True(rsp.IsError()) 130 | suite.Assert().NotNil(rsp.Error()) 131 | suite.Assert().IsType(&terrors.Error{}, rsp.Error()) 132 | err = rsp.Error().(*terrors.Error) 133 | suite.Assert().Equal(terrors.ErrBadRequest, err.Code) 134 | suite.Assert().Equal("FOO", rsp.Body()) 135 | } 136 | 137 | // TestJSON verifies a JSON request and response can be received from a protobuf handler 138 | func (suite *clientServerSuite) TestJSON() { 139 | suite.server.AddEndpoints( 140 | server.Endpoint{ 141 | Name: "test", 142 | Request: new(testproto.DummyRequest), 143 | Response: new(testproto.DummyResponse), 144 | Handler: func(req mercury.Request) (mercury.Response, error) { 145 | request := req.Body().(*testproto.DummyRequest) 146 | return req.Response(&testproto.DummyResponse{ 147 | Pong: request.Ping, 148 | }), nil 149 | }}) 150 | 151 | req := mercury.NewRequest() 152 | req.SetService(testServiceName) 153 | req.SetEndpoint("test") 154 | req.SetPayload([]byte(`{ "ping": "blah blah blah" }`)) 155 | req.SetHeader(marshaling.ContentTypeHeader, "application/json") 156 | req.SetHeader(marshaling.AcceptHeader, "application/json") 157 | 158 | cl := client.NewClient(). 159 | SetMiddleware(DefaultClientMiddleware()). 160 | AddRequest("call", req). 161 | SetTransport(suite.trans). 162 | SetTimeout(time.Second). 163 | Execute() 164 | 165 | suite.Assert().False(cl.Errors().Any()) 166 | rsp := cl.Response("call") 167 | suite.Assert().NotNil(rsp) 168 | var body map[string]string 169 | suite.Assert().NoError(json.Unmarshal(rsp.Payload(), &body)) 170 | suite.Assert().NotNil(body) 171 | suite.Assert().Equal(1, len(body)) 172 | suite.Assert().Equal("blah blah blah", body["pong"]) 173 | } 174 | -------------------------------------------------------------------------------- /server/srv.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "sync" 7 | "time" 8 | 9 | log "github.com/monzo/slog" 10 | "github.com/monzo/terrors" 11 | tmsg "github.com/monzo/typhon/message" 12 | ttrans "github.com/monzo/typhon/transport" 13 | "golang.org/x/net/context" 14 | "gopkg.in/tomb.v2" 15 | 16 | "github.com/monzo/mercury" 17 | "github.com/monzo/mercury/transport" 18 | ) 19 | 20 | const ( 21 | connectTimeout = 30 * time.Second 22 | ) 23 | 24 | var ( 25 | ErrAlreadyRunning error = terrors.InternalService("", "Server is already running", nil) // empty dotted code so impl details don't leak outside 26 | ErrTransportClosed error = terrors.InternalService("", "Transport closed", nil) 27 | errEndpointNotFound = terrors.BadRequest("endpoint_not_found", "Endpoint not found", nil) 28 | defaultMiddleware []ServerMiddleware 29 | defaultMiddlewareM sync.RWMutex 30 | ) 31 | 32 | func NewServer(name string) Server { 33 | defaultMiddlewareM.RLock() 34 | middleware := defaultMiddleware 35 | defaultMiddlewareM.RUnlock() 36 | 37 | return &server{ 38 | name: name, 39 | middleware: middleware, 40 | } 41 | } 42 | 43 | func SetDefaultMiddleware(middleware []ServerMiddleware) { 44 | defaultMiddlewareM.Lock() 45 | defer defaultMiddlewareM.Unlock() 46 | defaultMiddleware = middleware 47 | } 48 | 49 | type server struct { 50 | name string // server name (registered with the transport; immutable) 51 | endpoints map[string]Endpoint // endpoint name: Endpoint 52 | endpointsM sync.RWMutex // protects endpoints 53 | workerTomb *tomb.Tomb // runs as long as there is a worker consuming Requests 54 | workerTombM sync.RWMutex // protects workerTomb 55 | middleware []ServerMiddleware // applied in-order for requests, reverse-order for responses 56 | middlewareM sync.RWMutex // protects middleware 57 | } 58 | 59 | func (s *server) Name() string { 60 | return s.name 61 | } 62 | 63 | func (s *server) AddEndpoints(eps ...Endpoint) { 64 | // Check the endpoint is valid (panic if not) 65 | for _, ep := range eps { 66 | if ep.Handler == nil { 67 | panic(fmt.Sprintf("Endpoint %s has no handler function", ep.Name)) 68 | } 69 | } 70 | 71 | s.endpointsM.Lock() 72 | defer s.endpointsM.Unlock() 73 | if s.endpoints == nil { 74 | s.endpoints = make(map[string]Endpoint, len(eps)) 75 | } 76 | for _, e := range eps { 77 | // if e.Request == nil || e.Response == nil { 78 | // panic(fmt.Sprintf("Endpoint \"%s\" must have Request and Response defined", e.Name)) 79 | // } 80 | s.endpoints[e.Name] = e 81 | } 82 | } 83 | 84 | func (s *server) RemoveEndpoints(eps ...Endpoint) { 85 | s.endpointsM.Lock() 86 | defer s.endpointsM.Unlock() 87 | for _, e := range eps { 88 | delete(s.endpoints, e.Name) 89 | } 90 | } 91 | func (s *server) Endpoints() []Endpoint { 92 | s.endpointsM.RLock() 93 | defer s.endpointsM.RUnlock() 94 | result := make([]Endpoint, 0, len(s.endpoints)) 95 | for _, ep := range s.endpoints { 96 | result = append(result, ep) 97 | } 98 | return result 99 | } 100 | 101 | func (s *server) Endpoint(path string) (Endpoint, bool) { 102 | s.endpointsM.RLock() 103 | defer s.endpointsM.RUnlock() 104 | ep, ok := s.endpoints[path] 105 | if !ok && strings.HasPrefix(path, "/") { // Try looking for a "legacy" match without the leading slash 106 | ep, ok = s.endpoints[strings.TrimPrefix(path, "/")] 107 | } 108 | return ep, ok 109 | } 110 | 111 | func (s *server) start(trans transport.Transport) (*tomb.Tomb, error) { 112 | ctx := context.Background() 113 | 114 | s.workerTombM.Lock() 115 | if s.workerTomb != nil { 116 | s.workerTombM.Unlock() 117 | return nil, ErrAlreadyRunning 118 | } 119 | tm := new(tomb.Tomb) 120 | s.workerTomb = tm 121 | s.workerTombM.Unlock() 122 | 123 | stop := func() { 124 | trans.StopListening(s.Name()) 125 | s.workerTombM.Lock() 126 | s.workerTomb = nil 127 | s.workerTombM.Unlock() 128 | } 129 | 130 | var inbound chan tmsg.Request 131 | connect := func() error { 132 | select { 133 | case <-trans.Ready(): 134 | inbound = make(chan tmsg.Request, 500) 135 | return trans.Listen(s.Name(), inbound) 136 | 137 | case <-time.After(connectTimeout): 138 | log.Warn(ctx, "[Mercury:Server] Timed out after %v waiting for transport readiness", connectTimeout) 139 | return ttrans.ErrTimeout 140 | } 141 | } 142 | 143 | // Block here purposefully (deliberately not in the goroutine below, because we want to report a connection error 144 | // to the caller) 145 | if err := connect(); err != nil { 146 | stop() 147 | return nil, err 148 | } 149 | 150 | tm.Go(func() error { 151 | defer stop() 152 | for { 153 | select { 154 | case req, ok := <-inbound: 155 | if !ok { 156 | // Received because the channel closed; try to reconnect 157 | log.Warn(ctx, "[Mercury:Server] Inbound channel closed; trying to reconnect…") 158 | if err := connect(); err != nil { 159 | log.Critical(ctx, "[Mercury:Server] Could not reconnect after channel close: %s", err) 160 | return err 161 | } 162 | } else { 163 | go s.handle(trans, req) 164 | } 165 | 166 | case <-tm.Dying(): 167 | return tomb.ErrDying 168 | } 169 | } 170 | }) 171 | return tm, nil 172 | } 173 | 174 | func (s *server) Start(trans transport.Transport) error { 175 | _, err := s.start(trans) 176 | return err 177 | } 178 | 179 | func (s *server) Run(trans transport.Transport) { 180 | if tm, err := s.start(trans); err != nil || tm == nil { 181 | panic(err) 182 | } else if err := tm.Wait(); err != nil { 183 | panic(err) 184 | } 185 | } 186 | 187 | func (s *server) Stop() { 188 | s.workerTombM.RLock() 189 | tm := s.workerTomb 190 | s.workerTombM.RUnlock() 191 | if tm != nil { 192 | tm.Killf("Stop() called") 193 | tm.Wait() 194 | } 195 | } 196 | 197 | func (s *server) applyRequestMiddleware(req mercury.Request) (mercury.Request, mercury.Response) { 198 | s.middlewareM.RLock() 199 | mws := s.middleware 200 | s.middlewareM.RUnlock() 201 | for _, mw := range mws { 202 | if req_, rsp := mw.ProcessServerRequest(req); rsp != nil { 203 | return req_, rsp 204 | } else { 205 | req = req_ 206 | } 207 | } 208 | return req, nil 209 | } 210 | 211 | func (s *server) applyResponseMiddleware(rsp mercury.Response, req mercury.Request) mercury.Response { 212 | s.middlewareM.RLock() 213 | mws := s.middleware 214 | s.middlewareM.RUnlock() 215 | for i := len(mws) - 1; i >= 0; i-- { // reverse order 216 | mw := mws[i] 217 | rsp = mw.ProcessServerResponse(rsp, req) 218 | } 219 | return rsp 220 | } 221 | 222 | func (s *server) handle(trans transport.Transport, req_ tmsg.Request) { 223 | req := mercury.FromTyphonRequest(req_) 224 | req, rsp := s.applyRequestMiddleware(req) 225 | 226 | if rsp == nil { 227 | if ep, ok := s.Endpoint(req.Endpoint()); !ok { 228 | log.Warn(req, "[Mercury:Server] Received request %s for unknown endpoint %s", req.Id(), req.Endpoint()) 229 | rsp = ErrorResponse(req, errEndpointNotFound) 230 | } else { 231 | if rsp_, err := ep.Handle(req); err != nil { 232 | rsp = ErrorResponse(req, err) 233 | log.Info(req, "[Mercury:Server] Error from endpoint %s for %v: %v", ep.Name, req, err, map[string]string{ 234 | "request_payload": string(req.Payload())}) 235 | } else if rsp_ == nil { 236 | rsp = req.Response(nil) 237 | } else { 238 | rsp = rsp_ 239 | } 240 | } 241 | } 242 | rsp = s.applyResponseMiddleware(rsp, req) 243 | if rsp != nil { 244 | trans.Respond(req, rsp) 245 | } 246 | } 247 | 248 | func (s *server) Middleware() []ServerMiddleware { 249 | // Note that no operation exists that mutates a particular element; this is very deliberate and means we do not 250 | // need to hold a read lock when iterating over the middleware slice, only when getting a reference to the slice. 251 | s.middlewareM.RLock() 252 | mws := s.middleware 253 | s.middlewareM.RUnlock() 254 | result := make([]ServerMiddleware, len(mws)) 255 | copy(result, mws) 256 | return result 257 | } 258 | 259 | func (s *server) SetMiddleware(mws []ServerMiddleware) { 260 | s.middlewareM.Lock() 261 | defer s.middlewareM.Unlock() 262 | s.middleware = mws 263 | } 264 | 265 | func (s *server) AddMiddleware(mw ServerMiddleware) { 266 | s.middlewareM.Lock() 267 | defer s.middlewareM.Unlock() 268 | s.middleware = append(s.middleware, mw) 269 | } 270 | 271 | // ErrorResponse constructs a response for the given request, with the given error as its contents. Mercury clients 272 | // know how to unmarshal these errors. 273 | func ErrorResponse(req mercury.Request, err error) mercury.Response { 274 | rsp := req.Response(nil) 275 | var terr *terrors.Error 276 | if err != nil { 277 | terr = terrors.Wrap(err, nil).(*terrors.Error) 278 | } 279 | rsp.SetBody(terrors.Marshal(terr)) 280 | if err := tmsg.JSONMarshaler().MarshalBody(rsp); err != nil { 281 | log.Error(req, "[Mercury:Server] Failed to marshal error response: %v", err) 282 | return nil // Not much we can do here 283 | } 284 | rsp.SetIsError(true) 285 | return rsp 286 | } 287 | -------------------------------------------------------------------------------- /client/client.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "reflect" 5 | "sync" 6 | "time" 7 | 8 | log "github.com/monzo/slog" 9 | "github.com/monzo/terrors" 10 | tperrors "github.com/monzo/terrors/proto" 11 | tmsg "github.com/monzo/typhon/message" 12 | "github.com/nu7hatch/gouuid" 13 | 14 | "github.com/monzo/mercury" 15 | "github.com/monzo/mercury/marshaling" 16 | "github.com/monzo/mercury/transport" 17 | ) 18 | 19 | const defaultTimeout = 10 * time.Second 20 | 21 | var ( 22 | defaultMiddleware []ClientMiddleware 23 | defaultMiddlewareM sync.RWMutex 24 | ) 25 | 26 | type clientCall struct { 27 | uid string // unique identifier within a client 28 | req mercury.Request // may be nil in the case of a request marshaling failure 29 | rsp mercury.Response // set when a response is received 30 | rspProto interface{} // shared with rsp when unmarshaled 31 | err *terrors.Error // execution error or unmarshalled (remote) error 32 | } 33 | 34 | type client struct { 35 | sync.RWMutex 36 | calls map[string]clientCall // uid: call 37 | doneC chan struct{} // closed when execution has finished; immutable 38 | execC chan struct{} // closed when execution begins; immutable 39 | execOnce sync.Once // ensures execution only happens once 40 | timeout time.Duration // default: defaultTimeout 41 | trans transport.Transport // defaults to the global default 42 | middleware []ClientMiddleware // applied in-order for requests, reverse-order for responses 43 | } 44 | 45 | func NewClient() Client { 46 | defaultMiddlewareM.RLock() 47 | middleware := defaultMiddleware 48 | defaultMiddlewareM.RUnlock() 49 | 50 | return &client{ 51 | calls: make(map[string]clientCall), 52 | doneC: make(chan struct{}), 53 | execC: make(chan struct{}), 54 | timeout: defaultTimeout, 55 | middleware: middleware, 56 | } 57 | } 58 | 59 | func SetDefaultMiddleware(middleware []ClientMiddleware) { 60 | defaultMiddlewareM.Lock() 61 | defer defaultMiddlewareM.Unlock() 62 | defaultMiddleware = middleware 63 | } 64 | 65 | func (c *client) transport() transport.Transport { 66 | // Callers must hold (at least) a read lock 67 | if c.trans != nil { 68 | return c.trans 69 | } else { 70 | return transport.DefaultTransport() 71 | } 72 | } 73 | 74 | func (c *client) addCall(cc clientCall) { 75 | select { 76 | case <-c.execC: 77 | log.Warn(cc.req, "[Mercury:Client] Request added after client execution; discarding") 78 | return 79 | default: 80 | } 81 | 82 | c.Lock() 83 | defer c.Unlock() 84 | c.calls[cc.uid] = cc 85 | } 86 | 87 | func (c *client) Add(cl Call) Client { 88 | cc := clientCall{ 89 | uid: cl.Uid, 90 | rspProto: cl.Response, 91 | } 92 | req, err := cl.Request() 93 | if err != nil { 94 | cc.err = terrors.Wrap(err, nil).(*terrors.Error) 95 | } else { 96 | cc.req = req 97 | } 98 | c.addCall(cc) 99 | return c 100 | } 101 | 102 | func (c *client) AddRequest(uid string, req mercury.Request) Client { 103 | c.addCall(clientCall{ 104 | uid: uid, 105 | req: req, 106 | rspProto: nil, 107 | }) 108 | return c 109 | } 110 | 111 | func (c *client) Errors() ErrorSet { 112 | c.RLock() 113 | defer c.RUnlock() 114 | errs := ErrorSet(nil) 115 | for uid, call := range c.calls { 116 | if call.err != nil { 117 | err := *(call.err) // Modify a copy 118 | copyParams := make(map[string]string, len(err.Params)) 119 | for k, v := range err.Params { 120 | copyParams[k] = v 121 | } 122 | err.Params = copyParams 123 | err.Params[errUidField] = uid 124 | if call.req != nil { 125 | if err.Params[errServiceField] == "" { 126 | err.Params[errServiceField] = call.req.Service() 127 | } 128 | if err.Params[errEndpointField] == "" { 129 | err.Params[errEndpointField] = call.req.Endpoint() 130 | } 131 | } 132 | errs = append(errs, &err) 133 | } 134 | } 135 | return errs 136 | } 137 | 138 | func (c *client) unmarshaler(rsp mercury.Response, protocol interface{}) tmsg.Unmarshaler { 139 | result := marshaling.Unmarshaler(rsp.Headers()[marshaling.ContentTypeHeader], protocol) 140 | if result == nil { // Default to json 141 | result = marshaling.Unmarshaler(marshaling.JSONContentType, protocol) 142 | } 143 | return result 144 | } 145 | 146 | // performCall executes a single Call, unmarshals the response (if there is a response proto), and pushes the updted 147 | // clientCall down the response channel 148 | func (c *client) performCall(call clientCall, middleware []ClientMiddleware, trans transport.Transport, 149 | timeout time.Duration, completion chan<- clientCall) { 150 | 151 | req := call.req 152 | 153 | // Ensure we have a request ID before the request middleware is executed 154 | if id := req.Id(); id == "" { 155 | _uuid, err := uuid.NewV4() 156 | if err != nil { 157 | log.Error(call.req, "[Mercury:Client] Failed to generate request uuid: %v", err) 158 | call.err = terrors.Wrap(err, nil).(*terrors.Error) 159 | completion <- call 160 | return 161 | } 162 | req.SetId(_uuid.String()) 163 | } 164 | 165 | // Apply request middleware 166 | for _, md := range middleware { 167 | req = md.ProcessClientRequest(req) 168 | } 169 | 170 | rsp_, err := trans.Send(req, timeout) 171 | if err != nil { 172 | call.err = terrors.Wrap(err, nil).(*terrors.Error) 173 | } else if rsp_ != nil { 174 | rsp := mercury.FromTyphonResponse(rsp_) 175 | 176 | // For error responses, unmarshal the error, leaving the call's response nil 177 | if rsp.IsError() { 178 | errRsp := rsp.Copy() 179 | if unmarshalErr := c.unmarshaler(rsp, &tperrors.Error{}).UnmarshalPayload(errRsp); unmarshalErr != nil { 180 | call.err = terrors.WrapWithCode(unmarshalErr, nil, terrors.ErrBadResponse).(*terrors.Error) 181 | } else { 182 | err := errRsp.Body().(*tperrors.Error) 183 | call.err = terrors.Unmarshal(err) 184 | } 185 | 186 | // Set the response Body to a nil – but typed – interface to avoid type conversion panics if Body 187 | // properties are accessed in spite of the error 188 | // Relevant: http://golang.org/doc/faq#nil_error 189 | if call.rspProto != nil { 190 | bodyT := reflect.TypeOf(call.rspProto) 191 | rsp.SetBody(reflect.New(bodyT.Elem()).Interface()) 192 | } 193 | 194 | } else if call.rspProto != nil { 195 | rsp.SetBody(call.rspProto) 196 | if err := c.unmarshaler(rsp, call.rspProto).UnmarshalPayload(rsp); err != nil { 197 | call.err = terrors.WrapWithCode(err, nil, terrors.ErrBadResponse).(*terrors.Error) 198 | } 199 | } 200 | 201 | call.rsp = rsp 202 | } 203 | 204 | // Apply response/error middleware (in reverse order) 205 | for i := len(middleware) - 1; i >= 0; i-- { 206 | mw := middleware[i] 207 | if call.err != nil { 208 | mw.ProcessClientError(call.err, call.req) 209 | } else { 210 | call.rsp = mw.ProcessClientResponse(call.rsp, call.req) 211 | } 212 | } 213 | 214 | completion <- call 215 | } 216 | 217 | // exec actually executes the requests; called by Go() within a sync.Once. 218 | func (c *client) exec() { 219 | defer close(c.doneC) 220 | 221 | c.RLock() 222 | timeout := c.timeout 223 | calls := c.calls // We don't need to make a copy as calls cannot be mutated once execution begins 224 | trans := c.transport() 225 | middleware := c.middleware 226 | c.RUnlock() 227 | 228 | completedCallsC := make(chan clientCall, len(calls)) 229 | for _, call := range calls { 230 | if call.err != nil { 231 | completedCallsC <- call 232 | continue 233 | } else if trans == nil { 234 | call.err = terrors.InternalService("no_transport", "Client has no transport", nil) 235 | completedCallsC <- call 236 | continue 237 | } 238 | go c.performCall(call, middleware, trans, timeout, completedCallsC) 239 | } 240 | 241 | // Collect completed calls into a new map 242 | completedCalls := make(map[string]clientCall, cap(completedCallsC)) 243 | for i := 0; i < cap(completedCallsC); i++ { 244 | call := <-completedCallsC 245 | completedCalls[call.uid] = call 246 | } 247 | close(completedCallsC) 248 | 249 | c.Lock() 250 | defer c.Unlock() 251 | c.calls = completedCalls 252 | } 253 | 254 | func (c *client) Go() Client { 255 | c.execOnce.Do(func() { 256 | close(c.execC) 257 | go c.exec() 258 | }) 259 | return c 260 | } 261 | 262 | func (c *client) WaitC() <-chan struct{} { 263 | return c.doneC 264 | } 265 | 266 | func (c *client) Wait() Client { 267 | <-c.WaitC() 268 | return c 269 | } 270 | 271 | func (c *client) Execute() Client { 272 | return c.Go().Wait() 273 | } 274 | 275 | func (c *client) Response(uid string) mercury.Response { 276 | c.RLock() 277 | defer c.RUnlock() 278 | if call, ok := c.calls[uid]; ok { 279 | return call.rsp 280 | } 281 | return nil 282 | } 283 | 284 | func (c *client) SetTimeout(to time.Duration) Client { 285 | c.Lock() 286 | defer c.Unlock() 287 | c.timeout = to 288 | return c 289 | } 290 | 291 | func (c *client) SetTransport(trans transport.Transport) Client { 292 | c.Lock() 293 | defer c.Unlock() 294 | c.trans = trans 295 | return c 296 | } 297 | 298 | func (c *client) Middleware() []ClientMiddleware { 299 | c.RLock() 300 | defer c.RUnlock() 301 | return c.middleware 302 | } 303 | 304 | func (c *client) AddMiddleware(md ClientMiddleware) Client { 305 | c.Lock() 306 | defer c.Unlock() 307 | c.middleware = append(c.middleware, md) 308 | return c 309 | } 310 | 311 | func (c *client) SetMiddleware(mds []ClientMiddleware) Client { 312 | c.Lock() 313 | defer c.Unlock() 314 | c.middleware = mds 315 | return c 316 | } 317 | -------------------------------------------------------------------------------- /server/srv_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "encoding/json" 5 | "math/rand" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | "github.com/monzo/terrors" 11 | pe "github.com/monzo/terrors/proto" 12 | tmsg "github.com/monzo/typhon/message" 13 | "github.com/monzo/typhon/mock" 14 | "github.com/stretchr/testify/suite" 15 | 16 | "github.com/monzo/mercury" 17 | "github.com/monzo/mercury/marshaling" 18 | "github.com/monzo/mercury/testproto" 19 | "github.com/monzo/mercury/transport" 20 | ) 21 | 22 | const testServiceName = "service.server-example" 23 | 24 | func TestServerSuite_MockTransport(t *testing.T) { 25 | suite.Run(t, &serverSuite{ 26 | TransF: func() transport.Transport { 27 | return mock.NewTransport() 28 | }, 29 | }) 30 | } 31 | 32 | type serverSuite struct { 33 | suite.Suite 34 | TransF func() transport.Transport 35 | trans transport.Transport 36 | server Server 37 | } 38 | 39 | func (suite *serverSuite) SetupSuite() { 40 | // Share a single transport between all tests (which each use a different server). This deliberately tests the 41 | // underlying transport's ability to connect and disconnect a service while leaving the connection open. 42 | suite.trans = suite.TransF() 43 | select { 44 | case <-suite.trans.Ready(): 45 | case <-time.After(5 * time.Second): 46 | panic("transport not ready") 47 | } 48 | } 49 | 50 | func (suite *serverSuite) SetupTest() { 51 | suite.server = NewServer(testServiceName) 52 | suite.server.Start(suite.trans) 53 | } 54 | 55 | func (suite *serverSuite) TearDownTest() { 56 | suite.server.Stop() 57 | suite.server = nil 58 | } 59 | 60 | func (suite *serverSuite) TearDownSuite() { 61 | suite.trans.Tomb().Killf("Test ending") 62 | suite.trans.Tomb().Wait() 63 | suite.trans = nil 64 | } 65 | 66 | // TestRouting verifies a registered endpoint receives messages destined for it, and that responses are sent 67 | // appropriately 68 | func (suite *serverSuite) TestRouting() { 69 | srv := suite.server 70 | srv.AddEndpoints(Endpoint{ 71 | Name: "dummy", 72 | Request: new(testproto.DummyRequest), 73 | Response: new(testproto.DummyResponse), 74 | Handler: func(req mercury.Request) (mercury.Response, error) { 75 | request := req.Body().(*testproto.DummyRequest) 76 | rsp := req.Response(&testproto.DummyResponse{ 77 | Pong: request.Ping, 78 | }) 79 | rsp.SetHeader("X-Ping-Pong", request.Ping) 80 | return rsp, nil 81 | }}) 82 | 83 | req := mercury.NewRequest() 84 | req.SetService(testServiceName) 85 | req.SetEndpoint("dummy") 86 | req.SetBody(&testproto.DummyRequest{ 87 | Ping: "routing"}) 88 | suite.Assert().NoError(tmsg.JSONMarshaler().MarshalBody(req)) 89 | 90 | rsp, err := suite.trans.Send(req, time.Second) 91 | suite.Require().NoError(err) 92 | suite.Require().NotNil(rsp) 93 | 94 | suite.Require().NoError(tmsg.JSONUnmarshaler(new(testproto.DummyResponse)).UnmarshalPayload(rsp)) 95 | suite.Require().NotNil(rsp.Body()) 96 | suite.Require().IsType(new(testproto.DummyResponse), rsp.Body()) 97 | response := rsp.Body().(*testproto.DummyResponse) 98 | suite.Assert().Equal("routing", response.Pong) 99 | suite.Assert().Equal("routing", rsp.Headers()["X-Ping-Pong"]) 100 | } 101 | 102 | // TestErrorResponse tests that errors are serialised and returned to callers appropriately (as we are using the 103 | // transport directly here, we actually get a response containing an error, not a transport error) 104 | func (suite *serverSuite) TestErrorResponse() { 105 | srv := suite.server 106 | srv.AddEndpoints(Endpoint{ 107 | Name: "err", 108 | Request: new(testproto.DummyRequest), 109 | Response: new(testproto.DummyResponse), 110 | Handler: func(req mercury.Request) (mercury.Response, error) { 111 | request := req.Body().(*testproto.DummyRequest) 112 | return nil, terrors.NotFound("", request.Ping, nil) 113 | }}) 114 | 115 | req := mercury.NewRequest() 116 | req.SetService(testServiceName) 117 | req.SetEndpoint("err") 118 | req.SetBody(&testproto.DummyRequest{ 119 | Ping: "Foo bar baz"}) 120 | suite.Assert().NoError(tmsg.JSONMarshaler().MarshalBody(req)) 121 | 122 | rsp_, err := suite.trans.Send(req, time.Second) 123 | suite.Assert().NoError(err) 124 | suite.Assert().NotNil(rsp_) 125 | rsp := mercury.FromTyphonResponse(rsp_) 126 | suite.Assert().True(rsp.IsError()) 127 | 128 | errResponse := &pe.Error{} 129 | suite.Assert().NoError(json.Unmarshal(rsp.Payload(), errResponse)) 130 | terr := terrors.Unmarshal(errResponse) 131 | suite.Require().NotNil(terr) 132 | suite.Assert().Equal("Foo bar baz", terr.Message, string(rsp.Payload())) 133 | suite.Assert().Equal(terrors.ErrNotFound, terr.Code) 134 | } 135 | 136 | // TestNilResponse tests that a nil response correctly returns a Response with an empty payload to the caller 137 | func (suite *serverSuite) TestNilResponse() { 138 | srv := suite.server 139 | srv.AddEndpoints(Endpoint{ 140 | Name: "nil", 141 | Request: new(testproto.DummyRequest), 142 | Response: new(testproto.DummyResponse), 143 | Handler: func(req mercury.Request) (mercury.Response, error) { 144 | return nil, nil 145 | }}) 146 | 147 | req := mercury.NewRequest() 148 | req.SetService(testServiceName) 149 | req.SetBody(&testproto.DummyRequest{}) 150 | suite.Assert().NoError(tmsg.JSONMarshaler().MarshalBody(req)) 151 | req.SetEndpoint("nil") 152 | 153 | rsp, err := suite.trans.Send(req, time.Second) 154 | suite.Require().NoError(err) 155 | suite.Require().NotNil(rsp) 156 | suite.Assert().Len(rsp.Payload(), 0) 157 | } 158 | 159 | // TestEndpointNotFound tests that a Bad Request error is correctly returned on receiving an event for an unknown 160 | // endpoing 161 | func (suite *serverSuite) TestEndpointNotFound() { 162 | req := mercury.NewRequest() 163 | req.SetService(testServiceName) 164 | req.SetEndpoint("dummy") 165 | req.SetBody(&testproto.DummyRequest{ 166 | Ping: "routing"}) 167 | suite.Assert().NoError(tmsg.JSONMarshaler().MarshalBody(req)) 168 | 169 | rsp_, err := suite.trans.Send(req, time.Second) 170 | rsp := mercury.FromTyphonResponse(rsp_) 171 | suite.Require().NoError(err) 172 | suite.Require().NotNil(rsp) 173 | suite.Assert().True(rsp.IsError()) 174 | 175 | suite.Assert().NoError(tmsg.JSONUnmarshaler(new(pe.Error)).UnmarshalPayload(rsp)) 176 | suite.Assert().IsType(new(pe.Error), rsp.Body()) 177 | terr := terrors.Unmarshal(rsp.Body().(*pe.Error)) 178 | suite.Assert().Equal(terrors.ErrBadRequest+".endpoint_not_found", terr.Code) 179 | suite.Assert().Contains(terr.Error(), "Endpoint not found") 180 | } 181 | 182 | // TestRegisteringInvalidEndpoint tests that appropriate panics are raised when registering invalid Endpoints 183 | func (suite *serverSuite) TestRegisteringInvalidEndpoint() { 184 | srv := suite.server 185 | 186 | // An endpoint with no handler 187 | suite.Assert().Panics(func() { 188 | srv.AddEndpoints(Endpoint{ 189 | Name: "foo", 190 | Request: new(testproto.DummyRequest), 191 | Response: new(testproto.DummyResponse)}) 192 | }) 193 | } 194 | 195 | // TestRoutingParallel sends a bunch of requests in parallel to different endpoints and checks that the responses match 196 | // correctly. 200 workers, 100 requests each. 197 | func (suite *serverSuite) TestRoutingParallel() { 198 | if testing.Short() { 199 | suite.T().Skip("Skipping TestRoutingParallel in short mode") 200 | } 201 | 202 | names := [...]string{"1", "2", "3"} 203 | srv := suite.server 204 | srv.AddEndpoints( 205 | Endpoint{ 206 | Name: names[0], 207 | Request: new(testproto.DummyRequest), 208 | Response: new(testproto.DummyResponse), 209 | Handler: func(req mercury.Request) (mercury.Response, error) { 210 | return req.Response(&testproto.DummyResponse{Pong: names[0]}), nil 211 | }}, 212 | Endpoint{ 213 | Name: names[1], 214 | Request: new(testproto.DummyRequest), 215 | Response: new(testproto.DummyResponse), 216 | Handler: func(req mercury.Request) (mercury.Response, error) { 217 | return req.Response(&testproto.DummyResponse{Pong: names[1]}), nil 218 | }}, 219 | Endpoint{ 220 | Name: names[2], 221 | Request: new(testproto.DummyRequest), 222 | Response: new(testproto.DummyResponse), 223 | Handler: func(req mercury.Request) (mercury.Response, error) { 224 | return req.Response(&testproto.DummyResponse{Pong: names[2]}), nil 225 | }}) 226 | 227 | workers := 200 228 | wg := sync.WaitGroup{} 229 | wg.Add(workers) 230 | unmarshaler := tmsg.JSONUnmarshaler(new(testproto.DummyResponse)) 231 | work := func(i int) { 232 | defer wg.Done() 233 | rng := rand.New(rand.NewSource(time.Now().UnixNano())) 234 | ep := names[rng.Int()%len(names)] 235 | for i := 0; i < 100; i++ { 236 | req := mercury.NewRequest() 237 | req.SetService(testServiceName) 238 | req.SetEndpoint(ep) 239 | req.SetBody(&testproto.DummyRequest{}) 240 | suite.Assert().NoError(tmsg.JSONMarshaler().MarshalBody(req)) 241 | 242 | rsp, err := suite.trans.Send(req, time.Second) 243 | suite.Require().NoError(err) 244 | suite.Require().NotNil(rsp) 245 | suite.Require().NoError(unmarshaler.UnmarshalPayload(rsp)) 246 | response := rsp.Body().(*testproto.DummyResponse) 247 | suite.Assert().Equal(ep, response.Pong) 248 | } 249 | } 250 | 251 | for i := 0; i < workers; i++ { 252 | go work(i) 253 | } 254 | wg.Wait() 255 | } 256 | 257 | // TestJSONResponse verifies that a request sent with JSON content and an accepting JSON response, does in fact yield a 258 | // JSON response (from a proto handler) 259 | func (suite *serverSuite) TestJSONResponse() { 260 | srv := suite.server 261 | srv.AddEndpoints(Endpoint{ 262 | Name: "dummy", 263 | Request: new(testproto.DummyRequest), 264 | Response: new(testproto.DummyResponse), 265 | Handler: func(req mercury.Request) (mercury.Response, error) { 266 | request := req.Body().(*testproto.DummyRequest) 267 | return req.Response(&testproto.DummyResponse{ 268 | Pong: request.Ping, 269 | }), nil 270 | }}) 271 | 272 | req := mercury.NewRequest() 273 | req.SetService(testServiceName) 274 | req.SetEndpoint("dummy") 275 | req.SetBody(map[string]string{ 276 | "ping": "json"}) 277 | req.SetHeader(marshaling.AcceptHeader, "application/json") 278 | suite.Assert().NoError(tmsg.JSONMarshaler().MarshalBody(req)) 279 | 280 | rsp, err := suite.trans.Send(req, time.Second) 281 | suite.Require().NoError(err) 282 | suite.Require().NotNil(rsp) 283 | 284 | suite.Assert().Equal("application/json", rsp.Headers()[marshaling.ContentTypeHeader]) 285 | suite.Assert().Equal(`{"pong":"json"}`, string(rsp.Payload())) 286 | } 287 | -------------------------------------------------------------------------------- /client/client_test.go: -------------------------------------------------------------------------------- 1 | package client 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | "time" 7 | 8 | "github.com/monzo/terrors" 9 | tmsg "github.com/monzo/typhon/message" 10 | "github.com/monzo/typhon/mock" 11 | "github.com/stretchr/testify/suite" 12 | 13 | "github.com/monzo/mercury" 14 | "github.com/monzo/mercury/marshaling" 15 | "github.com/monzo/mercury/testproto" 16 | "github.com/monzo/mercury/transport" 17 | ) 18 | 19 | const testServiceName = "service.client-example" 20 | 21 | func TestClientSuite_MockTransport(t *testing.T) { 22 | suite.Run(t, &clientSuite{ 23 | TransF: func() transport.Transport { 24 | return mock.NewTransport() 25 | }, 26 | }) 27 | } 28 | 29 | type clientSuite struct { 30 | suite.Suite 31 | TransF func() transport.Transport 32 | trans transport.Transport 33 | } 34 | 35 | func (suite *clientSuite) SetupSuite() { 36 | trans := suite.TransF() 37 | select { 38 | case <-trans.Ready(): 39 | case <-time.After(2 * time.Second): 40 | panic("transport not ready") 41 | } 42 | suite.trans = trans 43 | 44 | // Add a listener that responds blindly to all messages 45 | inboundChan := make(chan tmsg.Request, 10) 46 | trans.Listen(testServiceName, inboundChan) 47 | go func() { 48 | for { 49 | select { 50 | case _req := <-inboundChan: 51 | req := mercury.FromTyphonRequest(_req) 52 | switch req.Endpoint() { 53 | case "timeout": 54 | continue 55 | 56 | case "invalid-payload": 57 | // Wrong proto here 58 | rsp := req.Response(nil) 59 | rsp.SetPayload([]byte("†HÎß ßHøÜ¬∂ÑT ∑ø®K")) 60 | suite.Require().NoError(trans.Respond(req, rsp)) 61 | 62 | case "error": 63 | err := terrors.BadRequest("", "foo bar", nil) 64 | rsp := req.Response(terrors.Marshal(err)) 65 | rsp.SetHeaders(req.Headers()) 66 | rsp.SetIsError(true) 67 | suite.Require().NoError(trans.Respond(req, rsp)) 68 | 69 | case "bulls--t": 70 | rsp := req.Response(map[string]string{}) 71 | rsp.SetHeaders(req.Headers()) 72 | rsp.SetHeader(marshaling.ContentTypeHeader, "application/bulls--t") 73 | suite.Require().NoError(trans.Respond(req, rsp)) 74 | 75 | default: 76 | rsp := req.Response(&testproto.DummyResponse{ 77 | Pong: "Pong"}) 78 | rsp.SetHeaders(req.Headers()) 79 | suite.Require().NoError(tmsg.JSONMarshaler().MarshalBody(rsp)) 80 | suite.Require().NoError(trans.Respond(req, rsp)) 81 | } 82 | 83 | case <-trans.Tomb().Dying(): 84 | return 85 | } 86 | } 87 | }() 88 | } 89 | 90 | func (suite *clientSuite) TearDownSuite() { 91 | trans := suite.trans 92 | trans.Tomb().Killf("Test ending") 93 | trans.Tomb().Wait() 94 | suite.trans = nil 95 | } 96 | 97 | // TestExecuting tests an end-to-end flow of one request 98 | func (suite *clientSuite) TestExecuting() { 99 | response := new(testproto.DummyResponse) 100 | client := NewClient().Add(Call{ 101 | Uid: "call1", 102 | Service: testServiceName, 103 | Endpoint: "foo", 104 | Response: response, 105 | }).SetTransport(suite.trans).Execute() 106 | 107 | rsp := client.Response("call1") 108 | 109 | suite.Assert().Empty(client.Errors()) 110 | suite.Require().NotNil(rsp) 111 | suite.Assert().Equal("Pong", response.Pong) 112 | suite.Assert().Equal(response, rsp.Body()) 113 | suite.Assert().Equal("Pong", rsp.Body().(*testproto.DummyResponse).Pong) 114 | } 115 | 116 | // TestTimeout verifies the timeout functionality of the client behaves as expected (especially with multiple calls, 117 | // some of which succeed and some of which fail). 118 | func (suite *clientSuite) TestTimeout() { 119 | client := NewClient().Add(Call{ 120 | Uid: "call1", 121 | Service: testServiceName, 122 | Endpoint: "timeout", 123 | Response: new(testproto.DummyResponse), 124 | }).SetTransport(suite.trans).SetTimeout(time.Second).Go() 125 | 126 | select { 127 | case <-client.WaitC(): 128 | case <-time.After(time.Second + 50*time.Millisecond): 129 | suite.Fail("Should have timed out") 130 | } 131 | 132 | suite.Assert().Len(client.Errors(), 1) 133 | err := client.Errors().ForUid("call1") 134 | suite.Assert().Error(err) 135 | suite.Assert().Equal(terrors.ErrTimeout, err.Code, err.Message) 136 | } 137 | 138 | // TestRawRequest verifies that adding raw requests (rather than Calls) works as expected. 139 | 140 | // TestResponseUnmarshalingError verifies that unmarshaling errors are handled appropriately (in this case by expecting 141 | // a different response protocol to what is received). 142 | // 143 | // This also conveniently verifies that Clients use custom transports appropriately. 144 | func (suite *clientSuite) TestResponseUnmarshalingError() { 145 | client := NewClient().Add(Call{ 146 | Uid: "call1", 147 | Service: testServiceName, 148 | Endpoint: "invalid-payload", 149 | Response: new(testproto.DummyResponse), 150 | }). 151 | SetTimeout(time.Second). 152 | SetTransport(suite.trans). 153 | Execute() 154 | 155 | suite.Assert().Len(client.Errors(), 1) 156 | err := client.Errors().ForUid("call1") 157 | suite.Assert().Equal(terrors.ErrBadResponse, err.Code) 158 | 159 | rsp := client.Response("call1") 160 | suite.Require().NotNil(rsp) 161 | response := rsp.Body().(*testproto.DummyResponse) 162 | suite.Assert().Equal("", response.Pong) 163 | } 164 | 165 | type testMw struct { 166 | err *terrors.Error 167 | } 168 | 169 | func (m *testMw) ProcessClientRequest(req mercury.Request) mercury.Request { 170 | req.SetHeader("X-Foo", "X-Bar") 171 | return req 172 | } 173 | 174 | func (m *testMw) ProcessClientResponse(rsp mercury.Response, req mercury.Request) mercury.Response { 175 | rsp.SetHeader("X-Boop", "Boop") 176 | return rsp 177 | } 178 | 179 | func (m *testMw) ProcessClientError(err *terrors.Error, req mercury.Request) { 180 | m.err = err 181 | } 182 | 183 | // TestMiddleware verifies client middleware methods are executed as expected 184 | func (suite *clientSuite) TestMiddleware() { 185 | mw := &testMw{} 186 | client := NewClient(). 187 | AddMiddleware(mw). 188 | Add(Call{ 189 | Uid: "call1", 190 | Service: testServiceName, 191 | Endpoint: "ping", 192 | Response: new(testproto.DummyResponse), 193 | }). 194 | SetTimeout(time.Second). 195 | SetTransport(suite.trans). 196 | Execute() 197 | 198 | suite.Assert().Empty(client.Errors()) 199 | rsp := client.Response("call1") 200 | suite.Require().NotNil(rsp) 201 | // ProcessClientRequest should have set X-Foo: Bar (and ping echoes the headers) 202 | suite.Assert().Equal("X-Bar", rsp.Headers()["X-Foo"]) 203 | // ProcessClientResponse should have set X-Boop: Boop 204 | suite.Assert().Equal("Boop", rsp.Headers()["X-Boop"]) 205 | suite.Assert().Nil(mw.err) 206 | client = NewClient(). 207 | AddMiddleware(mw). 208 | Add(Call{ 209 | Uid: "call1", 210 | Service: testServiceName, 211 | Endpoint: "error", 212 | Response: new(testproto.DummyResponse), 213 | }). 214 | SetTimeout(time.Second). 215 | SetTransport(suite.trans). 216 | Execute() 217 | 218 | rsp = client.Response("call1") 219 | suite.Require().NotNil(rsp) 220 | suite.Assert().Len(client.Errors(), 1) 221 | err := client.Errors().ForUid("call1") 222 | suite.Require().Error(err) 223 | // ProcessClientError should have stored the error 224 | suite.Assert().Equal(err.Code, mw.err.Code) 225 | suite.Assert().Equal(err.Message, mw.err.Message) 226 | // ProcessClientRequest should have set X-Foo: Bar (and ping echoes the headers) 227 | suite.Assert().Equal("X-Bar", rsp.Headers()["X-Foo"]) 228 | // ProcessClientResponse should not have run 229 | suite.Assert().Empty(rsp.Headers()["X-Boop"]) 230 | } 231 | 232 | // TestParallelCalls verifies that many calls made in parallel are routed correctly, and their responses/errors are 233 | // available in the proper places. 234 | func (suite *clientSuite) TestParallelCalls() { 235 | client := NewClient(). 236 | SetTimeout(5 * time.Second). 237 | SetTransport(suite.trans) 238 | 239 | for i := 0; i < 100; i++ { 240 | uid := fmt.Sprintf("call%d", i) 241 | client = client.Add(Call{ 242 | Uid: uid, 243 | Service: testServiceName, 244 | Endpoint: "foo", 245 | Response: new(testproto.DummyResponse), 246 | Headers: map[string]string{ 247 | "Iteration": uid}}) 248 | } 249 | 250 | client.Execute() 251 | suite.Require().Empty(client.Errors()) 252 | 253 | for i := 0; i < 100; i++ { 254 | uid := fmt.Sprintf("call%d", i) 255 | rsp := client.Response(uid) 256 | suite.Assert().Equal(uid, rsp.Headers()["Iteration"]) 257 | } 258 | } 259 | 260 | type bsMarshaler struct{} 261 | 262 | func (m bsMarshaler) MarshalBody(msg tmsg.Message) error { 263 | msg.SetPayload([]byte("total garbage")) 264 | return nil 265 | } 266 | 267 | func (m bsMarshaler) UnmarshalPayload(msg tmsg.Message) error { 268 | msg.SetBody(map[string]string{ 269 | "1": "2", 270 | }) 271 | return nil 272 | } 273 | 274 | // TestCustomMarshaler registers a custom marshaler and then checks a request can be made using it 275 | func (suite *clientSuite) TestCustomMarshaler() { 276 | marshaling.Register( 277 | "application/bulls--t", 278 | func() tmsg.Marshaler { return bsMarshaler{} }, 279 | func(_ interface{}) tmsg.Unmarshaler { return bsMarshaler{} }, 280 | ) 281 | 282 | cl := NewClient(). 283 | Add(Call{ 284 | Uid: "foo", 285 | Service: testServiceName, 286 | Endpoint: "bulls--t", 287 | Body: map[string]string{}, 288 | Response: map[string]string{}, 289 | Headers: map[string]string{ 290 | marshaling.ContentTypeHeader: "application/bulls--t", 291 | marshaling.AcceptHeader: "application/bulls--t"}}). 292 | SetTransport(suite.trans). 293 | SetTimeout(time.Second) 294 | 295 | suite.Require().NoError(cl.Execute().Errors().Combined()) 296 | rsp := cl.Response("foo") 297 | suite.Require().NotNil(rsp) 298 | suite.Require().IsType(map[string]string{}, rsp.Body()) 299 | suite.Require().Equal(map[string]string{ 300 | "1": "2", 301 | }, rsp.Body().(map[string]string)) 302 | } 303 | 304 | // TestInvalidBody verifies that an incorrect type passed as the `Body` returns a "bad request" error 305 | func (suite *clientSuite) TestInvalidBody() { 306 | cl := NewClient().Add(Call{ 307 | Uid: "call", 308 | Service: "notathing", // We would get a timeout if the service *did not* exist 309 | Endpoint: "reallynotathing", 310 | Body: make(chan struct{}), 311 | Response: &testproto.DummyResponse{}, 312 | }). 313 | SetTransport(suite.trans) 314 | 315 | err := cl.Execute().Errors().ForUid("call") 316 | suite.Require().Error(err) 317 | suite.Assert().Equal(terrors.ErrBadRequest, err.Code) 318 | } 319 | 320 | // TestEmpty verifies that an empty call-set results in no errors 321 | func (suite *clientSuite) TestEmpty() { 322 | cl := NewClient() 323 | suite.Require().Empty(cl.Execute().Errors()) 324 | } 325 | --------------------------------------------------------------------------------