├── go.mod
├── .travis.yml
├── server_jsonrpc_test.go
├── LICENSE
├── internal
└── svc
│ └── svc.go
├── README.md
├── debug.go
├── client_test.go
├── server_util_test.go
├── client.go
├── server_test.go
└── server.go
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/keegancsmith/rpc
2 |
3 | go 1.18
4 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: go
2 | sudo: false
3 | go:
4 | - 1.x
5 |
--------------------------------------------------------------------------------
/server_jsonrpc_test.go:
--------------------------------------------------------------------------------
1 | package rpc
2 |
3 | import (
4 | "io"
5 | )
6 |
7 | type HTTPReadWriteCloser struct {
8 | In io.Reader
9 | Out io.Writer
10 | }
11 |
12 | func (c *HTTPReadWriteCloser) Read(p []byte) (n int, err error) { return c.In.Read(p) }
13 | func (c *HTTPReadWriteCloser) Write(d []byte) (n int, err error) { return c.Out.Write(d) }
14 | func (c *HTTPReadWriteCloser) Close() error { return nil }
15 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2009 The Go Authors. All rights reserved.
2 |
3 | Redistribution and use in source and binary forms, with or without
4 | modification, are permitted provided that the following conditions are
5 | met:
6 |
7 | * Redistributions of source code must retain the above copyright
8 | notice, this list of conditions and the following disclaimer.
9 | * Redistributions in binary form must reproduce the above
10 | copyright notice, this list of conditions and the following disclaimer
11 | in the documentation and/or other materials provided with the
12 | distribution.
13 | * Neither the name of Google Inc. nor the names of its
14 | contributors may be used to endorse or promote products derived from
15 | this software without specific prior written permission.
16 |
17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
21 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
22 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
23 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
25 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
--------------------------------------------------------------------------------
/internal/svc/svc.go:
--------------------------------------------------------------------------------
1 | package svc
2 |
3 | import (
4 | "context"
5 | "sync"
6 | )
7 |
8 | // Pending manages a map of all pending requests to a rpc.Service for a
9 | // connection (an rpc.ServerCodec).
10 | type Pending struct {
11 | mu sync.Mutex
12 | m map[uint64]context.CancelFunc // seq -> cancel
13 | parent context.Context
14 | }
15 |
16 | func NewPending(parent context.Context) *Pending {
17 | return &Pending{
18 | m: make(map[uint64]context.CancelFunc),
19 | parent: parent,
20 | }
21 | }
22 |
23 | func (s *Pending) Start(seq uint64) context.Context {
24 | ctx, cancel := context.WithCancel(s.parent)
25 | s.mu.Lock()
26 | // we assume seq is not already in map. If not, the client is broken.
27 | s.m[seq] = cancel
28 | s.mu.Unlock()
29 | return ctx
30 | }
31 |
32 | func (s *Pending) Cancel(seq uint64) {
33 | s.mu.Lock()
34 | cancel, ok := s.m[seq]
35 | if ok {
36 | delete(s.m, seq)
37 | }
38 | s.mu.Unlock()
39 | if ok {
40 | cancel()
41 | }
42 | }
43 |
44 | type CancelArgs struct {
45 | // Seq is the sequence number for the rpc.Call to cancel.
46 | Seq uint64
47 |
48 | // pending is the DS used by rpc.Server to track the ongoing calls for
49 | // this connection. It should not be set by the client, the Service will
50 | // set it.
51 | pending *Pending
52 | }
53 |
54 | // SetPending sets the pending map for the server to use. Do not use on the
55 | // client.
56 | func (a *CancelArgs) SetPending(p *Pending) {
57 | a.pending = p
58 | }
59 |
60 | // GoRPC is an internal service used by rpc.
61 | type GoRPC struct{}
62 |
63 | func (s *GoRPC) Cancel(ctx context.Context, args *CancelArgs, reply *bool) error {
64 | args.pending.Cancel(args.Seq)
65 | return nil
66 | }
67 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # rpc [](https://travis-ci.org/keegancsmith/rpc)
2 |
3 | This is a fork of the stdlib [net/rpc](https://golang.org/pkg/net/rpc/) which
4 | is frozen. It adds support for `context.Context` on the client and server,
5 | including propagating cancellation.
6 |
7 | The API is exactly the same, except `Client.Call` takes a `context.Context`,
8 | and Server methods are expected to take a `context.Context` as the first
9 | argument. Additionally the wire protocol is unchanged, so is backwards
10 | compatible with `net/rpc` clients.
11 |
12 | `DialHTTPPathTimeout` function is also added. A future release of rpc may
13 | update all Dial functions to instead take a context.
14 |
15 | `ClientTrace` functionality is also added. This is for hooking into the rpc
16 | client to enable tracing.
17 |
18 | ## Why use net/rpc
19 |
20 | There are many alternatives for RPC in Go, the most popular being
21 | [GRPC](https://grpc.io/). However, `net/rpc` has the following nice
22 | properties:
23 |
24 | - Nice API
25 | - No need for IDL
26 | - Good performance
27 |
28 | The nice API is subjective. However, the API is small, simple and composable.
29 | which makes it quite powerful. IDL tools are things like GRPC requiring protoc
30 | to generate go code from the protobuf files. `net/rpc` has no third party
31 | dependencies nor code generation step, simplify the use of it. A benchmark
32 | done on the [6 Sep
33 | 2016](https://github.com/golang/go/issues/16844#issuecomment-245261755)
34 | indicated `net/rpc` was 4x faster than GRPC. This is an outdated benchmark,
35 | but is an indication at the surprisingly good performance `net/rpc` provides.
36 |
37 | For more discussion on the pros and cons of `net/rpc` see the issue [proposal:
38 | freeze net/rpc](https://github.com/golang/go/issues/16844).
39 |
40 | ## Details
41 |
42 | Last forked from commit
43 | [bfadd78986](https://github.com/golang/go/commit/bfadd78986) on 7 September
44 | 2022.
45 |
46 | Cancellation implemented via the rpc call `_goRPC_.Cancel`.
47 |
--------------------------------------------------------------------------------
/debug.go:
--------------------------------------------------------------------------------
1 | // Copyright 2009 The Go Authors. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | package rpc
6 |
7 | /*
8 | Some HTML presented at http://machine:port/debug/rpc
9 | Lists services, their methods, and some statistics, still rudimentary.
10 | */
11 |
12 | import (
13 | "fmt"
14 | "html/template"
15 | "net/http"
16 | "sort"
17 | )
18 |
19 | const debugText = `
20 |
21 | Services
22 | {{range .}}
23 |
24 | Service {{.Name}}
25 |
26 |
27 | | Method | Calls |
28 | {{range .Method}}
29 |
30 | | {{.Name}}({{.Type.ArgType}}, {{.Type.ReplyType}}) error |
31 | {{.Type.NumCalls}} |
32 |
33 | {{end}}
34 |
35 | {{end}}
36 |
37 | `
38 |
39 | var debug = template.Must(template.New("RPC debug").Parse(debugText))
40 |
41 | // If set, print log statements for internal and I/O errors.
42 | var debugLog = false
43 |
44 | type debugMethod struct {
45 | Type *methodType
46 | Name string
47 | }
48 |
49 | type methodArray []debugMethod
50 |
51 | type debugService struct {
52 | Service *service
53 | Name string
54 | Method methodArray
55 | }
56 |
57 | type serviceArray []debugService
58 |
59 | func (s serviceArray) Len() int { return len(s) }
60 | func (s serviceArray) Less(i, j int) bool { return s[i].Name < s[j].Name }
61 | func (s serviceArray) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
62 |
63 | func (m methodArray) Len() int { return len(m) }
64 | func (m methodArray) Less(i, j int) bool { return m[i].Name < m[j].Name }
65 | func (m methodArray) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
66 |
67 | type debugHTTP struct {
68 | *Server
69 | }
70 |
71 | // Runs at /debug/rpc
72 | func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) {
73 | // Build a sorted version of the data.
74 | var services serviceArray
75 | server.serviceMap.Range(func(snamei, svci any) bool {
76 | svc := svci.(*service)
77 | ds := debugService{svc, snamei.(string), make(methodArray, 0, len(svc.method))}
78 | for mname, method := range svc.method {
79 | ds.Method = append(ds.Method, debugMethod{method, mname})
80 | }
81 | sort.Sort(ds.Method)
82 | services = append(services, ds)
83 | return true
84 | })
85 | sort.Sort(services)
86 | err := debug.Execute(w, services)
87 | if err != nil {
88 | fmt.Fprintln(w, "rpc: error executing template:", err.Error())
89 | }
90 | }
91 |
--------------------------------------------------------------------------------
/client_test.go:
--------------------------------------------------------------------------------
1 | // Copyright 2014 The Go Authors. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | package rpc
6 |
7 | import (
8 | "context"
9 | "errors"
10 | "fmt"
11 | "net"
12 | "strings"
13 | "testing"
14 | )
15 |
16 | type shutdownCodec struct {
17 | responded chan int
18 | closed bool
19 | }
20 |
21 | func (c *shutdownCodec) WriteRequest(*Request, any) error { return nil }
22 | func (c *shutdownCodec) ReadResponseBody(any) error { return nil }
23 | func (c *shutdownCodec) ReadResponseHeader(*Response) error {
24 | c.responded <- 1
25 | return errors.New("shutdownCodec ReadResponseHeader")
26 | }
27 | func (c *shutdownCodec) Close() error {
28 | c.closed = true
29 | return nil
30 | }
31 |
32 | func TestCloseCodec(t *testing.T) {
33 | codec := &shutdownCodec{responded: make(chan int)}
34 | client := NewClientWithCodec(codec)
35 | <-codec.responded
36 | client.Close()
37 | if !codec.closed {
38 | t.Error("client.Close did not close codec")
39 | }
40 | }
41 |
42 | // Test that errors in gob shut down the connection. Issue 7689.
43 |
44 | type R struct {
45 | msg []byte // Not exported, so R does not work with gob.
46 | }
47 |
48 | type S struct{}
49 |
50 | func (s *S) Recv(ctx context.Context, nul *struct{}, reply *R) error {
51 | *reply = R{[]byte("foo")}
52 | return nil
53 | }
54 |
55 | func TestGobError(t *testing.T) {
56 | defer func() {
57 | err := recover()
58 | if err == nil {
59 | t.Fatal("no error")
60 | }
61 | if !strings.Contains(err.(error).Error(), "reading body unexpected EOF") {
62 | t.Fatal("expected `reading body unexpected EOF', got", err)
63 | }
64 | }()
65 | Register(new(S))
66 |
67 | listen, err := net.Listen("tcp", "127.0.0.1:0")
68 | if err != nil {
69 | panic(err)
70 | }
71 | go Accept(listen)
72 |
73 | client, err := Dial("tcp", listen.Addr().String())
74 | if err != nil {
75 | panic(err)
76 | }
77 |
78 | var reply Reply
79 | err = client.Call(context.Background(), "S.Recv", &struct{}{}, &reply)
80 | if err != nil {
81 | panic(err)
82 | }
83 |
84 | fmt.Printf("%#v\n", reply)
85 | client.Close()
86 |
87 | listen.Close()
88 | }
89 |
90 | type ClientCodecError struct {
91 | WriteRequestError error
92 | }
93 |
94 | func (c *ClientCodecError) WriteRequest(*Request, any) error {
95 | return c.WriteRequestError
96 | }
97 | func (c *ClientCodecError) ReadResponseHeader(*Response) error {
98 | return nil
99 | }
100 | func (c *ClientCodecError) ReadResponseBody(any) error {
101 | return nil
102 | }
103 | func (c *ClientCodecError) Close() error {
104 | return nil
105 | }
106 |
107 | func TestClientTrace(t *testing.T) {
108 | wantErr := errors.New("test")
109 | client := NewClientWithCodec(&ClientCodecError{WriteRequestError: wantErr})
110 | defer client.Close()
111 |
112 | startCalled := false
113 | var gotErr error
114 | ctx := WithClientTrace(context.Background(), &ClientTrace{
115 | WriteRequestStart: func() { startCalled = true },
116 | WriteRequestDone: func(err error) { gotErr = err },
117 | })
118 |
119 | var reply Reply
120 | err := client.Call(ctx, "S.Recv", &struct{}{}, &reply)
121 | if err != wantErr {
122 | t.Fatalf("expected Call to return the same error sent to ClientTrace.WriteRequestDone: want %v, got %v", wantErr, err)
123 | }
124 | if gotErr != wantErr {
125 | t.Fatalf("expected ClientTrace.WriteRequestDone to be called with error %v, got %v", wantErr, gotErr)
126 | }
127 | if !startCalled {
128 | t.Fatal("expected ClientTrace.WriteRequestStart to be called")
129 | }
130 | }
131 |
--------------------------------------------------------------------------------
/server_util_test.go:
--------------------------------------------------------------------------------
1 | // This code is all of https://golang.org/src/net/rpc/jsonrpc/server.go and some of
2 | // https://golang.org/src/net/rpc/jsonrpc/client.go (both adjusted to use the fork).
3 | //
4 | // Unfortunately but logically the net/rpc/jsonrpc uses net/rpc types which are
5 | // incompatible with this fork, so the code could not be used as-is.
6 | //
7 | // Copyright 2010 The Go Authors. All rights reserved.
8 | // Use of this source code is governed by a BSD-style
9 | // license that can be found in the LICENSE file.
10 | //
11 | package rpc
12 |
13 | import (
14 | "encoding/json"
15 | "errors"
16 | "io"
17 | "sync"
18 | )
19 |
20 | var errMissingParams = errors.New("jsonrpc: request body missing params")
21 |
22 | type jsonServerCodec struct {
23 | dec *json.Decoder // for reading JSON values
24 | enc *json.Encoder // for writing JSON values
25 | c io.Closer
26 |
27 | // temporary work space
28 | req jsonServerRequest
29 |
30 | // JSON-RPC clients can use arbitrary json values as request IDs.
31 | // Package rpc expects uint64 request IDs.
32 | // We assign uint64 sequence numbers to incoming requests
33 | // but save the original request ID in the pending map.
34 | // When rpc responds, we use the sequence number in
35 | // the response to find the original request ID.
36 | mutex sync.Mutex // protects seq, pending
37 | seq uint64
38 | pending map[uint64]*json.RawMessage
39 | }
40 |
41 | // NewServerCodec returns a new rpc.ServerCodec using JSON-RPC on conn.
42 | func NewJsonServerCodec(conn io.ReadWriteCloser) ServerCodec {
43 | return &jsonServerCodec{
44 | dec: json.NewDecoder(conn),
45 | enc: json.NewEncoder(conn),
46 | c: conn,
47 | pending: make(map[uint64]*json.RawMessage),
48 | }
49 | }
50 |
51 | type jsonServerRequest struct {
52 | Method string `json:"method"`
53 | Params *json.RawMessage `json:"params"`
54 | Id *json.RawMessage `json:"id"`
55 | }
56 |
57 | func (r *jsonServerRequest) reset() {
58 | r.Method = ""
59 | r.Params = nil
60 | r.Id = nil
61 | }
62 |
63 | type jsonServerResponse struct {
64 | Id *json.RawMessage `json:"id"`
65 | Result any `json:"result"`
66 | Error any `json:"error"`
67 | }
68 |
69 | func (c *jsonServerCodec) ReadRequestHeader(r *Request) error {
70 | c.req.reset()
71 | if err := c.dec.Decode(&c.req); err != nil {
72 | return err
73 | }
74 | r.ServiceMethod = c.req.Method
75 |
76 | // JSON request id can be any JSON value;
77 | // RPC package expects uint64. Translate to
78 | // internal uint64 and save JSON on the side.
79 | c.mutex.Lock()
80 | c.seq++
81 | c.pending[c.seq] = c.req.Id
82 | c.req.Id = nil
83 | r.Seq = c.seq
84 | c.mutex.Unlock()
85 |
86 | return nil
87 | }
88 |
89 | func (c *jsonServerCodec) ReadRequestBody(x any) error {
90 | if x == nil {
91 | return nil
92 | }
93 | if c.req.Params == nil {
94 | return errMissingParams
95 | }
96 | // JSON params is array value.
97 | // RPC params is struct.
98 | // Unmarshal into array containing struct for now.
99 | // Should think about making RPC more general.
100 | var params [1]any
101 | params[0] = x
102 | return json.Unmarshal(*c.req.Params, ¶ms)
103 | }
104 |
105 | var null = json.RawMessage([]byte("null"))
106 |
107 | func (c *jsonServerCodec) WriteResponse(r *Response, x any) error {
108 | c.mutex.Lock()
109 | b, ok := c.pending[r.Seq]
110 | if !ok {
111 | c.mutex.Unlock()
112 | return errors.New("invalid sequence number in response")
113 | }
114 | delete(c.pending, r.Seq)
115 | c.mutex.Unlock()
116 |
117 | if b == nil {
118 | // Invalid request so no id. Use JSON null.
119 | b = &null
120 | }
121 | resp := jsonServerResponse{Id: b}
122 | if r.Error == "" {
123 | resp.Result = x
124 | } else {
125 | resp.Error = r.Error
126 | }
127 | return c.enc.Encode(resp)
128 | }
129 |
130 | func (c *jsonServerCodec) Close() error {
131 | return c.c.Close()
132 | }
133 |
134 | type jsonClientRequest struct {
135 | Method string `json:"method"`
136 | Params [1]any `json:"params"`
137 | Id uint64 `json:"id"`
138 | }
139 |
--------------------------------------------------------------------------------
/client.go:
--------------------------------------------------------------------------------
1 | // Copyright 2009 The Go Authors. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | package rpc
6 |
7 | import (
8 | "bufio"
9 | "context"
10 | "encoding/gob"
11 | "errors"
12 | "io"
13 | "log"
14 | "net"
15 | "net/http"
16 | "sync"
17 | "time"
18 |
19 | "github.com/keegancsmith/rpc/internal/svc"
20 | )
21 |
22 | // ServerError represents an error that has been returned from
23 | // the remote side of the RPC connection.
24 | type ServerError string
25 |
26 | func (e ServerError) Error() string {
27 | return string(e)
28 | }
29 |
30 | var ErrShutdown = errors.New("connection is shut down")
31 |
32 | // Call represents an active RPC.
33 | type Call struct {
34 | ServiceMethod string // The name of the service and method to call.
35 | Args any // The argument to the function (*struct).
36 | Reply any // The reply from the function (*struct).
37 | Error error // After completion, the error status.
38 | Done chan *Call // Receives *Call when Go is complete.
39 | seq uint64 // Sequence num used to send. Non-zero when sent.
40 | }
41 |
42 | // ClientTrace is a set of hooks to run at various stages of an outgoing RPC
43 | // request. Any particular hook may be nil. Functions may be called
44 | // concurrently from different goroutines.
45 | //
46 | // ClientTrace currently traces a single RPC request, not the response.
47 | type ClientTrace struct {
48 | // WriteRequestStart is called when start WriteRequest. Concurrent calls to
49 | // Client.Go or Client.Call can cause a queue to form, leading to a delay
50 | // between calls to Client.Go/Client.Call and WriteRequestStart being
51 | // called.
52 | WriteRequestStart func()
53 |
54 | // WriteRequestDone is called once WriteRequest returns with the error
55 | // returned from WriteRequest.
56 | WriteRequestDone func(err error)
57 | }
58 |
59 | // unique type to prevent assignment.
60 | type clientTraceContextKey struct{}
61 |
62 | // WithClientTrace returns a new context based on the provided parent
63 | // ctx. Requests made with the returned context will use the provided trace
64 | // hooks. Previous hooks registered with ctx are ignored.
65 | func WithClientTrace(ctx context.Context, trace *ClientTrace) context.Context {
66 | return context.WithValue(ctx, clientTraceContextKey{}, trace)
67 | }
68 |
69 | func contextClientTrace(ctx context.Context) *ClientTrace {
70 | trace, _ := ctx.Value(clientTraceContextKey{}).(*ClientTrace)
71 | return trace
72 | }
73 |
74 | // Client represents an RPC Client.
75 | // There may be multiple outstanding Calls associated
76 | // with a single Client, and a Client may be used by
77 | // multiple goroutines simultaneously.
78 | type Client struct {
79 | codec ClientCodec
80 |
81 | reqMutex sync.Mutex // protects following
82 | request Request
83 |
84 | mutex sync.Mutex // protects following
85 | seq uint64
86 | pending map[uint64]*Call
87 | closing bool // user has called Close
88 | shutdown bool // server has told us to stop
89 | }
90 |
91 | // A ClientCodec implements writing of RPC requests and
92 | // reading of RPC responses for the client side of an RPC session.
93 | // The client calls WriteRequest to write a request to the connection
94 | // and calls ReadResponseHeader and ReadResponseBody in pairs
95 | // to read responses. The client calls Close when finished with the
96 | // connection. ReadResponseBody may be called with a nil
97 | // argument to force the body of the response to be read and then
98 | // discarded.
99 | // See NewClient's comment for information about concurrent access.
100 | type ClientCodec interface {
101 | WriteRequest(*Request, any) error
102 | ReadResponseHeader(*Response) error
103 | ReadResponseBody(any) error
104 |
105 | Close() error
106 | }
107 |
108 | func (client *Client) send(ctx context.Context, call *Call) {
109 | trace := contextClientTrace(ctx)
110 |
111 | client.reqMutex.Lock()
112 | defer client.reqMutex.Unlock()
113 |
114 | // Register this call.
115 | client.mutex.Lock()
116 | if client.shutdown || client.closing {
117 | client.mutex.Unlock()
118 | call.Error = ErrShutdown
119 | call.done()
120 | return
121 | }
122 | if call.seq != 0 {
123 | // It has already been canceled, don't bother sending
124 | call.Error = context.Canceled
125 | client.mutex.Unlock()
126 | call.done()
127 | return
128 | }
129 | client.seq++
130 | seq := client.seq
131 | call.seq = seq
132 | client.pending[seq] = call
133 | client.mutex.Unlock()
134 |
135 | if trace != nil && trace.WriteRequestStart != nil {
136 | trace.WriteRequestStart()
137 | }
138 |
139 | // Encode and send the request.
140 | client.request.Seq = seq
141 | client.request.ServiceMethod = call.ServiceMethod
142 | err := client.codec.WriteRequest(&client.request, call.Args)
143 | if trace != nil && trace.WriteRequestDone != nil {
144 | trace.WriteRequestDone(err)
145 | }
146 | if err != nil {
147 | client.mutex.Lock()
148 | call = client.pending[seq]
149 | delete(client.pending, seq)
150 | client.mutex.Unlock()
151 | if call != nil {
152 | call.Error = err
153 | call.done()
154 | }
155 | }
156 | }
157 |
158 | func (client *Client) input() {
159 | var err error
160 | var response Response
161 | for err == nil {
162 | response = Response{}
163 | err = client.codec.ReadResponseHeader(&response)
164 | if err != nil {
165 | break
166 | }
167 | seq := response.Seq
168 | client.mutex.Lock()
169 | call := client.pending[seq]
170 | delete(client.pending, seq)
171 | client.mutex.Unlock()
172 |
173 | switch {
174 | case call == nil:
175 | // We've got no pending call. That usually means that
176 | // WriteRequest partially failed, and call was already
177 | // removed; response is a server telling us about an
178 | // error reading request body. We should still attempt
179 | // to read error body, but there's no one to give it to.
180 | err = client.codec.ReadResponseBody(nil)
181 | if err != nil {
182 | err = errors.New("reading error body: " + err.Error())
183 | }
184 | case response.Error != "":
185 | // We've got an error response. Give this to the request;
186 | // any subsequent requests will get the ReadResponseBody
187 | // error if there is one.
188 | call.Error = ServerError(response.Error)
189 | err = client.codec.ReadResponseBody(nil)
190 | if err != nil {
191 | err = errors.New("reading error body: " + err.Error())
192 | }
193 | call.done()
194 | default:
195 | err = client.codec.ReadResponseBody(call.Reply)
196 | if err != nil {
197 | call.Error = errors.New("reading body " + err.Error())
198 | }
199 | call.done()
200 | }
201 | }
202 | // Terminate pending calls.
203 | client.reqMutex.Lock()
204 | client.mutex.Lock()
205 | client.shutdown = true
206 | closing := client.closing
207 | if err == io.EOF {
208 | if closing {
209 | err = ErrShutdown
210 | } else {
211 | err = io.ErrUnexpectedEOF
212 | }
213 | }
214 | for _, call := range client.pending {
215 | call.Error = err
216 | call.done()
217 | }
218 | client.mutex.Unlock()
219 | client.reqMutex.Unlock()
220 | if debugLog && err != io.EOF && !closing {
221 | log.Println("rpc: client protocol error:", err)
222 | }
223 | }
224 |
225 | func (call *Call) done() {
226 | select {
227 | case call.Done <- call:
228 | // ok
229 | default:
230 | // We don't want to block here. It is the caller's responsibility to make
231 | // sure the channel has enough buffer space. See comment in Go().
232 | if debugLog {
233 | log.Println("rpc: discarding Call reply due to insufficient Done chan capacity")
234 | }
235 | }
236 | }
237 |
238 | // NewClient returns a new Client to handle requests to the
239 | // set of services at the other end of the connection.
240 | // It adds a buffer to the write side of the connection so
241 | // the header and payload are sent as a unit.
242 | //
243 | // The read and write halves of the connection are serialized independently,
244 | // so no interlocking is required. However each half may be accessed
245 | // concurrently so the implementation of conn should protect against
246 | // concurrent reads or concurrent writes.
247 | func NewClient(conn io.ReadWriteCloser) *Client {
248 | encBuf := bufio.NewWriter(conn)
249 | client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
250 | return NewClientWithCodec(client)
251 | }
252 |
253 | // NewClientWithCodec is like NewClient but uses the specified
254 | // codec to encode requests and decode responses.
255 | func NewClientWithCodec(codec ClientCodec) *Client {
256 | client := &Client{
257 | codec: codec,
258 | pending: make(map[uint64]*Call),
259 | }
260 | go client.input()
261 | return client
262 | }
263 |
264 | type gobClientCodec struct {
265 | rwc io.ReadWriteCloser
266 | dec *gob.Decoder
267 | enc *gob.Encoder
268 | encBuf *bufio.Writer
269 | }
270 |
271 | func (c *gobClientCodec) WriteRequest(r *Request, body any) (err error) {
272 | if err = c.enc.Encode(r); err != nil {
273 | return
274 | }
275 | if err = c.enc.Encode(body); err != nil {
276 | return
277 | }
278 | return c.encBuf.Flush()
279 | }
280 |
281 | func (c *gobClientCodec) ReadResponseHeader(r *Response) error {
282 | return c.dec.Decode(r)
283 | }
284 |
285 | func (c *gobClientCodec) ReadResponseBody(body any) error {
286 | return c.dec.Decode(body)
287 | }
288 |
289 | func (c *gobClientCodec) Close() error {
290 | return c.rwc.Close()
291 | }
292 |
293 | // DialHTTP connects to an HTTP RPC server at the specified network address
294 | // listening on the default HTTP RPC path.
295 | func DialHTTP(network, address string) (*Client, error) {
296 | return DialHTTPPath(network, address, DefaultRPCPath)
297 | }
298 |
299 | // DialHTTPPath connects to an HTTP RPC server
300 | // at the specified network address and path with a default timeout.
301 | func DialHTTPPath(network, address, path string) (*Client, error) {
302 | return DialHTTPPathTimeout(network, address, path, 0)
303 | }
304 |
305 | // DialHTTPPathTimeout connects to an HTTP RPC server
306 | // at the specified network address and path with the specified timeout for Dialing.
307 | //
308 | // This is a function added by github.com/keegancsmith/rpc
309 | func DialHTTPPathTimeout(network, address, path string, timeout time.Duration) (*Client, error) {
310 | conn, err := net.DialTimeout(network, address, timeout)
311 | if err != nil {
312 | return nil, err
313 | }
314 | io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n")
315 |
316 | // Require successful HTTP response
317 | // before switching to RPC protocol.
318 | resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
319 | if err == nil && resp.Status == connected {
320 | return NewClient(conn), nil
321 | }
322 | if err == nil {
323 | err = errors.New("unexpected HTTP response: " + resp.Status)
324 | }
325 | conn.Close()
326 | return nil, &net.OpError{
327 | Op: "dial-http",
328 | Net: network + " " + address,
329 | Addr: nil,
330 | Err: err,
331 | }
332 | }
333 |
334 | // Dial connects to an RPC server at the specified network address.
335 | func Dial(network, address string) (*Client, error) {
336 | conn, err := net.Dial(network, address)
337 | if err != nil {
338 | return nil, err
339 | }
340 | return NewClient(conn), nil
341 | }
342 |
343 | // Close calls the underlying codec's Close method. If the connection is already
344 | // shutting down, ErrShutdown is returned.
345 | func (client *Client) Close() error {
346 | client.mutex.Lock()
347 | if client.closing {
348 | client.mutex.Unlock()
349 | return ErrShutdown
350 | }
351 | client.closing = true
352 | client.mutex.Unlock()
353 | return client.codec.Close()
354 | }
355 |
356 | // Go calls client.GoContext with a background context. See GoContext docstring.
357 | func (client *Client) Go(serviceMethod string, args any, reply any, done chan *Call) *Call {
358 | return client.GoContext(context.Background(), serviceMethod, args, reply, done)
359 | }
360 |
361 | // GoContext invokes the function asynchronously. It returns the Call structure representing
362 | // the invocation. The done channel will signal when the call is complete by returning
363 | // the same Call object. If done is nil, Go will allocate a new channel.
364 | // If non-nil, done must be buffered or Go will deliberately crash.
365 | func (client *Client) GoContext(ctx context.Context, serviceMethod string, args any, reply any, done chan *Call) *Call {
366 | call := new(Call)
367 | call.ServiceMethod = serviceMethod
368 | call.Args = args
369 | call.Reply = reply
370 | if done == nil {
371 | done = make(chan *Call, 10) // buffered.
372 | } else {
373 | // If caller passes done != nil, it must arrange that
374 | // done has enough buffer for the number of simultaneous
375 | // RPCs that will be using that channel. If the channel
376 | // is totally unbuffered, it's best not to run at all.
377 | if cap(done) == 0 {
378 | log.Panic("rpc: done channel is unbuffered")
379 | }
380 | }
381 | call.Done = done
382 | client.send(ctx, call)
383 | return call
384 | }
385 |
386 | // Call invokes the named function, waits for it to complete, and returns its error status.
387 | func (client *Client) Call(ctx context.Context, serviceMethod string, args any, reply any) error {
388 | ch := make(chan *Call, 2) // 2 for this call and cancel
389 | call := client.GoContext(ctx, serviceMethod, args, reply, ch)
390 | select {
391 | case <-call.Done:
392 | return call.Error
393 | case <-ctx.Done():
394 | // Cancel the pending request on the client
395 | client.mutex.Lock()
396 | seq := call.seq
397 | _, ok := client.pending[seq]
398 | delete(client.pending, seq)
399 | if seq == 0 {
400 | // hasn't been sent yet, non-zero will prevent send
401 | call.seq = 1
402 | }
403 | client.mutex.Unlock()
404 |
405 | // Cancel running request on the server
406 | if seq != 0 && ok {
407 | client.Go("_goRPC_.Cancel", &svc.CancelArgs{Seq: seq}, nil, ch)
408 | }
409 |
410 | return ctx.Err()
411 | }
412 | }
413 |
--------------------------------------------------------------------------------
/server_test.go:
--------------------------------------------------------------------------------
1 | // Copyright 2009 The Go Authors. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | package rpc
6 |
7 | import (
8 | "bytes"
9 | "context"
10 | "encoding/json"
11 | "errors"
12 | "fmt"
13 | "io"
14 | "log"
15 | "net"
16 | "net/http"
17 | "net/http/httptest"
18 | "net/url"
19 | "reflect"
20 | "runtime"
21 | "strings"
22 | "sync"
23 | "sync/atomic"
24 | "testing"
25 | "time"
26 | )
27 |
28 | var (
29 | newServer *Server
30 | serverAddr, newServerAddr string
31 | httpServerAddr string
32 | once, newOnce, httpOnce sync.Once
33 | )
34 |
35 | const (
36 | newHttpPath = "/foo"
37 | )
38 |
39 | type Args struct {
40 | A, B int
41 | }
42 |
43 | type Reply struct {
44 | C int
45 | }
46 |
47 | type Arith int
48 |
49 | // Some of Arith's methods have value args, some have pointer args. That's deliberate.
50 |
51 | func (t *Arith) Add(ctx context.Context, args Args, reply *Reply) error {
52 | reply.C = args.A + args.B
53 | return nil
54 | }
55 |
56 | func (t *Arith) Mul(ctx context.Context, args *Args, reply *Reply) error {
57 | reply.C = args.A * args.B
58 | return nil
59 | }
60 |
61 | func (t *Arith) Div(ctx context.Context, args Args, reply *Reply) error {
62 | if args.B == 0 {
63 | return errors.New("divide by zero")
64 | }
65 | reply.C = args.A / args.B
66 | return nil
67 | }
68 |
69 | func (t *Arith) String(ctx context.Context, args *Args, reply *string) error {
70 | *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
71 | return nil
72 | }
73 |
74 | func (t *Arith) Scan(ctx context.Context, args string, reply *Reply) (err error) {
75 | _, err = fmt.Sscan(args, &reply.C)
76 | return
77 | }
78 |
79 | func (t *Arith) Error(ctx context.Context, args *Args, reply *Reply) error {
80 | panic("ERROR")
81 | }
82 |
83 | func (t *Arith) SleepMilli(ctx context.Context, args *Args, reply *Reply) error {
84 | time.Sleep(time.Duration(args.A) * time.Millisecond)
85 | return nil
86 | }
87 |
88 | type hidden int
89 |
90 | func (t *hidden) Exported(ctx context.Context, args Args, reply *Reply) error {
91 | reply.C = args.A + args.B
92 | return nil
93 | }
94 |
95 | type Embed struct {
96 | hidden
97 | }
98 |
99 | type BuiltinTypes struct{}
100 |
101 | func (BuiltinTypes) Map(ctx context.Context, args *Args, reply *map[int]int) error {
102 | (*reply)[args.A] = args.B
103 | return nil
104 | }
105 |
106 | func (BuiltinTypes) Slice(ctx context.Context, args *Args, reply *[]int) error {
107 | *reply = append(*reply, args.A, args.B)
108 | return nil
109 | }
110 |
111 | func (BuiltinTypes) Array(ctx context.Context, args *Args, reply *[2]int) error {
112 | (*reply)[0] = args.A
113 | (*reply)[1] = args.B
114 | return nil
115 | }
116 |
117 | func listenTCP() (net.Listener, string) {
118 | l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
119 | if e != nil {
120 | log.Fatalf("net.Listen tcp :0: %v", e)
121 | }
122 | return l, l.Addr().String()
123 | }
124 |
125 | func startServer() {
126 | Register(new(Arith))
127 | Register(new(Embed))
128 | RegisterName("net.rpc.Arith", new(Arith))
129 | Register(BuiltinTypes{})
130 |
131 | var l net.Listener
132 | l, serverAddr = listenTCP()
133 | log.Println("Test RPC server listening on", serverAddr)
134 | go Accept(l)
135 |
136 | HandleHTTP()
137 | httpOnce.Do(startHttpServer)
138 | }
139 |
140 | func startNewServer() {
141 | newServer = NewServer()
142 | newServer.Register(new(Arith))
143 | newServer.Register(new(Embed))
144 | newServer.RegisterName("net.rpc.Arith", new(Arith))
145 | newServer.RegisterName("newServer.Arith", new(Arith))
146 |
147 | var l net.Listener
148 | l, newServerAddr = listenTCP()
149 | log.Println("NewServer test RPC server listening on", newServerAddr)
150 | go newServer.Accept(l)
151 |
152 | newServer.HandleHTTP(newHttpPath, "/bar")
153 | httpOnce.Do(startHttpServer)
154 | }
155 |
156 | func startHttpServer() {
157 | server := httptest.NewServer(nil)
158 | httpServerAddr = server.Listener.Addr().String()
159 | log.Println("Test HTTP RPC server listening on", httpServerAddr)
160 | }
161 |
162 | func TestRPC(t *testing.T) {
163 | once.Do(startServer)
164 | testRPC(t, serverAddr)
165 | newOnce.Do(startNewServer)
166 | testRPC(t, newServerAddr)
167 | testNewServerRPC(t, newServerAddr)
168 | }
169 |
170 | func testRPC(t *testing.T, addr string) {
171 | client, err := Dial("tcp", addr)
172 | if err != nil {
173 | t.Fatal("dialing", err)
174 | }
175 | defer client.Close()
176 |
177 | // Synchronous calls
178 | ctx := context.Background()
179 | args := &Args{7, 8}
180 | reply := new(Reply)
181 | err = client.Call(ctx, "Arith.Add", args, reply)
182 | if err != nil {
183 | t.Errorf("Add: expected no error but got string %q", err.Error())
184 | }
185 | if reply.C != args.A+args.B {
186 | t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
187 | }
188 |
189 | // Methods exported from unexported embedded structs
190 | args = &Args{7, 0}
191 | reply = new(Reply)
192 | err = client.Call(ctx, "Embed.Exported", args, reply)
193 | if err != nil {
194 | t.Errorf("Add: expected no error but got string %q", err.Error())
195 | }
196 | if reply.C != args.A+args.B {
197 | t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
198 | }
199 |
200 | // Nonexistent method
201 | args = &Args{7, 0}
202 | reply = new(Reply)
203 | err = client.Call(ctx, "Arith.BadOperation", args, reply)
204 | // expect an error
205 | if err == nil {
206 | t.Error("BadOperation: expected error")
207 | } else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") {
208 | t.Errorf("BadOperation: expected can't find method error; got %q", err)
209 | }
210 |
211 | // Unknown service
212 | args = &Args{7, 8}
213 | reply = new(Reply)
214 | err = client.Call(ctx, "Arith.Unknown", args, reply)
215 | if err == nil {
216 | t.Error("expected error calling unknown service")
217 | } else if !strings.Contains(err.Error(), "method") {
218 | t.Error("expected error about method; got", err)
219 | }
220 |
221 | // Out of order.
222 | args = &Args{7, 8}
223 | mulReply := new(Reply)
224 | mulCall := client.Go("Arith.Mul", args, mulReply, nil)
225 | addReply := new(Reply)
226 | addCall := client.Go("Arith.Add", args, addReply, nil)
227 |
228 | addCall = <-addCall.Done
229 | if addCall.Error != nil {
230 | t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
231 | }
232 | if addReply.C != args.A+args.B {
233 | t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
234 | }
235 |
236 | mulCall = <-mulCall.Done
237 | if mulCall.Error != nil {
238 | t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
239 | }
240 | if mulReply.C != args.A*args.B {
241 | t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
242 | }
243 |
244 | // Error test
245 | args = &Args{7, 0}
246 | reply = new(Reply)
247 | err = client.Call(ctx, "Arith.Div", args, reply)
248 | // expect an error: zero divide
249 | if err == nil {
250 | t.Error("Div: expected error")
251 | } else if err.Error() != "divide by zero" {
252 | t.Error("Div: expected divide by zero error; got", err)
253 | }
254 |
255 | // Bad type.
256 | reply = new(Reply)
257 | err = client.Call(ctx, "Arith.Add", reply, reply) // args, reply would be the correct thing to use
258 | if err == nil {
259 | t.Error("expected error calling Arith.Add with wrong arg type")
260 | } else if !strings.Contains(err.Error(), "type") {
261 | t.Error("expected error about type; got", err)
262 | }
263 |
264 | // Non-struct argument
265 | const Val = 12345
266 | str := fmt.Sprint(Val)
267 | reply = new(Reply)
268 | err = client.Call(ctx, "Arith.Scan", &str, reply)
269 | if err != nil {
270 | t.Errorf("Scan: expected no error but got string %q", err.Error())
271 | } else if reply.C != Val {
272 | t.Errorf("Scan: expected %d got %d", Val, reply.C)
273 | }
274 |
275 | // Non-struct reply
276 | args = &Args{27, 35}
277 | str = ""
278 | err = client.Call(ctx, "Arith.String", args, &str)
279 | if err != nil {
280 | t.Errorf("String: expected no error but got string %q", err.Error())
281 | }
282 | expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
283 | if str != expect {
284 | t.Errorf("String: expected %s got %s", expect, str)
285 | }
286 |
287 | args = &Args{7, 8}
288 | reply = new(Reply)
289 | err = client.Call(ctx, "Arith.Mul", args, reply)
290 | if err != nil {
291 | t.Errorf("Mul: expected no error but got string %q", err.Error())
292 | }
293 | if reply.C != args.A*args.B {
294 | t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
295 | }
296 |
297 | // ServiceName contain "." character
298 | args = &Args{7, 8}
299 | reply = new(Reply)
300 | err = client.Call(ctx, "net.rpc.Arith.Add", args, reply)
301 | if err != nil {
302 | t.Errorf("Add: expected no error but got string %q", err.Error())
303 | }
304 | if reply.C != args.A+args.B {
305 | t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
306 | }
307 | }
308 |
309 | func testNewServerRPC(t *testing.T, addr string) {
310 | client, err := Dial("tcp", addr)
311 | if err != nil {
312 | t.Fatal("dialing", err)
313 | }
314 | defer client.Close()
315 |
316 | ctx := context.Background()
317 |
318 | // Synchronous calls
319 | args := &Args{7, 8}
320 | reply := new(Reply)
321 | err = client.Call(ctx, "newServer.Arith.Add", args, reply)
322 | if err != nil {
323 | t.Errorf("Add: expected no error but got string %q", err.Error())
324 | }
325 | if reply.C != args.A+args.B {
326 | t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
327 | }
328 | }
329 |
330 | func TestHTTP(t *testing.T) {
331 | once.Do(startServer)
332 | testHTTPRPC(t, "")
333 | newOnce.Do(startNewServer)
334 | testHTTPRPC(t, newHttpPath)
335 | }
336 |
337 | func testHTTPRPC(t *testing.T, path string) {
338 | var client *Client
339 | var err error
340 | if path == "" {
341 | client, err = DialHTTP("tcp", httpServerAddr)
342 | } else {
343 | client, err = DialHTTPPath("tcp", httpServerAddr, path)
344 | }
345 | if err != nil {
346 | t.Fatal("dialing", err)
347 | }
348 | defer client.Close()
349 |
350 | ctx := context.Background()
351 |
352 | // Synchronous calls
353 | args := &Args{7, 8}
354 | reply := new(Reply)
355 | err = client.Call(ctx, "Arith.Add", args, reply)
356 | if err != nil {
357 | t.Errorf("Add: expected no error but got string %q", err.Error())
358 | }
359 | if reply.C != args.A+args.B {
360 | t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
361 | }
362 | }
363 |
364 | func TestBuiltinTypes(t *testing.T) {
365 | once.Do(startServer)
366 |
367 | client, err := DialHTTP("tcp", httpServerAddr)
368 | if err != nil {
369 | t.Fatal("dialing", err)
370 | }
371 | defer client.Close()
372 |
373 | ctx := context.Background()
374 |
375 | // Map
376 | args := &Args{7, 8}
377 | replyMap := map[int]int{}
378 | err = client.Call(ctx, "BuiltinTypes.Map", args, &replyMap)
379 | if err != nil {
380 | t.Errorf("Map: expected no error but got string %q", err.Error())
381 | }
382 | if replyMap[args.A] != args.B {
383 | t.Errorf("Map: expected %d got %d", args.B, replyMap[args.A])
384 | }
385 |
386 | // Slice
387 | args = &Args{7, 8}
388 | replySlice := []int{}
389 | err = client.Call(ctx, "BuiltinTypes.Slice", args, &replySlice)
390 | if err != nil {
391 | t.Errorf("Slice: expected no error but got string %q", err.Error())
392 | }
393 | if e := []int{args.A, args.B}; !reflect.DeepEqual(replySlice, e) {
394 | t.Errorf("Slice: expected %v got %v", e, replySlice)
395 | }
396 |
397 | // Array
398 | args = &Args{7, 8}
399 | replyArray := [2]int{}
400 | err = client.Call(ctx, "BuiltinTypes.Array", args, &replyArray)
401 | if err != nil {
402 | t.Errorf("Array: expected no error but got string %q", err.Error())
403 | }
404 | if e := [2]int{args.A, args.B}; !reflect.DeepEqual(replyArray, e) {
405 | t.Errorf("Array: expected %v got %v", e, replyArray)
406 | }
407 | }
408 |
409 | type Context struct {
410 | started chan struct{}
411 | done chan struct{}
412 | }
413 |
414 | func (t *Context) Wait(ctx context.Context, s string, reply *int) error {
415 | close(t.started)
416 | <-ctx.Done()
417 | close(t.done)
418 | return nil
419 | }
420 |
421 | func TestContext(t *testing.T) {
422 | svc := &Context{
423 | started: make(chan struct{}),
424 | done: make(chan struct{}),
425 | }
426 |
427 | handler := NewServer()
428 | handler.Register(svc)
429 | ts := httptest.NewServer(handler)
430 | defer ts.Close()
431 | u, err := url.Parse(ts.URL)
432 | if err != nil {
433 | t.Fatal(err)
434 | }
435 | cl, err := DialHTTP("tcp", u.Host)
436 | if err != nil {
437 | t.Fatal(err)
438 | }
439 | defer cl.Close()
440 |
441 | wait := func(desc string, c chan struct{}) {
442 | select {
443 | case <-c:
444 | return
445 | case <-time.After(5 * time.Second):
446 | t.Fatal("Failed to wait for", desc)
447 | }
448 | }
449 |
450 | ctx, cancel := context.WithCancel(context.Background())
451 | done := make(chan struct{})
452 | go func() {
453 | args := ""
454 | err = cl.Call(ctx, "Context.Wait", &args, new(int))
455 | close(done)
456 | }()
457 | wait("server side to be called", svc.started)
458 | cancel()
459 | wait("client side to be done after cancel", done)
460 | wait("server side to be done after cancel", svc.done)
461 | if err != context.Canceled {
462 | t.Fatalf("expected to fail due to context cancellation: %v", err)
463 | }
464 | }
465 |
466 | type JsonServer struct {
467 | srv *Server
468 | }
469 |
470 | func (server *JsonServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
471 | rwc := &HTTPReadWriteCloser{
472 | In: req.Body,
473 | Out: w,
474 | }
475 | codec := NewJsonServerCodec(rwc)
476 | server.srv.ServeRequestContext(req.Context(), codec)
477 | }
478 |
479 | func TestContextCodec(t *testing.T) {
480 | svc := &Context{
481 | started: make(chan struct{}),
482 | done: make(chan struct{}),
483 | }
484 |
485 | srv := NewServer()
486 | srv.Register(svc)
487 | handler := &JsonServer{srv}
488 | ts := httptest.NewServer(handler)
489 | defer ts.Close()
490 |
491 | wait := func(desc string, c chan struct{}) {
492 | t.Helper()
493 | select {
494 | case <-c:
495 | return
496 | case <-time.After(5 * time.Second):
497 | t.Fatal("Failed to wait for", desc)
498 | }
499 | }
500 |
501 | b, err := json.Marshal(&jsonClientRequest{
502 | Method: "Context.Wait",
503 | Params: [1]any{""},
504 | Id: 1234,
505 | })
506 | if err != nil {
507 | t.Fatal(err)
508 | }
509 |
510 | ctx, cancel := context.WithCancel(context.Background())
511 | done := make(chan struct{})
512 | go func() {
513 | defer close(done)
514 | var req *http.Request
515 | req, err = http.NewRequestWithContext(ctx, http.MethodPost, ts.URL, bytes.NewBuffer(b))
516 | if err != nil {
517 | return
518 | }
519 | req.Header.Set("Content-Type", "application/json")
520 | _, err = http.DefaultClient.Do(req)
521 | }()
522 |
523 | wait("server side to be called", svc.started)
524 | cancel()
525 | wait("client side to be done after cancel", done)
526 | wait("server side to be done after cancel", svc.done)
527 | if !errors.Is(err, context.Canceled) {
528 | t.Fatalf("expected to fail due to context cancellation: %v", err)
529 | }
530 | }
531 |
532 | // CodecEmulator provides a client-like api and a ServerCodec interface.
533 | // Can be used to test ServeRequest.
534 | type CodecEmulator struct {
535 | server *Server
536 | serviceMethod string
537 | args *Args
538 | reply *Reply
539 | err error
540 | }
541 |
542 | func (codec *CodecEmulator) Call(ctx context.Context, serviceMethod string, args *Args, reply *Reply) error {
543 | codec.serviceMethod = serviceMethod
544 | codec.args = args
545 | codec.reply = reply
546 | codec.err = nil
547 | var serverError error
548 | if codec.server == nil {
549 | serverError = ServeRequest(codec)
550 | } else {
551 | serverError = codec.server.ServeRequest(codec)
552 | }
553 | if codec.err == nil && serverError != nil {
554 | codec.err = serverError
555 | }
556 | return codec.err
557 | }
558 |
559 | func (codec *CodecEmulator) ReadRequestHeader(req *Request) error {
560 | req.ServiceMethod = codec.serviceMethod
561 | req.Seq = 0
562 | return nil
563 | }
564 |
565 | func (codec *CodecEmulator) ReadRequestBody(argv any) error {
566 | if codec.args == nil {
567 | return io.ErrUnexpectedEOF
568 | }
569 | *(argv.(*Args)) = *codec.args
570 | return nil
571 | }
572 |
573 | func (codec *CodecEmulator) WriteResponse(resp *Response, reply any) error {
574 | if resp.Error != "" {
575 | codec.err = errors.New(resp.Error)
576 | } else {
577 | *codec.reply = *(reply.(*Reply))
578 | }
579 | return nil
580 | }
581 |
582 | func (codec *CodecEmulator) Close() error {
583 | return nil
584 | }
585 |
586 | func TestServeRequest(t *testing.T) {
587 | once.Do(startServer)
588 | testServeRequest(t, nil)
589 | newOnce.Do(startNewServer)
590 | testServeRequest(t, newServer)
591 | }
592 |
593 | func testServeRequest(t *testing.T, server *Server) {
594 | client := CodecEmulator{server: server}
595 | defer client.Close()
596 |
597 | ctx := context.Background()
598 | args := &Args{7, 8}
599 | reply := new(Reply)
600 | err := client.Call(ctx, "Arith.Add", args, reply)
601 | if err != nil {
602 | t.Errorf("Add: expected no error but got string %q", err.Error())
603 | }
604 | if reply.C != args.A+args.B {
605 | t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
606 | }
607 |
608 | err = client.Call(ctx, "Arith.Add", nil, reply)
609 | if err == nil {
610 | t.Errorf("expected error calling Arith.Add with nil arg")
611 | }
612 | }
613 |
614 | type NeedsCtx int
615 | type ReplyNotPointer int
616 | type ArgNotPublic int
617 | type ReplyNotPublic int
618 | type NeedsPtrType int
619 | type local struct{}
620 |
621 | func (t *NeedsCtx) NeedsCtx(args *Args, reply *Reply) error {
622 | return nil
623 | }
624 |
625 | func (t *ReplyNotPointer) ReplyNotPointer(ctx context.Context, args *Args, reply Reply) error {
626 | return nil
627 | }
628 |
629 | func (t *ArgNotPublic) ArgNotPublic(ctx context.Context, args *local, reply *Reply) error {
630 | return nil
631 | }
632 |
633 | func (t *ReplyNotPublic) ReplyNotPublic(ctx context.Context, args *Args, reply *local) error {
634 | return nil
635 | }
636 |
637 | func (t *NeedsPtrType) NeedsPtrType(ctx context.Context, args *Args, reply *Reply) error {
638 | return nil
639 | }
640 |
641 | // Check that registration handles lots of bad methods and a type with no suitable methods.
642 | func TestRegistrationError(t *testing.T) {
643 | err := Register(new(NeedsCtx))
644 | if err == nil {
645 | t.Error("expected error registering NeedsCtx")
646 | }
647 | err = Register(new(ReplyNotPointer))
648 | if err == nil {
649 | t.Error("expected error registering ReplyNotPointer")
650 | }
651 | err = Register(new(ArgNotPublic))
652 | if err == nil {
653 | t.Error("expected error registering ArgNotPublic")
654 | }
655 | err = Register(new(ReplyNotPublic))
656 | if err == nil {
657 | t.Error("expected error registering ReplyNotPublic")
658 | }
659 | err = Register(NeedsPtrType(0))
660 | if err == nil {
661 | t.Error("expected error registering NeedsPtrType")
662 | } else if !strings.Contains(err.Error(), "pointer") {
663 | t.Error("expected hint when registering NeedsPtrType")
664 | }
665 | }
666 |
667 | type WriteFailCodec int
668 |
669 | func (WriteFailCodec) WriteRequest(*Request, any) error {
670 | // the panic caused by this error used to not unlock a lock.
671 | return errors.New("fail")
672 | }
673 |
674 | func (WriteFailCodec) ReadResponseHeader(*Response) error {
675 | select {}
676 | }
677 |
678 | func (WriteFailCodec) ReadResponseBody(any) error {
679 | select {}
680 | }
681 |
682 | func (WriteFailCodec) Close() error {
683 | return nil
684 | }
685 |
686 | func TestSendDeadlock(t *testing.T) {
687 | client := NewClientWithCodec(WriteFailCodec(0))
688 | defer client.Close()
689 |
690 | done := make(chan bool)
691 | go func() {
692 | testSendDeadlock(client)
693 | testSendDeadlock(client)
694 | done <- true
695 | }()
696 | select {
697 | case <-done:
698 | return
699 | case <-time.After(5 * time.Second):
700 | t.Fatal("deadlock")
701 | }
702 | }
703 |
704 | func testSendDeadlock(client *Client) {
705 | defer func() {
706 | recover()
707 | }()
708 | ctx := context.Background()
709 | args := &Args{7, 8}
710 | reply := new(Reply)
711 | client.Call(ctx, "Arith.Add", args, reply)
712 | }
713 |
714 | func dialDirect() (*Client, error) {
715 | return Dial("tcp", serverAddr)
716 | }
717 |
718 | func dialHTTP() (*Client, error) {
719 | return DialHTTP("tcp", httpServerAddr)
720 | }
721 |
722 | func countMallocs(dial func() (*Client, error), t *testing.T) float64 {
723 | once.Do(startServer)
724 | client, err := dial()
725 | if err != nil {
726 | t.Fatal("error dialing", err)
727 | }
728 | defer client.Close()
729 |
730 | ctx := context.Background()
731 | args := &Args{7, 8}
732 | reply := new(Reply)
733 | return testing.AllocsPerRun(100, func() {
734 | err := client.Call(ctx, "Arith.Add", args, reply)
735 | if err != nil {
736 | t.Errorf("Add: expected no error but got string %q", err.Error())
737 | }
738 | if reply.C != args.A+args.B {
739 | t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
740 | }
741 | })
742 | }
743 |
744 | func TestCountMallocs(t *testing.T) {
745 | if testing.Short() {
746 | t.Skip("skipping malloc count in short mode")
747 | }
748 | if runtime.GOMAXPROCS(0) > 1 {
749 | t.Skip("skipping; GOMAXPROCS>1")
750 | }
751 | fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t))
752 | }
753 |
754 | func TestCountMallocsOverHTTP(t *testing.T) {
755 | if testing.Short() {
756 | t.Skip("skipping malloc count in short mode")
757 | }
758 | if runtime.GOMAXPROCS(0) > 1 {
759 | t.Skip("skipping; GOMAXPROCS>1")
760 | }
761 | fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t))
762 | }
763 |
764 | type writeCrasher struct {
765 | done chan bool
766 | }
767 |
768 | func (writeCrasher) Close() error {
769 | return nil
770 | }
771 |
772 | func (w *writeCrasher) Read(p []byte) (int, error) {
773 | <-w.done
774 | return 0, io.EOF
775 | }
776 |
777 | func (writeCrasher) Write(p []byte) (int, error) {
778 | return 0, errors.New("fake write failure")
779 | }
780 |
781 | func TestClientWriteError(t *testing.T) {
782 | w := &writeCrasher{done: make(chan bool)}
783 | c := NewClient(w)
784 | defer c.Close()
785 |
786 | ctx := context.Background()
787 | res := false
788 | err := c.Call(ctx, "foo", 1, &res)
789 | if err == nil {
790 | t.Fatal("expected error")
791 | }
792 | if err.Error() != "fake write failure" {
793 | t.Error("unexpected value of error:", err)
794 | }
795 | w.done <- true
796 | }
797 |
798 | func TestTCPClose(t *testing.T) {
799 | once.Do(startServer)
800 |
801 | client, err := dialHTTP()
802 | if err != nil {
803 | t.Fatalf("dialing: %v", err)
804 | }
805 | defer client.Close()
806 |
807 | ctx := context.Background()
808 | args := Args{17, 8}
809 | var reply Reply
810 | err = client.Call(ctx, "Arith.Mul", args, &reply)
811 | if err != nil {
812 | t.Fatal("arith error:", err)
813 | }
814 | t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply)
815 | if reply.C != args.A*args.B {
816 | t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B)
817 | }
818 | }
819 |
820 | func TestErrorAfterClientClose(t *testing.T) {
821 | once.Do(startServer)
822 | ctx := context.Background()
823 |
824 | client, err := dialHTTP()
825 | if err != nil {
826 | t.Fatalf("dialing: %v", err)
827 | }
828 | err = client.Close()
829 | if err != nil {
830 | t.Fatal("close error:", err)
831 | }
832 | err = client.Call(ctx, "Arith.Add", &Args{7, 9}, new(Reply))
833 | if err != ErrShutdown {
834 | t.Errorf("Forever: expected ErrShutdown got %v", err)
835 | }
836 | }
837 |
838 | // Tests the fix to issue 11221. Without the fix, this loops forever or crashes.
839 | func TestAcceptExitAfterListenerClose(t *testing.T) {
840 | newServer := NewServer()
841 | newServer.Register(new(Arith))
842 | newServer.RegisterName("net.rpc.Arith", new(Arith))
843 | newServer.RegisterName("newServer.Arith", new(Arith))
844 |
845 | var l net.Listener
846 | l, _ = listenTCP()
847 | l.Close()
848 | newServer.Accept(l)
849 | }
850 |
851 | func TestShutdown(t *testing.T) {
852 | var l net.Listener
853 | l, _ = listenTCP()
854 | ch := make(chan net.Conn, 1)
855 | go func() {
856 | defer l.Close()
857 | c, err := l.Accept()
858 | if err != nil {
859 | t.Error(err)
860 | }
861 | ch <- c
862 | }()
863 | c, err := net.Dial("tcp", l.Addr().String())
864 | if err != nil {
865 | t.Fatal(err)
866 | }
867 | c1 := <-ch
868 | if c1 == nil {
869 | t.Fatal(err)
870 | }
871 |
872 | newServer := NewServer()
873 | newServer.Register(new(Arith))
874 | go newServer.ServeConn(c1)
875 |
876 | ctx := context.Background()
877 | args := &Args{7, 8}
878 | reply := new(Reply)
879 | client := NewClient(c)
880 | err = client.Call(ctx, "Arith.Add", args, reply)
881 | if err != nil {
882 | t.Fatal(err)
883 | }
884 |
885 | // On an unloaded system 10ms is usually enough to fail 100% of the time
886 | // with a broken server. On a loaded system, a broken server might incorrectly
887 | // be reported as passing, but we're OK with that kind of flakiness.
888 | // If the code is correct, this test will never fail, regardless of timeout.
889 | args.A = 10 // 10 ms
890 | done := make(chan *Call, 1)
891 | call := client.Go("Arith.SleepMilli", args, reply, done)
892 | c.(*net.TCPConn).CloseWrite()
893 | <-done
894 | if call.Error != nil {
895 | t.Fatal(err)
896 | }
897 | }
898 |
899 | func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
900 | once.Do(startServer)
901 | client, err := dial()
902 | if err != nil {
903 | b.Fatal("error dialing:", err)
904 | }
905 | defer client.Close()
906 |
907 | // Synchronous calls
908 | ctx := context.Background()
909 | args := &Args{7, 8}
910 | b.ResetTimer()
911 |
912 | b.RunParallel(func(pb *testing.PB) {
913 | reply := new(Reply)
914 | for pb.Next() {
915 | err := client.Call(ctx, "Arith.Add", args, reply)
916 | if err != nil {
917 | b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error())
918 | }
919 | if reply.C != args.A+args.B {
920 | b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B)
921 | }
922 | }
923 | })
924 | }
925 |
926 | func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) {
927 | if b.N == 0 {
928 | return
929 | }
930 | const MaxConcurrentCalls = 100
931 | once.Do(startServer)
932 | client, err := dial()
933 | if err != nil {
934 | b.Fatal("error dialing:", err)
935 | }
936 | defer client.Close()
937 |
938 | // Asynchronous calls
939 | args := &Args{7, 8}
940 | procs := 4 * runtime.GOMAXPROCS(-1)
941 | send := int32(b.N)
942 | recv := int32(b.N)
943 | var wg sync.WaitGroup
944 | wg.Add(procs)
945 | gate := make(chan bool, MaxConcurrentCalls)
946 | res := make(chan *Call, MaxConcurrentCalls)
947 | b.ResetTimer()
948 |
949 | for p := 0; p < procs; p++ {
950 | go func() {
951 | for atomic.AddInt32(&send, -1) >= 0 {
952 | gate <- true
953 | reply := new(Reply)
954 | client.Go("Arith.Add", args, reply, res)
955 | }
956 | }()
957 | go func() {
958 | for call := range res {
959 | A := call.Args.(*Args).A
960 | B := call.Args.(*Args).B
961 | C := call.Reply.(*Reply).C
962 | if A+B != C {
963 | b.Errorf("incorrect reply: Add: expected %d got %d", A+B, C)
964 | return
965 | }
966 | <-gate
967 | if atomic.AddInt32(&recv, -1) == 0 {
968 | close(res)
969 | }
970 | }
971 | wg.Done()
972 | }()
973 | }
974 | wg.Wait()
975 | }
976 |
977 | func BenchmarkEndToEnd(b *testing.B) {
978 | benchmarkEndToEnd(dialDirect, b)
979 | }
980 |
981 | func BenchmarkEndToEndHTTP(b *testing.B) {
982 | benchmarkEndToEnd(dialHTTP, b)
983 | }
984 |
985 | func BenchmarkEndToEndAsync(b *testing.B) {
986 | benchmarkEndToEndAsync(dialDirect, b)
987 | }
988 |
989 | func BenchmarkEndToEndAsyncHTTP(b *testing.B) {
990 | benchmarkEndToEndAsync(dialHTTP, b)
991 | }
992 |
--------------------------------------------------------------------------------
/server.go:
--------------------------------------------------------------------------------
1 | // Copyright 2009 The Go Authors. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | /*
6 | Package rpc is a fork of the stdlib net/rpc which is frozen. It adds
7 | support for context.Context on the client and server, including
8 | propogating cancellation. See the README at
9 | https://github.com/keegancsmith/rpc for motivation why this exists.
10 |
11 | The API is exactly the same, except Client.Call takes a context.Context,
12 | and Server methods are expected to take a context.Context as the first
13 | argument. The following is the original rpc godoc updated to include
14 | context.Context. Additionally the wire protocol is unchanged, so is
15 | backwards compatible with net/rpc clients.
16 |
17 | Package rpc provides access to the exported methods of an object across a
18 | network or other I/O connection. A server registers an object, making it visible
19 | as a service with the name of the type of the object. After registration, exported
20 | methods of the object will be accessible remotely. A server may register multiple
21 | objects (services) of different types but it is an error to register multiple
22 | objects of the same type.
23 |
24 | Only methods that satisfy these criteria will be made available for remote access;
25 | other methods will be ignored:
26 |
27 | - the method's type is exported.
28 | - the method is exported.
29 | - the method has three arguments.
30 | - the method's first argument has type context.Context.
31 | - the method's last two arguments are exported (or builtin) types.
32 | - the method's third argument is a pointer.
33 | - the method has return type error.
34 |
35 | In effect, the method must look schematically like
36 |
37 | func (t *T) MethodName(ctx context.Context, argType T1, replyType *T2) error
38 |
39 | where T1 and T2 can be marshaled by encoding/gob.
40 | These requirements apply even if a different codec is used.
41 | (In the future, these requirements may soften for custom codecs.)
42 |
43 | The method's second argument represents the arguments provided by the caller; the
44 | third argument represents the result parameters to be returned to the caller.
45 | The method's return value, if non-nil, is passed back as a string that the client
46 | sees as if created by errors.New. If an error is returned, the reply parameter
47 | will not be sent back to the client.
48 |
49 | The server may handle requests on a single connection by calling ServeConn. More
50 | typically it will create a network listener and call Accept or, for an HTTP
51 | listener, HandleHTTP and http.Serve.
52 |
53 | A client wishing to use the service establishes a connection and then invokes
54 | NewClient on the connection. The convenience function Dial (DialHTTP) performs
55 | both steps for a raw network connection (an HTTP connection). The resulting
56 | Client object has two methods, Call and Go, that specify the service and method to
57 | call, a pointer containing the arguments, and a pointer to receive the result
58 | parameters.
59 |
60 | The Call method waits for the remote call to complete while the Go method
61 | launches the call asynchronously and signals completion using the Call
62 | structure's Done channel.
63 |
64 | Unless an explicit codec is set up, package encoding/gob is used to
65 | transport the data.
66 |
67 | Here is a simple example. A server wishes to export an object of type Arith:
68 |
69 | package server
70 |
71 | import "errors"
72 |
73 | type Args struct {
74 | A, B int
75 | }
76 |
77 | type Quotient struct {
78 | Quo, Rem int
79 | }
80 |
81 | type Arith int
82 |
83 | func (t *Arith) Multiply(args *Args, reply *int) error {
84 | *reply = args.A * args.B
85 | return nil
86 | }
87 |
88 | func (t *Arith) Divide(args *Args, quo *Quotient) error {
89 | if args.B == 0 {
90 | return errors.New("divide by zero")
91 | }
92 | quo.Quo = args.A / args.B
93 | quo.Rem = args.A % args.B
94 | return nil
95 | }
96 |
97 | The server calls (for HTTP service):
98 |
99 | arith := new(Arith)
100 | rpc.Register(arith)
101 | rpc.HandleHTTP()
102 | l, e := net.Listen("tcp", ":1234")
103 | if e != nil {
104 | log.Fatal("listen error:", e)
105 | }
106 | go http.Serve(l, nil)
107 |
108 | func (t *Arith) Multiply(ctx context.Context, args *Args, reply *int) error {
109 | *reply = args.A * args.B
110 | return nil
111 | }
112 |
113 | At this point, clients can see a service "Arith" with methods "Arith.Multiply" and
114 | "Arith.Divide". To invoke one, a client first dials the server:
115 |
116 | func (t *Arith) Divide(ctx context.Context, args *Args, quo *Quotient) error {
117 | if args.B == 0 {
118 | return errors.New("divide by zero")
119 | }
120 | quo.Quo = args.A / args.B
121 | quo.Rem = args.A % args.B
122 | return nil
123 | }
124 | func (t *Arith) Divide(args *Args, quo *Quotient) error {
125 | if args.B == 0 {
126 | return errors.New("divide by zero")
127 | }
128 | quo.Quo = args.A / args.B
129 | quo.Rem = args.A % args.B
130 | return nil
131 | }
132 |
133 | Then it can make a remote call:
134 |
135 | // Synchronous call
136 | args := &server.Args{7,8}
137 | var reply int
138 | err = client.Call(context.Background(), "Arith.Multiply", args, &reply)
139 | if err != nil {
140 | log.Fatal("arith error:", err)
141 | }
142 | fmt.Printf("Arith: %d*%d=%d", args.A, args.B, reply)
143 |
144 | or
145 |
146 | // Asynchronous call
147 | quotient := new(Quotient)
148 | divCall := client.Go("Arith.Divide", args, quotient, nil)
149 | replyCall := <-divCall.Done // will be equal to divCall
150 | // check errors, print, etc.
151 |
152 | A server implementation will often provide a simple, type-safe wrapper for the
153 | client.
154 |
155 | The net/rpc package is frozen and is not accepting new features.
156 | */
157 | package rpc
158 |
159 | import (
160 | "bufio"
161 | "context"
162 | "encoding/gob"
163 | "errors"
164 | "go/token"
165 | "io"
166 | "log"
167 | "net"
168 | "net/http"
169 | "reflect"
170 | "strings"
171 | "sync"
172 |
173 | "github.com/keegancsmith/rpc/internal/svc"
174 | )
175 |
176 | const (
177 | // Defaults used by HandleHTTP
178 | DefaultRPCPath = "/_goRPC_"
179 | DefaultDebugPath = "/debug/rpc"
180 | )
181 |
182 | // Precompute the reflect type for error. Can't use error directly
183 | // because Typeof takes an empty interface value. This is annoying.
184 | var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
185 | var typeOfCtx = reflect.TypeOf((*context.Context)(nil)).Elem()
186 |
187 | type methodType struct {
188 | sync.Mutex // protects counters
189 | method reflect.Method
190 | ArgType reflect.Type
191 | ReplyType reflect.Type
192 | numCalls uint
193 | }
194 |
195 | type service struct {
196 | name string // name of service
197 | rcvr reflect.Value // receiver of methods for the service
198 | typ reflect.Type // type of the receiver
199 | method map[string]*methodType // registered methods
200 | }
201 |
202 | // Request is a header written before every RPC call. It is used internally
203 | // but documented here as an aid to debugging, such as when analyzing
204 | // network traffic.
205 | type Request struct {
206 | ServiceMethod string // format: "Service.Method"
207 | Seq uint64 // sequence number chosen by client
208 | next *Request // for free list in Server
209 | }
210 |
211 | // Response is a header written before every RPC return. It is used internally
212 | // but documented here as an aid to debugging, such as when analyzing
213 | // network traffic.
214 | type Response struct {
215 | ServiceMethod string // echoes that of the Request
216 | Seq uint64 // echoes that of the request
217 | Error string // error, if any.
218 | next *Response // for free list in Server
219 | }
220 |
221 | // Server represents an RPC Server.
222 | type Server struct {
223 | serviceMap sync.Map // map[string]*service
224 | reqLock sync.Mutex // protects freeReq
225 | freeReq *Request
226 | respLock sync.Mutex // protects freeResp
227 | freeResp *Response
228 | }
229 |
230 | // NewServer returns a new Server.
231 | func NewServer() *Server {
232 | s := &Server{}
233 | s.RegisterName("_goRPC_", &svc.GoRPC{})
234 | return s
235 | }
236 |
237 | // DefaultServer is the default instance of *Server.
238 | var DefaultServer = NewServer()
239 |
240 | // Is this type exported or a builtin?
241 | func isExportedOrBuiltinType(t reflect.Type) bool {
242 | for t.Kind() == reflect.Pointer {
243 | t = t.Elem()
244 | }
245 | // PkgPath will be non-empty even for an exported type,
246 | // so we need to check the type name as well.
247 | return token.IsExported(t.Name()) || t.PkgPath() == ""
248 | }
249 |
250 | // Register publishes in the server the set of methods of the
251 | // receiver value that satisfy the following conditions:
252 | // - exported method of exported type
253 | // - two arguments, both of exported type
254 | // - the second argument is a pointer
255 | // - one return value, of type error
256 | //
257 | // It returns an error if the receiver is not an exported type or has
258 | // no suitable methods. It also logs the error using package log.
259 | // The client accesses each method using a string of the form "Type.Method",
260 | // where Type is the receiver's concrete type.
261 | func (server *Server) Register(rcvr any) error {
262 | return server.register(rcvr, "", false)
263 | }
264 |
265 | // RegisterName is like Register but uses the provided name for the type
266 | // instead of the receiver's concrete type.
267 | func (server *Server) RegisterName(name string, rcvr any) error {
268 | return server.register(rcvr, name, true)
269 | }
270 |
271 | // logRegisterError specifies whether to log problems during method registration.
272 | // To debug registration, recompile the package with this set to true.
273 | const logRegisterError = false
274 |
275 | func (server *Server) register(rcvr any, name string, useName bool) error {
276 | s := new(service)
277 | s.typ = reflect.TypeOf(rcvr)
278 | s.rcvr = reflect.ValueOf(rcvr)
279 | sname := name
280 | if !useName {
281 | sname = reflect.Indirect(s.rcvr).Type().Name()
282 | }
283 | if sname == "" {
284 | s := "rpc.Register: no service name for type " + s.typ.String()
285 | log.Print(s)
286 | return errors.New(s)
287 | }
288 | if !useName && !token.IsExported(sname) {
289 | s := "rpc.Register: type " + sname + " is not exported"
290 | log.Print(s)
291 | return errors.New(s)
292 | }
293 | s.name = sname
294 |
295 | // Install the methods
296 | s.method = suitableMethods(s.typ, logRegisterError)
297 |
298 | if len(s.method) == 0 {
299 | str := ""
300 |
301 | // To help the user, see if a pointer receiver would work.
302 | method := suitableMethods(reflect.PointerTo(s.typ), false)
303 | if len(method) != 0 {
304 | str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
305 | } else {
306 | str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
307 | }
308 | log.Print(str)
309 | return errors.New(str)
310 | }
311 |
312 | if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
313 | return errors.New("rpc: service already defined: " + sname)
314 | }
315 | return nil
316 | }
317 |
318 | // suitableMethods returns suitable Rpc methods of typ. It will log
319 | // errors if logErr is true.
320 | func suitableMethods(typ reflect.Type, logErr bool) map[string]*methodType {
321 | methods := make(map[string]*methodType)
322 | for m := 0; m < typ.NumMethod(); m++ {
323 | method := typ.Method(m)
324 | mtype := method.Type
325 | mname := method.Name
326 | // Method must be exported.
327 | if !method.IsExported() {
328 | continue
329 | }
330 | // Method needs four ins: receiver, ctx, *args, *reply.
331 | if mtype.NumIn() != 4 {
332 | if logErr {
333 | log.Printf("rpc.Register: method %q has %d input parameters; needs exactly three\n", mname, mtype.NumIn())
334 | }
335 | continue
336 | }
337 | // First arg must be context.Context
338 | if ctxType := mtype.In(1); ctxType != typeOfCtx {
339 | if logErr {
340 | log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, ctxType)
341 | }
342 | continue
343 | }
344 | // Second arg need not be a pointer.
345 | argType := mtype.In(2)
346 | if !isExportedOrBuiltinType(argType) {
347 | if logErr {
348 | log.Printf("rpc.Register: argument type of method %q is not exported: %q\n", mname, argType)
349 | }
350 | continue
351 | }
352 | // Third arg must be a pointer.
353 | replyType := mtype.In(3)
354 | if replyType.Kind() != reflect.Ptr {
355 | if logErr {
356 | log.Printf("rpc.Register: reply type of method %q is not a pointer: %q\n", mname, replyType)
357 | }
358 | continue
359 | }
360 | // Reply type must be exported.
361 | if !isExportedOrBuiltinType(replyType) {
362 | if logErr {
363 | log.Printf("rpc.Register: reply type of method %q is not exported: %q\n", mname, replyType)
364 | }
365 | continue
366 | }
367 | // Method needs one out.
368 | if mtype.NumOut() != 1 {
369 | if logErr {
370 | log.Printf("rpc.Register: method %q has %d output parameters; needs exactly one\n", mname, mtype.NumOut())
371 | }
372 | continue
373 | }
374 | // The return type of the method must be error.
375 | if returnType := mtype.Out(0); returnType != typeOfError {
376 | if logErr {
377 | log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, returnType)
378 | }
379 | continue
380 | }
381 | methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
382 | }
383 | return methods
384 | }
385 |
386 | // A value sent as a placeholder for the server's response value when the server
387 | // receives an invalid request. It is never decoded by the client since the Response
388 | // contains an error when it is used.
389 | var invalidRequest = struct{}{}
390 |
391 | func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply any, codec ServerCodec, errmsg string) {
392 | resp := server.getResponse()
393 | // Encode the response header
394 | resp.ServiceMethod = req.ServiceMethod
395 | if errmsg != "" {
396 | resp.Error = errmsg
397 | reply = invalidRequest
398 | }
399 | resp.Seq = req.Seq
400 | sending.Lock()
401 | err := codec.WriteResponse(resp, reply)
402 | if debugLog && err != nil {
403 | log.Println("rpc: writing response:", err)
404 | }
405 | sending.Unlock()
406 | server.freeResponse(resp)
407 | }
408 |
409 | func (m *methodType) NumCalls() (n uint) {
410 | m.Lock()
411 | n = m.numCalls
412 | m.Unlock()
413 | return n
414 | }
415 |
416 | func (s *service) call(server *Server, sending *sync.Mutex, pending *svc.Pending, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
417 | if wg != nil {
418 | defer wg.Done()
419 | }
420 | // _goRPC_ service calls require internal state.
421 | if s.name == "_goRPC_" {
422 | switch v := argv.Interface().(type) {
423 | case *svc.CancelArgs:
424 | v.SetPending(pending)
425 | }
426 | }
427 | mtype.Lock()
428 | mtype.numCalls++
429 | mtype.Unlock()
430 | ctx := pending.Start(req.Seq)
431 | defer pending.Cancel(req.Seq)
432 | function := mtype.method.Func
433 | // Invoke the method, providing a new value for the reply.
434 | returnValues := function.Call([]reflect.Value{s.rcvr, reflect.ValueOf(ctx), argv, replyv})
435 | // The return value for the method is an error.
436 | errInter := returnValues[0].Interface()
437 | errmsg := ""
438 | if errInter != nil {
439 | errmsg = errInter.(error).Error()
440 | }
441 | server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
442 | server.freeRequest(req)
443 | }
444 |
445 | type gobServerCodec struct {
446 | rwc io.ReadWriteCloser
447 | dec *gob.Decoder
448 | enc *gob.Encoder
449 | encBuf *bufio.Writer
450 | closed bool
451 | }
452 |
453 | func (c *gobServerCodec) ReadRequestHeader(r *Request) error {
454 | return c.dec.Decode(r)
455 | }
456 |
457 | func (c *gobServerCodec) ReadRequestBody(body any) error {
458 | return c.dec.Decode(body)
459 | }
460 |
461 | func (c *gobServerCodec) WriteResponse(r *Response, body any) (err error) {
462 | if err = c.enc.Encode(r); err != nil {
463 | if c.encBuf.Flush() == nil {
464 | // Gob couldn't encode the header. Should not happen, so if it does,
465 | // shut down the connection to signal that the connection is broken.
466 | log.Println("rpc: gob error encoding response:", err)
467 | c.Close()
468 | }
469 | return
470 | }
471 | if err = c.enc.Encode(body); err != nil {
472 | if c.encBuf.Flush() == nil {
473 | // Was a gob problem encoding the body but the header has been written.
474 | // Shut down the connection to signal that the connection is broken.
475 | log.Println("rpc: gob error encoding body:", err)
476 | c.Close()
477 | }
478 | return
479 | }
480 | return c.encBuf.Flush()
481 | }
482 |
483 | func (c *gobServerCodec) Close() error {
484 | if c.closed {
485 | // Only call c.rwc.Close once; otherwise the semantics are undefined.
486 | return nil
487 | }
488 | c.closed = true
489 | return c.rwc.Close()
490 | }
491 |
492 | // ServeConn runs the server on a single connection.
493 | // ServeConn blocks, serving the connection until the client hangs up.
494 | // The caller typically invokes ServeConn in a go statement.
495 | // ServeConn uses the gob wire format (see package gob) on the
496 | // connection. To use an alternate codec, use ServeCodec.
497 | // See NewClient's comment for information about concurrent access.
498 | func (server *Server) ServeConn(conn io.ReadWriteCloser) {
499 | buf := bufio.NewWriter(conn)
500 | srv := &gobServerCodec{
501 | rwc: conn,
502 | dec: gob.NewDecoder(conn),
503 | enc: gob.NewEncoder(buf),
504 | encBuf: buf,
505 | }
506 | server.ServeCodec(srv)
507 | }
508 |
509 | // ServeCodec is like ServeConn but uses the specified codec to
510 | // decode requests and encode responses.
511 | func (server *Server) ServeCodec(codec ServerCodec) {
512 | sending := new(sync.Mutex)
513 | ctx, cancel := context.WithCancel(context.Background())
514 | defer cancel()
515 | pending := svc.NewPending(ctx)
516 | wg := new(sync.WaitGroup)
517 | for {
518 | service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
519 | if err != nil {
520 | if debugLog && err != io.EOF {
521 | log.Println("rpc:", err)
522 | }
523 | if !keepReading {
524 | break
525 | }
526 | // send a response if we actually managed to read a header.
527 | if req != nil {
528 | server.sendResponse(sending, req, invalidRequest, codec, err.Error())
529 | server.freeRequest(req)
530 | }
531 | continue
532 | }
533 | wg.Add(1)
534 | go service.call(server, sending, pending, wg, mtype, req, argv, replyv, codec)
535 | }
536 | // We've seen that there are no more requests.
537 | // Wait for responses to be sent before closing codec.
538 | wg.Wait()
539 | codec.Close()
540 | }
541 |
542 | // ServeRequest is like ServeCodec but synchronously serves a single request.
543 | // It does not close the codec upon completion.
544 | func (server *Server) ServeRequest(codec ServerCodec) error {
545 | return server.ServeRequestContext(context.Background(), codec)
546 | }
547 |
548 | // ServeRequest is like ServeCodec but synchronously serves a single request.
549 | // It does not close the codec upon completion.
550 | //
551 | // Cancelling the context given here will propagate cancellation to the context
552 | // of the called function.
553 | func (server *Server) ServeRequestContext(ctx context.Context, codec ServerCodec) error {
554 | sending := new(sync.Mutex)
555 | pending := svc.NewPending(ctx)
556 | service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
557 | if err != nil {
558 | if !keepReading {
559 | return err
560 | }
561 | // send a response if we actually managed to read a header.
562 | if req != nil {
563 | server.sendResponse(sending, req, invalidRequest, codec, err.Error())
564 | server.freeRequest(req)
565 | }
566 | return err
567 | }
568 | service.call(server, sending, pending, nil, mtype, req, argv, replyv, codec)
569 | return nil
570 | }
571 |
572 | func (server *Server) getRequest() *Request {
573 | server.reqLock.Lock()
574 | req := server.freeReq
575 | if req == nil {
576 | req = new(Request)
577 | } else {
578 | server.freeReq = req.next
579 | *req = Request{}
580 | }
581 | server.reqLock.Unlock()
582 | return req
583 | }
584 |
585 | func (server *Server) freeRequest(req *Request) {
586 | server.reqLock.Lock()
587 | req.next = server.freeReq
588 | server.freeReq = req
589 | server.reqLock.Unlock()
590 | }
591 |
592 | func (server *Server) getResponse() *Response {
593 | server.respLock.Lock()
594 | resp := server.freeResp
595 | if resp == nil {
596 | resp = new(Response)
597 | } else {
598 | server.freeResp = resp.next
599 | *resp = Response{}
600 | }
601 | server.respLock.Unlock()
602 | return resp
603 | }
604 |
605 | func (server *Server) freeResponse(resp *Response) {
606 | server.respLock.Lock()
607 | resp.next = server.freeResp
608 | server.freeResp = resp
609 | server.respLock.Unlock()
610 | }
611 |
612 | func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
613 | service, mtype, req, keepReading, err = server.readRequestHeader(codec)
614 | if err != nil {
615 | if !keepReading {
616 | return
617 | }
618 | // discard body
619 | codec.ReadRequestBody(nil)
620 | return
621 | }
622 |
623 | // Decode the argument value.
624 | argIsValue := false // if true, need to indirect before calling.
625 | if mtype.ArgType.Kind() == reflect.Pointer {
626 | argv = reflect.New(mtype.ArgType.Elem())
627 | } else {
628 | argv = reflect.New(mtype.ArgType)
629 | argIsValue = true
630 | }
631 | // argv guaranteed to be a pointer now.
632 | if err = codec.ReadRequestBody(argv.Interface()); err != nil {
633 | return
634 | }
635 | if argIsValue {
636 | argv = argv.Elem()
637 | }
638 |
639 | replyv = reflect.New(mtype.ReplyType.Elem())
640 |
641 | switch mtype.ReplyType.Elem().Kind() {
642 | case reflect.Map:
643 | replyv.Elem().Set(reflect.MakeMap(mtype.ReplyType.Elem()))
644 | case reflect.Slice:
645 | replyv.Elem().Set(reflect.MakeSlice(mtype.ReplyType.Elem(), 0, 0))
646 | }
647 | return
648 | }
649 |
650 | func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) {
651 | // Grab the request header.
652 | req = server.getRequest()
653 | err = codec.ReadRequestHeader(req)
654 | if err != nil {
655 | req = nil
656 | if err == io.EOF || err == io.ErrUnexpectedEOF {
657 | return
658 | }
659 | err = errors.New("rpc: server cannot decode request: " + err.Error())
660 | return
661 | }
662 |
663 | // We read the header successfully. If we see an error now,
664 | // we can still recover and move on to the next request.
665 | keepReading = true
666 |
667 | dot := strings.LastIndex(req.ServiceMethod, ".")
668 | if dot < 0 {
669 | err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
670 | return
671 | }
672 | serviceName := req.ServiceMethod[:dot]
673 | methodName := req.ServiceMethod[dot+1:]
674 |
675 | // Look up the request.
676 | svci, ok := server.serviceMap.Load(serviceName)
677 | if !ok {
678 | err = errors.New("rpc: can't find service " + req.ServiceMethod)
679 | return
680 | }
681 | svc = svci.(*service)
682 | mtype = svc.method[methodName]
683 | if mtype == nil {
684 | err = errors.New("rpc: can't find method " + req.ServiceMethod)
685 | }
686 | return
687 | }
688 |
689 | // Accept accepts connections on the listener and serves requests
690 | // for each incoming connection. Accept blocks until the listener
691 | // returns a non-nil error. The caller typically invokes Accept in a
692 | // go statement.
693 | func (server *Server) Accept(lis net.Listener) {
694 | for {
695 | conn, err := lis.Accept()
696 | if err != nil {
697 | log.Print("rpc.Serve: accept:", err.Error())
698 | return
699 | }
700 | go server.ServeConn(conn)
701 | }
702 | }
703 |
704 | // Register publishes the receiver's methods in the DefaultServer.
705 | func Register(rcvr any) error { return DefaultServer.Register(rcvr) }
706 |
707 | // RegisterName is like Register but uses the provided name for the type
708 | // instead of the receiver's concrete type.
709 | func RegisterName(name string, rcvr any) error {
710 | return DefaultServer.RegisterName(name, rcvr)
711 | }
712 |
713 | // A ServerCodec implements reading of RPC requests and writing of
714 | // RPC responses for the server side of an RPC session.
715 | // The server calls ReadRequestHeader and ReadRequestBody in pairs
716 | // to read requests from the connection, and it calls WriteResponse to
717 | // write a response back. The server calls Close when finished with the
718 | // connection. ReadRequestBody may be called with a nil
719 | // argument to force the body of the request to be read and discarded.
720 | // See NewClient's comment for information about concurrent access.
721 | type ServerCodec interface {
722 | ReadRequestHeader(*Request) error
723 | ReadRequestBody(any) error
724 | WriteResponse(*Response, any) error
725 |
726 | // Close can be called multiple times and must be idempotent.
727 | Close() error
728 | }
729 |
730 | // ServeConn runs the DefaultServer on a single connection.
731 | // ServeConn blocks, serving the connection until the client hangs up.
732 | // The caller typically invokes ServeConn in a go statement.
733 | // ServeConn uses the gob wire format (see package gob) on the
734 | // connection. To use an alternate codec, use ServeCodec.
735 | // See NewClient's comment for information about concurrent access.
736 | func ServeConn(conn io.ReadWriteCloser) {
737 | DefaultServer.ServeConn(conn)
738 | }
739 |
740 | // ServeCodec is like ServeConn but uses the specified codec to
741 | // decode requests and encode responses.
742 | func ServeCodec(codec ServerCodec) {
743 | DefaultServer.ServeCodec(codec)
744 | }
745 |
746 | // ServeRequest is like ServeCodec but synchronously serves a single request.
747 | // It does not close the codec upon completion.
748 | func ServeRequest(codec ServerCodec) error {
749 | return ServeRequestContext(context.Background(), codec)
750 | }
751 |
752 | // ServeRequest is like ServeCodec but synchronously serves a single request.
753 | // It does not close the codec upon completion.
754 | //
755 | // Cancelling the context given here will propagate cancellation to the context
756 | // of the called function.
757 | func ServeRequestContext(ctx context.Context, codec ServerCodec) error {
758 | return DefaultServer.ServeRequestContext(ctx, codec)
759 | }
760 |
761 | // Accept accepts connections on the listener and serves requests
762 | // to DefaultServer for each incoming connection.
763 | // Accept blocks; the caller typically invokes it in a go statement.
764 | func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
765 |
766 | // Can connect to RPC service using HTTP CONNECT to rpcPath.
767 | var connected = "200 Connected to Go RPC"
768 |
769 | // ServeHTTP implements an http.Handler that answers RPC requests.
770 | func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
771 | if req.Method != "CONNECT" {
772 | w.Header().Set("Content-Type", "text/plain; charset=utf-8")
773 | w.WriteHeader(http.StatusMethodNotAllowed)
774 | io.WriteString(w, "405 must CONNECT\n")
775 | return
776 | }
777 | conn, _, err := w.(http.Hijacker).Hijack()
778 | if err != nil {
779 | log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
780 | return
781 | }
782 | io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
783 | server.ServeConn(conn)
784 | }
785 |
786 | // HandleHTTP registers an HTTP handler for RPC messages on rpcPath,
787 | // and a debugging handler on debugPath.
788 | // It is still necessary to invoke http.Serve(), typically in a go statement.
789 | func (server *Server) HandleHTTP(rpcPath, debugPath string) {
790 | http.Handle(rpcPath, server)
791 | http.Handle(debugPath, debugHTTP{server})
792 | }
793 |
794 | // HandleHTTP registers an HTTP handler for RPC messages to DefaultServer
795 | // on DefaultRPCPath and a debugging handler on DefaultDebugPath.
796 | // It is still necessary to invoke http.Serve(), typically in a go statement.
797 | func HandleHTTP() {
798 | DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
799 | }
800 |
--------------------------------------------------------------------------------