├── go.mod ├── .travis.yml ├── server_jsonrpc_test.go ├── LICENSE ├── internal └── svc │ └── svc.go ├── README.md ├── debug.go ├── client_test.go ├── server_util_test.go ├── client.go ├── server_test.go └── server.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/keegancsmith/rpc 2 | 3 | go 1.18 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | sudo: false 3 | go: 4 | - 1.x 5 | -------------------------------------------------------------------------------- /server_jsonrpc_test.go: -------------------------------------------------------------------------------- 1 | package rpc 2 | 3 | import ( 4 | "io" 5 | ) 6 | 7 | type HTTPReadWriteCloser struct { 8 | In io.Reader 9 | Out io.Writer 10 | } 11 | 12 | func (c *HTTPReadWriteCloser) Read(p []byte) (n int, err error) { return c.In.Read(p) } 13 | func (c *HTTPReadWriteCloser) Write(d []byte) (n int, err error) { return c.Out.Write(d) } 14 | func (c *HTTPReadWriteCloser) Close() error { return nil } 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2009 The Go Authors. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are 5 | met: 6 | 7 | * Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above 10 | copyright notice, this list of conditions and the following disclaimer 11 | in the documentation and/or other materials provided with the 12 | distribution. 13 | * Neither the name of Google Inc. nor the names of its 14 | contributors may be used to endorse or promote products derived from 15 | this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /internal/svc/svc.go: -------------------------------------------------------------------------------- 1 | package svc 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | ) 7 | 8 | // Pending manages a map of all pending requests to a rpc.Service for a 9 | // connection (an rpc.ServerCodec). 10 | type Pending struct { 11 | mu sync.Mutex 12 | m map[uint64]context.CancelFunc // seq -> cancel 13 | parent context.Context 14 | } 15 | 16 | func NewPending(parent context.Context) *Pending { 17 | return &Pending{ 18 | m: make(map[uint64]context.CancelFunc), 19 | parent: parent, 20 | } 21 | } 22 | 23 | func (s *Pending) Start(seq uint64) context.Context { 24 | ctx, cancel := context.WithCancel(s.parent) 25 | s.mu.Lock() 26 | // we assume seq is not already in map. If not, the client is broken. 27 | s.m[seq] = cancel 28 | s.mu.Unlock() 29 | return ctx 30 | } 31 | 32 | func (s *Pending) Cancel(seq uint64) { 33 | s.mu.Lock() 34 | cancel, ok := s.m[seq] 35 | if ok { 36 | delete(s.m, seq) 37 | } 38 | s.mu.Unlock() 39 | if ok { 40 | cancel() 41 | } 42 | } 43 | 44 | type CancelArgs struct { 45 | // Seq is the sequence number for the rpc.Call to cancel. 46 | Seq uint64 47 | 48 | // pending is the DS used by rpc.Server to track the ongoing calls for 49 | // this connection. It should not be set by the client, the Service will 50 | // set it. 51 | pending *Pending 52 | } 53 | 54 | // SetPending sets the pending map for the server to use. Do not use on the 55 | // client. 56 | func (a *CancelArgs) SetPending(p *Pending) { 57 | a.pending = p 58 | } 59 | 60 | // GoRPC is an internal service used by rpc. 61 | type GoRPC struct{} 62 | 63 | func (s *GoRPC) Cancel(ctx context.Context, args *CancelArgs, reply *bool) error { 64 | args.pending.Cancel(args.Seq) 65 | return nil 66 | } 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # rpc [![Build Status](https://travis-ci.org/keegancsmith/rpc.svg?branch=master)](https://travis-ci.org/keegancsmith/rpc) 2 | 3 | This is a fork of the stdlib [net/rpc](https://golang.org/pkg/net/rpc/) which 4 | is frozen. It adds support for `context.Context` on the client and server, 5 | including propagating cancellation. 6 | 7 | The API is exactly the same, except `Client.Call` takes a `context.Context`, 8 | and Server methods are expected to take a `context.Context` as the first 9 | argument. Additionally the wire protocol is unchanged, so is backwards 10 | compatible with `net/rpc` clients. 11 | 12 | `DialHTTPPathTimeout` function is also added. A future release of rpc may 13 | update all Dial functions to instead take a context. 14 | 15 | `ClientTrace` functionality is also added. This is for hooking into the rpc 16 | client to enable tracing. 17 | 18 | ## Why use net/rpc 19 | 20 | There are many alternatives for RPC in Go, the most popular being 21 | [GRPC](https://grpc.io/). However, `net/rpc` has the following nice 22 | properties: 23 | 24 | - Nice API 25 | - No need for IDL 26 | - Good performance 27 | 28 | The nice API is subjective. However, the API is small, simple and composable. 29 | which makes it quite powerful. IDL tools are things like GRPC requiring protoc 30 | to generate go code from the protobuf files. `net/rpc` has no third party 31 | dependencies nor code generation step, simplify the use of it. A benchmark 32 | done on the [6 Sep 33 | 2016](https://github.com/golang/go/issues/16844#issuecomment-245261755) 34 | indicated `net/rpc` was 4x faster than GRPC. This is an outdated benchmark, 35 | but is an indication at the surprisingly good performance `net/rpc` provides. 36 | 37 | For more discussion on the pros and cons of `net/rpc` see the issue [proposal: 38 | freeze net/rpc](https://github.com/golang/go/issues/16844). 39 | 40 | ## Details 41 | 42 | Last forked from commit 43 | [bfadd78986](https://github.com/golang/go/commit/bfadd78986) on 7 September 44 | 2022. 45 | 46 | Cancellation implemented via the rpc call `_goRPC_.Cancel`. 47 | -------------------------------------------------------------------------------- /debug.go: -------------------------------------------------------------------------------- 1 | // Copyright 2009 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package rpc 6 | 7 | /* 8 | Some HTML presented at http://machine:port/debug/rpc 9 | Lists services, their methods, and some statistics, still rudimentary. 10 | */ 11 | 12 | import ( 13 | "fmt" 14 | "html/template" 15 | "net/http" 16 | "sort" 17 | ) 18 | 19 | const debugText = ` 20 | 21 | Services 22 | {{range .}} 23 |
24 | Service {{.Name}} 25 |
26 | 27 | 28 | {{range .Method}} 29 | 30 | 31 | 32 | 33 | {{end}} 34 |
MethodCalls
{{.Name}}({{.Type.ArgType}}, {{.Type.ReplyType}}) error{{.Type.NumCalls}}
35 | {{end}} 36 | 37 | ` 38 | 39 | var debug = template.Must(template.New("RPC debug").Parse(debugText)) 40 | 41 | // If set, print log statements for internal and I/O errors. 42 | var debugLog = false 43 | 44 | type debugMethod struct { 45 | Type *methodType 46 | Name string 47 | } 48 | 49 | type methodArray []debugMethod 50 | 51 | type debugService struct { 52 | Service *service 53 | Name string 54 | Method methodArray 55 | } 56 | 57 | type serviceArray []debugService 58 | 59 | func (s serviceArray) Len() int { return len(s) } 60 | func (s serviceArray) Less(i, j int) bool { return s[i].Name < s[j].Name } 61 | func (s serviceArray) Swap(i, j int) { s[i], s[j] = s[j], s[i] } 62 | 63 | func (m methodArray) Len() int { return len(m) } 64 | func (m methodArray) Less(i, j int) bool { return m[i].Name < m[j].Name } 65 | func (m methodArray) Swap(i, j int) { m[i], m[j] = m[j], m[i] } 66 | 67 | type debugHTTP struct { 68 | *Server 69 | } 70 | 71 | // Runs at /debug/rpc 72 | func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) { 73 | // Build a sorted version of the data. 74 | var services serviceArray 75 | server.serviceMap.Range(func(snamei, svci any) bool { 76 | svc := svci.(*service) 77 | ds := debugService{svc, snamei.(string), make(methodArray, 0, len(svc.method))} 78 | for mname, method := range svc.method { 79 | ds.Method = append(ds.Method, debugMethod{method, mname}) 80 | } 81 | sort.Sort(ds.Method) 82 | services = append(services, ds) 83 | return true 84 | }) 85 | sort.Sort(services) 86 | err := debug.Execute(w, services) 87 | if err != nil { 88 | fmt.Fprintln(w, "rpc: error executing template:", err.Error()) 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2014 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package rpc 6 | 7 | import ( 8 | "context" 9 | "errors" 10 | "fmt" 11 | "net" 12 | "strings" 13 | "testing" 14 | ) 15 | 16 | type shutdownCodec struct { 17 | responded chan int 18 | closed bool 19 | } 20 | 21 | func (c *shutdownCodec) WriteRequest(*Request, any) error { return nil } 22 | func (c *shutdownCodec) ReadResponseBody(any) error { return nil } 23 | func (c *shutdownCodec) ReadResponseHeader(*Response) error { 24 | c.responded <- 1 25 | return errors.New("shutdownCodec ReadResponseHeader") 26 | } 27 | func (c *shutdownCodec) Close() error { 28 | c.closed = true 29 | return nil 30 | } 31 | 32 | func TestCloseCodec(t *testing.T) { 33 | codec := &shutdownCodec{responded: make(chan int)} 34 | client := NewClientWithCodec(codec) 35 | <-codec.responded 36 | client.Close() 37 | if !codec.closed { 38 | t.Error("client.Close did not close codec") 39 | } 40 | } 41 | 42 | // Test that errors in gob shut down the connection. Issue 7689. 43 | 44 | type R struct { 45 | msg []byte // Not exported, so R does not work with gob. 46 | } 47 | 48 | type S struct{} 49 | 50 | func (s *S) Recv(ctx context.Context, nul *struct{}, reply *R) error { 51 | *reply = R{[]byte("foo")} 52 | return nil 53 | } 54 | 55 | func TestGobError(t *testing.T) { 56 | defer func() { 57 | err := recover() 58 | if err == nil { 59 | t.Fatal("no error") 60 | } 61 | if !strings.Contains(err.(error).Error(), "reading body unexpected EOF") { 62 | t.Fatal("expected `reading body unexpected EOF', got", err) 63 | } 64 | }() 65 | Register(new(S)) 66 | 67 | listen, err := net.Listen("tcp", "127.0.0.1:0") 68 | if err != nil { 69 | panic(err) 70 | } 71 | go Accept(listen) 72 | 73 | client, err := Dial("tcp", listen.Addr().String()) 74 | if err != nil { 75 | panic(err) 76 | } 77 | 78 | var reply Reply 79 | err = client.Call(context.Background(), "S.Recv", &struct{}{}, &reply) 80 | if err != nil { 81 | panic(err) 82 | } 83 | 84 | fmt.Printf("%#v\n", reply) 85 | client.Close() 86 | 87 | listen.Close() 88 | } 89 | 90 | type ClientCodecError struct { 91 | WriteRequestError error 92 | } 93 | 94 | func (c *ClientCodecError) WriteRequest(*Request, any) error { 95 | return c.WriteRequestError 96 | } 97 | func (c *ClientCodecError) ReadResponseHeader(*Response) error { 98 | return nil 99 | } 100 | func (c *ClientCodecError) ReadResponseBody(any) error { 101 | return nil 102 | } 103 | func (c *ClientCodecError) Close() error { 104 | return nil 105 | } 106 | 107 | func TestClientTrace(t *testing.T) { 108 | wantErr := errors.New("test") 109 | client := NewClientWithCodec(&ClientCodecError{WriteRequestError: wantErr}) 110 | defer client.Close() 111 | 112 | startCalled := false 113 | var gotErr error 114 | ctx := WithClientTrace(context.Background(), &ClientTrace{ 115 | WriteRequestStart: func() { startCalled = true }, 116 | WriteRequestDone: func(err error) { gotErr = err }, 117 | }) 118 | 119 | var reply Reply 120 | err := client.Call(ctx, "S.Recv", &struct{}{}, &reply) 121 | if err != wantErr { 122 | t.Fatalf("expected Call to return the same error sent to ClientTrace.WriteRequestDone: want %v, got %v", wantErr, err) 123 | } 124 | if gotErr != wantErr { 125 | t.Fatalf("expected ClientTrace.WriteRequestDone to be called with error %v, got %v", wantErr, gotErr) 126 | } 127 | if !startCalled { 128 | t.Fatal("expected ClientTrace.WriteRequestStart to be called") 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /server_util_test.go: -------------------------------------------------------------------------------- 1 | // This code is all of https://golang.org/src/net/rpc/jsonrpc/server.go and some of 2 | // https://golang.org/src/net/rpc/jsonrpc/client.go (both adjusted to use the fork). 3 | // 4 | // Unfortunately but logically the net/rpc/jsonrpc uses net/rpc types which are 5 | // incompatible with this fork, so the code could not be used as-is. 6 | // 7 | // Copyright 2010 The Go Authors. All rights reserved. 8 | // Use of this source code is governed by a BSD-style 9 | // license that can be found in the LICENSE file. 10 | // 11 | package rpc 12 | 13 | import ( 14 | "encoding/json" 15 | "errors" 16 | "io" 17 | "sync" 18 | ) 19 | 20 | var errMissingParams = errors.New("jsonrpc: request body missing params") 21 | 22 | type jsonServerCodec struct { 23 | dec *json.Decoder // for reading JSON values 24 | enc *json.Encoder // for writing JSON values 25 | c io.Closer 26 | 27 | // temporary work space 28 | req jsonServerRequest 29 | 30 | // JSON-RPC clients can use arbitrary json values as request IDs. 31 | // Package rpc expects uint64 request IDs. 32 | // We assign uint64 sequence numbers to incoming requests 33 | // but save the original request ID in the pending map. 34 | // When rpc responds, we use the sequence number in 35 | // the response to find the original request ID. 36 | mutex sync.Mutex // protects seq, pending 37 | seq uint64 38 | pending map[uint64]*json.RawMessage 39 | } 40 | 41 | // NewServerCodec returns a new rpc.ServerCodec using JSON-RPC on conn. 42 | func NewJsonServerCodec(conn io.ReadWriteCloser) ServerCodec { 43 | return &jsonServerCodec{ 44 | dec: json.NewDecoder(conn), 45 | enc: json.NewEncoder(conn), 46 | c: conn, 47 | pending: make(map[uint64]*json.RawMessage), 48 | } 49 | } 50 | 51 | type jsonServerRequest struct { 52 | Method string `json:"method"` 53 | Params *json.RawMessage `json:"params"` 54 | Id *json.RawMessage `json:"id"` 55 | } 56 | 57 | func (r *jsonServerRequest) reset() { 58 | r.Method = "" 59 | r.Params = nil 60 | r.Id = nil 61 | } 62 | 63 | type jsonServerResponse struct { 64 | Id *json.RawMessage `json:"id"` 65 | Result any `json:"result"` 66 | Error any `json:"error"` 67 | } 68 | 69 | func (c *jsonServerCodec) ReadRequestHeader(r *Request) error { 70 | c.req.reset() 71 | if err := c.dec.Decode(&c.req); err != nil { 72 | return err 73 | } 74 | r.ServiceMethod = c.req.Method 75 | 76 | // JSON request id can be any JSON value; 77 | // RPC package expects uint64. Translate to 78 | // internal uint64 and save JSON on the side. 79 | c.mutex.Lock() 80 | c.seq++ 81 | c.pending[c.seq] = c.req.Id 82 | c.req.Id = nil 83 | r.Seq = c.seq 84 | c.mutex.Unlock() 85 | 86 | return nil 87 | } 88 | 89 | func (c *jsonServerCodec) ReadRequestBody(x any) error { 90 | if x == nil { 91 | return nil 92 | } 93 | if c.req.Params == nil { 94 | return errMissingParams 95 | } 96 | // JSON params is array value. 97 | // RPC params is struct. 98 | // Unmarshal into array containing struct for now. 99 | // Should think about making RPC more general. 100 | var params [1]any 101 | params[0] = x 102 | return json.Unmarshal(*c.req.Params, ¶ms) 103 | } 104 | 105 | var null = json.RawMessage([]byte("null")) 106 | 107 | func (c *jsonServerCodec) WriteResponse(r *Response, x any) error { 108 | c.mutex.Lock() 109 | b, ok := c.pending[r.Seq] 110 | if !ok { 111 | c.mutex.Unlock() 112 | return errors.New("invalid sequence number in response") 113 | } 114 | delete(c.pending, r.Seq) 115 | c.mutex.Unlock() 116 | 117 | if b == nil { 118 | // Invalid request so no id. Use JSON null. 119 | b = &null 120 | } 121 | resp := jsonServerResponse{Id: b} 122 | if r.Error == "" { 123 | resp.Result = x 124 | } else { 125 | resp.Error = r.Error 126 | } 127 | return c.enc.Encode(resp) 128 | } 129 | 130 | func (c *jsonServerCodec) Close() error { 131 | return c.c.Close() 132 | } 133 | 134 | type jsonClientRequest struct { 135 | Method string `json:"method"` 136 | Params [1]any `json:"params"` 137 | Id uint64 `json:"id"` 138 | } 139 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | // Copyright 2009 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package rpc 6 | 7 | import ( 8 | "bufio" 9 | "context" 10 | "encoding/gob" 11 | "errors" 12 | "io" 13 | "log" 14 | "net" 15 | "net/http" 16 | "sync" 17 | "time" 18 | 19 | "github.com/keegancsmith/rpc/internal/svc" 20 | ) 21 | 22 | // ServerError represents an error that has been returned from 23 | // the remote side of the RPC connection. 24 | type ServerError string 25 | 26 | func (e ServerError) Error() string { 27 | return string(e) 28 | } 29 | 30 | var ErrShutdown = errors.New("connection is shut down") 31 | 32 | // Call represents an active RPC. 33 | type Call struct { 34 | ServiceMethod string // The name of the service and method to call. 35 | Args any // The argument to the function (*struct). 36 | Reply any // The reply from the function (*struct). 37 | Error error // After completion, the error status. 38 | Done chan *Call // Receives *Call when Go is complete. 39 | seq uint64 // Sequence num used to send. Non-zero when sent. 40 | } 41 | 42 | // ClientTrace is a set of hooks to run at various stages of an outgoing RPC 43 | // request. Any particular hook may be nil. Functions may be called 44 | // concurrently from different goroutines. 45 | // 46 | // ClientTrace currently traces a single RPC request, not the response. 47 | type ClientTrace struct { 48 | // WriteRequestStart is called when start WriteRequest. Concurrent calls to 49 | // Client.Go or Client.Call can cause a queue to form, leading to a delay 50 | // between calls to Client.Go/Client.Call and WriteRequestStart being 51 | // called. 52 | WriteRequestStart func() 53 | 54 | // WriteRequestDone is called once WriteRequest returns with the error 55 | // returned from WriteRequest. 56 | WriteRequestDone func(err error) 57 | } 58 | 59 | // unique type to prevent assignment. 60 | type clientTraceContextKey struct{} 61 | 62 | // WithClientTrace returns a new context based on the provided parent 63 | // ctx. Requests made with the returned context will use the provided trace 64 | // hooks. Previous hooks registered with ctx are ignored. 65 | func WithClientTrace(ctx context.Context, trace *ClientTrace) context.Context { 66 | return context.WithValue(ctx, clientTraceContextKey{}, trace) 67 | } 68 | 69 | func contextClientTrace(ctx context.Context) *ClientTrace { 70 | trace, _ := ctx.Value(clientTraceContextKey{}).(*ClientTrace) 71 | return trace 72 | } 73 | 74 | // Client represents an RPC Client. 75 | // There may be multiple outstanding Calls associated 76 | // with a single Client, and a Client may be used by 77 | // multiple goroutines simultaneously. 78 | type Client struct { 79 | codec ClientCodec 80 | 81 | reqMutex sync.Mutex // protects following 82 | request Request 83 | 84 | mutex sync.Mutex // protects following 85 | seq uint64 86 | pending map[uint64]*Call 87 | closing bool // user has called Close 88 | shutdown bool // server has told us to stop 89 | } 90 | 91 | // A ClientCodec implements writing of RPC requests and 92 | // reading of RPC responses for the client side of an RPC session. 93 | // The client calls WriteRequest to write a request to the connection 94 | // and calls ReadResponseHeader and ReadResponseBody in pairs 95 | // to read responses. The client calls Close when finished with the 96 | // connection. ReadResponseBody may be called with a nil 97 | // argument to force the body of the response to be read and then 98 | // discarded. 99 | // See NewClient's comment for information about concurrent access. 100 | type ClientCodec interface { 101 | WriteRequest(*Request, any) error 102 | ReadResponseHeader(*Response) error 103 | ReadResponseBody(any) error 104 | 105 | Close() error 106 | } 107 | 108 | func (client *Client) send(ctx context.Context, call *Call) { 109 | trace := contextClientTrace(ctx) 110 | 111 | client.reqMutex.Lock() 112 | defer client.reqMutex.Unlock() 113 | 114 | // Register this call. 115 | client.mutex.Lock() 116 | if client.shutdown || client.closing { 117 | client.mutex.Unlock() 118 | call.Error = ErrShutdown 119 | call.done() 120 | return 121 | } 122 | if call.seq != 0 { 123 | // It has already been canceled, don't bother sending 124 | call.Error = context.Canceled 125 | client.mutex.Unlock() 126 | call.done() 127 | return 128 | } 129 | client.seq++ 130 | seq := client.seq 131 | call.seq = seq 132 | client.pending[seq] = call 133 | client.mutex.Unlock() 134 | 135 | if trace != nil && trace.WriteRequestStart != nil { 136 | trace.WriteRequestStart() 137 | } 138 | 139 | // Encode and send the request. 140 | client.request.Seq = seq 141 | client.request.ServiceMethod = call.ServiceMethod 142 | err := client.codec.WriteRequest(&client.request, call.Args) 143 | if trace != nil && trace.WriteRequestDone != nil { 144 | trace.WriteRequestDone(err) 145 | } 146 | if err != nil { 147 | client.mutex.Lock() 148 | call = client.pending[seq] 149 | delete(client.pending, seq) 150 | client.mutex.Unlock() 151 | if call != nil { 152 | call.Error = err 153 | call.done() 154 | } 155 | } 156 | } 157 | 158 | func (client *Client) input() { 159 | var err error 160 | var response Response 161 | for err == nil { 162 | response = Response{} 163 | err = client.codec.ReadResponseHeader(&response) 164 | if err != nil { 165 | break 166 | } 167 | seq := response.Seq 168 | client.mutex.Lock() 169 | call := client.pending[seq] 170 | delete(client.pending, seq) 171 | client.mutex.Unlock() 172 | 173 | switch { 174 | case call == nil: 175 | // We've got no pending call. That usually means that 176 | // WriteRequest partially failed, and call was already 177 | // removed; response is a server telling us about an 178 | // error reading request body. We should still attempt 179 | // to read error body, but there's no one to give it to. 180 | err = client.codec.ReadResponseBody(nil) 181 | if err != nil { 182 | err = errors.New("reading error body: " + err.Error()) 183 | } 184 | case response.Error != "": 185 | // We've got an error response. Give this to the request; 186 | // any subsequent requests will get the ReadResponseBody 187 | // error if there is one. 188 | call.Error = ServerError(response.Error) 189 | err = client.codec.ReadResponseBody(nil) 190 | if err != nil { 191 | err = errors.New("reading error body: " + err.Error()) 192 | } 193 | call.done() 194 | default: 195 | err = client.codec.ReadResponseBody(call.Reply) 196 | if err != nil { 197 | call.Error = errors.New("reading body " + err.Error()) 198 | } 199 | call.done() 200 | } 201 | } 202 | // Terminate pending calls. 203 | client.reqMutex.Lock() 204 | client.mutex.Lock() 205 | client.shutdown = true 206 | closing := client.closing 207 | if err == io.EOF { 208 | if closing { 209 | err = ErrShutdown 210 | } else { 211 | err = io.ErrUnexpectedEOF 212 | } 213 | } 214 | for _, call := range client.pending { 215 | call.Error = err 216 | call.done() 217 | } 218 | client.mutex.Unlock() 219 | client.reqMutex.Unlock() 220 | if debugLog && err != io.EOF && !closing { 221 | log.Println("rpc: client protocol error:", err) 222 | } 223 | } 224 | 225 | func (call *Call) done() { 226 | select { 227 | case call.Done <- call: 228 | // ok 229 | default: 230 | // We don't want to block here. It is the caller's responsibility to make 231 | // sure the channel has enough buffer space. See comment in Go(). 232 | if debugLog { 233 | log.Println("rpc: discarding Call reply due to insufficient Done chan capacity") 234 | } 235 | } 236 | } 237 | 238 | // NewClient returns a new Client to handle requests to the 239 | // set of services at the other end of the connection. 240 | // It adds a buffer to the write side of the connection so 241 | // the header and payload are sent as a unit. 242 | // 243 | // The read and write halves of the connection are serialized independently, 244 | // so no interlocking is required. However each half may be accessed 245 | // concurrently so the implementation of conn should protect against 246 | // concurrent reads or concurrent writes. 247 | func NewClient(conn io.ReadWriteCloser) *Client { 248 | encBuf := bufio.NewWriter(conn) 249 | client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf} 250 | return NewClientWithCodec(client) 251 | } 252 | 253 | // NewClientWithCodec is like NewClient but uses the specified 254 | // codec to encode requests and decode responses. 255 | func NewClientWithCodec(codec ClientCodec) *Client { 256 | client := &Client{ 257 | codec: codec, 258 | pending: make(map[uint64]*Call), 259 | } 260 | go client.input() 261 | return client 262 | } 263 | 264 | type gobClientCodec struct { 265 | rwc io.ReadWriteCloser 266 | dec *gob.Decoder 267 | enc *gob.Encoder 268 | encBuf *bufio.Writer 269 | } 270 | 271 | func (c *gobClientCodec) WriteRequest(r *Request, body any) (err error) { 272 | if err = c.enc.Encode(r); err != nil { 273 | return 274 | } 275 | if err = c.enc.Encode(body); err != nil { 276 | return 277 | } 278 | return c.encBuf.Flush() 279 | } 280 | 281 | func (c *gobClientCodec) ReadResponseHeader(r *Response) error { 282 | return c.dec.Decode(r) 283 | } 284 | 285 | func (c *gobClientCodec) ReadResponseBody(body any) error { 286 | return c.dec.Decode(body) 287 | } 288 | 289 | func (c *gobClientCodec) Close() error { 290 | return c.rwc.Close() 291 | } 292 | 293 | // DialHTTP connects to an HTTP RPC server at the specified network address 294 | // listening on the default HTTP RPC path. 295 | func DialHTTP(network, address string) (*Client, error) { 296 | return DialHTTPPath(network, address, DefaultRPCPath) 297 | } 298 | 299 | // DialHTTPPath connects to an HTTP RPC server 300 | // at the specified network address and path with a default timeout. 301 | func DialHTTPPath(network, address, path string) (*Client, error) { 302 | return DialHTTPPathTimeout(network, address, path, 0) 303 | } 304 | 305 | // DialHTTPPathTimeout connects to an HTTP RPC server 306 | // at the specified network address and path with the specified timeout for Dialing. 307 | // 308 | // This is a function added by github.com/keegancsmith/rpc 309 | func DialHTTPPathTimeout(network, address, path string, timeout time.Duration) (*Client, error) { 310 | conn, err := net.DialTimeout(network, address, timeout) 311 | if err != nil { 312 | return nil, err 313 | } 314 | io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n") 315 | 316 | // Require successful HTTP response 317 | // before switching to RPC protocol. 318 | resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) 319 | if err == nil && resp.Status == connected { 320 | return NewClient(conn), nil 321 | } 322 | if err == nil { 323 | err = errors.New("unexpected HTTP response: " + resp.Status) 324 | } 325 | conn.Close() 326 | return nil, &net.OpError{ 327 | Op: "dial-http", 328 | Net: network + " " + address, 329 | Addr: nil, 330 | Err: err, 331 | } 332 | } 333 | 334 | // Dial connects to an RPC server at the specified network address. 335 | func Dial(network, address string) (*Client, error) { 336 | conn, err := net.Dial(network, address) 337 | if err != nil { 338 | return nil, err 339 | } 340 | return NewClient(conn), nil 341 | } 342 | 343 | // Close calls the underlying codec's Close method. If the connection is already 344 | // shutting down, ErrShutdown is returned. 345 | func (client *Client) Close() error { 346 | client.mutex.Lock() 347 | if client.closing { 348 | client.mutex.Unlock() 349 | return ErrShutdown 350 | } 351 | client.closing = true 352 | client.mutex.Unlock() 353 | return client.codec.Close() 354 | } 355 | 356 | // Go calls client.GoContext with a background context. See GoContext docstring. 357 | func (client *Client) Go(serviceMethod string, args any, reply any, done chan *Call) *Call { 358 | return client.GoContext(context.Background(), serviceMethod, args, reply, done) 359 | } 360 | 361 | // GoContext invokes the function asynchronously. It returns the Call structure representing 362 | // the invocation. The done channel will signal when the call is complete by returning 363 | // the same Call object. If done is nil, Go will allocate a new channel. 364 | // If non-nil, done must be buffered or Go will deliberately crash. 365 | func (client *Client) GoContext(ctx context.Context, serviceMethod string, args any, reply any, done chan *Call) *Call { 366 | call := new(Call) 367 | call.ServiceMethod = serviceMethod 368 | call.Args = args 369 | call.Reply = reply 370 | if done == nil { 371 | done = make(chan *Call, 10) // buffered. 372 | } else { 373 | // If caller passes done != nil, it must arrange that 374 | // done has enough buffer for the number of simultaneous 375 | // RPCs that will be using that channel. If the channel 376 | // is totally unbuffered, it's best not to run at all. 377 | if cap(done) == 0 { 378 | log.Panic("rpc: done channel is unbuffered") 379 | } 380 | } 381 | call.Done = done 382 | client.send(ctx, call) 383 | return call 384 | } 385 | 386 | // Call invokes the named function, waits for it to complete, and returns its error status. 387 | func (client *Client) Call(ctx context.Context, serviceMethod string, args any, reply any) error { 388 | ch := make(chan *Call, 2) // 2 for this call and cancel 389 | call := client.GoContext(ctx, serviceMethod, args, reply, ch) 390 | select { 391 | case <-call.Done: 392 | return call.Error 393 | case <-ctx.Done(): 394 | // Cancel the pending request on the client 395 | client.mutex.Lock() 396 | seq := call.seq 397 | _, ok := client.pending[seq] 398 | delete(client.pending, seq) 399 | if seq == 0 { 400 | // hasn't been sent yet, non-zero will prevent send 401 | call.seq = 1 402 | } 403 | client.mutex.Unlock() 404 | 405 | // Cancel running request on the server 406 | if seq != 0 && ok { 407 | client.Go("_goRPC_.Cancel", &svc.CancelArgs{Seq: seq}, nil, ch) 408 | } 409 | 410 | return ctx.Err() 411 | } 412 | } 413 | -------------------------------------------------------------------------------- /server_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2009 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package rpc 6 | 7 | import ( 8 | "bytes" 9 | "context" 10 | "encoding/json" 11 | "errors" 12 | "fmt" 13 | "io" 14 | "log" 15 | "net" 16 | "net/http" 17 | "net/http/httptest" 18 | "net/url" 19 | "reflect" 20 | "runtime" 21 | "strings" 22 | "sync" 23 | "sync/atomic" 24 | "testing" 25 | "time" 26 | ) 27 | 28 | var ( 29 | newServer *Server 30 | serverAddr, newServerAddr string 31 | httpServerAddr string 32 | once, newOnce, httpOnce sync.Once 33 | ) 34 | 35 | const ( 36 | newHttpPath = "/foo" 37 | ) 38 | 39 | type Args struct { 40 | A, B int 41 | } 42 | 43 | type Reply struct { 44 | C int 45 | } 46 | 47 | type Arith int 48 | 49 | // Some of Arith's methods have value args, some have pointer args. That's deliberate. 50 | 51 | func (t *Arith) Add(ctx context.Context, args Args, reply *Reply) error { 52 | reply.C = args.A + args.B 53 | return nil 54 | } 55 | 56 | func (t *Arith) Mul(ctx context.Context, args *Args, reply *Reply) error { 57 | reply.C = args.A * args.B 58 | return nil 59 | } 60 | 61 | func (t *Arith) Div(ctx context.Context, args Args, reply *Reply) error { 62 | if args.B == 0 { 63 | return errors.New("divide by zero") 64 | } 65 | reply.C = args.A / args.B 66 | return nil 67 | } 68 | 69 | func (t *Arith) String(ctx context.Context, args *Args, reply *string) error { 70 | *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) 71 | return nil 72 | } 73 | 74 | func (t *Arith) Scan(ctx context.Context, args string, reply *Reply) (err error) { 75 | _, err = fmt.Sscan(args, &reply.C) 76 | return 77 | } 78 | 79 | func (t *Arith) Error(ctx context.Context, args *Args, reply *Reply) error { 80 | panic("ERROR") 81 | } 82 | 83 | func (t *Arith) SleepMilli(ctx context.Context, args *Args, reply *Reply) error { 84 | time.Sleep(time.Duration(args.A) * time.Millisecond) 85 | return nil 86 | } 87 | 88 | type hidden int 89 | 90 | func (t *hidden) Exported(ctx context.Context, args Args, reply *Reply) error { 91 | reply.C = args.A + args.B 92 | return nil 93 | } 94 | 95 | type Embed struct { 96 | hidden 97 | } 98 | 99 | type BuiltinTypes struct{} 100 | 101 | func (BuiltinTypes) Map(ctx context.Context, args *Args, reply *map[int]int) error { 102 | (*reply)[args.A] = args.B 103 | return nil 104 | } 105 | 106 | func (BuiltinTypes) Slice(ctx context.Context, args *Args, reply *[]int) error { 107 | *reply = append(*reply, args.A, args.B) 108 | return nil 109 | } 110 | 111 | func (BuiltinTypes) Array(ctx context.Context, args *Args, reply *[2]int) error { 112 | (*reply)[0] = args.A 113 | (*reply)[1] = args.B 114 | return nil 115 | } 116 | 117 | func listenTCP() (net.Listener, string) { 118 | l, e := net.Listen("tcp", "127.0.0.1:0") // any available address 119 | if e != nil { 120 | log.Fatalf("net.Listen tcp :0: %v", e) 121 | } 122 | return l, l.Addr().String() 123 | } 124 | 125 | func startServer() { 126 | Register(new(Arith)) 127 | Register(new(Embed)) 128 | RegisterName("net.rpc.Arith", new(Arith)) 129 | Register(BuiltinTypes{}) 130 | 131 | var l net.Listener 132 | l, serverAddr = listenTCP() 133 | log.Println("Test RPC server listening on", serverAddr) 134 | go Accept(l) 135 | 136 | HandleHTTP() 137 | httpOnce.Do(startHttpServer) 138 | } 139 | 140 | func startNewServer() { 141 | newServer = NewServer() 142 | newServer.Register(new(Arith)) 143 | newServer.Register(new(Embed)) 144 | newServer.RegisterName("net.rpc.Arith", new(Arith)) 145 | newServer.RegisterName("newServer.Arith", new(Arith)) 146 | 147 | var l net.Listener 148 | l, newServerAddr = listenTCP() 149 | log.Println("NewServer test RPC server listening on", newServerAddr) 150 | go newServer.Accept(l) 151 | 152 | newServer.HandleHTTP(newHttpPath, "/bar") 153 | httpOnce.Do(startHttpServer) 154 | } 155 | 156 | func startHttpServer() { 157 | server := httptest.NewServer(nil) 158 | httpServerAddr = server.Listener.Addr().String() 159 | log.Println("Test HTTP RPC server listening on", httpServerAddr) 160 | } 161 | 162 | func TestRPC(t *testing.T) { 163 | once.Do(startServer) 164 | testRPC(t, serverAddr) 165 | newOnce.Do(startNewServer) 166 | testRPC(t, newServerAddr) 167 | testNewServerRPC(t, newServerAddr) 168 | } 169 | 170 | func testRPC(t *testing.T, addr string) { 171 | client, err := Dial("tcp", addr) 172 | if err != nil { 173 | t.Fatal("dialing", err) 174 | } 175 | defer client.Close() 176 | 177 | // Synchronous calls 178 | ctx := context.Background() 179 | args := &Args{7, 8} 180 | reply := new(Reply) 181 | err = client.Call(ctx, "Arith.Add", args, reply) 182 | if err != nil { 183 | t.Errorf("Add: expected no error but got string %q", err.Error()) 184 | } 185 | if reply.C != args.A+args.B { 186 | t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 187 | } 188 | 189 | // Methods exported from unexported embedded structs 190 | args = &Args{7, 0} 191 | reply = new(Reply) 192 | err = client.Call(ctx, "Embed.Exported", args, reply) 193 | if err != nil { 194 | t.Errorf("Add: expected no error but got string %q", err.Error()) 195 | } 196 | if reply.C != args.A+args.B { 197 | t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 198 | } 199 | 200 | // Nonexistent method 201 | args = &Args{7, 0} 202 | reply = new(Reply) 203 | err = client.Call(ctx, "Arith.BadOperation", args, reply) 204 | // expect an error 205 | if err == nil { 206 | t.Error("BadOperation: expected error") 207 | } else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") { 208 | t.Errorf("BadOperation: expected can't find method error; got %q", err) 209 | } 210 | 211 | // Unknown service 212 | args = &Args{7, 8} 213 | reply = new(Reply) 214 | err = client.Call(ctx, "Arith.Unknown", args, reply) 215 | if err == nil { 216 | t.Error("expected error calling unknown service") 217 | } else if !strings.Contains(err.Error(), "method") { 218 | t.Error("expected error about method; got", err) 219 | } 220 | 221 | // Out of order. 222 | args = &Args{7, 8} 223 | mulReply := new(Reply) 224 | mulCall := client.Go("Arith.Mul", args, mulReply, nil) 225 | addReply := new(Reply) 226 | addCall := client.Go("Arith.Add", args, addReply, nil) 227 | 228 | addCall = <-addCall.Done 229 | if addCall.Error != nil { 230 | t.Errorf("Add: expected no error but got string %q", addCall.Error.Error()) 231 | } 232 | if addReply.C != args.A+args.B { 233 | t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B) 234 | } 235 | 236 | mulCall = <-mulCall.Done 237 | if mulCall.Error != nil { 238 | t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error()) 239 | } 240 | if mulReply.C != args.A*args.B { 241 | t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B) 242 | } 243 | 244 | // Error test 245 | args = &Args{7, 0} 246 | reply = new(Reply) 247 | err = client.Call(ctx, "Arith.Div", args, reply) 248 | // expect an error: zero divide 249 | if err == nil { 250 | t.Error("Div: expected error") 251 | } else if err.Error() != "divide by zero" { 252 | t.Error("Div: expected divide by zero error; got", err) 253 | } 254 | 255 | // Bad type. 256 | reply = new(Reply) 257 | err = client.Call(ctx, "Arith.Add", reply, reply) // args, reply would be the correct thing to use 258 | if err == nil { 259 | t.Error("expected error calling Arith.Add with wrong arg type") 260 | } else if !strings.Contains(err.Error(), "type") { 261 | t.Error("expected error about type; got", err) 262 | } 263 | 264 | // Non-struct argument 265 | const Val = 12345 266 | str := fmt.Sprint(Val) 267 | reply = new(Reply) 268 | err = client.Call(ctx, "Arith.Scan", &str, reply) 269 | if err != nil { 270 | t.Errorf("Scan: expected no error but got string %q", err.Error()) 271 | } else if reply.C != Val { 272 | t.Errorf("Scan: expected %d got %d", Val, reply.C) 273 | } 274 | 275 | // Non-struct reply 276 | args = &Args{27, 35} 277 | str = "" 278 | err = client.Call(ctx, "Arith.String", args, &str) 279 | if err != nil { 280 | t.Errorf("String: expected no error but got string %q", err.Error()) 281 | } 282 | expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) 283 | if str != expect { 284 | t.Errorf("String: expected %s got %s", expect, str) 285 | } 286 | 287 | args = &Args{7, 8} 288 | reply = new(Reply) 289 | err = client.Call(ctx, "Arith.Mul", args, reply) 290 | if err != nil { 291 | t.Errorf("Mul: expected no error but got string %q", err.Error()) 292 | } 293 | if reply.C != args.A*args.B { 294 | t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) 295 | } 296 | 297 | // ServiceName contain "." character 298 | args = &Args{7, 8} 299 | reply = new(Reply) 300 | err = client.Call(ctx, "net.rpc.Arith.Add", args, reply) 301 | if err != nil { 302 | t.Errorf("Add: expected no error but got string %q", err.Error()) 303 | } 304 | if reply.C != args.A+args.B { 305 | t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 306 | } 307 | } 308 | 309 | func testNewServerRPC(t *testing.T, addr string) { 310 | client, err := Dial("tcp", addr) 311 | if err != nil { 312 | t.Fatal("dialing", err) 313 | } 314 | defer client.Close() 315 | 316 | ctx := context.Background() 317 | 318 | // Synchronous calls 319 | args := &Args{7, 8} 320 | reply := new(Reply) 321 | err = client.Call(ctx, "newServer.Arith.Add", args, reply) 322 | if err != nil { 323 | t.Errorf("Add: expected no error but got string %q", err.Error()) 324 | } 325 | if reply.C != args.A+args.B { 326 | t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 327 | } 328 | } 329 | 330 | func TestHTTP(t *testing.T) { 331 | once.Do(startServer) 332 | testHTTPRPC(t, "") 333 | newOnce.Do(startNewServer) 334 | testHTTPRPC(t, newHttpPath) 335 | } 336 | 337 | func testHTTPRPC(t *testing.T, path string) { 338 | var client *Client 339 | var err error 340 | if path == "" { 341 | client, err = DialHTTP("tcp", httpServerAddr) 342 | } else { 343 | client, err = DialHTTPPath("tcp", httpServerAddr, path) 344 | } 345 | if err != nil { 346 | t.Fatal("dialing", err) 347 | } 348 | defer client.Close() 349 | 350 | ctx := context.Background() 351 | 352 | // Synchronous calls 353 | args := &Args{7, 8} 354 | reply := new(Reply) 355 | err = client.Call(ctx, "Arith.Add", args, reply) 356 | if err != nil { 357 | t.Errorf("Add: expected no error but got string %q", err.Error()) 358 | } 359 | if reply.C != args.A+args.B { 360 | t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 361 | } 362 | } 363 | 364 | func TestBuiltinTypes(t *testing.T) { 365 | once.Do(startServer) 366 | 367 | client, err := DialHTTP("tcp", httpServerAddr) 368 | if err != nil { 369 | t.Fatal("dialing", err) 370 | } 371 | defer client.Close() 372 | 373 | ctx := context.Background() 374 | 375 | // Map 376 | args := &Args{7, 8} 377 | replyMap := map[int]int{} 378 | err = client.Call(ctx, "BuiltinTypes.Map", args, &replyMap) 379 | if err != nil { 380 | t.Errorf("Map: expected no error but got string %q", err.Error()) 381 | } 382 | if replyMap[args.A] != args.B { 383 | t.Errorf("Map: expected %d got %d", args.B, replyMap[args.A]) 384 | } 385 | 386 | // Slice 387 | args = &Args{7, 8} 388 | replySlice := []int{} 389 | err = client.Call(ctx, "BuiltinTypes.Slice", args, &replySlice) 390 | if err != nil { 391 | t.Errorf("Slice: expected no error but got string %q", err.Error()) 392 | } 393 | if e := []int{args.A, args.B}; !reflect.DeepEqual(replySlice, e) { 394 | t.Errorf("Slice: expected %v got %v", e, replySlice) 395 | } 396 | 397 | // Array 398 | args = &Args{7, 8} 399 | replyArray := [2]int{} 400 | err = client.Call(ctx, "BuiltinTypes.Array", args, &replyArray) 401 | if err != nil { 402 | t.Errorf("Array: expected no error but got string %q", err.Error()) 403 | } 404 | if e := [2]int{args.A, args.B}; !reflect.DeepEqual(replyArray, e) { 405 | t.Errorf("Array: expected %v got %v", e, replyArray) 406 | } 407 | } 408 | 409 | type Context struct { 410 | started chan struct{} 411 | done chan struct{} 412 | } 413 | 414 | func (t *Context) Wait(ctx context.Context, s string, reply *int) error { 415 | close(t.started) 416 | <-ctx.Done() 417 | close(t.done) 418 | return nil 419 | } 420 | 421 | func TestContext(t *testing.T) { 422 | svc := &Context{ 423 | started: make(chan struct{}), 424 | done: make(chan struct{}), 425 | } 426 | 427 | handler := NewServer() 428 | handler.Register(svc) 429 | ts := httptest.NewServer(handler) 430 | defer ts.Close() 431 | u, err := url.Parse(ts.URL) 432 | if err != nil { 433 | t.Fatal(err) 434 | } 435 | cl, err := DialHTTP("tcp", u.Host) 436 | if err != nil { 437 | t.Fatal(err) 438 | } 439 | defer cl.Close() 440 | 441 | wait := func(desc string, c chan struct{}) { 442 | select { 443 | case <-c: 444 | return 445 | case <-time.After(5 * time.Second): 446 | t.Fatal("Failed to wait for", desc) 447 | } 448 | } 449 | 450 | ctx, cancel := context.WithCancel(context.Background()) 451 | done := make(chan struct{}) 452 | go func() { 453 | args := "" 454 | err = cl.Call(ctx, "Context.Wait", &args, new(int)) 455 | close(done) 456 | }() 457 | wait("server side to be called", svc.started) 458 | cancel() 459 | wait("client side to be done after cancel", done) 460 | wait("server side to be done after cancel", svc.done) 461 | if err != context.Canceled { 462 | t.Fatalf("expected to fail due to context cancellation: %v", err) 463 | } 464 | } 465 | 466 | type JsonServer struct { 467 | srv *Server 468 | } 469 | 470 | func (server *JsonServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { 471 | rwc := &HTTPReadWriteCloser{ 472 | In: req.Body, 473 | Out: w, 474 | } 475 | codec := NewJsonServerCodec(rwc) 476 | server.srv.ServeRequestContext(req.Context(), codec) 477 | } 478 | 479 | func TestContextCodec(t *testing.T) { 480 | svc := &Context{ 481 | started: make(chan struct{}), 482 | done: make(chan struct{}), 483 | } 484 | 485 | srv := NewServer() 486 | srv.Register(svc) 487 | handler := &JsonServer{srv} 488 | ts := httptest.NewServer(handler) 489 | defer ts.Close() 490 | 491 | wait := func(desc string, c chan struct{}) { 492 | t.Helper() 493 | select { 494 | case <-c: 495 | return 496 | case <-time.After(5 * time.Second): 497 | t.Fatal("Failed to wait for", desc) 498 | } 499 | } 500 | 501 | b, err := json.Marshal(&jsonClientRequest{ 502 | Method: "Context.Wait", 503 | Params: [1]any{""}, 504 | Id: 1234, 505 | }) 506 | if err != nil { 507 | t.Fatal(err) 508 | } 509 | 510 | ctx, cancel := context.WithCancel(context.Background()) 511 | done := make(chan struct{}) 512 | go func() { 513 | defer close(done) 514 | var req *http.Request 515 | req, err = http.NewRequestWithContext(ctx, http.MethodPost, ts.URL, bytes.NewBuffer(b)) 516 | if err != nil { 517 | return 518 | } 519 | req.Header.Set("Content-Type", "application/json") 520 | _, err = http.DefaultClient.Do(req) 521 | }() 522 | 523 | wait("server side to be called", svc.started) 524 | cancel() 525 | wait("client side to be done after cancel", done) 526 | wait("server side to be done after cancel", svc.done) 527 | if !errors.Is(err, context.Canceled) { 528 | t.Fatalf("expected to fail due to context cancellation: %v", err) 529 | } 530 | } 531 | 532 | // CodecEmulator provides a client-like api and a ServerCodec interface. 533 | // Can be used to test ServeRequest. 534 | type CodecEmulator struct { 535 | server *Server 536 | serviceMethod string 537 | args *Args 538 | reply *Reply 539 | err error 540 | } 541 | 542 | func (codec *CodecEmulator) Call(ctx context.Context, serviceMethod string, args *Args, reply *Reply) error { 543 | codec.serviceMethod = serviceMethod 544 | codec.args = args 545 | codec.reply = reply 546 | codec.err = nil 547 | var serverError error 548 | if codec.server == nil { 549 | serverError = ServeRequest(codec) 550 | } else { 551 | serverError = codec.server.ServeRequest(codec) 552 | } 553 | if codec.err == nil && serverError != nil { 554 | codec.err = serverError 555 | } 556 | return codec.err 557 | } 558 | 559 | func (codec *CodecEmulator) ReadRequestHeader(req *Request) error { 560 | req.ServiceMethod = codec.serviceMethod 561 | req.Seq = 0 562 | return nil 563 | } 564 | 565 | func (codec *CodecEmulator) ReadRequestBody(argv any) error { 566 | if codec.args == nil { 567 | return io.ErrUnexpectedEOF 568 | } 569 | *(argv.(*Args)) = *codec.args 570 | return nil 571 | } 572 | 573 | func (codec *CodecEmulator) WriteResponse(resp *Response, reply any) error { 574 | if resp.Error != "" { 575 | codec.err = errors.New(resp.Error) 576 | } else { 577 | *codec.reply = *(reply.(*Reply)) 578 | } 579 | return nil 580 | } 581 | 582 | func (codec *CodecEmulator) Close() error { 583 | return nil 584 | } 585 | 586 | func TestServeRequest(t *testing.T) { 587 | once.Do(startServer) 588 | testServeRequest(t, nil) 589 | newOnce.Do(startNewServer) 590 | testServeRequest(t, newServer) 591 | } 592 | 593 | func testServeRequest(t *testing.T, server *Server) { 594 | client := CodecEmulator{server: server} 595 | defer client.Close() 596 | 597 | ctx := context.Background() 598 | args := &Args{7, 8} 599 | reply := new(Reply) 600 | err := client.Call(ctx, "Arith.Add", args, reply) 601 | if err != nil { 602 | t.Errorf("Add: expected no error but got string %q", err.Error()) 603 | } 604 | if reply.C != args.A+args.B { 605 | t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 606 | } 607 | 608 | err = client.Call(ctx, "Arith.Add", nil, reply) 609 | if err == nil { 610 | t.Errorf("expected error calling Arith.Add with nil arg") 611 | } 612 | } 613 | 614 | type NeedsCtx int 615 | type ReplyNotPointer int 616 | type ArgNotPublic int 617 | type ReplyNotPublic int 618 | type NeedsPtrType int 619 | type local struct{} 620 | 621 | func (t *NeedsCtx) NeedsCtx(args *Args, reply *Reply) error { 622 | return nil 623 | } 624 | 625 | func (t *ReplyNotPointer) ReplyNotPointer(ctx context.Context, args *Args, reply Reply) error { 626 | return nil 627 | } 628 | 629 | func (t *ArgNotPublic) ArgNotPublic(ctx context.Context, args *local, reply *Reply) error { 630 | return nil 631 | } 632 | 633 | func (t *ReplyNotPublic) ReplyNotPublic(ctx context.Context, args *Args, reply *local) error { 634 | return nil 635 | } 636 | 637 | func (t *NeedsPtrType) NeedsPtrType(ctx context.Context, args *Args, reply *Reply) error { 638 | return nil 639 | } 640 | 641 | // Check that registration handles lots of bad methods and a type with no suitable methods. 642 | func TestRegistrationError(t *testing.T) { 643 | err := Register(new(NeedsCtx)) 644 | if err == nil { 645 | t.Error("expected error registering NeedsCtx") 646 | } 647 | err = Register(new(ReplyNotPointer)) 648 | if err == nil { 649 | t.Error("expected error registering ReplyNotPointer") 650 | } 651 | err = Register(new(ArgNotPublic)) 652 | if err == nil { 653 | t.Error("expected error registering ArgNotPublic") 654 | } 655 | err = Register(new(ReplyNotPublic)) 656 | if err == nil { 657 | t.Error("expected error registering ReplyNotPublic") 658 | } 659 | err = Register(NeedsPtrType(0)) 660 | if err == nil { 661 | t.Error("expected error registering NeedsPtrType") 662 | } else if !strings.Contains(err.Error(), "pointer") { 663 | t.Error("expected hint when registering NeedsPtrType") 664 | } 665 | } 666 | 667 | type WriteFailCodec int 668 | 669 | func (WriteFailCodec) WriteRequest(*Request, any) error { 670 | // the panic caused by this error used to not unlock a lock. 671 | return errors.New("fail") 672 | } 673 | 674 | func (WriteFailCodec) ReadResponseHeader(*Response) error { 675 | select {} 676 | } 677 | 678 | func (WriteFailCodec) ReadResponseBody(any) error { 679 | select {} 680 | } 681 | 682 | func (WriteFailCodec) Close() error { 683 | return nil 684 | } 685 | 686 | func TestSendDeadlock(t *testing.T) { 687 | client := NewClientWithCodec(WriteFailCodec(0)) 688 | defer client.Close() 689 | 690 | done := make(chan bool) 691 | go func() { 692 | testSendDeadlock(client) 693 | testSendDeadlock(client) 694 | done <- true 695 | }() 696 | select { 697 | case <-done: 698 | return 699 | case <-time.After(5 * time.Second): 700 | t.Fatal("deadlock") 701 | } 702 | } 703 | 704 | func testSendDeadlock(client *Client) { 705 | defer func() { 706 | recover() 707 | }() 708 | ctx := context.Background() 709 | args := &Args{7, 8} 710 | reply := new(Reply) 711 | client.Call(ctx, "Arith.Add", args, reply) 712 | } 713 | 714 | func dialDirect() (*Client, error) { 715 | return Dial("tcp", serverAddr) 716 | } 717 | 718 | func dialHTTP() (*Client, error) { 719 | return DialHTTP("tcp", httpServerAddr) 720 | } 721 | 722 | func countMallocs(dial func() (*Client, error), t *testing.T) float64 { 723 | once.Do(startServer) 724 | client, err := dial() 725 | if err != nil { 726 | t.Fatal("error dialing", err) 727 | } 728 | defer client.Close() 729 | 730 | ctx := context.Background() 731 | args := &Args{7, 8} 732 | reply := new(Reply) 733 | return testing.AllocsPerRun(100, func() { 734 | err := client.Call(ctx, "Arith.Add", args, reply) 735 | if err != nil { 736 | t.Errorf("Add: expected no error but got string %q", err.Error()) 737 | } 738 | if reply.C != args.A+args.B { 739 | t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 740 | } 741 | }) 742 | } 743 | 744 | func TestCountMallocs(t *testing.T) { 745 | if testing.Short() { 746 | t.Skip("skipping malloc count in short mode") 747 | } 748 | if runtime.GOMAXPROCS(0) > 1 { 749 | t.Skip("skipping; GOMAXPROCS>1") 750 | } 751 | fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t)) 752 | } 753 | 754 | func TestCountMallocsOverHTTP(t *testing.T) { 755 | if testing.Short() { 756 | t.Skip("skipping malloc count in short mode") 757 | } 758 | if runtime.GOMAXPROCS(0) > 1 { 759 | t.Skip("skipping; GOMAXPROCS>1") 760 | } 761 | fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t)) 762 | } 763 | 764 | type writeCrasher struct { 765 | done chan bool 766 | } 767 | 768 | func (writeCrasher) Close() error { 769 | return nil 770 | } 771 | 772 | func (w *writeCrasher) Read(p []byte) (int, error) { 773 | <-w.done 774 | return 0, io.EOF 775 | } 776 | 777 | func (writeCrasher) Write(p []byte) (int, error) { 778 | return 0, errors.New("fake write failure") 779 | } 780 | 781 | func TestClientWriteError(t *testing.T) { 782 | w := &writeCrasher{done: make(chan bool)} 783 | c := NewClient(w) 784 | defer c.Close() 785 | 786 | ctx := context.Background() 787 | res := false 788 | err := c.Call(ctx, "foo", 1, &res) 789 | if err == nil { 790 | t.Fatal("expected error") 791 | } 792 | if err.Error() != "fake write failure" { 793 | t.Error("unexpected value of error:", err) 794 | } 795 | w.done <- true 796 | } 797 | 798 | func TestTCPClose(t *testing.T) { 799 | once.Do(startServer) 800 | 801 | client, err := dialHTTP() 802 | if err != nil { 803 | t.Fatalf("dialing: %v", err) 804 | } 805 | defer client.Close() 806 | 807 | ctx := context.Background() 808 | args := Args{17, 8} 809 | var reply Reply 810 | err = client.Call(ctx, "Arith.Mul", args, &reply) 811 | if err != nil { 812 | t.Fatal("arith error:", err) 813 | } 814 | t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply) 815 | if reply.C != args.A*args.B { 816 | t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B) 817 | } 818 | } 819 | 820 | func TestErrorAfterClientClose(t *testing.T) { 821 | once.Do(startServer) 822 | ctx := context.Background() 823 | 824 | client, err := dialHTTP() 825 | if err != nil { 826 | t.Fatalf("dialing: %v", err) 827 | } 828 | err = client.Close() 829 | if err != nil { 830 | t.Fatal("close error:", err) 831 | } 832 | err = client.Call(ctx, "Arith.Add", &Args{7, 9}, new(Reply)) 833 | if err != ErrShutdown { 834 | t.Errorf("Forever: expected ErrShutdown got %v", err) 835 | } 836 | } 837 | 838 | // Tests the fix to issue 11221. Without the fix, this loops forever or crashes. 839 | func TestAcceptExitAfterListenerClose(t *testing.T) { 840 | newServer := NewServer() 841 | newServer.Register(new(Arith)) 842 | newServer.RegisterName("net.rpc.Arith", new(Arith)) 843 | newServer.RegisterName("newServer.Arith", new(Arith)) 844 | 845 | var l net.Listener 846 | l, _ = listenTCP() 847 | l.Close() 848 | newServer.Accept(l) 849 | } 850 | 851 | func TestShutdown(t *testing.T) { 852 | var l net.Listener 853 | l, _ = listenTCP() 854 | ch := make(chan net.Conn, 1) 855 | go func() { 856 | defer l.Close() 857 | c, err := l.Accept() 858 | if err != nil { 859 | t.Error(err) 860 | } 861 | ch <- c 862 | }() 863 | c, err := net.Dial("tcp", l.Addr().String()) 864 | if err != nil { 865 | t.Fatal(err) 866 | } 867 | c1 := <-ch 868 | if c1 == nil { 869 | t.Fatal(err) 870 | } 871 | 872 | newServer := NewServer() 873 | newServer.Register(new(Arith)) 874 | go newServer.ServeConn(c1) 875 | 876 | ctx := context.Background() 877 | args := &Args{7, 8} 878 | reply := new(Reply) 879 | client := NewClient(c) 880 | err = client.Call(ctx, "Arith.Add", args, reply) 881 | if err != nil { 882 | t.Fatal(err) 883 | } 884 | 885 | // On an unloaded system 10ms is usually enough to fail 100% of the time 886 | // with a broken server. On a loaded system, a broken server might incorrectly 887 | // be reported as passing, but we're OK with that kind of flakiness. 888 | // If the code is correct, this test will never fail, regardless of timeout. 889 | args.A = 10 // 10 ms 890 | done := make(chan *Call, 1) 891 | call := client.Go("Arith.SleepMilli", args, reply, done) 892 | c.(*net.TCPConn).CloseWrite() 893 | <-done 894 | if call.Error != nil { 895 | t.Fatal(err) 896 | } 897 | } 898 | 899 | func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { 900 | once.Do(startServer) 901 | client, err := dial() 902 | if err != nil { 903 | b.Fatal("error dialing:", err) 904 | } 905 | defer client.Close() 906 | 907 | // Synchronous calls 908 | ctx := context.Background() 909 | args := &Args{7, 8} 910 | b.ResetTimer() 911 | 912 | b.RunParallel(func(pb *testing.PB) { 913 | reply := new(Reply) 914 | for pb.Next() { 915 | err := client.Call(ctx, "Arith.Add", args, reply) 916 | if err != nil { 917 | b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error()) 918 | } 919 | if reply.C != args.A+args.B { 920 | b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B) 921 | } 922 | } 923 | }) 924 | } 925 | 926 | func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) { 927 | if b.N == 0 { 928 | return 929 | } 930 | const MaxConcurrentCalls = 100 931 | once.Do(startServer) 932 | client, err := dial() 933 | if err != nil { 934 | b.Fatal("error dialing:", err) 935 | } 936 | defer client.Close() 937 | 938 | // Asynchronous calls 939 | args := &Args{7, 8} 940 | procs := 4 * runtime.GOMAXPROCS(-1) 941 | send := int32(b.N) 942 | recv := int32(b.N) 943 | var wg sync.WaitGroup 944 | wg.Add(procs) 945 | gate := make(chan bool, MaxConcurrentCalls) 946 | res := make(chan *Call, MaxConcurrentCalls) 947 | b.ResetTimer() 948 | 949 | for p := 0; p < procs; p++ { 950 | go func() { 951 | for atomic.AddInt32(&send, -1) >= 0 { 952 | gate <- true 953 | reply := new(Reply) 954 | client.Go("Arith.Add", args, reply, res) 955 | } 956 | }() 957 | go func() { 958 | for call := range res { 959 | A := call.Args.(*Args).A 960 | B := call.Args.(*Args).B 961 | C := call.Reply.(*Reply).C 962 | if A+B != C { 963 | b.Errorf("incorrect reply: Add: expected %d got %d", A+B, C) 964 | return 965 | } 966 | <-gate 967 | if atomic.AddInt32(&recv, -1) == 0 { 968 | close(res) 969 | } 970 | } 971 | wg.Done() 972 | }() 973 | } 974 | wg.Wait() 975 | } 976 | 977 | func BenchmarkEndToEnd(b *testing.B) { 978 | benchmarkEndToEnd(dialDirect, b) 979 | } 980 | 981 | func BenchmarkEndToEndHTTP(b *testing.B) { 982 | benchmarkEndToEnd(dialHTTP, b) 983 | } 984 | 985 | func BenchmarkEndToEndAsync(b *testing.B) { 986 | benchmarkEndToEndAsync(dialDirect, b) 987 | } 988 | 989 | func BenchmarkEndToEndAsyncHTTP(b *testing.B) { 990 | benchmarkEndToEndAsync(dialHTTP, b) 991 | } 992 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | // Copyright 2009 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | /* 6 | Package rpc is a fork of the stdlib net/rpc which is frozen. It adds 7 | support for context.Context on the client and server, including 8 | propogating cancellation. See the README at 9 | https://github.com/keegancsmith/rpc for motivation why this exists. 10 | 11 | The API is exactly the same, except Client.Call takes a context.Context, 12 | and Server methods are expected to take a context.Context as the first 13 | argument. The following is the original rpc godoc updated to include 14 | context.Context. Additionally the wire protocol is unchanged, so is 15 | backwards compatible with net/rpc clients. 16 | 17 | Package rpc provides access to the exported methods of an object across a 18 | network or other I/O connection. A server registers an object, making it visible 19 | as a service with the name of the type of the object. After registration, exported 20 | methods of the object will be accessible remotely. A server may register multiple 21 | objects (services) of different types but it is an error to register multiple 22 | objects of the same type. 23 | 24 | Only methods that satisfy these criteria will be made available for remote access; 25 | other methods will be ignored: 26 | 27 | - the method's type is exported. 28 | - the method is exported. 29 | - the method has three arguments. 30 | - the method's first argument has type context.Context. 31 | - the method's last two arguments are exported (or builtin) types. 32 | - the method's third argument is a pointer. 33 | - the method has return type error. 34 | 35 | In effect, the method must look schematically like 36 | 37 | func (t *T) MethodName(ctx context.Context, argType T1, replyType *T2) error 38 | 39 | where T1 and T2 can be marshaled by encoding/gob. 40 | These requirements apply even if a different codec is used. 41 | (In the future, these requirements may soften for custom codecs.) 42 | 43 | The method's second argument represents the arguments provided by the caller; the 44 | third argument represents the result parameters to be returned to the caller. 45 | The method's return value, if non-nil, is passed back as a string that the client 46 | sees as if created by errors.New. If an error is returned, the reply parameter 47 | will not be sent back to the client. 48 | 49 | The server may handle requests on a single connection by calling ServeConn. More 50 | typically it will create a network listener and call Accept or, for an HTTP 51 | listener, HandleHTTP and http.Serve. 52 | 53 | A client wishing to use the service establishes a connection and then invokes 54 | NewClient on the connection. The convenience function Dial (DialHTTP) performs 55 | both steps for a raw network connection (an HTTP connection). The resulting 56 | Client object has two methods, Call and Go, that specify the service and method to 57 | call, a pointer containing the arguments, and a pointer to receive the result 58 | parameters. 59 | 60 | The Call method waits for the remote call to complete while the Go method 61 | launches the call asynchronously and signals completion using the Call 62 | structure's Done channel. 63 | 64 | Unless an explicit codec is set up, package encoding/gob is used to 65 | transport the data. 66 | 67 | Here is a simple example. A server wishes to export an object of type Arith: 68 | 69 | package server 70 | 71 | import "errors" 72 | 73 | type Args struct { 74 | A, B int 75 | } 76 | 77 | type Quotient struct { 78 | Quo, Rem int 79 | } 80 | 81 | type Arith int 82 | 83 | func (t *Arith) Multiply(args *Args, reply *int) error { 84 | *reply = args.A * args.B 85 | return nil 86 | } 87 | 88 | func (t *Arith) Divide(args *Args, quo *Quotient) error { 89 | if args.B == 0 { 90 | return errors.New("divide by zero") 91 | } 92 | quo.Quo = args.A / args.B 93 | quo.Rem = args.A % args.B 94 | return nil 95 | } 96 | 97 | The server calls (for HTTP service): 98 | 99 | arith := new(Arith) 100 | rpc.Register(arith) 101 | rpc.HandleHTTP() 102 | l, e := net.Listen("tcp", ":1234") 103 | if e != nil { 104 | log.Fatal("listen error:", e) 105 | } 106 | go http.Serve(l, nil) 107 | 108 | func (t *Arith) Multiply(ctx context.Context, args *Args, reply *int) error { 109 | *reply = args.A * args.B 110 | return nil 111 | } 112 | 113 | At this point, clients can see a service "Arith" with methods "Arith.Multiply" and 114 | "Arith.Divide". To invoke one, a client first dials the server: 115 | 116 | func (t *Arith) Divide(ctx context.Context, args *Args, quo *Quotient) error { 117 | if args.B == 0 { 118 | return errors.New("divide by zero") 119 | } 120 | quo.Quo = args.A / args.B 121 | quo.Rem = args.A % args.B 122 | return nil 123 | } 124 | func (t *Arith) Divide(args *Args, quo *Quotient) error { 125 | if args.B == 0 { 126 | return errors.New("divide by zero") 127 | } 128 | quo.Quo = args.A / args.B 129 | quo.Rem = args.A % args.B 130 | return nil 131 | } 132 | 133 | Then it can make a remote call: 134 | 135 | // Synchronous call 136 | args := &server.Args{7,8} 137 | var reply int 138 | err = client.Call(context.Background(), "Arith.Multiply", args, &reply) 139 | if err != nil { 140 | log.Fatal("arith error:", err) 141 | } 142 | fmt.Printf("Arith: %d*%d=%d", args.A, args.B, reply) 143 | 144 | or 145 | 146 | // Asynchronous call 147 | quotient := new(Quotient) 148 | divCall := client.Go("Arith.Divide", args, quotient, nil) 149 | replyCall := <-divCall.Done // will be equal to divCall 150 | // check errors, print, etc. 151 | 152 | A server implementation will often provide a simple, type-safe wrapper for the 153 | client. 154 | 155 | The net/rpc package is frozen and is not accepting new features. 156 | */ 157 | package rpc 158 | 159 | import ( 160 | "bufio" 161 | "context" 162 | "encoding/gob" 163 | "errors" 164 | "go/token" 165 | "io" 166 | "log" 167 | "net" 168 | "net/http" 169 | "reflect" 170 | "strings" 171 | "sync" 172 | 173 | "github.com/keegancsmith/rpc/internal/svc" 174 | ) 175 | 176 | const ( 177 | // Defaults used by HandleHTTP 178 | DefaultRPCPath = "/_goRPC_" 179 | DefaultDebugPath = "/debug/rpc" 180 | ) 181 | 182 | // Precompute the reflect type for error. Can't use error directly 183 | // because Typeof takes an empty interface value. This is annoying. 184 | var typeOfError = reflect.TypeOf((*error)(nil)).Elem() 185 | var typeOfCtx = reflect.TypeOf((*context.Context)(nil)).Elem() 186 | 187 | type methodType struct { 188 | sync.Mutex // protects counters 189 | method reflect.Method 190 | ArgType reflect.Type 191 | ReplyType reflect.Type 192 | numCalls uint 193 | } 194 | 195 | type service struct { 196 | name string // name of service 197 | rcvr reflect.Value // receiver of methods for the service 198 | typ reflect.Type // type of the receiver 199 | method map[string]*methodType // registered methods 200 | } 201 | 202 | // Request is a header written before every RPC call. It is used internally 203 | // but documented here as an aid to debugging, such as when analyzing 204 | // network traffic. 205 | type Request struct { 206 | ServiceMethod string // format: "Service.Method" 207 | Seq uint64 // sequence number chosen by client 208 | next *Request // for free list in Server 209 | } 210 | 211 | // Response is a header written before every RPC return. It is used internally 212 | // but documented here as an aid to debugging, such as when analyzing 213 | // network traffic. 214 | type Response struct { 215 | ServiceMethod string // echoes that of the Request 216 | Seq uint64 // echoes that of the request 217 | Error string // error, if any. 218 | next *Response // for free list in Server 219 | } 220 | 221 | // Server represents an RPC Server. 222 | type Server struct { 223 | serviceMap sync.Map // map[string]*service 224 | reqLock sync.Mutex // protects freeReq 225 | freeReq *Request 226 | respLock sync.Mutex // protects freeResp 227 | freeResp *Response 228 | } 229 | 230 | // NewServer returns a new Server. 231 | func NewServer() *Server { 232 | s := &Server{} 233 | s.RegisterName("_goRPC_", &svc.GoRPC{}) 234 | return s 235 | } 236 | 237 | // DefaultServer is the default instance of *Server. 238 | var DefaultServer = NewServer() 239 | 240 | // Is this type exported or a builtin? 241 | func isExportedOrBuiltinType(t reflect.Type) bool { 242 | for t.Kind() == reflect.Pointer { 243 | t = t.Elem() 244 | } 245 | // PkgPath will be non-empty even for an exported type, 246 | // so we need to check the type name as well. 247 | return token.IsExported(t.Name()) || t.PkgPath() == "" 248 | } 249 | 250 | // Register publishes in the server the set of methods of the 251 | // receiver value that satisfy the following conditions: 252 | // - exported method of exported type 253 | // - two arguments, both of exported type 254 | // - the second argument is a pointer 255 | // - one return value, of type error 256 | // 257 | // It returns an error if the receiver is not an exported type or has 258 | // no suitable methods. It also logs the error using package log. 259 | // The client accesses each method using a string of the form "Type.Method", 260 | // where Type is the receiver's concrete type. 261 | func (server *Server) Register(rcvr any) error { 262 | return server.register(rcvr, "", false) 263 | } 264 | 265 | // RegisterName is like Register but uses the provided name for the type 266 | // instead of the receiver's concrete type. 267 | func (server *Server) RegisterName(name string, rcvr any) error { 268 | return server.register(rcvr, name, true) 269 | } 270 | 271 | // logRegisterError specifies whether to log problems during method registration. 272 | // To debug registration, recompile the package with this set to true. 273 | const logRegisterError = false 274 | 275 | func (server *Server) register(rcvr any, name string, useName bool) error { 276 | s := new(service) 277 | s.typ = reflect.TypeOf(rcvr) 278 | s.rcvr = reflect.ValueOf(rcvr) 279 | sname := name 280 | if !useName { 281 | sname = reflect.Indirect(s.rcvr).Type().Name() 282 | } 283 | if sname == "" { 284 | s := "rpc.Register: no service name for type " + s.typ.String() 285 | log.Print(s) 286 | return errors.New(s) 287 | } 288 | if !useName && !token.IsExported(sname) { 289 | s := "rpc.Register: type " + sname + " is not exported" 290 | log.Print(s) 291 | return errors.New(s) 292 | } 293 | s.name = sname 294 | 295 | // Install the methods 296 | s.method = suitableMethods(s.typ, logRegisterError) 297 | 298 | if len(s.method) == 0 { 299 | str := "" 300 | 301 | // To help the user, see if a pointer receiver would work. 302 | method := suitableMethods(reflect.PointerTo(s.typ), false) 303 | if len(method) != 0 { 304 | str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)" 305 | } else { 306 | str = "rpc.Register: type " + sname + " has no exported methods of suitable type" 307 | } 308 | log.Print(str) 309 | return errors.New(str) 310 | } 311 | 312 | if _, dup := server.serviceMap.LoadOrStore(sname, s); dup { 313 | return errors.New("rpc: service already defined: " + sname) 314 | } 315 | return nil 316 | } 317 | 318 | // suitableMethods returns suitable Rpc methods of typ. It will log 319 | // errors if logErr is true. 320 | func suitableMethods(typ reflect.Type, logErr bool) map[string]*methodType { 321 | methods := make(map[string]*methodType) 322 | for m := 0; m < typ.NumMethod(); m++ { 323 | method := typ.Method(m) 324 | mtype := method.Type 325 | mname := method.Name 326 | // Method must be exported. 327 | if !method.IsExported() { 328 | continue 329 | } 330 | // Method needs four ins: receiver, ctx, *args, *reply. 331 | if mtype.NumIn() != 4 { 332 | if logErr { 333 | log.Printf("rpc.Register: method %q has %d input parameters; needs exactly three\n", mname, mtype.NumIn()) 334 | } 335 | continue 336 | } 337 | // First arg must be context.Context 338 | if ctxType := mtype.In(1); ctxType != typeOfCtx { 339 | if logErr { 340 | log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, ctxType) 341 | } 342 | continue 343 | } 344 | // Second arg need not be a pointer. 345 | argType := mtype.In(2) 346 | if !isExportedOrBuiltinType(argType) { 347 | if logErr { 348 | log.Printf("rpc.Register: argument type of method %q is not exported: %q\n", mname, argType) 349 | } 350 | continue 351 | } 352 | // Third arg must be a pointer. 353 | replyType := mtype.In(3) 354 | if replyType.Kind() != reflect.Ptr { 355 | if logErr { 356 | log.Printf("rpc.Register: reply type of method %q is not a pointer: %q\n", mname, replyType) 357 | } 358 | continue 359 | } 360 | // Reply type must be exported. 361 | if !isExportedOrBuiltinType(replyType) { 362 | if logErr { 363 | log.Printf("rpc.Register: reply type of method %q is not exported: %q\n", mname, replyType) 364 | } 365 | continue 366 | } 367 | // Method needs one out. 368 | if mtype.NumOut() != 1 { 369 | if logErr { 370 | log.Printf("rpc.Register: method %q has %d output parameters; needs exactly one\n", mname, mtype.NumOut()) 371 | } 372 | continue 373 | } 374 | // The return type of the method must be error. 375 | if returnType := mtype.Out(0); returnType != typeOfError { 376 | if logErr { 377 | log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, returnType) 378 | } 379 | continue 380 | } 381 | methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType} 382 | } 383 | return methods 384 | } 385 | 386 | // A value sent as a placeholder for the server's response value when the server 387 | // receives an invalid request. It is never decoded by the client since the Response 388 | // contains an error when it is used. 389 | var invalidRequest = struct{}{} 390 | 391 | func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply any, codec ServerCodec, errmsg string) { 392 | resp := server.getResponse() 393 | // Encode the response header 394 | resp.ServiceMethod = req.ServiceMethod 395 | if errmsg != "" { 396 | resp.Error = errmsg 397 | reply = invalidRequest 398 | } 399 | resp.Seq = req.Seq 400 | sending.Lock() 401 | err := codec.WriteResponse(resp, reply) 402 | if debugLog && err != nil { 403 | log.Println("rpc: writing response:", err) 404 | } 405 | sending.Unlock() 406 | server.freeResponse(resp) 407 | } 408 | 409 | func (m *methodType) NumCalls() (n uint) { 410 | m.Lock() 411 | n = m.numCalls 412 | m.Unlock() 413 | return n 414 | } 415 | 416 | func (s *service) call(server *Server, sending *sync.Mutex, pending *svc.Pending, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) { 417 | if wg != nil { 418 | defer wg.Done() 419 | } 420 | // _goRPC_ service calls require internal state. 421 | if s.name == "_goRPC_" { 422 | switch v := argv.Interface().(type) { 423 | case *svc.CancelArgs: 424 | v.SetPending(pending) 425 | } 426 | } 427 | mtype.Lock() 428 | mtype.numCalls++ 429 | mtype.Unlock() 430 | ctx := pending.Start(req.Seq) 431 | defer pending.Cancel(req.Seq) 432 | function := mtype.method.Func 433 | // Invoke the method, providing a new value for the reply. 434 | returnValues := function.Call([]reflect.Value{s.rcvr, reflect.ValueOf(ctx), argv, replyv}) 435 | // The return value for the method is an error. 436 | errInter := returnValues[0].Interface() 437 | errmsg := "" 438 | if errInter != nil { 439 | errmsg = errInter.(error).Error() 440 | } 441 | server.sendResponse(sending, req, replyv.Interface(), codec, errmsg) 442 | server.freeRequest(req) 443 | } 444 | 445 | type gobServerCodec struct { 446 | rwc io.ReadWriteCloser 447 | dec *gob.Decoder 448 | enc *gob.Encoder 449 | encBuf *bufio.Writer 450 | closed bool 451 | } 452 | 453 | func (c *gobServerCodec) ReadRequestHeader(r *Request) error { 454 | return c.dec.Decode(r) 455 | } 456 | 457 | func (c *gobServerCodec) ReadRequestBody(body any) error { 458 | return c.dec.Decode(body) 459 | } 460 | 461 | func (c *gobServerCodec) WriteResponse(r *Response, body any) (err error) { 462 | if err = c.enc.Encode(r); err != nil { 463 | if c.encBuf.Flush() == nil { 464 | // Gob couldn't encode the header. Should not happen, so if it does, 465 | // shut down the connection to signal that the connection is broken. 466 | log.Println("rpc: gob error encoding response:", err) 467 | c.Close() 468 | } 469 | return 470 | } 471 | if err = c.enc.Encode(body); err != nil { 472 | if c.encBuf.Flush() == nil { 473 | // Was a gob problem encoding the body but the header has been written. 474 | // Shut down the connection to signal that the connection is broken. 475 | log.Println("rpc: gob error encoding body:", err) 476 | c.Close() 477 | } 478 | return 479 | } 480 | return c.encBuf.Flush() 481 | } 482 | 483 | func (c *gobServerCodec) Close() error { 484 | if c.closed { 485 | // Only call c.rwc.Close once; otherwise the semantics are undefined. 486 | return nil 487 | } 488 | c.closed = true 489 | return c.rwc.Close() 490 | } 491 | 492 | // ServeConn runs the server on a single connection. 493 | // ServeConn blocks, serving the connection until the client hangs up. 494 | // The caller typically invokes ServeConn in a go statement. 495 | // ServeConn uses the gob wire format (see package gob) on the 496 | // connection. To use an alternate codec, use ServeCodec. 497 | // See NewClient's comment for information about concurrent access. 498 | func (server *Server) ServeConn(conn io.ReadWriteCloser) { 499 | buf := bufio.NewWriter(conn) 500 | srv := &gobServerCodec{ 501 | rwc: conn, 502 | dec: gob.NewDecoder(conn), 503 | enc: gob.NewEncoder(buf), 504 | encBuf: buf, 505 | } 506 | server.ServeCodec(srv) 507 | } 508 | 509 | // ServeCodec is like ServeConn but uses the specified codec to 510 | // decode requests and encode responses. 511 | func (server *Server) ServeCodec(codec ServerCodec) { 512 | sending := new(sync.Mutex) 513 | ctx, cancel := context.WithCancel(context.Background()) 514 | defer cancel() 515 | pending := svc.NewPending(ctx) 516 | wg := new(sync.WaitGroup) 517 | for { 518 | service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec) 519 | if err != nil { 520 | if debugLog && err != io.EOF { 521 | log.Println("rpc:", err) 522 | } 523 | if !keepReading { 524 | break 525 | } 526 | // send a response if we actually managed to read a header. 527 | if req != nil { 528 | server.sendResponse(sending, req, invalidRequest, codec, err.Error()) 529 | server.freeRequest(req) 530 | } 531 | continue 532 | } 533 | wg.Add(1) 534 | go service.call(server, sending, pending, wg, mtype, req, argv, replyv, codec) 535 | } 536 | // We've seen that there are no more requests. 537 | // Wait for responses to be sent before closing codec. 538 | wg.Wait() 539 | codec.Close() 540 | } 541 | 542 | // ServeRequest is like ServeCodec but synchronously serves a single request. 543 | // It does not close the codec upon completion. 544 | func (server *Server) ServeRequest(codec ServerCodec) error { 545 | return server.ServeRequestContext(context.Background(), codec) 546 | } 547 | 548 | // ServeRequest is like ServeCodec but synchronously serves a single request. 549 | // It does not close the codec upon completion. 550 | // 551 | // Cancelling the context given here will propagate cancellation to the context 552 | // of the called function. 553 | func (server *Server) ServeRequestContext(ctx context.Context, codec ServerCodec) error { 554 | sending := new(sync.Mutex) 555 | pending := svc.NewPending(ctx) 556 | service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec) 557 | if err != nil { 558 | if !keepReading { 559 | return err 560 | } 561 | // send a response if we actually managed to read a header. 562 | if req != nil { 563 | server.sendResponse(sending, req, invalidRequest, codec, err.Error()) 564 | server.freeRequest(req) 565 | } 566 | return err 567 | } 568 | service.call(server, sending, pending, nil, mtype, req, argv, replyv, codec) 569 | return nil 570 | } 571 | 572 | func (server *Server) getRequest() *Request { 573 | server.reqLock.Lock() 574 | req := server.freeReq 575 | if req == nil { 576 | req = new(Request) 577 | } else { 578 | server.freeReq = req.next 579 | *req = Request{} 580 | } 581 | server.reqLock.Unlock() 582 | return req 583 | } 584 | 585 | func (server *Server) freeRequest(req *Request) { 586 | server.reqLock.Lock() 587 | req.next = server.freeReq 588 | server.freeReq = req 589 | server.reqLock.Unlock() 590 | } 591 | 592 | func (server *Server) getResponse() *Response { 593 | server.respLock.Lock() 594 | resp := server.freeResp 595 | if resp == nil { 596 | resp = new(Response) 597 | } else { 598 | server.freeResp = resp.next 599 | *resp = Response{} 600 | } 601 | server.respLock.Unlock() 602 | return resp 603 | } 604 | 605 | func (server *Server) freeResponse(resp *Response) { 606 | server.respLock.Lock() 607 | resp.next = server.freeResp 608 | server.freeResp = resp 609 | server.respLock.Unlock() 610 | } 611 | 612 | func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) { 613 | service, mtype, req, keepReading, err = server.readRequestHeader(codec) 614 | if err != nil { 615 | if !keepReading { 616 | return 617 | } 618 | // discard body 619 | codec.ReadRequestBody(nil) 620 | return 621 | } 622 | 623 | // Decode the argument value. 624 | argIsValue := false // if true, need to indirect before calling. 625 | if mtype.ArgType.Kind() == reflect.Pointer { 626 | argv = reflect.New(mtype.ArgType.Elem()) 627 | } else { 628 | argv = reflect.New(mtype.ArgType) 629 | argIsValue = true 630 | } 631 | // argv guaranteed to be a pointer now. 632 | if err = codec.ReadRequestBody(argv.Interface()); err != nil { 633 | return 634 | } 635 | if argIsValue { 636 | argv = argv.Elem() 637 | } 638 | 639 | replyv = reflect.New(mtype.ReplyType.Elem()) 640 | 641 | switch mtype.ReplyType.Elem().Kind() { 642 | case reflect.Map: 643 | replyv.Elem().Set(reflect.MakeMap(mtype.ReplyType.Elem())) 644 | case reflect.Slice: 645 | replyv.Elem().Set(reflect.MakeSlice(mtype.ReplyType.Elem(), 0, 0)) 646 | } 647 | return 648 | } 649 | 650 | func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) { 651 | // Grab the request header. 652 | req = server.getRequest() 653 | err = codec.ReadRequestHeader(req) 654 | if err != nil { 655 | req = nil 656 | if err == io.EOF || err == io.ErrUnexpectedEOF { 657 | return 658 | } 659 | err = errors.New("rpc: server cannot decode request: " + err.Error()) 660 | return 661 | } 662 | 663 | // We read the header successfully. If we see an error now, 664 | // we can still recover and move on to the next request. 665 | keepReading = true 666 | 667 | dot := strings.LastIndex(req.ServiceMethod, ".") 668 | if dot < 0 { 669 | err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod) 670 | return 671 | } 672 | serviceName := req.ServiceMethod[:dot] 673 | methodName := req.ServiceMethod[dot+1:] 674 | 675 | // Look up the request. 676 | svci, ok := server.serviceMap.Load(serviceName) 677 | if !ok { 678 | err = errors.New("rpc: can't find service " + req.ServiceMethod) 679 | return 680 | } 681 | svc = svci.(*service) 682 | mtype = svc.method[methodName] 683 | if mtype == nil { 684 | err = errors.New("rpc: can't find method " + req.ServiceMethod) 685 | } 686 | return 687 | } 688 | 689 | // Accept accepts connections on the listener and serves requests 690 | // for each incoming connection. Accept blocks until the listener 691 | // returns a non-nil error. The caller typically invokes Accept in a 692 | // go statement. 693 | func (server *Server) Accept(lis net.Listener) { 694 | for { 695 | conn, err := lis.Accept() 696 | if err != nil { 697 | log.Print("rpc.Serve: accept:", err.Error()) 698 | return 699 | } 700 | go server.ServeConn(conn) 701 | } 702 | } 703 | 704 | // Register publishes the receiver's methods in the DefaultServer. 705 | func Register(rcvr any) error { return DefaultServer.Register(rcvr) } 706 | 707 | // RegisterName is like Register but uses the provided name for the type 708 | // instead of the receiver's concrete type. 709 | func RegisterName(name string, rcvr any) error { 710 | return DefaultServer.RegisterName(name, rcvr) 711 | } 712 | 713 | // A ServerCodec implements reading of RPC requests and writing of 714 | // RPC responses for the server side of an RPC session. 715 | // The server calls ReadRequestHeader and ReadRequestBody in pairs 716 | // to read requests from the connection, and it calls WriteResponse to 717 | // write a response back. The server calls Close when finished with the 718 | // connection. ReadRequestBody may be called with a nil 719 | // argument to force the body of the request to be read and discarded. 720 | // See NewClient's comment for information about concurrent access. 721 | type ServerCodec interface { 722 | ReadRequestHeader(*Request) error 723 | ReadRequestBody(any) error 724 | WriteResponse(*Response, any) error 725 | 726 | // Close can be called multiple times and must be idempotent. 727 | Close() error 728 | } 729 | 730 | // ServeConn runs the DefaultServer on a single connection. 731 | // ServeConn blocks, serving the connection until the client hangs up. 732 | // The caller typically invokes ServeConn in a go statement. 733 | // ServeConn uses the gob wire format (see package gob) on the 734 | // connection. To use an alternate codec, use ServeCodec. 735 | // See NewClient's comment for information about concurrent access. 736 | func ServeConn(conn io.ReadWriteCloser) { 737 | DefaultServer.ServeConn(conn) 738 | } 739 | 740 | // ServeCodec is like ServeConn but uses the specified codec to 741 | // decode requests and encode responses. 742 | func ServeCodec(codec ServerCodec) { 743 | DefaultServer.ServeCodec(codec) 744 | } 745 | 746 | // ServeRequest is like ServeCodec but synchronously serves a single request. 747 | // It does not close the codec upon completion. 748 | func ServeRequest(codec ServerCodec) error { 749 | return ServeRequestContext(context.Background(), codec) 750 | } 751 | 752 | // ServeRequest is like ServeCodec but synchronously serves a single request. 753 | // It does not close the codec upon completion. 754 | // 755 | // Cancelling the context given here will propagate cancellation to the context 756 | // of the called function. 757 | func ServeRequestContext(ctx context.Context, codec ServerCodec) error { 758 | return DefaultServer.ServeRequestContext(ctx, codec) 759 | } 760 | 761 | // Accept accepts connections on the listener and serves requests 762 | // to DefaultServer for each incoming connection. 763 | // Accept blocks; the caller typically invokes it in a go statement. 764 | func Accept(lis net.Listener) { DefaultServer.Accept(lis) } 765 | 766 | // Can connect to RPC service using HTTP CONNECT to rpcPath. 767 | var connected = "200 Connected to Go RPC" 768 | 769 | // ServeHTTP implements an http.Handler that answers RPC requests. 770 | func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { 771 | if req.Method != "CONNECT" { 772 | w.Header().Set("Content-Type", "text/plain; charset=utf-8") 773 | w.WriteHeader(http.StatusMethodNotAllowed) 774 | io.WriteString(w, "405 must CONNECT\n") 775 | return 776 | } 777 | conn, _, err := w.(http.Hijacker).Hijack() 778 | if err != nil { 779 | log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error()) 780 | return 781 | } 782 | io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n") 783 | server.ServeConn(conn) 784 | } 785 | 786 | // HandleHTTP registers an HTTP handler for RPC messages on rpcPath, 787 | // and a debugging handler on debugPath. 788 | // It is still necessary to invoke http.Serve(), typically in a go statement. 789 | func (server *Server) HandleHTTP(rpcPath, debugPath string) { 790 | http.Handle(rpcPath, server) 791 | http.Handle(debugPath, debugHTTP{server}) 792 | } 793 | 794 | // HandleHTTP registers an HTTP handler for RPC messages to DefaultServer 795 | // on DefaultRPCPath and a debugging handler on DefaultDebugPath. 796 | // It is still necessary to invoke http.Serve(), typically in a go statement. 797 | func HandleHTTP() { 798 | DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath) 799 | } 800 | --------------------------------------------------------------------------------