├── .github └── workflows │ ├── go-check.yml │ ├── go-test.yml │ ├── release-check.yml │ ├── releaser.yml │ └── tagpush.yml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── auth ├── auth.go └── handler.go ├── client.go ├── errors.go ├── go.mod ├── go.sum ├── handler.go ├── httpio ├── README ├── reader.go └── reader_test.go ├── method_formatter.go ├── method_formatter_test.go ├── metrics └── metrics.go ├── options.go ├── options_server.go ├── resp_error_test.go ├── response.go ├── rpc_test.go ├── server.go ├── util.go ├── version.json └── websocket.go /.github/workflows/go-check.yml: -------------------------------------------------------------------------------- 1 | name: Go Checks 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: ["main"] 7 | workflow_dispatch: 8 | 9 | permissions: 10 | contents: read 11 | 12 | concurrency: 13 | group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.event_name == 'push' && github.sha || github.ref }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | go-check: 18 | uses: ipdxco/unified-github-workflows/.github/workflows/go-check.yml@v1.0 19 | -------------------------------------------------------------------------------- /.github/workflows/go-test.yml: -------------------------------------------------------------------------------- 1 | name: Go Test 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: ["main"] 7 | workflow_dispatch: 8 | 9 | permissions: 10 | contents: read 11 | 12 | concurrency: 13 | group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.event_name == 'push' && github.sha || github.ref }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | go-test: 18 | uses: ipdxco/unified-github-workflows/.github/workflows/go-test.yml@v1.0 19 | secrets: 20 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 21 | -------------------------------------------------------------------------------- /.github/workflows/release-check.yml: -------------------------------------------------------------------------------- 1 | name: Release Checker 2 | 3 | on: 4 | pull_request_target: 5 | paths: [ 'version.json' ] 6 | types: [ opened, synchronize, reopened, labeled, unlabeled ] 7 | workflow_dispatch: 8 | 9 | permissions: 10 | contents: write 11 | pull-requests: write 12 | 13 | concurrency: 14 | group: ${{ github.workflow }}-${{ github.ref }} 15 | cancel-in-progress: true 16 | 17 | jobs: 18 | release-check: 19 | uses: ipdxco/unified-github-workflows/.github/workflows/release-check.yml@v1.0 20 | -------------------------------------------------------------------------------- /.github/workflows/releaser.yml: -------------------------------------------------------------------------------- 1 | name: Releaser 2 | 3 | on: 4 | push: 5 | paths: [ 'version.json' ] 6 | workflow_dispatch: 7 | 8 | permissions: 9 | contents: write 10 | 11 | concurrency: 12 | group: ${{ github.workflow }}-${{ github.sha }} 13 | cancel-in-progress: true 14 | 15 | jobs: 16 | releaser: 17 | uses: ipdxco/unified-github-workflows/.github/workflows/releaser.yml@v1.0 18 | -------------------------------------------------------------------------------- /.github/workflows/tagpush.yml: -------------------------------------------------------------------------------- 1 | name: Tag Push Checker 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* 7 | 8 | permissions: 9 | contents: read 10 | issues: write 11 | 12 | concurrency: 13 | group: ${{ github.workflow }}-${{ github.ref }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | releaser: 18 | uses: ipdxco/unified-github-workflows/.github/workflows/tagpush.yml@v1.0 19 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at 2 | 3 | http://www.apache.org/licenses/LICENSE-2.0 4 | 5 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. 6 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | go-jsonrpc 2 | ================== 3 | 4 | [![go.dev reference](https://img.shields.io/badge/go.dev-reference-007d9c?logo=go&logoColor=white&style=flat-square)](https://pkg.go.dev/github.com/filecoin-project/go-jsonrpc) 5 | [![](https://img.shields.io/badge/made%20by-Protocol%20Labs-blue.svg?style=flat-square)](https://protocol.ai) 6 | 7 | > Low Boilerplate JSON-RPC 2.0 library 8 | 9 | ## Usage examples 10 | 11 | ### Server 12 | 13 | ```go 14 | // Have a type with some exported methods 15 | type SimpleServerHandler struct { 16 | n int 17 | } 18 | 19 | func (h *SimpleServerHandler) AddGet(in int) int { 20 | h.n += in 21 | return h.n 22 | } 23 | 24 | func main() { 25 | // create a new server instance 26 | rpcServer := jsonrpc.NewServer() 27 | 28 | // create a handler instance and register it 29 | serverHandler := &SimpleServerHandler{} 30 | rpcServer.Register("SimpleServerHandler", serverHandler) 31 | 32 | // rpcServer is now http.Handler which will serve jsonrpc calls to SimpleServerHandler.AddGet 33 | // a method with a single int param, and an int response. The server supports both http and websockets. 34 | 35 | // serve the api 36 | testServ := httptest.NewServer(rpcServer) 37 | defer testServ.Close() 38 | 39 | fmt.Println("URL: ", "ws://"+testServ.Listener.Addr().String()) 40 | 41 | [..do other app stuff / wait..] 42 | } 43 | ``` 44 | 45 | ### Client 46 | ```go 47 | func start() error { 48 | // Create a struct where each field is an exported function with signatures matching rpc calls 49 | var client struct { 50 | AddGet func(int) int 51 | } 52 | 53 | // Make jsonrp populate func fields in the struct with JSONRPC calls 54 | closer, err := jsonrpc.NewClient(context.Background(), rpcURL, "SimpleServerHandler", &client, nil) 55 | if err != nil { 56 | return err 57 | } 58 | defer closer() 59 | 60 | ... 61 | 62 | n := client.AddGet(10) 63 | // if the server is the one from the example above, n = 10 64 | 65 | n := client.AddGet(2) 66 | // if the server is the one from the example above, n = 12 67 | } 68 | ``` 69 | 70 | ### Supported function signatures 71 | 72 | ```go 73 | type _ interface { 74 | // No Params / Return val 75 | Func1() 76 | 77 | // With Params 78 | // Note: If param types implement json.[Un]Marshaler, go-jsonrpc will use it 79 | Func2(param1 int, param2 string, param3 struct{A int}) 80 | 81 | // Returning errors 82 | // * For some connection errors, go-jsonrpc will return jsonrpc.RPCConnectionError{}. 83 | // * RPC-returned errors will be constructed with basic errors.New(__"string message"__) 84 | // * JSON-RPC error codes can be mapped to typed errors with jsonrpc.Errors - https://pkg.go.dev/github.com/filecoin-project/go-jsonrpc#Errors 85 | // * For typed errors to work, server needs to be constructed with the `WithServerErrors` 86 | // option, and the client needs to be constructed with the `WithErrors` option 87 | Func3() error 88 | 89 | // Returning a value 90 | // Note: The value must be serializable with encoding/json. 91 | Func4() int 92 | 93 | // Returning a value and an error 94 | // Note: if the handler returns an error and a non-zero value, the value will not 95 | // be returned to the client - the client will see a zero value. 96 | Func4() (int, error) 97 | 98 | // With context 99 | // * Context isn't passed as JSONRPC param, instead it has a number of different uses 100 | // * When the context is cancelled on the client side, context cancellation should propagate to the server handler 101 | // * In http mode the http request will be aborted 102 | // * In websocket mode the client will send a `xrpc.cancel` with a single param containing ID of the cancelled request 103 | // * If the context contains an opencensus trace span, it will be propagated to the server through a 104 | // `"Meta": {"SpanContext": base64.StdEncoding.EncodeToString(propagation.Binary(span.SpanContext()))}` field in 105 | // the jsonrpc request 106 | // 107 | Func5(ctx context.Context, param1 string) error 108 | 109 | // With non-json-serializable (e.g. interface) params 110 | // * There are client and server options which make it possible to register transformers for types 111 | // to make them json-(de)serializable 112 | // * Server side: jsonrpc.WithParamDecoder(new(io.Reader), func(ctx context.Context, b []byte) (reflect.Value, error) { ... } 113 | // * Client side: jsonrpc.WithParamEncoder(new(io.Reader), func(value reflect.Value) (reflect.Value, error) { ... } 114 | // * For io.Reader specifically there's a simple param encoder/decoder implementation in go-jsonrpc/httpio package 115 | // which will pass reader data through separate http streams on a different hanhler. 116 | // * Note: a similar mechanism for return value transformation isn't supported yet 117 | Func6(r io.Reader) 118 | 119 | // Returning a channel 120 | // * Only supported in websocket mode 121 | // * If no error is returned, the return value will be an int channelId 122 | // * When the server handler writes values into the channel, the client will receive `xrpc.ch.val` notifications 123 | // with 2 params: [chanID: int, value: any] 124 | // * When the channel is closed the client will receive `xrpc.ch.close` notification with a single param: [chanId: int] 125 | // * The client-side channel will be closed when the websocket connection breaks; Server side will discard writes to 126 | // the channel. Handlers should rely on the context to know when to stop writing to the returned channel. 127 | // NOTE: There is no good backpressure mechanism implemented for channels, returning values faster that the client can 128 | // receive them may cause memory leaks. 129 | Func7(ctx context.Context, param1 int, param2 string) (<-chan int, error) 130 | } 131 | 132 | ``` 133 | 134 | ### Custom Transport Feature 135 | The go-jsonrpc library supports creating clients with custom transport mechanisms (e.g. use for IPC). This allows for greater flexibility in how requests are sent and received, enabling the use of custom protocols, special handling of requests, or integration with other systems. 136 | 137 | #### Example Usage of Custom Transport 138 | 139 | Here is an example demonstrating how to create a custom client with a custom transport mechanism: 140 | 141 | ```go 142 | // Setup server 143 | serverHandler := &SimpleServerHandler{} // some type with methods 144 | 145 | rpcServer := jsonrpc.NewServer() 146 | rpcServer.Register("SimpleServerHandler", serverHandler) 147 | 148 | // Custom doRequest function 149 | doRequest := func(ctx context.Context, body []byte) (io.ReadCloser, error) { 150 | reader := bytes.NewReader(body) 151 | pr, pw := io.Pipe() 152 | go func() { 153 | defer pw.Close() 154 | rpcServer.HandleRequest(ctx, reader, pw) // handle the rpc frame 155 | }() 156 | return pr, nil 157 | } 158 | 159 | var client struct { 160 | Add func(int) error 161 | } 162 | 163 | // Create custom client 164 | closer, err := jsonrpc.NewCustomClient("SimpleServerHandler", []interface{}{&client}, doRequest) 165 | if err != nil { 166 | log.Fatalf("Failed to create client: %v", err) 167 | } 168 | defer closer() 169 | 170 | // Use the client 171 | if err := client.Add(10); err != nil { 172 | log.Fatalf("Failed to call Add: %v", err) 173 | } 174 | fmt.Printf("Current value: %d\n", client.AddGet(5)) 175 | ``` 176 | 177 | ### Reverse Calling Feature 178 | The go-jsonrpc library also supports reverse calling, where the server can make calls to the client. This is useful in scenarios where the server needs to notify or request data from the client. 179 | 180 | NOTE: Reverse calling only works in websocket mode 181 | 182 | #### Example Usage of Reverse Calling 183 | 184 | Here is an example demonstrating how to set up reverse calling: 185 | 186 | ```go 187 | // Define the client handler interface 188 | type ClientHandler struct { 189 | CallOnClient func(int) (int, error) 190 | } 191 | 192 | // Define the server handler 193 | type ServerHandler struct {} 194 | 195 | func (h *ServerHandler) Call(ctx context.Context) error { 196 | revClient, ok := jsonrpc.ExtractReverseClient[ClientHandler](ctx) 197 | if !ok { 198 | return fmt.Errorf("no reverse client") 199 | } 200 | 201 | result, err := revClient.CallOnClient(7) // Multiply by 2 on client 202 | if err != nil { 203 | return fmt.Errorf("call on client: %w", err) 204 | } 205 | 206 | if result != 14 { 207 | return fmt.Errorf("unexpected result: %d", result) 208 | } 209 | 210 | return nil 211 | } 212 | 213 | // Define client handler 214 | type RevCallTestClientHandler struct { 215 | } 216 | 217 | func (h *RevCallTestClientHandler) CallOnClient(a int) (int, error) { 218 | return a * 2, nil 219 | } 220 | 221 | // Setup server with reverse client capability 222 | rpcServer := jsonrpc.NewServer(jsonrpc.WithReverseClient[ClientHandler]("Client")) 223 | rpcServer.Register("ServerHandler", &ServerHandler{}) 224 | 225 | testServ := httptest.NewServer(rpcServer) 226 | defer testServ.Close() 227 | 228 | // Setup client with reverse call handler 229 | var client struct { 230 | Call func() error 231 | } 232 | 233 | closer, err := jsonrpc.NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "ServerHandler", []interface{}{ 234 | &client, 235 | }, nil, jsonrpc.WithClientHandler("Client", &RevCallTestClientHandler{})) 236 | if err != nil { 237 | log.Fatalf("Failed to create client: %v", err) 238 | } 239 | defer closer() 240 | 241 | // Make a call from the client to the server, which will trigger a reverse call 242 | if err := client.Call(); err != nil { 243 | log.Fatalf("Failed to call server: %v", err) 244 | } 245 | ``` 246 | 247 | ## Options 248 | 249 | ### Using `WithServerMethodNameFormatter` 250 | 251 | `WithServerMethodNameFormatter` allows you to customize a function that formats the JSON-RPC method name, given namespace and method name. 252 | 253 | There are four possible options: 254 | - `jsonrpc.DefaultMethodNameFormatter` - default method name formatter, e.g. `SimpleServerHandler.AddGet` 255 | - `jsonrpc.NewMethodNameFormatter(true, jsonrpc.LowerFirstCharCase)` - method name formatter with namespace, e.g. `SimpleServerHandler.addGet` 256 | - `jsonrpc.NewMethodNameFormatter(false, jsonrpc.OriginalCase)` - method name formatter without namespace, e.g. `AddGet` 257 | - `jsonrpc.NewMethodNameFormatter(false, jsonrpc.LowerFirstCharCase)` - method name formatter without namespace and with the first char lowercased, e.g. `addGet` 258 | 259 | > [!NOTE] 260 | > The default method name formatter concatenates the namespace and method name with a dot. 261 | > Go exported methods are capitalized, so, the method name will be capitalized as well. 262 | > e.g. `SimpleServerHandler.AddGet` (capital "A" in "AddGet") 263 | 264 | ```go 265 | func main() { 266 | // create a new server instance with a custom separator 267 | rpcServer := jsonrpc.NewServer(jsonrpc.WithServerMethodNameFormatter( 268 | func(namespace, method string) string { 269 | return namespace + "_" + method 270 | }), 271 | ) 272 | 273 | // create a handler instance and register it 274 | serverHandler := &SimpleServerHandler{} 275 | rpcServer.Register("SimpleServerHandler", serverHandler) 276 | 277 | // serve the api 278 | testServ := httptest.NewServer(rpcServer) 279 | defer testServ.Close() 280 | 281 | fmt.Println("URL: ", "ws://"+testServ.Listener.Addr().String()) 282 | 283 | // rpc method becomes SimpleServerHandler_AddGet 284 | 285 | [..do other app stuff / wait..] 286 | } 287 | ``` 288 | 289 | ### Using `WithMethodNameFormatter` 290 | 291 | `WithMethodNameFormatter` is the client-side counterpart to `WithServerMethodNameFormatter`. 292 | 293 | ```go 294 | func main() { 295 | closer, err := NewMergeClient( 296 | context.Background(), 297 | "http://example.com", 298 | "SimpleServerHandler", 299 | []any{&client}, 300 | nil, 301 | WithMethodNameFormatter(jsonrpc.NewMethodNameFormatter(false, OriginalCase)), 302 | ) 303 | defer closer() 304 | } 305 | ``` 306 | 307 | ## Contribute 308 | 309 | PRs are welcome! 310 | 311 | ## License 312 | 313 | Dual-licensed under [MIT](https://github.com/filecoin-project/go-jsonrpc/blob/master/LICENSE-MIT) + [Apache 2.0](https://github.com/filecoin-project/go-jsonrpc/blob/master/LICENSE-APACHE) 314 | -------------------------------------------------------------------------------- /auth/auth.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "context" 5 | "reflect" 6 | 7 | "golang.org/x/xerrors" 8 | ) 9 | 10 | type Permission string 11 | 12 | type permKey int 13 | 14 | var permCtxKey permKey 15 | 16 | func WithPerm(ctx context.Context, perms []Permission) context.Context { 17 | return context.WithValue(ctx, permCtxKey, perms) 18 | } 19 | 20 | func HasPerm(ctx context.Context, defaultPerms []Permission, perm Permission) bool { 21 | callerPerms, ok := ctx.Value(permCtxKey).([]Permission) 22 | if !ok { 23 | callerPerms = defaultPerms 24 | } 25 | 26 | for _, callerPerm := range callerPerms { 27 | if callerPerm == perm { 28 | return true 29 | } 30 | } 31 | return false 32 | } 33 | 34 | func PermissionedProxy(validPerms, defaultPerms []Permission, in interface{}, out interface{}) { 35 | rint := reflect.ValueOf(out).Elem() 36 | ra := reflect.ValueOf(in) 37 | 38 | for f := 0; f < rint.NumField(); f++ { 39 | field := rint.Type().Field(f) 40 | requiredPerm := Permission(field.Tag.Get("perm")) 41 | if requiredPerm == "" { 42 | panic("missing 'perm' tag on " + field.Name) // ok 43 | } 44 | 45 | // Validate perm tag 46 | ok := false 47 | for _, perm := range validPerms { 48 | if requiredPerm == perm { 49 | ok = true 50 | break 51 | } 52 | } 53 | if !ok { 54 | panic("unknown 'perm' tag on " + field.Name) // ok 55 | } 56 | 57 | fn := ra.MethodByName(field.Name) 58 | 59 | rint.Field(f).Set(reflect.MakeFunc(field.Type, func(args []reflect.Value) (results []reflect.Value) { 60 | ctx := args[0].Interface().(context.Context) 61 | if HasPerm(ctx, defaultPerms, requiredPerm) { 62 | return fn.Call(args) 63 | } 64 | 65 | err := xerrors.Errorf("missing permission to invoke '%s' (need '%s')", field.Name, requiredPerm) 66 | rerr := reflect.ValueOf(&err).Elem() 67 | 68 | if field.Type.NumOut() == 2 { 69 | return []reflect.Value{ 70 | reflect.Zero(field.Type.Out(0)), 71 | rerr, 72 | } 73 | } else { 74 | return []reflect.Value{rerr} 75 | } 76 | })) 77 | 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /auth/handler.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "strings" 7 | 8 | logging "github.com/ipfs/go-log/v2" 9 | ) 10 | 11 | var log = logging.Logger("auth") 12 | 13 | type Handler struct { 14 | Verify func(ctx context.Context, token string) ([]Permission, error) 15 | Next http.HandlerFunc 16 | } 17 | 18 | func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 19 | ctx := r.Context() 20 | 21 | token := r.Header.Get("Authorization") 22 | if token == "" { 23 | token = r.FormValue("token") 24 | if token != "" { 25 | token = "Bearer " + token 26 | } 27 | } 28 | 29 | if token != "" { 30 | if !strings.HasPrefix(token, "Bearer ") { 31 | log.Warn("missing Bearer prefix in auth header") 32 | w.WriteHeader(401) 33 | return 34 | } 35 | token = strings.TrimPrefix(token, "Bearer ") 36 | 37 | allow, err := h.Verify(ctx, token) 38 | if err != nil { 39 | log.Warnf("JWT Verification failed (originating from %s): %s", r.RemoteAddr, err) 40 | w.WriteHeader(401) 41 | return 42 | } 43 | 44 | ctx = WithPerm(ctx, allow) 45 | } 46 | 47 | h.Next(w, r.WithContext(ctx)) 48 | } 49 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package jsonrpc 2 | 3 | import ( 4 | "bytes" 5 | "container/list" 6 | "context" 7 | "encoding/base64" 8 | "encoding/json" 9 | "fmt" 10 | "io" 11 | "net" 12 | "net/http" 13 | "net/url" 14 | "reflect" 15 | "runtime/pprof" 16 | "sync/atomic" 17 | "time" 18 | 19 | "github.com/google/uuid" 20 | "github.com/gorilla/websocket" 21 | logging "github.com/ipfs/go-log/v2" 22 | "go.opencensus.io/trace" 23 | "go.opencensus.io/trace/propagation" 24 | "golang.org/x/xerrors" 25 | ) 26 | 27 | const ( 28 | methodMinRetryDelay = 100 * time.Millisecond 29 | methodMaxRetryDelay = 10 * time.Minute 30 | ) 31 | 32 | var ( 33 | errorType = reflect.TypeOf(new(error)).Elem() 34 | contextType = reflect.TypeOf(new(context.Context)).Elem() 35 | 36 | log = logging.Logger("rpc") 37 | 38 | _defaultHTTPClient = &http.Client{ 39 | Transport: &http.Transport{ 40 | Proxy: http.ProxyFromEnvironment, 41 | DialContext: (&net.Dialer{ 42 | Timeout: 30 * time.Second, 43 | KeepAlive: 30 * time.Second, 44 | DualStack: true, 45 | }).DialContext, 46 | ForceAttemptHTTP2: true, 47 | MaxIdleConns: 100, 48 | MaxIdleConnsPerHost: 100, 49 | IdleConnTimeout: 90 * time.Second, 50 | TLSHandshakeTimeout: 10 * time.Second, 51 | ExpectContinueTimeout: 1 * time.Second, 52 | }, 53 | } 54 | ) 55 | 56 | // ErrClient is an error which occurred on the client side the library 57 | type ErrClient struct { 58 | err error 59 | } 60 | 61 | func (e *ErrClient) Error() string { 62 | return fmt.Sprintf("RPC client error: %s", e.err) 63 | } 64 | 65 | // Unwrap unwraps the actual error 66 | func (e *ErrClient) Unwrap() error { 67 | return e.err 68 | } 69 | 70 | type clientResponse struct { 71 | Jsonrpc string `json:"jsonrpc"` 72 | Result json.RawMessage `json:"result"` 73 | ID interface{} `json:"id"` 74 | Error *JSONRPCError `json:"error,omitempty"` 75 | } 76 | 77 | type makeChanSink func() (context.Context, func([]byte, bool)) 78 | 79 | type clientRequest struct { 80 | req request 81 | ready chan clientResponse 82 | 83 | // retCh provides a context and sink for handling incoming channel messages 84 | retCh makeChanSink 85 | } 86 | 87 | // ClientCloser is used to close Client from further use 88 | type ClientCloser func() 89 | 90 | // NewClient creates new jsonrpc 2.0 client 91 | // 92 | // handler must be pointer to a struct with function fields 93 | // Returned value closes the client connection 94 | // TODO: Example 95 | func NewClient(ctx context.Context, addr string, namespace string, handler interface{}, requestHeader http.Header) (ClientCloser, error) { 96 | return NewMergeClient(ctx, addr, namespace, []interface{}{handler}, requestHeader) 97 | } 98 | 99 | type client struct { 100 | namespace string 101 | paramEncoders map[reflect.Type]ParamEncoder 102 | errors *Errors 103 | 104 | doRequest func(context.Context, clientRequest) (clientResponse, error) 105 | exiting <-chan struct{} 106 | idCtr int64 107 | 108 | methodNameFormatter MethodNameFormatter 109 | } 110 | 111 | // NewMergeClient is like NewClient, but allows to specify multiple structs 112 | // to be filled in the same namespace, using one connection 113 | func NewMergeClient(ctx context.Context, addr string, namespace string, outs []interface{}, requestHeader http.Header, opts ...Option) (ClientCloser, error) { 114 | config := defaultConfig() 115 | for _, o := range opts { 116 | o(&config) 117 | } 118 | 119 | u, err := url.Parse(addr) 120 | if err != nil { 121 | return nil, xerrors.Errorf("parsing address: %w", err) 122 | } 123 | 124 | switch u.Scheme { 125 | case "ws", "wss": 126 | return websocketClient(ctx, addr, namespace, outs, requestHeader, config) 127 | case "http", "https": 128 | return httpClient(ctx, addr, namespace, outs, requestHeader, config) 129 | default: 130 | return nil, xerrors.Errorf("unknown url scheme '%s'", u.Scheme) 131 | } 132 | 133 | } 134 | 135 | // NewCustomClient is like NewMergeClient in single-request (http) mode, except it allows for a custom doRequest function 136 | func NewCustomClient(namespace string, outs []interface{}, doRequest func(ctx context.Context, body []byte) (io.ReadCloser, error), opts ...Option) (ClientCloser, error) { 137 | config := defaultConfig() 138 | for _, o := range opts { 139 | o(&config) 140 | } 141 | 142 | c := client{ 143 | namespace: namespace, 144 | paramEncoders: config.paramEncoders, 145 | errors: config.errors, 146 | methodNameFormatter: config.methodNamer, 147 | } 148 | 149 | stop := make(chan struct{}) 150 | c.exiting = stop 151 | 152 | c.doRequest = func(ctx context.Context, cr clientRequest) (clientResponse, error) { 153 | b, err := json.Marshal(&cr.req) 154 | if err != nil { 155 | return clientResponse{}, xerrors.Errorf("marshalling request: %w", err) 156 | } 157 | 158 | if ctx == nil { 159 | ctx = context.Background() 160 | } 161 | 162 | rawResp, err := doRequest(ctx, b) 163 | if err != nil { 164 | return clientResponse{}, xerrors.Errorf("doRequest failed: %w", err) 165 | } 166 | 167 | defer rawResp.Close() 168 | 169 | var resp clientResponse 170 | if cr.req.ID != nil { // non-notification 171 | if err := json.NewDecoder(rawResp).Decode(&resp); err != nil { 172 | return clientResponse{}, xerrors.Errorf("unmarshaling response: %w", err) 173 | } 174 | 175 | if resp.ID, err = normalizeID(resp.ID); err != nil { 176 | return clientResponse{}, xerrors.Errorf("failed to response ID: %w", err) 177 | } 178 | 179 | if resp.ID != cr.req.ID { 180 | return clientResponse{}, xerrors.New("request and response id didn't match") 181 | } 182 | } 183 | 184 | return resp, nil 185 | } 186 | 187 | if err := c.provide(outs); err != nil { 188 | return nil, err 189 | } 190 | 191 | return func() { 192 | close(stop) 193 | }, nil 194 | } 195 | 196 | func httpClient(ctx context.Context, addr string, namespace string, outs []interface{}, requestHeader http.Header, config Config) (ClientCloser, error) { 197 | c := client{ 198 | namespace: namespace, 199 | paramEncoders: config.paramEncoders, 200 | errors: config.errors, 201 | methodNameFormatter: config.methodNamer, 202 | } 203 | 204 | stop := make(chan struct{}) 205 | c.exiting = stop 206 | 207 | if requestHeader == nil { 208 | requestHeader = http.Header{} 209 | } 210 | 211 | c.doRequest = func(ctx context.Context, cr clientRequest) (clientResponse, error) { 212 | b, err := json.Marshal(&cr.req) 213 | if err != nil { 214 | return clientResponse{}, xerrors.Errorf("marshalling request: %w", err) 215 | } 216 | 217 | hreq, err := http.NewRequest("POST", addr, bytes.NewReader(b)) 218 | if err != nil { 219 | return clientResponse{}, &RPCConnectionError{err} 220 | } 221 | 222 | hreq.Header = requestHeader.Clone() 223 | 224 | if ctx != nil { 225 | hreq = hreq.WithContext(ctx) 226 | } 227 | 228 | hreq.Header.Set("Content-Type", "application/json") 229 | 230 | httpResp, err := config.httpClient.Do(hreq) 231 | if err != nil { 232 | return clientResponse{}, &RPCConnectionError{err} 233 | } 234 | 235 | // likely a failure outside of our control and ability to inspect; jsonrpc server only ever 236 | // returns json format errors with either a StatusBadRequest or a StatusInternalServerError 237 | if httpResp.StatusCode > http.StatusBadRequest && httpResp.StatusCode != http.StatusInternalServerError { 238 | return clientResponse{}, xerrors.Errorf("request failed, http status %s", httpResp.Status) 239 | } 240 | 241 | defer httpResp.Body.Close() 242 | 243 | var resp clientResponse 244 | if cr.req.ID != nil { // non-notification 245 | if err := json.NewDecoder(httpResp.Body).Decode(&resp); err != nil { 246 | return clientResponse{}, xerrors.Errorf("http status %s unmarshaling response: %w", httpResp.Status, err) 247 | } 248 | 249 | if resp.ID, err = normalizeID(resp.ID); err != nil { 250 | return clientResponse{}, xerrors.Errorf("failed to response ID: %w", err) 251 | } 252 | 253 | if resp.ID != cr.req.ID { 254 | return clientResponse{}, xerrors.New("request and response id didn't match") 255 | } 256 | } 257 | 258 | return resp, nil 259 | } 260 | 261 | if err := c.provide(outs); err != nil { 262 | return nil, err 263 | } 264 | 265 | return func() { 266 | close(stop) 267 | }, nil 268 | } 269 | 270 | func websocketClient(ctx context.Context, addr string, namespace string, outs []interface{}, requestHeader http.Header, config Config) (ClientCloser, error) { 271 | connFactory := func() (*websocket.Conn, error) { 272 | conn, _, err := websocket.DefaultDialer.Dial(addr, requestHeader) 273 | if err != nil { 274 | return nil, &RPCConnectionError{xerrors.Errorf("cannot dial address %s for %w", addr, err)} 275 | } 276 | return conn, nil 277 | } 278 | 279 | if config.proxyConnFactory != nil { 280 | // used in tests 281 | connFactory = config.proxyConnFactory(connFactory) 282 | } 283 | 284 | conn, err := connFactory() 285 | if err != nil { 286 | return nil, err 287 | } 288 | 289 | if config.noReconnect { 290 | connFactory = nil 291 | } 292 | 293 | c := client{ 294 | namespace: namespace, 295 | paramEncoders: config.paramEncoders, 296 | errors: config.errors, 297 | methodNameFormatter: config.methodNamer, 298 | } 299 | 300 | requests := c.setupRequestChan() 301 | 302 | stop := make(chan struct{}) 303 | exiting := make(chan struct{}) 304 | c.exiting = exiting 305 | 306 | var hnd reqestHandler 307 | if len(config.reverseHandlers) > 0 { 308 | h := makeHandler(defaultServerConfig()) 309 | h.aliasedMethods = config.aliasedHandlerMethods 310 | for _, reverseHandler := range config.reverseHandlers { 311 | h.register(reverseHandler.ns, reverseHandler.hnd) 312 | } 313 | hnd = h 314 | } 315 | 316 | wconn := &wsConn{ 317 | conn: conn, 318 | connFactory: connFactory, 319 | reconnectBackoff: config.reconnectBackoff, 320 | pingInterval: config.pingInterval, 321 | timeout: config.timeout, 322 | handler: hnd, 323 | requests: requests, 324 | stop: stop, 325 | exiting: exiting, 326 | } 327 | 328 | go func() { 329 | lbl := pprof.Labels("jrpc-mode", "wsclient", "jrpc-remote", addr, "jrpc-local", conn.LocalAddr().String(), "jrpc-uuid", uuid.New().String()) 330 | pprof.Do(ctx, lbl, func(ctx context.Context) { 331 | wconn.handleWsConn(ctx) 332 | }) 333 | }() 334 | 335 | if err := c.provide(outs); err != nil { 336 | return nil, err 337 | } 338 | 339 | return func() { 340 | close(stop) 341 | <-exiting 342 | }, nil 343 | } 344 | 345 | func (c *client) setupRequestChan() chan clientRequest { 346 | requests := make(chan clientRequest) 347 | 348 | c.doRequest = func(ctx context.Context, cr clientRequest) (clientResponse, error) { 349 | select { 350 | case requests <- cr: 351 | case <-c.exiting: 352 | return clientResponse{}, fmt.Errorf("websocket routine exiting") 353 | } 354 | 355 | var ctxDone <-chan struct{} 356 | var resp clientResponse 357 | 358 | if ctx != nil { 359 | ctxDone = ctx.Done() 360 | } 361 | 362 | // wait for response, handle context cancellation 363 | loop: 364 | for { 365 | select { 366 | case resp = <-cr.ready: 367 | break loop 368 | case <-ctxDone: // send cancel request 369 | ctxDone = nil 370 | 371 | rp, err := json.Marshal([]param{{v: reflect.ValueOf(cr.req.ID)}}) 372 | if err != nil { 373 | return clientResponse{}, xerrors.Errorf("marshalling cancel request: %w", err) 374 | } 375 | 376 | cancelReq := clientRequest{ 377 | req: request{ 378 | Jsonrpc: "2.0", 379 | Method: wsCancel, 380 | Params: rp, 381 | }, 382 | ready: make(chan clientResponse, 1), 383 | } 384 | select { 385 | case requests <- cancelReq: 386 | case <-c.exiting: 387 | log.Warn("failed to send request cancellation, websocket routing exited") 388 | } 389 | 390 | } 391 | } 392 | 393 | return resp, nil 394 | } 395 | 396 | return requests 397 | } 398 | 399 | func (c *client) provide(outs []interface{}) error { 400 | for _, handler := range outs { 401 | htyp := reflect.TypeOf(handler) 402 | if htyp.Kind() != reflect.Ptr { 403 | return xerrors.New("expected handler to be a pointer") 404 | } 405 | typ := htyp.Elem() 406 | if typ.Kind() != reflect.Struct { 407 | return xerrors.New("handler should be a struct") 408 | } 409 | 410 | val := reflect.ValueOf(handler) 411 | 412 | for i := 0; i < typ.NumField(); i++ { 413 | fn, err := c.makeRpcFunc(typ.Field(i)) 414 | if err != nil { 415 | return err 416 | } 417 | 418 | val.Elem().Field(i).Set(fn) 419 | } 420 | } 421 | 422 | return nil 423 | } 424 | 425 | func (c *client) makeOutChan(ctx context.Context, ftyp reflect.Type, valOut int) (func() reflect.Value, makeChanSink) { 426 | retVal := reflect.Zero(ftyp.Out(valOut)) 427 | 428 | chCtor := func() (context.Context, func([]byte, bool)) { 429 | // unpack chan type to make sure it's reflect.BothDir 430 | ctyp := reflect.ChanOf(reflect.BothDir, ftyp.Out(valOut).Elem()) 431 | ch := reflect.MakeChan(ctyp, 0) // todo: buffer? 432 | retVal = ch.Convert(ftyp.Out(valOut)) 433 | 434 | incoming := make(chan reflect.Value, 32) 435 | 436 | // gorotuine to handle buffering of items 437 | go func() { 438 | buf := (&list.List{}).Init() 439 | 440 | for { 441 | front := buf.Front() 442 | 443 | cases := []reflect.SelectCase{ 444 | { 445 | Dir: reflect.SelectRecv, 446 | Chan: reflect.ValueOf(ctx.Done()), 447 | }, 448 | { 449 | Dir: reflect.SelectRecv, 450 | Chan: reflect.ValueOf(incoming), 451 | }, 452 | } 453 | 454 | if front != nil { 455 | cases = append(cases, reflect.SelectCase{ 456 | Dir: reflect.SelectSend, 457 | Chan: ch, 458 | Send: front.Value.(reflect.Value).Elem(), 459 | }) 460 | } 461 | 462 | chosen, val, ok := reflect.Select(cases) 463 | 464 | switch chosen { 465 | case 0: 466 | ch.Close() 467 | return 468 | case 1: 469 | if ok { 470 | vvval := val.Interface().(reflect.Value) 471 | buf.PushBack(vvval) 472 | if buf.Len() > 1 { 473 | if buf.Len() > 10 { 474 | log.Warnw("rpc output message buffer", "n", buf.Len()) 475 | } else { 476 | log.Debugw("rpc output message buffer", "n", buf.Len()) 477 | } 478 | } 479 | } else { 480 | incoming = nil 481 | } 482 | 483 | case 2: 484 | buf.Remove(front) 485 | } 486 | 487 | if incoming == nil && buf.Len() == 0 { 488 | ch.Close() 489 | return 490 | } 491 | } 492 | }() 493 | 494 | return ctx, func(result []byte, ok bool) { 495 | if !ok { 496 | close(incoming) 497 | return 498 | } 499 | 500 | val := reflect.New(ftyp.Out(valOut).Elem()) 501 | if err := json.Unmarshal(result, val.Interface()); err != nil { 502 | log.Errorf("error unmarshaling chan response: %s", err) 503 | return 504 | } 505 | 506 | if ctx.Err() != nil { 507 | log.Errorf("got rpc message with cancelled context: %s", ctx.Err()) 508 | return 509 | } 510 | 511 | select { 512 | case incoming <- val: 513 | case <-ctx.Done(): 514 | } 515 | } 516 | } 517 | 518 | return func() reflect.Value { return retVal }, chCtor 519 | } 520 | 521 | func (c *client) sendRequest(ctx context.Context, req request, chCtor makeChanSink) (clientResponse, error) { 522 | creq := clientRequest{ 523 | req: req, 524 | ready: make(chan clientResponse, 1), 525 | 526 | retCh: chCtor, 527 | } 528 | 529 | return c.doRequest(ctx, creq) 530 | } 531 | 532 | type rpcFunc struct { 533 | client *client 534 | 535 | ftyp reflect.Type 536 | name string 537 | 538 | nout int 539 | valOut int 540 | errOut int 541 | 542 | // hasCtx is 1 if the function has a context.Context as its first argument. 543 | // Used as the number of the first non-context argument. 544 | hasCtx int 545 | 546 | hasRawParams bool 547 | returnValueIsChannel bool 548 | 549 | retry bool 550 | notify bool 551 | } 552 | 553 | func (fn *rpcFunc) processResponse(resp clientResponse, rval reflect.Value) []reflect.Value { 554 | out := make([]reflect.Value, fn.nout) 555 | 556 | if fn.valOut != -1 { 557 | out[fn.valOut] = rval 558 | } 559 | if fn.errOut != -1 { 560 | out[fn.errOut] = reflect.New(errorType).Elem() 561 | if resp.Error != nil { 562 | 563 | out[fn.errOut].Set(resp.Error.val(fn.client.errors)) 564 | } 565 | } 566 | 567 | return out 568 | } 569 | 570 | func (fn *rpcFunc) processError(err error) []reflect.Value { 571 | out := make([]reflect.Value, fn.nout) 572 | 573 | if fn.valOut != -1 { 574 | out[fn.valOut] = reflect.New(fn.ftyp.Out(fn.valOut)).Elem() 575 | } 576 | if fn.errOut != -1 { 577 | out[fn.errOut] = reflect.New(errorType).Elem() 578 | out[fn.errOut].Set(reflect.ValueOf(&ErrClient{err})) 579 | } 580 | 581 | return out 582 | } 583 | 584 | func (fn *rpcFunc) handleRpcCall(args []reflect.Value) (results []reflect.Value) { 585 | var id interface{} 586 | if !fn.notify { 587 | id = atomic.AddInt64(&fn.client.idCtr, 1) 588 | 589 | // Prepare the ID to send on the wire. 590 | // We track int64 ids as float64 in the inflight map (because that's what 591 | // they'll be decoded to). encoding/json outputs numbers with their minimal 592 | // encoding, avoding the decimal point when possible, i.e. 3 will never get 593 | // converted to 3.0. 594 | var err error 595 | id, err = normalizeID(id) 596 | if err != nil { 597 | return fn.processError(fmt.Errorf("failed to normalize id")) // should probably panic 598 | } 599 | } 600 | 601 | var serializedParams json.RawMessage 602 | 603 | if fn.hasRawParams { 604 | serializedParams = json.RawMessage(args[fn.hasCtx].Interface().(RawParams)) 605 | } else { 606 | params := make([]param, len(args)-fn.hasCtx) 607 | for i, arg := range args[fn.hasCtx:] { 608 | enc, found := fn.client.paramEncoders[arg.Type()] 609 | if found { 610 | // custom param encoder 611 | var err error 612 | arg, err = enc(arg) 613 | if err != nil { 614 | return fn.processError(fmt.Errorf("sendRequest failed: %w", err)) 615 | } 616 | } 617 | 618 | params[i] = param{ 619 | v: arg, 620 | } 621 | } 622 | var err error 623 | serializedParams, err = json.Marshal(params) 624 | if err != nil { 625 | return fn.processError(fmt.Errorf("marshaling params failed: %w", err)) 626 | } 627 | } 628 | 629 | var ctx context.Context 630 | var span *trace.Span 631 | if fn.hasCtx == 1 { 632 | ctx = args[0].Interface().(context.Context) 633 | ctx, span = trace.StartSpan(ctx, "api.call") 634 | defer span.End() 635 | } 636 | 637 | retVal := func() reflect.Value { return reflect.Value{} } 638 | 639 | // if the function returns a channel, we need to provide a sink for the 640 | // messages 641 | var chCtor makeChanSink 642 | if fn.returnValueIsChannel { 643 | retVal, chCtor = fn.client.makeOutChan(ctx, fn.ftyp, fn.valOut) 644 | } 645 | 646 | req := request{ 647 | Jsonrpc: "2.0", 648 | ID: id, 649 | Method: fn.name, 650 | Params: serializedParams, 651 | } 652 | 653 | if span != nil { 654 | span.AddAttributes(trace.StringAttribute("method", req.Method)) 655 | 656 | eSC := base64.StdEncoding.EncodeToString( 657 | propagation.Binary(span.SpanContext())) 658 | req.Meta = map[string]string{ 659 | "SpanContext": eSC, 660 | } 661 | } 662 | 663 | b := backoff{ 664 | maxDelay: methodMaxRetryDelay, 665 | minDelay: methodMinRetryDelay, 666 | } 667 | 668 | var err error 669 | var resp clientResponse 670 | // keep retrying if got a forced closed websocket conn and calling method 671 | // has retry annotation 672 | for attempt := 0; true; attempt++ { 673 | resp, err = fn.client.sendRequest(ctx, req, chCtor) 674 | if err != nil { 675 | return fn.processError(fmt.Errorf("sendRequest failed: %w", err)) 676 | } 677 | 678 | if !fn.notify && resp.ID != req.ID { 679 | return fn.processError(xerrors.New("request and response id didn't match")) 680 | } 681 | 682 | if fn.valOut != -1 && !fn.returnValueIsChannel { 683 | val := reflect.New(fn.ftyp.Out(fn.valOut)) 684 | 685 | if resp.Result != nil { 686 | log.Debugw("rpc result", "type", fn.ftyp.Out(fn.valOut)) 687 | if err := json.Unmarshal(resp.Result, val.Interface()); err != nil { 688 | log.Warnw("unmarshaling failed", "message", string(resp.Result)) 689 | return fn.processError(xerrors.Errorf("unmarshaling result: %w", err)) 690 | } 691 | } 692 | 693 | retVal = func() reflect.Value { return val.Elem() } 694 | } 695 | retry := resp.Error != nil && resp.Error.Code == eTempWSError && fn.retry 696 | if !retry { 697 | break 698 | } 699 | 700 | time.Sleep(b.next(attempt)) 701 | } 702 | 703 | return fn.processResponse(resp, retVal()) 704 | } 705 | 706 | const ( 707 | ProxyTagRetry = "retry" 708 | ProxyTagNotify = "notify" 709 | ProxyTagRPCMethod = "rpc_method" 710 | ) 711 | 712 | func (c *client) makeRpcFunc(f reflect.StructField) (reflect.Value, error) { 713 | ftyp := f.Type 714 | if ftyp.Kind() != reflect.Func { 715 | return reflect.Value{}, xerrors.New("handler field not a func") 716 | } 717 | 718 | name := c.methodNameFormatter(c.namespace, f.Name) 719 | if tag, ok := f.Tag.Lookup(ProxyTagRPCMethod); ok { 720 | name = tag 721 | } 722 | 723 | fun := &rpcFunc{ 724 | client: c, 725 | ftyp: ftyp, 726 | name: name, 727 | retry: f.Tag.Get(ProxyTagRetry) == "true", 728 | notify: f.Tag.Get(ProxyTagNotify) == "true", 729 | } 730 | fun.valOut, fun.errOut, fun.nout = processFuncOut(ftyp) 731 | 732 | if fun.valOut != -1 && fun.notify { 733 | return reflect.Value{}, xerrors.New("notify methods cannot return values") 734 | } 735 | 736 | fun.returnValueIsChannel = fun.valOut != -1 && ftyp.Out(fun.valOut).Kind() == reflect.Chan 737 | 738 | if ftyp.NumIn() > 0 && ftyp.In(0) == contextType { 739 | fun.hasCtx = 1 740 | } 741 | // note: hasCtx is also the number of the first non-context argument 742 | if ftyp.NumIn() > fun.hasCtx && ftyp.In(fun.hasCtx) == rtRawParams { 743 | if ftyp.NumIn() > fun.hasCtx+1 { 744 | return reflect.Value{}, xerrors.New("raw params can't be mixed with other arguments") 745 | } 746 | fun.hasRawParams = true 747 | } 748 | 749 | return reflect.MakeFunc(ftyp, fun.handleRpcCall), nil 750 | } 751 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package jsonrpc 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "reflect" 7 | ) 8 | 9 | const eTempWSError = -1111111 10 | 11 | type RPCConnectionError struct { 12 | err error 13 | } 14 | 15 | func (e *RPCConnectionError) Error() string { 16 | if e.err != nil { 17 | return e.err.Error() 18 | } 19 | return "RPCConnectionError" 20 | } 21 | 22 | func (e *RPCConnectionError) Unwrap() error { 23 | if e.err != nil { 24 | return e.err 25 | } 26 | return errors.New("RPCConnectionError") 27 | } 28 | 29 | type Errors struct { 30 | byType map[reflect.Type]ErrorCode 31 | byCode map[ErrorCode]reflect.Type 32 | } 33 | 34 | type ErrorCode int 35 | 36 | const FirstUserCode = 2 37 | 38 | func NewErrors() Errors { 39 | return Errors{ 40 | byType: map[reflect.Type]ErrorCode{}, 41 | byCode: map[ErrorCode]reflect.Type{ 42 | -1111111: reflect.TypeOf(&RPCConnectionError{}), 43 | }, 44 | } 45 | } 46 | 47 | func (e *Errors) Register(c ErrorCode, typ interface{}) { 48 | rt := reflect.TypeOf(typ).Elem() 49 | if !rt.Implements(errorType) { 50 | panic("can't register non-error types") 51 | } 52 | 53 | e.byType[rt] = c 54 | e.byCode[c] = rt 55 | } 56 | 57 | type marshalable interface { 58 | json.Marshaler 59 | json.Unmarshaler 60 | } 61 | 62 | type RPCErrorCodec interface { 63 | FromJSONRPCError(JSONRPCError) error 64 | ToJSONRPCError() (JSONRPCError, error) 65 | } 66 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/filecoin-project/go-jsonrpc 2 | 3 | go 1.23.0 4 | 5 | require ( 6 | github.com/google/uuid v1.1.1 7 | github.com/gorilla/mux v1.7.4 8 | github.com/gorilla/websocket v1.4.2 9 | github.com/ipfs/go-log/v2 v2.0.8 10 | github.com/stretchr/testify v1.5.1 11 | go.opencensus.io v0.22.3 12 | go.uber.org/zap v1.14.1 13 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 14 | ) 15 | 16 | require ( 17 | github.com/davecgh/go-spew v1.1.1 // indirect 18 | github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 // indirect 19 | github.com/pmezard/go-difflib v1.0.0 // indirect 20 | go.uber.org/atomic v1.6.0 // indirect 21 | go.uber.org/multierr v1.5.0 // indirect 22 | gopkg.in/yaml.v2 v2.2.2 // indirect 23 | ) 24 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= 2 | github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= 3 | github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= 4 | github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= 5 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 7 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 8 | github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= 9 | github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 h1:ZgQEtGgCBiWRM39fZuwSd1LwSqqSW0hOdXCYYDX0R3I= 10 | github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= 11 | github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= 12 | github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 13 | github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 14 | github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= 15 | github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= 16 | github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= 17 | github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= 18 | github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 19 | github.com/gorilla/mux v1.7.4 h1:VuZ8uybHlWmqV03+zRzdwKL4tUnIp1MAQtp1mIFE1bc= 20 | github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= 21 | github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= 22 | github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 23 | github.com/ipfs/go-log/v2 v2.0.8 h1:3b3YNopMHlj4AvyhWAx0pDxqSQWYi4/WuWO7yRV6/Qg= 24 | github.com/ipfs/go-log/v2 v2.0.8/go.mod h1:eZs4Xt4ZUJQFM3DlanGhy7TkwwawCZcSByscwkWG+dw= 25 | github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= 26 | github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= 27 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 28 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 29 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 30 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 31 | github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= 32 | github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 33 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 34 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 35 | github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= 36 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 37 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 38 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 39 | github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= 40 | github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= 41 | go.opencensus.io v0.22.3 h1:8sGtKOrtQqkN1bp2AtX+misvLIlOmsEsNd+9NIcPEm8= 42 | go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= 43 | go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= 44 | go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= 45 | go.uber.org/multierr v1.5.0 h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A= 46 | go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= 47 | go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= 48 | go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= 49 | go.uber.org/zap v1.14.1 h1:nYDKopTbvAPq/NrUVZwT15y2lpROBiLLyoRTbXOYWOo= 50 | go.uber.org/zap v1.14.1/go.mod h1:Mb2vm2krFEG5DV0W9qcHBYFtp/Wku1cvYaqPsS/WYfc= 51 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 52 | golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 53 | golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= 54 | golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= 55 | golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= 56 | golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= 57 | golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= 58 | golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= 59 | golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= 60 | golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 61 | golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 62 | golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= 63 | golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 64 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 65 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 66 | golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= 67 | golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 68 | golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 69 | golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 70 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 71 | golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 72 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 73 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 74 | golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 75 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 76 | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= 77 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 78 | golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 79 | golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= 80 | golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= 81 | golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= 82 | golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 83 | golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5 h1:hKsoRgsbwY1NafxrwTs+k64bikrLBkAgPir1TNCj3Zs= 84 | golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 85 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 86 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= 87 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 88 | google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= 89 | google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= 90 | google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= 91 | google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= 92 | google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= 93 | google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= 94 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 95 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= 96 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 97 | gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= 98 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= 99 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 100 | honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= 101 | honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= 102 | honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= 103 | -------------------------------------------------------------------------------- /handler.go: -------------------------------------------------------------------------------- 1 | package jsonrpc 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/base64" 7 | "encoding/json" 8 | "fmt" 9 | "io" 10 | "reflect" 11 | 12 | "go.opencensus.io/stats" 13 | "go.opencensus.io/tag" 14 | "go.opencensus.io/trace" 15 | "go.opencensus.io/trace/propagation" 16 | "go.uber.org/zap" 17 | "go.uber.org/zap/zapcore" 18 | "golang.org/x/xerrors" 19 | 20 | "github.com/filecoin-project/go-jsonrpc/metrics" 21 | ) 22 | 23 | type RawParams json.RawMessage 24 | 25 | var rtRawParams = reflect.TypeOf(RawParams{}) 26 | 27 | // todo is there a better way to tell 'struct with any number of fields'? 28 | func DecodeParams[T any](p RawParams) (T, error) { 29 | var t T 30 | err := json.Unmarshal(p, &t) 31 | 32 | // todo also handle list-encoding automagically (json.Unmarshal doesn't do that, does it?) 33 | 34 | return t, err 35 | } 36 | 37 | // methodHandler is a handler for a single method 38 | type methodHandler struct { 39 | paramReceivers []reflect.Type 40 | nParams int 41 | 42 | receiver reflect.Value 43 | handlerFunc reflect.Value 44 | 45 | hasCtx int 46 | hasRawParams bool 47 | 48 | errOut int 49 | valOut int 50 | } 51 | 52 | // Request / response 53 | 54 | type request struct { 55 | Jsonrpc string `json:"jsonrpc"` 56 | ID interface{} `json:"id,omitempty"` 57 | Method string `json:"method"` 58 | Params json.RawMessage `json:"params"` 59 | Meta map[string]string `json:"meta,omitempty"` 60 | } 61 | 62 | // Limit request size. Ideally this limit should be specific for each field 63 | // in the JSON request but as a simple defensive measure we just limit the 64 | // entire HTTP body. 65 | // Configured by WithMaxRequestSize. 66 | const DEFAULT_MAX_REQUEST_SIZE = 100 << 20 // 100 MiB 67 | 68 | type handler struct { 69 | methods map[string]methodHandler 70 | errors *Errors 71 | 72 | maxRequestSize int64 73 | 74 | // aliasedMethods contains a map of alias:original method names. 75 | // These are used as fallbacks if a method is not found by the given method name. 76 | aliasedMethods map[string]string 77 | 78 | paramDecoders map[reflect.Type]ParamDecoder 79 | 80 | methodNameFormatter MethodNameFormatter 81 | 82 | tracer Tracer 83 | } 84 | 85 | type Tracer func(method string, params []reflect.Value, results []reflect.Value, err error) 86 | 87 | func makeHandler(sc ServerConfig) *handler { 88 | return &handler{ 89 | methods: make(map[string]methodHandler), 90 | errors: sc.errors, 91 | 92 | aliasedMethods: map[string]string{}, 93 | paramDecoders: sc.paramDecoders, 94 | 95 | methodNameFormatter: sc.methodNameFormatter, 96 | 97 | maxRequestSize: sc.maxRequestSize, 98 | 99 | tracer: sc.tracer, 100 | } 101 | } 102 | 103 | // Register 104 | 105 | func (s *handler) register(namespace string, r interface{}) { 106 | val := reflect.ValueOf(r) 107 | // TODO: expect ptr 108 | 109 | for i := 0; i < val.NumMethod(); i++ { 110 | method := val.Type().Method(i) 111 | 112 | funcType := method.Func.Type() 113 | hasCtx := 0 114 | if funcType.NumIn() >= 2 && funcType.In(1) == contextType { 115 | hasCtx = 1 116 | } 117 | 118 | hasRawParams := false 119 | ins := funcType.NumIn() - 1 - hasCtx 120 | recvs := make([]reflect.Type, ins) 121 | for i := 0; i < ins; i++ { 122 | if hasRawParams && i > 0 { 123 | panic("raw params must be the last parameter") 124 | } 125 | if funcType.In(i+1+hasCtx) == rtRawParams { 126 | hasRawParams = true 127 | } 128 | recvs[i] = method.Type.In(i + 1 + hasCtx) 129 | } 130 | 131 | valOut, errOut, _ := processFuncOut(funcType) 132 | 133 | s.methods[s.methodNameFormatter(namespace, method.Name)] = methodHandler{ 134 | paramReceivers: recvs, 135 | nParams: ins, 136 | 137 | handlerFunc: method.Func, 138 | receiver: val, 139 | 140 | hasCtx: hasCtx, 141 | hasRawParams: hasRawParams, 142 | 143 | errOut: errOut, 144 | valOut: valOut, 145 | } 146 | } 147 | } 148 | 149 | // Handle 150 | 151 | type rpcErrFunc func(w func(func(io.Writer)), req *request, code ErrorCode, err error) 152 | type chanOut func(reflect.Value, interface{}) error 153 | 154 | func (s *handler) handleReader(ctx context.Context, r io.Reader, w io.Writer, rpcError rpcErrFunc) { 155 | wf := func(cb func(io.Writer)) { 156 | cb(w) 157 | } 158 | 159 | // We read the entire request upfront in a buffer to be able to tell if the 160 | // client sent more than maxRequestSize and report it back as an explicit error, 161 | // instead of just silently truncating it and reporting a more vague parsing 162 | // error. 163 | bufferedRequest := new(bytes.Buffer) 164 | // We use LimitReader to enforce maxRequestSize. Since it won't return an 165 | // EOF we can't actually know if the client sent more than the maximum or 166 | // not, so we read one byte more over the limit to explicitly query that. 167 | // FIXME: Maybe there's a cleaner way to do this. 168 | reqSize, err := bufferedRequest.ReadFrom(io.LimitReader(r, s.maxRequestSize+1)) 169 | if err != nil { 170 | // ReadFrom will discard EOF so any error here is unexpected and should 171 | // be reported. 172 | rpcError(wf, nil, rpcParseError, xerrors.Errorf("reading request: %w", err)) 173 | return 174 | } 175 | if reqSize > s.maxRequestSize { 176 | rpcError(wf, nil, rpcParseError, 177 | // rpcParseError is the closest we have from the standard errors defined 178 | // in [jsonrpc spec](https://www.jsonrpc.org/specification#error_object) 179 | // to report the maximum limit. 180 | xerrors.Errorf("request bigger than maximum %d allowed", 181 | s.maxRequestSize)) 182 | return 183 | } 184 | 185 | // Trim spaces to avoid issues with batch request detection. 186 | bufferedRequest = bytes.NewBuffer(bytes.TrimSpace(bufferedRequest.Bytes())) 187 | reqSize = int64(bufferedRequest.Len()) 188 | 189 | if reqSize == 0 { 190 | rpcError(wf, nil, rpcInvalidRequest, xerrors.New("Invalid request")) 191 | return 192 | } 193 | 194 | if bufferedRequest.Bytes()[0] == '[' && bufferedRequest.Bytes()[reqSize-1] == ']' { 195 | var reqs []request 196 | 197 | if err := json.NewDecoder(bufferedRequest).Decode(&reqs); err != nil { 198 | rpcError(wf, nil, rpcParseError, xerrors.New("Parse error")) 199 | return 200 | } 201 | 202 | if len(reqs) == 0 { 203 | rpcError(wf, nil, rpcInvalidRequest, xerrors.New("Invalid request")) 204 | return 205 | } 206 | 207 | _, _ = w.Write([]byte("[")) // todo consider handling this error 208 | for idx, req := range reqs { 209 | if req.ID, err = normalizeID(req.ID); err != nil { 210 | rpcError(wf, &req, rpcParseError, xerrors.Errorf("failed to parse ID: %w", err)) 211 | return 212 | } 213 | 214 | s.handle(ctx, req, wf, rpcError, func(bool) {}, nil) 215 | 216 | if idx != len(reqs)-1 { 217 | _, _ = w.Write([]byte(",")) // todo consider handling this error 218 | } 219 | } 220 | _, _ = w.Write([]byte("]")) // todo consider handling this error 221 | } else { 222 | var req request 223 | if err := json.NewDecoder(bufferedRequest).Decode(&req); err != nil { 224 | rpcError(wf, &req, rpcParseError, xerrors.New("Parse error")) 225 | return 226 | } 227 | 228 | if req.ID, err = normalizeID(req.ID); err != nil { 229 | rpcError(wf, &req, rpcParseError, xerrors.Errorf("failed to parse ID: %w", err)) 230 | return 231 | } 232 | 233 | s.handle(ctx, req, wf, rpcError, func(bool) {}, nil) 234 | } 235 | } 236 | 237 | func doCall(methodName string, f reflect.Value, params []reflect.Value) (out []reflect.Value, err error) { 238 | defer func() { 239 | if i := recover(); i != nil { 240 | err = xerrors.Errorf("panic in rpc method '%s': %s", methodName, i) 241 | log.Desugar().WithOptions(zap.AddStacktrace(zapcore.ErrorLevel)).Sugar().Error(err) 242 | } 243 | }() 244 | 245 | out = f.Call(params) 246 | return out, nil 247 | } 248 | 249 | func (s *handler) getSpan(ctx context.Context, req request) (context.Context, *trace.Span) { 250 | if req.Meta == nil { 251 | return ctx, nil 252 | } 253 | 254 | var span *trace.Span 255 | if eSC, ok := req.Meta["SpanContext"]; ok { 256 | bSC := make([]byte, base64.StdEncoding.DecodedLen(len(eSC))) 257 | _, err := base64.StdEncoding.Decode(bSC, []byte(eSC)) 258 | if err != nil { 259 | log.Errorf("SpanContext: decode", "error", err) 260 | return ctx, nil 261 | } 262 | sc, ok := propagation.FromBinary(bSC) 263 | if !ok { 264 | log.Errorf("SpanContext: could not create span", "data", bSC) 265 | return ctx, nil 266 | } 267 | ctx, span = trace.StartSpanWithRemoteParent(ctx, "api.handle", sc) 268 | } else { 269 | ctx, span = trace.StartSpan(ctx, "api.handle") 270 | } 271 | 272 | span.AddAttributes(trace.StringAttribute("method", req.Method)) 273 | return ctx, span 274 | } 275 | 276 | func (s *handler) createError(err error) *JSONRPCError { 277 | var code ErrorCode = 1 278 | if s.errors != nil { 279 | c, ok := s.errors.byType[reflect.TypeOf(err)] 280 | if ok { 281 | code = c 282 | } 283 | } 284 | 285 | out := &JSONRPCError{ 286 | Code: code, 287 | Message: err.Error(), 288 | } 289 | 290 | switch m := err.(type) { 291 | case RPCErrorCodec: 292 | o, err := m.ToJSONRPCError() 293 | if err != nil { 294 | log.Errorf("Failed to convert error to JSONRPCError: %w", err) 295 | } else { 296 | out = &o 297 | } 298 | case marshalable: 299 | meta, marshalErr := m.MarshalJSON() 300 | if marshalErr == nil { 301 | out.Meta = meta 302 | } else { 303 | log.Errorf("Failed to marshal error metadata: %w", marshalErr) 304 | } 305 | } 306 | 307 | return out 308 | } 309 | 310 | func (s *handler) handle(ctx context.Context, req request, w func(func(io.Writer)), rpcError rpcErrFunc, done func(keepCtx bool), chOut chanOut) { 311 | // Not sure if we need to sanitize the incoming req.Method or not. 312 | ctx, span := s.getSpan(ctx, req) 313 | ctx, _ = tag.New(ctx, tag.Insert(metrics.RPCMethod, req.Method)) 314 | defer span.End() 315 | 316 | handler, ok := s.methods[req.Method] 317 | if !ok { 318 | aliasTo, ok := s.aliasedMethods[req.Method] 319 | if ok { 320 | handler, ok = s.methods[aliasTo] 321 | } 322 | if !ok { 323 | rpcError(w, &req, rpcMethodNotFound, fmt.Errorf("method '%s' not found", req.Method)) 324 | stats.Record(ctx, metrics.RPCInvalidMethod.M(1)) 325 | done(false) 326 | return 327 | } 328 | } 329 | 330 | outCh := handler.valOut != -1 && handler.handlerFunc.Type().Out(handler.valOut).Kind() == reflect.Chan 331 | defer done(outCh) 332 | 333 | if chOut == nil && outCh { 334 | rpcError(w, &req, rpcMethodNotFound, fmt.Errorf("method '%s' not supported in this mode (no out channel support)", req.Method)) 335 | stats.Record(ctx, metrics.RPCRequestError.M(1)) 336 | return 337 | } 338 | 339 | callParams := make([]reflect.Value, 1+handler.hasCtx+handler.nParams) 340 | callParams[0] = handler.receiver 341 | if handler.hasCtx == 1 { 342 | callParams[1] = reflect.ValueOf(ctx) 343 | } 344 | 345 | if handler.hasRawParams { 346 | // When hasRawParams is true, there is only one parameter and it is a 347 | // json.RawMessage. 348 | 349 | callParams[1+handler.hasCtx] = reflect.ValueOf(RawParams(req.Params)) 350 | } else { 351 | // "normal" param list; no good way to do named params in Golang 352 | 353 | var ps []param 354 | if len(req.Params) > 0 { 355 | err := json.Unmarshal(req.Params, &ps) 356 | if err != nil { 357 | rpcError(w, &req, rpcParseError, xerrors.Errorf("unmarshaling param array: %w", err)) 358 | stats.Record(ctx, metrics.RPCRequestError.M(1)) 359 | return 360 | } 361 | } 362 | 363 | if len(ps) != handler.nParams { 364 | rpcError(w, &req, rpcInvalidParams, fmt.Errorf("wrong param count (method '%s'): %d != %d", req.Method, len(ps), handler.nParams)) 365 | stats.Record(ctx, metrics.RPCRequestError.M(1)) 366 | done(false) 367 | return 368 | } 369 | 370 | for i := 0; i < handler.nParams; i++ { 371 | var rp reflect.Value 372 | 373 | typ := handler.paramReceivers[i] 374 | dec, found := s.paramDecoders[typ] 375 | if !found { 376 | rp = reflect.New(typ) 377 | if err := json.NewDecoder(bytes.NewReader(ps[i].data)).Decode(rp.Interface()); err != nil { 378 | rpcError(w, &req, rpcParseError, xerrors.Errorf("unmarshaling params for '%s' (param: %T): %w", req.Method, rp.Interface(), err)) 379 | stats.Record(ctx, metrics.RPCRequestError.M(1)) 380 | return 381 | } 382 | rp = rp.Elem() 383 | } else { 384 | var err error 385 | rp, err = dec(ctx, ps[i].data) 386 | if err != nil { 387 | rpcError(w, &req, rpcParseError, xerrors.Errorf("decoding params for '%s' (param: %d; custom decoder): %w", req.Method, i, err)) 388 | stats.Record(ctx, metrics.RPCRequestError.M(1)) 389 | return 390 | } 391 | } 392 | 393 | callParams[i+1+handler.hasCtx] = reflect.ValueOf(rp.Interface()) 394 | } 395 | } 396 | 397 | // ///////////////// 398 | 399 | callResult, err := doCall(req.Method, handler.handlerFunc, callParams) 400 | if err != nil { 401 | rpcError(w, &req, 0, xerrors.Errorf("fatal error calling '%s': %w", req.Method, err)) 402 | stats.Record(ctx, metrics.RPCRequestError.M(1)) 403 | if s.tracer != nil { 404 | s.tracer(req.Method, callParams, nil, err) 405 | } 406 | return 407 | } 408 | if req.ID == nil { 409 | return // notification 410 | } 411 | 412 | if s.tracer != nil { 413 | s.tracer(req.Method, callParams, callResult, nil) 414 | } 415 | // ///////////////// 416 | 417 | resp := response{ 418 | Jsonrpc: "2.0", 419 | ID: req.ID, 420 | } 421 | 422 | if handler.errOut != -1 { 423 | err := callResult[handler.errOut].Interface() 424 | if err != nil { 425 | log.Warnf("error in RPC call to '%s': %+v", req.Method, err) 426 | stats.Record(ctx, metrics.RPCResponseError.M(1)) 427 | 428 | resp.Error = s.createError(err.(error)) 429 | } 430 | } 431 | 432 | var kind reflect.Kind 433 | var res interface{} 434 | var nonZero bool 435 | if handler.valOut != -1 { 436 | res = callResult[handler.valOut].Interface() 437 | kind = callResult[handler.valOut].Kind() 438 | nonZero = !callResult[handler.valOut].IsZero() 439 | } 440 | 441 | // check error as JSON-RPC spec prohibits error and value at the same time 442 | if resp.Error == nil { 443 | if res != nil && kind == reflect.Chan { 444 | // Channel responses are sent from channel control goroutine. 445 | // Sending responses here could cause deadlocks on writeLk, or allow 446 | // sending channel messages before this rpc call returns 447 | 448 | //noinspection GoNilness // already checked above 449 | err = chOut(callResult[handler.valOut], req.ID) 450 | if err == nil { 451 | return // channel goroutine handles responding 452 | } 453 | 454 | log.Warnf("failed to setup channel in RPC call to '%s': %+v", req.Method, err) 455 | stats.Record(ctx, metrics.RPCResponseError.M(1)) 456 | 457 | resp.Error = &JSONRPCError{ 458 | Code: 1, 459 | Message: err.Error(), 460 | } 461 | } else { 462 | resp.Result = res 463 | } 464 | } 465 | if resp.Error != nil && nonZero { 466 | log.Errorw("error and res returned", "request", req, "r.err", resp.Error, "res", res) 467 | } 468 | 469 | withLazyWriter(w, func(w io.Writer) { 470 | if err := json.NewEncoder(w).Encode(resp); err != nil { 471 | log.Error(err) 472 | stats.Record(ctx, metrics.RPCResponseError.M(1)) 473 | return 474 | } 475 | }) 476 | } 477 | 478 | // withLazyWriter makes it possible to defer acquiring a writer until the first write. 479 | // This is useful because json.Encode needs to marshal the response fully before writing, which may be 480 | // a problem for very large responses. 481 | func withLazyWriter(withWriterFunc func(func(io.Writer)), cb func(io.Writer)) { 482 | lw := &lazyWriter{ 483 | withWriterFunc: withWriterFunc, 484 | 485 | done: make(chan struct{}), 486 | } 487 | 488 | defer close(lw.done) 489 | cb(lw) 490 | } 491 | 492 | type lazyWriter struct { 493 | withWriterFunc func(func(io.Writer)) 494 | 495 | w io.Writer 496 | done chan struct{} 497 | } 498 | 499 | func (lw *lazyWriter) Write(p []byte) (n int, err error) { 500 | if lw.w == nil { 501 | acquired := make(chan struct{}) 502 | go func() { 503 | lw.withWriterFunc(func(w io.Writer) { 504 | lw.w = w 505 | close(acquired) 506 | <-lw.done 507 | }) 508 | }() 509 | <-acquired 510 | } 511 | 512 | return lw.w.Write(p) 513 | } 514 | -------------------------------------------------------------------------------- /httpio/README: -------------------------------------------------------------------------------- 1 | This package provides param encoders / decoders for `io.Reader` which proxy 2 | data over temporary http endpoints -------------------------------------------------------------------------------- /httpio/reader.go: -------------------------------------------------------------------------------- 1 | package httpio 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "net/url" 10 | "path" 11 | "reflect" 12 | "sync" 13 | 14 | "github.com/google/uuid" 15 | logging "github.com/ipfs/go-log/v2" 16 | "golang.org/x/xerrors" 17 | 18 | "github.com/filecoin-project/go-jsonrpc" 19 | ) 20 | 21 | var log = logging.Logger("rpc") 22 | 23 | func ReaderParamEncoder(addr string) jsonrpc.Option { 24 | return jsonrpc.WithParamEncoder(new(io.Reader), func(value reflect.Value) (reflect.Value, error) { 25 | r := value.Interface().(io.Reader) 26 | 27 | reqID := uuid.New() 28 | u, _ := url.Parse(addr) 29 | u.Path = path.Join(u.Path, reqID.String()) 30 | 31 | go func() { 32 | // TODO: figure out errors here 33 | 34 | resp, err := http.Post(u.String(), "application/octet-stream", r) 35 | if err != nil { 36 | log.Errorf("sending reader param: %+v", err) 37 | return 38 | } 39 | 40 | defer resp.Body.Close() 41 | 42 | if resp.StatusCode != 200 { 43 | log.Errorf("sending reader param: non-200 status: ", resp.Status) 44 | return 45 | } 46 | 47 | }() 48 | 49 | return reflect.ValueOf(reqID), nil 50 | }) 51 | } 52 | 53 | type waitReadCloser struct { 54 | io.ReadCloser 55 | wait chan struct{} 56 | } 57 | 58 | func (w *waitReadCloser) Read(p []byte) (int, error) { 59 | n, err := w.ReadCloser.Read(p) 60 | if err != nil { 61 | close(w.wait) 62 | } 63 | return n, err 64 | } 65 | 66 | func (w *waitReadCloser) Close() error { 67 | close(w.wait) 68 | return w.ReadCloser.Close() 69 | } 70 | 71 | func ReaderParamDecoder() (http.HandlerFunc, jsonrpc.ServerOption) { 72 | var readersLk sync.Mutex 73 | readers := map[uuid.UUID]chan *waitReadCloser{} 74 | 75 | hnd := func(resp http.ResponseWriter, req *http.Request) { 76 | strId := path.Base(req.URL.Path) 77 | u, err := uuid.Parse(strId) 78 | if err != nil { 79 | http.Error(resp, fmt.Sprintf("parsing reader uuid: %s", err), 400) 80 | } 81 | 82 | readersLk.Lock() 83 | ch, found := readers[u] 84 | if !found { 85 | ch = make(chan *waitReadCloser) 86 | readers[u] = ch 87 | } 88 | readersLk.Unlock() 89 | 90 | wr := &waitReadCloser{ 91 | ReadCloser: req.Body, 92 | wait: make(chan struct{}), 93 | } 94 | 95 | select { 96 | case ch <- wr: 97 | case <-req.Context().Done(): 98 | log.Error("context error in reader stream handler (1): %v", req.Context().Err()) 99 | resp.WriteHeader(500) 100 | return 101 | } 102 | 103 | select { 104 | case <-wr.wait: 105 | case <-req.Context().Done(): 106 | log.Error("context error in reader stream handler (2): %v", req.Context().Err()) 107 | resp.WriteHeader(500) 108 | return 109 | } 110 | 111 | resp.WriteHeader(200) 112 | } 113 | 114 | dec := jsonrpc.WithParamDecoder(new(io.Reader), func(ctx context.Context, b []byte) (reflect.Value, error) { 115 | var strId string 116 | if err := json.Unmarshal(b, &strId); err != nil { 117 | return reflect.Value{}, xerrors.Errorf("unmarshaling reader id: %w", err) 118 | } 119 | 120 | u, err := uuid.Parse(strId) 121 | if err != nil { 122 | return reflect.Value{}, xerrors.Errorf("parsing reader UUDD: %w", err) 123 | } 124 | 125 | readersLk.Lock() 126 | ch, found := readers[u] 127 | if !found { 128 | ch = make(chan *waitReadCloser) 129 | readers[u] = ch 130 | } 131 | readersLk.Unlock() 132 | 133 | select { 134 | case wr := <-ch: 135 | return reflect.ValueOf(wr), nil 136 | case <-ctx.Done(): 137 | return reflect.Value{}, ctx.Err() 138 | } 139 | }) 140 | 141 | return hnd, dec 142 | } 143 | -------------------------------------------------------------------------------- /httpio/reader_test.go: -------------------------------------------------------------------------------- 1 | package httpio 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net/http/httptest" 7 | "strings" 8 | "testing" 9 | 10 | "github.com/gorilla/mux" 11 | "github.com/stretchr/testify/require" 12 | 13 | "github.com/filecoin-project/go-jsonrpc" 14 | ) 15 | 16 | type ReaderHandler struct { 17 | } 18 | 19 | func (h *ReaderHandler) ReadAll(ctx context.Context, r io.Reader) ([]byte, error) { 20 | return io.ReadAll(r) 21 | } 22 | 23 | func (h *ReaderHandler) ReadUrl(ctx context.Context, u string) (string, error) { 24 | return u, nil 25 | } 26 | 27 | func TestReaderProxy(t *testing.T) { 28 | var client struct { 29 | ReadAll func(ctx context.Context, r io.Reader) ([]byte, error) 30 | } 31 | 32 | serverHandler := &ReaderHandler{} 33 | 34 | readerHandler, readerServerOpt := ReaderParamDecoder() 35 | rpcServer := jsonrpc.NewServer(readerServerOpt) 36 | rpcServer.Register("ReaderHandler", serverHandler) 37 | 38 | mux := mux.NewRouter() 39 | mux.Handle("/rpc/v0", rpcServer) 40 | mux.Handle("/rpc/streams/v0/push/{uuid}", readerHandler) 41 | 42 | testServ := httptest.NewServer(mux) 43 | defer testServ.Close() 44 | 45 | re := ReaderParamEncoder("http://" + testServ.Listener.Addr().String() + "/rpc/streams/v0/push") 46 | closer, err := jsonrpc.NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String()+"/rpc/v0", "ReaderHandler", []interface{}{&client}, nil, re) 47 | require.NoError(t, err) 48 | 49 | defer closer() 50 | 51 | read, err := client.ReadAll(context.TODO(), strings.NewReader("pooooootato")) 52 | require.NoError(t, err) 53 | require.Equal(t, "pooooootato", string(read), "potatos weren't equal") 54 | } 55 | -------------------------------------------------------------------------------- /method_formatter.go: -------------------------------------------------------------------------------- 1 | package jsonrpc 2 | 3 | import "strings" 4 | 5 | // MethodNameFormatter is a function that takes a namespace and a method name and returns the full method name, sent via JSON-RPC. 6 | // This is useful if you want to customize the default behaviour, e.g. send without the namespace or make it lowercase. 7 | type MethodNameFormatter func(namespace, method string) string 8 | 9 | // CaseStyle represents the case style for method names. 10 | type CaseStyle int 11 | 12 | const ( 13 | OriginalCase CaseStyle = iota 14 | LowerFirstCharCase 15 | ) 16 | 17 | // NewMethodNameFormatter creates a new method name formatter based on the provided options. 18 | func NewMethodNameFormatter(includeNamespace bool, nameCase CaseStyle) MethodNameFormatter { 19 | return func(namespace, method string) string { 20 | formattedMethod := method 21 | if nameCase == LowerFirstCharCase && len(method) > 0 { 22 | formattedMethod = strings.ToLower(method[:1]) + method[1:] 23 | } 24 | if includeNamespace { 25 | return namespace + "." + formattedMethod 26 | } 27 | return formattedMethod 28 | } 29 | } 30 | 31 | // DefaultMethodNameFormatter is a pass-through formatter with default options. 32 | var DefaultMethodNameFormatter = NewMethodNameFormatter(true, OriginalCase) 33 | -------------------------------------------------------------------------------- /method_formatter_test.go: -------------------------------------------------------------------------------- 1 | package jsonrpc 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/stretchr/testify/require" 7 | "net/http" 8 | "net/http/httptest" 9 | "strings" 10 | "testing" 11 | ) 12 | 13 | func TestDifferentMethodNamers(t *testing.T) { 14 | tests := map[string]struct { 15 | namer MethodNameFormatter 16 | 17 | requestedMethod string 18 | }{ 19 | "default namer": { 20 | namer: DefaultMethodNameFormatter, 21 | requestedMethod: "SimpleServerHandler.Inc", 22 | }, 23 | "lower fist char": { 24 | namer: NewMethodNameFormatter(true, LowerFirstCharCase), 25 | requestedMethod: "SimpleServerHandler.inc", 26 | }, 27 | "no namespace namer": { 28 | namer: NewMethodNameFormatter(false, OriginalCase), 29 | requestedMethod: "Inc", 30 | }, 31 | "no namespace & lower fist char": { 32 | namer: NewMethodNameFormatter(false, LowerFirstCharCase), 33 | requestedMethod: "inc", 34 | }, 35 | } 36 | for name, test := range tests { 37 | t.Run(name, func(t *testing.T) { 38 | rpcServer := NewServer(WithServerMethodNameFormatter(test.namer)) 39 | 40 | serverHandler := &SimpleServerHandler{} 41 | rpcServer.Register("SimpleServerHandler", serverHandler) 42 | 43 | testServ := httptest.NewServer(rpcServer) 44 | defer testServ.Close() 45 | 46 | req := fmt.Sprintf(`{"jsonrpc": "2.0", "method": "%s", "params": [], "id": 1}`, test.requestedMethod) 47 | 48 | res, err := http.Post(testServ.URL, "application/json", strings.NewReader(req)) 49 | require.NoError(t, err) 50 | 51 | require.Equal(t, http.StatusOK, res.StatusCode) 52 | require.Equal(t, int32(1), serverHandler.n) 53 | }) 54 | } 55 | } 56 | 57 | func TestDifferentMethodNamersWithClient(t *testing.T) { 58 | tests := map[string]struct { 59 | namer MethodNameFormatter 60 | urlPrefix string 61 | }{ 62 | "default namer & http": { 63 | namer: DefaultMethodNameFormatter, 64 | urlPrefix: "http://", 65 | }, 66 | "default namer & ws": { 67 | namer: DefaultMethodNameFormatter, 68 | urlPrefix: "ws://", 69 | }, 70 | "lower first char namer & http": { 71 | namer: NewMethodNameFormatter(true, LowerFirstCharCase), 72 | urlPrefix: "http://", 73 | }, 74 | "lower first char namer & ws": { 75 | namer: NewMethodNameFormatter(true, LowerFirstCharCase), 76 | urlPrefix: "ws://", 77 | }, 78 | "no namespace namer & http": { 79 | namer: NewMethodNameFormatter(false, OriginalCase), 80 | urlPrefix: "http://", 81 | }, 82 | "no namespace namer & ws": { 83 | namer: NewMethodNameFormatter(false, OriginalCase), 84 | urlPrefix: "ws://", 85 | }, 86 | "no namespace & lower first char & http": { 87 | namer: NewMethodNameFormatter(false, LowerFirstCharCase), 88 | urlPrefix: "http://", 89 | }, 90 | "no namespace & lower first char & ws": { 91 | namer: NewMethodNameFormatter(false, LowerFirstCharCase), 92 | urlPrefix: "ws://", 93 | }, 94 | } 95 | for name, test := range tests { 96 | t.Run(name, func(t *testing.T) { 97 | rpcServer := NewServer(WithServerMethodNameFormatter(test.namer)) 98 | 99 | serverHandler := &SimpleServerHandler{} 100 | rpcServer.Register("SimpleServerHandler", serverHandler) 101 | 102 | testServ := httptest.NewServer(rpcServer) 103 | defer testServ.Close() 104 | 105 | var client struct { 106 | AddGet func(int) int 107 | } 108 | 109 | closer, err := NewMergeClient( 110 | context.Background(), 111 | test.urlPrefix+testServ.Listener.Addr().String(), 112 | "SimpleServerHandler", 113 | []any{&client}, 114 | nil, 115 | WithHTTPClient(testServ.Client()), 116 | WithMethodNameFormatter(test.namer), 117 | ) 118 | require.NoError(t, err) 119 | defer closer() 120 | 121 | n := client.AddGet(123) 122 | require.Equal(t, 123, n) 123 | }) 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /metrics/metrics.go: -------------------------------------------------------------------------------- 1 | package metrics 2 | 3 | import ( 4 | "go.opencensus.io/stats" 5 | "go.opencensus.io/stats/view" 6 | "go.opencensus.io/tag" 7 | ) 8 | 9 | // Global Tags 10 | var ( 11 | RPCMethod, _ = tag.NewKey("method") 12 | ) 13 | 14 | // Measures 15 | var ( 16 | RPCInvalidMethod = stats.Int64("rpc/invalid_method", "Total number of invalid RPC methods called", stats.UnitDimensionless) 17 | RPCRequestError = stats.Int64("rpc/request_error", "Total number of request errors handled", stats.UnitDimensionless) 18 | RPCResponseError = stats.Int64("rpc/response_error", "Total number of responses errors handled", stats.UnitDimensionless) 19 | ) 20 | 21 | var ( 22 | // All RPC related metrics should at the very least tag the RPCMethod 23 | RPCInvalidMethodView = &view.View{ 24 | Measure: RPCInvalidMethod, 25 | Aggregation: view.Count(), 26 | TagKeys: []tag.Key{RPCMethod}, 27 | } 28 | RPCRequestErrorView = &view.View{ 29 | Measure: RPCRequestError, 30 | Aggregation: view.Count(), 31 | TagKeys: []tag.Key{RPCMethod}, 32 | } 33 | RPCResponseErrorView = &view.View{ 34 | Measure: RPCResponseError, 35 | Aggregation: view.Count(), 36 | TagKeys: []tag.Key{RPCMethod}, 37 | } 38 | ) 39 | 40 | // DefaultViews is an array of OpenCensus views for metric gathering purposes 41 | var DefaultViews = []*view.View{ 42 | RPCInvalidMethodView, 43 | RPCRequestErrorView, 44 | RPCResponseErrorView, 45 | } 46 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package jsonrpc 2 | 3 | import ( 4 | "net/http" 5 | "reflect" 6 | "time" 7 | 8 | "github.com/gorilla/websocket" 9 | ) 10 | 11 | type ParamEncoder func(reflect.Value) (reflect.Value, error) 12 | 13 | type clientHandler struct { 14 | ns string 15 | hnd interface{} 16 | } 17 | 18 | type Config struct { 19 | reconnectBackoff backoff 20 | pingInterval time.Duration 21 | timeout time.Duration 22 | 23 | paramEncoders map[reflect.Type]ParamEncoder 24 | errors *Errors 25 | 26 | reverseHandlers []clientHandler 27 | aliasedHandlerMethods map[string]string 28 | 29 | httpClient *http.Client 30 | 31 | noReconnect bool 32 | proxyConnFactory func(func() (*websocket.Conn, error)) func() (*websocket.Conn, error) // for testing 33 | 34 | methodNamer MethodNameFormatter 35 | } 36 | 37 | func defaultConfig() Config { 38 | return Config{ 39 | reconnectBackoff: backoff{ 40 | minDelay: 100 * time.Millisecond, 41 | maxDelay: 5 * time.Second, 42 | }, 43 | pingInterval: 5 * time.Second, 44 | timeout: 30 * time.Second, 45 | 46 | aliasedHandlerMethods: map[string]string{}, 47 | 48 | paramEncoders: map[reflect.Type]ParamEncoder{}, 49 | 50 | httpClient: _defaultHTTPClient, 51 | 52 | methodNamer: DefaultMethodNameFormatter, 53 | } 54 | } 55 | 56 | type Option func(c *Config) 57 | 58 | func WithReconnectBackoff(minDelay, maxDelay time.Duration) func(c *Config) { 59 | return func(c *Config) { 60 | c.reconnectBackoff = backoff{ 61 | minDelay: minDelay, 62 | maxDelay: maxDelay, 63 | } 64 | } 65 | } 66 | 67 | // Must be < Timeout/2 68 | func WithPingInterval(d time.Duration) func(c *Config) { 69 | return func(c *Config) { 70 | c.pingInterval = d 71 | } 72 | } 73 | 74 | func WithTimeout(d time.Duration) func(c *Config) { 75 | return func(c *Config) { 76 | c.timeout = d 77 | } 78 | } 79 | 80 | func WithNoReconnect() func(c *Config) { 81 | return func(c *Config) { 82 | c.noReconnect = true 83 | } 84 | } 85 | 86 | func WithParamEncoder(t interface{}, encoder ParamEncoder) func(c *Config) { 87 | return func(c *Config) { 88 | c.paramEncoders[reflect.TypeOf(t).Elem()] = encoder 89 | } 90 | } 91 | 92 | func WithErrors(es Errors) func(c *Config) { 93 | return func(c *Config) { 94 | c.errors = &es 95 | } 96 | } 97 | 98 | func WithClientHandler(ns string, hnd interface{}) func(c *Config) { 99 | return func(c *Config) { 100 | c.reverseHandlers = append(c.reverseHandlers, clientHandler{ns, hnd}) 101 | } 102 | } 103 | 104 | // WithClientHandlerAlias creates an alias for a client HANDLER method - for handlers created 105 | // with WithClientHandler 106 | func WithClientHandlerAlias(alias, original string) func(c *Config) { 107 | return func(c *Config) { 108 | c.aliasedHandlerMethods[alias] = original 109 | } 110 | } 111 | 112 | func WithHTTPClient(h *http.Client) func(c *Config) { 113 | return func(c *Config) { 114 | c.httpClient = h 115 | } 116 | } 117 | 118 | func WithMethodNameFormatter(namer MethodNameFormatter) func(c *Config) { 119 | return func(c *Config) { 120 | c.methodNamer = namer 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /options_server.go: -------------------------------------------------------------------------------- 1 | package jsonrpc 2 | 3 | import ( 4 | "context" 5 | "reflect" 6 | "time" 7 | 8 | "golang.org/x/xerrors" 9 | ) 10 | 11 | // note: we embed reflect.Type because proxy-structs are not comparable 12 | type jsonrpcReverseClient struct{ reflect.Type } 13 | 14 | type ParamDecoder func(ctx context.Context, json []byte) (reflect.Value, error) 15 | 16 | type ServerConfig struct { 17 | maxRequestSize int64 18 | pingInterval time.Duration 19 | 20 | paramDecoders map[reflect.Type]ParamDecoder 21 | errors *Errors 22 | 23 | reverseClientBuilder func(context.Context, *wsConn) (context.Context, error) 24 | tracer Tracer 25 | methodNameFormatter MethodNameFormatter 26 | } 27 | 28 | type ServerOption func(c *ServerConfig) 29 | 30 | func defaultServerConfig() ServerConfig { 31 | return ServerConfig{ 32 | paramDecoders: map[reflect.Type]ParamDecoder{}, 33 | maxRequestSize: DEFAULT_MAX_REQUEST_SIZE, 34 | 35 | pingInterval: 5 * time.Second, 36 | methodNameFormatter: DefaultMethodNameFormatter, 37 | } 38 | } 39 | 40 | func WithParamDecoder(t interface{}, decoder ParamDecoder) ServerOption { 41 | return func(c *ServerConfig) { 42 | c.paramDecoders[reflect.TypeOf(t).Elem()] = decoder 43 | } 44 | } 45 | 46 | func WithMaxRequestSize(max int64) ServerOption { 47 | return func(c *ServerConfig) { 48 | c.maxRequestSize = max 49 | } 50 | } 51 | 52 | func WithServerErrors(es Errors) ServerOption { 53 | return func(c *ServerConfig) { 54 | c.errors = &es 55 | } 56 | } 57 | 58 | func WithServerPingInterval(d time.Duration) ServerOption { 59 | return func(c *ServerConfig) { 60 | c.pingInterval = d 61 | } 62 | } 63 | 64 | func WithServerMethodNameFormatter(formatter MethodNameFormatter) ServerOption { 65 | return func(c *ServerConfig) { 66 | c.methodNameFormatter = formatter 67 | } 68 | } 69 | 70 | // WithTracer allows the instantiator to trace the method calls and results. 71 | // This is useful for debugging a client-server interaction. 72 | func WithTracer(l Tracer) ServerOption { 73 | return func(c *ServerConfig) { 74 | c.tracer = l 75 | } 76 | } 77 | 78 | // WithReverseClient will allow extracting reverse client on **WEBSOCKET** calls. 79 | // RP is a proxy-struct type, much like the one passed to NewClient. 80 | func WithReverseClient[RP any](namespace string) ServerOption { 81 | return func(c *ServerConfig) { 82 | c.reverseClientBuilder = func(ctx context.Context, conn *wsConn) (context.Context, error) { 83 | cl := client{ 84 | namespace: namespace, 85 | paramEncoders: map[reflect.Type]ParamEncoder{}, 86 | methodNameFormatter: c.methodNameFormatter, 87 | } 88 | 89 | // todo test that everything is closing correctly 90 | cl.exiting = conn.exiting 91 | 92 | requests := cl.setupRequestChan() 93 | conn.requests = requests 94 | 95 | calls := new(RP) 96 | 97 | err := cl.provide([]interface{}{ 98 | calls, 99 | }) 100 | if err != nil { 101 | return nil, xerrors.Errorf("provide reverse client calls: %w", err) 102 | } 103 | 104 | return context.WithValue(ctx, jsonrpcReverseClient{reflect.TypeOf(calls).Elem()}, calls), nil 105 | } 106 | } 107 | } 108 | 109 | // ExtractReverseClient will extract reverse client from context. Reverse client for the type 110 | // will only be present if the server was constructed with a matching WithReverseClient option 111 | // and the connection was a websocket connection. 112 | // If there is no reverse client, the call will return a zero value and `false`. Otherwise a reverse 113 | // client and `true` will be returned. 114 | func ExtractReverseClient[C any](ctx context.Context) (C, bool) { 115 | c, ok := ctx.Value(jsonrpcReverseClient{reflect.TypeOf(new(C)).Elem()}).(*C) 116 | if !ok { 117 | return *new(C), false 118 | } 119 | if c == nil { 120 | // something is very wrong, but don't panic 121 | return *new(C), false 122 | } 123 | 124 | return *c, ok 125 | } 126 | -------------------------------------------------------------------------------- /resp_error_test.go: -------------------------------------------------------------------------------- 1 | package jsonrpc 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | ) 10 | 11 | type ComplexData struct { 12 | Foo string `json:"foo"` 13 | Bar int `json:"bar"` 14 | } 15 | 16 | type StaticError struct{} 17 | 18 | func (e *StaticError) Error() string { return "static error" } 19 | 20 | // Define the error types 21 | type SimpleError struct { 22 | Message string 23 | } 24 | 25 | func (e *SimpleError) Error() string { 26 | return e.Message 27 | } 28 | 29 | func (e *SimpleError) FromJSONRPCError(jerr JSONRPCError) error { 30 | e.Message = jerr.Message 31 | return nil 32 | } 33 | 34 | func (e *SimpleError) ToJSONRPCError() (JSONRPCError, error) { 35 | return JSONRPCError{Message: e.Message}, nil 36 | } 37 | 38 | var _ RPCErrorCodec = (*SimpleError)(nil) 39 | 40 | type DataStringError struct { 41 | Message string `json:"message"` 42 | Data string `json:"data"` 43 | } 44 | 45 | func (e *DataStringError) Error() string { 46 | return e.Message 47 | } 48 | 49 | func (e *DataStringError) FromJSONRPCError(jerr JSONRPCError) error { 50 | e.Message = jerr.Message 51 | data, ok := jerr.Data.(string) 52 | if !ok { 53 | return fmt.Errorf("expected string data, got %T", jerr.Data) 54 | } 55 | 56 | e.Data = data 57 | 58 | return nil 59 | } 60 | 61 | func (e *DataStringError) ToJSONRPCError() (JSONRPCError, error) { 62 | return JSONRPCError{Message: e.Message, Data: e.Data}, nil 63 | } 64 | 65 | var _ RPCErrorCodec = (*DataStringError)(nil) 66 | 67 | type DataComplexError struct { 68 | Message string 69 | internalData ComplexData 70 | } 71 | 72 | func (e *DataComplexError) Error() string { 73 | return e.Message 74 | } 75 | 76 | func (e *DataComplexError) FromJSONRPCError(jerr JSONRPCError) error { 77 | e.Message = jerr.Message 78 | data, ok := jerr.Data.(json.RawMessage) 79 | if !ok { 80 | return fmt.Errorf("expected string data, got %T", jerr.Data) 81 | } 82 | 83 | if err := json.Unmarshal(data, &e.internalData); err != nil { 84 | return err 85 | } 86 | return nil 87 | } 88 | 89 | func (e *DataComplexError) ToJSONRPCError() (JSONRPCError, error) { 90 | data, err := json.Marshal(e.internalData) 91 | if err != nil { 92 | return JSONRPCError{}, err 93 | } 94 | return JSONRPCError{Message: e.Message, Data: data}, nil 95 | } 96 | 97 | var _ RPCErrorCodec = (*DataComplexError)(nil) 98 | 99 | type MetaError struct { 100 | Message string 101 | Details string 102 | } 103 | 104 | func (e *MetaError) Error() string { 105 | return e.Message 106 | } 107 | 108 | func (e *MetaError) MarshalJSON() ([]byte, error) { 109 | return json.Marshal(struct { 110 | Message string `json:"message"` 111 | Details string `json:"details"` 112 | }{ 113 | Message: e.Message, 114 | Details: e.Details, 115 | }) 116 | } 117 | 118 | func (e *MetaError) UnmarshalJSON(data []byte) error { 119 | var temp struct { 120 | Message string `json:"message"` 121 | Details string `json:"details"` 122 | } 123 | if err := json.Unmarshal(data, &temp); err != nil { 124 | return err 125 | } 126 | 127 | e.Message = temp.Message 128 | e.Details = temp.Details 129 | return nil 130 | } 131 | 132 | type ComplexError struct { 133 | Message string 134 | Data ComplexData 135 | Details string 136 | } 137 | 138 | func (e *ComplexError) Error() string { 139 | return e.Message 140 | } 141 | 142 | func (e *ComplexError) MarshalJSON() ([]byte, error) { 143 | return json.Marshal(struct { 144 | Message string `json:"message"` 145 | Details string `json:"details"` 146 | Data any `json:"data"` 147 | }{ 148 | Details: e.Details, 149 | Message: e.Message, 150 | Data: e.Data, 151 | }) 152 | } 153 | 154 | func (e *ComplexError) UnmarshalJSON(data []byte) error { 155 | var temp struct { 156 | Message string `json:"message"` 157 | Details string `json:"details"` 158 | Data ComplexData `json:"data"` 159 | } 160 | if err := json.Unmarshal(data, &temp); err != nil { 161 | return err 162 | } 163 | e.Details = temp.Details 164 | e.Message = temp.Message 165 | e.Data = temp.Data 166 | return nil 167 | } 168 | 169 | func TestRespErrorVal(t *testing.T) { 170 | // Initialize the Errors struct and register error types 171 | errorsMap := NewErrors() 172 | errorsMap.Register(1000, new(*StaticError)) 173 | errorsMap.Register(1001, new(*SimpleError)) 174 | errorsMap.Register(1002, new(*DataStringError)) 175 | errorsMap.Register(1003, new(*DataComplexError)) 176 | errorsMap.Register(1004, new(*MetaError)) 177 | errorsMap.Register(1005, new(*ComplexError)) 178 | 179 | // Define test cases 180 | testCases := []struct { 181 | name string 182 | respError *JSONRPCError 183 | expectedType interface{} 184 | expectedMessage string 185 | verify func(t *testing.T, err error) 186 | }{ 187 | { 188 | name: "StaticError", 189 | respError: &JSONRPCError{ 190 | Code: 1000, 191 | Message: "this is ignored", 192 | }, 193 | expectedType: &StaticError{}, 194 | expectedMessage: "static error", 195 | }, 196 | { 197 | name: "SimpleError", 198 | respError: &JSONRPCError{ 199 | Code: 1001, 200 | Message: "simple error occurred", 201 | }, 202 | expectedType: &SimpleError{}, 203 | expectedMessage: "simple error occurred", 204 | }, 205 | { 206 | name: "DataStringError", 207 | respError: &JSONRPCError{ 208 | Code: 1002, 209 | Message: "data error occurred", 210 | Data: "additional data", 211 | }, 212 | expectedType: &DataStringError{}, 213 | expectedMessage: "data error occurred", 214 | verify: func(t *testing.T, err error) { 215 | require.IsType(t, &DataStringError{}, err) 216 | require.Equal(t, "data error occurred", err.Error()) 217 | require.Equal(t, "additional data", err.(*DataStringError).Data) 218 | }, 219 | }, 220 | { 221 | name: "DataComplexError", 222 | respError: &JSONRPCError{ 223 | Code: 1003, 224 | Message: "data error occurred", 225 | Data: json.RawMessage(`{"foo":"boop","bar":101}`), 226 | }, 227 | expectedType: &DataComplexError{}, 228 | expectedMessage: "data error occurred", 229 | verify: func(t *testing.T, err error) { 230 | require.Equal(t, ComplexData{Foo: "boop", Bar: 101}, err.(*DataComplexError).internalData) 231 | }, 232 | }, 233 | { 234 | name: "MetaError", 235 | respError: &JSONRPCError{ 236 | Code: 1004, 237 | Message: "meta error occurred", 238 | Meta: func() json.RawMessage { 239 | me := &MetaError{ 240 | Message: "meta error occurred", 241 | Details: "meta details", 242 | } 243 | metaData, _ := me.MarshalJSON() 244 | return metaData 245 | }(), 246 | }, 247 | expectedType: &MetaError{}, 248 | expectedMessage: "meta error occurred", 249 | verify: func(t *testing.T, err error) { 250 | // details will also be included in the error message since it implements the marshable interface 251 | require.Equal(t, "meta details", err.(*MetaError).Details) 252 | }, 253 | }, 254 | { 255 | name: "ComplexError", 256 | respError: &JSONRPCError{ 257 | Code: 1005, 258 | Message: "complex error occurred", 259 | Data: json.RawMessage(`"complex data"`), 260 | Meta: func() json.RawMessage { 261 | ce := &ComplexError{ 262 | Message: "complex error occurred", 263 | Details: "complex details", 264 | Data: ComplexData{Foo: "foo", Bar: 42}, 265 | } 266 | metaData, _ := ce.MarshalJSON() 267 | return metaData 268 | }(), 269 | }, 270 | expectedType: &ComplexError{}, 271 | expectedMessage: "complex error occurred", 272 | verify: func(t *testing.T, err error) { 273 | require.Equal(t, ComplexData{Foo: "foo", Bar: 42}, err.(*ComplexError).Data) 274 | require.Equal(t, "complex details", err.(*ComplexError).Details) 275 | }, 276 | }, 277 | { 278 | name: "UnregisteredError", 279 | respError: &JSONRPCError{ 280 | Code: 9999, 281 | Message: "unregistered error occurred", 282 | Data: json.RawMessage(`"some data"`), 283 | }, 284 | expectedType: &JSONRPCError{}, 285 | expectedMessage: "unregistered error occurred", 286 | verify: func(t *testing.T, err error) { 287 | require.Equal(t, json.RawMessage(`"some data"`), err.(*JSONRPCError).Data) 288 | }, 289 | }, 290 | } 291 | 292 | for _, tc := range testCases { 293 | tc := tc 294 | t.Run(tc.name, func(t *testing.T) { 295 | errValue := tc.respError.val(&errorsMap) 296 | errInterface := errValue.Interface() 297 | err, ok := errInterface.(error) 298 | require.True(t, ok, "returned value does not implement error interface") 299 | require.IsType(t, tc.expectedType, err) 300 | require.Equal(t, tc.expectedMessage, err.Error()) 301 | if tc.verify != nil { 302 | tc.verify(t, err) 303 | } 304 | }) 305 | } 306 | } 307 | -------------------------------------------------------------------------------- /response.go: -------------------------------------------------------------------------------- 1 | package jsonrpc 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "reflect" 7 | ) 8 | 9 | type response struct { 10 | Jsonrpc string `json:"jsonrpc"` 11 | Result interface{} `json:"result,omitempty"` 12 | ID interface{} `json:"id"` 13 | Error *JSONRPCError `json:"error,omitempty"` 14 | } 15 | 16 | func (r response) MarshalJSON() ([]byte, error) { 17 | // Custom marshal logic as per JSON-RPC 2.0 spec: 18 | // > `result`: 19 | // > This member is REQUIRED on success. 20 | // > This member MUST NOT exist if there was an error invoking the method. 21 | // 22 | // > `error`: 23 | // > This member is REQUIRED on error. 24 | // > This member MUST NOT exist if there was no error triggered during invocation. 25 | data := map[string]interface{}{ 26 | "jsonrpc": r.Jsonrpc, 27 | "id": r.ID, 28 | } 29 | 30 | if r.Error != nil { 31 | data["error"] = r.Error 32 | } else { 33 | data["result"] = r.Result 34 | } 35 | return json.Marshal(data) 36 | } 37 | 38 | type JSONRPCError struct { 39 | Code ErrorCode `json:"code"` 40 | Message string `json:"message"` 41 | Meta json.RawMessage `json:"meta,omitempty"` 42 | Data interface{} `json:"data,omitempty"` 43 | } 44 | 45 | func (e *JSONRPCError) Error() string { 46 | if e.Code >= -32768 && e.Code <= -32000 { 47 | return fmt.Sprintf("RPC error (%d): %s", e.Code, e.Message) 48 | } 49 | return e.Message 50 | } 51 | 52 | var ( 53 | _ error = (*JSONRPCError)(nil) 54 | marshalableRT = reflect.TypeOf(new(marshalable)).Elem() 55 | errorCodecRT = reflect.TypeOf(new(RPCErrorCodec)).Elem() 56 | ) 57 | 58 | func (e *JSONRPCError) val(errors *Errors) reflect.Value { 59 | if errors != nil { 60 | t, ok := errors.byCode[e.Code] 61 | if ok { 62 | var v reflect.Value 63 | if t.Kind() == reflect.Ptr { 64 | v = reflect.New(t.Elem()) 65 | } else { 66 | v = reflect.New(t) 67 | } 68 | 69 | if v.Type().Implements(errorCodecRT) { 70 | if err := v.Interface().(RPCErrorCodec).FromJSONRPCError(*e); err != nil { 71 | log.Errorf("Error converting JSONRPCError to custom error type '%s' (code %d): %w", t.String(), e.Code, err) 72 | return reflect.ValueOf(e) 73 | } 74 | } else if len(e.Meta) > 0 && v.Type().Implements(marshalableRT) { 75 | if err := v.Interface().(marshalable).UnmarshalJSON(e.Meta); err != nil { 76 | log.Errorf("Error unmarshalling error metadata to custom error type '%s' (code %d): %w", t.String(), e.Code, err) 77 | return reflect.ValueOf(e) 78 | } 79 | } 80 | 81 | if t.Kind() != reflect.Ptr { 82 | v = v.Elem() 83 | } 84 | return v 85 | } 86 | } 87 | 88 | return reflect.ValueOf(e) 89 | } 90 | -------------------------------------------------------------------------------- /rpc_test.go: -------------------------------------------------------------------------------- 1 | package jsonrpc 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "net" 11 | "net/http" 12 | "net/http/httptest" 13 | "os" 14 | "reflect" 15 | "strconv" 16 | "strings" 17 | "sync" 18 | "sync/atomic" 19 | "testing" 20 | "time" 21 | 22 | "github.com/gorilla/websocket" 23 | logging "github.com/ipfs/go-log/v2" 24 | "github.com/stretchr/testify/assert" 25 | "github.com/stretchr/testify/require" 26 | "golang.org/x/xerrors" 27 | ) 28 | 29 | func init() { 30 | if _, exists := os.LookupEnv("GOLOG_LOG_LEVEL"); !exists { 31 | if err := logging.SetLogLevel("rpc", "DEBUG"); err != nil { 32 | panic(err) 33 | } 34 | } 35 | 36 | debugTrace = true 37 | } 38 | 39 | type SimpleServerHandler struct { 40 | n int32 41 | } 42 | 43 | type TestType struct { 44 | S string 45 | I int 46 | } 47 | 48 | type TestOut struct { 49 | TestType 50 | Ok bool 51 | } 52 | 53 | func (h *SimpleServerHandler) Inc() error { 54 | h.n++ 55 | 56 | return nil 57 | } 58 | 59 | func (h *SimpleServerHandler) Add(in int) error { 60 | if in == -3546 { 61 | return errors.New("test") 62 | } 63 | 64 | atomic.AddInt32(&h.n, int32(in)) 65 | 66 | return nil 67 | } 68 | 69 | func (h *SimpleServerHandler) AddGet(in int) int { 70 | atomic.AddInt32(&h.n, int32(in)) 71 | return int(h.n) 72 | } 73 | 74 | func (h *SimpleServerHandler) StringMatch(t TestType, i2 int64) (out TestOut, err error) { 75 | if strconv.FormatInt(i2, 10) == t.S { 76 | out.Ok = true 77 | } 78 | if i2 != int64(t.I) { 79 | return TestOut{}, errors.New(":(") 80 | } 81 | out.I = t.I 82 | out.S = t.S 83 | return 84 | } 85 | 86 | func TestRawRequests(t *testing.T) { 87 | rpcHandler := SimpleServerHandler{} 88 | 89 | rpcServer := NewServer() 90 | rpcServer.Register("SimpleServerHandler", &rpcHandler) 91 | 92 | testServ := httptest.NewServer(rpcServer) 93 | defer testServ.Close() 94 | 95 | removeSpaces := func(jsonStr string) (string, error) { 96 | var jsonObj interface{} 97 | err := json.Unmarshal([]byte(jsonStr), &jsonObj) 98 | if err != nil { 99 | return "", err 100 | } 101 | 102 | compactJSONBytes, err := json.Marshal(jsonObj) 103 | if err != nil { 104 | return "", err 105 | } 106 | 107 | return string(compactJSONBytes), nil 108 | } 109 | 110 | tc := func(req, resp string, n int32, statusCode int) func(t *testing.T) { 111 | return func(t *testing.T) { 112 | rpcHandler.n = 0 113 | 114 | res, err := http.Post(testServ.URL, "application/json", strings.NewReader(req)) 115 | require.NoError(t, err) 116 | 117 | b, err := io.ReadAll(res.Body) 118 | require.NoError(t, err) 119 | 120 | expectedResp, err := removeSpaces(resp) 121 | require.NoError(t, err) 122 | 123 | responseBody, err := removeSpaces(string(b)) 124 | require.NoError(t, err) 125 | 126 | assert.Equal(t, expectedResp, responseBody) 127 | require.Equal(t, n, rpcHandler.n) 128 | require.Equal(t, statusCode, res.StatusCode) 129 | } 130 | } 131 | 132 | t.Run("inc", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "params": [], "id": 1}`, `{"jsonrpc":"2.0","id":1,"result":null}`, 1, 200)) 133 | t.Run("inc-null", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "params": null, "id": 1}`, `{"jsonrpc":"2.0","id":1,"result":null}`, 1, 200)) 134 | t.Run("inc-noparam", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Inc", "id": 2}`, `{"jsonrpc":"2.0","id":2,"result":null}`, 1, 200)) 135 | t.Run("add", tc(`{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [10], "id": 4}`, `{"jsonrpc":"2.0","id":4,"result":null}`, 10, 200)) 136 | // Batch requests 137 | t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 5}`, `{"jsonrpc":"2.0","id":null,"error":{"code":-32700,"message":"Parse error"}}`, 0, 500)) 138 | t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 6}]`, `[{"jsonrpc":"2.0","id":6,"result":null}]`, 123, 200)) 139 | t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 7},{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [-122], "id": 8}]`, `[{"jsonrpc":"2.0","id":7,"result":null},{"jsonrpc":"2.0","id":8,"result":null}]`, 1, 200)) 140 | t.Run("add", tc(`[{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [123], "id": 9},{"jsonrpc": "2.0", "params": [-122], "id": 10}]`, `[{"jsonrpc":"2.0","id":9,"result":null},{"error":{"code":-32601,"message":"method '' not found"},"id":10,"jsonrpc":"2.0"}]`, 123, 200)) 141 | t.Run("add", tc(` [{"jsonrpc": "2.0", "method": "SimpleServerHandler.Add", "params": [-1], "id": 11}] `, `[{"jsonrpc":"2.0","id":11,"result":null}]`, -1, 200)) 142 | t.Run("add", tc(``, `{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"Invalid request"}}`, 0, 400)) 143 | } 144 | 145 | func TestReconnection(t *testing.T) { 146 | var rpcClient struct { 147 | Add func(int) error 148 | } 149 | 150 | rpcHandler := SimpleServerHandler{} 151 | 152 | rpcServer := NewServer() 153 | rpcServer.Register("SimpleServerHandler", &rpcHandler) 154 | 155 | testServ := httptest.NewServer(rpcServer) 156 | defer testServ.Close() 157 | 158 | // capture connection attempts for this duration 159 | captureDuration := 3 * time.Second 160 | 161 | // run the test until the timer expires 162 | timer := time.NewTimer(captureDuration) 163 | 164 | // record the number of connection attempts during this test 165 | connectionAttempts := int64(1) 166 | 167 | closer, err := NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", []interface{}{&rpcClient}, nil, func(c *Config) { 168 | c.proxyConnFactory = func(f func() (*websocket.Conn, error)) func() (*websocket.Conn, error) { 169 | return func() (*websocket.Conn, error) { 170 | defer func() { 171 | atomic.AddInt64(&connectionAttempts, 1) 172 | }() 173 | 174 | if atomic.LoadInt64(&connectionAttempts) > 1 { 175 | return nil, errors.New("simulates a failed reconnect attempt") 176 | } 177 | 178 | c, err := f() 179 | if err != nil { 180 | return nil, err 181 | } 182 | 183 | // closing the connection here triggers the reconnect logic 184 | _ = c.Close() 185 | 186 | return c, nil 187 | } 188 | } 189 | }) 190 | require.NoError(t, err) 191 | defer closer() 192 | 193 | // let the JSON-RPC library attempt to reconnect until the timer runs out 194 | <-timer.C 195 | 196 | // do some math 197 | attemptsPerSecond := atomic.LoadInt64(&connectionAttempts) / int64(captureDuration/time.Second) 198 | 199 | assert.Less(t, attemptsPerSecond, int64(50)) 200 | } 201 | 202 | func (h *SimpleServerHandler) ErrChanSub(ctx context.Context) (<-chan int, error) { 203 | return nil, errors.New("expect to return an error") 204 | } 205 | 206 | func TestRPCBadConnection(t *testing.T) { 207 | // setup server 208 | 209 | serverHandler := &SimpleServerHandler{} 210 | 211 | rpcServer := NewServer() 212 | rpcServer.Register("SimpleServerHandler", serverHandler) 213 | 214 | // httptest stuff 215 | testServ := httptest.NewServer(rpcServer) 216 | defer testServ.Close() 217 | // setup client 218 | 219 | var client struct { 220 | Add func(int) error 221 | AddGet func(int) int 222 | StringMatch func(t TestType, i2 int64) (out TestOut, err error) 223 | ErrChanSub func(context.Context) (<-chan int, error) 224 | } 225 | closer, err := NewClient(context.Background(), "http://"+testServ.Listener.Addr().String()+"0", "SimpleServerHandler", &client, nil) 226 | require.NoError(t, err) 227 | err = client.Add(2) 228 | require.True(t, errors.As(err, new(*RPCConnectionError))) 229 | 230 | defer closer() 231 | 232 | } 233 | 234 | func TestRPC(t *testing.T) { 235 | // setup server 236 | 237 | serverHandler := &SimpleServerHandler{} 238 | 239 | rpcServer := NewServer() 240 | rpcServer.Register("SimpleServerHandler", serverHandler) 241 | 242 | // httptest stuff 243 | testServ := httptest.NewServer(rpcServer) 244 | defer testServ.Close() 245 | // setup client 246 | 247 | var client struct { 248 | Add func(int) error 249 | AddGet func(int) int 250 | StringMatch func(t TestType, i2 int64) (out TestOut, err error) 251 | ErrChanSub func(context.Context) (<-chan int, error) 252 | } 253 | closer, err := NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &client, nil) 254 | require.NoError(t, err) 255 | defer closer() 256 | 257 | // Add(int) error 258 | 259 | require.NoError(t, client.Add(2)) 260 | require.Equal(t, 2, int(serverHandler.n)) 261 | 262 | err = client.Add(-3546) 263 | require.EqualError(t, err, "test") 264 | 265 | // AddGet(int) int 266 | 267 | n := client.AddGet(3) 268 | require.Equal(t, 5, n) 269 | require.Equal(t, 5, int(serverHandler.n)) 270 | 271 | // StringMatch 272 | 273 | o, err := client.StringMatch(TestType{S: "0"}, 0) 274 | require.NoError(t, err) 275 | require.Equal(t, "0", o.S) 276 | require.Equal(t, 0, o.I) 277 | 278 | _, err = client.StringMatch(TestType{S: "5"}, 5) 279 | require.EqualError(t, err, ":(") 280 | 281 | o, err = client.StringMatch(TestType{S: "8", I: 8}, 8) 282 | require.NoError(t, err) 283 | require.Equal(t, "8", o.S) 284 | require.Equal(t, 8, o.I) 285 | 286 | // ErrChanSub 287 | ctx := context.TODO() 288 | _, err = client.ErrChanSub(ctx) 289 | if err == nil { 290 | t.Fatal("expect an err return, but got nil") 291 | } 292 | 293 | // Invalid client handlers 294 | 295 | var noret struct { 296 | Add func(int) 297 | } 298 | closer, err = NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &noret, nil) 299 | require.NoError(t, err) 300 | 301 | // this one should actually work 302 | noret.Add(4) 303 | require.Equal(t, 9, int(serverHandler.n)) 304 | closer() 305 | 306 | var noparam struct { 307 | Add func() 308 | } 309 | closer, err = NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &noparam, nil) 310 | require.NoError(t, err) 311 | 312 | // shouldn't panic 313 | noparam.Add() 314 | closer() 315 | 316 | var erronly struct { 317 | AddGet func() (int, error) 318 | } 319 | closer, err = NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &erronly, nil) 320 | require.NoError(t, err) 321 | 322 | _, err = erronly.AddGet() 323 | if err == nil || err.Error() != "RPC error (-32602): wrong param count (method 'SimpleServerHandler.AddGet'): 0 != 1" { 324 | t.Error("wrong error:", err) 325 | } 326 | closer() 327 | 328 | var wrongtype struct { 329 | Add func(string) error 330 | } 331 | closer, err = NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &wrongtype, nil) 332 | require.NoError(t, err) 333 | 334 | err = wrongtype.Add("not an int") 335 | if err == nil || !strings.Contains(err.Error(), "RPC error (-32700):") || !strings.Contains(err.Error(), "json: cannot unmarshal string into Go value of type int") { 336 | t.Error("wrong error:", err) 337 | } 338 | closer() 339 | 340 | var notfound struct { 341 | NotThere func(string) error 342 | } 343 | closer, err = NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", ¬found, nil) 344 | require.NoError(t, err) 345 | 346 | err = notfound.NotThere("hello?") 347 | if err == nil || err.Error() != "RPC error (-32601): method 'SimpleServerHandler.NotThere' not found" { 348 | t.Error("wrong error:", err) 349 | } 350 | closer() 351 | } 352 | 353 | func TestRPCHttpClient(t *testing.T) { 354 | // setup server 355 | 356 | serverHandler := &SimpleServerHandler{} 357 | 358 | rpcServer := NewServer() 359 | rpcServer.Register("SimpleServerHandler", serverHandler) 360 | 361 | // httptest stuff 362 | testServ := httptest.NewServer(rpcServer) 363 | defer testServ.Close() 364 | // setup client 365 | 366 | var client struct { 367 | Add func(int) error 368 | AddGet func(int) int 369 | StringMatch func(t TestType, i2 int64) (out TestOut, err error) 370 | } 371 | closer, err := NewClient(context.Background(), "http://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &client, nil) 372 | require.NoError(t, err) 373 | defer closer() 374 | 375 | // Add(int) error 376 | 377 | require.NoError(t, client.Add(2)) 378 | require.Equal(t, 2, int(serverHandler.n)) 379 | 380 | err = client.Add(-3546) 381 | require.EqualError(t, err, "test") 382 | 383 | // AddGet(int) int 384 | 385 | n := client.AddGet(3) 386 | require.Equal(t, 5, n) 387 | require.Equal(t, 5, int(serverHandler.n)) 388 | 389 | // StringMatch 390 | 391 | o, err := client.StringMatch(TestType{S: "0"}, 0) 392 | require.NoError(t, err) 393 | require.Equal(t, "0", o.S) 394 | require.Equal(t, 0, o.I) 395 | 396 | _, err = client.StringMatch(TestType{S: "5"}, 5) 397 | require.EqualError(t, err, ":(") 398 | 399 | o, err = client.StringMatch(TestType{S: "8", I: 8}, 8) 400 | require.NoError(t, err) 401 | require.Equal(t, "8", o.S) 402 | require.Equal(t, 8, o.I) 403 | 404 | // Invalid client handlers 405 | 406 | var noret struct { 407 | Add func(int) 408 | } 409 | closer, err = NewClient(context.Background(), "http://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &noret, nil) 410 | require.NoError(t, err) 411 | 412 | // this one should actually work 413 | noret.Add(4) 414 | require.Equal(t, 9, int(serverHandler.n)) 415 | closer() 416 | 417 | var noparam struct { 418 | Add func() 419 | } 420 | closer, err = NewClient(context.Background(), "http://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &noparam, nil) 421 | require.NoError(t, err) 422 | 423 | // shouldn't panic 424 | noparam.Add() 425 | closer() 426 | 427 | var erronly struct { 428 | AddGet func() (int, error) 429 | } 430 | closer, err = NewClient(context.Background(), "http://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &erronly, nil) 431 | require.NoError(t, err) 432 | 433 | _, err = erronly.AddGet() 434 | if err == nil || err.Error() != "RPC error (-32602): wrong param count (method 'SimpleServerHandler.AddGet'): 0 != 1" { 435 | t.Error("wrong error:", err) 436 | } 437 | closer() 438 | 439 | var wrongtype struct { 440 | Add func(string) error 441 | } 442 | closer, err = NewClient(context.Background(), "http://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &wrongtype, nil) 443 | require.NoError(t, err) 444 | 445 | err = wrongtype.Add("not an int") 446 | if err == nil || !strings.Contains(err.Error(), "RPC error (-32700):") || !strings.Contains(err.Error(), "json: cannot unmarshal string into Go value of type int") { 447 | t.Error("wrong error:", err) 448 | } 449 | closer() 450 | 451 | var notfound struct { 452 | NotThere func(string) error 453 | } 454 | closer, err = NewClient(context.Background(), "http://"+testServ.Listener.Addr().String(), "SimpleServerHandler", ¬found, nil) 455 | require.NoError(t, err) 456 | 457 | err = notfound.NotThere("hello?") 458 | if err == nil || err.Error() != "RPC error (-32601): method 'SimpleServerHandler.NotThere' not found" { 459 | t.Error("wrong error:", err) 460 | } 461 | closer() 462 | } 463 | 464 | func TestParallelRPC(t *testing.T) { 465 | // setup server 466 | 467 | serverHandler := &SimpleServerHandler{} 468 | 469 | rpcServer := NewServer() 470 | rpcServer.Register("SimpleServerHandler", serverHandler) 471 | 472 | // httptest stuff 473 | testServ := httptest.NewServer(rpcServer) 474 | defer testServ.Close() 475 | // setup client 476 | 477 | var client struct { 478 | Add func(int) error 479 | } 480 | closer, err := NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &client, nil) 481 | require.NoError(t, err) 482 | defer closer() 483 | 484 | var wg sync.WaitGroup 485 | for i := 0; i < 100; i++ { 486 | wg.Add(1) 487 | go func() { 488 | defer wg.Done() 489 | for j := 0; j < 100; j++ { 490 | require.NoError(t, client.Add(2)) 491 | } 492 | }() 493 | } 494 | wg.Wait() 495 | 496 | require.Equal(t, 20000, int(serverHandler.n)) 497 | } 498 | 499 | type CtxHandler struct { 500 | lk sync.Mutex 501 | 502 | cancelled bool 503 | i int 504 | connectionType ConnectionType 505 | } 506 | 507 | func (h *CtxHandler) Test(ctx context.Context) { 508 | h.lk.Lock() 509 | defer h.lk.Unlock() 510 | timeout := time.After(300 * time.Millisecond) 511 | h.i++ 512 | h.connectionType = GetConnectionType(ctx) 513 | 514 | select { 515 | case <-timeout: 516 | case <-ctx.Done(): 517 | h.cancelled = true 518 | } 519 | } 520 | 521 | func TestCtx(t *testing.T) { 522 | // setup server 523 | 524 | serverHandler := &CtxHandler{} 525 | 526 | rpcServer := NewServer() 527 | rpcServer.Register("CtxHandler", serverHandler) 528 | 529 | // httptest stuff 530 | testServ := httptest.NewServer(rpcServer) 531 | defer testServ.Close() 532 | 533 | // setup client 534 | 535 | var client struct { 536 | Test func(ctx context.Context) 537 | } 538 | closer, err := NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "CtxHandler", &client, nil) 539 | require.NoError(t, err) 540 | 541 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 542 | defer cancel() 543 | 544 | client.Test(ctx) 545 | serverHandler.lk.Lock() 546 | 547 | if !serverHandler.cancelled { 548 | t.Error("expected cancellation on the server side") 549 | } 550 | if serverHandler.connectionType != ConnectionTypeWS { 551 | t.Error("wrong connection type") 552 | } 553 | 554 | serverHandler.cancelled = false 555 | 556 | serverHandler.lk.Unlock() 557 | closer() 558 | 559 | var noCtxClient struct { 560 | Test func() 561 | } 562 | closer, err = NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "CtxHandler", &noCtxClient, nil) 563 | if err != nil { 564 | t.Fatal(err) 565 | } 566 | 567 | noCtxClient.Test() 568 | 569 | serverHandler.lk.Lock() 570 | 571 | if serverHandler.cancelled || serverHandler.i != 2 { 572 | t.Error("wrong serverHandler state") 573 | } 574 | if serverHandler.connectionType != ConnectionTypeWS { 575 | t.Error("wrong connection type") 576 | } 577 | 578 | serverHandler.lk.Unlock() 579 | closer() 580 | } 581 | 582 | func TestCtxHttp(t *testing.T) { 583 | // setup server 584 | 585 | serverHandler := &CtxHandler{} 586 | 587 | rpcServer := NewServer() 588 | rpcServer.Register("CtxHandler", serverHandler) 589 | 590 | // httptest stuff 591 | testServ := httptest.NewServer(rpcServer) 592 | defer testServ.Close() 593 | 594 | // setup client 595 | 596 | var client struct { 597 | Test func(ctx context.Context) 598 | } 599 | closer, err := NewClient(context.Background(), "http://"+testServ.Listener.Addr().String(), "CtxHandler", &client, nil) 600 | require.NoError(t, err) 601 | 602 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 603 | defer cancel() 604 | 605 | client.Test(ctx) 606 | serverHandler.lk.Lock() 607 | 608 | if !serverHandler.cancelled { 609 | t.Error("expected cancellation on the server side") 610 | } 611 | if serverHandler.connectionType != ConnectionTypeHTTP { 612 | t.Error("wrong connection type") 613 | } 614 | 615 | serverHandler.cancelled = false 616 | 617 | serverHandler.lk.Unlock() 618 | closer() 619 | 620 | var noCtxClient struct { 621 | Test func() 622 | } 623 | closer, err = NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "CtxHandler", &noCtxClient, nil) 624 | if err != nil { 625 | t.Fatal(err) 626 | } 627 | 628 | noCtxClient.Test() 629 | 630 | serverHandler.lk.Lock() 631 | 632 | if serverHandler.cancelled || serverHandler.i != 2 { 633 | t.Error("wrong serverHandler state") 634 | } 635 | // connection type should have switched to WS 636 | if serverHandler.connectionType != ConnectionTypeWS { 637 | t.Error("wrong connection type") 638 | } 639 | 640 | serverHandler.lk.Unlock() 641 | closer() 642 | } 643 | 644 | type UnUnmarshalable int 645 | 646 | func (*UnUnmarshalable) UnmarshalJSON([]byte) error { 647 | return errors.New("nope") 648 | } 649 | 650 | type UnUnmarshalableHandler struct{} 651 | 652 | func (*UnUnmarshalableHandler) GetUnUnmarshalableStuff() (UnUnmarshalable, error) { 653 | return UnUnmarshalable(5), nil 654 | } 655 | 656 | func TestUnmarshalableResult(t *testing.T) { 657 | var client struct { 658 | GetUnUnmarshalableStuff func() (UnUnmarshalable, error) 659 | } 660 | 661 | rpcServer := NewServer() 662 | rpcServer.Register("Handler", &UnUnmarshalableHandler{}) 663 | 664 | testServ := httptest.NewServer(rpcServer) 665 | defer testServ.Close() 666 | 667 | closer, err := NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "Handler", &client, nil) 668 | require.NoError(t, err) 669 | defer closer() 670 | 671 | _, err = client.GetUnUnmarshalableStuff() 672 | require.EqualError(t, err, "RPC client error: unmarshaling result: nope") 673 | } 674 | 675 | type ChanHandler struct { 676 | wait chan struct{} 677 | ctxdone <-chan struct{} 678 | } 679 | 680 | func (h *ChanHandler) Sub(ctx context.Context, i int, eq int) (<-chan int, error) { 681 | out := make(chan int) 682 | h.ctxdone = ctx.Done() 683 | 684 | wait := h.wait 685 | 686 | log.Warnf("SERVER SUB!") 687 | go func() { 688 | defer close(out) 689 | var n int 690 | 691 | for { 692 | select { 693 | case <-ctx.Done(): 694 | fmt.Println("ctxdone1", i, eq) 695 | return 696 | case <-wait: 697 | //fmt.Println("CONSUMED WAIT: ", i) 698 | } 699 | 700 | n += i 701 | 702 | if n == eq { 703 | fmt.Println("eq") 704 | return 705 | } 706 | 707 | select { 708 | case <-ctx.Done(): 709 | fmt.Println("ctxdone2") 710 | return 711 | case out <- n: 712 | } 713 | } 714 | }() 715 | 716 | return out, nil 717 | } 718 | 719 | func TestChan(t *testing.T) { 720 | var client struct { 721 | Sub func(context.Context, int, int) (<-chan int, error) 722 | } 723 | 724 | serverHandler := &ChanHandler{ 725 | wait: make(chan struct{}, 5), 726 | } 727 | 728 | rpcServer := NewServer() 729 | rpcServer.Register("ChanHandler", serverHandler) 730 | 731 | testServ := httptest.NewServer(rpcServer) 732 | defer testServ.Close() 733 | 734 | closer, err := NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "ChanHandler", &client, nil) 735 | require.NoError(t, err) 736 | 737 | defer closer() 738 | 739 | serverHandler.wait <- struct{}{} 740 | 741 | ctx, cancel := context.WithCancel(context.Background()) 742 | 743 | // sub 744 | 745 | sub, err := client.Sub(ctx, 2, -1) 746 | require.NoError(t, err) 747 | 748 | // recv one 749 | 750 | require.Equal(t, 2, <-sub) 751 | 752 | // recv many (order) 753 | 754 | serverHandler.wait <- struct{}{} 755 | serverHandler.wait <- struct{}{} 756 | serverHandler.wait <- struct{}{} 757 | 758 | require.Equal(t, 4, <-sub) 759 | require.Equal(t, 6, <-sub) 760 | require.Equal(t, 8, <-sub) 761 | 762 | // close (through ctx) 763 | cancel() 764 | 765 | _, ok := <-sub 766 | require.Equal(t, false, ok) 767 | 768 | // sub (again) 769 | 770 | serverHandler.wait = make(chan struct{}, 5) 771 | serverHandler.wait <- struct{}{} 772 | 773 | ctx, cancel = context.WithCancel(context.Background()) 774 | defer cancel() 775 | 776 | log.Warnf("last sub") 777 | sub, err = client.Sub(ctx, 3, 6) 778 | require.NoError(t, err) 779 | 780 | log.Warnf("waiting for value now") 781 | require.Equal(t, 3, <-sub) 782 | log.Warnf("not equal") 783 | 784 | // close (remote) 785 | serverHandler.wait <- struct{}{} 786 | _, ok = <-sub 787 | require.Equal(t, false, ok) 788 | } 789 | 790 | func TestChanClosing(t *testing.T) { 791 | var client struct { 792 | Sub func(context.Context, int, int) (<-chan int, error) 793 | } 794 | 795 | serverHandler := &ChanHandler{ 796 | wait: make(chan struct{}, 5), 797 | } 798 | 799 | rpcServer := NewServer() 800 | rpcServer.Register("ChanHandler", serverHandler) 801 | 802 | testServ := httptest.NewServer(rpcServer) 803 | defer testServ.Close() 804 | 805 | closer, err := NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "ChanHandler", &client, nil) 806 | require.NoError(t, err) 807 | 808 | defer closer() 809 | 810 | ctx1, cancel1 := context.WithCancel(context.Background()) 811 | ctx2, cancel2 := context.WithCancel(context.Background()) 812 | 813 | // sub 814 | 815 | sub1, err := client.Sub(ctx1, 2, -1) 816 | require.NoError(t, err) 817 | 818 | sub2, err := client.Sub(ctx2, 3, -1) 819 | require.NoError(t, err) 820 | 821 | // recv one 822 | 823 | serverHandler.wait <- struct{}{} 824 | serverHandler.wait <- struct{}{} 825 | 826 | require.Equal(t, 2, <-sub1) 827 | require.Equal(t, 3, <-sub2) 828 | 829 | cancel1() 830 | 831 | require.Equal(t, 0, <-sub1) 832 | time.Sleep(time.Millisecond * 50) // make sure the loop has exited (having a shared wait channel makes this annoying) 833 | 834 | serverHandler.wait <- struct{}{} 835 | require.Equal(t, 6, <-sub2) 836 | 837 | cancel2() 838 | require.Equal(t, 0, <-sub2) 839 | } 840 | 841 | func TestChanServerClose(t *testing.T) { 842 | var client struct { 843 | Sub func(context.Context, int, int) (<-chan int, error) 844 | } 845 | 846 | serverHandler := &ChanHandler{ 847 | wait: make(chan struct{}, 5), 848 | } 849 | 850 | rpcServer := NewServer() 851 | rpcServer.Register("ChanHandler", serverHandler) 852 | 853 | tctx, tcancel := context.WithCancel(context.Background()) 854 | 855 | testServ := httptest.NewUnstartedServer(rpcServer) 856 | testServ.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { 857 | return tctx 858 | } 859 | testServ.Start() 860 | 861 | closer, err := NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "ChanHandler", &client, nil) 862 | require.NoError(t, err) 863 | 864 | defer closer() 865 | 866 | serverHandler.wait <- struct{}{} 867 | 868 | ctx, cancel := context.WithCancel(context.Background()) 869 | defer cancel() 870 | 871 | // sub 872 | 873 | sub, err := client.Sub(ctx, 2, -1) 874 | require.NoError(t, err) 875 | 876 | // recv one 877 | 878 | require.Equal(t, 2, <-sub) 879 | 880 | // make sure we're blocked 881 | 882 | select { 883 | case <-time.After(200 * time.Millisecond): 884 | case <-sub: 885 | t.Fatal("didn't expect to get anything from sub") 886 | } 887 | 888 | // close server 889 | 890 | tcancel() 891 | testServ.Close() 892 | 893 | _, ok := <-sub 894 | require.Equal(t, false, ok) 895 | } 896 | 897 | func TestServerChanLockClose(t *testing.T) { 898 | var client struct { 899 | Sub func(context.Context, int, int) (<-chan int, error) 900 | } 901 | 902 | serverHandler := &ChanHandler{ 903 | wait: make(chan struct{}), 904 | } 905 | 906 | rpcServer := NewServer() 907 | rpcServer.Register("ChanHandler", serverHandler) 908 | 909 | testServ := httptest.NewServer(rpcServer) 910 | 911 | var closeConn func() error 912 | 913 | _, err := NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), 914 | "ChanHandler", 915 | []interface{}{&client}, nil, 916 | func(c *Config) { 917 | c.proxyConnFactory = func(f func() (*websocket.Conn, error)) func() (*websocket.Conn, error) { 918 | return func() (*websocket.Conn, error) { 919 | c, err := f() 920 | if err != nil { 921 | return nil, err 922 | } 923 | 924 | closeConn = c.UnderlyingConn().Close 925 | 926 | return c, nil 927 | } 928 | } 929 | }) 930 | require.NoError(t, err) 931 | 932 | ctx, cancel := context.WithCancel(context.Background()) 933 | defer cancel() 934 | 935 | // sub 936 | 937 | sub, err := client.Sub(ctx, 2, -1) 938 | require.NoError(t, err) 939 | 940 | // recv one 941 | 942 | go func() { 943 | serverHandler.wait <- struct{}{} 944 | }() 945 | require.Equal(t, 2, <-sub) 946 | 947 | for i := 0; i < 100; i++ { 948 | serverHandler.wait <- struct{}{} 949 | } 950 | 951 | if err := closeConn(); err != nil { 952 | t.Fatal(err) 953 | } 954 | 955 | <-serverHandler.ctxdone 956 | } 957 | 958 | type StreamingHandler struct { 959 | } 960 | 961 | func (h *StreamingHandler) GetData(ctx context.Context, n int) (<-chan int, error) { 962 | out := make(chan int) 963 | 964 | go func() { 965 | defer close(out) 966 | 967 | for i := 0; i < n; i++ { 968 | out <- i 969 | } 970 | }() 971 | 972 | return out, nil 973 | } 974 | 975 | func TestChanClientReceiveAll(t *testing.T) { 976 | var client struct { 977 | GetData func(context.Context, int) (<-chan int, error) 978 | } 979 | 980 | serverHandler := &StreamingHandler{} 981 | 982 | rpcServer := NewServer() 983 | rpcServer.Register("ChanHandler", serverHandler) 984 | 985 | tctx, tcancel := context.WithCancel(context.Background()) 986 | 987 | testServ := httptest.NewUnstartedServer(rpcServer) 988 | testServ.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { 989 | return tctx 990 | } 991 | testServ.Start() 992 | 993 | closer, err := NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "ChanHandler", &client, nil) 994 | require.NoError(t, err) 995 | 996 | defer closer() 997 | 998 | ctx, cancel := context.WithCancel(context.Background()) 999 | defer cancel() 1000 | 1001 | // sub 1002 | 1003 | sub, err := client.GetData(ctx, 100) 1004 | require.NoError(t, err) 1005 | 1006 | for i := 0; i < 100; i++ { 1007 | select { 1008 | case v, ok := <-sub: 1009 | if !ok { 1010 | t.Fatal("channel closed", i) 1011 | } 1012 | 1013 | if v != i { 1014 | t.Fatal("got wrong value", v, i) 1015 | } 1016 | case <-time.After(time.Second): 1017 | t.Fatal("timed out waiting for values") 1018 | } 1019 | } 1020 | 1021 | tcancel() 1022 | testServ.Close() 1023 | 1024 | } 1025 | 1026 | func TestControlChanDeadlock(t *testing.T) { 1027 | if _, exists := os.LookupEnv("GOLOG_LOG_LEVEL"); !exists { 1028 | _ = logging.SetLogLevel("rpc", "error") 1029 | defer func() { 1030 | _ = logging.SetLogLevel("rpc", "DEBUG") 1031 | }() 1032 | } 1033 | 1034 | for r := 0; r < 20; r++ { 1035 | testControlChanDeadlock(t) 1036 | } 1037 | } 1038 | 1039 | func testControlChanDeadlock(t *testing.T) { 1040 | var client struct { 1041 | Sub func(context.Context, int, int) (<-chan int, error) 1042 | } 1043 | 1044 | n := 5000 1045 | 1046 | serverHandler := &ChanHandler{ 1047 | wait: make(chan struct{}, n), 1048 | } 1049 | 1050 | rpcServer := NewServer() 1051 | rpcServer.Register("ChanHandler", serverHandler) 1052 | 1053 | testServ := httptest.NewServer(rpcServer) 1054 | defer testServ.Close() 1055 | 1056 | closer, err := NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "ChanHandler", &client, nil) 1057 | require.NoError(t, err) 1058 | 1059 | defer closer() 1060 | 1061 | for i := 0; i < n; i++ { 1062 | serverHandler.wait <- struct{}{} 1063 | } 1064 | 1065 | ctx, cancel := context.WithCancel(context.Background()) 1066 | defer cancel() 1067 | 1068 | sub, err := client.Sub(ctx, 1, -1) 1069 | require.NoError(t, err) 1070 | 1071 | done := make(chan struct{}) 1072 | 1073 | go func() { 1074 | defer close(done) 1075 | for i := 0; i < n; i++ { 1076 | if <-sub != i+1 { 1077 | panic("bad!") 1078 | // require.Equal(t, i+1, <-sub) 1079 | } 1080 | } 1081 | }() 1082 | 1083 | // reset this channel so its not shared between the sub requests... 1084 | serverHandler.wait = make(chan struct{}, n) 1085 | for i := 0; i < n; i++ { 1086 | serverHandler.wait <- struct{}{} 1087 | } 1088 | 1089 | _, err = client.Sub(ctx, 2, -1) 1090 | require.NoError(t, err) 1091 | <-done 1092 | } 1093 | 1094 | type InterfaceHandler struct { 1095 | } 1096 | 1097 | func (h *InterfaceHandler) ReadAll(ctx context.Context, r io.Reader) ([]byte, error) { 1098 | return io.ReadAll(r) 1099 | } 1100 | 1101 | func TestInterfaceHandler(t *testing.T) { 1102 | var client struct { 1103 | ReadAll func(ctx context.Context, r io.Reader) ([]byte, error) 1104 | } 1105 | 1106 | serverHandler := &InterfaceHandler{} 1107 | 1108 | rpcServer := NewServer(WithParamDecoder(new(io.Reader), readerDec)) 1109 | rpcServer.Register("InterfaceHandler", serverHandler) 1110 | 1111 | testServ := httptest.NewServer(rpcServer) 1112 | defer testServ.Close() 1113 | 1114 | closer, err := NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "InterfaceHandler", []interface{}{&client}, nil, WithParamEncoder(new(io.Reader), readerEnc)) 1115 | require.NoError(t, err) 1116 | 1117 | defer closer() 1118 | 1119 | read, err := client.ReadAll(context.TODO(), strings.NewReader("pooooootato")) 1120 | require.NoError(t, err) 1121 | require.Equal(t, "pooooootato", string(read), "potatos weren't equal") 1122 | } 1123 | 1124 | var ( 1125 | readerRegistery = map[int]io.Reader{} 1126 | readerRegisteryN = 31 1127 | readerRegisteryLk sync.Mutex 1128 | ) 1129 | 1130 | func readerEnc(rin reflect.Value) (reflect.Value, error) { 1131 | reader := rin.Interface().(io.Reader) 1132 | 1133 | readerRegisteryLk.Lock() 1134 | defer readerRegisteryLk.Unlock() 1135 | 1136 | n := readerRegisteryN 1137 | readerRegisteryN++ 1138 | 1139 | readerRegistery[n] = reader 1140 | return reflect.ValueOf(n), nil 1141 | } 1142 | 1143 | func readerDec(ctx context.Context, rin []byte) (reflect.Value, error) { 1144 | var id int 1145 | if err := json.Unmarshal(rin, &id); err != nil { 1146 | return reflect.Value{}, err 1147 | } 1148 | 1149 | readerRegisteryLk.Lock() 1150 | defer readerRegisteryLk.Unlock() 1151 | 1152 | return reflect.ValueOf(readerRegistery[id]), nil 1153 | } 1154 | 1155 | type ErrSomethingBad struct{} 1156 | 1157 | func (e ErrSomethingBad) Error() string { 1158 | return "something bad has happened" 1159 | } 1160 | 1161 | type ErrMyErr struct{ str string } 1162 | 1163 | var _ error = ErrSomethingBad{} 1164 | 1165 | func (e *ErrMyErr) UnmarshalJSON(data []byte) error { 1166 | return json.Unmarshal(data, &e.str) 1167 | } 1168 | 1169 | func (e *ErrMyErr) MarshalJSON() ([]byte, error) { 1170 | return json.Marshal(e.str) 1171 | } 1172 | 1173 | func (e *ErrMyErr) Error() string { 1174 | return fmt.Sprintf("this happened: %s", e.str) 1175 | } 1176 | 1177 | type ErrHandler struct{} 1178 | 1179 | func (h *ErrHandler) Test() error { 1180 | return ErrSomethingBad{} 1181 | } 1182 | 1183 | func (h *ErrHandler) TestP() error { 1184 | return &ErrSomethingBad{} 1185 | } 1186 | 1187 | func (h *ErrHandler) TestMy(s string) error { 1188 | return &ErrMyErr{ 1189 | str: s, 1190 | } 1191 | } 1192 | 1193 | func TestUserError(t *testing.T) { 1194 | // setup server 1195 | 1196 | serverHandler := &ErrHandler{} 1197 | 1198 | const ( 1199 | EBad = iota + FirstUserCode 1200 | EBad2 1201 | EMy 1202 | ) 1203 | 1204 | errs := NewErrors() 1205 | errs.Register(EBad, new(ErrSomethingBad)) 1206 | errs.Register(EBad2, new(*ErrSomethingBad)) 1207 | errs.Register(EMy, new(*ErrMyErr)) 1208 | 1209 | rpcServer := NewServer(WithServerErrors(errs)) 1210 | rpcServer.Register("ErrHandler", serverHandler) 1211 | 1212 | // httptest stuff 1213 | testServ := httptest.NewServer(rpcServer) 1214 | defer testServ.Close() 1215 | 1216 | // setup client 1217 | 1218 | var client struct { 1219 | Test func() error 1220 | TestP func() error 1221 | TestMy func(s string) error 1222 | } 1223 | closer, err := NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "ErrHandler", []interface{}{ 1224 | &client, 1225 | }, nil, WithErrors(errs)) 1226 | require.NoError(t, err) 1227 | 1228 | e := client.Test() 1229 | require.True(t, xerrors.Is(e, ErrSomethingBad{})) 1230 | 1231 | e = client.TestP() 1232 | require.True(t, xerrors.Is(e, &ErrSomethingBad{})) 1233 | 1234 | e = client.TestMy("some event") 1235 | require.Error(t, e) 1236 | require.Equal(t, "this happened: some event", e.Error()) 1237 | require.Equal(t, "this happened: some event", e.(*ErrMyErr).Error()) 1238 | 1239 | closer() 1240 | } 1241 | 1242 | // Unit test for request/response ID translation. 1243 | func TestIDHandling(t *testing.T) { 1244 | var decoded request 1245 | 1246 | cases := []struct { 1247 | str string 1248 | expect interface{} 1249 | expectErr bool 1250 | }{ 1251 | { 1252 | `{"id":"8116d306-56cc-4637-9dd7-39ce1548a5a0","jsonrpc":"2.0","method":"eth_blockNumber","params":[]}`, 1253 | "8116d306-56cc-4637-9dd7-39ce1548a5a0", 1254 | false, 1255 | }, 1256 | {`{"id":1234,"jsonrpc":"2.0","method":"eth_blockNumber","params":[]}`, float64(1234), false}, 1257 | {`{"id":null,"jsonrpc":"2.0","method":"eth_blockNumber","params":[]}`, nil, false}, 1258 | {`{"id":1234.0,"jsonrpc":"2.0","method":"eth_blockNumber","params":[]}`, 1234.0, false}, 1259 | {`{"id":1.2,"jsonrpc":"2.0","method":"eth_blockNumber","params":[]}`, 1.2, false}, 1260 | {`{"id":["1"],"jsonrpc":"2.0","method":"eth_blockNumber","params":[]}`, nil, true}, 1261 | {`{"id":{"a":"b"},"jsonrpc":"2.0","method":"eth_blockNumber","params":[]}`, nil, true}, 1262 | } 1263 | 1264 | for _, tc := range cases { 1265 | t.Run(fmt.Sprintf("%v", tc.expect), func(t *testing.T) { 1266 | dec := json.NewDecoder(strings.NewReader(tc.str)) 1267 | require.NoError(t, dec.Decode(&decoded)) 1268 | if id, err := normalizeID(decoded.ID); !tc.expectErr { 1269 | require.NoError(t, err) 1270 | require.Equal(t, tc.expect, id) 1271 | } else { 1272 | require.Error(t, err) 1273 | } 1274 | }) 1275 | } 1276 | } 1277 | 1278 | func TestAliasedCall(t *testing.T) { 1279 | // setup server 1280 | 1281 | rpcServer := NewServer() 1282 | rpcServer.Register("ServName", &SimpleServerHandler{n: 3}) 1283 | 1284 | // httptest stuff 1285 | testServ := httptest.NewServer(rpcServer) 1286 | defer testServ.Close() 1287 | 1288 | // setup client 1289 | var client struct { 1290 | WhateverMethodName func(int) (int, error) `rpc_method:"ServName.AddGet"` 1291 | } 1292 | closer, err := NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "Server", []interface{}{ 1293 | &client, 1294 | }, nil) 1295 | require.NoError(t, err) 1296 | 1297 | // do the call! 1298 | 1299 | n, err := client.WhateverMethodName(1) 1300 | require.NoError(t, err) 1301 | 1302 | require.Equal(t, 4, n) 1303 | 1304 | closer() 1305 | } 1306 | 1307 | type NotifHandler struct { 1308 | notified chan struct{} 1309 | } 1310 | 1311 | func (h *NotifHandler) Notif() { 1312 | close(h.notified) 1313 | } 1314 | 1315 | func TestNotif(t *testing.T) { 1316 | tc := func(proto string) func(t *testing.T) { 1317 | return func(t *testing.T) { 1318 | // setup server 1319 | 1320 | nh := &NotifHandler{ 1321 | notified: make(chan struct{}), 1322 | } 1323 | 1324 | rpcServer := NewServer() 1325 | rpcServer.Register("Notif", nh) 1326 | 1327 | // httptest stuff 1328 | testServ := httptest.NewServer(rpcServer) 1329 | defer testServ.Close() 1330 | 1331 | // setup client 1332 | var client struct { 1333 | Notif func() error `notify:"true"` 1334 | } 1335 | closer, err := NewMergeClient(context.Background(), proto+"://"+testServ.Listener.Addr().String(), "Notif", []interface{}{ 1336 | &client, 1337 | }, nil) 1338 | require.NoError(t, err) 1339 | 1340 | // do the call! 1341 | 1342 | // this will block if it's not sent as a notification 1343 | err = client.Notif() 1344 | require.NoError(t, err) 1345 | 1346 | <-nh.notified 1347 | 1348 | closer() 1349 | } 1350 | } 1351 | 1352 | t.Run("ws", tc("ws")) 1353 | t.Run("http", tc("http")) 1354 | } 1355 | 1356 | type RawParamHandler struct { 1357 | } 1358 | 1359 | type CustomParams struct { 1360 | I int 1361 | } 1362 | 1363 | func (h *RawParamHandler) Call(ctx context.Context, ps RawParams) (int, error) { 1364 | p, err := DecodeParams[CustomParams](ps) 1365 | if err != nil { 1366 | return 0, err 1367 | } 1368 | return p.I + 1, nil 1369 | } 1370 | 1371 | func TestCallWithRawParams(t *testing.T) { 1372 | // setup server 1373 | 1374 | rpcServer := NewServer() 1375 | rpcServer.Register("Raw", &RawParamHandler{}) 1376 | 1377 | // httptest stuff 1378 | testServ := httptest.NewServer(rpcServer) 1379 | defer testServ.Close() 1380 | 1381 | // setup client 1382 | var client struct { 1383 | Call func(ctx context.Context, ps RawParams) (int, error) 1384 | } 1385 | closer, err := NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "Raw", []interface{}{ 1386 | &client, 1387 | }, nil) 1388 | require.NoError(t, err) 1389 | 1390 | // do the call! 1391 | 1392 | // this will block if it's not sent as a notification 1393 | n, err := client.Call(context.Background(), []byte(`{"I": 1}`)) 1394 | require.NoError(t, err) 1395 | require.Equal(t, 2, n) 1396 | 1397 | closer() 1398 | } 1399 | 1400 | type RevCallTestServerHandler struct { 1401 | } 1402 | 1403 | func (h *RevCallTestServerHandler) Call(ctx context.Context) error { 1404 | revClient, ok := ExtractReverseClient[RevCallTestClientProxy](ctx) 1405 | if !ok { 1406 | return fmt.Errorf("no reverse client") 1407 | } 1408 | 1409 | r, err := revClient.CallOnClient(7) // multiply by 2 on client 1410 | if err != nil { 1411 | return xerrors.Errorf("call on client: %w", err) 1412 | } 1413 | 1414 | if r != 14 { 1415 | return fmt.Errorf("unexpected result: %d", r) 1416 | } 1417 | 1418 | return nil 1419 | } 1420 | 1421 | type RevCallTestClientProxy struct { 1422 | CallOnClient func(int) (int, error) 1423 | } 1424 | 1425 | type RevCallTestClientHandler struct { 1426 | } 1427 | 1428 | func (h *RevCallTestClientHandler) CallOnClient(a int) (int, error) { 1429 | return a * 2, nil 1430 | } 1431 | 1432 | func TestReverseCall(t *testing.T) { 1433 | // setup server 1434 | 1435 | rpcServer := NewServer(WithReverseClient[RevCallTestClientProxy]("Client")) 1436 | rpcServer.Register("Server", &RevCallTestServerHandler{}) 1437 | 1438 | // httptest stuff 1439 | testServ := httptest.NewServer(rpcServer) 1440 | defer testServ.Close() 1441 | 1442 | // setup client 1443 | 1444 | var client struct { 1445 | Call func() error 1446 | } 1447 | closer, err := NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "Server", []interface{}{ 1448 | &client, 1449 | }, nil, WithClientHandler("Client", &RevCallTestClientHandler{})) 1450 | require.NoError(t, err) 1451 | 1452 | // do the call! 1453 | 1454 | e := client.Call() 1455 | require.NoError(t, e) 1456 | 1457 | closer() 1458 | } 1459 | 1460 | type RevCallTestServerHandlerAliased struct { 1461 | } 1462 | 1463 | func (h *RevCallTestServerHandlerAliased) Call(ctx context.Context) error { 1464 | revClient, ok := ExtractReverseClient[RevCallTestClientProxyAliased](ctx) 1465 | if !ok { 1466 | return fmt.Errorf("no reverse client") 1467 | } 1468 | 1469 | r, err := revClient.CallOnClient(8) // multiply by 2 on client 1470 | if err != nil { 1471 | return xerrors.Errorf("call on client: %w", err) 1472 | } 1473 | 1474 | if r != 16 { 1475 | return fmt.Errorf("unexpected result: %d", r) 1476 | } 1477 | 1478 | return nil 1479 | } 1480 | 1481 | type RevCallTestClientProxyAliased struct { 1482 | CallOnClient func(int) (int, error) `rpc_method:"rpc_thing"` 1483 | } 1484 | 1485 | func TestReverseCallAliased(t *testing.T) { 1486 | // setup server 1487 | 1488 | rpcServer := NewServer(WithReverseClient[RevCallTestClientProxyAliased]("Client")) 1489 | rpcServer.Register("Server", &RevCallTestServerHandlerAliased{}) 1490 | 1491 | // httptest stuff 1492 | testServ := httptest.NewServer(rpcServer) 1493 | defer testServ.Close() 1494 | 1495 | // setup client 1496 | 1497 | var client struct { 1498 | Call func() error 1499 | } 1500 | closer, err := NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "Server", []interface{}{ 1501 | &client, 1502 | }, nil, WithClientHandler("Client", &RevCallTestClientHandler{}), WithClientHandlerAlias("rpc_thing", "Client.CallOnClient")) 1503 | require.NoError(t, err) 1504 | 1505 | // do the call! 1506 | 1507 | e := client.Call() 1508 | require.NoError(t, e) 1509 | 1510 | closer() 1511 | } 1512 | 1513 | // RevCallDropTestServerHandler attempts to make a client call on a closed connection. 1514 | type RevCallDropTestServerHandler struct { 1515 | closeConn func() 1516 | res chan error 1517 | } 1518 | 1519 | func (h *RevCallDropTestServerHandler) Call(ctx context.Context) error { 1520 | revClient, ok := ExtractReverseClient[RevCallTestClientProxy](ctx) 1521 | if !ok { 1522 | return fmt.Errorf("no reverse client") 1523 | } 1524 | 1525 | h.closeConn() 1526 | time.Sleep(time.Second) 1527 | 1528 | _, err := revClient.CallOnClient(7) 1529 | h.res <- err 1530 | 1531 | return nil 1532 | } 1533 | 1534 | func TestReverseCallDroppedConn(t *testing.T) { 1535 | // setup server 1536 | 1537 | hnd := &RevCallDropTestServerHandler{ 1538 | res: make(chan error), 1539 | } 1540 | 1541 | rpcServer := NewServer(WithReverseClient[RevCallTestClientProxy]("Client")) 1542 | rpcServer.Register("Server", hnd) 1543 | 1544 | // httptest stuff 1545 | testServ := httptest.NewServer(rpcServer) 1546 | defer testServ.Close() 1547 | 1548 | // setup client 1549 | 1550 | var client struct { 1551 | Call func() error 1552 | } 1553 | closer, err := NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "Server", []interface{}{ 1554 | &client, 1555 | }, nil, WithClientHandler("Client", &RevCallTestClientHandler{})) 1556 | require.NoError(t, err) 1557 | 1558 | hnd.closeConn = closer 1559 | 1560 | // do the call! 1561 | e := client.Call() 1562 | 1563 | require.Error(t, e) 1564 | require.Contains(t, e.Error(), "websocket connection closed") 1565 | 1566 | res := <-hnd.res 1567 | require.Error(t, res) 1568 | require.Contains(t, res.Error(), "RPC client error: sendRequest failed: websocket routine exiting") 1569 | time.Sleep(100 * time.Millisecond) 1570 | } 1571 | 1572 | type BigCallTestServerHandler struct { 1573 | } 1574 | 1575 | type RecRes struct { 1576 | I int 1577 | R []RecRes 1578 | } 1579 | 1580 | func (h *BigCallTestServerHandler) Do() (RecRes, error) { 1581 | var res RecRes 1582 | res.I = 123 1583 | 1584 | for i := 0; i < 15000; i++ { 1585 | var ires RecRes 1586 | ires.I = i 1587 | 1588 | for j := 0; j < 15000; j++ { 1589 | var jres RecRes 1590 | jres.I = j 1591 | 1592 | ires.R = append(ires.R, jres) 1593 | } 1594 | 1595 | res.R = append(res.R, ires) 1596 | } 1597 | 1598 | fmt.Println("sending result") 1599 | 1600 | return res, nil 1601 | } 1602 | 1603 | func (h *BigCallTestServerHandler) Ch(ctx context.Context) (<-chan int, error) { 1604 | out := make(chan int) 1605 | 1606 | go func() { 1607 | var i int 1608 | for { 1609 | select { 1610 | case <-ctx.Done(): 1611 | fmt.Println("closing") 1612 | close(out) 1613 | return 1614 | case <-time.After(time.Second): 1615 | } 1616 | fmt.Println("sending") 1617 | out <- i 1618 | i++ 1619 | } 1620 | }() 1621 | 1622 | return out, nil 1623 | } 1624 | 1625 | // TestBigResult tests that the connection doesn't die when sending a large result, 1626 | // and that requests which happen while a large result is being sent don't fail. 1627 | func TestBigResult(t *testing.T) { 1628 | if os.Getenv("I_HAVE_A_LOT_OF_MEMORY_AND_TIME") != "1" { 1629 | // needs ~40GB of memory and ~4 minutes to run 1630 | t.Skip("skipping test due to required resources, set I_HAVE_A_LOT_OF_MEMORY_AND_TIME=1 to run") 1631 | } 1632 | 1633 | // setup server 1634 | 1635 | serverHandler := &BigCallTestServerHandler{} 1636 | 1637 | rpcServer := NewServer() 1638 | rpcServer.Register("SimpleServerHandler", serverHandler) 1639 | 1640 | // httptest stuff 1641 | testServ := httptest.NewServer(rpcServer) 1642 | defer testServ.Close() 1643 | // setup client 1644 | 1645 | var client struct { 1646 | Do func() (RecRes, error) 1647 | Ch func(ctx context.Context) (<-chan int, error) 1648 | } 1649 | closer, err := NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &client, nil) 1650 | require.NoError(t, err) 1651 | defer closer() 1652 | 1653 | chctx, cancel := context.WithCancel(context.Background()) 1654 | defer cancel() 1655 | 1656 | // client.Ch will generate some requests, which will require websocket locks, 1657 | // and before fixes in #97 would cause deadlocks / timeouts when combined with 1658 | // the large result processing from client.Do 1659 | ch, err := client.Ch(chctx) 1660 | require.NoError(t, err) 1661 | 1662 | prevN := <-ch 1663 | 1664 | go func() { 1665 | for n := range ch { 1666 | if n != prevN+1 { 1667 | panic("bad order") 1668 | } 1669 | prevN = n 1670 | } 1671 | }() 1672 | 1673 | _, err = client.Do() 1674 | require.NoError(t, err) 1675 | 1676 | fmt.Println("done") 1677 | } 1678 | 1679 | func TestNewCustomClient(t *testing.T) { 1680 | // Setup server 1681 | serverHandler := &SimpleServerHandler{} 1682 | rpcServer := NewServer() 1683 | rpcServer.Register("SimpleServerHandler", serverHandler) 1684 | 1685 | // Custom doRequest function 1686 | doRequest := func(ctx context.Context, body []byte) (io.ReadCloser, error) { 1687 | reader := bytes.NewReader(body) 1688 | pr, pw := io.Pipe() 1689 | go func() { 1690 | defer pw.Close() 1691 | rpcServer.HandleRequest(ctx, reader, pw) 1692 | }() 1693 | return pr, nil 1694 | } 1695 | 1696 | var client struct { 1697 | Add func(int) error 1698 | AddGet func(int) int 1699 | } 1700 | 1701 | // Create custom client 1702 | closer, err := NewCustomClient("SimpleServerHandler", []interface{}{&client}, doRequest) 1703 | require.NoError(t, err) 1704 | defer closer() 1705 | 1706 | // Add(int) error 1707 | require.NoError(t, client.Add(10)) 1708 | require.Equal(t, int32(10), serverHandler.n) 1709 | 1710 | err = client.Add(-3546) 1711 | require.EqualError(t, err, "test") 1712 | 1713 | // AddGet(int) int 1714 | n := client.AddGet(3) 1715 | require.Equal(t, 13, n) 1716 | require.Equal(t, int32(13), serverHandler.n) 1717 | } 1718 | 1719 | func TestReverseCallWithCustomMethodName(t *testing.T) { 1720 | // setup server 1721 | 1722 | rpcServer := NewServer(WithServerMethodNameFormatter(func(namespace, method string) string { return namespace + "_" + method })) 1723 | rpcServer.Register("Server", &RawParamHandler{}) 1724 | 1725 | // httptest stuff 1726 | testServ := httptest.NewServer(rpcServer) 1727 | defer testServ.Close() 1728 | 1729 | // setup client 1730 | 1731 | var client struct { 1732 | Call func(ctx context.Context, ps RawParams) error `rpc_method:"Server_Call"` 1733 | } 1734 | closer, err := NewMergeClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "Server", []interface{}{ 1735 | &client, 1736 | }, nil) 1737 | require.NoError(t, err) 1738 | 1739 | // do the call! 1740 | 1741 | e := client.Call(context.Background(), []byte(`{"I": 1}`)) 1742 | require.NoError(t, e) 1743 | 1744 | closer() 1745 | } 1746 | 1747 | type MethodTransformedHandler struct{} 1748 | 1749 | func (h *RawParamHandler) CallSomethingInSnakeCase(ctx context.Context, v int) (int, error) { 1750 | return v + 1, nil 1751 | } 1752 | -------------------------------------------------------------------------------- /server.go: -------------------------------------------------------------------------------- 1 | package jsonrpc 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "io" 7 | "net/http" 8 | "runtime/pprof" 9 | "strings" 10 | "time" 11 | 12 | "github.com/google/uuid" 13 | "github.com/gorilla/websocket" 14 | ) 15 | 16 | const ( 17 | rpcParseError = -32700 18 | rpcInvalidRequest = -32600 19 | rpcMethodNotFound = -32601 20 | rpcInvalidParams = -32602 21 | ) 22 | 23 | // ConnectionType indicates the type of connection, this is set in the context and can be retrieved 24 | // with GetConnectionType. 25 | type ConnectionType string 26 | 27 | const ( 28 | // ConnectionTypeUnknown indicates that the connection type cannot be determined, likely because 29 | // it hasn't passed through an RPCServer. 30 | ConnectionTypeUnknown ConnectionType = "unknown" 31 | // ConnectionTypeHTTP indicates that the connection is an HTTP connection. 32 | ConnectionTypeHTTP ConnectionType = "http" 33 | // ConnectionTypeWS indicates that the connection is a WebSockets connection. 34 | ConnectionTypeWS ConnectionType = "websockets" 35 | ) 36 | 37 | var connectionTypeCtxKey = &struct{ name string }{"jsonrpc-connection-type"} 38 | 39 | // GetConnectionType returns the connection type of the request if it was set by an RPCServer. 40 | // A connection type of ConnectionTypeUnknown means that the connection type was not set. 41 | func GetConnectionType(ctx context.Context) ConnectionType { 42 | if v := ctx.Value(connectionTypeCtxKey); v != nil { 43 | return v.(ConnectionType) 44 | } 45 | return ConnectionTypeUnknown 46 | } 47 | 48 | // RPCServer provides a jsonrpc 2.0 http server handler 49 | type RPCServer struct { 50 | *handler 51 | reverseClientBuilder func(context.Context, *wsConn) (context.Context, error) 52 | 53 | pingInterval time.Duration 54 | } 55 | 56 | // NewServer creates new RPCServer instance 57 | func NewServer(opts ...ServerOption) *RPCServer { 58 | config := defaultServerConfig() 59 | for _, o := range opts { 60 | o(&config) 61 | } 62 | 63 | return &RPCServer{ 64 | handler: makeHandler(config), 65 | reverseClientBuilder: config.reverseClientBuilder, 66 | 67 | pingInterval: config.pingInterval, 68 | } 69 | } 70 | 71 | var upgrader = websocket.Upgrader{ 72 | CheckOrigin: func(r *http.Request) bool { 73 | return true 74 | }, 75 | } 76 | 77 | func (s *RPCServer) handleWS(ctx context.Context, w http.ResponseWriter, r *http.Request) { 78 | // TODO: allow setting 79 | // (note that we still are mostly covered by jwt tokens) 80 | w.Header().Set("Access-Control-Allow-Origin", "*") 81 | if r.Header.Get("Sec-WebSocket-Protocol") != "" { 82 | w.Header().Set("Sec-WebSocket-Protocol", r.Header.Get("Sec-WebSocket-Protocol")) 83 | } 84 | 85 | c, err := upgrader.Upgrade(w, r, nil) 86 | if err != nil { 87 | log.Errorw("upgrading connection", "error", err) 88 | // note that upgrader.Upgrade will set http error if there is an error 89 | return 90 | } 91 | 92 | wc := &wsConn{ 93 | conn: c, 94 | handler: s, 95 | pingInterval: s.pingInterval, 96 | exiting: make(chan struct{}), 97 | } 98 | 99 | if s.reverseClientBuilder != nil { 100 | ctx, err = s.reverseClientBuilder(ctx, wc) 101 | if err != nil { 102 | log.Errorf("failed to build reverse client: %s", err) 103 | w.WriteHeader(500) 104 | return 105 | } 106 | } 107 | 108 | lbl := pprof.Labels("jrpc-mode", "wsserver", "jrpc-remote", r.RemoteAddr, "jrpc-uuid", uuid.New().String()) 109 | pprof.Do(ctx, lbl, func(ctx context.Context) { 110 | wc.handleWsConn(ctx) 111 | }) 112 | 113 | if err := c.Close(); err != nil { 114 | log.Errorw("closing websocket connection", "error", err) 115 | return 116 | } 117 | } 118 | 119 | // TODO: return errors to clients per spec 120 | func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { 121 | ctx := r.Context() 122 | 123 | h := strings.ToLower(r.Header.Get("Connection")) 124 | if strings.Contains(h, "upgrade") { 125 | ctx = context.WithValue(ctx, connectionTypeCtxKey, ConnectionTypeWS) 126 | s.handleWS(ctx, w, r) 127 | return 128 | } 129 | 130 | ctx = context.WithValue(ctx, connectionTypeCtxKey, ConnectionTypeHTTP) 131 | s.handleReader(ctx, r.Body, w, rpcError) 132 | } 133 | 134 | func (s *RPCServer) HandleRequest(ctx context.Context, r io.Reader, w io.Writer) { 135 | s.handleReader(ctx, r, w, rpcError) 136 | } 137 | 138 | func rpcError(wf func(func(io.Writer)), req *request, code ErrorCode, err error) { 139 | log.Errorf("RPC Error: %s", err) 140 | wf(func(w io.Writer) { 141 | if hw, ok := w.(http.ResponseWriter); ok { 142 | if code == rpcInvalidRequest { 143 | hw.WriteHeader(http.StatusBadRequest) 144 | } else { 145 | hw.WriteHeader(http.StatusInternalServerError) 146 | } 147 | } 148 | 149 | log.Warnf("rpc error: %s", err) 150 | 151 | if req == nil { 152 | req = &request{} 153 | } 154 | 155 | resp := response{ 156 | Jsonrpc: "2.0", 157 | ID: req.ID, 158 | Error: &JSONRPCError{ 159 | Code: code, 160 | Message: err.Error(), 161 | }, 162 | } 163 | 164 | err = json.NewEncoder(w).Encode(resp) 165 | if err != nil { 166 | log.Warnf("failed to write rpc error: %s", err) 167 | return 168 | } 169 | }) 170 | } 171 | 172 | // Register registers new RPC handler 173 | // 174 | // Handler is any value with methods defined 175 | func (s *RPCServer) Register(namespace string, handler interface{}) { 176 | s.register(namespace, handler) 177 | } 178 | 179 | func (s *RPCServer) AliasMethod(alias, original string) { 180 | s.aliasedMethods[alias] = original 181 | } 182 | 183 | var _ error = &JSONRPCError{} 184 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package jsonrpc 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "math" 7 | "math/rand" 8 | "reflect" 9 | "time" 10 | ) 11 | 12 | type param struct { 13 | data []byte // from unmarshal 14 | 15 | v reflect.Value // to marshal 16 | } 17 | 18 | func (p *param) UnmarshalJSON(raw []byte) error { 19 | p.data = make([]byte, len(raw)) 20 | copy(p.data, raw) 21 | return nil 22 | } 23 | 24 | func (p *param) MarshalJSON() ([]byte, error) { 25 | if p.v.Kind() == reflect.Invalid { 26 | return p.data, nil 27 | } 28 | 29 | return json.Marshal(p.v.Interface()) 30 | } 31 | 32 | // processFuncOut finds value and error Outs in function 33 | func processFuncOut(funcType reflect.Type) (valOut int, errOut int, n int) { 34 | errOut = -1 // -1 if not found 35 | valOut = -1 36 | n = funcType.NumOut() 37 | 38 | switch n { 39 | case 0: 40 | case 1: 41 | if funcType.Out(0) == errorType { 42 | errOut = 0 43 | } else { 44 | valOut = 0 45 | } 46 | case 2: 47 | valOut = 0 48 | errOut = 1 49 | if funcType.Out(1) != errorType { 50 | panic("expected error as second return value") 51 | } 52 | default: 53 | errstr := fmt.Sprintf("too many return values: %s", funcType) 54 | panic(errstr) 55 | } 56 | 57 | return 58 | } 59 | 60 | type backoff struct { 61 | minDelay time.Duration 62 | maxDelay time.Duration 63 | } 64 | 65 | func (b *backoff) next(attempt int) time.Duration { 66 | if attempt < 0 { 67 | return b.minDelay 68 | } 69 | 70 | minf := float64(b.minDelay) 71 | durf := minf * math.Pow(1.5, float64(attempt)) 72 | durf = durf + rand.Float64()*minf 73 | 74 | delay := time.Duration(durf) 75 | 76 | if delay > b.maxDelay { 77 | return b.maxDelay 78 | } 79 | 80 | return delay 81 | } 82 | -------------------------------------------------------------------------------- /version.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "v0.8.0" 3 | } 4 | -------------------------------------------------------------------------------- /websocket.go: -------------------------------------------------------------------------------- 1 | package jsonrpc 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "os" 10 | "reflect" 11 | "sync" 12 | "sync/atomic" 13 | "time" 14 | 15 | "github.com/gorilla/websocket" 16 | "golang.org/x/xerrors" 17 | ) 18 | 19 | const wsCancel = "xrpc.cancel" 20 | const chValue = "xrpc.ch.val" 21 | const chClose = "xrpc.ch.close" 22 | 23 | var debugTrace = os.Getenv("JSONRPC_ENABLE_DEBUG_TRACE") == "1" 24 | 25 | type frame struct { 26 | // common 27 | Jsonrpc string `json:"jsonrpc"` 28 | ID interface{} `json:"id,omitempty"` 29 | Meta map[string]string `json:"meta,omitempty"` 30 | 31 | // request 32 | Method string `json:"method,omitempty"` 33 | Params json.RawMessage `json:"params,omitempty"` 34 | 35 | // response 36 | Result json.RawMessage `json:"result,omitempty"` 37 | Error *JSONRPCError `json:"error,omitempty"` 38 | } 39 | 40 | type outChanReg struct { 41 | reqID interface{} 42 | 43 | chID uint64 44 | ch reflect.Value 45 | } 46 | 47 | type reqestHandler interface { 48 | handle(ctx context.Context, req request, w func(func(io.Writer)), rpcError rpcErrFunc, done func(keepCtx bool), chOut chanOut) 49 | } 50 | 51 | type wsConn struct { 52 | // outside params 53 | conn *websocket.Conn 54 | connFactory func() (*websocket.Conn, error) 55 | reconnectBackoff backoff 56 | pingInterval time.Duration 57 | timeout time.Duration 58 | handler reqestHandler 59 | requests <-chan clientRequest 60 | pongs chan struct{} 61 | stopPings func() 62 | stop <-chan struct{} 63 | exiting chan struct{} 64 | 65 | // incoming messages 66 | incoming chan io.Reader 67 | incomingErr error 68 | errLk sync.Mutex 69 | 70 | readError chan error 71 | 72 | frameExecQueue chan []byte 73 | 74 | // outgoing messages 75 | writeLk sync.Mutex 76 | 77 | // //// 78 | // Client related 79 | 80 | // inflight are requests we've sent to the remote 81 | inflight map[interface{}]clientRequest 82 | inflightLk sync.Mutex 83 | 84 | // chanHandlers is a map of client-side channel handlers 85 | chanHandlersLk sync.Mutex 86 | chanHandlers map[uint64]*chanHandler 87 | 88 | // //// 89 | // Server related 90 | 91 | // handling are the calls we handle 92 | handling map[interface{}]context.CancelFunc 93 | handlingLk sync.Mutex 94 | 95 | spawnOutChanHandlerOnce sync.Once 96 | 97 | // chanCtr is a counter used for identifying output channels on the server side 98 | chanCtr uint64 99 | 100 | registerCh chan outChanReg 101 | } 102 | 103 | type chanHandler struct { 104 | // take inside chanHandlersLk 105 | lk sync.Mutex 106 | 107 | cb func(m []byte, ok bool) 108 | } 109 | 110 | // // 111 | // WebSocket Message utils // 112 | // // 113 | 114 | // nextMessage wait for one message and puts it to the incoming channel 115 | func (c *wsConn) nextMessage() { 116 | c.resetReadDeadline() 117 | msgType, r, err := c.conn.NextReader() 118 | if err != nil { 119 | c.errLk.Lock() 120 | c.incomingErr = err 121 | c.errLk.Unlock() 122 | close(c.incoming) 123 | return 124 | } 125 | if msgType != websocket.BinaryMessage && msgType != websocket.TextMessage { 126 | c.errLk.Lock() 127 | c.incomingErr = errors.New("unsupported message type") 128 | c.errLk.Unlock() 129 | close(c.incoming) 130 | return 131 | } 132 | c.incoming <- r 133 | } 134 | 135 | // nextWriter waits for writeLk and invokes the cb callback with WS message 136 | // writer when the lock is acquired 137 | func (c *wsConn) nextWriter(cb func(io.Writer)) { 138 | c.writeLk.Lock() 139 | defer c.writeLk.Unlock() 140 | 141 | wcl, err := c.conn.NextWriter(websocket.TextMessage) 142 | if err != nil { 143 | log.Error("handle me:", err) 144 | return 145 | } 146 | 147 | cb(wcl) 148 | 149 | if err := wcl.Close(); err != nil { 150 | log.Error("handle me:", err) 151 | return 152 | } 153 | } 154 | 155 | func (c *wsConn) sendRequest(req request) error { 156 | c.writeLk.Lock() 157 | defer c.writeLk.Unlock() 158 | 159 | if debugTrace { 160 | log.Debugw("sendRequest", "req", req.Method, "id", req.ID) 161 | } 162 | 163 | if err := c.conn.WriteJSON(req); err != nil { 164 | return err 165 | } 166 | return nil 167 | } 168 | 169 | // // 170 | // Output channels // 171 | // // 172 | 173 | // handleOutChans handles channel communication on the server side 174 | // (forwards channel messages to client) 175 | func (c *wsConn) handleOutChans() { 176 | regV := reflect.ValueOf(c.registerCh) 177 | exitV := reflect.ValueOf(c.exiting) 178 | 179 | cases := []reflect.SelectCase{ 180 | { // registration chan always 0 181 | Dir: reflect.SelectRecv, 182 | Chan: regV, 183 | }, 184 | { // exit chan always 1 185 | Dir: reflect.SelectRecv, 186 | Chan: exitV, 187 | }, 188 | } 189 | internal := len(cases) 190 | var caseToID []uint64 191 | 192 | for { 193 | chosen, val, ok := reflect.Select(cases) 194 | 195 | switch chosen { 196 | case 0: // registration channel 197 | if !ok { 198 | // control channel closed - signals closed connection 199 | // This shouldn't happen, instead the exiting channel should get closed 200 | log.Warn("control channel closed") 201 | return 202 | } 203 | 204 | registration := val.Interface().(outChanReg) 205 | 206 | caseToID = append(caseToID, registration.chID) 207 | cases = append(cases, reflect.SelectCase{ 208 | Dir: reflect.SelectRecv, 209 | Chan: registration.ch, 210 | }) 211 | 212 | c.nextWriter(func(w io.Writer) { 213 | resp := &response{ 214 | Jsonrpc: "2.0", 215 | ID: registration.reqID, 216 | Result: registration.chID, 217 | } 218 | 219 | if err := json.NewEncoder(w).Encode(resp); err != nil { 220 | log.Error(err) 221 | return 222 | } 223 | }) 224 | 225 | continue 226 | case 1: // exiting channel 227 | if !ok { 228 | // exiting channel closed - signals closed connection 229 | // 230 | // We're not closing any channels as we're on receiving end. 231 | // Also, context cancellation below should take care of any running 232 | // requests 233 | return 234 | } 235 | log.Warn("exiting channel received a message") 236 | continue 237 | } 238 | 239 | if !ok { 240 | // Output channel closed, cleanup, and tell remote that this happened 241 | 242 | id := caseToID[chosen-internal] 243 | 244 | n := len(cases) - 1 245 | if n > 0 { 246 | cases[chosen] = cases[n] 247 | caseToID[chosen-internal] = caseToID[n-internal] 248 | } 249 | 250 | cases = cases[:n] 251 | caseToID = caseToID[:n-internal] 252 | 253 | rp, err := json.Marshal([]param{{v: reflect.ValueOf(id)}}) 254 | if err != nil { 255 | log.Error(err) 256 | continue 257 | } 258 | 259 | if err := c.sendRequest(request{ 260 | Jsonrpc: "2.0", 261 | ID: nil, // notification 262 | Method: chClose, 263 | Params: rp, 264 | }); err != nil { 265 | log.Warnf("closed out channel sendRequest failed: %s", err) 266 | } 267 | continue 268 | } 269 | 270 | // forward message 271 | rp, err := json.Marshal([]param{{v: reflect.ValueOf(caseToID[chosen-internal])}, {v: val}}) 272 | if err != nil { 273 | log.Errorw("marshaling params for sendRequest failed", "err", err) 274 | continue 275 | } 276 | 277 | if err := c.sendRequest(request{ 278 | Jsonrpc: "2.0", 279 | ID: nil, // notification 280 | Method: chValue, 281 | Params: rp, 282 | }); err != nil { 283 | log.Warnf("sendRequest failed: %s", err) 284 | return 285 | } 286 | } 287 | } 288 | 289 | // handleChanOut registers output channel for forwarding to client 290 | func (c *wsConn) handleChanOut(ch reflect.Value, req interface{}) error { 291 | c.spawnOutChanHandlerOnce.Do(func() { 292 | go c.handleOutChans() 293 | }) 294 | id := atomic.AddUint64(&c.chanCtr, 1) 295 | 296 | select { 297 | case c.registerCh <- outChanReg{ 298 | reqID: req, 299 | 300 | chID: id, 301 | ch: ch, 302 | }: 303 | return nil 304 | case <-c.exiting: 305 | return xerrors.New("connection closing") 306 | } 307 | } 308 | 309 | // // 310 | // Context.Done propagation // 311 | // // 312 | 313 | // handleCtxAsync handles context lifetimes for client 314 | // TODO: this should be aware of events going through chanHandlers, and quit 315 | // 316 | // when the related channel is closed. 317 | // This should also probably be a single goroutine, 318 | // Note that not doing this should be fine for now as long as we are using 319 | // contexts correctly (cancelling when async functions are no longer is use) 320 | func (c *wsConn) handleCtxAsync(actx context.Context, id interface{}) { 321 | <-actx.Done() 322 | 323 | rp, err := json.Marshal([]param{{v: reflect.ValueOf(id)}}) 324 | if err != nil { 325 | log.Errorw("marshaling params for sendRequest failed", "err", err) 326 | return 327 | } 328 | 329 | if err := c.sendRequest(request{ 330 | Jsonrpc: "2.0", 331 | Method: wsCancel, 332 | Params: rp, 333 | }); err != nil { 334 | log.Warnw("failed to send request", "method", wsCancel, "id", id, "error", err.Error()) 335 | } 336 | } 337 | 338 | // cancelCtx is a built-in rpc which handles context cancellation over rpc 339 | func (c *wsConn) cancelCtx(req frame) { 340 | if req.ID != nil { 341 | log.Warnf("%s call with ID set, won't respond", wsCancel) 342 | } 343 | 344 | var params []param 345 | if err := json.Unmarshal(req.Params, ¶ms); err != nil { 346 | log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err) 347 | return 348 | } 349 | 350 | var id interface{} 351 | if err := json.Unmarshal(params[0].data, &id); err != nil { 352 | log.Error("handle me:", err) 353 | return 354 | } 355 | 356 | c.handlingLk.Lock() 357 | defer c.handlingLk.Unlock() 358 | 359 | cf, ok := c.handling[id] 360 | if ok { 361 | cf() 362 | } 363 | } 364 | 365 | // // 366 | // Main Handling logic // 367 | // // 368 | 369 | func (c *wsConn) handleChanMessage(frame frame) { 370 | var params []param 371 | if err := json.Unmarshal(frame.Params, ¶ms); err != nil { 372 | log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err) 373 | return 374 | } 375 | 376 | var chid uint64 377 | if err := json.Unmarshal(params[0].data, &chid); err != nil { 378 | log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err) 379 | return 380 | } 381 | 382 | c.chanHandlersLk.Lock() 383 | hnd, ok := c.chanHandlers[chid] 384 | if !ok { 385 | c.chanHandlersLk.Unlock() 386 | log.Errorf("xrpc.ch.val: handler %d not found", chid) 387 | return 388 | } 389 | 390 | hnd.lk.Lock() 391 | defer hnd.lk.Unlock() 392 | 393 | c.chanHandlersLk.Unlock() 394 | 395 | hnd.cb(params[1].data, true) 396 | } 397 | 398 | func (c *wsConn) handleChanClose(frame frame) { 399 | var params []param 400 | if err := json.Unmarshal(frame.Params, ¶ms); err != nil { 401 | log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err) 402 | return 403 | } 404 | 405 | var chid uint64 406 | if err := json.Unmarshal(params[0].data, &chid); err != nil { 407 | log.Error("failed to unmarshal channel id in xrpc.ch.val: %s", err) 408 | return 409 | } 410 | 411 | c.chanHandlersLk.Lock() 412 | hnd, ok := c.chanHandlers[chid] 413 | if !ok { 414 | c.chanHandlersLk.Unlock() 415 | log.Errorf("xrpc.ch.val: handler %d not found", chid) 416 | return 417 | } 418 | 419 | hnd.lk.Lock() 420 | defer hnd.lk.Unlock() 421 | 422 | delete(c.chanHandlers, chid) 423 | 424 | c.chanHandlersLk.Unlock() 425 | 426 | hnd.cb(nil, false) 427 | } 428 | 429 | func (c *wsConn) handleResponse(frame frame) { 430 | c.inflightLk.Lock() 431 | req, ok := c.inflight[frame.ID] 432 | c.inflightLk.Unlock() 433 | if !ok { 434 | log.Error("client got unknown ID in response") 435 | return 436 | } 437 | 438 | if req.retCh != nil && frame.Result != nil { 439 | // output is channel 440 | var chid uint64 441 | if err := json.Unmarshal(frame.Result, &chid); err != nil { 442 | log.Errorf("failed to unmarshal channel id response: %s, data '%s'", err, string(frame.Result)) 443 | return 444 | } 445 | 446 | chanCtx, chHnd := req.retCh() 447 | 448 | c.chanHandlersLk.Lock() 449 | c.chanHandlers[chid] = &chanHandler{cb: chHnd} 450 | c.chanHandlersLk.Unlock() 451 | 452 | go c.handleCtxAsync(chanCtx, frame.ID) 453 | } 454 | 455 | req.ready <- clientResponse{ 456 | Jsonrpc: frame.Jsonrpc, 457 | Result: frame.Result, 458 | ID: frame.ID, 459 | Error: frame.Error, 460 | } 461 | c.inflightLk.Lock() 462 | delete(c.inflight, frame.ID) 463 | c.inflightLk.Unlock() 464 | } 465 | 466 | func (c *wsConn) handleCall(ctx context.Context, frame frame) { 467 | if c.handler == nil { 468 | log.Error("handleCall on client with no reverse handler") 469 | return 470 | } 471 | 472 | req := request{ 473 | Jsonrpc: frame.Jsonrpc, 474 | ID: frame.ID, 475 | Meta: frame.Meta, 476 | Method: frame.Method, 477 | Params: frame.Params, 478 | } 479 | 480 | ctx, cancel := context.WithCancel(ctx) 481 | 482 | nextWriter := func(cb func(io.Writer)) { 483 | cb(io.Discard) 484 | } 485 | done := func(keepCtx bool) { 486 | if !keepCtx { 487 | cancel() 488 | } 489 | } 490 | if frame.ID != nil { 491 | nextWriter = c.nextWriter 492 | 493 | c.handlingLk.Lock() 494 | c.handling[frame.ID] = cancel 495 | c.handlingLk.Unlock() 496 | 497 | done = func(keepctx bool) { 498 | c.handlingLk.Lock() 499 | defer c.handlingLk.Unlock() 500 | 501 | if !keepctx { 502 | cancel() 503 | delete(c.handling, frame.ID) 504 | } 505 | } 506 | } 507 | 508 | go c.handler.handle(ctx, req, nextWriter, rpcError, done, c.handleChanOut) 509 | } 510 | 511 | // handleFrame handles all incoming messages (calls and responses) 512 | func (c *wsConn) handleFrame(ctx context.Context, frame frame) { 513 | // Get message type by method name: 514 | // "" - response 515 | // "xrpc.*" - builtin 516 | // anything else - incoming remote call 517 | switch frame.Method { 518 | case "": // Response to our call 519 | c.handleResponse(frame) 520 | case wsCancel: 521 | c.cancelCtx(frame) 522 | case chValue: 523 | c.handleChanMessage(frame) 524 | case chClose: 525 | c.handleChanClose(frame) 526 | default: // Remote call 527 | c.handleCall(ctx, frame) 528 | } 529 | } 530 | 531 | func (c *wsConn) closeInFlight() { 532 | c.inflightLk.Lock() 533 | for id, req := range c.inflight { 534 | req.ready <- clientResponse{ 535 | Jsonrpc: "2.0", 536 | ID: id, 537 | Error: &JSONRPCError{ 538 | Message: "handler: websocket connection closed", 539 | Code: eTempWSError, 540 | }, 541 | } 542 | } 543 | c.inflight = map[interface{}]clientRequest{} 544 | c.inflightLk.Unlock() 545 | 546 | c.handlingLk.Lock() 547 | for _, cancel := range c.handling { 548 | cancel() 549 | } 550 | c.handling = map[interface{}]context.CancelFunc{} 551 | c.handlingLk.Unlock() 552 | 553 | } 554 | 555 | func (c *wsConn) closeChans() { 556 | c.chanHandlersLk.Lock() 557 | defer c.chanHandlersLk.Unlock() 558 | 559 | for chid := range c.chanHandlers { 560 | hnd := c.chanHandlers[chid] 561 | 562 | hnd.lk.Lock() 563 | 564 | delete(c.chanHandlers, chid) 565 | 566 | c.chanHandlersLk.Unlock() 567 | 568 | hnd.cb(nil, false) 569 | 570 | hnd.lk.Unlock() 571 | c.chanHandlersLk.Lock() 572 | } 573 | } 574 | 575 | func (c *wsConn) setupPings() func() { 576 | if c.pingInterval == 0 { 577 | return func() {} 578 | } 579 | 580 | c.conn.SetPongHandler(func(appData string) error { 581 | select { 582 | case c.pongs <- struct{}{}: 583 | default: 584 | } 585 | return nil 586 | }) 587 | c.conn.SetPingHandler(func(appData string) error { 588 | // treat pings as pongs - this lets us register server activity even if it's too busy to respond to our pings 589 | select { 590 | case c.pongs <- struct{}{}: 591 | default: 592 | } 593 | return nil 594 | }) 595 | 596 | stop := make(chan struct{}) 597 | 598 | go func() { 599 | for { 600 | select { 601 | case <-time.After(c.pingInterval): 602 | c.writeLk.Lock() 603 | if err := c.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { 604 | log.Errorf("sending ping message: %+v", err) 605 | } 606 | c.writeLk.Unlock() 607 | case <-stop: 608 | return 609 | } 610 | } 611 | }() 612 | 613 | var o sync.Once 614 | return func() { 615 | o.Do(func() { 616 | close(stop) 617 | }) 618 | } 619 | } 620 | 621 | // returns true if reconnected 622 | func (c *wsConn) tryReconnect(ctx context.Context) bool { 623 | if c.connFactory == nil { // server side 624 | return false 625 | } 626 | 627 | // connection dropped unexpectedly, do our best to recover it 628 | c.closeInFlight() 629 | c.closeChans() 630 | c.incoming = make(chan io.Reader) // listen again for responses 631 | go func() { 632 | c.stopPings() 633 | 634 | attempts := 0 635 | var conn *websocket.Conn 636 | for conn == nil { 637 | time.Sleep(c.reconnectBackoff.next(attempts)) 638 | if ctx.Err() != nil { 639 | return 640 | } 641 | var err error 642 | if conn, err = c.connFactory(); err != nil { 643 | log.Debugw("websocket connection retry failed", "error", err) 644 | } 645 | select { 646 | case <-ctx.Done(): 647 | return 648 | default: 649 | } 650 | attempts++ 651 | } 652 | 653 | c.writeLk.Lock() 654 | c.conn = conn 655 | c.errLk.Lock() 656 | c.incomingErr = nil 657 | c.errLk.Unlock() 658 | 659 | c.stopPings = c.setupPings() 660 | 661 | c.writeLk.Unlock() 662 | 663 | go c.nextMessage() 664 | }() 665 | 666 | return true 667 | } 668 | 669 | func (c *wsConn) readFrame(ctx context.Context, r io.Reader) { 670 | // debug util - dump all messages to stderr 671 | // r = io.TeeReader(r, os.Stderr) 672 | 673 | // json.NewDecoder(r).Decode would read the whole frame as well, so might as well do it 674 | // with ReadAll which should be much faster 675 | // use a autoResetReader in case the read takes a long time 676 | buf, err := io.ReadAll(c.autoResetReader(r)) // todo buffer pool 677 | if err != nil { 678 | c.readError <- xerrors.Errorf("reading frame into a buffer: %w", err) 679 | return 680 | } 681 | 682 | c.frameExecQueue <- buf 683 | if len(c.frameExecQueue) > 2*cap(c.frameExecQueue)/3 { // warn at 2/3 capacity 684 | log.Warnw("frame executor queue is backlogged", "queued", len(c.frameExecQueue), "cap", cap(c.frameExecQueue)) 685 | } 686 | 687 | // got the whole frame, can start reading the next one in background 688 | go c.nextMessage() 689 | } 690 | 691 | func (c *wsConn) frameExecutor(ctx context.Context) { 692 | for { 693 | select { 694 | case <-ctx.Done(): 695 | return 696 | case buf := <-c.frameExecQueue: 697 | var frame frame 698 | if err := json.Unmarshal(buf, &frame); err != nil { 699 | log.Warnw("failed to unmarshal frame", "error", err) 700 | // todo send invalid request response 701 | continue 702 | } 703 | 704 | var err error 705 | frame.ID, err = normalizeID(frame.ID) 706 | if err != nil { 707 | log.Warnw("failed to normalize frame id", "error", err) 708 | // todo send invalid request response 709 | continue 710 | } 711 | 712 | c.handleFrame(ctx, frame) 713 | } 714 | } 715 | } 716 | 717 | var maxQueuedFrames = 256 718 | 719 | func (c *wsConn) handleWsConn(ctx context.Context) { 720 | ctx, cancel := context.WithCancel(ctx) 721 | defer cancel() 722 | 723 | c.incoming = make(chan io.Reader) 724 | c.readError = make(chan error, 1) 725 | c.frameExecQueue = make(chan []byte, maxQueuedFrames) 726 | c.inflight = map[interface{}]clientRequest{} 727 | c.handling = map[interface{}]context.CancelFunc{} 728 | c.chanHandlers = map[uint64]*chanHandler{} 729 | c.pongs = make(chan struct{}, 1) 730 | 731 | c.registerCh = make(chan outChanReg) 732 | defer close(c.exiting) 733 | 734 | // //// 735 | 736 | // on close, make sure to return from all pending calls, and cancel context 737 | // on all calls we handle 738 | defer c.closeInFlight() 739 | defer c.closeChans() 740 | 741 | // setup pings 742 | 743 | c.stopPings = c.setupPings() 744 | defer c.stopPings() 745 | 746 | var timeoutTimer *time.Timer 747 | if c.timeout != 0 { 748 | timeoutTimer = time.NewTimer(c.timeout) 749 | defer timeoutTimer.Stop() 750 | } 751 | 752 | // start frame executor 753 | go c.frameExecutor(ctx) 754 | 755 | // wait for the first message 756 | go c.nextMessage() 757 | for { 758 | var timeoutCh <-chan time.Time 759 | if timeoutTimer != nil { 760 | if !timeoutTimer.Stop() { 761 | select { 762 | case <-timeoutTimer.C: 763 | default: 764 | } 765 | } 766 | timeoutTimer.Reset(c.timeout) 767 | 768 | timeoutCh = timeoutTimer.C 769 | } 770 | 771 | start := time.Now() 772 | action := "" 773 | 774 | select { 775 | case r, ok := <-c.incoming: 776 | action = "incoming" 777 | c.errLk.Lock() 778 | err := c.incomingErr 779 | c.errLk.Unlock() 780 | 781 | if ok { 782 | go c.readFrame(ctx, r) 783 | break 784 | } 785 | 786 | if err == nil { 787 | return // remote closed 788 | } 789 | 790 | log.Debugw("websocket error", "error", err, "lastAction", action, "time", time.Since(start)) 791 | // only client needs to reconnect 792 | if !c.tryReconnect(ctx) { 793 | return // failed to reconnect 794 | } 795 | case rerr := <-c.readError: 796 | action = "read-error" 797 | 798 | log.Debugw("websocket error", "error", rerr, "lastAction", action, "time", time.Since(start)) 799 | if !c.tryReconnect(ctx) { 800 | return // failed to reconnect 801 | } 802 | case <-ctx.Done(): 803 | log.Debugw("context cancelled", "error", ctx.Err(), "lastAction", action, "time", time.Since(start)) 804 | return 805 | case req := <-c.requests: 806 | action = fmt.Sprintf("send-request(%s,%v)", req.req.Method, req.req.ID) 807 | 808 | c.writeLk.Lock() 809 | if req.req.ID != nil { // non-notification 810 | c.errLk.Lock() 811 | hasErr := c.incomingErr != nil 812 | c.errLk.Unlock() 813 | if hasErr { // No conn?, immediate fail 814 | req.ready <- clientResponse{ 815 | Jsonrpc: "2.0", 816 | ID: req.req.ID, 817 | Error: &JSONRPCError{ 818 | Message: "handler: websocket connection closed", 819 | Code: eTempWSError, 820 | }, 821 | } 822 | c.writeLk.Unlock() 823 | break 824 | } 825 | c.inflightLk.Lock() 826 | c.inflight[req.req.ID] = req 827 | c.inflightLk.Unlock() 828 | } 829 | c.writeLk.Unlock() 830 | serr := c.sendRequest(req.req) 831 | if serr != nil { 832 | log.Errorf("sendReqest failed (Handle me): %s", serr) 833 | } 834 | if req.req.ID == nil { // notification, return immediately 835 | resp := clientResponse{ 836 | Jsonrpc: "2.0", 837 | } 838 | if serr != nil { 839 | resp.Error = &JSONRPCError{ 840 | Code: eTempWSError, 841 | Message: fmt.Sprintf("sendRequest: %s", serr), 842 | } 843 | } 844 | req.ready <- resp 845 | } 846 | 847 | case <-c.pongs: 848 | action = "pong" 849 | 850 | c.resetReadDeadline() 851 | case <-timeoutCh: 852 | if c.pingInterval == 0 { 853 | // pings not running, this is perfectly normal 854 | continue 855 | } 856 | 857 | c.writeLk.Lock() 858 | if err := c.conn.Close(); err != nil { 859 | log.Warnw("timed-out websocket close error", "error", err) 860 | } 861 | c.writeLk.Unlock() 862 | log.Errorw("Connection timeout", "remote", c.conn.RemoteAddr(), "lastAction", action) 863 | // The server side does not perform the reconnect operation, so need to exit 864 | if c.connFactory == nil { 865 | return 866 | } 867 | // The client performs the reconnect operation, and if it exits it cannot start a handleWsConn again, so it does not need to exit 868 | continue 869 | case <-c.stop: 870 | c.writeLk.Lock() 871 | cmsg := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") 872 | if err := c.conn.WriteMessage(websocket.CloseMessage, cmsg); err != nil { 873 | log.Warn("failed to write close message: ", err) 874 | } 875 | if err := c.conn.Close(); err != nil { 876 | log.Warnw("websocket close error", "error", err) 877 | } 878 | c.writeLk.Unlock() 879 | return 880 | } 881 | 882 | if c.pingInterval > 0 && time.Since(start) > c.pingInterval*2 { 883 | log.Warnw("websocket long time no response", "lastAction", action, "time", time.Since(start)) 884 | } 885 | if debugTrace { 886 | log.Debugw("websocket action", "lastAction", action, "time", time.Since(start)) 887 | } 888 | } 889 | } 890 | 891 | var onReadDeadlineResetInterval = 5 * time.Second 892 | 893 | // autoResetReader wraps a reader and resets the read deadline on if needed when doing large reads. 894 | func (c *wsConn) autoResetReader(reader io.Reader) io.Reader { 895 | return &deadlineResetReader{ 896 | r: reader, 897 | reset: c.resetReadDeadline, 898 | 899 | lastReset: time.Now(), 900 | } 901 | } 902 | 903 | type deadlineResetReader struct { 904 | r io.Reader 905 | reset func() 906 | 907 | lastReset time.Time 908 | } 909 | 910 | func (r *deadlineResetReader) Read(p []byte) (n int, err error) { 911 | n, err = r.r.Read(p) 912 | if time.Since(r.lastReset) > onReadDeadlineResetInterval { 913 | log.Warnw("slow/large read, resetting deadline while reading the frame", "since", time.Since(r.lastReset), "n", n, "err", err, "p", len(p)) 914 | 915 | r.reset() 916 | r.lastReset = time.Now() 917 | } 918 | return 919 | } 920 | 921 | func (c *wsConn) resetReadDeadline() { 922 | if c.timeout > 0 { 923 | if err := c.conn.SetReadDeadline(time.Now().Add(c.timeout)); err != nil { 924 | log.Error("setting read deadline", err) 925 | } 926 | } 927 | } 928 | 929 | // Takes an ID as received on the wire, validates it, and translates it to a 930 | // normalized ID appropriate for keying. 931 | func normalizeID(id interface{}) (interface{}, error) { 932 | switch v := id.(type) { 933 | case string, float64, nil: 934 | return v, nil 935 | case int64: // clients sending int64 need to normalize to float64 936 | return float64(v), nil 937 | default: 938 | return nil, xerrors.Errorf("invalid id type: %T", id) 939 | } 940 | } 941 | --------------------------------------------------------------------------------